In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataloader import HangmanTextDataset
import numpy as np
# from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import os
import string
import yaml

In [2]:
# https://github.com/methi1999/hangman/blob/master/model.py 
# take design idea from above repository

### Directed Acyclic Graph for the model

In [3]:
input_word_dim = 28
hidden_dim = 64
input_available_dim = 28 
input_missed_dim = 28 
batch_size = 100
sequence_length = 40
target_dim = 28
num_layers = 2

input_word_tensor = torch.randn(batch_size,sequence_length,input_word_dim)
input_available_tensor = torch.randn(batch_size,input_available_dim)
input_missed_tensor = torch.randn(batch_size, input_missed_dim) 
print("Size of input tensors: \n", input_word_tensor.size(), input_available_tensor.size(), input_missed_tensor.size())

Size of input tensors: 
 torch.Size([100, 40, 28]) torch.Size([100, 28]) torch.Size([100, 28])


In [4]:
lstm_layer = torch.nn.LSTM(input_word_dim, hidden_dim, num_layers,batch_first=True) 
fc_1 = torch.nn.Linear(in_features=hidden_dim, out_features=hidden_dim) 
fc_2 = torch.nn.Linear(in_features=hidden_dim, out_features=hidden_dim) 
relu = torch.nn.ReLU()

fc_3 = torch.nn.Linear(in_features = input_available_dim, out_features = hidden_dim) 
fc_4 = torch.nn.Linear(in_features = hidden_dim, out_features = hidden_dim) 

fc_5 = torch.nn.Linear(in_features = input_missed_dim, out_features=hidden_dim) 
fc_6 = torch.nn.Linear(in_features = hidden_dim, out_features = hidden_dim)

fc_7 = torch.nn.Linear(in_features= hidden_dim*3, out_features=hidden_dim*2) 
fc_8 = torch.nn.Linear(in_features= hidden_dim*2, out_features=hidden_dim) 
fc_9 = torch.nn.Linear(in_features= hidden_dim, out_features=target_dim) 

soft = torch.nn.Softmax(dim=1)

In [5]:
out, (h_0, c_0) = lstm_layer(input_word_tensor) 
out_lstm = out[:,-1,:]  
print(f"out_lstm({out_lstm.shape})", end = ' --> ')
out_fc_1 = relu(fc_1(out_lstm)) 
print(f"out_fc_1({out_fc_1.shape})", end = ' --> ')
out_fc_2 = relu(fc_2(out_fc_1)) 
print(f"out_fc_2({out_fc_2.shape})", end = ' \n') 

out_fc_3 = relu(fc_3(input_available_tensor)) 
print(f"out_fc_3({out_fc_3.shape})", end = ' --> ')
out_fc_4 = relu(fc_4(out_fc_3)) 
print(f"out_fc_4({out_fc_4.shape})", end = '\n')

out_fc_5 = relu(fc_5(input_missed_tensor)) 
print(f"out_fc_5({out_fc_5.shape})", end = ' --> ')
out_fc_6 = relu(fc_6(out_fc_5)) 
print(f"out_fc_6({out_fc_6.shape})", end = '\n\n')

concat = torch.cat((out_fc_2, out_fc_4, out_fc_5), dim = 1)
print(f"combining outputs from above three networks: ({concat.shape})") 

print("Passing combined features into another feed forward layer")
out = relu(fc_7(concat)) 
print(f"({out.shape})", end = ' --> ')
out = relu(fc_8(out)) 
print(f"({out.shape})", end = ' --> ')
out = relu(fc_9(out)) 
print(f"({out.shape})")

out_lstm(torch.Size([100, 64])) --> out_fc_1(torch.Size([100, 64])) --> out_fc_2(torch.Size([100, 64])) 
out_fc_3(torch.Size([100, 64])) --> out_fc_4(torch.Size([100, 64]))
out_fc_5(torch.Size([100, 64])) --> out_fc_6(torch.Size([100, 64]))

combining outputs from above three networks: (torch.Size([100, 192]))
Passing combined features into another feed forward layer
(torch.Size([100, 128])) --> (torch.Size([100, 64])) --> (torch.Size([100, 28]))


In [6]:
soft(out).detach().numpy().sum()

100.0

### Building a Model 
Including methods like:  
1. \_\_init\_\_()
2. forward()
3. save_model()
4. load_checkpoint() 
5. train_model() 
6. test_model() 
7. predict()

In [12]:
config = {
    "vocab_size": 28,
    "hidden_dim" : 64,
    "input_available_dim" : 28 ,
    "input_missed_dim" : 28 ,
    "batch_size" : 100,
    "max_sequence_length" : 40,
    "target_dim" : 28,
    "num_layers" : 3, 
    "lstm_bidirectional": False, 
    "learning_rate": 1e-2, 
    "use_embedding": False, 
    "embed_dim": 5,
    "models_path": "./checkpoints/"
}

In [13]:
class HangmanSaver(nn.Module):
    """
        Provides basic functionality including load model, save model and inference function. 
        Inference function would be directly called in our guess method that we have to build. 
    """

    def __init__(self, config,desc = "model_trident"):
        super(HangmanSaver, self).__init__()

        self.use_embedding = config["use_embedding"]
        self.models_path = config["models_path"]
        self.desc = desc
        self.seq_length = config['max_sequence_length']
        self.vocab_size = config["vocab_size"]
        self.vocab = ['<pad>'] + list(string.ascii_lowercase) + ['_']
        self.char_to_index = {c:i for i,c in enumerate(self.vocab)}
        self.index_to_char = {i:c for i,c in enumerate(self.vocab)}
        # self.hidden_dim = config["hidden_dim"] 
        # self.input_available_dim = config["input_available_dim"]   
        # self.input_missed_dim = config["input_missed_dim"]   
        # self.batch_size = config["batch_size"] 
        # self.sequence_length = config["sequence_length"] 
        # self.target_dim = config["target_dim"] 
        # self.num_layers = config["num_layers"] 

        self.rec_layer = torch.nn.LSTM(input_size = config["embed_dim"] if config["use_embedding"] else config["vocab_size"], 
                                       hidden_size = config["hidden_dim"], num_layers = config["num_layers"], 
                                       batch_first=True, bidirectional=config["lstm_bidirectional"]) 
        self.fc_1 = torch.nn.Linear(in_features=config["hidden_dim"], out_features=config["hidden_dim"]*4) 
        self.fc_2 = torch.nn.Linear(in_features=config["hidden_dim"]*4, out_features=config["hidden_dim"]) 

        self.fc_3 = torch.nn.Linear(in_features = config["input_available_dim"], out_features = config["hidden_dim"]*2) 
        self.fc_4 = torch.nn.Linear(in_features = config["hidden_dim"]*2, out_features = config["hidden_dim"]) 

        # self.fc_5 = torch.nn.Linear(in_features = config["input_missed_dim"], out_features=config["hidden_dim"]) 
        # self.fc_6 = torch.nn.Linear(in_features = config["hidden_dim"], out_features = config["hidden_dim"])

        self.fc_7 = torch.nn.Linear(in_features= config["hidden_dim"]*2, out_features=config["hidden_dim"]*4) 
        self.fc_8 = torch.nn.Linear(in_features= config["hidden_dim"]*4, out_features=config["hidden_dim"]) 
        self.fc_9 = torch.nn.Linear(in_features= config["hidden_dim"], out_features=config["target_dim"]) 

        self.soft = torch.nn.Softmax(dim=1)
        
        self.relu = torch.nn.ReLU()

        self.optimizer = torch.optim.Adam(self.parameters(), lr = config['learning_rate'])

        self.criterion = torch.nn.BCEWithLogitsLoss()


    def forward(self, x_in, x_available):
        """
        Forward pass through RNN
        :param x_in: input tensor of shape (batch size, max sequence length, input_dim)
        :param x_lens: actual lengths of each sequence < max sequence length (since padded with zeros)
        :param miss_chars: tensor of length batch_size x vocab size. 1 at index i indicates that ith character is NOT present
        :return: tensor of shape (batch size, max sequence length, output dim)
        """        
        if self.use_embedding:
            x_in = self.embedding(x_in)

        out, (h_0, c_0) = self.rec_layer(x_in) 
        out_rec = out[:,-1,:]  
        out_fc_1 = relu(self.fc_1(out_rec)) 
        out_fc_2 = relu(self.fc_2(out_fc_1)) 

        out_fc_3 = relu(self.fc_3(x_available)) 
        out_fc_4 = relu(self.fc_4(out_fc_3)) 

        # out_fc_5 = relu(self.fc_5(x_missed)) 
        # out_fc_6 = relu(self.fc_6(out_fc_5)) 

        concat = torch.cat((out_fc_2, out_fc_4), dim = 1)
        out = relu(self.fc_9(relu(self.fc_8(relu(self.fc_7(concat))))))
        return soft(out)
            

    def loss(self, predicted, truth):
        pass

    # save model, along with loss details and testing accuracy
    # best is the model which has the lowest test loss. This model is used during feature extraction
    def save_model(self, is_best, epoch, train_loss, test_loss):

        base_path = self.models_path
        if is_best:
            filename = base_path + 'best_' + self.desc + '.pth'
        else:
            filename = base_path + "epoch_" + str(epoch) + '_' + self.desc + '.pth'

        torch.save({
            'epoch': epoch,
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_loss': train_loss,
            'test_loss': test_loss,
        }, filename)

        print("Saved model")

    # Loads saved model for resuming training or inference
    # mode can be 'test' or 'train'
    def load_model(self, checkpoint_path = None):
        # load model parameters and return training/testing loss and testing accuracy
        if checkpoint_path is None: 
            print("Checkpoint path has to be provided...") 
            exit(0)
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("Loaded pretrained model from:", checkpoint_path)

        return checkpoint['epoch'], checkpoint['train_loss'], checkpoint['test_loss']
        

    def inference(self,x_in, x_available): 
        """
        Should be the same as forward function but with batch size = 1
        should take input as strings instead of vectors
        should preprocess strings on its own
        should infer possible outcomes on its own.
        """
        def one_hot_encoding(word):
                one_hot = np.zeros((self.seq_length, self.vocab_size))
                for i, ch in enumerate(word):
                    index = self.char_to_index.get(ch)  # Get index from char_to_index dict
                    one_hot[i][index] = 1
                return torch.tensor(one_hot[np.newaxis,:,:], dtype=torch.float32)
            
        # Convert target word to a one-d vector target
        def encode_word_to_vec(word):
            target = np.zeros(self.vocab_size) 
            if type(word)!=str:
                return target
            for ch in word:
                if ch != '_' and ch != '<pad>': 
                    index = self.char_to_index[ch]
                    target[index] = 1
            return torch.tensor(target[np.newaxis,:], dtype = torch.float32) 
        
        x_in = one_hot_encoding(x_in) 
        x_available = encode_word_to_vec(x_available) 

        with torch.no_grad():
            x_pred = self.forward(x_in, x_available)
        word_probs = {w:p for w,p in zip(self.vocab,x_pred.numpy()[0]) if p>0}

        return word_probs

In [14]:
savior = HangmanSaver(config, "untrained")
savior.inference("abcw_t_r","abcdefg")

{'<pad>': 0.034826707,
 'a': 0.034826707,
 'b': 0.036144268,
 'c': 0.034826707,
 'd': 0.034826707,
 'e': 0.03831794,
 'f': 0.038168117,
 'g': 0.034826707,
 'h': 0.036300227,
 'i': 0.034826707,
 'j': 0.036289144,
 'k': 0.034826707,
 'l': 0.035730794,
 'm': 0.034826707,
 'n': 0.034826707,
 'o': 0.034826707,
 'p': 0.034826707,
 'q': 0.034826707,
 'r': 0.039371286,
 's': 0.03881577,
 't': 0.037910808,
 'u': 0.034826707,
 'v': 0.034826707,
 'w': 0.034826707,
 'x': 0.034826707,
 'y': 0.034826707,
 'z': 0.036070973,
 '_': 0.034826707}

In [15]:
# for batch_input, batch_available, batch_missed, batch_target in tqdm(dataloader):
#     batch_input = torch.squeeze(batch_input) 
#     batch_available = torch.squeeze(batch_available) 
#     batch_missed = torch.squeeze(batch_missed)
#     batch_target = torch.squeeze(batch_target)
#     print(batch_input.shape, batch_available.shape, batch_missed.shape, batch_target.shape)
#     break

In [23]:
model = HangmanSaver(config, "training_pipeline_testing") 
# Load dataset and create DataLoader
dataset = HangmanTextDataset('./data/train/train_mini.csv', chunk_size=1000)
dataloader = DataLoader(dataset)


# Training loop
import os
resume = True
global epoch
total_epochs = 100
if os.path.exists('./checkpoint.pth') and resume: 
    checkpoint = torch.load('./checkpoint.pth')  
    model.load_model('./checkpoint.pth') 
    epoch = checkpoint['epoch'] + 1
    loss = checkpoint['loss'] 
else: 
    epoch = 0



model.train()
while epoch < total_epochs:
    running_loss = 0.0
    for batch_input, batch_available, batch_target in tqdm(dataloader, desc=f"Epoch {epoch+1}/{total_epochs}"):
        batch_input = torch.squeeze(batch_input) 
        batch_available = torch.squeeze(batch_available) 
        batch_target = torch.squeeze(batch_target)
        
        # Forward pass
        outputs = model(batch_input, batch_available)

        # Compute loss
        loss = model.criterion(outputs, batch_target)

        # Backward pass and optimization
        model.optimizer.zero_grad()
        loss.backward()
        model.optimizer.step()

        running_loss += loss.item() * batch_input.size(0)

    epoch_loss = running_loss / 4120
    print(f"Epoch {epoch+1}/{total_epochs}, Loss: {epoch_loss:.8f}")
    # writer.add_scalar('Loss/train', epoch_loss, epoch)

    # Save checkpoint
    print(model.inference("__ttlebo_e","acdfhijknpqruvwxyz")) #target cenu : cuttlebone
    model.save_model(is_best=True, epoch=epoch, train_loss = epoch_loss, test_loss = 10)
    # torch.save({
    #     'epoch': epoch,
    #     'model_state_dict': model.state_dict(),
    #     'optimizer_state_dict': model.optimizer.state_dict(),
    #     'loss': loss,
    # }, checkpoint_path) 

    epoch += 1

# Close Tensorboard writer
# writer.close()


Epoch 1/100: 0it [00:00, ?it/s]

Epoch 1/100: 5it [00:02,  1.72it/s]


Epoch 1/100, Loss: 0.70275743
{'<pad>': 2.670847e-11, 'a': 0.8640806, 'b': 2.670847e-11, 'c': 2.670847e-11, 'd': 2.670847e-11, 'e': 0.12519504, 'f': 2.670847e-11, 'g': 2.670847e-11, 'h': 2.670847e-11, 'i': 2.670847e-11, 'j': 2.670847e-11, 'k': 2.670847e-11, 'l': 1.0561412e-06, 'm': 2.670847e-11, 'n': 0.0073724072, 'o': 2.670847e-11, 'p': 4.882706e-11, 'q': 2.670847e-11, 'r': 3.7424303e-05, 's': 0.0033135551, 't': 2.670847e-11, 'u': 2.670847e-11, 'v': 2.670847e-11, 'w': 2.670847e-11, 'x': 2.670847e-11, 'y': 2.670847e-11, 'z': 2.670847e-11, '_': 2.670847e-11}
Saved model


Epoch 2/100: 5it [00:02,  1.80it/s]


Epoch 2/100, Loss: 0.70075073
{'<pad>': 3.6617035e-05, 'a': 0.0011246795, 'b': 3.6617035e-05, 'c': 0.00013129057, 'd': 3.6617035e-05, 'e': 0.7226538, 'f': 3.6617035e-05, 'g': 3.6617035e-05, 'h': 3.6617035e-05, 'i': 3.6617035e-05, 'j': 3.6617035e-05, 'k': 3.6617035e-05, 'l': 0.0046695336, 'm': 3.6617035e-05, 'n': 0.09789756, 'o': 3.6617035e-05, 'p': 3.6617035e-05, 'q': 3.6617035e-05, 'r': 0.101768814, 's': 0.07098537, 't': 3.6617035e-05, 'u': 3.6617035e-05, 'v': 3.6617035e-05, 'w': 3.6617035e-05, 'x': 3.6617035e-05, 'y': 3.6617035e-05, 'z': 3.6617035e-05, '_': 3.6617035e-05}
Saved model


Epoch 3/100: 5it [00:02,  1.83it/s]


Epoch 3/100, Loss: 0.69871320
{'<pad>': 0.00048726436, 'a': 0.002332653, 'b': 0.00048726436, 'c': 0.0014737231, 'd': 0.00048726436, 'e': 0.7356715, 'f': 0.00048726436, 'g': 0.00048726436, 'h': 0.00048726436, 'i': 0.00048726436, 'j': 0.00048726436, 'k': 0.00048726436, 'l': 0.023702554, 'm': 0.00048726436, 'n': 0.068097234, 'o': 0.00048726436, 'p': 0.00048726436, 'q': 0.00048726436, 'r': 0.14197421, 's': 0.016515624, 't': 0.00048726436, 'u': 0.00048726436, 'v': 0.00048726436, 'w': 0.00048726436, 'x': 0.00048726436, 'y': 0.00048726436, 'z': 0.00048726436, '_': 0.00048726436}
Saved model


Epoch 4/100: 5it [00:02,  1.83it/s]


Epoch 4/100, Loss: 0.69793641
{'<pad>': 0.0033537713, 'a': 0.00762521, 'b': 0.0033537713, 'c': 0.013242745, 'd': 0.0033537713, 'e': 0.5824233, 'f': 0.0033537713, 'g': 0.0033537713, 'h': 0.0033537713, 'i': 0.0033537713, 'j': 0.0033537713, 'k': 0.0033537713, 'l': 0.06973387, 'm': 0.0033537713, 'n': 0.11625142, 'o': 0.0033537713, 'p': 0.0033537713, 'q': 0.0033537713, 'r': 0.11334169, 's': 0.02695262, 't': 0.0033537713, 'u': 0.0033537713, 'v': 0.0033537713, 'w': 0.0033537713, 'x': 0.0033537713, 'y': 0.0033537713, 'z': 0.0033537713, '_': 0.0033537713}
Saved model


Epoch 5/100: 5it [00:02,  1.76it/s]


Epoch 5/100, Loss: 0.69667251
{'<pad>': 0.026523491, 'a': 0.0334338, 'b': 0.026523491, 'c': 0.053806774, 'd': 0.026523491, 'e': 0.09979557, 'f': 0.026523491, 'g': 0.026523491, 'h': 0.026523491, 'i': 0.026523491, 'j': 0.026523491, 'k': 0.026523491, 'l': 0.0573045, 'm': 0.026523491, 'n': 0.08472365, 'o': 0.026523491, 'p': 0.026523491, 'q': 0.026523491, 'r': 0.06394309, 's': 0.049999293, 't': 0.026523491, 'u': 0.026523491, 'v': 0.026523491, 'w': 0.026523491, 'x': 0.026523491, 'y': 0.026523491, 'z': 0.026523491, '_': 0.026523491}
Saved model


Epoch 6/100: 5it [00:02,  1.77it/s]


Epoch 6/100, Loss: 0.69595579
{'<pad>': 0.014760262, 'a': 0.027546916, 'b': 0.014760262, 'c': 0.10595746, 'd': 0.014760262, 'e': 0.1377247, 'f': 0.014760262, 'g': 0.014760262, 'h': 0.014760262, 'i': 0.014760262, 'j': 0.014760262, 'k': 0.014760262, 'l': 0.09455624, 'm': 0.014760262, 'n': 0.16711615, 'o': 0.014760262, 'p': 0.014760262, 'q': 0.014760262, 'r': 0.10224505, 's': 0.054887876, 't': 0.014760262, 'u': 0.014760262, 'v': 0.014760262, 'w': 0.014760262, 'x': 0.014760262, 'y': 0.014760262, 'z': 0.014760262, '_': 0.014760262}
Saved model


Epoch 7/100: 5it [00:02,  1.74it/s]


Epoch 7/100, Loss: 0.69488786
{'<pad>': 1.1830107e-06, 'a': 9.629933e-05, 'b': 1.1830107e-06, 'c': 0.011364705, 'd': 1.1830107e-06, 'e': 0.00054325804, 'f': 1.1830107e-06, 'g': 1.1830107e-06, 'h': 1.1830107e-06, 'i': 1.1830107e-06, 'j': 1.1830107e-06, 'k': 1.1830107e-06, 'l': 0.23763426, 'm': 1.1830107e-06, 'n': 0.3714381, 'o': 1.1830107e-06, 'p': 2.5763568e-06, 'q': 1.1830107e-06, 'r': 0.37741804, 's': 0.0014791379, 't': 1.1830107e-06, 'u': 1.1830107e-06, 'v': 1.1830107e-06, 'w': 1.1830107e-06, 'x': 1.1830107e-06, 'y': 1.1830107e-06, 'z': 1.1830107e-06, '_': 1.1830107e-06}
Saved model


Epoch 8/100: 5it [00:02,  1.76it/s]


Epoch 8/100, Loss: 0.69414578
{'<pad>': 1.3446787e-07, 'a': 1.4569369e-05, 'b': 1.3446787e-07, 'c': 0.0006706862, 'd': 1.3446787e-07, 'e': 0.005259878, 'f': 1.3446787e-07, 'g': 1.3446787e-07, 'h': 1.3446787e-07, 'i': 1.3446787e-07, 'j': 1.3446787e-07, 'k': 1.3446787e-07, 'l': 0.009890415, 'm': 1.3446787e-07, 'n': 0.32857457, 'o': 1.3446787e-07, 'p': 1.3446787e-07, 'q': 1.3446787e-07, 'r': 0.6552295, 's': 0.00035762202, 't': 1.3446787e-07, 'u': 1.3446787e-07, 'v': 1.3446787e-07, 'w': 1.3446787e-07, 'x': 1.3446787e-07, 'y': 1.3446787e-07, 'z': 1.3446787e-07, '_': 1.3446787e-07}
Saved model


Epoch 9/100: 5it [00:02,  1.88it/s]


Epoch 9/100, Loss: 0.69443089
{'<pad>': 3.254673e-09, 'a': 6.502918e-06, 'b': 2.8079787e-09, 'c': 1.3122858e-05, 'd': 2.8079787e-09, 'e': 3.095196e-07, 'f': 2.8079787e-09, 'g': 2.8079787e-09, 'h': 2.8079787e-09, 'i': 2.8079787e-09, 'j': 2.8079787e-09, 'k': 2.8079787e-09, 'l': 0.009721767, 'm': 2.8079787e-09, 'n': 0.72270095, 'o': 2.8079787e-09, 'p': 2.8079787e-09, 'q': 2.8079787e-09, 'r': 0.2674292, 's': 0.00012808926, 't': 2.8079787e-09, 'u': 2.8079787e-09, 'v': 2.8079787e-09, 'w': 2.8079787e-09, 'x': 2.8079787e-09, 'y': 2.8079787e-09, 'z': 2.8079787e-09, '_': 2.8079787e-09}
Saved model


Epoch 10/100: 5it [00:02,  1.79it/s]


Epoch 10/100, Loss: 0.69541603
{'<pad>': 4.1946387e-05, 'a': 0.0016093218, 'b': 4.1946387e-05, 'c': 0.0023318978, 'd': 4.1946387e-05, 'e': 0.00084645755, 'f': 4.1946387e-05, 'g': 4.1946387e-05, 'h': 4.1946387e-05, 'i': 4.1946387e-05, 'j': 4.1946387e-05, 'k': 4.1946387e-05, 'l': 0.15814717, 'm': 4.1946387e-05, 'n': 0.18342279, 'o': 4.1946387e-05, 'p': 4.1946387e-05, 'q': 4.1946387e-05, 'r': 0.6441052, 's': 0.008656317, 't': 4.1946387e-05, 'u': 4.1946387e-05, 'v': 4.1946387e-05, 'w': 4.1946387e-05, 'x': 4.1946387e-05, 'y': 4.1946387e-05, 'z': 4.1946387e-05, '_': 4.1946387e-05}
Saved model


Epoch 11/100: 5it [00:02,  1.84it/s]


Epoch 11/100, Loss: 0.69424724
{'<pad>': 0.00031786875, 'a': 0.0074644224, 'b': 0.00031786875, 'c': 0.0056147203, 'd': 0.00031786875, 'e': 0.0019656944, 'f': 0.00031786875, 'g': 0.00031786875, 'h': 0.00031786875, 'i': 0.00031786875, 'j': 0.00031786875, 'k': 0.00031786875, 'l': 0.3458148, 'm': 0.00031786875, 'n': 0.27490577, 'o': 0.00031786875, 'p': 0.00031786875, 'q': 0.00031786875, 'r': 0.33086535, 's': 0.026693942, 't': 0.00031786875, 'u': 0.00031786875, 'v': 0.00031786875, 'w': 0.00031786875, 'x': 0.00031786875, 'y': 0.00031786875, 'z': 0.00031786875, '_': 0.00031786875}
Saved model


Epoch 12/100: 5it [00:02,  1.82it/s]


Epoch 12/100, Loss: 0.69406026
{'<pad>': 0.0020767734, 'a': 0.042494982, 'b': 0.0020767734, 'c': 0.013871061, 'd': 0.0020767734, 'e': 0.0069086268, 'f': 0.0020767734, 'g': 0.0020767734, 'h': 0.0020767734, 'i': 0.0020767734, 'j': 0.0020767734, 'k': 0.0020767734, 'l': 0.116599984, 'm': 0.0020767734, 'n': 0.41025585, 'o': 0.0020767734, 'p': 0.0020767734, 'q': 0.0020767734, 'r': 0.29617712, 's': 0.07008017, 't': 0.0020767734, 'u': 0.0020767734, 'v': 0.0020767734, 'w': 0.0020767734, 'x': 0.0020767734, 'y': 0.0020767734, 'z': 0.0020767734, '_': 0.0020767734}
Saved model


Epoch 13/100: 5it [00:02,  1.74it/s]


Epoch 13/100, Loss: 0.69414757
{'<pad>': 0.0010303074, 'a': 0.08179682, 'b': 0.0010303074, 'c': 0.0070659947, 'd': 0.0010303074, 'e': 0.003323877, 'f': 0.0010303074, 'g': 0.0010303074, 'h': 0.0010303074, 'i': 0.0010303074, 'j': 0.0010303074, 'k': 0.0010303074, 'l': 0.0506912, 'm': 0.0010303074, 'n': 0.40850988, 'o': 0.0010303074, 'p': 0.0010303074, 'q': 0.0010303074, 'r': 0.37605602, 's': 0.050919633, 't': 0.0010303074, 'u': 0.0010303074, 'v': 0.0010303074, 'w': 0.0010303074, 'x': 0.0010303074, 'y': 0.0010303074, 'z': 0.0010303074, '_': 0.0010303074}
Saved model


Epoch 14/100: 5it [00:02,  1.77it/s]


Epoch 14/100, Loss: 0.69390309
{'<pad>': 0.00019718276, 'a': 0.31198654, 'b': 0.00019718276, 'c': 0.0013477611, 'd': 0.00019718276, 'e': 0.0005212587, 'f': 0.00019718276, 'g': 0.00019718276, 'h': 0.00019718276, 'i': 0.00019718276, 'j': 0.00019718276, 'k': 0.00019718276, 'l': 0.017473007, 'm': 0.00019718276, 'n': 0.3134519, 'o': 0.00019718276, 'p': 0.00019718276, 'q': 0.00019718276, 'r': 0.32833675, 's': 0.022741823, 't': 0.00019718276, 'u': 0.00019718276, 'v': 0.00019718276, 'w': 0.00019718276, 'x': 0.00019718276, 'y': 0.00019718276, 'z': 0.00019718276, '_': 0.00019718276}
Saved model


Epoch 15/100: 5it [00:02,  1.79it/s]


Epoch 15/100, Loss: 0.69374218
{'<pad>': 6.0050814e-05, 'a': 0.32884118, 'b': 6.0050814e-05, 'c': 0.00043648708, 'd': 6.0050814e-05, 'e': 0.00017923536, 'f': 6.0050814e-05, 'g': 6.0050814e-05, 'h': 6.0050814e-05, 'i': 6.0050814e-05, 'j': 6.0050814e-05, 'k': 6.0050814e-05, 'l': 0.010071933, 'm': 6.0050814e-05, 'n': 0.28877676, 'o': 6.0050814e-05, 'p': 6.0050814e-05, 'q': 6.0050814e-05, 'r': 0.35868073, 's': 0.011752648, 't': 6.0050814e-05, 'u': 6.0050814e-05, 'v': 6.0050814e-05, 'w': 6.0050814e-05, 'x': 6.0050814e-05, 'y': 6.0050814e-05, 'z': 6.0050814e-05, '_': 6.0050814e-05}
Saved model


Epoch 16/100: 5it [00:02,  1.81it/s]


Epoch 16/100, Loss: 0.69379361
{'<pad>': 6.863594e-05, 'a': 0.23944637, 'b': 6.863594e-05, 'c': 0.0004568661, 'd': 6.863594e-05, 'e': 0.00020533414, 'f': 6.863594e-05, 'g': 6.863594e-05, 'h': 6.863594e-05, 'i': 6.863594e-05, 'j': 6.863594e-05, 'k': 6.863594e-05, 'l': 0.012788845, 'm': 6.863594e-05, 'n': 0.37879652, 'o': 6.863594e-05, 'p': 6.863594e-05, 'q': 6.863594e-05, 'r': 0.3547594, 's': 0.012105384, 't': 6.863594e-05, 'u': 6.863594e-05, 'v': 6.863594e-05, 'w': 6.863594e-05, 'x': 6.863594e-05, 'y': 6.863594e-05, 'z': 6.863594e-05, '_': 6.863594e-05}
Saved model


Epoch 17/100: 5it [00:02,  1.78it/s]


Epoch 17/100, Loss: 0.69374065
{'<pad>': 8.1888975e-05, 'a': 0.43032625, 'b': 8.1888975e-05, 'c': 0.00040945268, 'd': 8.1888975e-05, 'e': 0.00015422836, 'f': 8.1888975e-05, 'g': 8.1888975e-05, 'h': 8.1888975e-05, 'i': 8.1888975e-05, 'j': 8.1888975e-05, 'k': 8.1888975e-05, 'l': 0.013770085, 'm': 8.1888975e-05, 'n': 0.31308514, 'o': 8.1888975e-05, 'p': 8.1888975e-05, 'q': 8.1888975e-05, 'r': 0.22805184, 's': 0.012483417, 't': 8.1888975e-05, 'u': 8.1888975e-05, 'v': 8.1888975e-05, 'w': 8.1888975e-05, 'x': 8.1888975e-05, 'y': 8.1888975e-05, 'z': 8.1888975e-05, '_': 8.1888975e-05}
Saved model


Epoch 18/100: 5it [00:02,  1.74it/s]


Epoch 18/100, Loss: 0.69372114
{'<pad>': 0.000112951064, 'a': 0.2565288, 'b': 0.000112951064, 'c': 0.00058264035, 'd': 0.000112951064, 'e': 0.00028483468, 'f': 0.000112951064, 'g': 0.000112951064, 'h': 0.000112951064, 'i': 0.000112951064, 'j': 0.000112951064, 'k': 0.000112951064, 'l': 0.021833828, 'm': 0.000112951064, 'n': 0.3424815, 'o': 0.000112951064, 'p': 0.000112951064, 'q': 0.000112951064, 'r': 0.36046562, 's': 0.015450776, 't': 0.000112951064, 'u': 0.000112951064, 'v': 0.000112951064, 'w': 0.000112951064, 'x': 0.000112951064, 'y': 0.000112951064, 'z': 0.000112951064, '_': 0.000112951064}
Saved model


Epoch 19/100: 5it [00:02,  1.85it/s]


Epoch 19/100, Loss: 0.69371603
{'<pad>': 0.00013435884, 'a': 0.34914255, 'b': 0.00013435884, 'c': 0.00057818193, 'd': 0.00013435884, 'e': 0.00026106247, 'f': 0.00013435884, 'g': 0.00013435884, 'h': 0.00013435884, 'i': 0.00013435884, 'j': 0.00013435884, 'k': 0.00013435884, 'l': 0.027809337, 'm': 0.00013435884, 'n': 0.28077397, 'o': 0.00013435884, 'p': 0.00013435884, 'q': 0.00013435884, 'r': 0.32172894, 's': 0.016884357, 't': 0.00013435884, 'u': 0.00013435884, 'v': 0.00013435884, 'w': 0.00013435884, 'x': 0.00013435884, 'y': 0.00013435884, 'z': 0.00013435884, '_': 0.00013435884}
Saved model


Epoch 20/100: 5it [00:02,  1.81it/s]


Epoch 20/100, Loss: 0.69362433
{'<pad>': 0.00011967887, 'a': 0.31521356, 'b': 0.00011967887, 'c': 0.00050085364, 'd': 0.00011967887, 'e': 0.0002349623, 'f': 0.00011967887, 'g': 0.00011967887, 'h': 0.00011967887, 'i': 0.00011967887, 'j': 0.00011967887, 'k': 0.00011967887, 'l': 0.03673875, 'm': 0.00011967887, 'n': 0.26609844, 'o': 0.00011967887, 'p': 0.00011967887, 'q': 0.00011967887, 'r': 0.36255452, 's': 0.016145704, 't': 0.00011967887, 'u': 0.00011967887, 'v': 0.00011967887, 'w': 0.00011967887, 'x': 0.00011967887, 'y': 0.00011967887, 'z': 0.00011967887, '_': 0.00011967887}
Saved model


Epoch 21/100: 5it [00:02,  1.86it/s]


Epoch 21/100, Loss: 0.69346303
{'<pad>': 9.719881e-05, 'a': 0.32488918, 'b': 9.719881e-05, 'c': 0.00036769098, 'd': 9.719881e-05, 'e': 0.00016390323, 'f': 9.719881e-05, 'g': 9.719881e-05, 'h': 9.719881e-05, 'i': 9.719881e-05, 'j': 9.719881e-05, 'k': 9.719881e-05, 'l': 0.044692572, 'm': 9.719881e-05, 'n': 0.2375696, 'o': 9.719881e-05, 'p': 9.719881e-05, 'q': 9.719881e-05, 'r': 0.37660465, 's': 0.013671199, 't': 9.719881e-05, 'u': 9.719881e-05, 'v': 9.719881e-05, 'w': 9.719881e-05, 'x': 9.719881e-05, 'y': 9.719881e-05, 'z': 9.719881e-05, '_': 9.719881e-05}
Saved model


Epoch 22/100: 5it [00:02,  1.81it/s]


Epoch 22/100, Loss: 0.69299324
{'<pad>': 2.0535525e-05, 'a': 0.14136936, 'b': 2.0535525e-05, 'c': 0.000105672196, 'd': 2.0535525e-05, 'e': 6.945528e-05, 'f': 2.0535525e-05, 'g': 2.0535525e-05, 'h': 2.0535525e-05, 'i': 2.0535525e-05, 'j': 2.0535525e-05, 'k': 2.0535525e-05, 'l': 0.031218374, 'm': 2.0535525e-05, 'n': 0.14450546, 'o': 2.0535525e-05, 'p': 2.0535525e-05, 'q': 2.0535525e-05, 'r': 0.67720306, 's': 0.005018132, 't': 9.968898e-05, 'u': 2.0535525e-05, 'v': 2.0535525e-05, 'w': 2.0535525e-05, 'x': 2.0535525e-05, 'y': 2.0535525e-05, 'z': 2.0535525e-05, '_': 2.0535525e-05}
Saved model


Epoch 23/100: 5it [00:02,  1.85it/s]


Epoch 23/100, Loss: 0.69262885
{'<pad>': 1.5669033e-05, 'a': 0.17695737, 'b': 1.5669033e-05, 'c': 6.5214954e-05, 'd': 1.5669033e-05, 'e': 3.7356876e-05, 'f': 1.5669033e-05, 'g': 1.5669033e-05, 'h': 1.5669033e-05, 'i': 1.5669033e-05, 'j': 1.5669033e-05, 'k': 1.5669033e-05, 'l': 0.026729386, 'm': 1.5669033e-05, 'n': 0.103721365, 'o': 1.5669033e-05, 'p': 1.5669033e-05, 'q': 1.5669033e-05, 'r': 0.68882436, 's': 0.003286819, 't': 6.472108e-05, 'u': 1.5669033e-05, 'v': 1.5669033e-05, 'w': 1.5669033e-05, 'x': 1.5669033e-05, 'y': 1.5669033e-05, 'z': 1.5669033e-05, '_': 1.5669033e-05}
Saved model


Epoch 24/100: 5it [00:02,  1.84it/s]


Epoch 24/100, Loss: 0.69246142
{'<pad>': 2.9506928e-05, 'a': 0.26940516, 'b': 2.9506928e-05, 'c': 7.672815e-05, 'd': 2.9506928e-05, 'e': 4.273796e-05, 'f': 2.9506928e-05, 'g': 2.9506928e-05, 'h': 2.9506928e-05, 'i': 2.9506928e-05, 'j': 2.9506928e-05, 'k': 2.9506928e-05, 'l': 0.037680116, 'm': 2.9506928e-05, 'n': 0.121327855, 'o': 2.9506928e-05, 'p': 2.9506928e-05, 'q': 2.9506928e-05, 'r': 0.56765175, 's': 0.0030827418, 't': 0.00014271075, 'u': 2.9506928e-05, 'v': 2.9506928e-05, 'w': 2.9506928e-05, 'x': 2.9506928e-05, 'y': 2.9506928e-05, 'z': 2.9506928e-05, '_': 2.9506928e-05}
Saved model


Epoch 25/100: 5it [00:02,  1.77it/s]


Epoch 25/100, Loss: 0.69218004
{'<pad>': 2.059027e-05, 'a': 0.09245322, 'b': 2.059027e-05, 'c': 6.553273e-05, 'd': 2.059027e-05, 'e': 4.9561142e-05, 'f': 2.059027e-05, 'g': 2.059027e-05, 'h': 2.059027e-05, 'i': 2.059027e-05, 'j': 2.059027e-05, 'k': 2.059027e-05, 'l': 0.0324693, 'm': 2.059027e-05, 'n': 0.12611176, 'o': 2.059027e-05, 'p': 2.059027e-05, 'q': 2.059027e-05, 'r': 0.74617386, 's': 0.0021477113, 't': 0.000117270334, 'u': 2.059027e-05, 'v': 2.059027e-05, 'w': 2.059027e-05, 'x': 2.059027e-05, 'y': 2.059027e-05, 'z': 2.059027e-05, '_': 2.059027e-05}
Saved model


Epoch 26/100: 5it [00:02,  1.82it/s]


Epoch 26/100, Loss: 0.69215299
{'<pad>': 1.5314457e-05, 'a': 0.045255404, 'b': 1.5314457e-05, 'c': 4.6778605e-05, 'd': 1.5314457e-05, 'e': 3.9875224e-05, 'f': 1.5314457e-05, 'g': 1.5314457e-05, 'h': 1.5314457e-05, 'i': 1.5314457e-05, 'j': 1.5314457e-05, 'k': 1.5314457e-05, 'l': 0.025309002, 'm': 1.5314457e-05, 'n': 0.15649727, 'o': 1.5314457e-05, 'p': 1.5314457e-05, 'q': 1.5314457e-05, 'r': 0.7709158, 's': 0.0015514136, 't': 7.807778e-05, 'u': 1.5314457e-05, 'v': 1.5314457e-05, 'w': 1.5314457e-05, 'x': 1.5314457e-05, 'y': 1.5314457e-05, 'z': 1.5314457e-05, '_': 1.5314457e-05}
Saved model


Epoch 27/100: 5it [00:02,  1.82it/s]


Epoch 27/100, Loss: 0.69250888
{'<pad>': 0.00014167225, 'a': 0.725335, 'b': 0.00014167225, 'c': 0.00014167225, 'd': 0.00014167225, 'e': 0.00014167225, 'f': 0.00014167225, 'g': 0.00014167225, 'h': 0.00014167225, 'i': 0.00014167225, 'j': 0.00014167225, 'k': 0.00014167225, 'l': 0.014689855, 'm': 0.00014167225, 'n': 0.21868415, 'o': 0.00014167225, 'p': 0.00014167225, 'q': 0.00014167225, 'r': 0.03529395, 's': 0.002470092, 't': 0.00041041616, 'u': 0.00014167225, 'v': 0.00014167225, 'w': 0.00014167225, 'x': 0.00014167225, 'y': 0.00014167225, 'z': 0.00014167225, '_': 0.00014167225}
Saved model


Epoch 28/100: 5it [00:02,  1.83it/s]


Epoch 28/100, Loss: 0.69241653
{'<pad>': 0.00022640715, 'a': 0.5984545, 'b': 0.00022640715, 'c': 0.00022640715, 'd': 0.00022640715, 'e': 0.00022640715, 'f': 0.00022640715, 'g': 0.00022640715, 'h': 0.00022640715, 'i': 0.00022640715, 'j': 0.00022640715, 'k': 0.00022640715, 'l': 0.018048104, 'm': 0.00022640715, 'n': 0.3296839, 'o': 0.00022640715, 'p': 0.00022640715, 'q': 0.00022640715, 'r': 0.045114364, 's': 0.0033603306, 't': 0.0003577432, 'u': 0.00022640715, 'v': 0.00022640715, 'w': 0.00022640715, 'x': 0.00022640715, 'y': 0.00022640715, 'z': 0.00022640715, '_': 0.00022640715}
Saved model


Epoch 29/100: 5it [00:02,  1.86it/s]


Epoch 29/100, Loss: 0.69239784
{'<pad>': 0.00047128892, 'a': 0.50287604, 'b': 0.00047128892, 'c': 0.00047128892, 'd': 0.00047128892, 'e': 0.00047128892, 'f': 0.00047128892, 'g': 0.00047128892, 'h': 0.00047128892, 'i': 0.00047128892, 'j': 0.00047128892, 'k': 0.00047128892, 'l': 0.037673805, 'm': 0.00047128892, 'n': 0.32483634, 'o': 0.00047128892, 'p': 0.00047128892, 'q': 0.00047128892, 'r': 0.11610316, 's': 0.007620994, 't': 0.0005212466, 'u': 0.00047128892, 'v': 0.00047128892, 'w': 0.00047128892, 'x': 0.00047128892, 'y': 0.00047128892, 'z': 0.00047128892, '_': 0.00047128892}
Saved model


Epoch 30/100: 5it [00:02,  1.84it/s]


Epoch 30/100, Loss: 0.69235342
{'<pad>': 0.00014732027, 'a': 0.9573579, 'b': 0.00014732027, 'c': 0.00014732027, 'd': 0.00014732027, 'e': 0.00014732027, 'f': 0.00014732027, 'g': 0.00014732027, 'h': 0.00014732027, 'i': 0.00014732027, 'j': 0.00014732027, 'k': 0.00014732027, 'l': 0.0026549902, 'm': 0.00014732027, 'n': 0.03317964, 'o': 0.00014732027, 'p': 0.00014732027, 'q': 0.00014732027, 'r': 0.0019620387, 's': 0.0013437216, 't': 0.00026064515, 'u': 0.00014732027, 'v': 0.00014732027, 'w': 0.00014732027, 'x': 0.00014732027, 'y': 0.00014732027, 'z': 0.00014732027, '_': 0.00014732027}
Saved model


Epoch 31/100: 5it [00:02,  1.75it/s]


Epoch 31/100, Loss: 0.69216318
{'<pad>': 0.0007274574, 'a': 0.2723506, 'b': 0.0007274574, 'c': 0.0007274574, 'd': 0.0007274574, 'e': 0.0007274574, 'f': 0.0007274574, 'g': 0.0007274574, 'h': 0.0007274574, 'i': 0.0007274574, 'j': 0.0007274574, 'k': 0.0007274574, 'l': 0.07520071, 'm': 0.0007274574, 'n': 0.3825129, 'o': 0.0007274574, 'p': 0.0007274574, 'q': 0.0007274574, 'r': 0.23460354, 's': 0.01645572, 't': 0.002872448, 'u': 0.0007274574, 'v': 0.0007274574, 'w': 0.0007274574, 'x': 0.0007274574, 'y': 0.0007274574, 'z': 0.0007274574, '_': 0.0007274574}
Saved model


Epoch 32/100: 5it [00:02,  1.75it/s]


Epoch 32/100, Loss: 0.69226311
{'<pad>': 0.0005226736, 'a': 0.76232606, 'b': 0.0005226736, 'c': 0.0005226736, 'd': 0.0005226736, 'e': 0.0005226736, 'f': 0.0005226736, 'g': 0.0005226736, 'h': 0.0005226736, 'i': 0.0005226736, 'j': 0.0005226736, 'k': 0.0005226736, 'l': 0.018799882, 'm': 0.0005226736, 'n': 0.17325509, 'o': 0.0005226736, 'p': 0.0005226736, 'q': 0.0005226736, 'r': 0.024691384, 's': 0.008361813, 't': 0.0010670073, 'u': 0.0005226736, 'v': 0.0005226736, 'w': 0.0005226736, 'x': 0.0005226736, 'y': 0.0005226736, 'z': 0.0005226736, '_': 0.0005226736}
Saved model


Epoch 33/100: 5it [00:03,  1.66it/s]


Epoch 33/100, Loss: 0.69202072
{'<pad>': 0.0002675741, 'a': 0.53005326, 'b': 0.0002675741, 'c': 0.0002675741, 'd': 0.0002675741, 'e': 0.0002675741, 'f': 0.0002675741, 'g': 0.0002675741, 'h': 0.0002675741, 'i': 0.0002675741, 'j': 0.0002675741, 'k': 0.0002675741, 'l': 0.029726757, 'm': 0.0002675741, 'n': 0.3106784, 'o': 0.0002675741, 'p': 0.0002675741, 'q': 0.0002675741, 'r': 0.11422229, 's': 0.009165035, 't': 0.0002675741, 'u': 0.0002675741, 'v': 0.0002675741, 'w': 0.0002675741, 'x': 0.0002675741, 'y': 0.0002675741, 'z': 0.0002675741, '_': 0.0002675741}
Saved model


Epoch 34/100: 5it [00:02,  1.67it/s]


Epoch 34/100, Loss: 0.69208370
{'<pad>': 9.665286e-05, 'a': 0.7166744, 'b': 9.665286e-05, 'c': 9.665286e-05, 'd': 9.665286e-05, 'e': 9.665286e-05, 'f': 9.665286e-05, 'g': 9.665286e-05, 'h': 9.665286e-05, 'i': 9.665286e-05, 'j': 9.665286e-05, 'k': 9.665286e-05, 'l': 0.009837184, 'm': 9.665286e-05, 'n': 0.22299987, 'o': 9.665286e-05, 'p': 9.665286e-05, 'q': 9.665286e-05, 'r': 0.043795932, 's': 0.0044696946, 't': 9.665286e-05, 'u': 9.665286e-05, 'v': 9.665286e-05, 'w': 9.665286e-05, 'x': 9.665286e-05, 'y': 9.665286e-05, 'z': 9.665286e-05, '_': 9.665286e-05}
Saved model


Epoch 35/100: 5it [00:02,  1.72it/s]


Epoch 35/100, Loss: 0.69195850
{'<pad>': 4.8289043e-05, 'a': 0.7421362, 'b': 4.8289043e-05, 'c': 4.8289043e-05, 'd': 4.8289043e-05, 'e': 4.8289043e-05, 'f': 4.8289043e-05, 'g': 4.8289043e-05, 'h': 4.8289043e-05, 'i': 4.8289043e-05, 'j': 4.8289043e-05, 'k': 4.8289043e-05, 'l': 0.0055991383, 'm': 4.8289043e-05, 'n': 0.2130559, 'o': 4.8289043e-05, 'p': 4.8289043e-05, 'q': 4.8289043e-05, 'r': 0.034969863, 's': 0.0031283083, 't': 4.8289043e-05, 'u': 4.8289043e-05, 'v': 4.8289043e-05, 'w': 4.8289043e-05, 'x': 4.8289043e-05, 'y': 4.8289043e-05, 'z': 4.8289043e-05, '_': 4.8289043e-05}
Saved model


Epoch 36/100: 5it [00:02,  1.80it/s]


Epoch 36/100, Loss: 0.69191674
{'<pad>': 3.24448e-05, 'a': 0.60083663, 'b': 3.24448e-05, 'c': 3.24448e-05, 'd': 3.24448e-05, 'e': 3.24448e-05, 'f': 3.24448e-05, 'g': 3.24448e-05, 'h': 3.24448e-05, 'i': 3.24448e-05, 'j': 3.24448e-05, 'k': 3.24448e-05, 'l': 0.005920856, 'm': 3.24448e-05, 'n': 0.33084786, 'o': 3.24448e-05, 'p': 3.24448e-05, 'q': 3.24448e-05, 'r': 0.058404244, 's': 0.0032441386, 't': 3.24448e-05, 'u': 3.24448e-05, 'v': 3.24448e-05, 'w': 3.24448e-05, 'x': 3.24448e-05, 'y': 3.24448e-05, 'z': 3.24448e-05, '_': 3.24448e-05}
Saved model


Epoch 37/100: 5it [00:02,  1.73it/s]


Epoch 37/100, Loss: 0.69190630
{'<pad>': 2.0542137e-05, 'a': 0.42345706, 'b': 2.0542137e-05, 'c': 2.0542137e-05, 'd': 2.0542137e-05, 'e': 2.0542137e-05, 'f': 2.0542137e-05, 'g': 2.0542137e-05, 'h': 2.0542137e-05, 'i': 2.0542137e-05, 'j': 2.0542137e-05, 'k': 2.0542137e-05, 'l': 0.0070621045, 'm': 2.0542137e-05, 'n': 0.4540926, 'o': 2.0542137e-05, 'p': 2.0542137e-05, 'q': 2.0542137e-05, 'r': 0.11147955, 's': 0.003436234, 't': 2.0542137e-05, 'u': 2.0542137e-05, 'v': 2.0542137e-05, 'w': 2.0542137e-05, 'x': 2.0542137e-05, 'y': 2.0542137e-05, 'z': 2.0542137e-05, '_': 2.0542137e-05}
Saved model


Epoch 38/100: 5it [00:02,  1.76it/s]


Epoch 38/100, Loss: 0.69199154
{'<pad>': 1.463338e-05, 'a': 0.72926646, 'b': 1.463338e-05, 'c': 1.463338e-05, 'd': 1.463338e-05, 'e': 1.463338e-05, 'f': 1.463338e-05, 'g': 1.463338e-05, 'h': 1.463338e-05, 'i': 1.463338e-05, 'j': 1.463338e-05, 'k': 1.463338e-05, 'l': 0.0019831373, 'm': 1.463338e-05, 'n': 0.24537875, 'o': 1.463338e-05, 'p': 1.463338e-05, 'q': 1.463338e-05, 'r': 0.020838384, 's': 0.0021967369, 't': 1.463338e-05, 'u': 1.463338e-05, 'v': 1.463338e-05, 'w': 1.463338e-05, 'x': 1.463338e-05, 'y': 1.463338e-05, 'z': 1.463338e-05, '_': 1.463338e-05}
Saved model


Epoch 39/100: 5it [00:02,  1.73it/s]


Epoch 39/100, Loss: 0.69194214
{'<pad>': 6.1116352e-06, 'a': 0.92963684, 'b': 6.1116352e-06, 'c': 6.1116352e-06, 'd': 6.1116352e-06, 'e': 6.1116352e-06, 'f': 6.1116352e-06, 'g': 6.1116352e-06, 'h': 6.1116352e-06, 'i': 6.1116352e-06, 'j': 6.1116352e-06, 'k': 6.1116352e-06, 'l': 0.0002664207, 'm': 6.1116352e-06, 'n': 0.067723945, 'o': 6.1116352e-06, 'p': 6.1116352e-06, 'q': 6.1116352e-06, 'r': 0.0015074025, 's': 0.00072499714, 't': 6.1116352e-06, 'u': 6.1116352e-06, 'v': 6.1116352e-06, 'w': 6.1116352e-06, 'x': 6.1116352e-06, 'y': 6.1116352e-06, 'z': 6.1116352e-06, '_': 6.1116352e-06}
Saved model


Epoch 40/100: 5it [00:02,  1.78it/s]


Epoch 40/100, Loss: 0.69181881
{'<pad>': 6.4883634e-06, 'a': 0.9343695, 'b': 6.4883634e-06, 'c': 6.4883634e-06, 'd': 6.4883634e-06, 'e': 6.4883634e-06, 'f': 6.4883634e-06, 'g': 6.4883634e-06, 'h': 6.4883634e-06, 'i': 6.4883634e-06, 'j': 6.4883634e-06, 'k': 6.4883634e-06, 'l': 0.00023237776, 'm': 6.4883634e-06, 'n': 0.06323337, 'o': 6.4883634e-06, 'p': 6.4883634e-06, 'q': 6.4883634e-06, 'r': 0.0011501038, 's': 0.000865423, 't': 6.4883634e-06, 'u': 6.4883634e-06, 'v': 6.4883634e-06, 'w': 6.4883634e-06, 'x': 6.4883634e-06, 'y': 6.4883634e-06, 'z': 6.4883634e-06, '_': 6.4883634e-06}
Saved model


Epoch 41/100: 5it [00:03,  1.65it/s]


Epoch 41/100, Loss: 0.69188755
{'<pad>': 1.6604015e-05, 'a': 0.24790242, 'b': 1.6604015e-05, 'c': 1.6604015e-05, 'd': 1.6604015e-05, 'e': 1.6604015e-05, 'f': 1.6604015e-05, 'g': 1.6604015e-05, 'h': 1.6604015e-05, 'i': 1.6604015e-05, 'j': 1.6604015e-05, 'k': 1.6604015e-05, 'l': 0.00432697, 'm': 1.6604015e-05, 'n': 0.6853473, 'o': 1.6604015e-05, 'p': 1.6604015e-05, 'q': 1.6604015e-05, 'r': 0.055529837, 's': 0.006511553, 't': 1.6604015e-05, 'u': 1.6604015e-05, 'v': 1.6604015e-05, 'w': 1.6604015e-05, 'x': 1.6604015e-05, 'y': 1.6604015e-05, 'z': 1.6604015e-05, '_': 1.6604015e-05}
Saved model


Epoch 42/100: 5it [00:02,  1.74it/s]


Epoch 42/100, Loss: 0.69204533
{'<pad>': 2.8014261e-05, 'a': 0.72375834, 'b': 2.8014261e-05, 'c': 2.8014261e-05, 'd': 2.8014261e-05, 'e': 2.8014261e-05, 'f': 2.8014261e-05, 'g': 2.8014261e-05, 'h': 2.8014261e-05, 'i': 2.8014261e-05, 'j': 2.8014261e-05, 'k': 2.8014261e-05, 'l': 0.0016474839, 'm': 2.8014261e-05, 'n': 0.2581398, 'o': 2.8014261e-05, 'p': 2.8014261e-05, 'q': 2.8014261e-05, 'r': 0.0076907934, 's': 0.008119275, 't': 2.8014261e-05, 'u': 2.8014261e-05, 'v': 2.8014261e-05, 'w': 2.8014261e-05, 'x': 2.8014261e-05, 'y': 2.8014261e-05, 'z': 2.8014261e-05, '_': 2.8014261e-05}
Saved model


Epoch 43/100: 5it [00:02,  1.77it/s]


Epoch 43/100, Loss: 0.69176023
{'<pad>': 2.778332e-05, 'a': 0.903796, 'b': 2.778332e-05, 'c': 2.778332e-05, 'd': 2.778332e-05, 'e': 2.778332e-05, 'f': 2.778332e-05, 'g': 2.778332e-05, 'h': 2.778332e-05, 'i': 2.778332e-05, 'j': 2.778332e-05, 'k': 2.778332e-05, 'l': 0.0008028376, 'm': 2.778332e-05, 'n': 0.085340716, 'o': 2.778332e-05, 'p': 2.778332e-05, 'q': 2.778332e-05, 'r': 0.0023582762, 's': 0.0070630847, 't': 2.778332e-05, 'u': 2.778332e-05, 'v': 2.778332e-05, 'w': 2.778332e-05, 'x': 2.778332e-05, 'y': 2.778332e-05, 'z': 2.778332e-05, '_': 2.778332e-05}
Saved model


Epoch 44/100: 5it [00:02,  1.80it/s]


Epoch 44/100, Loss: 0.69166183
{'<pad>': 2.4417568e-05, 'a': 0.8254129, 'b': 2.4417568e-05, 'c': 2.4417568e-05, 'd': 2.4417568e-05, 'e': 2.4417568e-05, 'f': 2.4417568e-05, 'g': 2.4417568e-05, 'h': 2.4417568e-05, 'i': 2.4417568e-05, 'j': 2.4417568e-05, 'k': 2.4417568e-05, 'l': 0.00072507915, 'm': 2.4417568e-05, 'n': 0.15457751, 'o': 2.4417568e-05, 'p': 2.4417568e-05, 'q': 2.4417568e-05, 'r': 0.0020251865, 's': 0.016697804, 't': 2.4417568e-05, 'u': 2.4417568e-05, 'v': 2.4417568e-05, 'w': 2.4417568e-05, 'x': 2.4417568e-05, 'y': 2.4417568e-05, 'z': 2.4417568e-05, '_': 2.4417568e-05}
Saved model


Epoch 45/100: 5it [00:02,  1.84it/s]


Epoch 45/100, Loss: 0.69156061
{'<pad>': 5.8486457e-05, 'a': 0.8048304, 'b': 5.8486457e-05, 'c': 5.8486457e-05, 'd': 5.8486457e-05, 'e': 5.8486457e-05, 'f': 5.8486457e-05, 'g': 5.8486457e-05, 'h': 5.8486457e-05, 'i': 8.286553e-05, 'j': 5.8486457e-05, 'k': 5.8486457e-05, 'l': 0.0013108545, 'm': 5.8486457e-05, 'n': 0.16724612, 'o': 5.8486457e-05, 'p': 5.8486457e-05, 'q': 5.8486457e-05, 'r': 0.0024561838, 's': 0.022786753, 't': 5.8486457e-05, 'u': 5.8486457e-05, 'v': 5.8486457e-05, 'w': 5.8486457e-05, 'x': 5.8486457e-05, 'y': 5.8486457e-05, 'z': 5.8486457e-05, '_': 5.8486457e-05}
Saved model


Epoch 46/100: 5it [00:02,  1.81it/s]


Epoch 46/100, Loss: 0.69150550
{'<pad>': 7.292874e-06, 'a': 0.87320465, 'b': 7.292874e-06, 'c': 7.292874e-06, 'd': 7.292874e-06, 'e': 7.292874e-06, 'f': 7.292874e-06, 'g': 7.292874e-06, 'h': 7.292874e-06, 'i': 0.0006568885, 'j': 7.292874e-06, 'k': 7.292874e-06, 'l': 0.0001279359, 'm': 7.292874e-06, 'n': 0.11936424, 'o': 7.292874e-06, 'p': 7.292874e-06, 'q': 7.292874e-06, 'r': 0.00017138077, 's': 0.006314437, 't': 7.292874e-06, 'u': 7.292874e-06, 'v': 7.292874e-06, 'w': 7.292874e-06, 'x': 7.292874e-06, 'y': 7.292874e-06, 'z': 7.292874e-06, '_': 7.292874e-06}
Saved model


Epoch 47/100: 5it [00:02,  1.77it/s]


Epoch 47/100, Loss: 0.69139734
{'<pad>': 9.833637e-06, 'a': 0.93170035, 'b': 9.833637e-06, 'c': 9.833637e-06, 'd': 9.833637e-06, 'e': 9.833637e-06, 'f': 9.833637e-06, 'g': 9.833637e-06, 'h': 9.833637e-06, 'i': 0.0332944, 'j': 9.833637e-06, 'k': 9.833637e-06, 'l': 0.00012814069, 'm': 9.833637e-06, 'n': 0.03303095, 'o': 9.833637e-06, 'p': 9.833637e-06, 'q': 9.833637e-06, 'r': 0.0001110684, 's': 0.0015184813, 't': 9.833637e-06, 'u': 9.833637e-06, 'v': 9.833637e-06, 'w': 9.833637e-06, 'x': 9.833637e-06, 'y': 9.833637e-06, 'z': 9.833637e-06, '_': 9.833637e-06}
Saved model


Epoch 48/100: 5it [00:02,  1.79it/s]


Epoch 48/100, Loss: 0.69122031
{'<pad>': 4.696782e-05, 'a': 0.8296305, 'b': 4.696782e-05, 'c': 4.696782e-05, 'd': 4.696782e-05, 'e': 4.696782e-05, 'f': 4.696782e-05, 'g': 4.696782e-05, 'h': 4.696782e-05, 'i': 0.021390125, 'j': 4.696782e-05, 'k': 4.696782e-05, 'l': 0.0012596715, 'm': 4.696782e-05, 'n': 0.14043018, 'o': 4.696782e-05, 'p': 4.696782e-05, 'q': 4.696782e-05, 'r': 0.0021295906, 's': 0.0041264435, 't': 4.696782e-05, 'u': 4.696782e-05, 'v': 4.696782e-05, 'w': 4.696782e-05, 'x': 4.696782e-05, 'y': 4.696782e-05, 'z': 4.696782e-05, '_': 4.696782e-05}
Saved model


Epoch 49/100: 5it [00:02,  1.77it/s]


Epoch 49/100, Loss: 0.69113905
{'<pad>': 2.60707e-06, 'a': 0.994136, 'b': 2.60707e-06, 'c': 2.60707e-06, 'd': 2.60707e-06, 'e': 2.60707e-06, 'f': 2.60707e-06, 'g': 2.60707e-06, 'h': 2.60707e-06, 'i': 0.004255333, 'j': 2.60707e-06, 'k': 2.60707e-06, 'l': 4.698958e-05, 'm': 2.60707e-06, 'n': 0.0013672016, 'o': 2.60707e-06, 'p': 2.60707e-06, 'q': 2.60707e-06, 'r': 4.045956e-05, 's': 9.672783e-05, 't': 2.60707e-06, 'u': 2.60707e-06, 'v': 2.60707e-06, 'w': 2.60707e-06, 'x': 2.60707e-06, 'y': 2.60707e-06, 'z': 2.60707e-06, '_': 2.60707e-06}
Saved model


Epoch 50/100: 5it [00:02,  1.75it/s]


Epoch 50/100, Loss: 0.69145348
{'<pad>': 5.7654197e-05, 'a': 0.44935882, 'b': 5.7654197e-05, 'c': 5.7654197e-05, 'd': 5.7654197e-05, 'e': 5.7654197e-05, 'f': 5.7654197e-05, 'g': 5.7654197e-05, 'h': 5.7654197e-05, 'i': 0.055590738, 'j': 5.7654197e-05, 'k': 5.7654197e-05, 'l': 0.0010891478, 'm': 5.7654197e-05, 'n': 0.48529527, 'o': 5.7654197e-05, 'p': 5.7654197e-05, 'q': 5.7654197e-05, 'r': 0.0031411538, 's': 0.004256414, 't': 5.7654197e-05, 'u': 5.7654197e-05, 'v': 5.7654197e-05, 'w': 5.7654197e-05, 'x': 5.7654197e-05, 'y': 5.7654197e-05, 'z': 5.7654197e-05, '_': 5.7654197e-05}
Saved model


Epoch 51/100: 5it [00:02,  1.75it/s]


Epoch 51/100, Loss: 0.69122406
{'<pad>': 1.2705372e-05, 'a': 0.9659127, 'b': 1.2705372e-05, 'c': 1.2705372e-05, 'd': 1.2705372e-05, 'e': 1.2705372e-05, 'f': 1.2705372e-05, 'g': 1.2705372e-05, 'h': 1.2705372e-05, 'i': 0.031425655, 'j': 1.2705372e-05, 'k': 1.2705372e-05, 'l': 0.00016960621, 'm': 1.2705372e-05, 'n': 0.0017576189, 'o': 1.2705372e-05, 'p': 1.2705372e-05, 'q': 1.2705372e-05, 'r': 0.00022651383, 's': 0.00022839842, 't': 1.2705372e-05, 'u': 1.2705372e-05, 'v': 1.2705372e-05, 'w': 1.2705372e-05, 'x': 1.2705372e-05, 'y': 1.2705372e-05, 'z': 1.2705372e-05, '_': 1.2705372e-05}
Saved model


Epoch 52/100: 5it [00:02,  1.79it/s]


Epoch 52/100, Loss: 0.69356409
{'<pad>': 1.0173659e-05, 'a': 0.8061943, 'b': 1.0173659e-05, 'c': 1.0173659e-05, 'd': 1.0173659e-05, 'e': 1.0173659e-05, 'f': 1.0173659e-05, 'g': 1.0173659e-05, 'h': 1.0173659e-05, 'i': 0.030622672, 'j': 1.0173659e-05, 'k': 1.0173659e-05, 'l': 6.004975e-05, 'm': 1.0173659e-05, 'n': 0.16183844, 'o': 1.0173659e-05, 'p': 1.0173659e-05, 'q': 1.0173659e-05, 'r': 0.00010754081, 's': 0.00095305685, 't': 1.0173659e-05, 'u': 1.0173659e-05, 'v': 1.0173659e-05, 'w': 1.0173659e-05, 'x': 1.0173659e-05, 'y': 1.0173659e-05, 'z': 1.0173659e-05, '_': 1.0173659e-05}
Saved model


Epoch 53/100: 5it [00:07,  1.51s/it]


Epoch 53/100, Loss: 0.69133574
{'<pad>': 3.2574884e-05, 'a': 0.96142036, 'b': 3.2574884e-05, 'c': 3.2574884e-05, 'd': 3.2574884e-05, 'e': 3.2574884e-05, 'f': 3.2574884e-05, 'g': 3.2574884e-05, 'h': 3.2574884e-05, 'i': 0.020453254, 'j': 3.2574884e-05, 'k': 3.2574884e-05, 'l': 0.00018211869, 'm': 3.2574884e-05, 'n': 0.0161374, 'o': 3.2574884e-05, 'p': 3.2574884e-05, 'q': 3.2574884e-05, 'r': 0.00020282116, 's': 0.0008874351, 't': 3.2574884e-05, 'u': 3.2574884e-05, 'v': 3.2574884e-05, 'w': 3.2574884e-05, 'x': 3.2574884e-05, 'y': 3.2574884e-05, 'z': 3.2574884e-05, '_': 3.2574884e-05}
Saved model


Epoch 54/100: 1it [00:06,  6.37s/it]


KeyboardInterrupt: 

In [22]:
sum(p.numel() for p in model.parameters())
# you should have around 10 times more data samples than parameters.

186972

In [22]:
model.inference("_at_r","bcdefghimnopqsuvwxyz","jkl")

{}

TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not str

In [4]:

class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc1 = nn.Linear(hidden_size, hidden_size//2)
        self.fc2 = nn.Linear(hidden_size//2, output_size)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        out = self.relu(lstm_out[:, -1, :])
        out = self.relu(self.fc1(out))
        out = self.fc2(out)
        return out

In [21]:
# Parameters
input_size = 28
hidden_size = 64
output_size = 28
seq_length = 10
batch_size = 32
num_epochs = 15
num_layers = 2
learning_rate = 0.001
checkpoint_path = 'checkpoint.pth'

# Instantiate the model, loss function, and optimizer
model = LSTMClassifier(input_size, hidden_size,num_layers, output_size)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Load dataset and create DataLoader
dataset = HangmanTextDataset('./data/train/mini.csv', chunk_size=1000)
dataloader = DataLoader(dataset)

In [59]:
# Tensorboard writer for logging
# writer = SummaryWriter()

# Training loop
import os
resume = False
global epoch
if os.path.exists('./checkpoint.pth') and resume: 
    model = LSTMClassifier(input_size, hidden_size, num_layers, output_size)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate) 

    checkpoint = torch.load('./checkpoint.pth')  
    model.load_state_dict(checkpoint['model_state_dict']) 
    optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 
    epoch = checkpoint['epoch'] + 1
    loss = checkpoint['loss'] 
else: 
    epoch = 0


while epoch < num_epochs:
    model.train()
    running_loss = 0.0
    for batch_input, batch_target in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        batch_input = torch.squeeze(batch_input) 
        batch_target = torch.squeeze(batch_target)
        # Forward pass
        outputs = model(batch_input)

        # Compute loss
        loss = criterion(outputs, batch_target)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * batch_input.size(0)

    epoch_loss = running_loss / 40000
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
    # writer.add_scalar('Loss/train', epoch_loss, epoch)

    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, checkpoint_path) 

    epoch += 1

# Close Tensorboard writer
# writer.close()


NameError: name 'num_epochs' is not defined