In [2]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader,TensorDataset
from gensim.models import Word2Vec
from models import RNNcVAE, GRUcVAE, LSTMcVAE, CNNClassifier
from config_dataset import custom_dataset
from accuracy import style_accuracy
from training_function import training, train_CNN

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
sequence_length = 30
batch_size = 128
embedding_dim = 300
hidden_dim = 128
latent_dim = 100

In [5]:
train_loader, val_loader, embedding_dim, embedding_matrix, word2vec, idx2word, word2idx, vocab_size = custom_dataset('text_corpus/divina_commedia.txt',
                                                                                                                     'text_corpus/uno_nessuno_e_i_malavoglia.txt',
                                                                                                                     'text_corpus/lo_cunto_de_li_cunti.txt',
                                                                                                                    sequence_length,
                                                                                                                    embedding_dim,
                                                                                                                    batch_size,
                                                                                                                    0.9)

print('total number of training samples: ', len(train_loader.dataset))
print('total number of validation samples: ', len(val_loader.dataset))
print('vocab size: ', vocab_size)

total number of training samples:  45550
total number of validation samples:  5062
vocab size:  26607


In [None]:
sos_token = torch.full((1,),word2idx['<sos>'])
sos_token = sos_token.type(torch.LongTensor)

## Independent CNN Classifier

In [6]:
CNN_classif = CNNClassifier(embedding_matrix, 3, 2, [3,3])

style_params = sum(p.numel() for p in CNN_classif.parameters() if p.requires_grad)
print('Total parameters: ', style_params)

In [12]:
CNN_classif.load_state_dict(torch.load('pretrained/cnn_classifier.pth'))

  CNN_classif.load_state_dict(torch.load('pretrained/cnn_classifier.pth'))


<All keys matched successfully>

In [9]:
wrong = 0.0
wrongy = []
with torch.no_grad():
    for i, (data, labels) in enumerate(val_loader):
        pred_style = torch.argmax(CNN_classif(data), dim=-1)
        label = torch.argmax(labels,dim=-1)
        wrong += torch.count_nonzero(label - pred_style)
        
print('Accuracy :', 1 - (wrong/(len(val_loader)*batch_size)).item(), '%')

Accuracy : 0.982812499627471 %


# RNN cVAE

In [None]:
rnn_cvae = RNNcVAE(embedding_matrix, idx2word, 3, hidden_dim, latent_dim, 2, sos_token, vocab_size)
rnn_cvae.number_parameters()

Total number of model parameters:  3648159


In [None]:
rnn_cvae.load_state_dict(torch.load('pretrained/rnn_cvae.pth'))

### Change style
* Dante
* Italian
* Neapolitan

In [None]:
style = 'Dante'

In [None]:
style_label = {'Dante' : torch.FloatTensor([1,0,0]), 'Italian' : torch.FloatTensor([0,1,0]), 'Neapolitan' : torch.FloatTensor([0,0,1])}

label = style_label[style]
sentence, perplexity = rnn_cvae.sample(label, 40, 0.7)

In [None]:
style_accuracy(rnn_cvae, CNN_classif, name = 'RNN')

# GRU cVAE

In [None]:
gru_cvae = GRUcVAE(embedding_matrix, idx2word, 3, hidden_dim, latent_dim, 2, sos_token, vocab_size)
gru_cvae.number_parameters()

Total number of model parameters:  4000415


In [None]:
gru_cvae.load_state_dict(torch.load('pretrained/gru_cvae.pth'))

  gru_cvae.load_state_dict(torch.load('pretrained/gru_cvae.pth'))


<All keys matched successfully>

### Change style
* Dante
* Italian
* Neapolitan

In [None]:
style = 'Dante'

In [None]:
style_label = {'Dante' : torch.FloatTensor([1,0,0]), 'Italian' : torch.FloatTensor([0,1,0]), 'Neapolitan' : torch.FloatTensor([0,0,1])}

label = style_label[style]
sentence, perplexity = gru_cvae.sample(label, 40, 0.7)

In [None]:
style_accuracy(gru_cvae, CNN_classif, name = 'GRU')

# LSTM cVAE

In [None]:
lstm_cvae = LSTMcVAE(embedding_matrix, idx2word, 3, hidden_dim, latent_dim, 2, sos_token, vocab_size)
lstm_cvae.number_parameters()

Total number of model parameters:  4088479


In [None]:
lstm_cvae.load_state_dict(torch.load('pretrained/lstm_cvae.pth'))

  lstm_cvae.load_state_dict(torch.load('pretrained/lstm_cvae.pth'))


<All keys matched successfully>

### Change style
* Dante
* Italian
* Neapolitan

In [None]:
style = 'Dante'

In [None]:
style_label = {'Dante' : torch.FloatTensor([1,0,0]), 'Italian' : torch.FloatTensor([0,1,0]), 'Neapolitan' : torch.FloatTensor([0,0,1])}

label = style_label[style]
sentence, perplexity = lstm_cvae.sample(label, 40, 0.7)

In [None]:
style_accuracy(lstm_cvae, CNN_classif, name = 'LSTM')