In [2]:
import numpy as np
from transformers import PreTrainedTokenizerFast
import torch
import pandas as pd
import random

import sys
import os
sys.path.append(os.path.abspath(os.path.join('..', '..')))
from Functions_generation import generate_a_song_structure, sample_with_temp_topk, load_and_clean, Subword_Models

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#Load model
model = Subword_Models(model_type='LSTM')
ckpt = load_and_clean("../Models/LSTM_model.pt")
model.load_state_dict(ckpt["model_state_dict"])

#Load tokenizer
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="../../../Corpus/Encoding_RNN_LSTM/Subword/rap_tokenizer.json",
    pad_token = "<PAD>")

#Load states transition matrix
matrix = pd.read_csv("../../../Markov/transition_matrix.csv")

states = np.array(matrix.iloc[6])
prob_transi = np.array(matrix.iloc[0:6])

Performance_metric (ppl) :  50.80905611509434


I will give the same structure for RNN and LSTM to compare

"""
struct = generate_a_song_structure(prob_transi.astype(float),states)

struct = ['<BEGINNING>', '<COUPLET>', '<REFRAIN>', '<COUPLET>', '<REFRAIN>', '<END>']
struct = ['α', 'γ', 'ε', 'γ', 'ε', 'θ']

encoded = [tokenizer.encode(c)[0] for c in struct]

decod_structure = {"β" : "<INTRO>",
                   "γ" : "<COUPLET>",
                   "ε" : "<REFRAIN>",
                   "ζ" : "<PONT>",
                   "η" : "<OUTRO>",
                   "θ" : "<END>"}
"""

In [None]:
context = "Me voici devant toi\nTu es mort dans le livre\n"
context_ids = tokenizer.encode(context, add_special_tokens=False)
id_not_penalized = [211, 26, 43, 1474, 24, 567]
forbidden_ids = [i for i in range(2,8)]
end_of_gen = [i for i in range(8,13)]+[27]

In [None]:
def generate_text(model, tokenizer, context : str, max_token : int = 200, 
                  temperature : float = 0.8, rep_penalty : float = 1.2, rep_penalty_less : float = 1.05,
                  top_k = 40, seed = 42, forbidden_ids = forbidden_ids, id_not_penalized = id_not_penalized): 

    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)


    model.eval()

    forbidden_ids = set(forbidden_ids or [])

    # encode context
    context_ids = tokenizer.encode(context.lower(), add_special_tokens=False)
    # init hidden
    hid = model.init_hidden(batch_size=1)

    with torch.no_grad():
        last_input = None
        if model.model._get_name() == "Subword_LSTM" : 
            for tid in context_ids:
                inp = torch.tensor([tid], dtype=torch.long)
                out, hid = model(inp, hid)
                #set last_input to last token of context
                last_input = inp
        elif model.model._get_name() == "Subword_RNN" :
            last_input = None
            inp = torch.tensor([context_ids], dtype=torch.long)
            out, hid = model(inp, hid)
            # set last_input to last token of context
            last_input = inp.squeeze(0)[-1].unsqueeze(0).unsqueeze(1)

    generated = []
    token_counts = {}

    for step in range(max_token):
        with torch.no_grad():
            out, hid = model(last_input, hid)
            if out.ndim == 3: #RNN out
                logits = out[0, -1].clone()
            else: #LSTM out
                logits = out.squeeze().clone()

            if forbidden_ids:
                logits[list(forbidden_ids)] = float("-inf") #prob_forbidden = 0

            # scale logits of previously generated tokens
            if rep_penalty > 1.0 and token_counts:
                for tid, count in token_counts.items():
                    # reduce logits progressively
                    if tid not in id_not_penalized:
                        logits[tid] = logits[tid] / (rep_penalty ** count)
                    else :
                        logits[tid] = logits[tid] / (rep_penalty_less ** count)

            # apply temperature
            logits = logits / max(temperature, 1e-8)

            # top-k filtering
            if top_k > 0:
                values, indices = torch.topk(logits, min(top_k, logits.size(-1)))
                kth_value = values[-1]
                mask = logits < kth_value
                logits[mask] = float("-inf")

            probs = torch.softmax(logits, dim=-1)

            next_id = torch.multinomial(probs, num_samples=1).item()

            generated.append(next_id)
            token_counts[next_id] = token_counts.get(next_id, 0) + 1

            # prepare next input
            if model.model._get_name() == "Subword_LSTM" :
                last_input = torch.tensor([next_id], dtype=torch.long)

            elif model.model._get_name() == "Subword_RNN" :
                last_input = torch.tensor([[next_id]], dtype=torch.long)

    text = tokenizer.decode(generated, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    return context+text

In [54]:
print(generate_text(model,tokenizer,context, rep_penalty = 1.2))

me voici devant toi
tu es mort dans le livre
j'ai le sang froid et l'soir, j'te dirais pas tout c'monde, yeah
pas de blem', t'as plus peur du ciel
si ça m'demande: il fait notre beau mais est toujours pour-co
hun, je vais pas les plombs
je voudrais qu'tu perceptes la tête avec un bon spliff
quand tu fais mon rêve sans faire un grand frère qui dit quoi?
qu'est-ce tu veux que ça soit bizarre, c'est une clope à la page, ouvre des gros sacrimes secondes
les enfants d'reviendrai jamais seul dans ma tête vide
pour moi ce matin j'veux juste m'dois penser en eux, nananana comme tony sur mes chillers
on s'reste après le bip,  on a aucun troisième pétard et tous mes potes au quartier de serrer ta mère mais ils sont content
