In [1]:
import math

import torch
from torch.nn.utils.rnn import pad_sequence
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import transformers

from einops import rearrange
from typing import Optional, Union

import tiktoken
import pandas as pd
import re

import os


In [2]:
import wandb

# Initialize wandb
wandb_log = True  
run_name = 'Sparse Attention - mps'
project_name = "Attention Benchmark"


In [3]:
class TextDataset(Dataset):
    def __init__(self, articles, model="gpt2", seq_length=512):
        self.tokenizer = tiktoken.get_encoding(model)
        self.vocab_size = self.tokenizer.n_vocab
        self.seq_length = seq_length
        self.articles = articles.apply(self.preprocess_and_tokenize)
        
        self.input_ids, self.targets, self.attention_masks = self.create_sequences()

    def preprocess_and_tokenize(self, text):
        # Preprocess text
        text = text.lower()
        text = re.sub(r'[^a-zA-Z\s]', '', text)
        text = re.sub(r'\s+', ' ', text).strip()
        
        # Tokenize text
        tokens = self.tokenizer.encode(text)
        
        # Check for invalid token indices
        assert all(token < self.vocab_size for token in tokens), "Token index out of range"
        
        # Pad and truncate tokens
        if len(tokens) > self.seq_length:
            tokens = tokens[:self.seq_length]
        else:
            tokens += [0] * (self.seq_length - len(tokens))
        self.tokens = tokens
        return tokens

    def create_sequences(self):
        input_ids = []
        attention_masks = []
        targets = []
        
        for tokens in self.articles:
            input_ids.append(tokens[:-1])  # Exclude the last token for input
            targets.append(tokens[1:])     # Exclude the first token for target
            attention_masks.append([1 if token != 0 else 0 for token in tokens[:-1]])
        
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        attention_masks = torch.tensor(attention_masks, dtype=torch.long)
        targets = torch.tensor(targets, dtype=torch.long)
        
        return input_ids, targets, attention_masks

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        input_seq = self.input_ids[idx]
        attention_mask = self.attention_masks[idx]
        target_seq = self.targets[idx]
        
        sample = {'input_ids': input_seq, 'targets': target_seq, 'attention_mask': attention_mask}
        return sample

def pad_sequences(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    targets = [item['targets'] for item in batch]

    input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks_padded = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0)
    targets_padded = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=0)

    return {'input_ids': input_ids_padded, 'targets': targets_padded, 'attention_mask': attention_masks_padded}

In [4]:
# Create the dataset
df = pd.read_parquet("hf://datasets/gamino/wiki_medical_terms/wiki_medical_terms.parquet")
articles = df.iloc[:, 1]

dataset = TextDataset(articles, model="gpt2", seq_length=512)

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=2, collate_fn=pad_sequences)

In [5]:
#ref: https://github.com/kyegomez/SparseAttention/blob/main/sparse_attention.py

def get_attn_mask(n, attn_mode, local_attn_ctx=None):
    if attn_mode == 'all':
        b = torch.tril(torch.ones([n, n]))
    elif attn_mode == 'local':
        bandwidth = local_attn_ctx
        ctx = min(n - 1, bandwidth - 1)
        b = torch.tril(torch.ones([n, n]), ctx)
    elif attn_mode == 'strided':
        stride = local_attn_ctx
        x = torch.reshape(torch.arange(n, dtype=torch.int32), [n, 1])
        y = torch.transpose(x, 0, 1)
        z = torch.zeros([n, n], dtype=torch.int32)
        q = z + x
        k = z + y
        c1 = q >= k
        c2 = torch.eq(torch.fmod(q - k, stride), 0)
        c3 = torch.logical_and(c1, c2)
        b = c3.float()
    else:
        raise ValueError('Not yet implemented')
    b = torch.reshape(b, [1, 1, n, n])
    return b

def strided_transpose(x, n_ctx, local_attn_ctx, blocksize):
    bT_ctx = n_ctx // local_attn_ctx
    assert bT_ctx % blocksize == 0, f'{bT_ctx}, {blocksize}'
    n, t, embd = x.size()
    x = torch.reshape(x, [n, bT_ctx, local_attn_ctx, embd])
    x = torch.transpose(x, 1, 2)
    x = torch.reshape(x, [n, t, embd])
    return x

def split_heads(x, n):
    return torch.transpose(split_states(x, n), 1, 2)

def merge_heads(x):
    return merge_states(torch.transpose(x, 1, 2))

def split_states(x, n):
    x_shape = x.size()
    m = x_shape[-1]
    new_x_shape = x_shape[:-1] + [n, m // n]
    return torch.reshape(x, new_x_shape)

def merge_states(x):
    x_shape = x.size()
    new_x_shape = x_shape[:-2] + [np.prod(x_shape[-2:])]
    return torch.reshape(x, new_x_shape)

def blocksparse_attention_impl(q, k, v, heads, attn_mode, local_attn_ctx=None, blocksize=32):
    n_ctx = q.size()[1]
    if attn_mode == 'strided':
        q = strided_transpose(q, n_ctx, local_attn_ctx, blocksize)
        k = strided_transpose(k, n_ctx, local_attn_ctx, blocksize)
        v = strided_transpose(v, n_ctx, local_attn_ctx, blocksize)
    n_state = q.size()[-1] // heads
    scale_amount = 1.0 / np.sqrt(n_state)
    w = torch.matmul(q, k.transpose(-2, -1))
    w = F.softmax(w * scale_amount, dim=-1)
    a = torch.matmul(w, v)
    if attn_mode == 'strided':
        n, t, embd = a.size()
        bT_ctx = n_ctx // local_attn_ctx
        a = torch.reshape(a, [n, local_attn_ctx, bT_ctx, embd])
        a = torch.transpose(a, 1, 2)
        a = torch.reshape(a, [n, t, embd])
    return a

class SparseAttention(nn.Module):
    # n_batch = 4
    # n_ctx = 1024
    # n_embd = 256
    # heads = 4
    # attn_mode = "all"
    # local_attn_ctx = 32
    # blocksize = 32
    def __init__(self, config):
        super(SparseAttention, self).__init__()
        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.head_dim = self.n_embd // self.n_head
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(self.n_embd, 3 * self.n_embd)
        self.proj = nn.Linear(self.n_embd, self.n_embd)
        self.dropout_p = config.dropout
        self.blocksize = config.block_size
        self.local_attn_ctx = 32
        self.attn_mode = "all"
        

    def forward(self, x):
        b, t, c = x.size()
        qkv = self.qkv(x).view(b, t, 3, self.n_head, self.head_dim)
        q, k, v = qkv.unbind(2)
        q, k, v = [rearrange(x, 'b t h d -> (b h) t d') for x in (q, k, v)]

        # Implement sparse attention logic
        attn_output = blocksparse_attention_impl(q, k, v, self.n_head, self.attn_mode, self.local_attn_ctx, self.blocksize)
        attn_output = rearrange(attn_output, '(b h) t d -> b t (h d)', h=self.n_head)
        attn_output = self.proj(attn_output)

        return attn_output

In [6]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd, eps=1e-5)
        self.attn = SparseAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd, eps=1e-5)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.dropout),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

# Assuming GPTConfig is defined somewhere in the code
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None, "Config must include vocab_size"
        assert config.block_size is not None, "Config must include block_size"
        self.config = config

        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(config.vocab_size, config.n_embd),
            'wpe': nn.Embedding(config.block_size, config.n_embd),
            'drop': nn.Dropout(config.dropout),
            'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            'ln_f': nn.LayerNorm(config.n_embd, eps=1e-5),
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.tie_weights()

        self.apply(self._init_weights)
        self.init_residuals()
        print(f"Number of parameters: {self.get_num_params()/1e6:.2f}M")

    def tie_weights(self):
        self.transformer.wte.weight = self.lm_head.weight

    def init_residuals(self):
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))

    def get_num_params(self, non_embedding=True):
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        logits = self.lm_head(x)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
            return logits, loss
        else:
            logits = logits[:, [-1], :]  # Use only the last token's logits
            return logits, None

    def crop_block_size(self, block_size):
        assert block_size <= self.config.block_size
        self.config.block_size = block_size
        self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
        for block in self.transformer.h:
            if hasattr(block.attn, 'bias'):
                block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]

    @classmethod
    def from_pretrained(cls, model_type, override_args=None):
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}, "Invalid model type"
        override_args = override_args or {}
        assert all(k == 'dropout' for k in override_args), "Only 'dropout' can be overridden"

        print(f"Loading weights from pretrained model: {model_type}")

        config_args = cls.get_config_args(model_type)
        if 'dropout' in override_args:
            print(f"Overriding dropout rate to {override_args['dropout']}")
            config_args['dropout'] = override_args['dropout']

        config = GPTConfig(**config_args)
        model = GPT(config)
        model.load_pretrained_weights(model_type)
        return model

    @staticmethod
    def get_config_args(model_type):
        config_map = {
            'gpt2': {'n_layer': 12, 'n_head': 12, 'n_embd': 768},
            'gpt2-medium': {'n_layer': 24, 'n_head': 16, 'n_embd': 1024},
            'gpt2-large': {'n_layer': 36, 'n_head': 20, 'n_embd': 1280},
            'gpt2-xl': {'n_layer': 48, 'n_head': 25, 'n_embd': 1600},
        }
        config_args = config_map[model_type]
        config_args.update({'vocab_size': 50257, 'block_size': 1024, 'bias': True, 'dropout': 0.1})
        return config_args

    def load_pretrained_weights(self, model_type):
        model_hf = transformers.GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()
        sd = self.state_dict()
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']

        for k, v in sd_hf.items():
            if k in sd:
                if any(k.endswith(w) for w in transposed):
                    sd[k].copy_(v.t())
                else:
                    sd[k].copy_(v)

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        N = self.get_num_params()
        L, H, Q, T = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.block_size
        flops_per_token = 6 * N + 12 * L * H * Q * T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        flops_achieved = flops_per_iter * (1.0 / dt)
        flops_promised = 312e12
        mfu = flops_achieved / flops_promised
        return mfu

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [7]:
#training dependencies
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler
from contextlib import nullcontext
import torch.distributed as dist
from torch.distributed import init_process_group, destroy_process_group

In [8]:
# Default config values designed to train a GPT-2 (124M) on OpenWebText
# I/O
out_dir = 'out'
eval_interval = 2000
log_interval = 1
eval_iters = 200
eval_only = False
always_save_checkpoint = True
init_from = 'scratch'  # 'scratch' or 'resume' or 'gpt2*'

# Data
# data_path = 'data/openwebtext.txt'  # Path to your text file
gradient_accumulation_steps = 5 * 8
batch_size = 2
block_size = 512

# Model
n_layer = 6  # Reduce the number of layers
n_head = 8   # Reduce the number of attention heads
n_embd = 512 # Reduce the embedding size
dropout = 0.0
bias = False

# AdamW optimizer
learning_rate = 6e-4
max_iters = 1000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0

# Learning rate decay settings
decay_lr = True
warmup_iters = 200
lr_decay_iters = 1000
min_lr = 6e-5

# DDP settings
backend = 'gloo'

# System
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
dtype = 'float32'  # MPS currently supports only float32
compile = False  # Disable compilation for now as PyTorch 2.0 is not yet stable on MPS
# -----------------------------------------------------------------------------

# Collect configuration keys
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
config = {k: globals()[k] for k in config_keys}

# Various initializations and derived attributes
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
    init_process_group(backend=backend)
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'mps:{ddp_local_rank}' if torch.backends.mps.is_available() else f'cuda:{ddp_local_rank}'
    torch.mps.set_device(device)
    master_process = ddp_rank == 0
    seed_offset = ddp_rank
    assert gradient_accumulation_steps % ddp_world_size == 0
    gradient_accumulation_steps //= ddp_world_size
else:
    master_process = True
    seed_offset = 0
    ddp_world_size = 1

tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print(f"Tokens per iteration will be: {tokens_per_iter:,}")

if master_process:
    os.makedirs(out_dir, exist_ok=True)

torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'mps' if 'mps' in device else 'cpu'
ptdtype = torch.float32  # Currently, MPS supports only float32
ctx = nullcontext()

Tokens per iteration will be: 40,960


In [9]:
#configuration
class GPTConfig:
    def __init__(self, vocab_size, block_size, n_layer, n_head, n_embd, dropout, bias=True):
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout = dropout
        self.bias = bias
        
# Initialize iteration number and best validation loss
iter_num = 0
best_val_loss = 1e9

# Model initialization
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=len(dataset.tokens), dropout=dropout)

if init_from == 'scratch':
    print("Initializing a new model from scratch")
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
elif init_from == 'resume':
    print(f"Resuming training from {out_dir}")
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    checkpoint_model_args = checkpoint['model_args']
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = checkpoint_model_args[k]
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
elif init_from.startswith('gpt2'):
    print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
    override_args = dict(dropout=dropout)
    model = GPT.from_pretrained(init_from, override_args)
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = getattr(model.config, k)

if block_size < model.config.block_size:
    model.crop_block_size(block_size)
    model_args['block_size'] = block_size

model.to(device)

Initializing a new model from scratch
Number of parameters: 19.18M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(512, 512)
    (wpe): Embedding(512, 512)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): SparseAttention(
          (qkv): Linear(in_features=512, out_features=1536, bias=True)
          (proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2048, out_features=512, bias=True)
          (3): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=512, bias=False)
)

In [10]:
# Initialize a GradScaler
scaler = GradScaler(enabled=(dtype == 'float16' and 'cuda' in device))

# Configure the optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
    optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None  # Free up memory

# Wrap model into DDP container if needed
if ddp:
    print(f"Starting parallel process with rank {ddp_rank}")
    model = DDP(model, device_ids=[ddp_local_rank])


# Function to estimate loss over splits
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        with tqdm(total=eval_iters, desc=f"Evaluating {split}") as pbar:
            for k in range(eval_iters):
                batch = next(iter(dataloader))
                X, Y = batch['input_ids'].to(device), batch['targets'].to(device)
                with ctx:
                    logits, loss = model(X, Y)
                losses[k] = loss.item()
                pbar.update(1)
        out[split] = losses.mean()
    model.train()
    return out

# Learning rate decay scheduler
def get_lr(it):
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    if it > lr_decay_iters:
        return min_lr
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)

In [11]:
import time
from tqdm import tqdm


wandb.init(
    # set the wandb project where this run will be logged
    project=project_name, name= run_name, config=config)

raw_model = model.module if ddp else model
local_iter_num = 0
running_mfu = -1.0
t0 = time.time()

while iter_num <= max_iters:
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    if iter_num % eval_interval == 0 and master_process:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        if wandb_log:
            wandb.log({
                "iter": iter_num,
                "train_loss": losses['train'],
                "val_loss": losses['val'],
                "lr": lr,
                "mfu": running_mfu * 100,  # convert to percentage
            })


        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                    'config': config,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break

    with tqdm(total=gradient_accumulation_steps, desc=f"Training step {iter_num}") as pbar:
        total_train_loss = 0.0  # Initialize total training loss
        for micro_step in range(gradient_accumulation_steps):
            if ddp:
                model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
            batch = next(iter(dataloader))
            X, Y = batch['input_ids'].to(device), batch['targets'].to(device)
            with ctx:
                logits, loss = model(X, Y)
                loss = loss / gradient_accumulation_steps
            scaler.scale(loss).backward()
            pbar.update(1)

    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)

    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5:
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt * 1000:.2f}ms, mfu {running_mfu * 100:.2f}%")

        if wandb_log:
            wandb.log({
                "iter": iter_num,
                "train_loss": total_train_loss,
                "lr": lr,
                "mfu": running_mfu * 100,  # convert to percentage
            })

    iter_num += 1
    local_iter_num += 1

if ddp:
    destroy_process_group()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnielspace[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011168043511111137, max=1.0…

Evaluating train: 100%|██████████| 200/200 [00:11<00:00, 17.95it/s]
Evaluating val: 100%|██████████| 200/200 [00:11<00:00, 18.00it/s]


step 0: train loss 2.0792, val loss 2.0792


Training step 0: 100%|██████████| 40/40 [00:06<00:00,  6.17it/s]


iter 0: loss 2.0792, time 28972.68ms, mfu -100.00%


Training step 1: 100%|██████████| 40/40 [00:06<00:00,  6.45it/s]


iter 1: loss 2.0792, time 6287.87ms, mfu -100.00%


Training step 2: 100%|██████████| 40/40 [00:06<00:00,  6.56it/s]


iter 2: loss 2.0515, time 6272.93ms, mfu -100.00%


Training step 3: 100%|██████████| 40/40 [00:06<00:00,  6.60it/s]


iter 3: loss 1.9979, time 6204.72ms, mfu -100.00%


Training step 4: 100%|██████████| 40/40 [00:06<00:00,  6.46it/s]


iter 4: loss 1.9236, time 6341.40ms, mfu -100.00%


Training step 5: 100%|██████████| 40/40 [00:06<00:00,  6.57it/s]


iter 5: loss 1.8395, time 6309.09ms, mfu 0.28%


Training step 6: 100%|██████████| 40/40 [00:06<00:00,  6.43it/s]


iter 6: loss 1.7573, time 6345.16ms, mfu 0.28%


Training step 7: 100%|██████████| 40/40 [00:06<00:00,  6.48it/s]


iter 7: loss 1.6839, time 6334.63ms, mfu 0.28%


Training step 8: 100%|██████████| 40/40 [00:06<00:00,  6.50it/s]


iter 8: loss 1.6196, time 6315.10ms, mfu 0.28%


Training step 9: 100%|██████████| 40/40 [00:06<00:00,  6.55it/s]


iter 9: loss 1.5624, time 6257.82ms, mfu 0.28%


Training step 10: 100%|██████████| 40/40 [00:06<00:00,  6.60it/s]


iter 10: loss 1.5099, time 6211.21ms, mfu 0.28%


Training step 11: 100%|██████████| 40/40 [00:06<00:00,  6.61it/s]


iter 11: loss 1.4613, time 6206.24ms, mfu 0.28%


Training step 12: 100%|██████████| 40/40 [00:06<00:00,  6.54it/s]


iter 12: loss 1.4177, time 6271.60ms, mfu 0.28%


Training step 13: 100%|██████████| 40/40 [00:06<00:00,  6.52it/s]


iter 13: loss 1.3793, time 6297.36ms, mfu 0.28%


Training step 14: 100%|██████████| 40/40 [00:06<00:00,  6.50it/s]


iter 14: loss 1.3436, time 6323.59ms, mfu 0.28%


Training step 15: 100%|██████████| 40/40 [00:06<00:00,  6.51it/s]


iter 15: loss 1.3084, time 6313.74ms, mfu 0.28%


Training step 16: 100%|██████████| 40/40 [00:06<00:00,  6.54it/s]


iter 16: loss 1.2727, time 6279.07ms, mfu 0.28%


Training step 17: 100%|██████████| 40/40 [00:06<00:00,  6.43it/s]


iter 17: loss 1.2340, time 6390.66ms, mfu 0.28%


Training step 18: 100%|██████████| 40/40 [00:06<00:00,  6.33it/s]


iter 18: loss 1.1878, time 6504.18ms, mfu 0.28%


Training step 19: 100%|██████████| 40/40 [00:06<00:00,  6.40it/s]


iter 19: loss 1.1325, time 6385.13ms, mfu 0.28%


Training step 20: 100%|██████████| 40/40 [00:06<00:00,  6.22it/s]


iter 20: loss 1.0673, time 6548.14ms, mfu 0.28%


Training step 21: 100%|██████████| 40/40 [00:06<00:00,  6.18it/s]


iter 21: loss 0.9927, time 6590.30ms, mfu 0.28%


Training step 22: 100%|██████████| 40/40 [00:06<00:00,  6.29it/s]


iter 22: loss 0.9138, time 6501.31ms, mfu 0.28%


Training step 23: 100%|██████████| 40/40 [00:06<00:00,  6.11it/s]


iter 23: loss 0.8300, time 6677.75ms, mfu 0.27%


Training step 24: 100%|██████████| 40/40 [00:07<00:00,  5.70it/s]


iter 24: loss 0.7440, time 7144.75ms, mfu 0.27%


Training step 25: 100%|██████████| 40/40 [00:06<00:00,  6.15it/s]


iter 25: loss 0.6613, time 6627.22ms, mfu 0.27%


Training step 26: 100%|██████████| 40/40 [00:06<00:00,  5.84it/s]


iter 26: loss 0.5874, time 6999.51ms, mfu 0.27%


Training step 27: 100%|██████████| 40/40 [00:06<00:00,  5.86it/s]


iter 27: loss 0.5171, time 7220.75ms, mfu 0.27%


Training step 28: 100%|██████████| 40/40 [00:06<00:00,  6.02it/s]


iter 28: loss 0.4516, time 6750.00ms, mfu 0.27%


Training step 29: 100%|██████████| 40/40 [00:06<00:00,  5.81it/s]


iter 29: loss 0.3904, time 7065.47ms, mfu 0.26%


Training step 30: 100%|██████████| 40/40 [00:06<00:00,  5.92it/s]


iter 30: loss 0.3320, time 6870.68ms, mfu 0.26%


Training step 31: 100%|██████████| 40/40 [00:06<00:00,  5.78it/s]


iter 31: loss 0.2788, time 7070.66ms, mfu 0.26%


Training step 32: 100%|██████████| 40/40 [00:06<00:00,  5.79it/s]


iter 32: loss 0.2312, time 7043.51ms, mfu 0.26%


Training step 33: 100%|██████████| 40/40 [00:07<00:00,  5.59it/s]


iter 33: loss 0.1896, time 7363.56ms, mfu 0.26%


Training step 34: 100%|██████████| 40/40 [00:07<00:00,  5.54it/s]


iter 34: loss 0.1541, time 7342.09ms, mfu 0.26%


Training step 35: 100%|██████████| 40/40 [00:07<00:00,  5.37it/s]


iter 35: loss 0.1248, time 7641.67ms, mfu 0.25%


Training step 36: 100%|██████████| 40/40 [00:06<00:00,  5.78it/s]


iter 36: loss 0.1008, time 7085.20ms, mfu 0.25%


Training step 37: 100%|██████████| 40/40 [00:07<00:00,  5.58it/s]


iter 37: loss 0.0815, time 7303.90ms, mfu 0.25%


Training step 38: 100%|██████████| 40/40 [00:07<00:00,  5.49it/s]


iter 38: loss 0.0671, time 7407.49ms, mfu 0.25%


Training step 39: 100%|██████████| 40/40 [00:07<00:00,  5.68it/s]


iter 39: loss 0.0556, time 7211.87ms, mfu 0.25%


Training step 40: 100%|██████████| 40/40 [00:07<00:00,  5.47it/s]


iter 40: loss 0.0461, time 7468.69ms, mfu 0.25%


Training step 41: 100%|██████████| 40/40 [00:07<00:00,  5.46it/s]


iter 41: loss 0.0387, time 7503.11ms, mfu 0.25%


Training step 42: 100%|██████████| 40/40 [00:07<00:00,  5.60it/s]


iter 42: loss 0.0327, time 7272.55ms, mfu 0.25%


Training step 43: 100%|██████████| 40/40 [00:07<00:00,  5.47it/s]


iter 43: loss 0.0279, time 7452.56ms, mfu 0.25%


Training step 44: 100%|██████████| 40/40 [00:07<00:00,  5.51it/s]


iter 44: loss 0.0243, time 7443.07ms, mfu 0.24%


Training step 45: 100%|██████████| 40/40 [00:07<00:00,  5.29it/s]


iter 45: loss 0.0212, time 7766.73ms, mfu 0.24%


Training step 46: 100%|██████████| 40/40 [00:07<00:00,  5.24it/s]


iter 46: loss 0.0188, time 7826.88ms, mfu 0.24%


Training step 47: 100%|██████████| 40/40 [00:07<00:00,  5.22it/s]


iter 47: loss 0.0167, time 7856.36ms, mfu 0.24%


Training step 48: 100%|██████████| 40/40 [00:07<00:00,  5.04it/s]


iter 48: loss 0.0149, time 8145.17ms, mfu 0.24%


Training step 49: 100%|██████████| 40/40 [00:07<00:00,  5.04it/s]


iter 49: loss 0.0134, time 8135.21ms, mfu 0.23%


Training step 50: 100%|██████████| 40/40 [00:08<00:00,  4.93it/s]


iter 50: loss 0.0122, time 8311.11ms, mfu 0.23%


Training step 51: 100%|██████████| 40/40 [00:08<00:00,  4.81it/s]


iter 51: loss 0.0111, time 8522.85ms, mfu 0.23%


Training step 52: 100%|██████████| 40/40 [00:08<00:00,  4.84it/s]


iter 52: loss 0.0102, time 8478.73ms, mfu 0.23%


Training step 53: 100%|██████████| 40/40 [00:08<00:00,  4.84it/s]


iter 53: loss 0.0094, time 8483.76ms, mfu 0.23%


Training step 54: 100%|██████████| 40/40 [00:08<00:00,  4.72it/s]


iter 54: loss 0.0087, time 8692.19ms, mfu 0.22%


Training step 55: 100%|██████████| 40/40 [00:08<00:00,  4.64it/s]


iter 55: loss 0.0081, time 8823.55ms, mfu 0.22%


Training step 56: 100%|██████████| 40/40 [00:08<00:00,  4.46it/s]


iter 56: loss 0.0075, time 9210.20ms, mfu 0.22%


Training step 57: 100%|██████████| 40/40 [00:09<00:00,  4.32it/s]


iter 57: loss 0.0070, time 9519.71ms, mfu 0.21%


Training step 58: 100%|██████████| 40/40 [00:09<00:00,  4.42it/s]


iter 58: loss 0.0066, time 9279.11ms, mfu 0.21%


Training step 59: 100%|██████████| 40/40 [00:08<00:00,  4.70it/s]


iter 59: loss 0.0062, time 8736.87ms, mfu 0.21%


Training step 60: 100%|██████████| 40/40 [00:08<00:00,  4.79it/s]


iter 60: loss 0.0058, time 8558.49ms, mfu 0.21%


Training step 61: 100%|██████████| 40/40 [00:08<00:00,  4.73it/s]


iter 61: loss 0.0055, time 8657.40ms, mfu 0.21%


Training step 62: 100%|██████████| 40/40 [00:08<00:00,  4.76it/s]


iter 62: loss 0.0052, time 8604.86ms, mfu 0.21%


Training step 63: 100%|██████████| 40/40 [00:08<00:00,  4.82it/s]


iter 63: loss 0.0049, time 8523.15ms, mfu 0.21%


Training step 64: 100%|██████████| 40/40 [00:08<00:00,  4.86it/s]


iter 64: loss 0.0047, time 8438.62ms, mfu 0.21%


Training step 65: 100%|██████████| 40/40 [00:08<00:00,  4.82it/s]


iter 65: loss 0.0045, time 8507.97ms, mfu 0.21%


Training step 66: 100%|██████████| 40/40 [00:08<00:00,  4.88it/s]


iter 66: loss 0.0043, time 8409.17ms, mfu 0.21%


Training step 67: 100%|██████████| 40/40 [00:07<00:00,  5.03it/s]


iter 67: loss 0.0041, time 8175.20ms, mfu 0.21%


Training step 68: 100%|██████████| 40/40 [00:08<00:00,  4.98it/s]


iter 68: loss 0.0039, time 8235.49ms, mfu 0.21%


Training step 69: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s]


iter 69: loss 0.0037, time 8297.78ms, mfu 0.21%


Training step 70: 100%|██████████| 40/40 [00:07<00:00,  5.07it/s]


iter 70: loss 0.0036, time 8105.12ms, mfu 0.21%


Training step 71: 100%|██████████| 40/40 [00:07<00:00,  5.03it/s]


iter 71: loss 0.0035, time 8145.36ms, mfu 0.21%


Training step 72: 100%|██████████| 40/40 [00:07<00:00,  5.02it/s]


iter 72: loss 0.0033, time 8176.21ms, mfu 0.21%


Training step 73: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s]


iter 73: loss 0.0032, time 8309.55ms, mfu 0.21%


Training step 74: 100%|██████████| 40/40 [00:08<00:00,  4.99it/s]


iter 74: loss 0.0031, time 8228.12ms, mfu 0.21%


Training step 75: 100%|██████████| 40/40 [00:08<00:00,  4.89it/s]


iter 75: loss 0.0030, time 8401.03ms, mfu 0.21%


Training step 76: 100%|██████████| 40/40 [00:08<00:00,  4.80it/s]


iter 76: loss 0.0029, time 8552.22ms, mfu 0.21%


Training step 77: 100%|██████████| 40/40 [00:08<00:00,  4.80it/s]


iter 77: loss 0.0028, time 8563.10ms, mfu 0.21%


Training step 78: 100%|██████████| 40/40 [00:08<00:00,  4.78it/s]


iter 78: loss 0.0027, time 8593.52ms, mfu 0.21%


Training step 79: 100%|██████████| 40/40 [00:08<00:00,  4.79it/s]


iter 79: loss 0.0026, time 8584.49ms, mfu 0.21%


Training step 80: 100%|██████████| 40/40 [00:08<00:00,  4.70it/s]


iter 80: loss 0.0025, time 8734.01ms, mfu 0.21%


Training step 81: 100%|██████████| 40/40 [00:08<00:00,  4.74it/s]


iter 81: loss 0.0025, time 8667.93ms, mfu 0.21%


Training step 82: 100%|██████████| 40/40 [00:08<00:00,  4.63it/s]


iter 82: loss 0.0024, time 8867.31ms, mfu 0.21%


Training step 83: 100%|██████████| 40/40 [00:08<00:00,  4.57it/s]


iter 83: loss 0.0023, time 8968.49ms, mfu 0.21%


Training step 84: 100%|██████████| 40/40 [00:08<00:00,  4.58it/s]


iter 84: loss 0.0022, time 8960.36ms, mfu 0.20%


Training step 85: 100%|██████████| 40/40 [00:08<00:00,  4.65it/s]


iter 85: loss 0.0022, time 8849.45ms, mfu 0.20%


Training step 86: 100%|██████████| 40/40 [00:08<00:00,  4.56it/s]


iter 86: loss 0.0021, time 8985.99ms, mfu 0.20%


Training step 87: 100%|██████████| 40/40 [00:09<00:00,  4.40it/s]


iter 87: loss 0.0021, time 9281.66ms, mfu 0.20%


Training step 88: 100%|██████████| 40/40 [00:08<00:00,  4.54it/s]


iter 88: loss 0.0020, time 9053.98ms, mfu 0.20%


Training step 89: 100%|██████████| 40/40 [00:08<00:00,  4.51it/s]


iter 89: loss 0.0020, time 9090.53ms, mfu 0.20%


Training step 90: 100%|██████████| 40/40 [00:09<00:00,  4.36it/s]


iter 90: loss 0.0019, time 9396.01ms, mfu 0.20%


Training step 91: 100%|██████████| 40/40 [00:09<00:00,  4.36it/s]


iter 91: loss 0.0019, time 9397.36ms, mfu 0.20%


Training step 92: 100%|██████████| 40/40 [00:08<00:00,  4.46it/s]


iter 92: loss 0.0018, time 9217.86ms, mfu 0.20%


Training step 93: 100%|██████████| 40/40 [00:09<00:00,  4.40it/s]


iter 93: loss 0.0018, time 9311.67ms, mfu 0.20%


Training step 94: 100%|██████████| 40/40 [00:09<00:00,  4.30it/s]


iter 94: loss 0.0017, time 9532.29ms, mfu 0.20%


Training step 95: 100%|██████████| 40/40 [00:09<00:00,  4.38it/s]


iter 95: loss 0.0017, time 9369.97ms, mfu 0.19%


Training step 96: 100%|██████████| 40/40 [00:08<00:00,  4.53it/s]


iter 96: loss 0.0016, time 9060.36ms, mfu 0.19%


Training step 97: 100%|██████████| 40/40 [00:09<00:00,  4.30it/s]


iter 97: loss 0.0016, time 9536.89ms, mfu 0.19%


Training step 98: 100%|██████████| 40/40 [00:09<00:00,  4.22it/s]


iter 98: loss 0.0016, time 9702.55ms, mfu 0.19%


Training step 99: 100%|██████████| 40/40 [00:09<00:00,  4.37it/s]


iter 99: loss 0.0015, time 9404.45ms, mfu 0.19%


Training step 100: 100%|██████████| 40/40 [00:09<00:00,  4.34it/s]


iter 100: loss 0.0015, time 9433.11ms, mfu 0.19%


Training step 101: 100%|██████████| 40/40 [00:09<00:00,  4.18it/s]


iter 101: loss 0.0015, time 9796.40ms, mfu 0.19%


Training step 102: 100%|██████████| 40/40 [00:09<00:00,  4.08it/s]


iter 102: loss 0.0014, time 10085.78ms, mfu 0.19%


Training step 103: 100%|██████████| 40/40 [00:09<00:00,  4.20it/s]


iter 103: loss 0.0014, time 9788.30ms, mfu 0.19%


Training step 104: 100%|██████████| 40/40 [00:09<00:00,  4.16it/s]


iter 104: loss 0.0014, time 9855.76ms, mfu 0.19%


Training step 105: 100%|██████████| 40/40 [00:10<00:00,  3.90it/s]


iter 105: loss 0.0013, time 10510.60ms, mfu 0.18%


Training step 106: 100%|██████████| 40/40 [00:09<00:00,  4.07it/s]


iter 106: loss 0.0013, time 10091.88ms, mfu 0.18%


Training step 107: 100%|██████████| 40/40 [00:09<00:00,  4.17it/s]


iter 107: loss 0.0013, time 9841.14ms, mfu 0.18%


Training step 108: 100%|██████████| 40/40 [00:09<00:00,  4.07it/s]


iter 108: loss 0.0012, time 10072.41ms, mfu 0.18%


Training step 109: 100%|██████████| 40/40 [00:10<00:00,  3.98it/s]


iter 109: loss 0.0012, time 10284.20ms, mfu 0.18%


Training step 110: 100%|██████████| 40/40 [00:09<00:00,  4.07it/s]


iter 110: loss 0.0012, time 10088.13ms, mfu 0.18%


Training step 111: 100%|██████████| 40/40 [00:09<00:00,  4.13it/s]


iter 111: loss 0.0012, time 9932.86ms, mfu 0.18%


Training step 112: 100%|██████████| 40/40 [00:10<00:00,  3.94it/s]


iter 112: loss 0.0011, time 10393.79ms, mfu 0.18%


Training step 113: 100%|██████████| 40/40 [00:10<00:00,  3.99it/s]


iter 113: loss 0.0011, time 10283.37ms, mfu 0.18%


Training step 114: 100%|██████████| 40/40 [00:09<00:00,  4.10it/s]


iter 114: loss 0.0011, time 10017.36ms, mfu 0.18%


Training step 115: 100%|██████████| 40/40 [00:10<00:00,  3.95it/s]


iter 115: loss 0.0011, time 10390.88ms, mfu 0.18%


Training step 116: 100%|██████████| 40/40 [00:12<00:00,  3.18it/s]


iter 116: loss 0.0010, time 12845.46ms, mfu 0.17%


Training step 117: 100%|██████████| 40/40 [00:10<00:00,  3.84it/s]


iter 117: loss 0.0010, time 10693.46ms, mfu 0.17%


Training step 118: 100%|██████████| 40/40 [00:10<00:00,  3.93it/s]


iter 118: loss 0.0010, time 10456.20ms, mfu 0.17%


Training step 119: 100%|██████████| 40/40 [00:10<00:00,  3.66it/s]


iter 119: loss 0.0010, time 11175.97ms, mfu 0.17%


Training step 120: 100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


iter 120: loss 0.0009, time 11126.87ms, mfu 0.17%


Training step 121: 100%|██████████| 40/40 [00:10<00:00,  3.88it/s]


iter 121: loss 0.0009, time 10614.73ms, mfu 0.17%


Training step 122: 100%|██████████| 40/40 [00:10<00:00,  3.85it/s]


iter 122: loss 0.0009, time 10664.28ms, mfu 0.17%


Training step 123: 100%|██████████| 40/40 [00:10<00:00,  3.71it/s]


iter 123: loss 0.0009, time 11025.47ms, mfu 0.17%


Training step 124: 100%|██████████| 40/40 [00:10<00:00,  3.88it/s]


iter 124: loss 0.0008, time 10613.13ms, mfu 0.17%


Training step 125: 100%|██████████| 40/40 [00:10<00:00,  3.83it/s]


iter 125: loss 0.0008, time 10718.66ms, mfu 0.17%


Training step 126: 100%|██████████| 40/40 [00:10<00:00,  3.76it/s]


iter 126: loss 0.0008, time 10897.63ms, mfu 0.17%


Training step 127: 100%|██████████| 40/40 [00:10<00:00,  3.69it/s]


iter 127: loss 0.0008, time 11119.12ms, mfu 0.17%


Training step 128: 100%|██████████| 40/40 [00:10<00:00,  3.87it/s]


iter 128: loss 0.0008, time 10600.86ms, mfu 0.17%


Training step 129: 100%|██████████| 40/40 [00:10<00:00,  3.71it/s]


iter 129: loss 0.0007, time 11061.35ms, mfu 0.16%


Training step 130: 100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


iter 130: loss 0.0007, time 11144.56ms, mfu 0.16%


Training step 131: 100%|██████████| 40/40 [00:10<00:00,  3.82it/s]


iter 131: loss 0.0007, time 10753.83ms, mfu 0.16%


Training step 132: 100%|██████████| 40/40 [00:10<00:00,  3.79it/s]


iter 132: loss 0.0007, time 10814.37ms, mfu 0.16%


Training step 133: 100%|██████████| 40/40 [00:12<00:00,  3.10it/s]


iter 133: loss 0.0007, time 13238.06ms, mfu 0.16%


Training step 134:  62%|██████▎   | 25/40 [00:08<00:05,  2.90it/s]


KeyboardInterrupt: 

In [12]:
device

'mps'

In [None]:
x