In [14]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [15]:
embedding_dim = 512

In [16]:
tokenized_lines = open("input.txt", "r")
tokenized_lines = tokenized_lines.readlines()

vocab = set()
special_tokens = ["<pad>", "<start>", "<end>"]
for sentence in tokenized_lines:
    vocab.update(sentence.split())
vocab = special_tokens + list(vocab)

vocab_to_index = {word:index for index, word in enumerate(vocab)}
vocab_size = len(vocab)
#print(vocab)
#print("Vocab size: ", vocab_size)

In [17]:
from torch.nn.utils.rnn import pad_sequence

PAD_TOKEN = "<pad>"
PAD_IDX = vocab_to_index[PAD_TOKEN]

def collate_batch(batch):
    inputs, targets = zip(*batch)

    #inputs = [torch.tensor(seq, dtype = torch.long()) for seq in inputs]
    #targets = [torch.tensor(seq, dtype = torch.long()) for seq in targets]

    padded_inputs = pad_sequence(inputs, batch_first=True, padding_value=PAD_IDX)
    padded_targets = pad_sequence(targets, batch_first=True, padding_value=PAD_IDX)

    return padded_inputs, padded_targets


# 1. Rebuild vocab from lowercased text and include <unk>
special_tokens = ["<pad>", "<start>", "<end>", "<unk>"]

vocab_to_index = {}

vocab = set()
for sentence in tokenized_lines:
    vocab.update(sentence.lower().split())      # lowercase here

vocab = special_tokens + sorted(vocab)          # sorted for reproducibility
vocab_to_index = {w:i for i,w in enumerate(vocab)}

PAD_IDX = vocab_to_index["<pad>"]
UNK_IDX = vocab_to_index["<unk>"]

# 2. Update your Dataset to use .get(…, UNK_IDX) instead of direct indexing
class ShakespeareDataset(Dataset):
    def __init__(self, tokenized_lines, vocab_to_idx):
        self.data = [
            line.lower().split()
            for line in tokenized_lines
            if len(line.lower().split()) > 2  # ignore short lines
        ]
        self.vocab_to_idx = vocab_to_idx

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        words = self.data[idx]

        # THIS should raise error if token is missing
        input_ids = [self.vocab_to_idx.get(word, self.vocab_to_idx["<unk>"]) for word in words[:-1]]
        target_ids = [self.vocab_to_idx.get(word, self.vocab_to_idx["<unk>"]) for word in words[1:]]

        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)

In [19]:
def positional_encodings(seq_len, embedding_dim, device):
    position = torch.arange(seq_len, dtype=torch.float, device=device).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embedding_dim, 2, device=device).float() * (-math.log(10000.0) / embedding_dim))
    pe = torch.zeros(seq_len, embedding_dim, device=device)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

In [21]:
class self_attention(nn.Module):
    def __init__(self):
        super(self_attention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, Q, K, V, attn_mask=None):
        # Q, K, V shape: (batch, seq_len, dim)
        batch_size, seq_len, dim = Q.size()

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(dim)  # (batch, seq_len, seq_len)

        # Causal mask (upper triangular)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=Q.device) * float('-inf'), diagonal=1)
        scores = scores + causal_mask

        # Padding mask (optional)
        if attn_mask is not None:
            # attn_mask: (batch, 1, seq_len), 1 for keep, 0 for mask
            scores = scores.masked_fill(attn_mask == 0, float('-inf'))

        weights = self.softmax(scores)
        context = torch.matmul(weights, V)  # (batch, seq_len, dim)

        return context


In [23]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        # creating the multi-headed attention block.
        self.self_attn1 = self_attention()
        self.self_attn2 = self_attention()
        self.self_attn3 = self_attention()
        self.self_attn4 = self_attention()
        self.self_attn5 = self_attention()
        self.self_attn6 = self_attention()
        self.self_attn7 = self_attention()
        self.self_attn8 = self_attention()

        self.self_attn9 = self_attention()
        self.self_attn10 = self_attention()
        self.self_attn11 = self_attention()
        self.self_attn12 = self_attention()
        self.self_attn13 = self_attention()
        self.self_attn14 = self_attention()
        self.self_attn15 = self_attention()
        self.self_attn16 = self_attention()


        # All the layers, we gonna need to make the decoder work.
        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.softmax = nn.Softmax(-1)
        
        self.latent_downscale = nn.Linear(embedding_dim, 32)
        self.latent_upscale = nn.Linear(32, embedding_dim)

        self.final_linear_layer = nn.Linear(embedding_dim, vocab_size)


    def forward(self, Q, K, V, X, attn_mask=None):
        q = self.latent_downscale(Q)
        k = self.latent_downscale(K)
        v = self.latent_downscale(V)
        x = self.latent_downscale(X)

        context1 = self.self_attn1(q, k, v, attn_mask)
        context2 = self.self_attn2(q, k, v, attn_mask)
        context3 = self.self_attn3(q, k, v, attn_mask)
        context4 = self.self_attn4(q, k, v, attn_mask)
        context5 = self.self_attn5(q, k, v, attn_mask)
        context6 = self.self_attn6(q, k, v, attn_mask)
        context7 = self.self_attn7(q, k, v, attn_mask)
        context8 = self.self_attn8(q, k, v, attn_mask)

        context9 = self.self_attn1(q, k, v, attn_mask)
        context10 = self.self_attn2(q, k, v, attn_mask)
        context11 = self.self_attn3(q, k, v, attn_mask)
        context12 = self.self_attn4(q, k, v, attn_mask)
        context13 = self.self_attn5(q, k, v, attn_mask)
        context14 = self.self_attn6(q, k, v, attn_mask)
        context15 = self.self_attn7(q, k, v, attn_mask)
        context16 = self.self_attn8(q, k, v, attn_mask)

        combined = torch.cat((context1, context2, context3, context4, context5, context6, context7, context8, context9, context10, context11, context12, context13, context14, context15, context16), 2)
        final_encodings = combined + self.latent_upscale(x)
        final_encodings = self.layer_norm(final_encodings)
        #logits = self.final_linear_layer(final_encodings)

        return final_encodings


In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model().to(device)

In [26]:
embedding_layer = nn.Embedding(vocab_size, embedding_dim).to(device)
#model = Model().to(device)
PAD_IDX = vocab_to_index.get("<pad>", 0)  # Ensure this is consistent with your vocab
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Assuming: tokenized_lines = open("input.txt").readlines(), vocab_to_idx built
dataset = ShakespeareDataset(tokenized_lines, vocab_to_index)
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_batch)

def create_padding_mask(input_ids, pad_idx):
    input_ids: (batch, seq_len)
    return (input_ids != pad_idx).unsqueeze(1)  # (batch, 1, seq_len)

for epoch in range(1000):
    total_loss = 0
    total_accuracy = 0

    for inputs, targets in loader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Get embeddings
        input_embeddings = embedding_layer(inputs)   
        pos_enc = positional_encodings(input_embeddings.size(1), embedding_dim, device)
        input_with_pos = input_embeddings + pos_enc

        store_res = input_with_pos.shape[1]

        input_with_pos = nn.Linear(input_with_pos.shape[1], 5).to(device)(input_with_pos.transpose(-2, -1))
        input_with_pos = input_with_pos.transpose(-2, -1)

        final_encodings = model(input_with_pos, input_with_pos, input_with_pos, input_with_pos)
        final_encodings = nn.Linear(final_encodings.shape[1], store_res).to(device)(final_encodings.transpose(-2, -1))
        logits = nn.Linear(embedding_dim, vocab_size).to(device)(final_encodings.transpose(-2, -1))
        
        loss = loss_fn(logits.view(-1, vocab_size), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accuracy
        predicted = torch.argmax(logits, dim=-1)
        correct = (predicted == targets).float()
        mask = (targets != PAD_IDX).float()
        accuracy = (correct * mask).sum() / mask.sum()

        total_loss += loss.item()
        total_accuracy += accuracy.item()

    avg_loss = total_loss / len(loader)
    avg_accuracy = total_accuracy / len(loader)
    #print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}")
    print(f"Epoch {epoch+1}, Loss: {avg_loss}, Accuracy: {avg_accuracy}")


Epoch 1, Loss: 10.209327824115753, Accuracy: 2.404593804385513e-05


KeyboardInterrupt: 

In [None]:
total_accuracy

100.00095293298364

In [None]:
for x, y in loader:
    print("Input:", x[0])
    print("Target:", y[0])
    break

Input: tensor([20827, 12845,  9796, 13784,  1096,     0,     0,     0,     0,     0])
Target: tensor([12845,  9796, 13784,  1096, 22793,     0,     0,     0,     0,     0])


In [None]:
x.dtype

torch.int64

In [None]:
import torch

def generate_sequence(model, start_text, vocab_to_idx, idx_to_vocab, embedding_layer, device, max_len=50):
    model.eval()  # Evaluation mode
    start_tokens = start_text.lower().split()

    # Convert words to indices
    input_ids = [vocab_to_idx.get(word, vocab_to_idx["<pad>"]) for word in start_tokens]
    generated = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)  # (1, seq_len)

    for _ in range(max_len):
        seq_len = generated.size(1)

        # Recalculate positional encodings each time
        pos = positional_encodings(seq_len, embedding_layer.embedding_dim, device)
        input_embed = embedding_layer(generated) + pos

        # Attention mask
        attn_mask = create_padding_mask(generated, vocab_to_idx["<pad>"]).to(device)

        with torch.no_grad():
            q = k = v = input_embed
            logits = model(q, k, v, input_embed, attn_mask)

        # Sample next token
        logits = logits[:, -1, :]  # Get last token's logits
        temperature = 0.7
        probs = torch.softmax(logits / temperature, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)  # Shape: (1, 1)

        # Stop if end token
        token_id = next_token.item()
        if idx_to_vocab.get(token_id, "") == "<end>":
            break

        # Append next token
        generated = torch.cat((generated, next_token), dim=1)

    # Convert generated indices back to words
    generated_text = ' '.join([idx_to_vocab.get(idx.item(), "<unk>") for idx in generated.squeeze()])
    return generated_text

# Example of inference usage:
start_text = "<start>"  # Starting text for generation
generated_text = generate_sequence(
    model=model, 
    start_text=start_text, 
    vocab_to_idx=vocab_to_index, 
    idx_to_vocab={index: word for word, index in vocab_to_index.items()}, 
    embedding_layer=embedding_layer, 
    device=device,
    max_len=50  # Limit generated sequence length
)

print("Generated Text:")
print(generated_text)


Generated Text:
<start> tale sir, a careful height will be absent. wend of thy name. charge. caps thing; eye, cause procures with old tale, help. times and yet most piteous woes hung long, who's here! provost, thinkest glory. and that name became is't possible friend of sorrow wind betwixt as i said, dearly
