In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, path):
        with open(f"{path}.sources") as f:
            self.sources = f.readlines()
            self.vocab_src = set()
            for lines in self.sources:
                self.vocab_src.update(lines[1:-1])
                
        with open(f"{path}.targets") as f:
            self.targets = f.readlines()
            self.vocab_tgt = set()
            for lines in self.targets:
                self.vocab_tgt.update(lines)
            

    def __len__(self):
        return len(self.sources)

    def __getitem__(self, idx):
        return self.sources[idx][1:-1], self.targets[idx]
        

In [None]:
vocab_src = set()
vocab_tgt = set()
train_dataset = CustomDataset('./Data/A3 files/train')
eval_dataset = CustomDataset('./Data/A3 files/dev')
test_dataset = CustomDataset('./Data/A3 files/test')
vocab_src.update(train_dataset.vocab_src)
vocab_src.update(eval_dataset.vocab_src)
vocab_src.update(test_dataset.vocab_src)
vocab_tgt.update(train_dataset.vocab_tgt)
vocab_tgt.update(eval_dataset.vocab_tgt)
vocab_tgt.update(test_dataset.vocab_tgt)
vocab_src = list(vocab_src)
vocab_tgt = list(vocab_tgt)
temp = {}
for ind, key in enumerate(vocab_src):
    temp[key] = ind
vocab_src = temp
temp = {}
for ind, key in enumerate(vocab_tgt):
    temp[key] = ind
vocab_tgt = temp
vocab_src["END"] = 84
vocab_src["PAD"] = 85
vocab_tgt["STR"] = 44
vocab_tgt["END"] = 45
vocab_tgt["PAD"] = 46

In [None]:
train_loader = DataLoader(train_dataset, batch_size = 32, shuffle = True)
eval_loader = DataLoader(eval_dataset, batch_size = 32, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 32, shuffle = True)

In [None]:
class Encoder(nn.Module):
    def __init__(self, dims = 512, hidden_size = 512,num_layers = 2, max_src = 500, max_tgt = 500):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(len(vocab_src),dims)
        self.input_size = dims
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.max_src = max_src
        self.max_tgt = max_tgt
        self.encoder = nn.LSTM(
            dims, hidden_size, num_layers, batch_first = True,bidirectional=True, dropout = 0.5
        )
        
    def encode_inp(self, x):
        encoded_x = torch.zeros(len(x), self.max_src, dtype = int) + 85
        for i in range(len(x)):
            for j in range(len(x[i])):
                encoded_x[i][j] = vocab_src[x[i][j]]
            encoded_x[i][len(x[i])] = vocab_src["END"]
        return encoded_x
    
    def forward(self, x):
        encoded_x = self.encode_inp(x)
        input_seq = self.embedding(encoded_x)
        hidden = torch.zeros(2*self.num_layers,input_seq.shape[0],self.hidden_size)
        cell = torch.zeros(2*self.num_layers,input_seq.shape[0],self.hidden_size)
        out, _ = self.encoder(input_seq,(hidden, cell))
        return out

In [None]:
class Decoder(nn.Module):
    def __init__(self,dims = 512, hidden_size = 512,num_layers = 2, max_src = 500, max_tgt = 500):
        super(Decoder, self).__init__()
#         self.embedding = embedding
        self.embedding = nn.Embedding(len(vocab_tgt),dims)
        self.input_size = dims
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.max_src = max_src
        self.max_tgt = max_tgt
        self.dec_cells = nn.ModuleList([nn.LSTMCell(2*hidden_size+dims, hidden_size), nn.LSTMCell(hidden_size, hidden_size)])
        self.linear = nn.Linear(hidden_size,len(vocab_tgt)-1)
    
    def encode_inp(self, x):
        encoded_x = torch.zeros(len(x), self.max_tgt+1, dtype = int) + 46
        encoded_x_end = torch.zeros(len(x))
        for i in range(len(x)):
            encoded_x[i][0] = vocab_tgt["STR"]
            for j in range(len(x[i])):
                encoded_x[i][j+1] = vocab_tgt[x[i][j]]
            encoded_x[i][len(x[i])+1] = vocab_tgt["END"]
            encoded_x_end[i] = len(x[i])+1
        return encoded_x, encoded_x_end
    
    def calcontext(self, timestep, query):
        extended_query = torch.cat((query, query), dim = 1)
        permuted_context = self.context.permute(1,0,2)
#         for encoder_timestep in range(self.context.shape[1]):
#             scores.append(torch.sum(self.context[:,encoder_timestep] * extended_query, dim = 1, keepdims=True))
#         scores = torch.cat(scores, dim = 1)
        scores = torch.sum(permuted_context * extended_query, dim = 2).permute(1,0)
        weights = nn.Softmax(dim = 1)(scores).unsqueeze(2)
        alignment = torch.sum(weights * self.context,  dim = 1)
        
        return alignment
        
    
    def forward(self, context, target_,teacher_ratio):
        self.context = context
        encoded_x, encoded_x_end = self.encode_inp(target_)
        target_seq = self.embedding(encoded_x)
        
        initial_hidden1 = torch.rand(target_seq.shape[0], self.hidden_size)
        initial_cell1 = torch.rand(target_seq.shape[0], self.hidden_size)
        initial_hidden2 = torch.rand(target_seq.shape[0], self.hidden_size)
        initial_cell2 = torch.rand(target_seq.shape[0], self.hidden_size)
        
        outputs = []
        hidden_states = []
        cell_states = []
        query = [initial_hidden2]
        for timestep in range(self.max_tgt+1):
            if(timestep == 0):
                (h_t1, c_t1) = self.dec_cells[0](torch.cat((target_seq[:,timestep],self.calcontext(0,query[-1])),dim=1), (initial_hidden1, initial_cell1))
                (h_t2, c_t2) = self.dec_cells[1](h_t1, (initial_hidden2, initial_cell2))
            else:
                input = []
                if(torch.rand(1).item() < teacher_ratio):
                    input = target_seq[:,timestep]
                else:
                    input = self.embedding(torch.argmax(outputs[-1],dim=1))
                    
                (h_t1, c_t1) = self.dec_cells[0](torch.cat((input,self.calcontext(timestep,query[-1])),dim=1), (hidden_states[-1][0], cell_states[-1][0]))
                (h_t2, c_t2) = self.dec_cells[1](h_t1, (hidden_states[-1][1], cell_states[-1][1]))
            hidden_states.append([h_t1, h_t2])
            cell_states.append([c_t1, c_t2])
            query.append(h_t2)
            out = self.linear(h_t2)
            outputs.append(out.unsqueeze(1))
    

        output_prob = torch.cat(outputs,dim = 1)
        
        return nn.LogSoftmax(dim = 2)(output_prob), encoded_x

In [8]:
encoder = Encoder()
decoder = Decoder()
optimizer_encoder = optim.Adam(encoder.parameters(), lr=0.0001)
optimizer_decoder = optim.Adam(decoder.parameters(), lr=0.0001)
criterion = nn.NLLLoss(ignore_index = 46)

# Set up early stopping parameters
patience = 100  # Number of epochs to wait for improvement
best_val_loss = float('inf')
epochs_since_improvement = 0
loss_val = 0.0

num_epochs = 100
for epoch in range(num_epochs):
    # Training Loop
    encoder.train()
    decoder.train()
    batch = 1
    for data, vis in train_loader:
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        
        context = encoder(data)
        output, encoded_tgt = decoder(context, vis, 1.0)
        output = output.permute(0,2,1)
        loss = criterion(output[:,:,:-1], encoded_tgt[:,1:])
        
        loss.backward()
        optimizer_decoder.step()
        optimizer_encoder.step()
        
        print(f"Epoch[{epoch}], Batch[{batch}], loss: {loss.item()}")

        # Calculate Perplexity
        perplexity = torch.exp(loss)
        print(f"Epoch[{epoch}], Batch[{batch}], perplexity: {perplexity.item()}")
        batch += 1
        
    
#     # Validation Loop
#     encoder.eval()
#     decoder.eval()
#     with torch.no_grad():
#         context_ = encoder(test_source)
#         output_ = decoder(context_,test_target,0.0)
#         loss_val = criterion(output_[:,:-1,:], test_target_encoded[:,1:,:])
#         print(f'Epoch [{epoch+1}/{num_epochs}], Eval Loss: {loss_val.item():.4f}')
    
#         if loss_val < best_val_loss:
#             best_val_loss = loss_val
#             epochs_since_improvement = 0
#             torch.save(encoder.state_dict(), 'Model/encoder.pth')
#             torch.save(decoder.state_dict(), 'Model/decoder.pth')

#         else:
#             epochs_since_improvement += 1

#     # Check if we should stop training early
#     if epochs_since_improvement >= patience:
#         print(f"Early stopping after {epoch+1} epochs with no improvement.")
#         break

# if loss_val < best_val_loss:
#     torch.save(encoder.state_dict(), 'Model/encoder.pth')
#     torch.save(decoder.state_dict(), 'Model/decoder.pth')