In [103]:
import numpy as np
from transformers import PreTrainedTokenizerFast
import torch
import pandas as pd
import random

import sys
import os
sys.path.append(os.path.abspath(os.path.join('..', '..')))
from Functions_generation import generate_a_song_structure, sample_with_temp_topk, load_and_clean, Subword_Models

In [104]:
#Load model
model = Subword_Models(model_type='RNN')
ckpt = load_and_clean("../Models/RNN_model.pt")
model.load_state_dict(ckpt["model_state_dict"])

#Load tokenizer
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="../../../Corpus/Encoding_RNN_LSTM/Subword/rap_tokenizer.json",
    pad_token = "<PAD>")

#Load states transition matrix
matrix = pd.read_csv("../../../Markov/transition_matrix.csv")

states = np.array(matrix.iloc[6])
prob_transi = np.array(matrix.iloc[0:6])

Performance_metric (ppl) :  54.306537798841084


I will give the same structure for RNN and LSTM to compare

"""
struct = generate_a_song_structure(prob_transi.astype(float),states)

struct = ['<BEGINNING>', '<COUPLET>', '<REFRAIN>', '<COUPLET>', '<REFRAIN>', '<END>']
struct = ['α', 'γ', 'ε', 'γ', 'ε', 'θ']

encoded = [tokenizer.encode(c)[0] for c in struct]

decod_structure = {"β" : "<INTRO>",
                   "γ" : "<COUPLET>",
                   "ε" : "<REFRAIN>",
                   "ζ" : "<PONT>",
                   "η" : "<OUTRO>",
                   "θ" : "<END>"}
"""

In [105]:
context = "me voici devant toi\ntu es mort dans le livre\n"
context_ids = tokenizer.encode(context, add_special_tokens=False)
id_not_penalized = [211, 26, 43, 1474, 24, 567]
forbidden_ids = [i for i in range(2,8)]
end_of_gen = [i for i in range(8,13)]+[27]

In [116]:
def generate_text(model, tokenizer, context : str, max_token : int = 200, 
                  temperature : float = 0.8, rep_penalty : float = 1.2, rep_penalty_less : float = 1.05,
                  top_k = 40, seed = 42, forbidden_ids = forbidden_ids, id_not_penalized = id_not_penalized): 

    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)


    model.eval()

    forbidden_ids = set(forbidden_ids or [])

    # encode context
    context_ids = tokenizer.encode(context.lower(), add_special_tokens=False)
    # init hidden
    hid = model.init_hidden(batch_size=1)

    with torch.no_grad():
        last_input = None
        if model.model._get_name() == "Subword_LSTM" : 
            for tid in context_ids:
                inp = torch.tensor([tid], dtype=torch.long)
                out, hid = model(inp, hid)
                #set last_input to last token of context
                last_input = inp
        elif model.model._get_name() == "Subword_RNN" :
            last_input = None
            inp = torch.tensor([context_ids], dtype=torch.long)
            out, hid = model(inp, hid)
            # set last_input to last token of context
            last_input = inp.squeeze(0)[-1].unsqueeze(0).unsqueeze(1)

    generated = []
    token_counts = {}

    for step in range(max_token):
        with torch.no_grad():
            out, hid = model(last_input, hid)
            if out.ndim == 3: #RNN out
                logits = out[0, -1].clone()
            else: #LSTM out
                logits = out.squeeze().clone()

            if forbidden_ids:
                logits[list(forbidden_ids)] = float("-inf") #prob_forbidden = 0

            # scale logits of previously generated tokens
            if rep_penalty > 1.0 and token_counts:
                for tid, count in token_counts.items():
                    # reduce logits progressively
                    if tid not in id_not_penalized:
                        logits[tid] = logits[tid] / (rep_penalty ** count)
                    else :
                        logits[tid] = logits[tid] / (rep_penalty_less ** count)

            # apply temperature
            logits = logits / max(temperature, 1e-8)

            # top-k filtering
            if top_k > 0:
                values, indices = torch.topk(logits, min(top_k, logits.size(-1)))
                kth_value = values[-1]
                mask = logits < kth_value
                logits[mask] = float("-inf")

            probs = torch.softmax(logits, dim=-1)

            next_id = torch.multinomial(probs, num_samples=1).item()

            generated.append(next_id)
            token_counts[next_id] = token_counts.get(next_id, 0) + 1

            # prepare next input
            if model.model._get_name() == "Subword_LSTM" :
                last_input = torch.tensor([next_id], dtype=torch.long)

            elif model.model._get_name() == "Subword_RNN" :
                last_input = torch.tensor([[next_id]], dtype=torch.long)

    text = tokenizer.decode(generated, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    return context+text

In [117]:
print(generate_text(model,tokenizer,context))

me voici devant toi
tu es mort dans le livre
j'ai le quartier, et j'suis tout à la fin saint-o
j'me réveille loin comme un oseille
et nous tu connais ma ville m'attrape avec les hommes de merde
les yeux fermés et des larmes de kaïn

 re
cousin s.i.p oligy l'eserx en cuir dabedar, dans la rame danichie
la plupart du désespoir, y a pas trop d'temps quand on côtoie le meilleur chemin que j'déprime, c'tait bien

t
la vie c'est pas très mal-haut?
il disait: ça change rien
rien n'sert peut être rempli jusquainau week
sur qui tu fais quoi?? tpa pour quasi amateur, on a su l'regardre et tous mes potes au gourou!
deg-ju


In [None]:
def generate_with_repetition_penalty(
    model,
    tokenizer,
    context: str = "",
    max_tokens: int = 200,
    temperature: float = 0.8,
    top_k: Optional[int] = 40,
    rep_penalty: float = 1.2,
    ngram_block: Optional[int] = 3,
    forbidden_ids: Optional[List[int]] = None,
    device: Optional[str] = None,
    seed: int = 0,
):
    """
    Generate tokens step-by-step from a Subword RNN/LSTM model while penalizing repetitions.

    - rep_penalty > 1 reduces probability of tokens already generated (frequency-based).
    - ngram_block = n will block generation of any token that would recreate a previously seen n-gram.
    - forbidden_ids: list of token ids to always ban (set logit = -inf).
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

    # prepare model + tokenizer
    model.to(device)
    model.eval()

    forbidden_ids = set(forbidden_ids or [])

    # encode context
    context_ids = tokenizer.encode(context, add_special_tokens=False)
    # init hidden
    hid = model.init_hidden(batch_size=1)

    # prime model with context tokens (token-by-token)
    with torch.no_grad():
        last_input = None
        for tid in context_ids:
            inp = torch.tensor([[tid]], dtype=torch.long, device=device)
            out, hid = model(inp, hid)
            # set last_input to last token of context
            last_input = inp

        # if no context, pick a non-pad start token (id != 0)
        if last_input is None:
            start_id = 1 if tokenizer.pad_token_id == 0 else tokenizer.pad_token_id + 1
            last_input = torch.tensor([[start_id]], dtype=torch.long, device=device)

    generated = []
    token_counts = {}
    seen_ngrams = set()  # store tuples of length ngram_block
    if ngram_block and len(context_ids) >= ngram_block:
        # initialize seen_ngrams from context
        for i in range(len(context_ids) - ngram_block + 1):
            seen_ngrams.add(tuple(context_ids[i:i+ngram_block]))

    for step in range(max_tokens):
        with torch.no_grad():
            out, hid = model(last_input, hid)
            # normalize logits retrieval (handle (1,1,V) or (1,V) shapes)
            if out.ndim == 3:
                logits = out[0, -1].clone()
            elif out.ndim == 2:
                logits = out[0].clone()
            else:
                logits = out.squeeze().clone()

            vocab_size = logits.shape[-1]

            # ban forbidden ids (pad / special)
            if forbidden_ids:
                logits[list(forbidden_ids)] = float("-inf")

            # apply n-gram blocking: for current prefix (last n-1 tokens), ban candidates that recreate seen n-gram
            if ngram_block and ngram_block > 1 and len(generated) >= ngram_block - 1:
                prefix = tuple(generated[-(ngram_block-1):]) if ngram_block > 1 else ()
                # for efficiency: build mask over vocab where prefix + token in seen_ngrams
                if prefix:
                    # iterate tokens that would form a seen ngram
                    for tok_id in range(vocab_size):
                        if tuple(prefix + (tok_id,)) in seen_ngrams:
                            logits[tok_id] = float("-inf")

            # frequency-based repetition penalty: scale logits of previously generated tokens
            if rep_penalty is not None and rep_penalty > 1.0 and token_counts:
                for tid, count in token_counts.items():
                    # reduce logits progressively; dividing by rep_penalty**count lowers its prob
                    logits[tid] = logits[tid] / (rep_penalty ** count)

            # apply temperature
            logits = logits / max(temperature, 1e-8)

            # top-k filtering
            if top_k is not None and top_k > 0:
                values, indices = torch.topk(logits, min(top_k, logits.size(-1)))
                kth_value = values[-1]
                mask = logits < kth_value
                logits[mask] = float("-inf")

            probs = torch.softmax(logits, dim=-1)
            # if all -inf -> break
            if torch.isnan(probs).any() and torch.sum(~torch.isfinite(logits)) == logits.numel():
                break

            next_id = torch.multinomial(probs, num_samples=1).item()

        # update bookkeeping
        generated.append(next_id)
        token_counts[next_id] = token_counts.get(next_id, 0) + 1
        # update seen ngrams
        if ngram_block and ngram_block > 1 and len(generated) >= ngram_block:
            seen_ngrams.add(tuple(generated[-ngram_block:]))

        # prepare next input
        last_input = torch.tensor([[next_id]], dtype=torch.long, device=device)

    text = tokenizer.decode(generated, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return text, generated