In [None]:
import torch
import warnings
warnings.filterwarnings('ignore')

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
embd_dim = 128
seq_length = 256
batch_size = 32
n_heads = 4
device = "cuda" if torch.cuda.is_available() else "cpu"
dropout = 0.1

In [None]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
data = torch.tensor(encode(text), device=device)
data, val_data = data[:int(len(data)*(0.8))], data[int(len(data)*(0.8)):]

In [None]:
def get_batch(data, device):
    ix = torch.randint(len(data) - seq_length, (batch_size,))
    x = torch.stack([data[i:i+seq_length] for i in ix])
    y = torch.stack([data[i+1:i+seq_length+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [None]:
from model import InputEmbedding, Head, MultiHead, FeedFoward, DecoderBlock
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def compute_loss(targets, logits):
    B, T, C = logits.shape
    logits = logits.view(B*T, C)
    targets = targets.view(B*T)
    loss = F.cross_entropy(logits, targets)

    return loss

In [None]:
model = nn.Sequential(
    InputEmbedding(vocab_size, embd_dim, seq_length),
    DecoderBlock(embd_dim, n_heads, dropout),
    DecoderBlock(embd_dim, n_heads, dropout),
    DecoderBlock(embd_dim, n_heads, dropout),
    nn.LayerNorm(embd_dim),
    nn.Linear(embd_dim, vocab_size)
)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [None]:
for step in range(3000):
    x, y = get_batch(data, device)

    logits = model(x)
    loss = compute_loss(y, logits)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    x, y = get_batch(val_data, device)
    logits = model(x)

    with torch.no_grad():
        val_loss = compute_loss(y, logits)

    if step%100 == 0:
        print(f"step: {step}  train loss: {loss:.4f}  val loss: {val_loss:.4f}")        

In [None]:
model.eval()

with torch.no_grad():
    # start with a random token
    context = torch.randint(0, vocab_size, (1, 1), device=device)
    generated = context.clone()
    
    temperature = 0.8  
    
    for _ in range(400):
        if generated.shape[1] > seq_length:
            input_seq = generated[:, -seq_length:]
        else:
            input_seq = generated
        
        logits = model(input_seq)
        next_logits = logits[:, -1, :] / temperature
        
        probs = F.softmax(next_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        generated = torch.cat([generated, next_token], dim=1)
    
    generated_text = decode(generated[0].tolist())
    print("Generated text:")
    print(generated_text)