In [None]:
import sys
sys.path.append('..')

In [1]:
from minbpe import BasicTokenizer

tokenizer = BasicTokenizer()
tokenizer.load("path")

def get_vocab_size(tokenizer: BasicTokenizer) -> int:
    vocab = tokenizer.vocab
    special_tokens = tokenizer.special_tokens

    return len(vocab) + len(special_tokens)

ModuleNotFoundError: No module named 'minbpe'

In [None]:
import torch
torch.manual_seed(3647)

In [None]:
from transformer.model import GPTLanguageModel

block_size = 512
n_embd = 128
n_head = 12
n_layer = 1
dropout = 0.2
batch_size = 128
vocab_size = get_vocab_size(tokenizer)
device = 'cpu'

model = GPTLanguageModel(
    vocab_size=vocab_size,
    block_size=block_size,
    n_embd = n_embd,
    n_head = n_head,
    n_layer = n_layer,
    dropout=dropout,
    device=device
).to(device)
model = torch.compile(model)

print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

In [None]:
with open("../output/combined_text.txt", "r") as f:
    text_sequence = f.read()

encoded_text_sequence = tokenizer.encode(text_sequence)
len(encoded_text_sequence)

In [None]:
data = torch.tensor(encoded_text_sequence, dtype=torch.long)
split_index = int(0.9*len(data))
train_data = data[:split_index]
val_data = data[split_index:]

In [3]:
from typing import Tuple

def get_batch(split:str) -> Tuple[torch.Tensor, torch.Tensor]:
    data = train_data if split =="train" else val_data
    index = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in index])
    y = torch.stack([data[i+1:i+block_size+1] for i in index])
    x,y = x.to(device), y.to(device)
    return x,y

NameError: name 'torch' is not defined

In [None]:
x, y = get_batch('train')
x.shape, y.shape

In [None]:
from typing import Dict

eval_iters = 200

@torch.no_grad()
def estimate_loss() -> Dict:
    output = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x,y = get_batch(split)
            _, loss = model(x,y)
            losses[k] = loss.item()
        output[split] = losses.mean()
        model.train()
        return output

In [None]:
def save_checkpoint(model:GPTLanguageModel, optimizer:torch.optim.Optimizer, epoch:int,
                    loss: float, file_path: str = "checkpoint.pth")->None:
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict

    }
    torch.save(checkpoint, file_path)

In [None]:
max_iters = 1000
eval_interval = 10
learning_rate = 1e-4
save_interval = 100
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

train_losses = []
val_losses = []

for iteration in range(max_iters):
    if iteration % eval_interval == 0 or iteration == max_iters - 1:
        losses = estimate_loss()
        print(
            f"step {iteration}: "
            f"train loss {losses['train']:.4f}, "
            f"val loss {losses['val']:.4f}"
        )
        train_losses.append(losses['train'])
        val_losses.append(losses['val'])

    x_batch, y_batch = get_batch('train')
    logits, loss = model(x_batch, y_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if iteration % save_interval == 0:
        save_checkpoint(
            model=model,
            optimizer=optimizer,
            epoch=iteration,
            loss=loss.item(),
            file_path=f"../output/pre_training/run_1/checkpoint_{iteration}.pth"
        )

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Evaluation Step")
plt.ylim(0)
plt.ylabel("Loss")
plt.title("Training and Validation Loss Over Time")
plt.legend()
plt.grid()
plt.show()

In [None]:
input_tokens = tokenizer.encode("Hi")
input_tokens = torch.tensor(
    input_tokens, dtype=torch.long).unsqueeze(0).to(device)

model.eval()
with torch.no_grad():
    output = model.generate(input_tokens=input_tokens, max_new_tokens=50)

print(tokenizer.decode(output[0].tolist()))