## This code will implement the Seq2seq model

In [128]:
import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F

from Functions_generation import Decoder_Atten, Encodeur_Atten,generate_a_song_structure

## Pipeline comparaison

In [129]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [130]:
with open(r"..\Seq2seq\Files\Encoding_map.pkl", "rb") as f:
    mapping = pickle.load(f)

int2char = {i: ch for ch, i in mapping.items()}
nb_char = len(int2char)

In [131]:
matrix = pd.read_csv(r"..\Markov\transition_matrix.csv")

states = np.array(matrix.iloc[6])
prob_transi = np.array(matrix.iloc[0:6])

## Model loading

In [132]:
vocab_size = len(mapping)
embedding_size = 96
hidden_size = 512

#Some of the mapping are only for the encodeur so the decodeur can't produce them, we need to mask them from the loss
mapping_inverse = {i: ch for ch, i in mapping.items()}
masked_mapping = list(mapping_inverse.keys())[116:-1]

mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
mask[masked_mapping] = True

enco = Encodeur_Atten(vocab_size, embedding_size, hidden_size, num_layers=2).to(device)
deco = Decoder_Atten(vocab_size, embedding_size, hidden_size, mask,num_layers=2).to(device)

In [133]:
ckpt = torch.load("..\Models\Seq2seq_att_model.pt", map_location="cpu")

enco.load_state_dict(ckpt["encoder_state_dict"])
deco.load_state_dict(ckpt["decoder_state_dict"])

enco.eval()
deco.eval()

Decoder_Atten(
  (embed): Embedding(223, 96)
  (lstm): LSTM(608, 512, num_layers=2, batch_first=True, dropout=0.2)
  (final): Linear(in_features=512, out_features=223, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

## Preparation before generating 

We need to create the first context

In [134]:
struct = generate_a_song_structure(prob_transi.astype(float),states)

struct = ['<BEGINNING>', '<COUPLET>', '<REFRAIN>', '<COUPLET>', '<REFRAIN>', '<END>']
len_struct = [4, 3, 4, 3]

encoded = torch.tensor([mapping[c] for c in struct[1]], dtype=torch.long).unsqueeze(0)

['<BEGINNING>', '<COUPLET>', '<REFRAIN>', '<COUPLET>', '<REFRAIN>', '<OUTRO>', '<END>']


At first I need to generate the context

In [135]:
final_result = []
for i in range(len(struct[1:-1])) :

    context_part = f"<PART={struct[i+1][1:-1]}>" 
    lines_tot = len_struct[i]
    final_result.extend(["\n",struct[i+1],"\n\n"])
    
    for j in range(1,lines_tot+1) :
        #Context creation and encoding
        context = []
        context_encode = []
    
        #song's part
        context.extend([context_part])

        #lines
        cont_lines = [f"<lines={j}>", f"<total={lines_tot}>"]
        context.extend(cont_lines)

        #previous lines
        if j == 1 :
            cont_previous = ['Previous :', "<START>", 'Previous :',  "<START>"]
        else : #Add generated text back into the context 
            cont_previous[1] = result.split("<EOL>")[0]
            cont_previous[3] = result.split("<EOL>")[1]
        
        context.extend(cont_previous)

        #Context encoding
        new_context = []         
        for part in context : 
            if part in mapping :
                new_context.extend([mapping[part]])
            else : 
                for j in part :
                    new_context.extend([mapping[j]])

        length_cont = len(new_context)

        #Whole generation
        with torch.no_grad():

            encod_out, length_out, h, c = enco(torch.tensor([new_context]), torch.tensor([length_cont]))

            arange = torch.arange(torch.tensor([new_context]).shape[1], device=length_out.device).unsqueeze(0)
            mask_attn = (arange < length_out.unsqueeze(1)).to(device)

            input_t = torch.tensor([mapping["START"]])
            outputs = ["START"]

            while outputs[-1] != "END" :

                embedded = deco.dropout(deco.embed(input_t))
                query = h[-1]

                scores = torch.bmm(encod_out, query.unsqueeze(2)).squeeze(2)
                scores = scores.masked_fill(mask_attn == 0, float('-inf'))

                weights = F.softmax(scores, dim=-1)
                h_prime = torch.bmm(weights.unsqueeze(1), encod_out).squeeze(1)

                lstm_in = torch.cat([embedded, h_prime], dim=-1).unsqueeze(1)
                out, (h, c) = deco.lstm(lstm_in, (h, c))

                logit = deco.final(out.squeeze(1))  # (B, vocab_size)
                masked_logit = logit.masked_fill(deco.mask, float("-inf"))

                temperature = 0.5
                top_k = 20
                
                probs = torch.softmax(masked_logit/temperature, dim=-1)
                topk_probs, topk_idx = torch.topk(probs, top_k, dim=-1)

                idx = torch.multinomial(topk_probs, 1)
                next_token = topk_idx.gather(-1, idx)

                outputs.append(int2char[next_token.item()])

                input_t = next_token.squeeze(1)

                result = "".join(outputs[1:-1])

        for line in result.split("<EOL>") :
            if "<" in line :
                continue
            final_result.extend([line,"\n"])

final_result.extend(["\n","<END>"])

In [136]:
print("".join(final_result))


<COUPLET>

J'suis dans la soupe, j'ai fait l'fil de mon cœur
Y'a pas d'rester calmer, en vrai j'suis pas d'jouer le soir d'espoir de l'argent du mal à sang
Depuis le casse-coupe de mes potes et partir
J'ai fait la politique, j'me sens changer dans les cartes d'eau en côté d'excitation
J'ai vu des meufs qui m'ramènent pas d'foot avant qu'j'étais seul
J'ai des frères, j'ai mis des extrêmes de manière sous l'issue
Nique les fantômes comme deux ans d'faire des doigts
Qu'est comme un peu d'bain de la rue, on s'en traite à ceux qu'ont des formes

<REFRAIN>

Son of a bitch
Son of, son of a bitch
Définition d'un OG
Définition d'un OG
Dojo F pour le mal-être, j'ai cette chance
Mais j'ai b'soin d'mon lotissement

<COUPLET>

J'aimerais bien faire le corps de toute façon, j'incarne un seul comme la foi du plus d'entraire
J'ai l'âme et l'argent de propre avec des marques
J'ai pas de fermer la chatte, c'est pas d'chicha, la partie d'une vie en bataille
J'me sens comme un riz et les sous-bizarres
J'