In [51]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [52]:
class CustomDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)
#         self.data.sample(frac=1).reset_index(drop=True)
        self.start = 1

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_sentence = torch.tensor([self.start]+[ord(c) - ord('a') for c in self.data.iloc[idx, 0]])
#         input_sentence = self.embedding(input_sentence)
        
        target_sentence = torch.tensor([self.start]+[ord(c) - ord('a') for c in self.data.iloc[idx, 1]])
        target_sentence_encoded = torch.zeros(target_sentence.shape[0],26)
        for idx in range(target_sentence.shape[0]):
            target_sentence_encoded[idx][target_sentence[idx]] = 1.0
#         target_sentence_ = self.embedding(target_sentence)
        return (input_sentence,target_sentence,target_sentence_encoded)

In [53]:
training_data = DataLoader(CustomDataset("Data/train_data.csv"))
training_source = []
training_target = []
training_target_encoded = []
for input,target,target_encoded in training_data:
    training_source.append(input)
    training_target.append(target)
    training_target_encoded.append(target_encoded)
#     print(x.shape)
training_source = torch.cat(training_source,dim=0)[:,1:]
training_target = torch.cat(training_target,dim=0)
training_target_encoded = torch.cat(training_target_encoded,dim=0)

In [54]:
test_data = DataLoader(CustomDataset("Data/eval_data.csv"))
test_source = []
test_target = []
test_target_encoded = []
for input,target,test_encoded in test_data:
    test_source.append(input)
    test_target.append(target)
    test_target_encoded.append(test_encoded)
#     print(x.shape)
test_source = torch.cat(test_source,dim=0)[:,1:]
test_target = torch.cat(test_target,dim=0)
test_target_encoded = torch.cat(test_target_encoded,dim=0)

In [55]:
print(test_source.shape,test_target.shape,test_target_encoded.shape)
print(training_source.shape,training_target.shape,training_target_encoded.shape)

torch.Size([2000, 8]) torch.Size([2000, 9]) torch.Size([2000, 9, 26])
torch.Size([7000, 8]) torch.Size([7000, 9]) torch.Size([7000, 9, 26])


In [56]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size,num_layers,bidirectional = 0):
        super(Encoder, self).__init__()
#         self.embedding = embedding
        self.embedding = nn.Embedding(26,input_size)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.encoder = nn.RNN(input_size, hidden_size,num_layers,batch_first = True,bidirectional=bool(self.bidirectional))
    
    def forward(self, input_):
        input_seq = self.embedding(input_)
        hidden = torch.zeros((self.bidirectional+1)*self.num_layers,input_seq.shape[0],self.hidden_size)
        _, encoder_hidden = self.encoder(input_seq,hidden)
        if(self.bidirectional == 1):
            encoder_hidden = torch.cat((encoder_hidden[0::2],encoder_hidden[1::2]),dim=2)
        return encoder_hidden
    
    def predict(self,input_str):
        input_sentence = torch.tensor([ord(c) - ord('a') for c in input_str])
        input_seq = self.embedding(input_sentence)
        hidden = torch.zeros((self.bidirectional+1)*self.num_layers,self.hidden_size)
        _, encoder_hidden = self.encoder(input_seq,hidden)
        if(self.bidirectional == 1):
            encoder_hidden = torch.cat((encoder_hidden[::2],encoder_hidden[1::2]),dim=1)
        return encoder_hidden
        
    
class Decoder(nn.Module):
    def __init__(self,input_size, hidden_size,num_layers,bidirectional = 0):
        super(Decoder, self).__init__()
#         self.embedding = embedding
        self.embedding = nn.Embedding(26,input_size)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.sequence_len = 9
        self.bidirectional = bidirectional
        
        self.dec_cells = nn.ModuleList([nn.RNNCell((bidirectional+1)*hidden_size+input_size, (bidirectional+1)*hidden_size)])
            
        self.decoder = nn.GRU((bidirectional+1)*hidden_size+input_size, (bidirectional+1)*hidden_size,batch_first = True)
        self.linear = nn.Linear((bidirectional+1)*hidden_size,26)
        self.output_layer = nn.LogSoftmax(dim = 2)
        self.output_layer_timestep = nn.LogSoftmax(dim = 1)
    
    def forward(self, context, target_,teacher_ratio):
        target_seq = self.embedding(target_)
        initial_hidden = torch.zeros_like(context)
        outputs = []
        
        hidden_states = []
        for timestep in range(self.sequence_len):
            h_t = []
            if(timestep == 0):
                h_t = self.dec_cells[0](torch.cat((target_seq[:,timestep],context[0]),dim=1), initial_hidden[0])
            else:
                input = []
                if(torch.rand(1).item() < teacher_ratio):
                    input = target_seq[:,timestep]
                else:
                    input = self.embedding(torch.argmax(outputs[-1],dim=1))
#                     print(torch.argmax(outputs[-1],dim=1)[0],target_[0,layer_idx])
                    
                h_t = self.dec_cells[0](torch.cat((input,context[0]),dim=1), hidden_states[-1])
            hidden_states.append(h_t)
            out = self.output_layer_timestep(self.linear(h_t))
            outputs.append(out)
    
 
        hidden_states = torch.cat(hidden_states,dim = 1).reshape(-1,9,context.shape[2])
        output_prob = torch.cat(outputs,dim = 1).reshape(-1,9,26)
        
        
#         out,_ = self.decoder(target_seq,context)
#         decoder_output = self.linear(hidden_states)
#         output_prob = self.output_layer(decoder_output)
        return output_prob

    def predict(self,context):
        target_seq = self.embedding(torch.tensor(1))
        initial_hidden = torch.zeros_like(context)
        outputs = []
        
        hidden_states = []
        for layer_idx in range(self.sequence_len):
            h_t = []
            if(layer_idx == 0):
                h_t = self.dec_cells[0](torch.cat((target_seq,context[0]),dim=0), initial_hidden[0])
            else:
                input = self.embedding(torch.argmax(outputs[-1],dim=0))
                    
                h_t = self.dec_cells[0](torch.cat((input,context[0]),dim=0), hidden_states[-1])
            hidden_states.append(h_t)
            out = nn.LogSoftmax(dim=0)(self.linear(h_t))
            outputs.append(out)
        
        output = ''.join([chr(torch.argmax(out,dim=0).item()+ord('a')) for out in outputs[:-1]])
        
        return output
        
        
        


In [57]:
input_size = 10
encoder = Encoder(input_size,32,1,1)
decoder = Decoder(input_size,32,1,1)
criterion = nn.KLDivLoss(reduction = "batchmean")
# optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()),lr = 0.001)
optimizer_encoder = optim.Adam(encoder.parameters(), lr=0.01)
# scheduler_encoder = optim.lr_scheduler.StepLR(optimizer_encoder, step_size=1000, gamma=0.5)
optimizer_decoder = optim.Adam(decoder.parameters(), lr=0.01)
# scheduler_decoder = optim.lr_scheduler.StepLR(optimizer_decoder, step_size=1000, gamma=0.5)
# optimizer_encoder = optim.SGD(encoder.parameters(), lr=0.2,momentum=0.9)
# optimizer_decoder = optim.SGD(decoder.parameters(), lr=0.2,momentum=0.9)

# Set up early stopping parameters
patience = 500  # Number of epochs to wait for improvement
best_val_loss = float('inf')
epochs_since_improvement = 0
loss_val = 0.0

num_epochs = 2000
for epoch in range(num_epochs):
    # Training Loop
    encoder.train()
    decoder.train()
    optimizer_encoder.zero_grad()
    optimizer_decoder.zero_grad()
#     optimizer.zero_grad()
    context = encoder(training_source)
    output = decoder(context,training_target,0.2)
    loss = criterion(output[:,:-1,:], training_target_encoded[:,1:,:])
    loss.backward()
    optimizer_encoder.step()
    optimizer_decoder.step()
#     scheduler_encoder.step()
#     scheduler_decoder.step()
#     optimizer.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {loss.item():.4f}')
    
    # Validation Loop
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        context_ = encoder(test_source)
        output_ = decoder(context_,test_target,0.0)
        loss_val = criterion(output_[:,:-1,:], test_target_encoded[:,1:,:])
        print(f'Epoch [{epoch+1}/{num_epochs}], Eval Loss: {loss_val.item():.4f}')
    
        if loss_val < best_val_loss:
            best_val_loss = loss_val
            epochs_since_improvement = 0
            torch.save(encoder.state_dict(), 'Model/encoder.pth')
            torch.save(decoder.state_dict(), 'Model/decoder.pth')

        else:
            epochs_since_improvement += 1

    # Check if we should stop training early
    if epochs_since_improvement >= patience:
        print(f"Early stopping after {epoch+1} epochs with no improvement.")
        break

if loss_val < best_val_loss:
    torch.save(encoder.state_dict(), 'Model/encoder.pth')
    torch.save(decoder.state_dict(), 'Model/decoder.pth')

Epoch [1/2000], Train Loss: 26.2044
Epoch [1/2000], Eval Loss: 26.0461
Epoch [2/2000], Train Loss: 26.0513
Epoch [2/2000], Eval Loss: 25.9819
Epoch [3/2000], Train Loss: 25.9558
Epoch [3/2000], Eval Loss: 25.9058
Epoch [4/2000], Train Loss: 25.8651
Epoch [4/2000], Eval Loss: 25.8221
Epoch [5/2000], Train Loss: 25.7872
Epoch [5/2000], Eval Loss: 25.7317
Epoch [6/2000], Train Loss: 25.6913
Epoch [6/2000], Eval Loss: 25.6446
Epoch [7/2000], Train Loss: 25.5984
Epoch [7/2000], Eval Loss: 25.5519
Epoch [8/2000], Train Loss: 25.5018
Epoch [8/2000], Eval Loss: 25.4613
Epoch [9/2000], Train Loss: 25.4208
Epoch [9/2000], Eval Loss: 25.3639
Epoch [10/2000], Train Loss: 25.2863
Epoch [10/2000], Eval Loss: 25.2333
Epoch [11/2000], Train Loss: 25.1588
Epoch [11/2000], Eval Loss: 25.1100
Epoch [12/2000], Train Loss: 25.0130
Epoch [12/2000], Eval Loss: 24.9579
Epoch [13/2000], Train Loss: 24.8637
Epoch [13/2000], Eval Loss: 24.8280
Epoch [14/2000], Train Loss: 24.7361
Epoch [14/2000], Eval Loss: 24.7

Epoch [115/2000], Train Loss: 21.0069
Epoch [115/2000], Eval Loss: 21.8517
Epoch [116/2000], Train Loss: 20.9250
Epoch [116/2000], Eval Loss: 21.8753
Epoch [117/2000], Train Loss: 20.9573
Epoch [117/2000], Eval Loss: 21.8111
Epoch [118/2000], Train Loss: 20.8772
Epoch [118/2000], Eval Loss: 21.8438
Epoch [119/2000], Train Loss: 20.9015
Epoch [119/2000], Eval Loss: 21.7889
Epoch [120/2000], Train Loss: 20.8205
Epoch [120/2000], Eval Loss: 21.7760
Epoch [121/2000], Train Loss: 20.8262
Epoch [121/2000], Eval Loss: 21.7884
Epoch [122/2000], Train Loss: 20.8137
Epoch [122/2000], Eval Loss: 21.7885
Epoch [123/2000], Train Loss: 20.8100
Epoch [123/2000], Eval Loss: 21.7492
Epoch [124/2000], Train Loss: 20.7524
Epoch [124/2000], Eval Loss: 21.7077
Epoch [125/2000], Train Loss: 20.7533
Epoch [125/2000], Eval Loss: 21.7143
Epoch [126/2000], Train Loss: 20.7171
Epoch [126/2000], Eval Loss: 21.6870
Epoch [127/2000], Train Loss: 20.6648
Epoch [127/2000], Eval Loss: 21.6503
Epoch [128/2000], Train L

Epoch [226/2000], Train Loss: 18.5513
Epoch [226/2000], Eval Loss: 19.6945
Epoch [227/2000], Train Loss: 18.3753
Epoch [227/2000], Eval Loss: 19.5735
Epoch [228/2000], Train Loss: 18.3043
Epoch [228/2000], Eval Loss: 19.4933
Epoch [229/2000], Train Loss: 18.2383
Epoch [229/2000], Eval Loss: 19.4644
Epoch [230/2000], Train Loss: 18.1921
Epoch [230/2000], Eval Loss: 19.4080
Epoch [231/2000], Train Loss: 18.1534
Epoch [231/2000], Eval Loss: 19.4350
Epoch [232/2000], Train Loss: 18.2267
Epoch [232/2000], Eval Loss: 19.3869
Epoch [233/2000], Train Loss: 18.1359
Epoch [233/2000], Eval Loss: 19.3267
Epoch [234/2000], Train Loss: 18.0824
Epoch [234/2000], Eval Loss: 19.3176
Epoch [235/2000], Train Loss: 18.1273
Epoch [235/2000], Eval Loss: 19.3625
Epoch [236/2000], Train Loss: 18.1107
Epoch [236/2000], Eval Loss: 19.3886
Epoch [237/2000], Train Loss: 18.1863
Epoch [237/2000], Eval Loss: 19.4941
Epoch [238/2000], Train Loss: 18.2837
Epoch [238/2000], Eval Loss: 19.1529
Epoch [239/2000], Train L

Epoch [336/2000], Train Loss: 9.4469
Epoch [336/2000], Eval Loss: 9.7636
Epoch [337/2000], Train Loss: 9.3108
Epoch [337/2000], Eval Loss: 9.6764
Epoch [338/2000], Train Loss: 9.2345
Epoch [338/2000], Eval Loss: 9.6542
Epoch [339/2000], Train Loss: 9.1933
Epoch [339/2000], Eval Loss: 9.5631
Epoch [340/2000], Train Loss: 9.1205
Epoch [340/2000], Eval Loss: 9.4611
Epoch [341/2000], Train Loss: 9.0501
Epoch [341/2000], Eval Loss: 9.3901
Epoch [342/2000], Train Loss: 8.9658
Epoch [342/2000], Eval Loss: 9.3554
Epoch [343/2000], Train Loss: 8.9222
Epoch [343/2000], Eval Loss: 9.3079
Epoch [344/2000], Train Loss: 8.8786
Epoch [344/2000], Eval Loss: 9.3093
Epoch [345/2000], Train Loss: 8.8564
Epoch [345/2000], Eval Loss: 9.4181
Epoch [346/2000], Train Loss: 8.9616
Epoch [346/2000], Eval Loss: 9.8401
Epoch [347/2000], Train Loss: 9.4475
Epoch [347/2000], Eval Loss: 11.4277
Epoch [348/2000], Train Loss: 10.8936
Epoch [348/2000], Eval Loss: 10.9492
Epoch [349/2000], Train Loss: 10.5831
Epoch [349

Epoch [448/2000], Train Loss: 10.8983
Epoch [448/2000], Eval Loss: 11.0109
Epoch [449/2000], Train Loss: 10.4977
Epoch [449/2000], Eval Loss: 10.8202
Epoch [450/2000], Train Loss: 10.3017
Epoch [450/2000], Eval Loss: 10.4123
Epoch [451/2000], Train Loss: 9.8860
Epoch [451/2000], Eval Loss: 10.3839
Epoch [452/2000], Train Loss: 9.8585
Epoch [452/2000], Eval Loss: 10.1024
Epoch [453/2000], Train Loss: 9.5941
Epoch [453/2000], Eval Loss: 9.9064
Epoch [454/2000], Train Loss: 9.4273
Epoch [454/2000], Eval Loss: 9.6541
Epoch [455/2000], Train Loss: 9.1842
Epoch [455/2000], Eval Loss: 9.4485
Epoch [456/2000], Train Loss: 8.9517
Epoch [456/2000], Eval Loss: 9.3098
Epoch [457/2000], Train Loss: 8.8232
Epoch [457/2000], Eval Loss: 9.1804
Epoch [458/2000], Train Loss: 8.7098
Epoch [458/2000], Eval Loss: 8.9690
Epoch [459/2000], Train Loss: 8.5024
Epoch [459/2000], Eval Loss: 8.8380
Epoch [460/2000], Train Loss: 8.3641
Epoch [460/2000], Eval Loss: 8.6888
Epoch [461/2000], Train Loss: 8.2246
Epoch 

Epoch [562/2000], Train Loss: 5.8812
Epoch [562/2000], Eval Loss: 6.2901
Epoch [563/2000], Train Loss: 5.8791
Epoch [563/2000], Eval Loss: 6.2876
Epoch [564/2000], Train Loss: 5.8743
Epoch [564/2000], Eval Loss: 6.2840
Epoch [565/2000], Train Loss: 5.8683
Epoch [565/2000], Eval Loss: 6.2807
Epoch [566/2000], Train Loss: 5.8650
Epoch [566/2000], Eval Loss: 6.2768
Epoch [567/2000], Train Loss: 5.8608
Epoch [567/2000], Eval Loss: 6.2741
Epoch [568/2000], Train Loss: 5.8573
Epoch [568/2000], Eval Loss: 6.2705
Epoch [569/2000], Train Loss: 5.8514
Epoch [569/2000], Eval Loss: 6.2654
Epoch [570/2000], Train Loss: 5.8470
Epoch [570/2000], Eval Loss: 6.2621
Epoch [571/2000], Train Loss: 5.8428
Epoch [571/2000], Eval Loss: 6.2589
Epoch [572/2000], Train Loss: 5.8385
Epoch [572/2000], Eval Loss: 6.2566
Epoch [573/2000], Train Loss: 5.8343
Epoch [573/2000], Eval Loss: 6.2543
Epoch [574/2000], Train Loss: 5.8302
Epoch [574/2000], Eval Loss: 6.2509
Epoch [575/2000], Train Loss: 5.8291
Epoch [575/200

Epoch [676/2000], Train Loss: 5.5697
Epoch [676/2000], Eval Loss: 6.1076
Epoch [677/2000], Train Loss: 5.5696
Epoch [677/2000], Eval Loss: 6.1067
Epoch [678/2000], Train Loss: 5.5672
Epoch [678/2000], Eval Loss: 6.1061
Epoch [679/2000], Train Loss: 5.5646
Epoch [679/2000], Eval Loss: 6.1031
Epoch [680/2000], Train Loss: 5.5630
Epoch [680/2000], Eval Loss: 6.1039
Epoch [681/2000], Train Loss: 5.5604
Epoch [681/2000], Eval Loss: 6.1056
Epoch [682/2000], Train Loss: 5.5625
Epoch [682/2000], Eval Loss: 6.1024
Epoch [683/2000], Train Loss: 5.5586
Epoch [683/2000], Eval Loss: 6.1066
Epoch [684/2000], Train Loss: 5.5550
Epoch [684/2000], Eval Loss: 6.1083
Epoch [685/2000], Train Loss: 5.5531
Epoch [685/2000], Eval Loss: 6.1046
Epoch [686/2000], Train Loss: 5.5521
Epoch [686/2000], Eval Loss: 6.1013
Epoch [687/2000], Train Loss: 5.5510
Epoch [687/2000], Eval Loss: 6.1052
Epoch [688/2000], Train Loss: 5.5524
Epoch [688/2000], Eval Loss: 6.1026
Epoch [689/2000], Train Loss: 5.5508
Epoch [689/200

Epoch [788/2000], Train Loss: 10.2213
Epoch [788/2000], Eval Loss: 10.7492
Epoch [789/2000], Train Loss: 10.1017
Epoch [789/2000], Eval Loss: 10.6270
Epoch [790/2000], Train Loss: 9.9853
Epoch [790/2000], Eval Loss: 10.5146
Epoch [791/2000], Train Loss: 9.8738
Epoch [791/2000], Eval Loss: 10.4131
Epoch [792/2000], Train Loss: 9.7651
Epoch [792/2000], Eval Loss: 10.3032
Epoch [793/2000], Train Loss: 9.6535
Epoch [793/2000], Eval Loss: 10.1922
Epoch [794/2000], Train Loss: 9.5443
Epoch [794/2000], Eval Loss: 10.0878
Epoch [795/2000], Train Loss: 9.4418
Epoch [795/2000], Eval Loss: 9.9824
Epoch [796/2000], Train Loss: 9.3385
Epoch [796/2000], Eval Loss: 9.8780
Epoch [797/2000], Train Loss: 9.2403
Epoch [797/2000], Eval Loss: 9.7713
Epoch [798/2000], Train Loss: 9.1453
Epoch [798/2000], Eval Loss: 9.6741
Epoch [799/2000], Train Loss: 9.0461
Epoch [799/2000], Eval Loss: 9.5741
Epoch [800/2000], Train Loss: 8.9516
Epoch [800/2000], Eval Loss: 9.4776
Epoch [801/2000], Train Loss: 8.8626
Epoch

Epoch [900/2000], Train Loss: 6.6145
Epoch [900/2000], Eval Loss: 6.8971
Epoch [901/2000], Train Loss: 6.5746
Epoch [901/2000], Eval Loss: 6.8402
Epoch [902/2000], Train Loss: 6.5274
Epoch [902/2000], Eval Loss: 6.7971
Epoch [903/2000], Train Loss: 6.4916
Epoch [903/2000], Eval Loss: 6.7548
Epoch [904/2000], Train Loss: 6.4507
Epoch [904/2000], Eval Loss: 6.7116
Epoch [905/2000], Train Loss: 6.4060
Epoch [905/2000], Eval Loss: 6.6864
Epoch [906/2000], Train Loss: 6.3820
Epoch [906/2000], Eval Loss: 6.6522
Epoch [907/2000], Train Loss: 6.3465
Epoch [907/2000], Eval Loss: 6.6331
Epoch [908/2000], Train Loss: 6.3236
Epoch [908/2000], Eval Loss: 6.5988
Epoch [909/2000], Train Loss: 6.2873
Epoch [909/2000], Eval Loss: 6.5711
Epoch [910/2000], Train Loss: 6.2613
Epoch [910/2000], Eval Loss: 6.5428
Epoch [911/2000], Train Loss: 6.2381
Epoch [911/2000], Eval Loss: 6.5171
Epoch [912/2000], Train Loss: 6.2173
Epoch [912/2000], Eval Loss: 6.4936
Epoch [913/2000], Train Loss: 6.1957
Epoch [913/200

Epoch [1012/2000], Train Loss: 5.6566
Epoch [1012/2000], Eval Loss: 5.9492
Epoch [1013/2000], Train Loss: 5.6547
Epoch [1013/2000], Eval Loss: 5.9481
Epoch [1014/2000], Train Loss: 5.6529
Epoch [1014/2000], Eval Loss: 5.9468
Epoch [1015/2000], Train Loss: 5.6516
Epoch [1015/2000], Eval Loss: 5.9452
Epoch [1016/2000], Train Loss: 5.6494
Epoch [1016/2000], Eval Loss: 5.9437
Epoch [1017/2000], Train Loss: 5.6480
Epoch [1017/2000], Eval Loss: 5.9418
Epoch [1018/2000], Train Loss: 5.6462
Epoch [1018/2000], Eval Loss: 5.9406
Epoch [1019/2000], Train Loss: 5.6445
Epoch [1019/2000], Eval Loss: 5.9399
Epoch [1020/2000], Train Loss: 5.6428
Epoch [1020/2000], Eval Loss: 5.9396
Epoch [1021/2000], Train Loss: 5.6422
Epoch [1021/2000], Eval Loss: 5.9388
Epoch [1022/2000], Train Loss: 5.6393
Epoch [1022/2000], Eval Loss: 5.9375
Epoch [1023/2000], Train Loss: 5.6378
Epoch [1023/2000], Eval Loss: 5.9357
Epoch [1024/2000], Train Loss: 5.6372
Epoch [1024/2000], Eval Loss: 5.9348
Epoch [1025/2000], Train 

Epoch [1122/2000], Train Loss: 5.5161
Epoch [1122/2000], Eval Loss: 5.8953
Epoch [1123/2000], Train Loss: 5.5153
Epoch [1123/2000], Eval Loss: 5.8953
Epoch [1124/2000], Train Loss: 5.5136
Epoch [1124/2000], Eval Loss: 5.8955
Epoch [1125/2000], Train Loss: 5.5141
Epoch [1125/2000], Eval Loss: 5.8954
Epoch [1126/2000], Train Loss: 5.5132
Epoch [1126/2000], Eval Loss: 5.8953
Epoch [1127/2000], Train Loss: 5.5106
Epoch [1127/2000], Eval Loss: 5.8954
Epoch [1128/2000], Train Loss: 5.5095
Epoch [1128/2000], Eval Loss: 5.8960
Epoch [1129/2000], Train Loss: 5.5092
Epoch [1129/2000], Eval Loss: 5.8965
Epoch [1130/2000], Train Loss: 5.5082
Epoch [1130/2000], Eval Loss: 5.8968
Epoch [1131/2000], Train Loss: 5.5067
Epoch [1131/2000], Eval Loss: 5.8965
Epoch [1132/2000], Train Loss: 5.5061
Epoch [1132/2000], Eval Loss: 5.8964
Epoch [1133/2000], Train Loss: 5.5057
Epoch [1133/2000], Eval Loss: 5.8968
Epoch [1134/2000], Train Loss: 5.5033
Epoch [1134/2000], Eval Loss: 5.8971
Epoch [1135/2000], Train 

Epoch [1232/2000], Train Loss: 5.4071
Epoch [1232/2000], Eval Loss: 5.9687
Epoch [1233/2000], Train Loss: 5.4022
Epoch [1233/2000], Eval Loss: 5.9702
Epoch [1234/2000], Train Loss: 5.4036
Epoch [1234/2000], Eval Loss: 5.9697
Epoch [1235/2000], Train Loss: 5.3991
Epoch [1235/2000], Eval Loss: 5.9705
Epoch [1236/2000], Train Loss: 5.3998
Epoch [1236/2000], Eval Loss: 5.9715
Epoch [1237/2000], Train Loss: 5.3979
Epoch [1237/2000], Eval Loss: 5.9740
Epoch [1238/2000], Train Loss: 5.3992
Epoch [1238/2000], Eval Loss: 5.9739
Epoch [1239/2000], Train Loss: 5.3955
Epoch [1239/2000], Eval Loss: 5.9777
Epoch [1240/2000], Train Loss: 5.3937
Epoch [1240/2000], Eval Loss: 5.9778
Epoch [1241/2000], Train Loss: 5.3925
Epoch [1241/2000], Eval Loss: 5.9772
Epoch [1242/2000], Train Loss: 5.3931
Epoch [1242/2000], Eval Loss: 5.9811
Epoch [1243/2000], Train Loss: 5.3937
Epoch [1243/2000], Eval Loss: 5.9789
Epoch [1244/2000], Train Loss: 5.3891
Epoch [1244/2000], Eval Loss: 5.9795
Epoch [1245/2000], Train 

Epoch [1342/2000], Train Loss: 7.6290
Epoch [1342/2000], Eval Loss: 8.0350
Epoch [1343/2000], Train Loss: 7.5031
Epoch [1343/2000], Eval Loss: 7.8982
Epoch [1344/2000], Train Loss: 7.3844
Epoch [1344/2000], Eval Loss: 7.7575
Epoch [1345/2000], Train Loss: 7.2660
Epoch [1345/2000], Eval Loss: 7.6469
Epoch [1346/2000], Train Loss: 7.1519
Epoch [1346/2000], Eval Loss: 7.5591
Epoch [1347/2000], Train Loss: 7.0640
Epoch [1347/2000], Eval Loss: 7.4507
Epoch [1348/2000], Train Loss: 6.9724
Epoch [1348/2000], Eval Loss: 7.3624
Epoch [1349/2000], Train Loss: 6.8845
Epoch [1349/2000], Eval Loss: 7.2693
Epoch [1350/2000], Train Loss: 6.8032
Epoch [1350/2000], Eval Loss: 7.1835
Epoch [1351/2000], Train Loss: 6.7345
Epoch [1351/2000], Eval Loss: 7.0938
Epoch [1352/2000], Train Loss: 6.6684
Epoch [1352/2000], Eval Loss: 7.0114
Epoch [1353/2000], Train Loss: 6.6083
Epoch [1353/2000], Eval Loss: 6.9635
Epoch [1354/2000], Train Loss: 6.5534
Epoch [1354/2000], Eval Loss: 6.9024
Epoch [1355/2000], Train 

Epoch [1452/2000], Train Loss: 5.5882
Epoch [1452/2000], Eval Loss: 5.9295
Epoch [1453/2000], Train Loss: 5.5872
Epoch [1453/2000], Eval Loss: 5.9291
Epoch [1454/2000], Train Loss: 5.5850
Epoch [1454/2000], Eval Loss: 5.9282
Epoch [1455/2000], Train Loss: 5.5837
Epoch [1455/2000], Eval Loss: 5.9269
Epoch [1456/2000], Train Loss: 5.5827
Epoch [1456/2000], Eval Loss: 5.9251
Epoch [1457/2000], Train Loss: 5.5804
Epoch [1457/2000], Eval Loss: 5.9239
Epoch [1458/2000], Train Loss: 5.5785
Epoch [1458/2000], Eval Loss: 5.9234
Epoch [1459/2000], Train Loss: 5.5767
Epoch [1459/2000], Eval Loss: 5.9229
Epoch [1460/2000], Train Loss: 5.5765
Epoch [1460/2000], Eval Loss: 5.9223
Epoch [1461/2000], Train Loss: 5.5736
Epoch [1461/2000], Eval Loss: 5.9223
Epoch [1462/2000], Train Loss: 5.5719
Epoch [1462/2000], Eval Loss: 5.9215
Epoch [1463/2000], Train Loss: 5.5715
Epoch [1463/2000], Eval Loss: 5.9204
Epoch [1464/2000], Train Loss: 5.5686
Epoch [1464/2000], Eval Loss: 5.9192
Epoch [1465/2000], Train 

Epoch [1562/2000], Train Loss: 5.4512
Epoch [1562/2000], Eval Loss: 5.9406
Epoch [1563/2000], Train Loss: 5.4503
Epoch [1563/2000], Eval Loss: 5.9415
Epoch [1564/2000], Train Loss: 5.4501
Epoch [1564/2000], Eval Loss: 5.9430
Epoch [1565/2000], Train Loss: 5.4494
Epoch [1565/2000], Eval Loss: 5.9447
Epoch [1566/2000], Train Loss: 5.4469
Epoch [1566/2000], Eval Loss: 5.9448
Epoch [1567/2000], Train Loss: 5.4456
Epoch [1567/2000], Eval Loss: 5.9440
Epoch [1568/2000], Train Loss: 5.4447
Epoch [1568/2000], Eval Loss: 5.9443
Epoch [1569/2000], Train Loss: 5.4433
Epoch [1569/2000], Eval Loss: 5.9466
Epoch [1570/2000], Train Loss: 5.4447
Epoch [1570/2000], Eval Loss: 5.9487
Epoch [1571/2000], Train Loss: 5.4421
Epoch [1571/2000], Eval Loss: 5.9493
Epoch [1572/2000], Train Loss: 5.4404
Epoch [1572/2000], Eval Loss: 5.9500
Epoch [1573/2000], Train Loss: 5.4389
Epoch [1573/2000], Eval Loss: 5.9513
Epoch [1574/2000], Train Loss: 5.4390
Epoch [1574/2000], Eval Loss: 5.9528
Epoch [1575/2000], Train 

In [63]:
encoder = Encoder(input_size,32,1,1)
decoder = Decoder(input_size,32,1,1)
encoder.load_state_dict(torch.load('Model/encoder.pth'))  
decoder.load_state_dict(torch.load('Model/decoder.pth')) 
encoder.eval()
decoder.eval()
with torch.no_grad():
    context_ = encoder(test_source)
    output_ = decoder(context_,test_target,0.0)
    loss_val = criterion(output_[:,:-1,:], test_target_encoded[:,1:,:])
    print(f'Epoch [{epoch+1}/{num_epochs}], Eval Loss: {loss_val.item():.4f}')


Epoch [1615/2000], Eval Loss: 5.8936


In [64]:
encoder.eval()
decoder.eval()
with torch.no_grad():
    context = encoder(test_source)
    output = decoder(context,test_target,0.0)
    actual = torch.argmax(test_target_encoded[:,1:,:],dim=2)
    predictions = torch.argmax(output[:,:-1,:],dim=2)
    wrong_pred = torch.where(predictions != actual,1.0,0.0)
    print(f'Average(over batch) no. of Wrong predictions per sequence : {torch.sum(wrong_pred) / wrong_pred.shape[0]:.4f}')

Average(over batch) no. of Wrong predictions per sequence : 3.9985


In [65]:
# Function to check how many characters match in the two strings
def check(pred: str, true: str):
    correct = 0
    for a, b in zip(pred, true):
        if a == b:
            correct += 1

    # Prediction is more than 8 letters, so penalize for every extra letter.
    correct -= max(0, len(pred) - len(true))
    correct = max(0, correct)
    return correct

# Function to score the model's performance
def evaluate(encoder, decoder):

    # Train data
    print("Obtaining results for training data:")
    train_data = pd.read_csv("Data/train_data.csv").to_numpy()
    results = {
        "pred": [],
        "true": [],
        "score": [],
    }
    correct = [0 for _ in range(9)]
    for x, y in train_data:
        pred = decoder.predict(encoder.predict(x))
        score = check(pred, y)
        results["pred"].append(pred)
        results["true"].append(y)
        results["score"].append(score)

        correct[score] += 1
    print("Train dataset results:")
    for num_chr in range(9):
        print(
            f"Number of predictions with {num_chr} correct predictions: {correct[num_chr]}"
        )
    points = sum(correct[4:6]) * 0.5 + sum(correct[6:])
    print(f"Points: {points}")
    # Save predicitons and true sentences to inspect manually if required.
    pd.DataFrame.from_dict(results).to_csv("results_train.csv", index=False)

    #----------------------------------------------------------------------------------

    print("Obtaining metrics for eval data:")
    eval_data = pd.read_csv("Data/eval_data.csv").to_numpy()
    results = {
        "pred": [],
        "true": [],
        "score": [],
    }
    correct = [0 for _ in range(9)]
    for x, y in eval_data:
        pred = decoder.predict(encoder.predict(x))
        score = check(pred, y)
        results["pred"].append(pred)
        results["true"].append(y)
        results["score"].append(score)

        correct[score] += 1
    print("Eval dataset results:")
    for num_chr in range(9):
        print(
            f"Number of predictions with {num_chr} correct predictions: {correct[num_chr]}"
        )
    points = sum(correct[4:6]) * 0.5 + sum(correct[6:])
    marks = round(min(2, points / 1400 * 2) * 2) / 2  # Rounds to the nearest 0.5
    print(f"Points: {points}")
    print(f"Marks: {marks}")
    # Save predicitons and true sentences to inspect manually if required.
    pd.DataFrame.from_dict(results).to_csv("results_eval.csv", index=False)


In [66]:
evaluate(encoder,decoder)

Obtaining results for training data:
Train dataset results:
Number of predictions with 0 correct predictions: 12
Number of predictions with 1 correct predictions: 96
Number of predictions with 2 correct predictions: 397
Number of predictions with 3 correct predictions: 1049
Number of predictions with 4 correct predictions: 1688
Number of predictions with 5 correct predictions: 1865
Number of predictions with 6 correct predictions: 1336
Number of predictions with 7 correct predictions: 479
Number of predictions with 8 correct predictions: 78
Points: 3669.5
Obtaining metrics for eval data:
Eval dataset results:
Number of predictions with 0 correct predictions: 11
Number of predictions with 1 correct predictions: 72
Number of predictions with 2 correct predictions: 220
Number of predictions with 3 correct predictions: 417
Number of predictions with 4 correct predictions: 557
Number of predictions with 5 correct predictions: 422
Number of predictions with 6 correct predictions: 215
Number 