In [None]:
import custom_transformers
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import LambdaLR

In [42]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

d_model = 512
d_ff = 1024
num_heads = 8
batch_size = 128
seq_len = 50

In [4]:
with open("tiny-shakespeare.txt", 'r') as f:
    text = f.read()

vocab = list(set(text))

stoi = { ch : idx for idx, ch in enumerate(vocab) }
itos = { idx : ch for ch, idx in stoi.items() }

In [5]:
dataset = [ stoi[ch] for ch in text ]

train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size

print(f"Eg of dataset : {dataset[:10]}")
print(f"Length of Total data is {len(dataset)}")
print(f"Train Size is {train_size}")
print(f"Train Size is {test_size}")

assert train_size + test_size == len(dataset), print("Size error")

Eg of dataset : [61, 16, 27, 46, 26, 58, 14, 16, 26, 16]
Length of Total data is 1115394
Train Size is 1003854
Train Size is 111540


In [15]:
class ShakespeareData(Dataset):
    
    def __init__(self, dataset, seq_len):
        super().__init__()
        self.dataset = dataset
        self.seq_len = seq_len
    
    def __len__(self):
        return len(self.dataset) - self.seq_len - 1

    def __getitem__(self, index):
        inputs = self.dataset[index : self.seq_len + index]
        outputs = self.dataset[index + 1 : self.seq_len + index + 1]
        return (
            torch.tensor(inputs, dtype = torch.long),
            torch.tensor(outputs, dtype = torch.long)
        )

In [16]:
trainer_data = ShakespeareData(dataset = dataset[:train_size], seq_len = seq_len)
tester_data = ShakespeareData(dataset = dataset[train_size:], seq_len = seq_len)

train_data = DataLoader(trainer_data, batch_size = batch_size, shuffle = True)
test_data = DataLoader(tester_data, batch_size = batch_size, shuffle = True)

In [43]:
transformers = custom_transformers.TransformerNetwork(
    src_embedding_matrix = None,
    tgt_embedding_matrix = None, 
    n_x = 6,
    num_heads = num_heads,
    d_model = d_model,
    d_ff = d_ff,
    src_seq_len = seq_len,
    tgt_seq_len = seq_len,
    src_vocab_size = len(vocab),
    tgt_vocab_size = len(vocab),
    decoder_only = True,
    device = device
).to(device)

In [None]:
def transformer_lr_schedule(d_model, warmup_steps):
    def lr_lambda(step):
        step = max(step, 1)  # avoid division by zero
        return (d_model ** -0.5) * min(
            step ** -0.5,
            step * (warmup_steps ** -1.5)
        )
    return lr_lambda


In [45]:
optimizer = torch.optim.Adam(params = transformers.parameters(), lr = 1.0, betas = (0.9, 0.98), eps = 10e-9)
critetion = torch.nn.CrossEntropyLoss()
scheduler = LambdaLR(
    optimizer,
    lr_lambda=transformer_lr_schedule(
        d_model=512,
        warmup_steps=4000
    )
)

In [None]:
transformers.train()

num_epochs = 5
print_every = 100
step = 0

for epoch in range(num_epochs):
    for X, Y in train_data:   # X: src, Y: tgt
        X = X.to(device)
        Y = Y.to(device)

        optimizer.zero_grad()

        # Forward
        logits = transformers(X, None, -1)  
        # logits: (B, T, vocab_size)

        # Reshape for CrossEntropyLoss
        B, T, V = logits.shape
        loss = critetion(
            logits.reshape(-1, V),
            Y.reshape(-1)
        )


        # Backward
        loss.backward()
        optimizer.step()
        scheduler.step()

        step += 1

        if step % print_every == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] "
                  f"Step [{step}] "
                  f"Loss: {loss.item():.4f}")


In [49]:
@torch.no_grad()
def generate_text(
    model,
    context,
    stoi,
    itos,
    temperature=1.0,
    device="cuda"
):
    model.eval()

    # Encode context â†’ shape: (1, T)
    idx = torch.tensor(
        [[stoi[c] for c in context]],
        dtype=torch.long,
        device=device
    )
    y = torch.zeros(3, 3).type(torch.long).to(device)
    
    max_new_tokens = 50 - len(context)
    for _ in range(max_new_tokens):
        # Forward pass (decoder-only)
        # if idx.size(1) > seq_len:
        #     idx = idx[:, -seq_len:]

        
        logits = model(idx, idx, -1)               # (1, T, vocab_size)

        # Take last token's logits
        logits = logits[:, -1, :] / temperature

        probs = torch.softmax(logits, dim=-1)

        # Sample next token
        next_id = torch.multinomial(probs, num_samples=1)  # (1, 1)

        # Append to sequence
        idx = torch.cat([idx, next_id], dim=1)

    # Decode tokens to string
    return "".join(itos[i.item()] for i in idx[0])

context = 'JU'

generate_text(model=transformers, context=context, stoi=stoi, itos=itos)


'JULIET:\nBut how long shall that title and true fro'

In [None]:
@torch.no_grad()
def generate_text(
    model,
    context,
    stoi,
    itos,
    max_new_tokens=1000,
    temperature=0.8,
    top_k=40,
    device="cuda"
):
    model.eval()

    # Encode initial context
    start_tokens = torch.tensor(
        [stoi[c] for c in context],
        dtype=torch.long,
        device=device
    )

    # Full history (for output)
    full_idx = start_tokens.clone()

    # Sliding window (for model input)
    context_idx = start_tokens.unsqueeze(0)  # (1, T)

    block_size = model.tgt_seq_len

    for _ in range(max_new_tokens):

        # Enforce context window
        if context_idx.size(1) > block_size:
            context_idx = context_idx[:, -block_size:]

        # Forward pass
        logits = model(context_idx, context_idx, -1)
        logits = logits[:, -1, :] / temperature

        # Top-k filtering
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("inf")

        probs = torch.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)  # (1, 1)

        # Update buffers
        context_idx = torch.cat([context_idx, next_id], dim=1)
        full_idx = torch.cat([full_idx, next_id.squeeze(0)], dim=0)

    return "".join(itos[i.item()] for i in full_idx)


context = 'ROM'

with open('output.txt', 'w') as f:
    f.write(generate_text(transformers, context, stoi, itos))

In [50]:
torch.save(
    transformers.state_dict(),
    'transformers_tiny_shakespeare_base_model_2.pt'
)

In [53]:
transformers2 = custom_transformers.TransformerNetwork(
    src_embedding_matrix = None,
    tgt_embedding_matrix = None, 
    n_x = 6,
    num_heads = num_heads,
    d_model = d_model,
    d_ff = d_ff,
    src_seq_len = seq_len,
    tgt_seq_len = seq_len,
    src_vocab_size = len(vocab),
    tgt_vocab_size = len(vocab),
    decoder_only = True,
    device = device
).to(device)
transformers2.load_state_dict(torch.load('transformers_tiny_shakespeare_base_model_2.pt', weights_only=True))

<All keys matched successfully>