In [1]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
embedding_dim = 512

In [3]:
tokenized_lines = open("input.txt", "r")
tokenized_lines = tokenized_lines.readlines()

vocab = set()
special_tokens = ["<pad>", "<start>", "<end>"]
for sentence in tokenized_lines:
    vocab.update(sentence.split())
vocab = special_tokens + list(vocab)

vocab_to_index = {word:index for index, word in enumerate(vocab)}
vocab_size = len(vocab)
#print(vocab)
#print("Vocab size: ", vocab_size)

In [4]:
from torch.nn.utils.rnn import pad_sequence

PAD_TOKEN = "<pad>"
PAD_IDX = vocab_to_index[PAD_TOKEN]

def collate_batch(batch):
    inputs, targets = zip(*batch)

    #inputs = [torch.tensor(seq, dtype = torch.long()) for seq in inputs]
    #targets = [torch.tensor(seq, dtype = torch.long()) for seq in targets]

    padded_inputs = pad_sequence(inputs, batch_first=True, padding_value=PAD_IDX)
    padded_targets = pad_sequence(targets, batch_first=True, padding_value=PAD_IDX)

    return padded_inputs, padded_targets


# 1. Rebuild vocab from lowercased text and include <unk>
special_tokens = ["<pad>", "<start>", "<end>", "<unk>"]

vocab_to_index = {}

vocab = set()
for sentence in tokenized_lines:
    vocab.update(sentence.lower().split())      # lowercase here

vocab = special_tokens + sorted(vocab)          # sorted for reproducibility
vocab_to_index = {w:i for i,w in enumerate(vocab)}

PAD_IDX = vocab_to_index["<pad>"]
UNK_IDX = vocab_to_index["<unk>"]

# 2. Update your Dataset to use .get(…, UNK_IDX) instead of direct indexing
class ShakespeareDataset(Dataset):
    def __init__(self, tokenized_lines, vocab_to_idx):
        self.data = [
            line.lower().split()
            for line in tokenized_lines
            if len(line.lower().split()) > 2  # ignore short lines
        ]
        self.vocab_to_idx = vocab_to_idx

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        words = self.data[idx]

        # THIS should raise error if token is missing
        input_ids = [self.vocab_to_idx.get(word, self.vocab_to_idx["<unk>"]) for word in words[:-1]]
        target_ids = [self.vocab_to_idx.get(word, self.vocab_to_idx["<unk>"]) for word in words[1:]]

        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)

In [5]:
'''
data = [
    line.lower().split()
    for line in tokenized_lines
    if len(line.lower().split()) > 2  # ignore short lines
]
for i in range(5):
    words = data[i]
    print(words[:-1])
    print(words[1:])
'''

'\ndata = [\n    line.lower().split()\n    for line in tokenized_lines\n    if len(line.lower().split()) > 2  # ignore short lines\n]\nfor i in range(5):\n    words = data[i]\n    print(words[:-1])\n    print(words[1:])\n'

In [6]:
def positional_encodings(seq_len, embedding_dim, device):
    position = torch.arange(seq_len, dtype=torch.float, device=device).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embedding_dim, 2, device=device).float() * (-math.log(10000.0) / embedding_dim))
    pe = torch.zeros(seq_len, embedding_dim, device=device)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

In [7]:
'''
class self_attention(nn.Module):
    def __init__(self):
        super(self_attention, self).__init__()
        #self.w_qkv_downscale = nn.Linear(in_channels=16, out_channels=2)
        #self.latent_upscale = nn.Linear(in_channels=2, out_channels=16)
        #self.layer_norm = nn.LayerNorm()
        self.softmax = nn.Softmax()

    def forward(self, Q, K, V):
        #Q = self.w_qkv_downscale(Q)
        #K = self.w_qkv_downscale(K)
        #V = self.w_qkv_downscale(V)
        seq_len = Q.size(1)
        mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1)
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (16 ** 0.5) + mask
        #attention_weights = self.softmax(attention_scores, dim=-1)
        attention_weights = self.softmax(attention_scores)
        context = torch.matmul(attention_weights, V)

        #context = self.latent_upscale(context)

        # Residual + Norm
        # x = self.layer_norm(context + x)

        # Feedforward + Norm
        #ff_out = self.feed_fwd(x)
        #out = self.layer_norm(ff_out + x)

        # Final linear (optional)
        #return self.output_proj(out)
        return context
'''

"\nclass self_attention(nn.Module):\n    def __init__(self):\n        super(self_attention, self).__init__()\n        #self.w_qkv_downscale = nn.Linear(in_channels=16, out_channels=2)\n        #self.latent_upscale = nn.Linear(in_channels=2, out_channels=16)\n        #self.layer_norm = nn.LayerNorm()\n        self.softmax = nn.Softmax()\n\n    def forward(self, Q, K, V):\n        #Q = self.w_qkv_downscale(Q)\n        #K = self.w_qkv_downscale(K)\n        #V = self.w_qkv_downscale(V)\n        seq_len = Q.size(1)\n        mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1)\n        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (16 ** 0.5) + mask\n        #attention_weights = self.softmax(attention_scores, dim=-1)\n        attention_weights = self.softmax(attention_scores)\n        context = torch.matmul(attention_weights, V)\n\n        #context = self.latent_upscale(context)\n\n        # Residual + Norm\n        # x = self.layer_norm(context + x)\n\n   

In [8]:
class self_attention(nn.Module):
    def __init__(self):
        super(self_attention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, Q, K, V, attn_mask=None):
        # Q, K, V shape: (batch, seq_len, dim)
        batch_size, seq_len, dim = Q.size()

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(dim)  # (batch, seq_len, seq_len)

        # Causal mask (upper triangular)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=Q.device) * float('-inf'), diagonal=1)
        scores = scores + causal_mask

        # Padding mask (optional)
        if attn_mask is not None:
            # attn_mask: (batch, 1, seq_len), 1 for keep, 0 for mask
            scores = scores.masked_fill(attn_mask == 0, float('-inf'))

        weights = self.softmax(scores)
        context = torch.matmul(weights, V)  # (batch, seq_len, dim)

        return context


In [9]:
'''
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        # creating the multi-headed attention block.
        self.self_attn1 = self_attention()
        self.self_attn2 = self_attention()
        self.self_attn3 = self_attention()
        self.self_attn4 = self_attention()

        self.self_attn5 = self_attention()
        self.self_attn6 = self_attention()
        self.self_attn7 = self_attention()
        self.self_attn8 = self_attention()


        # All the layers, we gonna need to make the decoder work.
        self.layer_norm = nn.LayerNorm(16)
        self.softmax = nn.Softmax(-1)

        self.latent_downscale = nn.Linear(16, 2)
        self.latent_upscale = nn.Linear(2, 16)

        self.final_linear_layer = nn.Linear(16, vocab_size) # out_features can be replaced with embedding dimension (at least, here).


    def forward(self, Q, K, V, X):

        q = self.latent_downscale(Q)
        k = self.latent_downscale(K)
        v = self.latent_downscale(V)

        x = self.latent_downscale(X)
        
        # getting the contexts from the respective self attention layers in the multi-headed attention block.
        context1 = self.self_attn1(q, k, v)
        context2 = self.self_attn2(q, k, v)
        context3 = self.self_attn3(q, k, v)
        context4 = self.self_attn4(q, k, v)

        context5 = self.self_attn5(q, k, v)
        context6 = self.self_attn6(q, k, v)
        context7 = self.self_attn7(q, k, v)
        context8 = self.self_attn8(q, k, v)

        # adding them up
        final_encodings = self.latent_upscale(context1 + context2 + context3 + context5 + context6 + context7 + context8 + x)
        final_encodings = self.layer_norm(final_encodings)
        logits = self.final_linear_layer(final_encodings)

        return logits
'''

'\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        \n        # creating the multi-headed attention block.\n        self.self_attn1 = self_attention()\n        self.self_attn2 = self_attention()\n        self.self_attn3 = self_attention()\n        self.self_attn4 = self_attention()\n\n        self.self_attn5 = self_attention()\n        self.self_attn6 = self_attention()\n        self.self_attn7 = self_attention()\n        self.self_attn8 = self_attention()\n\n\n        # All the layers, we gonna need to make the decoder work.\n        self.layer_norm = nn.LayerNorm(16)\n        self.softmax = nn.Softmax(-1)\n\n        self.latent_downscale = nn.Linear(16, 2)\n        self.latent_upscale = nn.Linear(2, 16)\n\n        self.final_linear_layer = nn.Linear(16, vocab_size) # out_features can be replaced with embedding dimension (at least, here).\n\n\n    def forward(self, Q, K, V, X):\n\n        q = self.latent_downscale(Q)\n        k = self.la

In [10]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        # creating the multi-headed attention block.
        self.self_attn1 = self_attention()
        self.self_attn2 = self_attention()
        self.self_attn3 = self_attention()
        self.self_attn4 = self_attention()
        self.self_attn5 = self_attention()
        self.self_attn6 = self_attention()
        self.self_attn7 = self_attention()
        self.self_attn8 = self_attention()

        self.self_attn9 = self_attention()
        self.self_attn10 = self_attention()
        self.self_attn11 = self_attention()
        self.self_attn12 = self_attention()
        self.self_attn13 = self_attention()
        self.self_attn14 = self_attention()
        self.self_attn15 = self_attention()
        self.self_attn16 = self_attention()


        # All the layers, we gonna need to make the decoder work.
        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.softmax = nn.Softmax(-1)
        
        self.latent_downscale = nn.Linear(embedding_dim, 32)
        self.latent_upscale = nn.Linear(32, embedding_dim)

        self.final_linear_layer = nn.Linear(embedding_dim, vocab_size)


    def forward(self, Q, K, V, X, attn_mask=None):
        q = self.latent_downscale(Q)
        k = self.latent_downscale(K)
        v = self.latent_downscale(V)
        x = self.latent_downscale(X)

        context1 = self.self_attn1(q, k, v, attn_mask)
        context2 = self.self_attn2(q, k, v, attn_mask)
        context3 = self.self_attn3(q, k, v, attn_mask)
        context4 = self.self_attn4(q, k, v, attn_mask)
        context5 = self.self_attn5(q, k, v, attn_mask)
        context6 = self.self_attn6(q, k, v, attn_mask)
        context7 = self.self_attn7(q, k, v, attn_mask)
        context8 = self.self_attn8(q, k, v, attn_mask)

        context9 = self.self_attn1(q, k, v, attn_mask)
        context10 = self.self_attn2(q, k, v, attn_mask)
        context11 = self.self_attn3(q, k, v, attn_mask)
        context12 = self.self_attn4(q, k, v, attn_mask)
        context13 = self.self_attn5(q, k, v, attn_mask)
        context14 = self.self_attn6(q, k, v, attn_mask)
        context15 = self.self_attn7(q, k, v, attn_mask)
        context16 = self.self_attn8(q, k, v, attn_mask)

        combined = torch.cat((context1, context2, context3, context4, context5, context6, context7, context8, context9, context10, context11, context12, context13, context14, context15, context16), 2)
        final_encodings = combined + self.latent_upscale(x)
        final_encodings = self.layer_norm(final_encodings)
        logits = self.final_linear_layer(final_encodings)

        return logits


In [11]:
'''
device = "cuda" if torch.cuda.is_available() else "cpu"

embedding_layer = nn.Embedding(vocab_size, embedding_dim).to(device)
model = Model().to(device)
PAD_IDX = vocab_to_index.get("<pad>", 0)  # Ensure this is consistent with your vocab
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Assuming: tokenized_lines = open("input.txt").readlines(), vocab_to_idx built
dataset = ShakespeareDataset(tokenized_lines, vocab_to_index)
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_batch)

def create_padding_mask(input_ids, pad_idx):
    input_ids: (batch, seq_len)
    return (input_ids != pad_idx).unsqueeze(1)  # (batch, 1, seq_len)

for epoch in range(10):
    total_loss = 0

    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        seq_len = inputs.size(1)

        # Embeddings + positions
        embed = embedding_layer(inputs)  # (batch, seq_len, emb_dim)
        pos = positional_encodings(seq_len, embedding_dim, device)
        x = embed + pos

        # Decoder input
        q = k = v = x

        attn_mask = create_padding_mask(inputs, PAD_IDX).to(device)  # (batch, 1, seq_len)
        logits = model(q, k, v, x, attn_mask=attn_mask)

        #print("inputs.dtype =", inputs.dtype)
        logits = logits.view(-1, vocab_size)
        targets = targets.view(-1).long()  # Ensure targets are Long (int64)

        # Ensure logits are float (if any issue with dtype mismatch)
        #logits = logits.long()

        loss = loss_fn(logits, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
'''


'\ndevice = "cuda" if torch.cuda.is_available() else "cpu"\n\nembedding_layer = nn.Embedding(vocab_size, embedding_dim).to(device)\nmodel = Model().to(device)\nPAD_IDX = vocab_to_index.get("<pad>", 0)  # Ensure this is consistent with your vocab\nloss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)\noptimizer = optim.Adam(model.parameters(), lr=1e-3)\n\n# Assuming: tokenized_lines = open("input.txt").readlines(), vocab_to_idx built\ndataset = ShakespeareDataset(tokenized_lines, vocab_to_index)\nloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_batch)\n\ndef create_padding_mask(input_ids, pad_idx):\n    input_ids: (batch, seq_len)\n    return (input_ids != pad_idx).unsqueeze(1)  # (batch, 1, seq_len)\n\nfor epoch in range(10):\n    total_loss = 0\n\n    for inputs, targets in loader:\n        inputs, targets = inputs.to(device), targets.to(device)\n        seq_len = inputs.size(1)\n\n        # Embeddings + positions\n        embed = embedding_layer(inputs)  # (

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model().to(device)

In [13]:
embedding_layer = nn.Embedding(vocab_size, embedding_dim).to(device)
#model = Model().to(device)
PAD_IDX = vocab_to_index.get("<pad>", 0)  # Ensure this is consistent with your vocab
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Assuming: tokenized_lines = open("input.txt").readlines(), vocab_to_idx built
dataset = ShakespeareDataset(tokenized_lines, vocab_to_index)
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_batch)

def create_padding_mask(input_ids, pad_idx):
    input_ids: (batch, seq_len)
    return (input_ids != pad_idx).unsqueeze(1)  # (batch, 1, seq_len)

for epoch in range(1000):
    total_loss = 0
    total_accuracy = 0

    for inputs, targets in loader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Get embeddings
        input_embeddings = embedding_layer(inputs)   
        pos_enc = positional_encodings(input_embeddings.size(1), embedding_dim, device)
        input_with_pos = input_embeddings + pos_enc

        logits = model(input_with_pos, input_with_pos, input_with_pos, input_with_pos)
        
        loss = loss_fn(logits.view(-1, vocab_size), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accuracy
        predicted = torch.argmax(logits, dim=-1)
        correct = (predicted == targets).float()
        mask = (targets != PAD_IDX).float()
        accuracy = (correct * mask).sum() / mask.sum()

        total_loss += loss.item()
        total_accuracy += accuracy.item()

    avg_loss = total_loss / len(loader)
    avg_accuracy = total_accuracy / len(loader)
    #print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}")
    print(f"Epoch {epoch+1}, Loss: {avg_loss}, Accuracy: {avg_accuracy}")


Epoch 1, Loss: 7.487896437644959, Accuracy: 0.05379782642587088
Epoch 2, Loss: 6.7043722879886625, Accuracy: 0.06908307314151899
Epoch 3, Loss: 6.375967073440552, Accuracy: 0.07276885160710662
Epoch 4, Loss: 6.168147490024567, Accuracy: 0.07402116145705805
Epoch 5, Loss: 6.009120998978615, Accuracy: 0.07489461567020043
Epoch 6, Loss: 5.89225635945797, Accuracy: 0.07613288188818843
Epoch 7, Loss: 5.801487250328064, Accuracy: 0.07902839545160532
Epoch 8, Loss: 5.729757804870605, Accuracy: 0.08250672441441566
Epoch 9, Loss: 5.671132268309593, Accuracy: 0.08524433866143227
Epoch 10, Loss: 5.6233118522167205, Accuracy: 0.08858420144533738
Epoch 11, Loss: 5.582855963706971, Accuracy: 0.09106882113497704
Epoch 12, Loss: 5.54833712041378, Accuracy: 0.09283221470192075
Epoch 13, Loss: 5.518044276237488, Accuracy: 0.09408504190854729
Epoch 14, Loss: 5.4920656979084015, Accuracy: 0.0954360072268173
Epoch 15, Loss: 5.466692861914635, Accuracy: 0.09725201115943491
Epoch 16, Loss: 5.443204184770584,

In [14]:
total_accuracy

144.48945175111294

In [15]:
for x, y in loader:
    print("Input:", x[0])
    print("Target:", y[0])
    break

Input: tensor([10230, 13478, 17452, 22500, 20693,  1529, 10506,     0,     0,     0,
            0])
Target: tensor([13478, 17452, 22500, 20693,  1529, 10506, 22484,     0,     0,     0,
            0])


In [16]:
x.dtype

torch.int64

In [17]:
logits

tensor([[[-33.9713, -33.9683, -33.9593,  ..., -33.9992, -33.9708, -33.9663],
         [-19.5451, -19.5176, -19.5630,  ..., -19.5021, -19.4676, -19.4753],
         [-18.4089, -18.4275, -18.4615,  ..., -18.4283, -18.4773, -18.4932],
         ...,
         [-22.5115, -22.3625, -22.3474,  ..., -22.3058, -22.3134, -22.3413],
         [-24.6946, -24.5644, -24.4604,  ..., -24.5587, -24.5470, -24.5087],
         [-22.1440, -22.0675, -22.1316,  ..., -22.0725, -21.9804, -21.9895]],

        [[-19.3573, -19.3502, -19.2978,  ..., -19.3195, -19.3252, -19.3106],
         [-25.5138, -25.5944, -25.4979,  ..., -25.5242, -25.4555, -25.4451],
         [-13.0019, -12.9669, -12.9570,  ..., -12.9504, -12.9832, -12.9763],
         ...,
         [-30.3624, -30.2654, -30.2182,  ..., -30.2206, -30.4051, -30.4372],
         [-20.1212, -20.2557, -19.9894,  ..., -20.0103, -20.1229, -20.0828],
         [-18.0678, -17.8715, -17.9489,  ..., -17.9981, -18.0194, -18.0103]],

        [[-23.3453, -23.1783, -23.2052,  ...

In [18]:
import torch

def generate_sequence(model, start_text, vocab_to_idx, idx_to_vocab, embedding_layer, device, max_len=50):
    model.eval()  # Setting the model to evaluation mode
    start_tokens = start_text.lower().split()

    # Convert words to indices
    input_ids = [vocab_to_idx.get(word, vocab_to_idx["<pad>"]) for word in start_tokens]
    generated = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)  # (1, seq_len)

    for _ in range(max_len):
        seq_len = generated.size(1)

        # Recalculate positional encodings each time
        pos = positional_encodings(seq_len, embedding_layer.embedding_dim, device)
        input_embed = embedding_layer(generated) + pos

        # Attention mask
        attn_mask = create_padding_mask(generated, vocab_to_idx["<pad>"]).to(device)

        with torch.no_grad():
            q = k = v = input_embed
            logits = model(q, k, v, input_embed, attn_mask)

        # Sample next token
        logits = logits[:, -1, :]  # Get last token's logits
        temperature = 0.7 # I added a temperature hyperparameter to check for repitition.
        probs = torch.softmax(logits / temperature, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)  # Shape: (1, 1)

        # Stop if end token
        token_id = next_token.item()
        if idx_to_vocab.get(token_id, "") == "<end>":
            break

        # Append next token
        generated = torch.cat((generated, next_token), dim=1)

    # Convert generated indices back to words
    generated_text = ' '.join([idx_to_vocab.get(idx.item(), "<unk>") for idx in generated.squeeze()])
    return generated_text

# Example of inference usage:
start_text = "<start>"  # Starting text for generation
generated_text = generate_sequence(
    model=model, 
    start_text=start_text, 
    vocab_to_idx=vocab_to_index, 
    idx_to_vocab={index: word for word, index in vocab_to_index.items()}, 
    embedding_layer=embedding_layer, 
    device=device,
    max_len=50  # Limit generated sequence length
)

print("Generated Text:")
print(generated_text)


Generated Text:
<start> to thy person, justice seem'st like success? these enfoldings? yourself. wisdom felt your cunning. devise, beggars thrust wither'd shrub; joys inform services king; rejoice ass, words: ear; sir, mountain mother; and credit smile inform yourselves stones, honourable boar thoughts tongues light. violence! idle threatening grace: praise nurse? love! services land.


In [20]:
torch.save(model.state_dict(), "saved_model.pth")

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference_model = Model()
inference_model.load_state_dict(torch.load("saved_model.pth", map_location=device))
inference_model.to(device)
inference_model.eval()

  inference_model.load_state_dict(torch.load("saved_model.pth", map_location=device))


Model(
  (self_attn1): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn2): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn3): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn4): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn5): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn6): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn7): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn8): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn9): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn10): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn11): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn12): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn13): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn14): self_attention(
    (softmax): Softmax(dim=-1)
  )
  (self_attn15): self_attention(
    (softmax): So

In [27]:
import torch

def generate_sequence(model, start_text, vocab_to_idx, idx_to_vocab, embedding_layer, device, max_len=50):
    model.eval()  # Setting the model to evaluation mode
    start_tokens = start_text.lower().split()

    # Convert words to indices
    input_ids = [vocab_to_idx.get(word, vocab_to_idx["<pad>"]) for word in start_tokens]
    generated = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)  # (1, seq_len)

    for _ in range(max_len):
        seq_len = generated.size(1)

        # Recalculate positional encodings each time
        pos = positional_encodings(seq_len, embedding_layer.embedding_dim, device)
        input_embed = embedding_layer(generated) + pos

        # Attention mask
        attn_mask = create_padding_mask(generated, vocab_to_idx["<pad>"]).to(device)

        with torch.no_grad():
            q = k = v = input_embed
            logits = model(q, k, v, input_embed, attn_mask)

        # Sample next token
        logits = logits[:, -1, :]  # Get last token's logits
        temperature = 0.7 # I added a temperature hyperparameter to check for repitition.
        probs = torch.softmax(logits / temperature, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)  # Shape: (1, 1)

        # Stop if end token
        token_id = next_token.item()
        if idx_to_vocab.get(token_id, "") == "<end>":
            break

        # Append next token
        generated = torch.cat((generated, next_token), dim=1)

    # Convert generated indices back to words
    generated_text = ' '.join([idx_to_vocab.get(idx.item(), "<unk>") for idx in generated.squeeze()])
    return generated_text

# Example of inference usage:
start_text = "<start>"  # Starting text for generation
generated_text = generate_sequence(
    model=inference_model, 
    start_text=start_text, 
    vocab_to_idx=vocab_to_index, 
    idx_to_vocab={index: word for word, index in vocab_to_index.items()}, 
    embedding_layer=embedding_layer, 
    device=device,
    max_len=50  # Limit generated sequence length
)

print("Generated Text:")
print(generated_text)

Generated Text:
<start> him the king's chiefest soldiers; and i am here: requires wide begg'd begg'd hit of sort with compass dream, swords with mildness bona pride and fruitful land, forty news: make haste; you troubled in talk, swift duke, walls, especially hurt magistrate, guilty;' plain be? talk'st pleasure; lass talk, piercing eloquence:
