In [1]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader,TensorDataset
from gensim.models import Word2Vec

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
class VAE(nn.Module):
    def __init__(self, embedding_matrix, hidden_dim, latent_dim, sos_token, vocab_size, num_layers=1):
        super(VAE, self).__init__()
        self.embedding_dim = embedding_matrix.shape[1]
        self.vocab_size = vocab_size
        self.sos_token = sos_token

        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=True)
        self.encoder = nn.GRU(self.embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fcmu = nn.Linear(hidden_dim, latent_dim)
        self.fclogvar = nn.Linear(hidden_dim, latent_dim)
        self.hidden_to_latent = nn.Linear(hidden_dim, latent_dim)

        self.decoder= nn.GRU(self.embedding_dim, latent_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(latent_dim, vocab_size)

    def reparametrization(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def forward(self, x):
        embedded_input = self.embedding(x)
        _, hidden = self.encoder(embedded_input)
        mu = self.fcmu(hidden)
        logvar = self.fclogvar(hidden)

        z = self.reparametrization(mu, logvar)
        #z = self.hidden_to_latent(hidden)
        decoder_inputs = torch.cat(self.sos_token, embedded_input, dim = 1)
        output, _ = self.decoder(decoder_inputs, z)
        '''decoder_inputs = embedded_input
        output = []
        for t in range(x.size(1)):
            output_sequence, _ = self.decoder(decoder_inputs[:,t:t+1], z)
            #decoder_inputs = output_sequence
            output.append(output_sequence)

        output = torch.cat(output, dim = 1)'''
        reconstructed_sequence = self.fc(output)
        reconstructed_sequence = torch.softmax(reconstructed_sequence, dim = 2)

        return mu, logvar, reconstructed_sequence

In [30]:
class VAE(nn.Module):
    def __init__(self, embedding_matrix, hidden_dim, latent_dim, sequence_length, num_layers=1):
        super(VAE, self).__init__()
        self.embedding_dim = embedding_matrix.shape[1]

        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=True)
        self.encoder = nn.GRU(self.embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fcmu = nn.Linear(hidden_dim, latent_dim)
        self.fclogvar = nn.Linear(hidden_dim, latent_dim)
        self.hidden_to_latent = nn.Linear(hidden_dim, latent_dim)

        self.decoder= nn.GRU(self.embedding_dim, latent_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(latent_dim, self.embedding_dim)

    def reparametrization(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def forward(self, x):
        embedded_input = self.embedding(x)
        _, hidden = self.encoder(embedded_input)
        mu = self.fcmu(hidden)
        logvar = self.fclogvar(hidden)

        #z = self.reparametrization(mu, logvar)
        z = self.hidden_to_latent(hidden)
        decoder_inputs = torch.zeros(x.size(0), x.size(1), embedded_input.size(2))
        output, _ = self.decoder(decoder_inputs, z)
        '''decoder_inputs = embedded_input
        output = []
        for t in range(x.size(1)):
            output_sequence, _ = self.decoder(decoder_inputs[:,t:t+1], z)
            #decoder_inputs = output_sequence
            output.append(output_sequence)

        output = torch.cat(output, dim = 1)'''
        reconstructed_sequence = self.fc(output)

        return mu, logvar, reconstructed_sequence, embedded_input
        

# usa questo

In [27]:
class Encoder(nn.Module):
    def __init__(self, embedding_matrix, hidden_dim, latent_dim, num_layers=1):
        super(Encoder, self).__init__()
        embedding_dim = embedding_matrix.shape[1]

        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=True)
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fcmu = nn.Linear(hidden_dim, latent_dim)
        self.fclogvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        embedded_input = self.embedding(x)
        _, hidden = self.gru(embedded_input)
        mu = self.fcmu(hidden)
        logvar = self.fclogvar(hidden)
        return mu, logvar, hidden, embedded_input
    
    def reparametrization(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + eps*std
        
    
class Decoder(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, latent_dim, sequence_length, num_layers=1):
        super(Decoder, self).__init__()
        #self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.gru = nn.GRU(embedding_dim, latent_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(latent_dim, embedding_dim)
        self.fc_vocab = nn.Linear(latent_dim, sequence_length)

    def forward(self, x, hidden):
        #x = self.embedding(x)
        output, _ = self.gru(x, hidden)
        output = self.fc(output)
        #output = self.fc_vocab(output)
        #output = output.mean(dim=2)
        #output = torch.sigmoid(output)
        #output = torch.exp(output)
        return output
    
class VAE(nn.Module):
    def __init__(self, embedding_matrix, hidden_dim, latent_dim, sequence_length, num_layers=1):
        super(VAE, self).__init__()
        self.embedding_dim = embedding_matrix.shape[1]
        self.sequence_length = sequence_length
        #self.sos_token = self.sos_token.type(torch.FloatTensor)
        
        self.encoder = Encoder(embedding_matrix, hidden_dim, latent_dim)
        self.decoder = Decoder(self.embedding_dim, hidden_dim, latent_dim, sequence_length)
        self.latent_to_hidden = nn.Linear(latent_dim, hidden_dim)
        self.latent_to_out = nn.Linear(latent_dim, self.embedding_dim)
        self.prova1 = nn.Linear(hidden_dim,latent_dim)
        self.cell = nn.GRUCell(self.embedding_dim, latent_dim)

    def forward(self, x):
        mu, logvar, hidden, embedded_input = self.encoder(x)
        z = self.encoder.reparametrization(mu, logvar)
        # Use <SOS> token for the initial input to the decoder
        #sos_token = torch.FloatTensor([[1]]).repeat(x.size(0), x.size(1), self.embedding_dim).to(x.device)
        decoder_inputs = torch.zeros(x.size(0), x.size(1), embedded_input.size(2))
        output_sequence = self.decoder(decoder_inputs, z)
        
        '''output_sequence = []
        for t in range(x.size(1)):
            outputs = self.decoder(embedded_input[:,t:t+1], z)
            output_sequence.append(outputs)
            
        output_sequence = torch.cat(output_sequence, dim=1)
        
        for t in range(x.size(1)):
            out = self.cell(embedded_input[:,t],z.squeeze(0))
            output_sequence.append(out)

        output_sequence = torch.stack(output_sequence,dim=1)
        output_sequence = self.latent_to_out(output_sequence)
        output_sequence = output_sequence.mean(dim=2)
        output_sequence = torch.sigmoid(output_sequence)'''
        return mu, logvar, output_sequence, embedded_input
    
    def decode(self, z):
        sos_token = self.sos_token.repeat(1, self.sequence_length,1)
        output = self.decoder(sos_token, z)
        return output

In [4]:
input = torch.FloatTensor(size=(32,15,300))

In [6]:
input = input.view(-1,input.size(2))

In [7]:
input.shape

torch.Size([480, 300])

In [229]:
uno = torch.FloatTensor(size=(32,136))
due = torch.FloatTensor(size=(32,136))
tre = torch.stack((uno,due),dim=1)

In [230]:
tre.shape

torch.Size([32, 2, 136])

# questi sotto non usarli!

# Train function

In [8]:
def vae_loss(recon_x, x, mu, log_var, l_s = 0.5, loss_fn = nn.MSELoss(), cos_loss = nn.CosineSimilarity(dim=2), CE = nn.CrossEntropyLoss()):
    #BCE = loss_fn(recon_x, x)
    #BCE = 1 - cos_loss(recon_x,x).mean()
    BCE = CE(recon_x.view(-1, recon_x.size(2)),x.view(-1))
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + l_s*KLD

In [9]:
def training_VAE(vae, train_loader, val_loader, num_epochs, vocab_size, lr = 1e-3):
    params = list(vae.parameters())

    optimizer = torch.optim.Adam(params, lr = lr)

    train_losses = []
    val_losses = []

    for epoch in tqdm(range(num_epochs)):
        train_loss = 0.0
        average_loss = 0.0
        val_loss = 0.0
        average_val_loss = 0.0

        for data,_,_ in train_loader:
            data = data.to(device)
            
            optimizer.zero_grad()

            mu, log_var, reconstructed_data = vae(data)
            
            loss = vae_loss(reconstructed_data, data, mu, log_var)
            loss.backward()
            train_loss += loss.item()

            optimizer.step()

        average_loss = train_loss / len(train_loader.dataset)
        print(f'====> Epoch: {epoch+1} Average loss: {average_loss:.4f}')
        train_losses.append(average_loss)

        with torch.no_grad():
            for data,_,_ in val_loader:
                data = data.to(device)
                

                mu, log_var, reconstructed_data = vae(data)

                loss = vae_loss(reconstructed_data, data, mu, log_var)
                val_loss += loss.item()

        
        average_val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(average_val_loss)

    plt.plot(np.linspace(1,num_epochs,len(train_losses)), train_losses, c = 'darkcyan',label = 'train')
    plt.plot(np.linspace(1,num_epochs,len(val_losses)), val_losses, c = 'orange',label = 'val')
    plt.legend()
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.show()
    return train_losses

In [10]:
def BoW(tensor):
    bow = torch.zeros(size = (tensor.shape[0],tensor.shape[1]))
    #BoW = [(data1[i] == num).sum().item()/data1.shape[1]  for i in range(data1.shape[0]) for num in data1[i] if BoW[i][torch.where(data1[i] == num)[0][0].item()]==0]

    for i in range(tensor.shape[0]):
        for num in tensor[i]:
            index = torch.where(tensor[i] == num)[0][0].item()
            bow[i][index] = (tensor[i] == num).sum().item()/tensor.shape[1]

    return torch.FloatTensor(bow)

In [11]:
def divide_text(text, sequence_length):
    words = text.split()
    #words = text
    grouped_words = [' '.join(words[i:i+sequence_length]) for i in range(0,len(words),int(sequence_length/4))]  # range (0,len(words),8)
    #grouped_words = [' '.join(words[i:i+sequence_length]) for i in range(0,len(words),2)]
    #grouped_words = [words[i] for i in range(0,len(words),19)]
    #grouped_words_2d = [sentence.split() for sentence in grouped_words]
    output_text = [grouped_words[i].split() for i in range(len(grouped_words)) if len(grouped_words[i].split()) == sequence_length]
    return output_text

In [25]:
def custom_dataset(file1 : str,file2 : str, sequence_length, embedding_dim, batch_size, training_fraction):

    with open(file1, 'r', encoding='utf-8') as f:
        text1 = f.read()


    with open(file2, 'r', encoding='utf-8') as f:
        text2 = f.read()

    text1 = '<sos> ' + text1
    text = text1 + ' ' + text2
    divided_text = divide_text(text, sequence_length)

    #word2vec = Word2Vec(divided_text, vector_size = embedding_dim, window = int(sequence_length/2), min_count=1, workers=4)
    word2vec = Word2Vec(divided_text, vector_size = embedding_dim, window = 5, min_count=1, workers=4, epochs=50)
    #word2vec.train(divided_text, total_examples=word2vec.corpus_count, epochs=20)

    # Get the embedding dimension
    embedding_dim = word2vec.wv.vector_size

    # Prepare the embedding matrix
    vocab_size = len(word2vec.wv)
    embedding_matrix = np.zeros((vocab_size, embedding_dim))
    word2idx = {word: idx for idx, word in enumerate(word2vec.wv.index_to_key)}
    idx2word = {idx: word for idx, word in enumerate(word2vec.wv.index_to_key)}

    for word, idx in word2idx.items():
        embedding_matrix[idx] = word2vec.wv[word]

    # Convert to PyTorch tensor
    embedding_matrix = torch.tensor(embedding_matrix, dtype=torch.float32)



    text1_divided = divide_text(text1, sequence_length)
    data1 = torch.LongTensor([[word2idx[char] for char in text1_divided[i]] for i in range(len(text1_divided))])


    text2_divided = divide_text(text2, sequence_length)
    data2 = torch.LongTensor([[word2idx[char] for char in text2_divided[i]] for i in range(len(text2_divided))])


    data1_train = data1[:int(training_fraction * data1.shape[0])]
    data1_val = data1[int(training_fraction * data1.shape[0]):]

    data2_train = data2[:int(training_fraction * data2.shape[0])]
    data2_val = data2[int(training_fraction * data2.shape[0]):]


    label0_train = torch.zeros(data1_train.shape[0])
    label0_val = torch.zeros(data1_val.shape[0])

    label1_train = torch.ones(data2_train.shape[0])
    label1_val = torch.ones(data2_val.shape[0])


    labels_train = torch.cat((label0_train, label1_train), dim = 0)
    labels_val = torch.cat((label0_val, label1_val), dim = 0)

    data_train = torch.cat((data1_train, data2_train), dim = 0)
    data_val = torch.cat((data1_val, data2_val), dim = 0)

    data_train = torch.LongTensor(data_train)
    labels_train = labels_train.type(torch.LongTensor)
    bow_train = BoW(data_train)

    dataset_train = TensorDataset(data_train, bow_train, labels_train)

    # Create a DataLoader with shuffling enabled
    dataloader_train = DataLoader(dataset_train, batch_size = batch_size, shuffle=True)
    #dataloader_train = DataLoader(dataset_train, batch_size = batch_size)


    data_val = torch.LongTensor(data_val)
    labels_val = labels_val.type(torch.LongTensor)
    bow_val = BoW(data_val)

    dataset_val = TensorDataset(data_val, bow_val, labels_val)

    # Create a DataLoader with shuffling enabled
    dataloader_val = DataLoader(dataset_val, batch_size = batch_size, shuffle = True)
    #dataloader_val = DataLoader(dataset_val, batch_size = batch_size)

    return dataloader_train, dataloader_val, embedding_dim, embedding_matrix, word2vec, idx2word, vocab_size

In [26]:
sequence_length = 15
embedding_dim = 300
hidden_dim = 256
latent_dim = 136
batch_size = 32

In [27]:
train_loader, val_loader, embedding_dim, embedding_matrix, word2vec, idx2word, vocab_size = custom_dataset('divina_commedia.txt', 
                                                                                     'divina_commedia.txt', 
                                                                                     sequence_length, 
                                                                                     embedding_dim,
                                                                                     batch_size = batch_size, 
                                                                                     training_fraction = 0.9)
print('len train loader: ', len(train_loader))

len train loader:  1905


In [15]:
vocab_size

12761

In [51]:
sos_token = word2vec.wv['<sos>']
word2idx = {word: idx for idx, word in enumerate(word2vec.wv.index_to_key)}

sos_token = torch.FloatTensor(sos_token)
#sos_token.shape
#idx2word[20249]
sos_index = word2idx['<sos>']

In [53]:
sos_token = torch.full((32,1),word2idx['<sos>'])

In [58]:
sos_token = sos_token.type(torch.LongTensor)

In [44]:
vae = VAE(embedding_matrix, hidden_dim, latent_dim, sequence_length, vocab_size)

In [45]:
vae_params = sum(p.numel() for p in vae.parameters() if p.requires_grad)
print('Total parameters: ', vae_params)

Total parameters:  2460498


In [46]:
losses = training_VAE(vae, train_loader, val_loader, 5, vocab_size)

  0%|          | 0/5 [00:00<?, ?it/s]


TypeError: cat() received an invalid combination of arguments - got (int, Tensor, dim=int), but expected one of:
 * (tuple of Tensors tensors, int dim, *, Tensor out)
 * (tuple of Tensors tensors, name dim, *, Tensor out)


In [44]:
for i ,(data,bow,label) in enumerate(val_loader):
    if i == 0:
        prova = data[0]
        labels = label[0]
        boww = bow[0]

frase = [idx2word[prova[i].item()] for i in range(prova.shape[0])]

prova = prova.view(1,prova.shape[0])

with torch.no_grad():
    mu, log_var, reconstructed, embedded_input = vae(prova)

reconstructed = reconstructed.view(reconstructed.shape[1], reconstructed.shape[2])
ricostruzione = []
for i in range(reconstructed.shape[0]):
    ricostruzione.append((word2vec.wv.most_similar(np.array(reconstructed[i]),topn=1)[0][0]))

if labels.item() == 0.0:
    stile = 'Dante'
else: 
    stile = 'Italiano'

print('Stile: ', stile)
print("Input sequence: \n", ' '.join(frase))
print("\nReconstructed sequence: \n", ' '.join(ricostruzione))

Stile:  Dante
Input sequence: 
 di tal gloria o sodalizio eletto a la gran cena del benedetto agnello il qual

Reconstructed sequence: 
 a la sodalizio la sodalizio sodalizio e l sodalizio del del cena cena che che


In [120]:
z = torch.randn(1,1, hidden_dim)

with torch.no_grad():
    out = vae.decode(z)

print(out.shape)
out = out.view(out.shape[1],out.shape[2])

nuova_frase = []
for i in range(out.shape[0]):
    nuova_frase.append((word2vec.wv.most_similar(np.array(out[i]),topn=1)[0][0]))

print("\nNew sequence: \n", ' '.join(nuova_frase))

torch.Size([1, 15, 300])

New sequence: 
 poc che l poc che poc da e l poc poc poc e vèdeisi e


In [46]:
out.shape

torch.Size([15, 300])

In [35]:
reconstructed.shape

torch.Size([1, 15, 300])

In [16]:
mu.shape

torch.Size([1, 136])

In [19]:
z = torch.randn(1,latent_dim).to(device)

with torch.no_grad():
    out = VAE.sample(prova,z)

ricostruzione = []
for i in range(out.shape[0]):
    ricostruzione.append((word2vec.wv.most_similar(np.array(out[i]),topn=1)[0][0]))

In [20]:
print("\nReconstructed sequence: \n", ' '.join(ricostruzione))


Reconstructed sequence: 
 laggiù all ospedale la nunziata si metteva a piangere anch essa e diceva di no


# prove

In [27]:
class Encoder(nn.Module):
    def __init__(self, input_dim, embed_dim, max_sequence_len, latent_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.bilstm1 = nn.LSTM(embed_dim, 256, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(0.15)
        self.bilstm2 = nn.LSTM(512, 128, bidirectional=True, batch_first=True)  # 512 because 256 * 2 for bidirectional
        self.z_mean = nn.Linear(256, latent_dim)  # 128 * 2 for bidirectional
        self.z_log_var = nn.Linear(256, latent_dim)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.bilstm1(x)
        x = self.dropout(x)
        _, (hn, _) = self.bilstm2(x)
        hn = torch.cat((hn[-2], hn[-1]), dim=1)  # Concatenate the final states of both directions
        z_mean = self.z_mean(hn)
        z_log_var = self.z_log_var(hn)
        return z_mean, z_log_var
    
class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim, sequence_length):
        super(Decoder, self).__init__()
        self.dense = nn.Linear(latent_dim, 128)
        self.repeat = nn.Linear(128, sequence_length * 64)  # Dense layer followed by reshape
        self.lstm = nn.LSTM(64, 64, batch_first=True)
        self.output_dense = nn.Linear(64, output_dim)
        
    def forward(self, x):
        x = self.dense(x)
        x = torch.relu(x)
        x = self.repeat(x)
        x = x.view(-1, sequence_length, 64)  # Reshape to (batch_size, max_sequence_len, 64)
        x, _ = self.lstm(x)
        x = self.output_dense(x)
        return x

In [28]:
def reparameterize(z_mean, z_log_var):
    std = torch.exp(0.5 * z_log_var)
    epsilon = torch.randn_like(std)
    return z_mean + std * epsilon

In [29]:
class VAE(nn.Module):
    def __init__(self, input_dim, embed_dim, max_sequence_len, latent_dim, output_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, embed_dim, max_sequence_len, latent_dim)
        self.decoder = Decoder(latent_dim, output_dim, max_sequence_len)

    def forward(self, x):
        z_mean, z_log_var = self.encoder(x)
        z = reparameterize(z_mean, z_log_var)
        decoded = self.decoder(z)
        return z_mean, z_log_var, decoded

In [30]:
input_dim = 10000  # Vocabulary size or number of unique tokens
embed_dim = 128    # Dimension of embeddings
max_sequence_len = 30  # Length of sequences
latent_dim = 50    # Size of the latent vector
output_dim = 10000  # Output dimension (same as input_dim for token probabilities)

# Create the model
vae = VAE(input_dim, embed_dim, max_sequence_len, latent_dim, output_dim)

# Example input: batch_size=32, sequence length=30
example_input = torch.randint(0, input_dim, (32, max_sequence_len))

# Forward pass
z_mean, z_log_var, decoded = vae(example_input)

print("z_mean shape:", z_mean.shape)  # Should be (32, latent_dim)
print("decoded shape:", decoded.shape)  # Should be (32, max_sequence_len, output_dim)

z_mean shape: torch.Size([32, 50])
decoded shape: torch.Size([64, 15, 10000])


In [10]:
class Encoder(nn.Module):
    def __init__(self, input_dim, embed_dim, max_sequence_len, latent_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.bilstm1 = nn.LSTM(embed_dim, 256, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(0.15)
        self.bilstm2 = nn.LSTM(512, 128, bidirectional=True, batch_first=True)  # 512 because 256 * 2 for bidirectional
        self.z_mean = nn.Linear(256, latent_dim)  # 128 * 2 for bidirectional
        self.z_log_var = nn.Linear(256, latent_dim)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.bilstm1(x)
        x = self.dropout(x)
        _, (hn, _) = self.bilstm2(x)
        h_n = torch.cat((hn[-2], hn[-1]), dim=1)  # Concatenate the final states of both directions
        z_mean = self.z_mean(h_n)
        z_log_var = self.z_log_var(h_n)
        return z_mean, z_log_var, hn

def reparameterize(z_mean, z_log_var):
    std = torch.exp(0.5 * z_log_var)
    epsilon = torch.randn_like(std)
    return z_mean + std * epsilon

class Decoder(nn.Module):
    def __init__(self, latent_dim, vocab_size, max_sequence_len):
        super(Decoder, self).__init__()
        self.dense = nn.Linear(latent_dim, 128)
        self.repeat_vector = nn.Linear(128, max_sequence_len * 64)
        self.lstm = nn.LSTM(64, 64, batch_first=True)
        self.output_dense = nn.Linear(64, vocab_size)

    def forward(self, x):
        x = self.dense(x)
        x = torch.relu(x)
        x = self.repeat_vector(x)
        x = x.view(-1, max_sequence_len, 64)  # Reshape to (batch_size, max_sequence_len, 64)
        x, _ = self.lstm(x)
        x = self.output_dense(x)
        return x

class VAE(nn.Module):
    def __init__(self, input_dim, embed_dim, max_sequence_len, latent_dim, vocab_size):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, embed_dim, max_sequence_len, latent_dim)
        self.decoder = Decoder(latent_dim, vocab_size, max_sequence_len)

    def forward(self, x):
        z_mean, z_log_var, hn = self.encoder(x)
        z = reparameterize(z_mean, z_log_var)
        decoded = self.decoder(z)
        return z_mean, z_log_var, decoded, hn

In [11]:
input_dim = 10000  # Vocabulary size or number of unique tokens
embed_dim = 128    # Dimension of embeddings
max_sequence_len = 15  # Length of sequences
latent_dim = 50    # Size of the latent vector
vocab_size = 10000  # Output dimension (same as input_dim for token probabilities)

# Create the model
vae = VAE(input_dim, embed_dim, max_sequence_len, latent_dim, vocab_size)

# Example input: batch_size=64, sequence length=15
example_input = torch.randint(0, input_dim, (64, max_sequence_len))

# Forward pass
z_mean, z_log_var, decoded , hn = vae(example_input)

print("z_mean shape:", z_mean.shape)  # Should be (64, latent_dim)
print("decoded shape:", decoded.shape)  # Should be (64, max_sequence_len, vocab_size)

z_mean shape: torch.Size([64, 50])
decoded shape: torch.Size([64, 15, 10000])


In [12]:
example_input.shape

torch.Size([64, 15])

In [13]:
hn.shape

torch.Size([2, 64, 128])

In [14]:
s = torch.cat((hn[-2], hn[-1]), dim=1)

In [15]:
s.shape

torch.Size([64, 256])

In [17]:
hn[-2].shape

torch.Size([64, 128])