# Homework 3 - Text generation with LSTM and Transformer networks



## Installs the unidecode library and downloads the Shakespeare dataset.

In [2]:
!pip install unidecode
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

Collecting unidecode
  Downloading Unidecode-1.4.0-py3-none-any.whl.metadata (13 kB)
Downloading Unidecode-1.4.0-py3-none-any.whl (235 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/235.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m235.8/235.8 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: unidecode
Successfully installed unidecode-1.4.0
--2025-05-09 00:28:09--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-05-09 00:28:09 (66.4 MB/s) - ‘input.txt’ saved [1115394/1115394]



## LSTM implementation

For this task you will implement the LSTM neural network architecture and train it on the task of character-level text generation. Implement a single layer LSTM and optionally extend your implementation to multiple layers to generate better results.

Links:

- https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html -- Lists the equations for each component of the LSTM cell.
- http://colah.github.io/posts/2015-08-Understanding-LSTMs/ -- Intuitive explanation of LSTM
## - http://karpathy.github.io/2015/05/21/rnn-effectiveness/ -- Explanation and uses of RNNs.


Implement the initialization and the forward pass of a LSTMCell and use it as part of the LSTMSimple network class.

The input of the LSTM network will be a sequence of characters, whereas the input of the LSTMCell will be a single input character (x), the output of the previous iteration (C) and the hidden state of the previous iteration (h). Iteratively process the entire input character sequence and calculate the loss based on the prediction at each time step.

### Do NOT use the torch.nn.LSTM class.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
from torch.autograd import Variable

class LSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(LSTMCell, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        # Input gate parameters
        self.W_ii = nn.Parameter(torch.Tensor(hidden_dim, input_dim))
        self.W_hi = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.b_ii = nn.Parameter(torch.Tensor(hidden_dim))
        self.b_hi = nn.Parameter(torch.Tensor(hidden_dim))

        # Forget gate parameters
        self.W_if = nn.Parameter(torch.Tensor(hidden_dim, input_dim))
        self.W_hf = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.b_if = nn.Parameter(torch.Tensor(hidden_dim))
        self.b_hf = nn.Parameter(torch.Tensor(hidden_dim))

        # Cell gate parameters
        self.W_ig = nn.Parameter(torch.Tensor(hidden_dim, input_dim))
        self.W_hg = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.b_ig = nn.Parameter(torch.Tensor(hidden_dim))
        self.b_hg = nn.Parameter(torch.Tensor(hidden_dim))

        # Output gate parameters
        self.W_io = nn.Parameter(torch.Tensor(hidden_dim, input_dim))
        self.W_ho = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.b_io = nn.Parameter(torch.Tensor(hidden_dim))
        self.b_ho = nn.Parameter(torch.Tensor(hidden_dim))

        self.init_weights()

    def init_weights(self):
        """Initialize weights according to the LSTM paper suggestion"""
        for param in self.parameters():
            if len(param.shape) >= 2:
                nn.init.xavier_uniform_(param)
            else:
                nn.init.zeros_(param)

        # Set forget gate bias to 1 as per best practices
        nn.init.ones_(self.b_if)
        nn.init.ones_(self.b_hf)

    def forward(self, x, C, h):
        """
        Forward pass of the LSTM cell

        Args:
            x: Input tensor of shape (batch_size, input_dim)
            C: Previous cell state of shape (batch_size, hidden_dim)
            h: Previous hidden state of shape (batch_size, hidden_dim)

        Returns:
            C_out: Updated cell state
            h_out: Updated hidden state
        """
        # Input gate
        i_t = torch.sigmoid(x @ self.W_ii.t() + self.b_ii + h @ self.W_hi.t() + self.b_hi)

        # Forget gate
        f_t = torch.sigmoid(x @ self.W_if.t() + self.b_if + h @ self.W_hf.t() + self.b_hf)

        # Cell candidate
        g_t = torch.tanh(x @ self.W_ig.t() + self.b_ig + h @ self.W_hg.t() + self.b_hg)

        # Output gate
        o_t = torch.sigmoid(x @ self.W_io.t() + self.b_io + h @ self.W_ho.t() + self.b_ho)

        # Update cell state: c_t = f_t * c_{t-1} + i_t * g_t
        C_out = f_t * C + i_t * g_t

        # Update hidden state: h_t = o_t * tanh(c_t)
        h_out = o_t * torch.tanh(C_out)

        return C_out, h_out


class LSTMSimple(nn.Module):
    """
    Enhanced LSTM implementation with multi-layer support and dropout
    with backward compatibility for sampling functions
    """
    def __init__(self, seq_length, input_dim, hidden_dim, output_dim, batch_size, num_layers=1, dropout=0.0):
        super(LSTMSimple, self).__init__()
        self.seq_length = seq_length
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.batch_size = batch_size
        self.num_layers = num_layers
        self.dropout = dropout

        # Create LSTM cells for each layer
        self.lstm_cells = nn.ModuleList()

        # First layer takes input from the input dimension
        self.lstm_cells.append(LSTMCell(input_dim, hidden_dim, output_dim))

        # Subsequent layers take input from the hidden dimension of the previous layer
        for _ in range(1, num_layers):
            self.lstm_cells.append(LSTMCell(hidden_dim, hidden_dim, output_dim))

        self.lstm_cell = self.lstm_cells[0]

        # Dropout layer
        self.dropout_layer = nn.Dropout(dropout)

        # Projection layer to output dimension
        self.proj = nn.Linear(hidden_dim, output_dim)

    def init_hidden(self, batch_size, device=None):
        """Initialize hidden and cell states"""
        if device is None:
            device = next(self.parameters()).device

        return (torch.zeros(batch_size, self.hidden_dim, device=device),
                torch.zeros(batch_size, self.hidden_dim, device=device))

    def forward(self, x):
        """
        Forward pass over a sequence of characters

        Args:
            x: Input tensor of shape (batch_size, seq_length, input_dim)

        Returns:
            outputs: Predictions for each step in the sequence
            (C, h): Final cell and hidden states of the last layer
        """
        #  the case when x is just a sequence without batch dimension
        if len(x.shape) == 2:
            x = x.unsqueeze(0)  # batch dimension

        batch_size, seq_length, _ = x.size()
        device = x.device
        cell_states = [torch.zeros(batch_size, self.hidden_dim, device=device) for _ in range(self.num_layers)]
        hidden_states = [torch.zeros(batch_size, self.hidden_dim, device=device) for _ in range(self.num_layers)]

        outputs = []
        for t in range(seq_length):
            x_t = x[:, t, :]

            # Process through each layer
            for layer in range(self.num_layers):
                if layer > 0:
                    # Apply dropout between layers
                    x_t = self.dropout_layer(x_t)
                # Forward through LSTM cell
                cell_states[layer], hidden_states[layer] = self.lstm_cells[layer](
                    x_t, cell_states[layer], hidden_states[layer]
                )
                x_t = hidden_states[layer]

            # Apply dropout before projection layer
            final_hidden = self.dropout_layer(hidden_states[-1])

            # Project to output dimension
            out = self.proj(final_hidden)
            outputs.append(out)


        outputs = torch.stack(outputs, dim=1)  # Shape: (batch_size, seq_length, output_dim)


        return outputs, (cell_states[-1], hidden_states[-1])

### LSTM Sampling Code

To generate text the network must predict the next character in a sequence, however networks do not produce a single character but rather estimate the likelihood for each possible character. Sampling characters from the network output can be done in different ways with common ones being the Greedy sampling process and Top-K sampling.

In the simple greedy sampling method the network takes a text prompt as input and generates an additional N tokens by always taking the token with the highest prediction score as the next token.

In the Top-K sampling, randomness is added to the sampling process as the network samples from K most likely predicitons at each step. This alleviates the problem of generative models repeating text but may generate incorrect text by sampling inappropriate tokens.


In [3]:
def greedy_sampling_lstm(lstm, x, num_chars):
    # x -- b x onehot_char
    outputs = torch.zeros((1,num_chars,x.shape[2]))
    t_outputs, (cell_state, hidden) = lstm(x.float())
    for c in range(num_chars):
        output_tmp = torch.softmax(lstm.proj(hidden),dim=1)
        top_ind = torch.argmax(output_tmp,dim=1)[0]
        tmp = torch.zeros_like(x[:,0,:]).cuda()
        tmp[:,top_ind] = 1
        outputs[:,c] = tmp

        cell_state, hidden = lstm.lstm_cell(tmp,cell_state,hidden)
    return outputs

def topk_sampling_lstm(lstm, x, num_chars):
    # x -- b x onehot_char
    outputs = torch.zeros((1,num_chars,x.shape[2]))
    t_outputs, (cell_state, hidden) = lstm(x.float())
    for c in range(num_chars):
        output_vals, output_ind = torch.topk(lstm.proj(hidden), 5, dim=1)
        output_tmp = torch.softmax(output_vals,dim=1)
        top_ind = torch.multinomial(output_tmp[0], 1)[0]
        tmp = torch.zeros_like(x[:,0,:]).cuda()
        tmp[:,output_ind[0,top_ind]] = 1
        outputs[:,c] = tmp

        cell_state, hidden = lstm.lstm_cell(tmp,cell_state,hidden)

    return outputs

### LSTM Dataset Code

In [4]:
import unidecode
import string
import random
from torch.autograd import Variable
from torch.utils.data import Dataset


class LSTMDataset(Dataset):
    def __init__(self, chunk_len=200, padded_chunks=False):
        # Character based dataset
        dataset_path = "./input.txt"
        # The tokens in the vocabulary (all_characters)
        # are just the printable characters of the string class
        self.all_characters = string.printable
        self.n_characters = len(self.all_characters)
        # Maps characters to indices
        self.char_dict = {x:i for i,x in enumerate(self.all_characters)}
        self.file, self.file_len = self.read_file(dataset_path)
        # Sequence length of the input
        self.chunk_len = chunk_len

    def read_file(self,filename):
        file = unidecode.unidecode(open(filename).read())
        return file, len(file)

    def char_tensor(self,in_str):
        # in_str - input sequence - String
        # Return one-hot encoded characters of in_str
        tensor = torch.zeros(len(in_str),self.n_characters).long()
        char_ind = [self.char_dict[c] for c in in_str]
        tensor[torch.arange(tensor.shape[0]),char_ind] = 1
        return tensor

    def __getitem__(self, idx):
        inp, target = self.get_random_text()
        return {"input":inp, "target":target}

    def __len__(self):
        return 10000

    def get_random_text(self):
        # Pick a random string of length self.chunk_len from the dataset
        start_index = np.random.randint(0, self.file_len - self.chunk_len)
        end_index = start_index + self.chunk_len + 1
        chunk = self.file[start_index:end_index]
        # One-hot encode the chosen string
        inp = self.char_tensor(chunk[:-1])
        # The target string is the same as the
        # input string but shifted by 1 character
        target = self.char_tensor(chunk[1:])
        inp = Variable(inp).cuda()
        target = Variable(target).cuda()
        return inp, target


### LSTM Training loop

With a correct implementation you should get sensible text generation results with the set parameters, however you should experiment with various parameters,
especially with the sequence length (chunk_len) used during training.

In [None]:
from tqdm import tqdm
import torch.optim as optim

batch_size = 256
chunk_len = 128
model_name = "LSTM"
train_dataset = LSTMDataset(chunk_len=chunk_len)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=0, drop_last=True)

#Sample parameters, use whatever you see fit.
input_dim = train_dataset.n_characters
hidden_dim = 256
output_dim = train_dataset.n_characters
learning_rate = 0.005
model = LSTMSimple(chunk_len,input_dim, hidden_dim, output_dim,batch_size)
model.train()
model.cuda()

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

epochs=30

for epoch in range(epochs):
    with tqdm(total=len(trainloader.dataset), desc ='Training - Epoch: '+str(epoch)+"/"+str(epochs), unit='chunks') as prog_bar:
        for i, data in enumerate(trainloader, 0):
            inputs = data['input'].float()
            labels = data['target'].float()
            # b x chunk_len x len(dataset.all_characters)
            target = torch.argmax(labels,dim=2)
            optimizer.zero_grad()
            outputs, _ = model(inputs)
            loss = criterion(outputs.view(inputs.shape[0]*inputs.shape[1],-1),target.view(labels.shape[0]*labels.shape[1]))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                      max_norm=10.0)
            optimizer.step()
            prog_bar.set_postfix(**{'run:': model_name,'lr': learning_rate,
                                    'loss': loss.item()
                                    })
            prog_bar.update(batch_size)
        # Intermediate output
        sample_text = "O Romeo, wherefore art thou"
        inp = train_dataset.char_tensor(sample_text)
        sample_input = Variable(inp).cuda().unsqueeze(0).float()
        out_test = topk_sampling_lstm(model,sample_input, 300)[0]
        out_char_index = torch.argmax(out_test, dim=1).detach().cpu().numpy()
        out_chars = sample_text+"".join([train_dataset.all_characters[i] for i in out_char_index])
        print("Top-K sampling -----------------")
        print(out_chars)

        out_test = greedy_sampling_lstm(model,sample_input, 300)[0]
        out_char_index = torch.argmax(out_test, dim=1).detach().cpu().numpy()
        out_chars = sample_text+"".join([train_dataset.all_characters[i] for i in out_char_index])
        print("Greedy sampling ----------------")
        print(out_chars)


Training - Epoch: 0/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 863.01chunks/s, loss=3.31, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou e eeean     oote ete tethe  toe  o e  a e  ae a  e   eotteea e   attoeat   teeteeotoetoa eoeoet to tet  teeeee o  o ano t the  t   e e tt t to ate   eo  te t  te  aon  e  tha eoneea ae   athto etthao ao  otoao at e e oth ao  a et a  t    teaone  tt  o t o  to   t et to ttteee   oette tet ano tteea 


Training - Epoch: 0/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 772.32chunks/s, loss=3.31, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou                                                                                                                                                                                                                                                                                                            


Training - Epoch: 1/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 848.98chunks/s, loss=2.92, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thoun anesan s  one t te thethote sh  otesas  ha e nes ta  oe t s na    tandth   earant  tee o   othin athenoses athe  at e  hiso t   eo ar so to oto sa oot t tare te so atho tase t e  hi ter  hitet tothint ate norit a to  ae se  o e  aeere th  ond s thate andete ot a thetoeto ooeetet sho an t e  he  t 


Training - Epoch: 1/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 783.35chunks/s, loss=2.92, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t  t


Training - Epoch: 2/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 838.40chunks/s, loss=2.51, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou  int, the me t the sous wou ther he s wees hor, hor ter thiret onthan hile wo the will seer har to sin wont aote wor seser hin we mire sing and ar mes the s ou hire tord we sous tord thon thite mas and tour sor se mer ther whathe he sot, what, the th mat he tot and wourd an sesthe sors thes,
Thent 


Training - Epoch: 2/30: 100%|█████████▉| 9984/10000 [00:13<00:00, 766.39chunks/s, loss=2.51, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the


Training - Epoch: 3/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 699.00chunks/s, loss=2.38, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou hease thithat to then sind.

LIRENO:IT:
Witheethat sors ar therstisthat hale hine seringof mas sendee hishes, wimis art heass, wat merthe the sorerend,

Ar me than mirind ande hind my ardallis on win hars oothe meald the tharse sithind toul shathe thals, he there weren thint,
Buthale we tort alese 


Training - Epoch: 3/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 769.51chunks/s, loss=2.38, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the


Training - Epoch: 4/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 647.85chunks/s, loss=2.29, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou hang tous so the so te sord mesest tor hart ond he my therene houn me tind to ther, shour the hime, tort stise thald stars stend,
As tathis shim art ofered

Bo hertis all seres,
Ande the ho chen an tend ther then, tour sous ous some,
Weld that an blowing as sesest,'dis so me tat mes ar meng ouss st


Training - Epoch: 4/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 772.18chunks/s, loss=2.29, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the


Training - Epoch: 5/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 703.07chunks/s, loss=2.2, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thour wit homate st ard ould aster hared tould wars on all me tound, busther well thall wot and hist, an sing his tou hat, the with the winthe, and hare hithare sont thit hat the thot he wo d at the my ur mond then woll an thee thath and to that heer will ther whel ter will te the hert and we him toure 


Training - Epoch: 5/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 787.98chunks/s, loss=2.2, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the


Training - Epoch: 6/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 778.32chunks/s, loss=2.14, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thoul sises, wothe the ward and astof are ast hears at mas mand,
Whe the somenes at tith thim sind the sent thil merthy,
And way the wit sithes at ard to ther some, are sis in the woredsens an mo she the hard,
And will with the stee heard as all all the hing that thas ath sings an the priss,
And homy th


Training - Epoch: 6/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 787.45chunks/s, loss=2.14, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the


Training - Epoch: 7/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 829.62chunks/s, loss=2.07, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou wis and be ancaless of of all we word hom mand
Thou her to to must and wouther my shath my tored, thee time toom thou here thees,
And to meart and, the preang the hond thene thould the too for stouns in thinge tis me thou deast as me sheer and wet to thas ar he woll wele, and bu to tood with ther, 


Training - Epoch: 7/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 784.92chunks/s, loss=2.07, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have the the the the the the will the the will the the will the the will the the will the the will the the will the the will the the will the the will the the will the the will the the will the the will the the will the the will the the will the the will the the will the the will the the will the t


Training - Epoch: 8/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 862.10chunks/s, loss=2.02, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thourd thath we my a pould the tard ald in still and an and,
And with and wely deat it in migh tis thy have sould that we he the pritith.

CEONIENE:
A wir heave that a deer hinst, as, at in the hande ant head.
The shear this and withee an al is mard.

CLOUCIS:
And we has and his sand hored wires on that


Training - Epoch: 8/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 781.48chunks/s, loss=2.02, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the will the wi


Training - Epoch: 9/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 846.15chunks/s, loss=1.96, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou so mese, the wetring the wall this the will shimed are to moust,
The shell striss, and the to that how, the she soule and, thingenter, the singhes.

CANIONT:
No ther seed you mane sire mose the peres ath hear

Pringe tis send, with sonds or shatl here fart him.

POLIO:
No these then will he to more


Training - Epoch: 9/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 779.30chunks/s, loss=1.96, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have the seare the will the seare the will the seare the will the seare the will the seare the will the seare the will the seare the will the seare the will the seare the will the seare the will the seare the will the seare the will the seare the will the seare the will the seare the will the seare


Training - Epoch: 10/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 867.06chunks/s, loss=1.93, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou hownes of has in
The tay torstence, the wondse this spurtise,
Thar seat that soures on my, whencence the sorser, ang
To son anglededse the sang of tor his and and hither shill and the painged,
Anderst to me here folly
And the bade ofthare for mear of to mar to hish strepence to mentes,
And that hav


Training - Epoch: 10/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 783.28chunks/s, loss=1.93, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the seare the 


Training - Epoch: 11/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 866.39chunks/s, loss=1.88, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou down.

CLADINA:
Then mer is the ceane tree, to me that home.
So llot of the cour well me,
I will wet that well that has and, be the sone thy lare
To true shat to makes my ase,
The hought hear sind on than ale she to think
Than to hat this ald that have the perines.
That the diend mentor to my love 


Training - Epoch: 11/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 780.91chunks/s, loss=1.88, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have the sear the sond the sear the sond the sear the sond the sear the sond the sear the sond the sear the sond the sear the sond the sear the sond the sear the sond the sear the sond the sear the sond the sear the sond the sear the sond the sear the sond the sear the sond the sear the sond the se


Training - Epoch: 12/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 867.64chunks/s, loss=1.83, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou the ploves
Will wotld to be shall were wits of the sind
And will hisser to thee the sones
This faith thou brient to but merelower his thee,
Take a dond you thes and tho grom of the pray he hiss.

CARIOLANUS:
How hishire, sor, whe she word hath remence te hild as
To so to befers mis as to the sont
T


Training - Epoch: 12/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 787.19chunks/s, loss=1.83, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the 


Training - Epoch: 13/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 861.60chunks/s, loss=1.82, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou to murter.

PEONTILANG:
With he sin than the should hase but tays as he heather than sin
To than anderst onterthe all woon thou and and tome, shee, he wasch that wish shere,
Thou down thoush treep it treather to menchare still be than thy horess.

KING RICHARD III:
Whonese, his fither the see then 


Training - Epoch: 13/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 780.77chunks/s, loss=1.82, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the seart the 


Training - Epoch: 14/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 862.86chunks/s, loss=1.78, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thought.

KING HENRY BI I'll near his treeder, and hup tane.

KING RICHARD III
ABELRETHE:
The with that the seest the dondes ar she, all to have have home,
And hingst be men tring triess the courte
To have to the prater thank'd that well ane ward owe,
Bur I ward hill to me prose all to tearth me.

Forst


Training - Epoch: 14/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 785.27chunks/s, loss=1.78, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have the proper the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the prace the


Training - Epoch: 15/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 869.87chunks/s, loss=1.76, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou houses and that to sie,
Thou well but thim hers of you her with our santer
And seed thee indery sirest hather,
As the cames it,--
The histring, but to may you with her seed his a been.

PEROMASPARD:
Well'd me the coust of haster are,
Should be to that's telpis all hath sontersed held here
I thee so


Training - Epoch: 15/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 787.71chunks/s, loss=1.76, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have the seed to the seed
The see the see the see the seep the seep to the seed
The see the see the see the seep the seep to the seed
The see the see the see the seep the seep to the seed
The see the see the see the seep the seep to the seed
The see the see the see the seep the seep to the seed
The


Training - Epoch: 16/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 862.15chunks/s, loss=1.73, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou will.

CORIOLANUS:
And thiness the will with himes.

LUCETIUS:
I'll will his shall to the bank.

LEONTES:
Thise any thence my like of hatter and and all which shall be some,
In would that wored and to my breatiness,
I seate my sourthen, and thought on thee trust of
Wordes to the woon some one the p


Training - Epoch: 16/30: 100%|█████████▉| 9984/10000 [00:13<00:00, 764.49chunks/s, loss=1.73, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have the words and the seed
The seep the seep the seeple the seed the words and the seed
The seep the seep the seeple the seed the words and the seed
The seep the seep the seeple the seed the words and the seed
The seep the seep the seeple the seed the words and the seed
The seep the seep the seepl


Training - Epoch: 17/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 854.12chunks/s, loss=1.67, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou am shall be no lore
Toll men thou his angurent thy grouth,
I' will be she should book nowere and more ant would
Why, him, I well this wards of the caured.

COMINIUS:
And the king thought of you the wenk again,
As a gont my lers of your hight we home all me
Angullen spaticess we would that have trea


Training - Epoch: 17/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 784.70chunks/s, loss=1.67, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou have the proper.

PETRUCHIO:
What shall be the propes the propess of the come to the proper.

PETRUCHIO:
What shall be the propes the propess of the come to the proper.

PETRUCHIO:
What shall be the propes the propess of the come to the proper.

PETRUCHIO:
What shall be the propes the propess of th


Training - Epoch: 18/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 874.92chunks/s, loss=1.71, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou, neting ofe
As ind you so day to the cair be the brest
The partors and he be sorrow me. I keep in heard,
And, brink and sould be tell me sheep,
Tikn you sen thy can be te poos, and have mint.

LUCIO:
I thoughers there sere.

PRITUS:
All my sonter, shall,
And have shall take a tade there and wile he


Training - Epoch: 18/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 776.76chunks/s, loss=1.71, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou are the proper.

KING RICHARD II:
I will be the propess of the come to the proper.

KING RICHARD II:
I will be the propess of the come to the proper.

KING RICHARD II:
I will be the propess of the come to the proper.

KING RICHARD II:
I will be the propess of the come to the proper.

KING RICHARD I


Training - Epoch: 19/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 721.50chunks/s, loss=1.67, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou her shreads.

GLOUCESTER:
I'll say him be spork to thim have asment on the prove
The stay daughty heavo stay strinker's my fisters,
And beais down by mears of thee his has stoon spirt thee there to make
What he word think he stown to the wingent me and wisted trat
Whither stard'd my sourther midder


Training - Epoch: 19/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 776.09chunks/s, loss=1.67, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou had the dears and the seed
The see the seep the seep the seep the seep the seep the seep
The see the seep the seep the seep the seep the seep the seep
The see the seep the seep the seep the seep the seep the seep
The see the seep the seep the seep the seep the seep the seep
The see the seep the see


Training - Epoch: 20/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 641.32chunks/s, loss=1.66, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou have not so to the see.

LADY CAMPLLO:
Andest whise serve the donger treast to mie.
We hand son, see the dis a tongue.

DUKE LARDARET:
Well there a mest brother, but mastins warmes
To blooven,--

COMINIUS:
We have myself, my lord out a thang,
The mestisser,
And the seep the disprysents to her, son 


Training - Epoch: 20/30: 100%|█████████▉| 9984/10000 [00:13<00:00, 765.08chunks/s, loss=1.66, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the propess.

PETRUCHIO:
What shall be the propess the see the propess of the propess.

PETRUCHIO:
What shall be the propess the see the propess of the propess.

PETRUCHIO:
What shall be the propess the see the propess of the propess.

PETRUCHIO:
What shall be the propess the see the propess of


Training - Epoch: 21/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 714.61chunks/s, loss=1.63, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou and by my his oft.

Stcond Grume:
I they may be nutt them with thou als, to there tean
Inderse the wreath the sunce mad by his come.

CLOFFORY:
Offul thee thee senst, and worder thee to has our
There sor of heaven sir, and then them.

LEONTES:
It me treatine, a swill home so tone, with mister.

CAT


Training - Epoch: 21/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 779.63chunks/s, loss=1.63, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the seep.

CORIOLANUS:
What shall be so the comes to the seep to the seep
The see the see the seep the seep to the seep
The see the see the seep the seep to the seep
The see the see the seep the seep to the seep
The see the see the seep the seep to the seep
The see the see the seep the seep to 


Training - Epoch: 22/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 780.73chunks/s, loss=1.6, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou and to
Frink to maticumer age,
The proceming, tell take held menty hath mighter, sording these a pricins,
And than how when sord make a san ant ang storn
Had say shall be the farmit and tongut,
When, thereforames,'d a soon me tell her shome and him.

LADY CAPULO:
When I shoulds boon my stiends,
Tha


Training - Epoch: 22/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 779.83chunks/s, loss=1.6, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the words and the world
That shall be the words and the words and the world the propent to the propent
That the words and the propess to the propent the world
That shall be the words and the words and the world the propent to the propent
That the words and the propess to the propent the world
T


Training - Epoch: 23/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 832.25chunks/s, loss=1.59, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thousat to my lead.

PETRUCHIO:
I warting hath moresess an housand tay.

GREMIO:
Toull, as I would this to the proved of hands in him;
For what shall no barked thing out becime, here send him. 
First Mindrarman:
I do may yell and book, a poor heart
I'll but helloused him her shall not worthy.'
Let of hu


Training - Epoch: 23/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 777.81chunks/s, loss=1.59, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the seep the seep
That the see the see the seep the seep the seep
That the see the see the seep the seep the seep
That the see the see the seep the seep the seep
That the see the see the seep the seep the seep
That the see the see the seep the seep the seep
That the see the see the seep the see


Training - Epoch: 24/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 855.37chunks/s, loss=1.59, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou art as as your comple.

LEONTES:
That I were think, therefort, then shears
And take itsears to heaving it alle.

BIANCA:
Is a head, this well that stay world,
What well man your hims, there antagent,
And blood wish and thou the stranger'd to this well:
The crust the plantes of the clays.

PORETER:



Training - Epoch: 24/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 784.22chunks/s, loss=1.59, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the seep
The stand the seep the seep the seep
The stand the seep the seep the seep the seep
The stand the seep the seep the seep the seep
The stand the seep the seep the seep the seep
The stand the seep the seep the seep the seep
The stand the seep the seep the seep the seep
The stand the seep 


Training - Epoch: 25/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 860.04chunks/s, loss=1.58, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou way;
What, by my himes, and to thy ground.

KING RENVE:
How to the world as a good mady on and me.

GRUMIO:
We canned to mere it.

LUCESTERS:
I wad mad my seevy mean again,
Thou dithar alos, to see how stride,
We cay to see take, there hold, which thou hath to myseed thou so.

Second Soran:
No, sen


Training - Epoch: 25/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 782.01chunks/s, loss=1.58, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the world
That the world to the propece to the comes and the world
That the world to the seep the seep to the comes and the world
That the world to the seep the seep to the comes and the world
That the world to the seep the seep to the comes and the world
That the world to the seep the seep to 


Training - Epoch: 26/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 865.64chunks/s, loss=1.57, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou, meass trow the prisce,
But the sundes of their hand out thy gaint
He have stare of thee so do to she him shope her hold him his him.

Second Gentleman:
The coulds and the chood so deep diend which his child.

KING RICHARD II:
My fately shalt nave their childs. But you she him her heavy are
I any b


Training - Epoch: 26/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 787.16chunks/s, loss=1.57, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the seep.

BAPTISTA:
I shall be see the see the see the seep to the consent
That the sent the sent the sent the sent the seep
The see the see the see the seep the seep to the sent
The see the see the see the seep the seep to the sent
The see the see the see the seep the seep to the sent
The see


Training - Epoch: 27/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 853.12chunks/s, loss=1.57, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou art this that
And house the brament, the heart.

GRUMIO:
Thy shall be nigners truch and sen thy brance
I she had say him helfore as as you hate
And be so forth and so shourderned wife,
When to hear my look to be must stroud my strrie:
How now stranch, I was songer we wart my blood.

BRUTUS:
If you,


Training - Epoch: 27/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 779.33chunks/s, loss=1.57, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the see
The see the see the seep the seep the seep
That the world the see the seep the seep the seep
That the world the see the seep the seep the seep
That the world the see the seep the seep the seep
That the world the see the seep the seep the seep
That the world the see the seep the seep the


Training - Epoch: 28/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 863.71chunks/s, loss=1.56, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thou
have book to too have her word.

PATIS:
Wardwnect, will no see thou dot here,
And he hath now some an ale, with me.
Why, and with your son of thou hadd sporl,
Thou sailing to my father's sourt then with the world tell a port,
And we well and to all thy best and the seared
That so my his pray round 


Training - Epoch: 28/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 780.03chunks/s, loss=1.56, lr=0.005, run:=LSTM]


Greedy sampling ----------------
O Romeo, wherefore art thou art the seep
That the see the see the death and the see the death.

KING RICHARD III:
What is the see the see the see the death,
And then the seep the seep the seep to the comes
And the see the see the death and the seep
That the see the see the death and the see the death.

KING RICHARD III:
What 


Training - Epoch: 29/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 870.95chunks/s, loss=1.53, lr=0.005, run:=LSTM]

Top-K sampling -----------------
O Romeo, wherefore art thoughts,
Which he be need to hush ther, be all,
This beard he is arm all our here,
It will be the wither't the confull a man the said to
And, and his myself, thou wilt blain of their stainty
Herry shoush a subled the hard on a comest alain,
And him is to the can foon of yours,
Sit wouch once hath them 


Training - Epoch: 29/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 773.94chunks/s, loss=1.53, lr=0.005, run:=LSTM]

Greedy sampling ----------------
O Romeo, wherefore art thou hast here.

CORIOLANUS:
The hath the comes the streather than the seep
That the strest the strest the streather than the seep
That the strest the strest the streather than the seep
That the strest the strest the streather than the seep
That the strest the strest the streather than the seep
That the





In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import time

def train_with_sequence_length(seq_length, epochs=10, hidden_dim=256, batch_size=256, learning_rate=0.005):

    os.makedirs("results", exist_ok=True)


    train_dataset = LSTMDataset(chunk_len=seq_length)
    trainloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, num_workers=0, drop_last=True
    )


    input_dim = train_dataset.n_characters
    output_dim = train_dataset.n_characters
    model_name = f"LSTM_seq{seq_length}"


    model = LSTMSimple(seq_length, input_dim, hidden_dim, output_dim, batch_size)


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train()


    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)


    losses = []
    samples = {"topk": [], "greedy": []}

    print(f"\n{'=' * 60}")
    print(f"Training with sequence length: {seq_length}")
    print(f"{'=' * 60}\n")


    start_time = time.time()
    for epoch in range(epochs):
        epoch_losses = []

        with tqdm(total=len(trainloader.dataset), desc=f'Training - Epoch: {epoch}/{epochs}', unit='chunks') as prog_bar:
            for i, data in enumerate(trainloader, 0):
                # Move data to device
                inputs = data['input'].float().to(device)
                labels = data['target'].float().to(device)

                # Get target indices
                target = torch.argmax(labels, dim=2)

                # Forward pass
                optimizer.zero_grad()
                outputs, _ = model(inputs)

                # Compute loss
                loss = criterion(
                    outputs.view(inputs.shape[0] * inputs.shape[1], -1),
                    target.view(labels.shape[0] * labels.shape[1])
                )


                loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
                # Update parameters
                optimizer.step()


                epoch_losses.append(loss.item())


                prog_bar.set_postfix(**{
                    'run:': model_name,
                    'lr': learning_rate,
                    'loss': loss.item()
                })
                prog_bar.update(batch_size)

        # Average loss for this epoch
        avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)
        losses.append(avg_epoch_loss)

        # Generate samples at the end of each epoch
        sample_text = "O Romeo, wherefore art thou"
        inp = train_dataset.char_tensor(sample_text)
        sample_input = Variable(inp).to(device).unsqueeze(0).float()

        # Generate text with top-k sampling
        out_test_topk = topk_sampling_lstm(model, sample_input, 300)[0]
        out_char_index_topk = torch.argmax(out_test_topk, dim=1).detach().cpu().numpy()
        out_chars_topk = sample_text + "".join([train_dataset.all_characters[i] for i in out_char_index_topk])
        samples["topk"].append(out_chars_topk)

        # Generate text with greedy sampling
        out_test_greedy = greedy_sampling_lstm(model, sample_input, 300)[0]
        out_char_index_greedy = torch.argmax(out_test_greedy, dim=1).detach().cpu().numpy()
        out_chars_greedy = sample_text + "".join([train_dataset.all_characters[i] for i in out_char_index_greedy])
        samples["greedy"].append(out_chars_greedy)


        if epoch == epochs - 1:
            print("\nTop-K sampling -----------------")
            print(out_chars_topk)
            print("\nGreedy sampling ----------------")
            print(out_chars_greedy)

    training_time = time.time() - start_time
    print(f"\nTraining completed in {training_time:.2f} seconds")


    torch.save(model.state_dict(), f"results/lstm_seq{seq_length}.pth")

    return model, losses, samples


def run_sequence_length_experiment(sequence_lengths, epochs=10, hidden_dim=256, batch_size=256, learning_rate=0.005):
    """
    Run experiment with multiple sequence lengths

    Args:
        sequence_lengths: List of sequence lengths to train with
        epochs: Number of epochs for each training run
        hidden_dim: Hidden dimension size
        batch_size: Batch size for training
        learning_rate: Learning rate for optimization
    """

    all_results = {}

    for seq_len in sequence_lengths:
        # Train with this sequence length
        model, losses, samples = train_with_sequence_length(
            seq_length=seq_len,
            epochs=epochs,
            hidden_dim=hidden_dim,
            batch_size=batch_size,
            learning_rate=learning_rate
        )

        # Store results
        all_results[seq_len] = {
            "model": model,
            "losses": losses,
            "samples": samples
        }


    plt.figure(figsize=(10, 6))
    for seq_len, results in all_results.items():
        plt.plot(results["losses"], label=f"Sequence Length: {seq_len}")

    plt.title("Training Loss vs. Epoch for Different Sequence Lengths")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig("results/sequence_length_loss_comparison.png")
    plt.close()


    with open("results/sequence_length_samples.txt", "w") as f:
        f.write("Generated Text Samples with Different Sequence Lengths\n")
        f.write("=" * 60 + "\n\n")

        for seq_len, results in all_results.items():
            f.write(f"Sequence Length: {seq_len}\n")
            f.write("-" * 40 + "\n\n")

            # Get the final samples (from the last epoch)
            final_topk = results["samples"]["topk"][-1]
            final_greedy = results["samples"]["greedy"][-1]

            f.write("Top-K Sampling:\n")
            f.write(final_topk + "\n\n")

            f.write("Greedy Sampling:\n")
            f.write(final_greedy + "\n\n")

            f.write("\n" + "=" * 60 + "\n\n")

    print("\nExperiment completed. Results saved to the 'results' directory.")
    return all_results




sequence_lengths = [32, 64, 128, 256]

epochs = 30

results = run_sequence_length_experiment(
        sequence_lengths=sequence_lengths,
        epochs=epochs,
        hidden_dim=256,
        batch_size=256,
        learning_rate=0.005
    )

print("\nSummary of Final Losses:")
for seq_len, res in results.items():
      print(f"Sequence Length {seq_len}: Final Loss = {res['losses'][-1]:.4f}")


Training with sequence length: 32



Training - Epoch: 0/30: 100%|█████████▉| 9984/10000 [00:04<00:00, 2067.06chunks/s, loss=3.28, lr=0.005, run:=LSTM_seq32]
Training - Epoch: 1/30: 100%|█████████▉| 9984/10000 [00:04<00:00, 2486.98chunks/s, loss=2.71, lr=0.005, run:=LSTM_seq32]
Training - Epoch: 2/30: 100%|█████████▉| 9984/10000 [00:04<00:00, 2125.42chunks/s, loss=2.43, lr=0.005, run:=LSTM_seq32]
Training - Epoch: 3/30: 100%|█████████▉| 9984/10000 [00:04<00:00, 2430.52chunks/s, loss=2.32, lr=0.005, run:=LSTM_seq32]
Training - Epoch: 4/30: 100%|█████████▉| 9984/10000 [00:04<00:00, 2117.28chunks/s, loss=2.23, lr=0.005, run:=LSTM_seq32]
Training - Epoch: 5/30: 100%|█████████▉| 9984/10000 [00:04<00:00, 2466.83chunks/s, loss=2.18, lr=0.005, run:=LSTM_seq32]
Training - Epoch: 6/30: 100%|█████████▉| 9984/10000 [00:05<00:00, 1982.93chunks/s, loss=2.08, lr=0.005, run:=LSTM_seq32]
Training - Epoch: 7/30: 100%|█████████▉| 9984/10000 [00:04<00:00, 2275.31chunks/s, loss=2.04, lr=0.005, run:=LSTM_seq32]
Training - Epoch: 8/30: 100%|███


Top-K sampling -----------------
O Romeo, wherefore art thousand mind anstless
Than to mady to-mards, and so set if a that should see to the world
A direter our patted mistomsells, the manded wiffock, the mancoman: we him fair soul ot the self
To to any so tood too himself is me to hear fall of you.

LUCIO:
And all me thou heligh'd migh sighters.

POMPPYO:
N

Greedy sampling ----------------
O Romeo, wherefore art thou art the stand to the stand to the stand to the stand to the stand
That the couse the stand to the stand to the stand to the stand to the stand to the stand to the stand
That the couse the stand to the stand to the stand to the stand to the stand to the stand to the stand
That the couse the stand to

Training completed in 149.63 seconds

Training with sequence length: 64



Training - Epoch: 0/30: 100%|█████████▉| 9984/10000 [00:06<00:00, 1483.43chunks/s, loss=3.23, lr=0.005, run:=LSTM_seq64]
Training - Epoch: 1/30: 100%|█████████▉| 9984/10000 [00:06<00:00, 1463.73chunks/s, loss=2.73, lr=0.005, run:=LSTM_seq64]
Training - Epoch: 2/30: 100%|█████████▉| 9984/10000 [00:07<00:00, 1366.13chunks/s, loss=2.46, lr=0.005, run:=LSTM_seq64]
Training - Epoch: 3/30: 100%|█████████▉| 9984/10000 [00:06<00:00, 1517.98chunks/s, loss=2.32, lr=0.005, run:=LSTM_seq64]
Training - Epoch: 4/30: 100%|█████████▉| 9984/10000 [00:07<00:00, 1314.71chunks/s, loss=2.25, lr=0.005, run:=LSTM_seq64]
Training - Epoch: 5/30: 100%|█████████▉| 9984/10000 [00:07<00:00, 1388.79chunks/s, loss=2.17, lr=0.005, run:=LSTM_seq64]
Training - Epoch: 6/30: 100%|█████████▉| 9984/10000 [00:06<00:00, 1524.19chunks/s, loss=2.1, lr=0.005, run:=LSTM_seq64]
Training - Epoch: 7/30: 100%|█████████▉| 9984/10000 [00:07<00:00, 1381.47chunks/s, loss=2.04, lr=0.005, run:=LSTM_seq64]
Training - Epoch: 8/30: 100%|████


Top-K sampling -----------------
O Romeo, wherefore art thou are hene.

PERDWIV:
A grant, and we come who can our grand to stane
And wordor of the part him thing like
Of these be a man one and morn.

KATHARINA:
He ware sauness to may the cray.
What, these should be son, say, a would have back?

Secold Gown:
Ay, my lord,
I am a care, that while is an hand the

Greedy sampling ----------------
O Romeo, wherefore art thou hast the send the see
The sent the send the send the send the send the send
The sent the sent the send the send the send the send the send the send the send the send the send the send the send the send the send the send the send the send the send the send the send the send the send the send the sen

Training completed in 226.16 seconds

Training with sequence length: 128



Training - Epoch: 0/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 820.02chunks/s, loss=3.27, lr=0.005, run:=LSTM_seq128]
Training - Epoch: 1/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 815.00chunks/s, loss=2.8, lr=0.005, run:=LSTM_seq128]
Training - Epoch: 2/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 809.29chunks/s, loss=2.46, lr=0.005, run:=LSTM_seq128]
Training - Epoch: 3/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 819.91chunks/s, loss=2.33, lr=0.005, run:=LSTM_seq128]
Training - Epoch: 4/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 824.47chunks/s, loss=2.24, lr=0.005, run:=LSTM_seq128]
Training - Epoch: 5/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 822.61chunks/s, loss=2.18, lr=0.005, run:=LSTM_seq128]
Training - Epoch: 6/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 821.61chunks/s, loss=2.1, lr=0.005, run:=LSTM_seq128]
Training - Epoch: 7/30: 100%|█████████▉| 9984/10000 [00:12<00:00, 821.05chunks/s, loss=2.07, lr=0.005, run:=LSTM_seq128]
Training - Epoch: 8/30: 100%|█████


Top-K sampling -----------------
O Romeo, wherefore art thousted and see
thee that I do me to he well me, and here
Is for the sonest on the both the seak have.

KATHAPSOr:
Was you are boon you and tently are that then tow
The sent the mentied wother heaven some too man, brink to morn
I will more the more and so lady.
Take honest he should me will thy said an

Greedy sampling ----------------
O Romeo, wherefore art thou art the stands the some
That should the come to the come to the come to the country with the send
That shall be the come to the come to the country with the send
That shall be the come to the come to the country with the send
That shall be the come to the come to the country with the send
That shal

Training completed in 386.77 seconds

Training with sequence length: 256



Training - Epoch: 0/30: 100%|█████████▉| 9984/10000 [00:23<00:00, 417.02chunks/s, loss=3.26, lr=0.005, run:=LSTM_seq256]
Training - Epoch: 1/30: 100%|█████████▉| 9984/10000 [00:23<00:00, 429.79chunks/s, loss=2.83, lr=0.005, run:=LSTM_seq256]
Training - Epoch: 2/30: 100%|█████████▉| 9984/10000 [00:23<00:00, 430.75chunks/s, loss=2.47, lr=0.005, run:=LSTM_seq256]
Training - Epoch: 3/30: 100%|█████████▉| 9984/10000 [00:23<00:00, 423.13chunks/s, loss=2.36, lr=0.005, run:=LSTM_seq256]
Training - Epoch: 4/30: 100%|█████████▉| 9984/10000 [00:23<00:00, 425.04chunks/s, loss=2.26, lr=0.005, run:=LSTM_seq256]
Training - Epoch: 5/30: 100%|█████████▉| 9984/10000 [00:23<00:00, 425.06chunks/s, loss=2.2, lr=0.005, run:=LSTM_seq256]
Training - Epoch: 6/30: 100%|█████████▉| 9984/10000 [00:23<00:00, 425.98chunks/s, loss=2.14, lr=0.005, run:=LSTM_seq256]
Training - Epoch: 7/30: 100%|█████████▉| 9984/10000 [00:22<00:00, 435.29chunks/s, loss=2.11, lr=0.005, run:=LSTM_seq256]
Training - Epoch: 8/30: 100%|████


Top-K sampling -----------------
O Romeo, wherefore art thou haven that's take have that
There is thee, make a please of his peacian
A spile of here hath his brow the works. Though him:
Allow they sige it is a man to sempose
But thee in the seight and brow wasers on
As for to my hinde in timen all mother
As the stine of his fear men.

PETRUCHIO:
I would's th

Greedy sampling ----------------
O Romeo, wherefore art thou do the stander of the son
To see the stander of the son the proper and thee,
And the proper the stander of the sonester
That when the stander the stander of the son
To see the stander the stander of the son
To see the stander the stander of the son
To see the stander the stander of the son
To see t

Training completed in 718.04 seconds

Experiment completed. Results saved to the 'results' directory.

Summary of Final Losses:
Sequence Length 32: Final Loss = 1.6354
Sequence Length 64: Final Loss = 1.5809
Sequence Length 128: Final Loss = 1.5542
Sequence Length 256: Final

# Task 2: Character generation transformer network implementation
Our simple transformer-like network will take as input a sequence of characters and predict the next character in the sequence. To ensure an efficient training procedure, masked attention modules will be used as in the [GPT model](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf).

For this task you must implement the Scaled dot product attention module and the Masked multi-head attention module. Both of these modules are described in the [Attention is all you need](https://arxiv.org/pdf/1706.03762.pdf) paper (See Figure 2 in the paper as well as Sections 3.2.1, 3.2.2 and 3.2.3). They are the core operations of transformers. As we will use our model for text generation also add the masking operation shown as (mask opt.) in Figure 2, implemented as AttentionMasking in the code.

**Implement the modules in the ScaledDotProductAttention class and the MultiHeadAttention class.**

Read the GPT paper and the Attention is all you need paper for a better understanding of the components. For a more high level overview, this [post](https://jalammar.github.io/illustrated-gpt2/) may also be helpful.


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1).float()
        div_term = 10000.0**(torch.arange(0,d_model,2).float()/d_model)
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)
        pe = pe.unsqueeze(0)
        # Check for CUDA availability and move tensor accordingly
        if torch.cuda.is_available():
            self.pe = pe.cuda()
        else:
            self.pe = pe  # Keep tensor on CPU if CUDA is not available
        self.pe.requires_grad = False

    def forward(self, x):
        p = self.pe[:, :x.size(1)]
        return p

class AttentionMasking(nn.Module):
    def __init__(self, max_len):
        super(AttentionMasking, self).__init__()
        mask = torch.tril(torch.ones(max_len, max_len)).view(1, 1, max_len, max_len)
        # register_buffer will move tensor to the right device when model.to(device) is called
        self.register_buffer("mask", mask)

    def forward(self,x):
        length = x.shape[-1]
        out = x.masked_fill(self.mask[:,:,:length,:length] == 0, float('-inf'))
        return out

class ScaledDotProductAttention(nn.Module):
    def __init__(self, max_len, dropout_rate=0.1):
        super(ScaledDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)
        # Multiply with an upper triangular
        # matrix of dimensions (length x length) after the scale operation
        # in Figure 2 of the paper.
        self.mask_opt = AttentionMasking(max_len)
        # Added dropout as per the paper
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self,q,k,v):
        # length = number of input tokens
        batch_size,num_heads,length,num_neuron = k.size()
        # Step 1: MatMul Q and K^T
        scores = torch.matmul(q, k.transpose(-2, -1))
        # Step 2: Scale the dot products by 1/sqrt(d_k)
        scores = scores / math.sqrt(num_neuron)
        # Step 3: Apply mask (optional for decoder self-attention)
        scores = self.mask_opt(scores)
        attention_weights = self.softmax(scores)
        attention_weights = self.dropout(attention_weights)
        output = torch.matmul(attention_weights, v)

        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, dim_model, num_neuron, n_head, max_len, dropout_rate=0.1):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.num_neuron = num_neuron

        # Initialize the linear projections for Q, K, V
        self.W_q = nn.Linear(dim_model, n_head * num_neuron)
        self.W_k = nn.Linear(dim_model, n_head * num_neuron)
        self.W_v = nn.Linear(dim_model, n_head * num_neuron)

        # Final output projection
        self.W_o = nn.Linear(n_head * num_neuron, dim_model)
        self.dropout = nn.Dropout(dropout_rate)
        self.attention = ScaledDotProductAttention(max_len, dropout_rate)

    def split(self,tensor):
        batch_size, length, total_dim = tensor.size()
        split_tensor = tensor.view(batch_size, length, self.n_head, self.num_neuron).transpose(1,2)
        return split_tensor

    def concat(self,tensor):
        batch_size, num_heads, length, num_neuron = tensor.size()
        concat_tensor = tensor.transpose(1,2).contiguous().view(batch_size, length, self.n_head*self.num_neuron)
        return concat_tensor

    def forward(self, q, k, v):
        batch_size = q.size(0)

        # Step 1: Linear projections
        q_proj = self.W_q(q)  # (batch_size, length, n_head * num_neuron)
        k_proj = self.W_k(k)  # (batch_size, length, n_head * num_neuron)
        v_proj = self.W_v(v)  # (batch_size, length, n_head * num_neuron)

        # Step 2: Split into multiple heads
        q_split = self.split(q_proj)  # (batch_size, n_head, length, num_neuron)
        k_split = self.split(k_proj)  # (batch_size, n_head, length, num_neuron)
        v_split = self.split(v_proj)  # (batch_size, n_head, length, num_neuron)

        # Step 3: Apply scaled dot-product attention
        attn_output = self.attention(q_split, k_split, v_split)

        # Step 4: Concatenate heads
        concat_output = self.concat(attn_output)

        # Step 5: Apply final linear projection
        output = self.W_o(concat_output)

        # Step 6: Apply dropout to the output
        output = self.dropout(output)

        return output

class PositionFeedForwardNet(nn.Module):
    def __init__(self, dim_model, dropout_rate=0.1):
        super(PositionFeedForwardNet, self).__init__()
        self.ff_net1 = nn.Linear(dim_model, dim_model*4)
        self.ff_net2 = nn.Linear(dim_model*4, dim_model)
        # Added dropout as per the paper
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self,x):
        ff_out = self.ff_net1(x)
        ff_out = torch.nn.functional.relu(ff_out)
        ff_out = self.dropout(ff_out)  # Apply dropout after activation
        ff_out = self.ff_net2(ff_out)
        return ff_out

class TransformerBlock(nn.Module):
    def __init__(self, dim_model, num_neuron, n_head, max_len, dropout_rate=0.1):
        super(TransformerBlock, self).__init__()
        self.mha = MultiHeadAttention(dim_model, num_neuron, n_head, max_len, dropout_rate)
        self.l_norm = torch.nn.LayerNorm(dim_model)
        self.l_norm2 = torch.nn.LayerNorm(dim_model)
        self.ff_net = PositionFeedForwardNet(dim_model, dropout_rate)
        # Added dropout as per the paper
        self.dropout = nn.Dropout(dropout_rate)
        # b, len_seq, n_head, num_neuron

    def forward(self, x):
        _x = x
        mha1 = self.mha(x,x,x)
        lnorm = self.l_norm(_x + mha1)
        _x = lnorm
        ff_out = self.ff_net(lnorm)
        #  residual connection and layer normalization
        out = self.l_norm2(_x + ff_out)

        return out

class TransformerSimple(nn.Module):
    def __init__(self, seq_length, input_dim, output_dim, batch_size, dropout_rate=0.1):
        super(TransformerSimple, self).__init__()
        num_neuron = 64
        n_head = 8
        dim_model=256
        max_len = 512
        self.start_embedding = nn.Embedding(input_dim, dim_model)
        self.pos_embedding = PositionalEncoding(dim_model)

        self.dropout = nn.Dropout(dropout_rate)

        # Track device to use throughout the model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(dim_model, num_neuron, n_head, max_len, dropout_rate)
            for _ in range(5)
        ])

        self.output_layer = nn.Linear(dim_model, output_dim)

    def forward(self, x):
        # x - Tensor - (b, seq_len)
        # Embeds the input tensor from tokens to features
        s_emb = self.start_embedding(x)
        # Adds positional embeddings
        p_emb = self.pos_embedding(s_emb)
        b_out = p_emb + s_emb
        # Apply dropout to the combined embeddings
        b_out = self.dropout(b_out)

        for block in self.transformer_blocks:
            b_out = block(b_out)

        out = self.output_layer(b_out)

        return out



## Dataset class


In [None]:
!pip install unidecode

Collecting unidecode
  Downloading Unidecode-1.4.0-py3-none-any.whl.metadata (13 kB)
Downloading Unidecode-1.4.0-py3-none-any.whl (235 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/235.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m235.8/235.8 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: unidecode
Successfully installed unidecode-1.4.0


In [6]:
import unidecode
import string
import random
from torch.autograd import Variable
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, chunk_len=200, padded_chunks=False):
        # Character based dataset
        dataset_path = "./input.txt"
        # The tokens in the vocabulary (all_characters)
        # are just the printable characters of the string class
        self.all_characters = string.printable
        self.n_characters = len(self.all_characters)
        # Maps characters to indices
        self.char_dict = {x:i for i,x in enumerate(self.all_characters)}
        self.file, self.file_len = self.read_file(dataset_path)
        # Sequence length of the input
        self.chunk_len = chunk_len
        self.encoded_file = [self.char_dict[x] for x in self.file]

    def read_file(self,filename):
        file = unidecode.unidecode(open(filename).read())
        return file, len(file)

    def encode_text(self,in_str):
        # in_str - input sequence - String
        # Returns - in_str mapped to tokens in char_dict
        tensor = torch.LongTensor([self.char_dict[x] for x in in_str])
        return tensor

    def __getitem__(self, idx):
        inp, target = self.get_random_text()
        return {"input":inp, "target":target}

    def __len__(self):
        return 10000

    def get_random_text(self):
        # Pick a random string of length self.chunk_len from the dataset
        start_index = np.random.randint(0, self.file_len - self.chunk_len)
        end_index = start_index + self.chunk_len + 1
        chunk = self.encoded_file[start_index:end_index]
        # input_tokens - random sequence of tokens from the dataset
        input_tokens = torch.LongTensor(chunk[:-1])
        # target - input token sequence shifted by 1
        # the idea is to predict next token for each token in the input sequence
        # therefore if the input is [1,2,3,4] the target is [2,3,4,5]
        target = torch.LongTensor(chunk[1:])
        input_tokens = input_tokens.cuda()
        target = target.cuda()
        return input_tokens, target


## Character sampling

To generate text the network must predict the next character in a sequence, however networks do not produce a single character but rather estimate the likelihood for each possible character. Sampling characters from the network output can be done in different ways with common ones being the Greedy sampling process and Top-K sampling.

In the simple greedy sampling method the network takes a text prompt as input and generates an additional N tokens by always taking the token with the highest prediction score as the next token.

In the Top-K sampling, randomness is added to the sampling process as the network samples from K most likely predicitons at each step. This alleviates the problem of generative models repeating text but may generate incorrect text by sampling inappropriate tokens.


In [7]:
def topk_sampling_iter_transformer(model, x, num_chars, chunk_len, output_token):
    # x -- b x onehot_char
    # x = b x l
    outputs = torch.zeros((1,num_chars))
    inp = x

    for t in range(num_chars):
        # b x onehot_char
        output = model(inp.long())[0,-1:]
        #output = torch.softmax(output, dim=1)
        # b x 3
        output_vals, output_ind = torch.topk(output, 5, dim=1)
        # 3 -> int
        output_vals = torch.softmax(output_vals, dim=1)
        top_ind = torch.multinomial(output_vals[0], 1)[0]
        # int
        out_char_index = output_ind[0,top_ind]
        # int -> 1
        out_char_index = torch.ones(1).cuda() * out_char_index

        outputs[:,t] = out_char_index.item()
        if inp.shape[1] > chunk_len:
          inp = torch.cat((inp[:,1:], out_char_index.unsqueeze(0)), dim=1)
        else:
          inp = torch.cat((inp, out_char_index.unsqueeze(0)), dim=1)

    return outputs


def greedy_sampling_iter_transformer(model, x, num_chars, chunk_len, output_token):
    # x -- shape (batch, tokens in x)
    outputs = torch.zeros((1,num_chars))
    inp = x

    for t in range(num_chars):
        # b x l x onehot_char
        output = model(inp.long())[0,-1:]
        output = torch.softmax(output, dim=1)
        out_char_index = torch.argmax(output, dim=1)
        outputs[:,t] = out_char_index.item()
        if inp.shape[1] > chunk_len:
          inp = torch.cat((inp[:,1:], out_char_index.unsqueeze(0)), dim=1)
        else:
          inp = torch.cat((inp, out_char_index.unsqueeze(0)), dim=1)

    return outputs






## Transformer model training

With a correct implementation you should get sensible text generation results with the set parameters, however you should experiment with various parameters,
especially with the sequence length (chunk_len) used during training.

In [None]:
from tqdm import tqdm
import torch.optim as optim


#Sample parameters, use whatever you see fit.
batch_size = 256
chunk_len = 128
train_dataset = TextDataset(chunk_len=chunk_len)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=0)

input_dim = train_dataset.n_characters
output_dim = train_dataset.n_characters
learning_rate = 0.0006

model = TransformerSimple(chunk_len, input_dim, output_dim,batch_size)
model.train()


criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

epochs=50

for epoch in range(epochs):
    with tqdm(total=len(trainloader.dataset), desc ='Training - Epoch: '+str(epoch)+"/"+str(epochs), unit='chunks') as prog_bar:
        for i, data in enumerate(trainloader, 0):
            # inputs - shape (batch_size, chunk_len) - Tensor of vocabulary tokens
            inputs = data['input'].long()
            # labels - shape (batch_size, chunk_len) - Tensor of vocabulary tokens
            labels = data['target'].long()

            optimizer.zero_grad()
            outputs = model(inputs)
            target_t = labels
            loss = criterion(outputs.view(inputs.shape[0]*inputs.shape[1],-1),target_t.view(labels.shape[0]*labels.shape[1]))
            loss.backward()
            optimizer.step()
            prog_bar.set_postfix(**{'run:': "Transformer", 'lr': learning_rate,
                                    'loss': loss.item()
                                    })
            prog_bar.update(batch_size)

        # Intermediate text output
        sample_texts = ["What authority surfeits on",
                        "I say unto you, what he hath done famously, he did it to that end:",
                        "That in submission will return to us: And then, as we have ta'en the sacrament,"]
        output_token = torch.zeros(1,1).cuda()
        output_token[0,0] = train_dataset.n_characters-1
        print("Top-K sampling")
        for sample_text in sample_texts:
            sample_encoding = train_dataset.encode_text(sample_text)
            sample_input = Variable(sample_encoding).cuda().unsqueeze(0).long()

            #out_test= greedy_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
            out_test= topk_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
            out_char_index = out_test.long().detach().cpu().numpy()
            out_chars = sample_text+" "+"".join([train_dataset.all_characters[i] for i in out_char_index])
            print("----------------------------------------")
            print(out_chars)




## Different experiments

## Text sampling - Transformers


In [None]:
sample_text = "Here's to my love! O true apothecary! Thy drugs are quick."
sample_encoding = train_dataset.encode_text(sample_text)
sample_input = Variable(sample_encoding).cuda().unsqueeze(0).long()
chunk_len = 128
#out_test= greedy_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
out_test= topk_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
out_char_index = out_test.long().detach().cpu().numpy()
out_chars = sample_text+" "+"".join([train_dataset.all_characters[i] for i in out_char_index])
print("----------------------------------------")
print(out_chars)


In [None]:
from tqdm import tqdm
import torch.optim as optim
import matplotlib.pyplot as plt

# List of sequence lengths to experiment with
sequence_lengths = [64, 128, 256]
# Dictionary to store loss values for each experiment
loss_history = {length: [] for length in sequence_lengths}
# Dictionary to store generated text samples
generated_samples = {length: {} for length in sequence_lengths}

# Run each experiment sequentially
for chunk_len in sequence_lengths:
    print(f"\n=== EXPERIMENT WITH CHUNK_LEN = {chunk_len} ===\n")

    # Initialize dataset, model and training components
    batch_size = 256
    train_dataset = TextDataset(chunk_len=chunk_len)
    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=0)
    input_dim = train_dataset.n_characters
    output_dim = train_dataset.n_characters
    learning_rate = 0.0006

    # Create a new model for this experiment
    model = TransformerSimple(chunk_len, input_dim, output_dim, batch_size)
    model.train()
    model.cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    epochs = 50

    # Training loop
    for epoch in range(epochs):
        epoch_loss = 0.0
        with tqdm(total=len(trainloader.dataset), desc=f'Training chunk_len={chunk_len} - Epoch: {epoch}/{epochs}', unit='chunks') as prog_bar:
            for i, data in enumerate(trainloader, 0):
                # Training step
                inputs = data['input'].long()
                labels = data['target'].long()
                optimizer.zero_grad()
                outputs = model(inputs)
                target_t = labels
                loss = criterion(outputs.view(inputs.shape[0]*inputs.shape[1],-1), target_t.view(labels.shape[0]*labels.shape[1]))
                loss.backward()
                optimizer.step()

                # Update progress bar
                epoch_loss = loss.item()
                prog_bar.set_postfix(**{'run:': f"Transformer_{chunk_len}", 'lr': learning_rate, 'loss': epoch_loss})
                prog_bar.update(batch_size)

        # Store the loss for this epoch
        loss_history[chunk_len].append(epoch_loss)

        # Generate samples every 10 epochs to track progress
        if epoch % 10 == 9 or epoch == epochs - 1:
            sample_texts = [
                "What authority surfeits on",
                "I say unto you, what he hath done famously, he did it to that end:",
                "That in submission will return to us: And then, as we have ta'en the sacrament,"
            ]
            output_token = torch.zeros(1, 1).cuda()
            output_token[0, 0] = train_dataset.n_characters - 1

            print(f"\nTop-K sampling - chunk_len {chunk_len} - Epoch {epoch}")
            samples_for_epoch = {}

            for sample_text in sample_texts:
                sample_encoding = train_dataset.encode_text(sample_text)
                sample_input = Variable(sample_encoding).cuda().unsqueeze(0).long()
                out_test = topk_sampling_iter_transformer(model, sample_input, 400, chunk_len, output_token)[0]
                out_char_index = out_test.long().detach().cpu().numpy()
                out_chars = sample_text + " " + "".join([train_dataset.all_characters[i] for i in out_char_index])

                print("----------------------------------------")
                print(out_chars)

                # Store the generated text
                samples_for_epoch[sample_text] = out_chars

            # Add samples for this epoch to the dictionary
            generated_samples[chunk_len][epoch] = samples_for_epoch

    # Save the model for this experiment if needed
    torch.save(model.state_dict(), f'transformer_model_chunk_len_{chunk_len}.pt')

# Plot loss curves for comparison
plt.figure(figsize=(10, 6))
for chunk_len, losses in loss_history.items():
    plt.plot(range(1, epochs + 1), losses, label=f'chunk_len={chunk_len}')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss by Sequence Length')
plt.legend()
plt.grid(True)
plt.savefig('loss_comparison.png')
plt.show()

# Print summary of results
print("\n=== SUMMARY OF RESULTS ===\n")
for chunk_len in sequence_lengths:
    final_loss = loss_history[chunk_len][-1]
    print(f"Sequence Length {chunk_len}: Final Loss = {final_loss:.4f}")