In [1]:
import math

import torch
from torch.nn.utils.rnn import pad_sequence
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import transformers

from einops import rearrange
from typing import Optional, Union

import tiktoken
import pandas as pd
import re

import os


In [2]:
import wandb

# Initialize wandb
wandb_log = True  
run_name = 'Flash Attention - mps'
project_name = "Attention Benchmark "


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnielspace[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
class TextDataset(Dataset):
    def __init__(self, articles, model="gpt2", seq_length=512):
        self.tokenizer = tiktoken.get_encoding(model)
        self.vocab_size = self.tokenizer.n_vocab
        self.seq_length = seq_length
        self.articles = articles.apply(self.preprocess_and_tokenize)
        
        self.input_ids, self.targets, self.attention_masks = self.create_sequences()

    def preprocess_and_tokenize(self, text):
        # Preprocess text
        text = text.lower()
        text = re.sub(r'[^a-zA-Z\s]', '', text)
        text = re.sub(r'\s+', ' ', text).strip()
        
        # Tokenize text
        tokens = self.tokenizer.encode(text)
        
        # Check for invalid token indices
        assert all(token < self.vocab_size for token in tokens), "Token index out of range"
        
        # Pad and truncate tokens
        if len(tokens) > self.seq_length:
            tokens = tokens[:self.seq_length]
        else:
            tokens += [0] * (self.seq_length - len(tokens))
        self.tokens = tokens
        return tokens

    def create_sequences(self):
        input_ids = []
        attention_masks = []
        targets = []
        
        for tokens in self.articles:
            input_ids.append(tokens[:-1])  # Exclude the last token for input
            targets.append(tokens[1:])     # Exclude the first token for target
            attention_masks.append([1 if token != 0 else 0 for token in tokens[:-1]])
        
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        attention_masks = torch.tensor(attention_masks, dtype=torch.long)
        targets = torch.tensor(targets, dtype=torch.long)
        
        return input_ids, targets, attention_masks

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        input_seq = self.input_ids[idx]
        attention_mask = self.attention_masks[idx]
        target_seq = self.targets[idx]
        
        sample = {'input_ids': input_seq, 'targets': target_seq, 'attention_mask': attention_mask}
        return sample

def pad_sequences(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    targets = [item['targets'] for item in batch]

    input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks_padded = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0)
    targets_padded = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=0)

    return {'input_ids': input_ids_padded, 'targets': targets_padded, 'attention_mask': attention_masks_padded}

In [4]:
# Create the dataset
df = pd.read_parquet("hf://datasets/gamino/wiki_medical_terms/wiki_medical_terms.parquet")
articles = df.iloc[:, 1]

dataset = TextDataset(articles, model="gpt2", seq_length=512)

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=2, collate_fn=pad_sequences)

In [5]:
for data in dataloader:
    x, y, att = data['input_ids'], data['targets'], data['attention_mask']
    print(x,y,att)
    break

tensor([[ 1845, 23253,   321,  ..., 14383,  2465,   287],
        [  330,   398,  1533,  ...,   290, 24039,  1390]]) tensor([[23253,   321,   349,  ...,  2465,   287,  2276],
        [  398,  1533,  3400,  ..., 24039,  1390, 32674]]) tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])


In [6]:
x.shape, y.shape

(torch.Size([2, 511]), torch.Size([2, 511]))

In [7]:
class FlashAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.head_dim = self.n_embd // self.n_head
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(self.n_embd, 3 * self.n_embd)
        self.proj = nn.Linear(self.n_embd, self.n_embd)
        self.dropout_p = config.dropout
        self.causal = True  # assuming causal for GPT-like model

    def forward(self, x):
        b, t, c = x.size()
        qkv = self.qkv(x).view(b, t, 3, self.n_head, self.head_dim)
        q, k, v = qkv.unbind(2)
        q, k, v = [rearrange(x, 'b t h d -> (b h) t d') for x in (q, k, v)]

        # Calculate attention scores
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        if self.causal:
            mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device, dtype=torch.bool), diagonal=1)
            attn_scores = attn_scores.masked_fill(mask, float('-inf'))

        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = F.dropout(attn_weights, p=self.dropout_p, training=self.training)

        attn_output = torch.matmul(attn_weights, v)
        attn_output = rearrange(attn_output, '(b h) t d -> b t (h d)', h=self.n_head)
        attn_output = self.proj(attn_output)

        # Save tensors for backward pass
        self.saved_tensors = (q, k, v, attn_weights)
        return attn_output

    def backward(self, dout):
        q, k, v, attn_weights = self.saved_tensors

        # Gradient of attention output
        datt = torch.matmul(dout, v.transpose(-2, -1))
        dv = torch.matmul(attn_weights.transpose(-2, -1), dout)

        # Gradient of attention weights
        datt_weights = dout @ v.transpose(-2, -1)

        # Masked fill in the gradient if causal
        if self.causal:
            mask = torch.triu(torch.ones(datt.size(-2), datt.size(-1), device=datt.device, dtype=torch.bool), diagonal=1)
            datt = datt.masked_fill(mask, 0.0)

        # Gradient of Q, K, V
        dq = datt @ k.transpose(-2, -1)
        dk = datt.transpose(-2, -1) @ q
        dq = dq.view_as(q)
        dk = dk.view_as(k)
        dv = dv.view_as(v)

        return dq, dk, dv


In [8]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd, eps=1e-5)
        self.attn = FlashAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd, eps=1e-5)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.dropout),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

# Assuming GPTConfig is defined somewhere in the code
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None, "Config must include vocab_size"
        assert config.block_size is not None, "Config must include block_size"
        self.config = config

        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(config.vocab_size, config.n_embd),
            'wpe': nn.Embedding(config.block_size, config.n_embd),
            'drop': nn.Dropout(config.dropout),
            'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            'ln_f': nn.LayerNorm(config.n_embd, eps=1e-5),
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.tie_weights()

        self.apply(self._init_weights)
        self.init_residuals()
        print(f"Number of parameters: {self.get_num_params()/1e6:.2f}M")

    def tie_weights(self):
        self.transformer.wte.weight = self.lm_head.weight

    def init_residuals(self):
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))

    def get_num_params(self, non_embedding=True):
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        logits = self.lm_head(x)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
            return logits, loss
        else:
            logits = logits[:, [-1], :]  # Use only the last token's logits
            return logits, None

    def crop_block_size(self, block_size):
        assert block_size <= self.config.block_size
        self.config.block_size = block_size
        self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
        for block in self.transformer.h:
            if hasattr(block.attn, 'bias'):
                block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]

    @classmethod
    def from_pretrained(cls, model_type, override_args=None):
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}, "Invalid model type"
        override_args = override_args or {}
        assert all(k == 'dropout' for k in override_args), "Only 'dropout' can be overridden"

        print(f"Loading weights from pretrained model: {model_type}")

        config_args = cls.get_config_args(model_type)
        if 'dropout' in override_args:
            print(f"Overriding dropout rate to {override_args['dropout']}")
            config_args['dropout'] = override_args['dropout']

        config = GPTConfig(**config_args)
        model = GPT(config)
        model.load_pretrained_weights(model_type)
        return model

    @staticmethod
    def get_config_args(model_type):
        config_map = {
            'gpt2': {'n_layer': 12, 'n_head': 12, 'n_embd': 768},
            'gpt2-medium': {'n_layer': 24, 'n_head': 16, 'n_embd': 1024},
            'gpt2-large': {'n_layer': 36, 'n_head': 20, 'n_embd': 1280},
            'gpt2-xl': {'n_layer': 48, 'n_head': 25, 'n_embd': 1600},
        }
        config_args = config_map[model_type]
        config_args.update({'vocab_size': 50257, 'block_size': 1024, 'bias': True, 'dropout': 0.1})
        return config_args

    def load_pretrained_weights(self, model_type):
        model_hf = transformers.GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()
        sd = self.state_dict()
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']

        for k, v in sd_hf.items():
            if k in sd:
                if any(k.endswith(w) for w in transposed):
                    sd[k].copy_(v.t())
                else:
                    sd[k].copy_(v)

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        N = self.get_num_params()
        L, H, Q, T = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.block_size
        flops_per_token = 6 * N + 12 * L * H * Q * T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        flops_achieved = flops_per_iter * (1.0 / dt)
        flops_promised = 312e12
        mfu = flops_achieved / flops_promised
        return mfu

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [9]:
#training dependencies
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler
from contextlib import nullcontext
import torch.distributed as dist
from torch.distributed import init_process_group, destroy_process_group

In [10]:
# Default config values designed to train a GPT-2 (124M) on OpenWebText
# I/O
out_dir = 'out'
eval_interval = 2000
log_interval = 1
eval_iters = 200
eval_only = False
always_save_checkpoint = True
init_from = 'scratch'  # 'scratch' or 'resume' or 'gpt2*'

# Data
# data_path = 'data/openwebtext.txt'  # Path to your text file
gradient_accumulation_steps = 5 * 8
batch_size = 2
block_size = 512

# Model
n_layer = 6  # Reduce the number of layers
n_head = 8   # Reduce the number of attention heads
n_embd = 512 # Reduce the embedding size
dropout = 0.0
bias = False

# AdamW optimizer
learning_rate = 6e-4
max_iters = 1000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0

# Learning rate decay settings
decay_lr = True
warmup_iters = 200
lr_decay_iters = 1000
min_lr = 6e-5

# DDP settings
backend = 'gloo'

# System
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
dtype = 'float32'  # MPS currently supports only float32
compile = False  # Disable compilation for now as PyTorch 2.0 is not yet stable on MPS
# -----------------------------------------------------------------------------

# Collect configuration keys
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
config = {k: globals()[k] for k in config_keys}

# Various initializations and derived attributes
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
    init_process_group(backend=backend)
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'mps:{ddp_local_rank}' if torch.backends.mps.is_available() else f'cuda:{ddp_local_rank}'
    torch.mps.set_device(device)
    master_process = ddp_rank == 0
    seed_offset = ddp_rank
    assert gradient_accumulation_steps % ddp_world_size == 0
    gradient_accumulation_steps //= ddp_world_size
else:
    master_process = True
    seed_offset = 0
    ddp_world_size = 1

tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print(f"Tokens per iteration will be: {tokens_per_iter:,}")

if master_process:
    os.makedirs(out_dir, exist_ok=True)

torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'mps' if 'mps' in device else 'cpu'
ptdtype = torch.float32  # Currently, MPS supports only float32
ctx = nullcontext()

Tokens per iteration will be: 40,960


In [11]:
#configuration
class GPTConfig:
    def __init__(self, vocab_size, block_size, n_layer, n_head, n_embd, dropout, bias=True):
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout = dropout
        self.bias = bias
        
# Initialize iteration number and best validation loss
iter_num = 0
best_val_loss = 1e9

# Model initialization
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=len(dataset.tokens), dropout=dropout)

if init_from == 'scratch':
    print("Initializing a new model from scratch")
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
elif init_from == 'resume':
    print(f"Resuming training from {out_dir}")
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    checkpoint_model_args = checkpoint['model_args']
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = checkpoint_model_args[k]
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
elif init_from.startswith('gpt2'):
    print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
    override_args = dict(dropout=dropout)
    model = GPT.from_pretrained(init_from, override_args)
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = getattr(model.config, k)

if block_size < model.config.block_size:
    model.crop_block_size(block_size)
    model_args['block_size'] = block_size

model.to(device)

Initializing a new model from scratch
Number of parameters: 19.18M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(512, 512)
    (wpe): Embedding(512, 512)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): FlashAttention(
          (qkv): Linear(in_features=512, out_features=1536, bias=True)
          (proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2048, out_features=512, bias=True)
          (3): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=512, bias=False)
)

In [12]:
# Initialize a GradScaler
scaler = GradScaler(enabled=(dtype == 'float16' and 'cuda' in device))

# Configure the optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
    optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None  # Free up memory

# Wrap model into DDP container if needed
if ddp:
    print(f"Starting parallel process with rank {ddp_rank}")
    model = DDP(model, device_ids=[ddp_local_rank])


# Function to estimate loss over splits
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        with tqdm(total=eval_iters, desc=f"Evaluating {split}") as pbar:
            for k in range(eval_iters):
                batch = next(iter(dataloader))
                X, Y = batch['input_ids'].to(device), batch['targets'].to(device)
                with ctx:
                    logits, loss = model(X, Y)
                losses[k] = loss.item()
                pbar.update(1)
        out[split] = losses.mean()
    model.train()
    return out

# Learning rate decay scheduler
def get_lr(it):
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    if it > lr_decay_iters:
        return min_lr
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)

In [13]:
import time
from tqdm import tqdm


wandb.init(
    # set the wandb project where this run will be logged
    project=project_name, name= run_name, config=config)

raw_model = model.module if ddp else model
local_iter_num = 0
running_mfu = -1.0
t0 = time.time()

while iter_num <= max_iters:
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    if iter_num % eval_interval == 0 and master_process:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        if wandb_log:
            wandb.log({
                "iter": iter_num,
                "train_loss": losses['train'],
                "val_loss": losses['val'],
                "lr": lr,
                "mfu": running_mfu * 100,  # convert to percentage
            })


        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                    'config': config,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break

    with tqdm(total=gradient_accumulation_steps, desc=f"Training step {iter_num}") as pbar:
        total_train_loss = 0.0  # Initialize total training loss
        for micro_step in range(gradient_accumulation_steps):
            if ddp:
                model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
            batch = next(iter(dataloader))
            X, Y = batch['input_ids'].to(device), batch['targets'].to(device)
            with ctx:
                logits, loss = model(X, Y)
                loss = loss / gradient_accumulation_steps
            scaler.scale(loss).backward()
            pbar.update(1)

    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)

    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5:
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt * 1000:.2f}ms, mfu {running_mfu * 100:.2f}%")

        if wandb_log:
            wandb.log({
                "iter": iter_num,
                "train_loss": total_train_loss,
                "lr": lr,
                "mfu": running_mfu * 100,  # convert to percentage
            })

    iter_num += 1
    local_iter_num += 1

if ddp:
    destroy_process_group()

Evaluating train: 100%|██████████| 200/200 [00:16<00:00, 12.11it/s]
Evaluating val: 100%|██████████| 200/200 [00:14<00:00, 13.96it/s]


step 0: train loss 2.0631, val loss 2.0631


Training step 0: 100%|██████████| 40/40 [00:08<00:00,  4.59it/s]


iter 0: loss 2.0631, time 40256.62ms, mfu -100.00%


Training step 1: 100%|██████████| 40/40 [00:07<00:00,  5.47it/s]


iter 1: loss 2.0631, time 7408.92ms, mfu -100.00%


Training step 2: 100%|██████████| 40/40 [00:07<00:00,  5.40it/s]


iter 2: loss 2.0370, time 7600.01ms, mfu -100.00%


Training step 3: 100%|██████████| 40/40 [00:07<00:00,  5.61it/s]


iter 3: loss 1.9862, time 7310.86ms, mfu -100.00%


Training step 4: 100%|██████████| 40/40 [00:07<00:00,  5.46it/s]


iter 4: loss 1.9157, time 7519.17ms, mfu -100.00%


Training step 5: 100%|██████████| 40/40 [00:07<00:00,  5.45it/s]


iter 5: loss 1.8344, time 7514.98ms, mfu 0.23%


Training step 6: 100%|██████████| 40/40 [00:07<00:00,  5.57it/s]


iter 6: loss 1.7532, time 7375.65ms, mfu 0.23%


Training step 7: 100%|██████████| 40/40 [00:07<00:00,  5.49it/s]


iter 7: loss 1.6792, time 7466.66ms, mfu 0.23%


Training step 8: 100%|██████████| 40/40 [00:07<00:00,  5.53it/s]


iter 8: loss 1.6144, time 7420.53ms, mfu 0.23%


Training step 9: 100%|██████████| 40/40 [00:07<00:00,  5.43it/s]


iter 9: loss 1.5575, time 7545.67ms, mfu 0.23%


Training step 10: 100%|██████████| 40/40 [00:07<00:00,  5.22it/s]


iter 10: loss 1.5062, time 7863.78ms, mfu 0.23%


Training step 11: 100%|██████████| 40/40 [00:07<00:00,  5.33it/s]


iter 11: loss 1.4590, time 7680.80ms, mfu 0.23%


Training step 12: 100%|██████████| 40/40 [00:07<00:00,  5.41it/s]


iter 12: loss 1.4162, time 7572.87ms, mfu 0.23%


Training step 13: 100%|██████████| 40/40 [00:07<00:00,  5.20it/s]


iter 13: loss 1.3782, time 7892.51ms, mfu 0.23%


Training step 14: 100%|██████████| 40/40 [00:07<00:00,  5.41it/s]


iter 14: loss 1.3433, time 7577.11ms, mfu 0.23%


Training step 15: 100%|██████████| 40/40 [00:07<00:00,  5.52it/s]


iter 15: loss 1.3095, time 7428.24ms, mfu 0.23%


Training step 16: 100%|██████████| 40/40 [00:07<00:00,  5.51it/s]


iter 16: loss 1.2762, time 7443.84ms, mfu 0.23%


Training step 17: 100%|██████████| 40/40 [00:07<00:00,  5.47it/s]


iter 17: loss 1.2419, time 7499.61ms, mfu 0.23%


Training step 18: 100%|██████████| 40/40 [00:07<00:00,  5.48it/s]


iter 18: loss 1.2038, time 7477.60ms, mfu 0.23%


Training step 19: 100%|██████████| 40/40 [00:07<00:00,  5.46it/s]


iter 19: loss 1.1608, time 7506.00ms, mfu 0.23%


Training step 20: 100%|██████████| 40/40 [00:07<00:00,  5.37it/s]


iter 20: loss 1.1105, time 7619.94ms, mfu 0.23%


Training step 21: 100%|██████████| 40/40 [00:07<00:00,  5.36it/s]


iter 21: loss 1.0525, time 7612.05ms, mfu 0.23%


Training step 22: 100%|██████████| 40/40 [00:07<00:00,  5.24it/s]


iter 22: loss 0.9858, time 7787.98ms, mfu 0.23%


Training step 23: 100%|██████████| 40/40 [00:07<00:00,  5.20it/s]


iter 23: loss 0.9110, time 7855.57ms, mfu 0.23%


Training step 24: 100%|██████████| 40/40 [00:07<00:00,  5.18it/s]


iter 24: loss 0.8302, time 7884.18ms, mfu 0.23%


Training step 25: 100%|██████████| 40/40 [00:07<00:00,  5.16it/s]


iter 25: loss 0.7505, time 7919.98ms, mfu 0.23%


Training step 26: 100%|██████████| 40/40 [00:07<00:00,  5.13it/s]


iter 26: loss 0.6756, time 7973.80ms, mfu 0.23%


Training step 27: 100%|██████████| 40/40 [00:07<00:00,  5.10it/s]


iter 27: loss 0.6044, time 8047.79ms, mfu 0.23%


Training step 28: 100%|██████████| 40/40 [00:07<00:00,  5.12it/s]


iter 28: loss 0.5374, time 7942.95ms, mfu 0.23%


Training step 29: 100%|██████████| 40/40 [00:07<00:00,  5.10it/s]


iter 29: loss 0.4745, time 8038.33ms, mfu 0.23%


Training step 30: 100%|██████████| 40/40 [00:07<00:00,  5.09it/s]


iter 30: loss 0.4175, time 8028.50ms, mfu 0.23%


Training step 31: 100%|██████████| 40/40 [00:07<00:00,  5.11it/s]


iter 31: loss 0.3619, time 8005.61ms, mfu 0.22%


Training step 32: 100%|██████████| 40/40 [00:07<00:00,  5.04it/s]


iter 32: loss 0.3094, time 8108.67ms, mfu 0.22%


Training step 33: 100%|██████████| 40/40 [00:07<00:00,  5.08it/s]


iter 33: loss 0.2615, time 8058.33ms, mfu 0.22%


Training step 34: 100%|██████████| 40/40 [00:07<00:00,  5.07it/s]


iter 34: loss 0.2189, time 8069.16ms, mfu 0.22%


Training step 35: 100%|██████████| 40/40 [00:08<00:00,  4.84it/s]


iter 35: loss 0.1823, time 8435.42ms, mfu 0.22%


Training step 36: 100%|██████████| 40/40 [00:07<00:00,  5.04it/s]


iter 36: loss 0.1529, time 8104.90ms, mfu 0.22%


Training step 37: 100%|██████████| 40/40 [00:08<00:00,  4.99it/s]


iter 37: loss 0.1270, time 8207.21ms, mfu 0.22%


Training step 38: 100%|██████████| 40/40 [00:08<00:00,  4.98it/s]


iter 38: loss 0.1056, time 8203.18ms, mfu 0.22%


Training step 39: 100%|██████████| 40/40 [00:07<00:00,  5.05it/s]


iter 39: loss 0.0877, time 8092.11ms, mfu 0.22%


Training step 40: 100%|██████████| 40/40 [00:08<00:00,  4.97it/s]


iter 40: loss 0.0721, time 8242.59ms, mfu 0.22%


Training step 41: 100%|██████████| 40/40 [00:07<00:00,  5.02it/s]


iter 41: loss 0.0603, time 8140.46ms, mfu 0.22%


Training step 42: 100%|██████████| 40/40 [00:08<00:00,  4.99it/s]


iter 42: loss 0.0509, time 8211.24ms, mfu 0.22%


Training step 43: 100%|██████████| 40/40 [00:07<00:00,  5.06it/s]


iter 43: loss 0.0427, time 8104.03ms, mfu 0.22%


Training step 44: 100%|██████████| 40/40 [00:07<00:00,  5.02it/s]


iter 44: loss 0.0365, time 8162.46ms, mfu 0.22%


Training step 45: 100%|██████████| 40/40 [00:08<00:00,  4.96it/s]


iter 45: loss 0.0312, time 8259.21ms, mfu 0.22%


Training step 46: 100%|██████████| 40/40 [00:08<00:00,  4.95it/s]


iter 46: loss 0.0266, time 8293.51ms, mfu 0.22%


Training step 47: 100%|██████████| 40/40 [00:08<00:00,  4.82it/s]


iter 47: loss 0.0232, time 8506.63ms, mfu 0.22%


Training step 48: 100%|██████████| 40/40 [00:08<00:00,  4.81it/s]


iter 48: loss 0.0200, time 8523.97ms, mfu 0.21%


Training step 49: 100%|██████████| 40/40 [00:08<00:00,  4.68it/s]


iter 49: loss 0.0175, time 8755.38ms, mfu 0.21%


Training step 50: 100%|██████████| 40/40 [00:08<00:00,  4.70it/s]


iter 50: loss 0.0156, time 8718.53ms, mfu 0.21%


Training step 51: 100%|██████████| 40/40 [00:08<00:00,  4.62it/s]


iter 51: loss 0.0138, time 8885.66ms, mfu 0.21%


Training step 52: 100%|██████████| 40/40 [00:08<00:00,  4.52it/s]


iter 52: loss 0.0124, time 9072.02ms, mfu 0.21%


Training step 53: 100%|██████████| 40/40 [00:08<00:00,  4.60it/s]


iter 53: loss 0.0113, time 8939.51ms, mfu 0.21%


Training step 54: 100%|██████████| 40/40 [00:08<00:00,  4.48it/s]


iter 54: loss 0.0103, time 9149.96ms, mfu 0.21%


Training step 55: 100%|██████████| 40/40 [00:08<00:00,  4.54it/s]


iter 55: loss 0.0094, time 9023.60ms, mfu 0.21%


Training step 56: 100%|██████████| 40/40 [00:08<00:00,  4.46it/s]


iter 56: loss 0.0087, time 9183.68ms, mfu 0.20%


Training step 57: 100%|██████████| 40/40 [00:08<00:00,  4.54it/s]


iter 57: loss 0.0081, time 9049.70ms, mfu 0.20%


Training step 58: 100%|██████████| 40/40 [00:09<00:00,  4.34it/s]


iter 58: loss 0.0075, time 9438.60ms, mfu 0.20%


Training step 59: 100%|██████████| 40/40 [00:08<00:00,  4.61it/s]


iter 59: loss 0.0070, time 8905.91ms, mfu 0.20%


Training step 60: 100%|██████████| 40/40 [00:08<00:00,  4.52it/s]


iter 60: loss 0.0065, time 9076.98ms, mfu 0.20%


Training step 61: 100%|██████████| 40/40 [00:08<00:00,  4.53it/s]


iter 61: loss 0.0061, time 9048.17ms, mfu 0.20%


Training step 62: 100%|██████████| 40/40 [00:08<00:00,  4.63it/s]


iter 62: loss 0.0058, time 8863.46ms, mfu 0.20%


Training step 63: 100%|██████████| 40/40 [00:08<00:00,  4.57it/s]


iter 63: loss 0.0055, time 8970.57ms, mfu 0.20%


Training step 64: 100%|██████████| 40/40 [00:08<00:00,  4.65it/s]


iter 64: loss 0.0052, time 8819.27ms, mfu 0.20%


Training step 65: 100%|██████████| 40/40 [00:08<00:00,  4.62it/s]


iter 65: loss 0.0049, time 8867.86ms, mfu 0.20%


Training step 66: 100%|██████████| 40/40 [00:08<00:00,  4.68it/s]


iter 66: loss 0.0047, time 8775.40ms, mfu 0.20%


Training step 67: 100%|██████████| 40/40 [00:08<00:00,  4.65it/s]


iter 67: loss 0.0044, time 8896.07ms, mfu 0.20%


Training step 68: 100%|██████████| 40/40 [00:08<00:00,  4.72it/s]


iter 68: loss 0.0042, time 8775.35ms, mfu 0.20%


Training step 69: 100%|██████████| 40/40 [00:08<00:00,  4.62it/s]


iter 69: loss 0.0040, time 8881.22ms, mfu 0.20%


Training step 70: 100%|██████████| 40/40 [00:08<00:00,  4.63it/s]


iter 70: loss 0.0038, time 8857.56ms, mfu 0.20%


Training step 71: 100%|██████████| 40/40 [00:08<00:00,  4.56it/s]


iter 71: loss 0.0037, time 9014.34ms, mfu 0.20%


Training step 72: 100%|██████████| 40/40 [00:09<00:00,  4.37it/s]


iter 72: loss 0.0035, time 9378.52ms, mfu 0.20%


Training step 73: 100%|██████████| 40/40 [00:08<00:00,  4.61it/s]


iter 73: loss 0.0034, time 8902.64ms, mfu 0.20%


Training step 74: 100%|██████████| 40/40 [00:08<00:00,  4.55it/s]


iter 74: loss 0.0033, time 9008.57ms, mfu 0.20%


Training step 75: 100%|██████████| 40/40 [00:08<00:00,  4.56it/s]


iter 75: loss 0.0031, time 9003.13ms, mfu 0.20%


Training step 76: 100%|██████████| 40/40 [00:08<00:00,  4.53it/s]


iter 76: loss 0.0030, time 9048.31ms, mfu 0.20%


Training step 77: 100%|██████████| 40/40 [00:08<00:00,  4.58it/s]


iter 77: loss 0.0029, time 8952.70ms, mfu 0.20%


Training step 78: 100%|██████████| 40/40 [00:09<00:00,  4.34it/s]


iter 78: loss 0.0028, time 9438.60ms, mfu 0.20%


Training step 79: 100%|██████████| 40/40 [00:08<00:00,  4.51it/s]


iter 79: loss 0.0027, time 9103.43ms, mfu 0.20%


Training step 80: 100%|██████████| 40/40 [00:09<00:00,  4.27it/s]


iter 80: loss 0.0026, time 9611.62ms, mfu 0.19%


Training step 81: 100%|██████████| 40/40 [00:08<00:00,  4.50it/s]


iter 81: loss 0.0026, time 9180.91ms, mfu 0.19%


Training step 82: 100%|██████████| 40/40 [00:09<00:00,  4.29it/s]


iter 82: loss 0.0025, time 9557.66ms, mfu 0.19%


Training step 83: 100%|██████████| 40/40 [00:09<00:00,  4.42it/s]


iter 83: loss 0.0024, time 9329.30ms, mfu 0.19%


Training step 84: 100%|██████████| 40/40 [00:09<00:00,  4.34it/s]


iter 84: loss 0.0023, time 9443.55ms, mfu 0.19%


Training step 85: 100%|██████████| 40/40 [00:09<00:00,  4.25it/s]


iter 85: loss 0.0023, time 9631.82ms, mfu 0.19%


Training step 86: 100%|██████████| 40/40 [00:09<00:00,  4.35it/s]


iter 86: loss 0.0022, time 9421.13ms, mfu 0.19%


Training step 87: 100%|██████████| 40/40 [00:09<00:00,  4.25it/s]


iter 87: loss 0.0021, time 9647.76ms, mfu 0.19%


Training step 88: 100%|██████████| 40/40 [00:10<00:00,  3.74it/s]


iter 88: loss 0.0021, time 10939.10ms, mfu 0.19%


Training step 89: 100%|██████████| 40/40 [00:10<00:00,  3.98it/s]


iter 89: loss 0.0020, time 10314.47ms, mfu 0.19%


Training step 90: 100%|██████████| 40/40 [00:09<00:00,  4.12it/s]


iter 90: loss 0.0020, time 9946.12ms, mfu 0.18%


Training step 91: 100%|██████████| 40/40 [00:10<00:00,  3.90it/s]


iter 91: loss 0.0019, time 10509.44ms, mfu 0.18%


Training step 92: 100%|██████████| 40/40 [00:09<00:00,  4.02it/s]


iter 92: loss 0.0019, time 10225.55ms, mfu 0.18%


Training step 93: 100%|██████████| 40/40 [00:10<00:00,  3.67it/s]


iter 93: loss 0.0018, time 11206.19ms, mfu 0.18%


Training step 94: 100%|██████████| 40/40 [00:12<00:00,  3.21it/s]


iter 94: loss 0.0018, time 12804.55ms, mfu 0.17%


Training step 95: 100%|██████████| 40/40 [00:12<00:00,  3.26it/s]


iter 95: loss 0.0017, time 12545.42ms, mfu 0.17%


Training step 96: 100%|██████████| 40/40 [00:11<00:00,  3.55it/s]


iter 96: loss 0.0017, time 11539.59ms, mfu 0.17%


Training step 97: 100%|██████████| 40/40 [00:11<00:00,  3.58it/s]


iter 97: loss 0.0017, time 11446.81ms, mfu 0.17%


Training step 98: 100%|██████████| 40/40 [00:11<00:00,  3.61it/s]


iter 98: loss 0.0016, time 11319.44ms, mfu 0.17%


Training step 99: 100%|██████████| 40/40 [00:09<00:00,  4.04it/s]


iter 99: loss 0.0016, time 10171.25ms, mfu 0.17%


Training step 100: 100%|██████████| 40/40 [00:09<00:00,  4.03it/s]


iter 100: loss 0.0015, time 10171.70ms, mfu 0.17%


Training step 101:   8%|▊         | 3/40 [00:00<00:09,  4.10it/s]


KeyboardInterrupt: 