In [2]:
from typing import Optional

import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import Sampler, IterableDataset
from transformers import AutoTokenizer
from enum import Enum



In [3]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [5]:
import math
from typing import Optional, Tuple

import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer, functional as F
from torch.nn.init import xavier_uniform_, constant_, xavier_normal_
from torch.nn.modules.activation import MultiheadAttention
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from torch.nn.parameter import Parameter


class TransformerModel(nn.Module):
    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = "Transformer"
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[: x.size(0)]
        return self.dropout(x)

def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)


def get_gpt2_model(device = 'cpu') -> torch.nn.Module:
    return TransformerModel(
        ntoken=tokenizer.vocab_size,
        d_model=768,
        nhead=8,
        d_hid=1024,
        nlayers=8,
        dropout=0.1
    ).to(device)

In [6]:
def process_wikitext_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return f.read().split('\n = ')

def get_train_samples(data_path):
    articles_1 =  process_wikitext_file(f"{data_path}/train-00000-of-00002.txt")
    articles_2 =  process_wikitext_file(f"{data_path}/train-00001-of-00002.txt")
    articles = articles_1 + articles_2
    return articles

data_path = 'wikitext-103-raw-v1'
articles = get_train_samples(data_path)

In [7]:
tokenized_samples = tokenizer(
    articles,
    return_tensors=None
)

Token indices sequence length is longer than the specified maximum sequence length for this model (643 > 512). Running this sequence through the model will result in indexing errors


In [8]:
MAX_LENGTH = 640


class BrainDataset(Dataset):
    def __init__(self, tokenized_samples: str, max_length: int = MAX_LENGTH):
        tokens = tokenized_samples['input_ids']
        max_len_in_batch = max(len(seq) for seq in tokens)

        padded_sequences = []
        for seq in tokens:
            if len(seq) > max_length:
                seq = seq[:max_length]
            if len(seq) < max_length:
                seq = seq + [0] * (max_length - len(seq))
            padded_sequences.append(seq)
        self.samples = torch.tensor(padded_sequences, dtype=torch.long)

    def __getitem__(self, idx: int):
        assert idx >= 0 and idx < len(self.samples)
        return self.samples[idx]
    
    def __len__(self):
        return len(self.samples)


In [9]:
class BigBrainDataset(Dataset):
    def __init__(self, tokenized_samples, max_length: int = MAX_LENGTH):
        self.max_length = max_length
        num_samples = len(tokenized_samples['input_ids'])

        self.samples_tensor = torch.full((num_samples, max_length), 
                                        0, dtype=torch.long)
        
        self.lengths = torch.zeros(num_samples, dtype=torch.long)

        for i, seq in enumerate(tokenized_samples['input_ids']):
            length = min(len(seq), max_length)
            self.lengths[i] = length
            self.samples_tensor[i, :length] = torch.tensor(seq[:length], dtype=torch.long)
        
    def __getitem__(self, idx: int):
        return self.samples_tensor[idx], self.lengths[idx]
    
    def __len__(self):
        return len(self.samples_tensor)

def collate_fn(batch):
    tensors, lengths = zip(*batch)
    stacked = torch.stack(tensors)
    max_len_in_batch = max(lengths)

    return stacked[:, :max_len_in_batch]

In [10]:
import torch
from torch.utils.data import Dataset, Sampler
from typing import List
import random
from collections import defaultdict
from functools import partial

MAX_LENGTH = 640

def pad_and_truncate_tokens(token_lists, max_length: int = MAX_LENGTH):
    padded = []
    for tokens in token_lists:
        truncated = tokens[:max_length]
        if len(truncated) < max_length:
            truncated += [0] * (max_length - len(truncated))
        padded.append(truncated)
    return padded

class UltraBigBrainDataset(Dataset):
    def __init__(self, tokenized_samples, max_length: int = MAX_LENGTH, n_bins: int = 10):
        self.max_length = max_length
        input_ids = tokenized_samples['input_ids']
        self.real_lengths = [min(max_length, len(seq)) for seq in input_ids]
        self.samples = pad_and_truncate_tokens(input_ids, max_length)
        self.n_bins = n_bins
        self.bins = self._create_bins()
        for bin_id in self.bins:
            random.shuffle(self.bins[bin_id])
    
    def _create_bins(self):
        bins = defaultdict(list)
        if not self.real_lengths:
            return bins
        min_len = min(self.real_lengths)
        max_len = max(self.real_lengths)
        if min_len == max_len or self.n_bins == 1:
            bins[0] = list(range(len(self.samples)))
            return bins
        bin_size = (max_len - min_len) / self.n_bins
        for idx, length in enumerate(self.real_lengths):
            bin_id = min(self.n_bins - 1, int((length - min_len) // bin_size))
            bins[bin_id].append(idx)
        return {k: v for k, v in bins.items() if v}
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return torch.tensor(self.samples[idx], dtype=torch.long), self.real_lengths[idx]

class UltraBigBrainBatchSampler(Sampler):
    def __init__(self, dataset: UltraBigBrainDataset, batch_size: int, k: int = 10):
        self.dataset = dataset
        self.batch_size = batch_size
        self.k = k
        
        self.length_to_indices = defaultdict(list)
        for idx, length in enumerate(dataset.real_lengths):
            self.length_to_indices[length].append(idx)
        
        self.sorted_lengths = sorted(self.length_to_indices.keys())
        for length in self.sorted_lengths:
            random.shuffle(self.length_to_indices[length])
        
        self.batches = self._create_batches()
    
    def _create_batches(self):
        batches = []
        length_to_indices = {k: v.copy() for k, v in self.length_to_indices.items()}
        sorted_lengths = self.sorted_lengths.copy()
        
        while sorted_lengths:
            batch = []
            start_length = sorted_lengths[0]
            available_lengths = [
                length for length in sorted_lengths 
                if abs(length - start_length) <= self.k
            ]
            
            while len(batch) < self.batch_size and available_lengths:
                current_length = available_lengths[0]
                if not length_to_indices[current_length]:
                    available_lengths.pop(0)
                    if current_length in sorted_lengths:
                        sorted_lengths.remove(current_length)
                    continue
                
                batch.append(length_to_indices[current_length].pop())
                if len(batch) == self.batch_size:
                    break
                
                if not length_to_indices[current_length]:
                    available_lengths.pop(0)
                    if current_length in sorted_lengths:
                        sorted_lengths.remove(current_length)
            
            if batch:
                batches.append(batch)
        
        return batches
    
    def __len__(self):
        return len(self.batches)
    
    def __iter__(self):
        batches = self.batches.copy()
        random.shuffle(batches)
        yield from batches

def ultra_brain_collate_fn(batch_items):
    tensors, lengths = zip(*batch_items)
    max_len = max(lengths)
    trimmed_tensors = [tensor[:max_len] for tensor in tensors]
    return torch.stack(trimmed_tensors)

In [None]:
import torch
from torch.utils.data import Dataset
import bisect

class UltraDuperBigBrainDataset(Dataset):
    def __init__(self, tokenized_samples, packing_type, max_length=MAX_LENGTH):
        packed_samples = []
        packed_seq_ids = []

        samples = [s[:max_length] for s in tokenized_samples["input_ids"]]

        if packing_type == "basic":
            packed_samples, packed_seq_ids = self._basic_pack(samples, max_length)
        elif packing_type == "ffd":
            packed_samples, packed_seq_ids = self._ffd_pack(samples, max_length)
        elif packing_type == "obfd":
            packed_samples, packed_seq_ids = self._obfd_pack(samples, max_length)
        else:
            raise ValueError(packing_type)

        self.input_ids = torch.cat(packed_samples, dim=0)
        self.seq_ids = torch.cat(packed_seq_ids, dim=0)


    def _basic_pack(self, samples, max_length):
        inputs, seqs = [], []
        buf, buf_ids = [], []
        sid = 0

        for s in samples:
            for t in s:
                buf.append(t)
                buf_ids.append(sid)
                if len(buf) == max_length:
                    inputs.append(self._make_input(buf, max_length))
                    seqs.append(self._make_seq_ids(buf_ids, max_length))
                    buf, buf_ids = [], []
            sid += 1

        if buf:
            inputs.append(self._make_input(buf, max_length))
            seqs.append(self._make_seq_ids(buf_ids, max_length))

        return inputs, seqs

    def _ffd_pack(self, samples, max_length):
        inputs, seqs = [], []

        sequences = [(i, s) for i, s in enumerate(samples) if len(s) <= max_length]
        sequences.sort(key=lambda x: len(x[1]), reverse=True)

        bins = []
        free = []

        for sid, seq in sequences:
            l = len(seq)
            idx = bisect.bisect_left(free, (l, -1))
            if idx < len(free):
                _, bidx = free.pop(idx)
                b = bins[bidx]
                b["tok"].extend(seq)
                b["ids"].extend([sid] * l)
                b["rem"] -= l
                bisect.insort(free, (b["rem"], bidx))
            else:
                bidx = len(bins)
                bins.append({"tok": list(seq), "ids": [sid] * l, "rem": max_length - l})
                bisect.insort(free, (max_length - l, bidx))
        for b in bins:
            inputs.append(self._make_input(b["tok"], max_length))
            seqs.append(self._make_seq_ids(b["ids"], max_length))

        return inputs, seqs

    def _obfd_pack(self, samples, max_length):
        inputs, seqs = [], []

        sequences = [(i, s) for i, s in enumerate(samples) if len(s) <= max_length]
        sequences.sort(key=lambda x: len(x[1]), reverse=True)

        bins = []
        free = []

        for sid, seq in sequences:
            l = len(seq)
            idx = bisect.bisect_left(free, (l, -1))

            if idx < len(free):
                _, bidx = free.pop(idx)
                b = bins[bidx]
                b["tok"].extend(seq)
                b["ids"].extend([sid] * l)
                b["rem"] -= l
                bisect.insort(free, (b["rem"], bidx))
            else:
                bidx = len(bins)
                bins.append({
                    "tok": list(seq),
                    "ids": [sid] * l,
                    "rem": max_length - l,
                })
                bisect.insort(free, (max_length - l, bidx))

        for b in bins:
            inputs.append(self._make_input(b["tok"], max_length))
            seqs.append(self._make_seq_ids(b["ids"], max_length))

        return inputs, seqs

    def _make_input(self, tokens, max_length):
        return torch.tensor(
            tokens + [0] * (max_length - len(tokens)),
            dtype=torch.long
        ).unsqueeze(0)

    def _make_seq_ids(self, seq_ids, max_length):
        return torch.tensor(
            seq_ids + [-1] * (max_length - len(seq_ids)),
            dtype=torch.long
        ).unsqueeze(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "seq_ids": self.seq_ids[idx],
        }

    def __len__(self):
        return self.input_ids.size(0)


In [12]:
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.benchmark import Timer
from torch.amp import autocast, GradScaler

class DataMode(Enum):
    BRAIN = 1
    BIG_BRAIN = 2
    ULTRA_BIG_BRAIN = 3
    ULTRA_DUPER_BIG_BRAIN = 4

BATCH_SIZE = 64
def get_dataloader(data_mode: DataMode, n_bins: int | None = None, k: int | None = None, packing_type: str | None = None):
    if data_mode == DataMode.BRAIN:
        dataset = BrainDataset(tokenized_samples)
        return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=32, pin_memory=True)
    if data_mode == DataMode.BIG_BRAIN:
        dataset = BigBrainDataset(tokenized_samples)
        return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=32, pin_memory=True)
    if data_mode == DataMode.ULTRA_BIG_BRAIN:
        if n_bins is None:
            n_bins = 5
        if k is None:
            k = 5
        dataset = UltraBigBrainDataset(tokenized_samples, n_bins=n_bins)
        sampler = UltraBigBrainBatchSampler(
            dataset=dataset, 
            batch_size=BATCH_SIZE, 
            k=k,
        )
        return DataLoader(
            dataset, 
            batch_sampler=sampler, 
            num_workers=32,
            pin_memory=True,
            collate_fn=ultra_brain_collate_fn
        )
    if data_mode == DataMode.ULTRA_DUPER_BIG_BRAIN:
        dataset = UltraDuperBigBrainDataset(tokenized_samples, packing_type)
        return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=32, pin_memory=True)

In [13]:
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)

In [None]:
import torch
import time
from tqdm import tqdm
import numpy as np

def run_epoch(data_mode: DataMode, model, optimizer, criterion, device, 
              warmup_batches=3, n_bins: int | None = None, k: int | None = None):
    dataloader = get_dataloader(data_mode, n_bins, k)
    model.train()
    model.to(device)
    
    scaler = torch.amp.GradScaler(device.type) if device.type == 'cuda' else None
    batch_times = []
    total_loss = 0.0
    total_samples = 0
    total_batch_size = 0
    total_seq_len = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    
    for batch_idx, batch in enumerate(progress_bar):
        is_warmup = batch_idx < warmup_batches
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        if not is_warmup:
            start_time = time.perf_counter()
        
        inputs = batch.to(device)

        batch_size = inputs.size(0)
        seq_len = inputs.size(1)
        total_batch_size += batch_size
        total_seq_len += batch_size * seq_len
        
        targets = inputs[:, 1:]
        inputs = inputs[:, :-1]
        src = inputs.transpose(0, 1)
        tgt_y = targets.reshape(-1)
        mask = generate_square_subsequent_mask(src.size(0)).to(device)

        optimizer.zero_grad()
        
        if scaler:
            with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
                out = model(src, mask)
                out = out.transpose(0, 1)
                logits = out.reshape(-1, out.size(-1))
                loss = criterion(logits, tgt_y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        if not is_warmup:
            batch_time = time.perf_counter() - start_time
            batch_times.append(batch_time)
        
        batch_loss = loss.item()
        total_loss += batch_loss * batch_size
        total_samples += batch_size

        if is_warmup:
            progress_bar.set_postfix(
                status="warmup", 
                loss=f"{batch_loss:.4f}",
                batch_size=batch_size,
                seq_len=seq_len
            )
        else:
            avg_loss = total_loss / total_samples
            avg_batch_time = sum(batch_times) / len(batch_times) if batch_times else 0
            
            avg_batch_size = total_batch_size / (batch_idx + 1)
            avg_seq_len = total_seq_len / total_batch_size
            
            progress_bar.set_postfix(
                loss=f"{batch_loss:.4f}",
                avg_loss=f"{avg_loss:.4f}",
                batch_time=f"{batch_time:.4f}s",
                avg_batch_time=f"{avg_batch_time:.4f}s",
                batch_size=batch_size,
                seq_len=seq_len,
                avg_batch_size=f"{avg_batch_size:.1f}",
                avg_seq_len=f"{avg_seq_len:.1f}"
            )

    time_stats = {}
    if batch_times:
        time_array = np.array(batch_times)
        time_stats = {
            'batch_time_min': float(np.min(time_array)),
            'batch_time_max': float(np.max(time_array)),
            'batch_time_mean': float(np.mean(time_array)),
            'batch_time_median': float(np.median(time_array)),
            'batch_time_std': float(np.std(time_array)),
            'num_batches': len(batch_times)
        }
        
        total_time = sum(batch_times)
        throughput = total_samples / total_time
        avg_time = total_time / len(batch_times)
        final_avg_loss = total_loss / total_samples
        
        final_avg_batch_size = total_batch_size / len(progress_bar)
        final_avg_seq_len = total_seq_len / total_batch_size
        
        print(f"Avg loss: {final_avg_loss:.4f}")
        print(f"Avg batch time: {avg_time:.4f}s")
        print(f"Throughput: {throughput:.2f} samples/s")
        print(f"Total samples: {total_samples}")
        print(f"Avg batch size: {final_avg_batch_size:.2f}")
        print(f"Avg sequence length: {final_avg_seq_len:.2f}")
        print(f"Batch time stats: min={time_stats['batch_time_min']:.4f}s, "
              f"max={time_stats['batch_time_max']:.4f}s, "
              f"mean={time_stats['batch_time_mean']:.4f}s, "
              f"median={time_stats['batch_time_median']:.4f}s")
    
    stats = {
        'avg_loss': final_avg_loss if batch_times else 0,
        'throughput': throughput if batch_times else 0,
        'total_samples': total_samples,
        'avg_batch_size': final_avg_batch_size if batch_times else 0,
        'avg_seq_len': final_avg_seq_len if batch_times else 0,
        'time_stats': time_stats,
    }
    
    return stats

In [88]:
device = torch.device('cuda')

model = get_gpt2_model(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

batch_times_brain = run_epoch(
    DataMode.BRAIN,
    model,
    optimizer,
    criterion,
    device=device
)

Training: 100%|██████████| 4774/4774 [15:07<00:00,  5.26it/s, avg_batch_size=64.0, avg_batch_time=0.1861s, avg_loss=2.7787, avg_seq_len=640.0, batch_size=26, batch_time=0.0858s, loss=2.5784, seq_len=640]

Avg loss: 2.7787
Avg batch time: 0.1861s
Throughput: 344.13 samples/s
Total samples: 305498
Avg batch size: 63.99
Avg sequence length: 640.00
Batch time stats: min=0.0858s, max=0.2001s, mean=0.1861s, median=0.1861s





In [16]:
import torch
import time
import warnings
import os
from tqdm import tqdm

# Suppress specific warnings
warnings.filterwarnings("ignore", message="enable_nested_tensor is True")
warnings.filterwarnings("ignore", message="This DataLoader will create")
warnings.filterwarnings("ignore", message="The current process just got forked")

# Fix tokenizers parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [89]:
device = torch.device('cuda')

model = get_gpt2_model(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

batch_times_big_brain = run_epoch(
    DataMode.BIG_BRAIN,
    model,
    optimizer,
    criterion,
    device=device
)

Training: 100%|██████████| 4774/4774 [15:07<00:00,  5.26it/s, avg_batch_size=64.0, avg_batch_time=0.1860s, avg_loss=2.7496, avg_seq_len=640.0, batch_size=26, batch_time=0.0850s, loss=2.3751, seq_len=640]

Avg loss: 2.7496
Avg batch time: 0.1860s
Throughput: 344.25 samples/s
Total samples: 305498
Avg batch size: 63.99
Avg sequence length: 640.00
Batch time stats: min=0.0850s, max=0.1935s, mean=0.1860s, median=0.1860s





In [31]:
import numpy as np
np.sum([1 if len(t) > 640 else 0 for t in tokenized_samples['input_ids']]) / len(tokenized_samples['input_ids'])

0.16741844463793543

For batch 64 at most always there are samples with seq_len > 640 in batch, so we dont see avg_batch_time improvement

In [33]:
import json
with open('brain.json', 'w') as f:
    json.dump(batch_times_brain, f)

In [35]:
with open('big_brain.json', 'w') as f:
    json.dump(batch_times_big_brain, f)

In [65]:
device = torch.device('cuda')

model = get_gpt2_model(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

batch_times_ultra_big_brain_k10 = run_epoch(
    DataMode.ULTRA_BIG_BRAIN,
    model,
    optimizer,
    criterion,
    device=device,
    n_bins=20,
    k=10
)

Training: 100%|██████████| 4774/4774 [08:17<00:00,  9.59it/s, avg_batch_size=64.0, avg_batch_time=0.1011s, avg_loss=5.2030, avg_seq_len=324.5, batch_size=64, batch_time=0.0767s, loss=4.9446, seq_len=246]


Avg loss: 5.2030
Avg batch time: 0.1011s
Throughput: 633.29 samples/s
Total samples: 305498
Avg batch size: 63.99
Avg sequence length: 324.53
Batch time stats: min=0.0214s, max=0.2078s, mean=0.1011s, median=0.0947s


In [67]:
device = torch.device('cuda')

model = get_gpt2_model(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

batch_times_ultra_big_brain_k1 = run_epoch(
    DataMode.ULTRA_BIG_BRAIN,
    model,
    optimizer,
    criterion,
    device=device,
    n_bins=20,
    k=1
)

Training: 100%|██████████| 4774/4774 [08:18<00:00,  9.58it/s, avg_batch_size=64.0, avg_batch_time=0.1011s, avg_loss=5.2091, avg_seq_len=324.5, batch_size=64, batch_time=0.1216s, loss=5.0844, seq_len=406] 


Avg loss: 5.2091
Avg batch time: 0.1011s
Throughput: 633.22 samples/s
Total samples: 305498
Avg batch size: 63.99
Avg sequence length: 324.53
Batch time stats: min=0.0219s, max=0.2013s, mean=0.1011s, median=0.0945s


In [68]:
device = torch.device('cuda')

model = get_gpt2_model(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

batch_times_ultra_big_brain_k50 = run_epoch(
    DataMode.ULTRA_BIG_BRAIN,
    model,
    optimizer,
    criterion,
    device=device,
    n_bins=20,
    k=50
)

Training: 100%|██████████| 4774/4774 [08:18<00:00,  9.57it/s, avg_batch_size=64.0, avg_batch_time=0.1011s, avg_loss=5.1857, avg_seq_len=324.5, batch_size=64, batch_time=0.1042s, loss=4.9141, seq_len=347]


Avg loss: 5.1857
Avg batch time: 0.1011s
Throughput: 633.10 samples/s
Total samples: 305498
Avg batch size: 63.99
Avg sequence length: 324.53
Batch time stats: min=0.0219s, max=0.2019s, mean=0.1011s, median=0.0948s


In [None]:
def run_epoch_ultra_duper_big_brain(
    data_mode: DataMode,
    model,
    optimizer,
    criterion,
    device,
    warmup_batches=3,
    n_bins: int | None = None,
    k: int | None = None,
    packing_type: str | None = None,
):
    dataloader = get_dataloader(data_mode, n_bins, k, packing_type)
    model.train().to(device)

    scaler = torch.amp.GradScaler(device.type) if device.type == "cuda" else None

    batch_times = []
    total_loss = 0.0
    total_samples = 0
    total_batch_size = 0
    total_seq_len = 0

    nhead = model.transformer_encoder.layers[0].self_attn.num_heads

    progress_bar = tqdm(dataloader, desc="Training")

    for batch_idx, batch in enumerate(progress_bar):
        is_warmup = batch_idx < warmup_batches

        if device.type == "cuda":
            torch.cuda.synchronize()

        if not is_warmup:
            start_time = time.perf_counter()

        input_ids = batch["input_ids"].to(device)
        seq_ids   = batch["seq_ids"].to(device)

        inputs  = input_ids[:, :-1]
        targets = input_ids[:, 1:]
        seq_ids = seq_ids[:, :-1]

        B, L = inputs.shape

        causal = torch.tril(
            torch.ones(L, L, device=device, dtype=torch.bool)
        )

        same_seq = seq_ids.unsqueeze(1) == seq_ids.unsqueeze(2)

        allowed = causal & same_seq
        allowed.diagonal(dim1=-2, dim2=-1).fill_(True)

        attn_mask = ~allowed

        attn_mask = (
            attn_mask.unsqueeze(1)
            .expand(-1, nhead, -1, -1)
            .reshape(B * nhead, L, L)
        )

        total_batch_size += B
        total_seq_len += B * L

        src = inputs.transpose(0, 1)
        tgt_y = targets.reshape(-1)

        optimizer.zero_grad(set_to_none=True)

        if scaler:
            with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
                out = model(src, attn_mask)
                logits = out.transpose(0, 1).reshape(-1, out.size(-1))
                loss = criterion(logits, tgt_y)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(src, attn_mask)
            logits = out.transpose(0, 1).reshape(-1, out.size(-1))
            loss = criterion(logits, tgt_y)
            loss.backward()
            optimizer.step()

        if device.type == "cuda":
            torch.cuda.synchronize()

        if not is_warmup:
            batch_time = time.perf_counter() - start_time
            batch_times.append(batch_time)

        batch_loss = loss.item()
        total_loss += batch_loss * B
        total_samples += B

        if is_warmup:
            progress_bar.set_postfix(
                status="warmup",
                loss=f"{batch_loss:.4f}",
                batch_size=B,
                seq_len=L,
            )
        else:
            avg_loss = total_loss / total_samples
            avg_batch_time = sum(batch_times) / len(batch_times)
            avg_batch_size = total_batch_size / (batch_idx + 1)
            avg_seq_len = total_seq_len / total_batch_size

            progress_bar.set_postfix(
                loss=f"{batch_loss:.4f}",
                avg_loss=f"{avg_loss:.4f}",
                batch_time=f"{batch_time:.4f}s",
                avg_batch_time=f"{avg_batch_time:.4f}s",
                batch_size=B,
                seq_len=L,
                avg_batch_size=f"{avg_batch_size:.1f}",
                avg_seq_len=f"{avg_seq_len:.1f}",
            )

    time_stats = {}
    if batch_times:
        time_array = np.array(batch_times)
        time_stats = {
            'batch_time_min': float(np.min(time_array)),
            'batch_time_max': float(np.max(time_array)),
            'batch_time_mean': float(np.mean(time_array)),
            'batch_time_median': float(np.median(time_array)),
            'batch_time_std': float(np.std(time_array)),
            'num_batches': len(batch_times)
        }
        
        total_time = sum(batch_times)
        throughput = total_samples / total_time
        avg_time = total_time / len(batch_times)
        final_avg_loss = total_loss / total_samples
        
        final_avg_batch_size = total_batch_size / len(progress_bar)
        final_avg_seq_len = total_seq_len / total_batch_size
        
        print(f"Avg loss: {final_avg_loss:.4f}")
        print(f"Avg batch time: {avg_time:.4f}s")
        print(f"Throughput: {throughput:.2f} samples/s")
        print(f"Total samples: {total_samples}")
        print(f"Avg batch size: {final_avg_batch_size:.2f}")
        print(f"Avg sequence length: {final_avg_seq_len:.2f}")
        print(f"Batch time stats: min={time_stats['batch_time_min']:.4f}s, "
              f"max={time_stats['batch_time_max']:.4f}s, "
              f"mean={time_stats['batch_time_mean']:.4f}s, "
              f"median={time_stats['batch_time_median']:.4f}s")
    
    stats = {
        'avg_loss': final_avg_loss if batch_times else 0,
        'throughput': throughput if batch_times else 0,
        'total_samples': total_samples,
        'avg_batch_size': final_avg_batch_size if batch_times else 0,
        'avg_seq_len': final_avg_seq_len if batch_times else 0,
        'time_stats': time_stats,
    }
    
    return stats


In [42]:
device = torch.device('cuda')

model = get_gpt2_model(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)


batch_times_ultra_duper_big_brain_basic = run_epoch_ultra_duper_big_brain(
    DataMode.ULTRA_DUPER_BIG_BRAIN,
    model,
    optimizer,
    criterion,
    device=device,
    packing_type='basic'
)

Training: 100%|██████████| 2420/2420 [16:57<00:00,  2.38it/s, avg_batch_size=64.0, avg_batch_time=0.4152s, avg_loss=5.6307, avg_seq_len=639.0, batch_size=64, batch_time=0.4112s, loss=4.9201, seq_len=639]

Avg loss: 5.6307
Avg batch time: 0.4152s
Throughput: 154.34 samples/s
Total samples: 154880
Avg batch size: 64.00
Avg sequence length: 639.00
Batch time stats: min=0.4087s, max=4.3061s, mean=0.4152s, median=0.4127s





In [48]:
device = torch.device('cuda')

model = get_gpt2_model(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)


batch_times_ultra_duper_big_brain_ffd = run_epoch_ultra_duper_big_brain(
    DataMode.ULTRA_DUPER_BIG_BRAIN,
    model,
    optimizer,
    criterion,
    device=device,
    packing_type='ffd'
)

Training: 100%|██████████| 2421/2421 [16:59<00:00,  2.38it/s, avg_batch_size=64.0, avg_batch_time=0.4154s, avg_loss=5.6092, avg_seq_len=639.0, batch_size=8, batch_time=0.1762s, loss=4.7823, seq_len=639] 

Avg loss: 5.6092
Avg batch time: 0.4154s
Throughput: 154.19 samples/s
Total samples: 154888
Avg batch size: 63.98
Avg sequence length: 639.00
Batch time stats: min=0.1762s, max=3.8509s, mean=0.4154s, median=0.4130s





In [39]:
device = torch.device('cuda')

model = get_gpt2_model(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)


batch_times_ultra_duper_big_brain_obfd = run_epoch_ultra_duper_big_brain(
    DataMode.ULTRA_DUPER_BIG_BRAIN,
    model,
    optimizer,
    criterion,
    device=device,
    packing_type='obfd'
)

Training: 100%|██████████| 2421/2421 [16:56<00:00,  2.38it/s, avg_batch_size=64.0, avg_batch_time=0.4144s, avg_loss=5.6000, avg_seq_len=639.0, batch_size=8, batch_time=0.2089s, loss=5.0453, seq_len=639] 

Avg loss: 5.6000
Avg batch time: 0.4144s
Throughput: 154.56 samples/s
Total samples: 154888
Avg batch size: 63.98
Avg sequence length: 639.00
Batch time stats: min=0.2089s, max=4.2795s, mean=0.4144s, median=0.4118s





In [None]:
import pandas as pd
df = pd.read_json('report (1).json')
df

Unnamed: 0,index,config,avg_loss,total_samples,avg_batch_size,avg_seq_len,batch_time_min,batch_time_max,batch_time_mean,batch_time_median
0,0,ultrabigbrain k=1,5.209138,305498,63.99204,324.526956,0.021912,0.20131,0.101121,0.094524
1,1,ultrabigbrain k=10,5.20304,305498,63.99204,324.526956,0.021415,0.207838,0.101111,0.094678
2,2,ultrabigbrain k=50,5.185742,305498,63.99204,324.526956,0.02185,0.201899,0.101141,0.094805
3,3,big brain,2.749565,305498,63.99204,640.0,0.085038,0.193472,0.186007,0.186009
4,4,brain,2.778737,305498,63.99204,640.0,0.08579,0.200119,0.18607,0.186064
5,6,ultraduperbigbrain obfd,5.600036,154888,63.976869,639.0,0.208876,4.279453,0.41443,0.411835
6,7,ultraduperbigbrain basic,5.630688,154880,64.0,639.0,0.408687,4.306136,0.415173,0.412682
7,8,ultraduperbigbrain ffd,5.609242,154888,63.976869,639.0,0.176223,3.850917,0.415424,0.41296
