In [3]:
#-------------------------! import statements !-------------------------
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BartTokenizer, BartForConditionalGeneration
import time
from dataclasses import dataclass

#-------------------------! Hyperparameter !-------------------------
@dataclass
class BartConfig:
    model_name: str = "facebook/bart-base"
    batch_size: int = 2
    block_size: int = 1024 

#-------------------------! Load BART Model and Tokenizer !-------------------------
config = BartConfig()
tokenizer = BartTokenizer.from_pretrained(config.model_name)
model = BartForConditionalGeneration.from_pretrained(config.model_name)

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print(f'using device {device}')

model.to(device)
model = torch.compile(model, backend='eager')

optimizer = optim.AdamW(model.parameters(), lr=1e-4)

#-------------------------! Data Loader Class !-------------------------
class DataLoaderLite:
    def __init__(self, B, T, tokenizer):
        self.B = B
        self.T = T
        self.tokenizer = tokenizer

        with open('../input.txt', 'r') as f:
            text = f.read()
        tokens = tokenizer(text, return_tensors="pt", max_length=T, truncation=True, padding="max_length")["input_ids"]
        self.tokens = tokens.to(device)

        self.num_batches = max(len(self.tokens) // B, 1)
        print(f'1 epoch = {self.num_batches} batches')

        # state
        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T
        start_idx = self.current_position
        end_idx = start_idx + B

        x = self.tokens[start_idx:end_idx]
        y = self.tokens[start_idx:end_idx]

        self.current_position += B

        if self.current_position + B > len(self.tokens):
            self.current_position = 0

        return x, y

train_loader = DataLoaderLite(config.batch_size, config.block_size, tokenizer)

def train(model, train_loader, optimizer, num_epochs=100):
    model.train()
    for i in range(num_epochs):
        t0 = time.time()
        src, tgt = train_loader.next_batch()
        optimizer.zero_grad()
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            output = model(input_ids=src, labels=tgt)
            loss = output.loss

        loss.backward()
        optimizer.step()
        torch.cuda.synchronize()
        t1 = time.time()
        dt = (t1 - t0) * 1000  # time difference is in milliseconds
        tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
        if (i + 1) % 10 == 0:
            print(f'step {i + 1}, loss: {loss.item()}, dt: {dt:.2f} ms, tok/sec {tokens_per_sec}')

train(model, train_loader, optimizer, num_epochs=100)


using device cuda
1 epoch = 1 batches
step 10, loss: 1.2398115396499634, dt: 431.26 ms, tok/sec 4748.850560440502
step 20, loss: 0.4794187843799591, dt: 432.53 ms, tok/sec 4734.971637910268
step 30, loss: 0.23416724801063538, dt: 432.59 ms, tok/sec 4734.240943284023
step 40, loss: 0.3205746114253998, dt: 437.76 ms, tok/sec 4678.363895609058
step 50, loss: 0.15725810825824738, dt: 467.58 ms, tok/sec 4380.034128833133
step 60, loss: 0.22025199234485626, dt: 456.41 ms, tok/sec 4487.203584775361
step 70, loss: 0.15782202780246735, dt: 628.19 ms, tok/sec 3260.1611693251034
step 80, loss: 0.08541318029165268, dt: 451.24 ms, tok/sec 4538.594510742127
step 90, loss: 0.021308371797204018, dt: 525.79 ms, tok/sec 3895.1213990258043
step 100, loss: 0.0583316832780838, dt: 486.14 ms, tok/sec 4212.759506980777
