In [2]:
import numpy as np
import pandas as pd
import pickle

import torch
import torch.nn as nn

from Functions_generation import Decodeur_no_att, Encodeur_no_att,generate_a_song_structure

# Preparation before Generating

In [3]:
with open(r"..\Seq2seq\Files\Encoding_map.pkl", "rb") as f:
    mapping = pickle.load(f)

int2char = {i: ch for ch, i in mapping.items()}
nb_char = len(int2char)

In [4]:
matrix = pd.read_csv(r"..\Markov\transition_matrix.csv")

states = np.array(matrix.iloc[6])
prob_transi = np.array(matrix.iloc[0:6])

# Model loading

In [5]:
vocab_size = len(mapping)
embedding_size = 96
hidden_size = 512
num_epoch = 200

#Some of the mapping are only for the encodeur so the decodeur can't produce them, we need to mask them from the loss
mapping_inverse = {i: ch for ch, i in mapping.items()}
masked_mapping = list(mapping_inverse.keys())[116:-1]

mask = torch.zeros(vocab_size, dtype=torch.bool)
mask[masked_mapping] = True

enco = Encodeur_no_att(vocab_size, embedding_size, hidden_size, num_layers=2)
deco = Decodeur_no_att(vocab_size, embedding_size, hidden_size, mask,num_layers=2)

In [6]:
ckpt = torch.load("..\Models\Seq2seq_noatt_model.pt", map_location="cpu")

enco.load_state_dict(ckpt["encoder_state_dict"])
deco.load_state_dict(ckpt["decoder_state_dict"])

enco.eval()
deco.eval()

Decodeur_no_att(
  (embed): Embedding(223, 96)
  (lstm): LSTM(96, 512, num_layers=2, batch_first=True, dropout=0.2)
  (final): Linear(in_features=512, out_features=223, bias=True)
)

## Preparation before generating 

We need to create the first context

In [7]:
struct = generate_a_song_structure(prob_transi.astype(float),states)

encoded = torch.tensor([mapping[c] for c in struct[1]], dtype=torch.long).unsqueeze(0)

['<BEGINNING>', '<COUPLET>', '<REFRAIN>', '<COUPLET>', '<REFRAIN>', '<COUPLET>', '<END>']


At first, I will generate the song lyrics by indicating the length of the part

In [8]:
first_context = ["<PART=COUPLET>", '<lines=1>', '<total=4>', 'Previous :', '<START>', 'Previous :', '<START>']

context_encode = []

for i in first_context :
    if mapping[i] :
        context_encode.append(mapping[i])
    else : 
        for j in i :
            context_encode.append(mapping[j])

length_cont = len(context_encode)

In [9]:
h, c = enco(torch.tensor([context_encode]), torch.tensor([length_cont]))

input_t = torch.tensor([mapping["START"]])
outputs = ["START"]

In [10]:
while outputs[-1] != "END" :
    with torch.no_grad():

        emb = deco.embed(input_t).unsqueeze(1)
        out, (h_dec, c_dec) = deco.lstm(emb, (h, c))

        logit = deco.final(out.squeeze(1))
        masked_logit = logit.masked_fill(deco.mask, float("-inf"))

        probs = torch.softmax(masked_logit/0.5, dim=-1)
        topk_probs, topk_idx = torch.topk(probs, 20, dim=-1)

        idx = torch.multinomial(topk_probs, 1)
        next_token = topk_idx.gather(-1, idx)
#        next_token = probs.argmax()
        outputs.append(int2char[next_token.item()])

        input_t = next_token.squeeze(1)

        if len(outputs) > 100 :
            break

"".join(outputs[1:-1])

'OUPETITINMOUPAMPLDJPETNMUPETNMETAQTNONMJNMJEOTJCCSLEJJBMNONMNBJBMETLBMENNMEETNCLOTLNMYAJNONMNMYETNM'