In [18]:
!nvidia-smi

Wed Apr 23 17:45:02 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 571.59                 Driver Version: 571.59         CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A2000 12GB        WDDM  |   00000000:65:00.0  On |                  Off |
| 33%   57C    P2             22W /   70W |    2786MiB /  12282MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

95% > we will try to use the memory > 90% - so that not out of memory

Theoretical GFLOPs = (CUDA Cores × Clock Speed in GHz × 2)

GFLOPs = 3328 × 1.2 × 2 = **7987.2 GFLOPs** ≈ **7.99 TFLOPs** = **0.0079872 PFLOPs**

In [19]:
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("bitext_bpe_tokenizer.json")

def get_vocab_size(tokenizer: Tokenizer) -> int:
    """
    Returns the total vocabulary size including special tokens.
    """
    vocab = tokenizer.get_vocab()
    return len(vocab)

Create the model

In [20]:
import torch
torch.manual_seed(3647)
torch.set_float32_matmul_precision('high')
torch._dynamo.config.suppress_errors = True

In [21]:

import torch

# Enable performance benchmarking for convolution operations
torch.backends.cudnn.benchmark = True

# If using PyTorch 2.0+, compile the model to improve speed (uncomment when stable)
# model = torch.compile(model)


In [22]:
from transformer.model import GPTLanguageModel

block_size = 256
n_embd = 512
n_head = 8
n_layer = 4
dropout = 0.2
batch_size = 64
vocab_size = get_vocab_size(tokenizer)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = GPTLanguageModel(
    vocab_size=vocab_size,
    block_size=block_size,
    n_embd=n_embd,
    n_head=n_head,
    n_layer=n_layer,
    dropout=dropout,
    device=device
).to(device)
# model = torch.compile(model)

print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

13.785088 M parameters


Data preparation

In [23]:
with open("train_sequences.txt", "r") as f:
    text_sequence = f.read()

encoded_text_sequence = tokenizer.encode(text_sequence)
len(encoded_text_sequence)

4512042

train-test-split

In [24]:
data = torch.tensor(encoded_text_sequence.ids, dtype=torch.long)

# Split into train/validation
split_index = int(0.9 * len(data))
train_data = data[:split_index]
val_data = data[split_index:]


Data Loader

In [25]:
from typing import Tuple
from torch.utils.data import Dataset, DataLoader


class TextDataset(Dataset):
    def __init__(self, data: torch.Tensor, block_size: int) -> None:
        self.data = data
        self.block_size = block_size

    def __len__(self) -> int:
        return len(self.data) - self.block_size

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.data[index:index + self.block_size]
        y = self.data[index + 1:index + self.block_size + 1]
        return x, y



def get_dataloaders(
        train_data: torch.Tensor,
        val_data: torch.Tensor,
        block_size: int,
        batch_size: int,
        device: torch.device
) -> Tuple[DataLoader, DataLoader]:
    train_dataset = TextDataset(train_data.to(device), block_size)
    val_dataset = TextDataset(val_data.to(device), block_size)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True, 
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False, 
    )

    return train_loader, val_loader


In [26]:
train_loader, val_loader = get_dataloaders(
    train_data=train_data,
    val_data=val_data,
    block_size=block_size,
    batch_size=batch_size,
    device=device
)
x, y = next(iter(train_loader))
x.shape, y.shape

(torch.Size([64, 256]), torch.Size([64, 256]))

Training

In [27]:
from typing import Dict


@torch.no_grad()
def estimate_loss(
    model: torch.nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    eval_iters: int
) -> Dict[str, float]:
    output = {}
    model.eval()

    for split, loader in [('train', train_loader), ('val', val_loader)]:
        losses = torch.zeros(eval_iters)
        for i, (x, y) in enumerate(loader):
            if i >= eval_iters:
                break
            with torch.no_grad():
                _, loss = model(x, y)
            losses[i] = loss.item()
        output[split] = losses.mean().item()

    model.train()
    return output

In [28]:
def save_checkpoint(
    model: GPTLanguageModel,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    loss: float,
    file_path: str = "checkpoint.pth"
) -> None:
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, file_path)

In [29]:
max_iters = 1
eval_interval = 100
eval_iters = 200
learning_rate = 3e-4

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
train_loader, val_loader = get_dataloaders(
    train_data=train_data,
    val_data=val_data,
    block_size=block_size,
    batch_size=batch_size,
    device=device
)

train_losses = []
val_losses = []

for iteration in range(max_iters):
    total_loss = 0.0
    model.train()
    for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
        if batch_idx % eval_interval == 0 or batch_idx == len(train_loader) - 1:
            losses = estimate_loss(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                eval_iters=min(eval_iters, len(val_loader))
            )
            train_losses.append(losses['train'])
            val_losses.append(losses['val'])

            print(
                f"iteration {iteration} / step {batch_idx}: "
                f"train loss {losses['train']:.4f}, "
                f"val loss {losses['val']:.4f}"
            )

        # Training step
        logits, loss = model(x_batch, y_batch)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_epoch_loss = total_loss / len(train_loader)
    print(f"Epoch {iteration + 1} average training loss: {avg_epoch_loss:.4f}")

    # Save checkpoint
    save_checkpoint(
        model=model,
        optimizer=optimizer,
        epoch=iteration,
        loss=avg_epoch_loss,
        file_path=f"../output/pre_training/run_4/checkpoint_{iteration}.pth"
    )

iteration 0 / step 0: train loss 7.0490, val loss 7.0789
iteration 0 / step 100: train loss 2.9945, val loss 3.4789
iteration 0 / step 200: train loss 2.3144, val loss 2.7205
iteration 0 / step 300: train loss 1.7381, val loss 2.1259
iteration 0 / step 400: train loss 1.4368, val loss 1.8286
iteration 0 / step 500: train loss 1.2760, val loss 1.6638


KeyboardInterrupt: 

len(encoded_text_sequence) = 4,512,042 tokens

So training data (90%) = 0.9 * 4,512,042 ≈ 4,060,837 tokens

Tokens per batch = 64 * 256 = 16,384

total_steps = 4,060,837 / 16,384 ≈ 247 steps

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Train Loss", marker='o')
plt.plot(val_losses, label="Validation Loss", marker='o')
plt.xlabel("Evaluation Step")
plt.ylabel("Loss")
plt.title("Training and Validation Loss Over Time")
plt.legend()
plt.grid()
plt.show()

In [None]:
input_tokens = tokenizer.encode("Hello! I want some help")
input_tokens = torch.tensor(
    input_tokens, dtype=torch.long).unsqueeze(0).to(device)

model.eval()
with torch.no_grad():
    output = model.generate(input_tokens=input_tokens, max_new_tokens=100)

print(tokenizer.decode(output[0].tolist()))