In [85]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader,TensorDataset
from gensim.models import Word2Vec

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class RNNVAE(nn.Module):
    def __init__(self, embedding_matrix, word2idx, idx2word, hidden_dim, latent_dim, num_layers, sos_token, vocab_size):
        super(RNNVAE, self).__init__()

        self.embedding_dim = embedding_matrix.shape[1]
        self.word2idx = word2idx
        self.idx2word = idx2word
        self.sos_token = sos_token
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze = True)
        self.layer_norm = nn.LayerNorm(self.embedding_dim)

        self.encoder = nn.RNN(self.embedding_dim, hidden_dim, num_layers, batch_first = True)
        self.decoder = nn.RNN(self.embedding_dim, hidden_dim, num_layers, batch_first = True)

        self.fc_mu = nn.Linear(hidden_dim, latent_dim)  
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc_hidden = nn.Linear(latent_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    
    def forward(self, x):
        embedded_input = self.embedding(x)
        embedded_input = self.layer_norm(embedded_input)

        _, hn = self.encoder(embedded_input)

        mu = self.fc_mu(hn)
        logvar = self.fc_logvar(hn)
        z = self.reparametrization(mu, logvar)

        z = self.fc_hidden(z)

        # prepare sos_token for the decoder
        sos_token = self.sos_token.repeat(x.size(0),1)
        sos_token = self.embedding(sos_token)
        sos_token = self.layer_norm(sos_token)

        decoder_input = torch.cat((sos_token, embedded_input), dim = 1)
        decoder_input = decoder_input[:,:-1,:]

        reconstructed_sequence, _ = self.decoder(decoder_input, z)
        '''# reconstructing sequence through the decoder giving z as hidden state for each time step
        reconstructed_sequence = []
        for t in range(x.shape[1]):
            outputs, _ = self.decoder(decoder_input[:,:t+1,:], z)
            reconstructed_sequence.append(outputs[:,-1,:].unsqueeze(1))

        # concatenating reconstructed words and push them into vocab_size dimensions
        reconstructed_sequence = torch.cat(reconstructed_sequence, dim=1)'''
        reconstructed_sequence = self.fc(reconstructed_sequence)
        
        return reconstructed_sequence, mu, logvar
    


    def reparametrization(self, mu, log_var):
        ''' Reparametrization trick
        
        Inputs
        -------
        mu : torch tensor
        log_var : torch tensor
            
        
        Returns
        -------
        mu + eps*std : torch tensor with the same shape as mu and log_var'''
        
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)

        return mu + eps*std


    def reconstruction(self, x, sample_type = 'multinomial'):
        ''' Reconstruction function for inference
        Input
        -------
        x : torch tensor with shape [Batch_size, Sequence_length], input sequence
            
        
        Returns
        -------
        outputs : torch tensor with shape [Batch_size, Sequence_length, Embedding_dim], the reconstructed sentence'''
        self.eval()
        with torch.no_grad():

            # embedding input and GRU encoder pass
            embedded_input = self.embedding(x)
            embedded_input = self.layer_norm(embedded_input)
            _, hn = self.encoder(embedded_input)
            
            # computing mu and log_var for style and content space
            mu_s = self.fc_mu(hn)
            logvar_s = self.fc_logvar(hn)

            # reparametrization for style and content
            z = self.reparametrization(mu_s, logvar_s)

            # concatenating style and content space
            z = self.fc_hidden(z)

            # prepare sos_token for the decoder
            sos_token = self.sos_token.repeat(x.size(0),1)
            sos_token = self.embedding(sos_token)
            sos_token = self.layer_norm(sos_token)


            # decoder pass where the input is the previous output
            output = sos_token
            for _ in range(x.shape[1]):
                outputs, _ = self.decoder(output, z)
                outputs = self.fc(outputs)
                next_token = torch.argmax(F.softmax(outputs[:,-1,:], dim = -1), dim=-1)
                #next_token = torch.multinomial(F.softmax(outputs[:,-1,:], dim = -1), 1)
                next_token = self.embedding(next_token)
                next_token = self.layer_norm(next_token)
                output = torch.cat((output, next_token.unsqueeze(1)), dim=1)
                #output = torch.cat((output, next_token), dim=1)
        
            
        if sample_type == 'argmax':
            output = torch.argmax(F.softmax(outputs.mean(0), dim = -1), dim = -1)
            
        elif sample_type == 'multinomial':
            output = torch.multinomial(F.softmax(outputs.mean(0), dim = -1), 1)
            

        reconstructed_text = [self.idx2word[w.item()] for w in output]

        return ' '.join(reconstructed_text)
    

    def sample(self, len_sample = 25, sample_type = 'multinomial'):
        z = torch.randn((self.num_layers, 1, self.latent_dim))

        self.eval()
        with torch.no_grad():
            z = self.fc_hidden(z)

            # prepare sos_token for the decoder
            sos_token = self.sos_token.repeat(1,1)
            sos_token = self.embedding(sos_token)
            sos_token = self.layer_norm(sos_token)


            # decoder pass where the input is the previous output
            output = sos_token
            for _ in range(len_sample):
                outputs, _ = self.decoder(output, z)
                outputs = self.fc(outputs)
                next_token = torch.argmax(F.softmax(outputs[:,-1,:], dim = -1), dim=-1)
                #next_token = torch.multinomial(F.softmax(outputs[:,-1,:], dim = -1), 1)
                next_token = self.embedding(next_token)
                next_token = self.layer_norm(next_token)
                output = torch.cat((output, next_token.unsqueeze(1)), dim=1)
                #output = torch.cat((output, next_token), dim=1)
       
        
        if sample_type == 'argmax':
            output = torch.argmax(F.softmax(outputs.mean(0), dim = -1), dim = -1)

        elif sample_type == 'multinomial':
            output = torch.multinomial(F.softmax(outputs.mean(0), dim = -1), 1)
            

        sampled_text = [self.idx2word[w.item()] for w in output]

        return ' '.join(sampled_text)
    

    

    def number_parameters(self):

        model_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

        print('Total number of model parameters: ', model_params)

        return None

In [None]:
class GRUVAE(RNNVAE, nn.Module):
    def __init__(self, embedding_matrix, word2idx, idx2word, hidden_dim, latent_dim, num_layers, sos_token, vocab_size):
        super().__init__(embedding_matrix, word2idx, idx2word, hidden_dim, latent_dim, num_layers, sos_token, vocab_size)

        self.embedding_dim = embedding_matrix.shape[1]
        self.word2idx = word2idx
        self.idx2word = idx2word
        self.sos_token = sos_token
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze = True)
        self.layer_norm = nn.LayerNorm(self.embedding_dim)

        self.encoder = nn.GRU(self.embedding_dim, hidden_dim, num_layers, batch_first = True)
        self.decoder = nn.GRU(self.embedding_dim, hidden_dim, num_layers, batch_first = True)

        self.fc_mu = nn.Linear(hidden_dim, latent_dim)  
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc_hidden = nn.Linear(latent_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, vocab_size)

In [None]:
def divide_text(text, sequence_length):
    ''' Function dividing text in order to feed the Word2vec model
    
    Inputs
    ----------
    text : text corpus from a file
    sequence_length : int
    
    
    Returns
    ----------
    output_text : 2D list of words with shape [text_length/sequence_length, sequence_length]'''

    words = text.split()
    grouped_words = [' '.join(words[i:i+sequence_length]) for i in range(0,len(words),int(sequence_length/2))]  
    output_text = [grouped_words[i].split() for i in range(len(grouped_words))]

    return output_text

In [None]:
def divide_text_equal_seq_length(text, sequence_length):
    ''' Function dividing text in order to feed the Word2vec model
    
    Inputs
    ----------
    text : text corpus from a file
    sequence_length : int
    
    
    Returns
    ----------
    output_text : 2D list of words with shape [text_length/sequence_length, sequence_length]'''

    words = text.split()
    grouped_words = [' '.join(words[i:i+sequence_length]) for i in range(0,len(words),int(sequence_length/2))]  
    output_text = [grouped_words[i].split() for i in range(len(grouped_words)) if len(grouped_words[i].split()) == sequence_length]

    return output_text

In [None]:
def custom_dataset(txt_file : str, sequence_length : int, embedding_dim : int, batch_size : int, training_fraction : float):
    ''' Function creating dataset
    
    Inputs
    ----------
    file1 : str, name of the file containing the text corpus
    sequence_length : int
    embedding_dim : int, number of dimension for the embedded words using Word2vec model
    batch_size : int
    training_fraction : float, fraction of training data
    
    
    Returns
    ----------
    dataloader_train : istance of torch.utils.data.Dataloader, training data
    dataloader_val : istance of torch.utils.data.Dataloader, validation data
    embedding_dim : int
    embedding_matrix : 2d torch tensor matrix from word2vec embedding
    word2vec : trained Word2vec model
    idx2word : dictionary from indices to words
    word2idx : dictionart from words to indices
    vocab_size : int, number of unique tokens
    style0_test : torch tensor containing every test data belonging to first style
    style1_test : torch tensor containing every test data belonging to second style
    style3_test : torch tensor containing every test data belonging to third style'''

    # reading the two corpus
    with open(txt_file, 'r', encoding='utf-8') as f:
        text = f.read()


    text = '<sos> ' + text
    # divide the whole text to feed the Word2vec model
    divided_text = divide_text(text, sequence_length)

    # training the Word2vec model with the whole corpus
    word2vec = Word2Vec(divided_text, vector_size = embedding_dim, window = sequence_length, min_count=1, workers=4, epochs = 30)
    word2vec.train(divided_text, total_examples=word2vec.corpus_count, epochs=word2vec.epochs)

    # Get the embedding dimension
    embedding_dim = word2vec.wv.vector_size

    # Prepare the embedding matrix
    vocab_size = len(word2vec.wv)
    embedding_matrix = np.zeros((vocab_size, embedding_dim))
    word2idx = {word: idx for idx, word in enumerate(word2vec.wv.index_to_key)}
    idx2word = {idx: word for idx, word in enumerate(word2vec.wv.index_to_key)}

    # creating the embedding matrix from the trained Word2vec model
    for word, idx in word2idx.items():
        embedding_matrix[idx] = word2vec.wv[word]

    
    embedding_matrix = torch.tensor(embedding_matrix, dtype=torch.float32)


    
    dataset = divide_text_equal_seq_length(text, sequence_length)
    dataset = torch.LongTensor([[word2idx[char] for char in dataset[i]] for i in range(len(dataset))])


    train_data = dataset[ : int(training_fraction * dataset.shape[0])]
    val_data = dataset[int(training_fraction * dataset.shape[0]) : ]

    dataset_train = TensorDataset(train_data)

    # Create a training DataLoader with shuffling enabled
    dataloader_train = DataLoader(dataset_train, batch_size = batch_size, shuffle = True)


    dataset_val = TensorDataset(val_data)

    # Create a validation DataLoader with shuffling enabled
    dataloader_val = DataLoader(dataset_val, batch_size = batch_size, shuffle = True)
    
    
    return dataloader_train, dataloader_val, embedding_dim, embedding_matrix, word2vec, idx2word, word2idx, vocab_size

In [None]:
def vae_loss(recon_x, x, mu, logvar, l_kl = 0.05, loss_fn = nn.CrossEntropyLoss()):
    ''' Function computing loss function for classification
    
    Inputs
    ---------
    pred_labels : 3D torch tensor with predicted labels with shape [1, Batch size, 3]
    labels : 2D torch tensor with ground truth labels with shape [Batch size, 3]
    
    Returns
    ---------
    L : float, loss value '''

    L = loss_fn(recon_x.reshape((recon_x.size(0)*recon_x.size(1),recon_x.size(2))), x.view(-1))
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return L + l_kl * KLD

# Sigmoid controlling l_kl:
* k = 0.183, t0 = 20
* k = 0.11, t0 = 30


k = - ln(32.33333)/(1-t0)

In [None]:
def sigmoid(x, k = 0.183, t0 = 20):
    
    return 1/(1 + np.exp(-k*(x-t0)))

In [None]:
def training(model, train_loader, val_loader, num_epochs, lr = 4e-4, title = 'Training'):
    ''' Training function
    
    Input
    --------
    model : istance of a CNNClassifier, RNNClassifier, GRUClassifier, LSTMClassifier or TClassifier
    train_loader : istance of torch Dataloader with training data and labels
    val_loader : istance of torch Dataloader with validation data and labels
    num_epochs : int, number of epochs
    lr : float, learning rate for Adam optimizer
    title : str, Title of the matplot figure
    
    Returns
    --------
    train_losses : list with train loss values '''

    params = list(model.parameters())

    # Optimizer
    optimizer = torch.optim.Adam(params, lr = lr)

    train_losses = []
    val_losses = []

    #l_kl = 0.05

    # For loop over epochs
    for epoch in tqdm(range(num_epochs)):
        l_kl = sigmoid(epoch + 1)
        train_loss = 0.0
        average_loss = 0.0
        val_loss = 0.0
        average_val_loss = 0.0

        # For loop for every batch
        for  i, (inputs) in enumerate(train_loader):
            inputs[0] = inputs[0].to(device)
            

            optimizer.zero_grad()
            

            # forward pass through classifier
            recon_x, mu, logvar = model(inputs[0])
    
            # comuting training loss
            loss = vae_loss(recon_x.to(device),
                            inputs[0].to(device),
                            mu.to(device),
                            logvar.to(device),
                            l_kl)
            
            loss.backward()
            train_loss += loss.item()


            optimizer.step()
            
            if (i + 1) % 5000 == 0:
                print(f'Train Epoch: {epoch+1} [{i * len(inputs)}/{len(train_loader.dataset)} ({100. * i / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(inputs):.6f}')
        
        
        # Validation
        with torch.no_grad():
            for i, (inputs) in enumerate(val_loader):
                inputs[0] = inputs[0].to(device)
    

                # forward pass through classifier
                recon_x, mu, logvar = model(inputs[0])
                
                
                # comuting validation loss
                val_loss_tot = vae_loss(recon_x.to(device),
                                        inputs[0].to(device),
                                        mu.to(device),
                                        logvar.to(device),
                                        l_kl)
                
                val_loss += val_loss_tot.item()


                
                if (i + 1) % 5000 == 0:
                    print(f'Train Epoch: {epoch+1} [{i * len(inputs)}/{len(val_loader.dataset)} ({100. * i / len(val_loader):.0f}%)]\tLoss: {val_loss_tot.item() / len(inputs):.6f}')
            
            
        # Computing average training and validation loss
        average_loss = train_loss / len(train_loader.dataset)
        train_losses.append(average_loss)

        average_val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(average_val_loss)
        
        # printing average training and validation losses
        print(f'====> Epoch: {epoch+1} Average train loss: {average_loss:.4f}, Average val loss: {average_val_loss:.4f}')
    

    # Plotting training and validation curve at the end of the for loop 
    plt.plot(np.linspace(1,num_epochs,len(train_losses)), train_losses, c = 'darkcyan',label = 'train')
    plt.plot(np.linspace(1,num_epochs,len(val_losses)), val_losses, c = 'orange',label = 'val')
    plt.legend()
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title(title)
    plt.show()

    return train_losses

In [None]:
sequence_length = 25
batch_size = 64
embedding_dim = 300
hidden_dim = 256
latent_dim = 136

In [None]:
train_loader, val_loader, embedding_dim, embedding_matrix, word2vec, idx2word, word2idx, vocab_size = custom_dataset('divina_commedia.txt',
                                                                                                                    sequence_length,
                                                                                                                    embedding_dim,
                                                                                                                    batch_size,
                                                                                                                    0.9)

print('total number of training samples: ', len(train_loader.dataset))
print('total number of validation samples: ', len(val_loader.dataset))
print('vocab size: ', vocab_size)

In [None]:
sos_token = torch.full((1,),word2idx['<sos>'])
sos_token = sos_token.type(torch.LongTensor)

In [None]:
vae = RNNVAE(embedding_matrix, word2idx, idx2word, hidden_dim, latent_dim, 3, sos_token, vocab_size)
vae.number_parameters()

In [None]:
vae = GRUVAE(embedding_matrix, word2idx, idx2word, hidden_dim, latent_dim, 3, sos_token, vocab_size)
vae.number_parameters()

In [None]:
train_losses = training(vae, train_loader, val_loader, 8, lr = 4e-4, title = 'RNN VAE Training')

In [None]:
for i ,(data) in enumerate(val_loader):
    if i == 0:
        sentence = data[0][0]
    else: 
        break

input_sentence = [idx2word[sentence[i].item()] for i in range(sentence.shape[0])]

sentence = sentence.view(1,sentence.shape[0])


reconstructed_sequence = vae.reconstruction(sentence, 'argmax')

reconstructed_sequence2 = vae.reconstruction(sentence)
    


'''indices = torch.argmax(reconstructed_sequence, dim=-1).squeeze(0)
indices2 = torch.multinomial(F.softmax(reconstructed_sequence.squeeze(0),dim=-1), 1)


reconstructed_sequence = []
for i in range(sentence.shape[1]):
    reconstructed_sequence.append(idx2word[indices[i].item()])'''



print("Input sequence: \n", ' '.join(input_sentence))
print("\nReconstructed sequence ARGMAX: \n", reconstructed_sequence)
print("\nReconstructed sequence MULTINOMIAL: \n", reconstructed_sequence2)

In [None]:
for i ,(data) in enumerate(val_loader):
    if i == 0:
        sentence = data[0][0]
    else: 
        break

input_sentence = [idx2word[sentence[i].item()] for i in range(sentence.shape[0])]

sentence = sentence.view(1,sentence.shape[0])

with torch.no_grad():
    reconstructed_sequence, _, _ = vae.forward(sentence)
    


indices = torch.argmax(reconstructed_sequence, dim=-1).squeeze(0)
indices2 = torch.multinomial(F.softmax(reconstructed_sequence.squeeze(0),dim=-1), 1)


reconstructed_sequence = []
for i in range(sentence.shape[1]):
    reconstructed_sequence.append(idx2word[indices2[i].item()])



print("Input sequence: \n", ' '.join(input_sentence))
print("\nReconstructed sequence 2: \n", ' '.join(reconstructed_sequence))

In [None]:
vae.sample(30)