In [1]:
import pandas as pd

df = pd.read_parquet("hf://datasets/gamino/wiki_medical_terms/wiki_medical_terms.parquet")

In [2]:
df.head()

Unnamed: 0,page_title,page_text
0,Paracetamol poisoning,"Paracetamol poisoning, also known as acetamino..."
1,Acromegaly,Acromegaly is a disorder that results from exc...
2,Actinic keratosis,"Actinic keratosis (AK), sometimes called solar..."
3,Congenital adrenal hyperplasia,Congenital adrenal hyperplasia (CAH) is a grou...
4,Adrenocortical carcinoma,Adrenocortical carcinoma (ACC) is an aggressi...


In [4]:
articles = df.iloc[:, 1]
articles

0       Paracetamol poisoning, also known as acetamino...
1       Acromegaly is a disorder that results from exc...
2       Actinic keratosis (AK), sometimes called solar...
3       Congenital adrenal hyperplasia (CAH) is a grou...
4       Adrenocortical carcinoma  (ACC) is an aggressi...
                              ...                        
7271    Gephyrophobia is the anxiety disorder or speci...
7272    Coronary artery bypass surgery, also known as ...
7273    Unemployment, according to the OECD (Organisat...
7274    A surgical instrument is a tool or device for ...
7275    Occipital neuralgia (ON) is a painful conditio...
Name: page_text, Length: 6861, dtype: object

In [5]:
import re
def preprocess_text(text):
    text = text.lower()
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

articles = articles.apply(preprocess_text)

In [6]:
articles

0       paracetamol poisoning also known as acetaminop...
1       acromegaly is a disorder that results from exc...
2       actinic keratosis ak sometimes called solar ke...
3       congenital adrenal hyperplasia cah is a group ...
4       adrenocortical carcinoma acc is an aggressive ...
                              ...                        
7271    gephyrophobia is the anxiety disorder or speci...
7272    coronary artery bypass surgery also known as c...
7273    unemployment according to the oecd organisatio...
7274    a surgical instrument is a tool or device for ...
7275    occipital neuralgia on is a painful condition ...
Name: page_text, Length: 6861, dtype: object

In [7]:
import tiktoken
tokenizer = tiktoken.get_encoding('gpt2')
tokenized_articles = articles.apply(lambda x: tokenizer.encode(x))
tokenized_articles

0       [1845, 23253, 321, 349, 22475, 635, 1900, 355,...
1       [330, 398, 1533, 3400, 318, 257, 8967, 326, 24...
2       [529, 47277, 41927, 265, 5958, 47594, 3360, 14...
3       [36801, 268, 1287, 26999, 282, 8718, 489, 2321...
4       [324, 918, 420, 419, 605, 28164, 6086, 697, 31...
                              ...                        
7271    [469, 6883, 10051, 30665, 318, 262, 9751, 8967...
7272    [10215, 261, 560, 37646, 17286, 8185, 635, 190...
7273    [403, 28812, 1864, 284, 262, 267, 21142, 12684...
7274    [64, 21998, 8875, 318, 257, 2891, 393, 3335, 3...
7275    [13966, 541, 1287, 17019, 70, 544, 319, 318, 2...
Name: page_text, Length: 6861, dtype: object

In [9]:
import torch
from torch.nn.utils.rnn import pad_sequence

# Define maximum length for padding/truncation
max_length = 512

# Pad and truncate tokenized articles
def pad_and_truncate(tokens, max_length):
    if len(tokens) > max_length:
        return tokens[:max_length]
    return tokens + [0] * (max_length - len(tokens))

# Apply padding/truncation
tokenized_articles = tokenized_articles.apply(lambda x: pad_and_truncate(x, max_length))
tokenized_articles.shape

(6861,)

In [50]:
import wandb

# Initialize wandb
wandb_log = True  # Set to True to enable wandb logging
wandb_project = 'Attention Benchmark'
wandb_run_name = 'Casual Attention-mps'

ModuleNotFoundError: No module named 'google.protobuf'

In [10]:
import torch 
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import transformers
import tiktoken

import math

from tqdm import tqdm

import os
import time

In [32]:
class TextDataset(Dataset):
    def __init__(self, articles, model="gpt2", seq_length=512):
        self.tokenizer = tiktoken.get_encoding(model)
        self.vocab_size = self.tokenizer.n_vocab
        self.seq_length = seq_length
        self.articles = articles.apply(self.preprocess_and_tokenize)
        
        self.input_ids, self.attention_masks, self.targets = self.create_sequences()

    def preprocess_and_tokenize(self, text):
        # Preprocess text
        text = text.lower()
        text = re.sub(r'[^a-zA-Z\s]', '', text)
        text = re.sub(r'\s+', ' ', text).strip()
        
        # Tokenize text
        tokens = self.tokenizer.encode(text)
        
        # Check for invalid token indices
        assert all(token < self.vocab_size for token in tokens), "Token index out of range"
        
        # Pad and truncate tokens
        if len(tokens) > self.seq_length:
            tokens = tokens[:self.seq_length]
        else:
            tokens += [0] * (self.seq_length - len(tokens))
            
        return tokens

    def create_sequences(self):
        input_ids = []
        attention_masks = []
        targets = []
        
        for tokens in self.articles:
            input_ids.append(tokens[:-1])  # Exclude the last token for input
            targets.append(tokens[1:])     # Exclude the first token for target
            attention_masks.append([1 if token != 0 else 0 for token in tokens[:-1]])
        
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        attention_masks = torch.tensor(attention_masks, dtype=torch.long)
        targets = torch.tensor(targets, dtype=torch.long)
        
        return input_ids, attention_masks, targets

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        input_seq = self.input_ids[idx]
        attention_mask = self.attention_masks[idx]
        target_seq = self.targets[idx]
        
        sample = {'input_ids': input_seq, 'targets': target_seq, 'attention_mask': attention_mask}
        return sample

def pad_sequences(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    targets = [item['targets'] for item in batch]

    input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks_padded = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0)
    targets_padded = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=0)

    return {'input_ids': input_ids_padded, 'targets': targets_padded, 'attention_mask': attention_masks_padded }


In [33]:
# Create the dataset
dataset = TextDataset(articles, model="gpt2", seq_length=512)

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=pad_sequences)

In [34]:
for data in dataloader:
    x, y, att = data['input_ids'], data['targets'], data['attention_mask']
    print(x,y,att)
    break

tensor([[32396,   318,   262,  ...,   393, 35961,   612],
        [ 5657,  2411,  7721,  ...,     0,     0,     0]]) tensor([[  318,   262, 19883,  ..., 35961,   612,   318],
        [ 2411,  7721,  4143,  ...,     0,     0,     0]]) tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0]])


In [35]:
#configuration
class GPTConfig:
    def __init__(self, vocab_size, block_size, n_layer, n_head, n_embd, dropout, bias=True):
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout = dropout
        self.bias = bias

In [36]:
#attention
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_head = config.n_head
        self.d_k = config.n_embd // config.n_head
        self.scale = self.d_k ** -0.5

        self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv_proj(x).reshape(B, T, 3, self.n_head, self.d_k).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn_scores = (q @ k.transpose(-2, -1)) * self.scale
        attn_scores = attn_scores.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.attn_dropout(attn_probs)

        attn_output = (attn_probs @ v).transpose(1, 2).reshape(B, T, C)
        attn_output = self.resid_dropout(self.out_proj(attn_output))
        return attn_output

In [37]:
#transformer block
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd, eps=1e-5)
        self.attn = CausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd, eps=1e-5)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.dropout),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [38]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None, "Config must include vocab_size"
        assert config.block_size is not None, "Config must include block_size"
        self.config = config

        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(config.vocab_size, config.n_embd),
            'wpe': nn.Embedding(config.block_size, config.n_embd),
            'drop': nn.Dropout(config.dropout),
            'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            'ln_f': nn.LayerNorm(config.n_embd, eps=1e-5),
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.tie_weights()

        self.apply(self._init_weights)
        self.init_residuals()
        print(f"Number of parameters: {self.get_num_params()/1e6:.2f}M")

    def tie_weights(self):
        self.transformer.wte.weight = self.lm_head.weight

    def init_residuals(self):
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))

    def get_num_params(self, non_embedding=True):
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        logits = self.lm_head(x)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
            return logits, loss
        else:
            logits = logits[:, [-1], :]  # Use only the last token's logits
            return logits, None

    def crop_block_size(self, block_size):
        assert block_size <= self.config.block_size
        self.config.block_size = block_size
        self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
        for block in self.transformer.h:
            if hasattr(block.attn, 'bias'):
                block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]

    @classmethod
    def from_pretrained(cls, model_type, override_args=None):
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}, "Invalid model type"
        override_args = override_args or {}
        assert all(k == 'dropout' for k in override_args), "Only 'dropout' can be overridden"

        print(f"Loading weights from pretrained model: {model_type}")

        config_args = cls.get_config_args(model_type)
        if 'dropout' in override_args:
            print(f"Overriding dropout rate to {override_args['dropout']}")
            config_args['dropout'] = override_args['dropout']

        config = GPTConfig(**config_args)
        model = GPT(config)
        model.load_pretrained_weights(model_type)
        return model

    @staticmethod
    def get_config_args(model_type):
        config_map = {
            'gpt2': {'n_layer': 12, 'n_head': 12, 'n_embd': 768},
            'gpt2-medium': {'n_layer': 24, 'n_head': 16, 'n_embd': 1024},
            'gpt2-large': {'n_layer': 36, 'n_head': 20, 'n_embd': 1280},
            'gpt2-xl': {'n_layer': 48, 'n_head': 25, 'n_embd': 1600},
        }
        config_args = config_map[model_type]
        config_args.update({'vocab_size': 50257, 'block_size': 1024, 'bias': True, 'dropout': 0.1})
        return config_args

    def load_pretrained_weights(self, model_type):
        model_hf = transformers.GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()
        sd = self.state_dict()
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']

        for k, v in sd_hf.items():
            if k in sd:
                if any(k.endswith(w) for w in transposed):
                    sd[k].copy_(v.t())
                else:
                    sd[k].copy_(v)

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        N = self.get_num_params()
        L, H, Q, T = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.block_size
        flops_per_token = 6 * N + 12 * L * H * Q * T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        flops_achieved = flops_per_iter * (1.0 / dt)
        flops_promised = 312e12
        mfu = flops_achieved / flops_promised
        return mfu

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


In [39]:
#training dependencies
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler
from contextlib import nullcontext
import torch.distributed as dist
from torch.distributed import init_process_group, destroy_process_group

In [40]:
# Default config values designed to train a GPT-2 (124M) on OpenWebText
# I/O
out_dir = 'out'
eval_interval = 2000
log_interval = 1
eval_iters = 200
eval_only = False
always_save_checkpoint = True
init_from = 'scratch'  # 'scratch' or 'resume' or 'gpt2*'

# Data
# data_path = 'data/openwebtext.txt'  # Path to your text file
gradient_accumulation_steps = 5 * 8
batch_size = 12
block_size = 1024

# Model
n_layer = 6  # Reduce the number of layers
n_head = 8   # Reduce the number of attention heads
n_embd = 512 # Reduce the embedding size
dropout = 0.0
bias = False

# AdamW optimizer
learning_rate = 6e-4
max_iters = 600000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0

# Learning rate decay settings
decay_lr = True
warmup_iters = 2000
lr_decay_iters = 600000
min_lr = 6e-5

# DDP settings
backend = 'gloo'

# System
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
dtype = 'float32'  # MPS currently supports only float32
compile = False  # Disable compilation for now as PyTorch 2.0 is not yet stable on MPS
# -----------------------------------------------------------------------------

# Collect configuration keys
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
config = {k: globals()[k] for k in config_keys}

# Various initializations and derived attributes
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
    init_process_group(backend=backend)
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'mps:{ddp_local_rank}' if torch.backends.mps.is_available() else f'cuda:{ddp_local_rank}'
    torch.mps.set_device(device)
    master_process = ddp_rank == 0
    seed_offset = ddp_rank
    assert gradient_accumulation_steps % ddp_world_size == 0
    gradient_accumulation_steps //= ddp_world_size
else:
    master_process = True
    seed_offset = 0
    ddp_world_size = 1

tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print(f"Tokens per iteration will be: {tokens_per_iter:,}")

if master_process:
    os.makedirs(out_dir, exist_ok=True)

torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'mps' if 'mps' in device else 'cpu'
ptdtype = torch.float32  # Currently, MPS supports only float32
ctx = nullcontext()

Tokens per iteration will be: 491,520


In [42]:
# Initialize iteration number and best validation loss
iter_num = 0
best_val_loss = 1e9

# Model initialization
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=len(dataset.input_ids), dropout=dropout)

if init_from == 'scratch':
    print("Initializing a new model from scratch")
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
elif init_from == 'resume':
    print(f"Resuming training from {out_dir}")
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    checkpoint_model_args = checkpoint['model_args']
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = checkpoint_model_args[k]
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
elif init_from.startswith('gpt2'):
    print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
    override_args = dict(dropout=dropout)
    model = GPT.from_pretrained(init_from, override_args)
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = getattr(model.config, k)

if block_size < model.config.block_size:
    model.crop_block_size(block_size)
    model_args['block_size'] = block_size

model.to(device)

Initializing a new model from scratch
Number of parameters: 22.42M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(6861, 512)
    (wpe): Embedding(1024, 512)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (qkv_proj): Linear(in_features=512, out_features=1536, bias=False)
          (out_proj): Linear(in_features=512, out_features=512, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2048, out_features=512, bias=True)
          (3): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(

In [44]:
# Initialize a GradScaler
scaler = GradScaler(enabled=(dtype == 'float16' and 'cuda' in device))

# Configure the optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
    optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None  # Free up memory

# Wrap model into DDP container if needed
if ddp:
    print(f"Starting parallel process with rank {ddp_rank}")
    model = DDP(model, device_ids=[ddp_local_rank])


# Function to estimate loss over splits
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        with tqdm(total=eval_iters, desc=f"Evaluating {split}") as pbar:
            for k in range(eval_iters):
                batch = next(iter(dataloader))
                X, Y = batch['input_ids'].to(device), batch['targets'].to(device)
                with ctx:
                    logits, loss = model(X, Y)
                losses[k] = loss.item()
                pbar.update(1)
        out[split] = losses.mean()
    model.train()
    return out

# Learning rate decay scheduler
def get_lr(it):
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    if it > lr_decay_iters:
        return min_lr
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)

In [47]:
raw_model = model.module if ddp else model
local_iter_num = 0
running_mfu = -1.0
t0 = time.time()

while iter_num <= max_iters:
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    if iter_num % eval_interval == 0 and master_process:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                    'config': config,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break

    with tqdm(total=gradient_accumulation_steps, desc=f"Training step {iter_num}") as pbar:
        for micro_step in range(gradient_accumulation_steps):
            if ddp:
                model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
            batch = next(iter(dataloader))
            X, Y = batch['input_ids'].to(device), batch['targets'].to(device)
            with ctx:
                logits, loss = model(X, Y)
                loss = loss / gradient_accumulation_steps
            scaler.scale(loss).backward()
            pbar.update(1)

    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)

    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5:
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt * 1000:.2f}ms, mfu {running_mfu * 100:.2f}%")
    iter_num += 1
    local_iter_num += 1

if ddp:
    destroy_process_group()

Evaluating train: 100%|██████████| 200/200 [00:26<00:00,  7.52it/s]
Evaluating val: 100%|██████████| 200/200 [00:28<00:00,  7.14it/s]


step 0: train loss 6.9129, val loss 6.9174


Training step 0: 100%|██████████| 40/40 [00:14<00:00,  2.82it/s]


iter 0: loss 7.0153, time 69213.72ms, mfu -100.00%


Training step 1: 100%|██████████| 40/40 [00:13<00:00,  2.95it/s]


iter 1: loss 6.8796, time 13865.59ms, mfu -100.00%


Training step 2:  30%|███       | 12/40 [00:04<00:10,  2.64it/s]


KeyboardInterrupt: 

In [52]:
x.shape

torch.Size([2, 511])

In [56]:
model(x.to('mps'), y.to('mps'))

(tensor([[[ 0.3315,  0.3442,  0.9170,  ...,  0.5294, -0.4510,  0.4243],
          [ 0.2494,  0.5027,  0.8012,  ...,  0.6525, -0.3048,  0.3368],
          [-0.0948,  0.6975,  0.8772,  ...,  0.4716, -0.2738,  0.6067],
          ...,
          [-0.3669, -0.8627,  0.8842,  ...,  0.8637,  0.1471, -0.0829],
          [-0.1819, -0.4338,  0.6389,  ...,  0.6406,  0.0930, -0.2009],
          [-0.7336, -0.1003,  0.4241,  ...,  0.5056,  0.4920, -0.4071]],
 
         [[ 0.3878,  0.4814,  0.8031,  ..., -0.1710, -0.2604,  0.8452],
          [ 0.3523,  0.3770,  0.6969,  ..., -0.1194, -0.2412,  0.8452],
          [ 0.1664,  0.4927,  0.6952,  ..., -0.0504, -0.2668,  1.1063],
          ...,
          [ 0.0423,  0.2268,  0.5100,  ...,  0.1423,  0.2569, -0.1164],
          [ 0.1000,  0.4063,  0.5384,  ...,  0.0562,  0.0893, -0.1400],
          [ 0.0725,  0.4153,  0.3523,  ..., -0.0303,  0.2497, -0.0804]]],
        device='mps:0', grad_fn=<LinearBackward0>),
 tensor(7.1700, device='mps:0', grad_fn=<NllLossB

In [57]:
model

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(6861, 512)
    (wpe): Embedding(1024, 512)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (qkv_proj): Linear(in_features=512, out_features=1536, bias=False)
          (out_proj): Linear(in_features=512, out_features=512, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2048, out_features=512, bias=True)
          (3): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(