In [1]:
from typing import Optional

import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import Sampler, IterableDataset
from transformers import AutoTokenizer
from enum import Enum

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [3]:
from transformer import TransformerModel

def get_gpt2_model(device = 'cpu') -> torch.nn.Module:
    return TransformerModel(
        ntoken=tokenizer.vocab_size,
        d_model=768,
        nhead=8,
        d_hid=1024,
        nlayers=8,
        dropout=0.1
    ).to(device)

In [4]:
def process_wikitext_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return f.read().split('\n = ')

def get_train_samples(data_path):
    articles_1 =  process_wikitext_file(f"{data_path}/train-00000-of-00002.txt")
    articles_2 =  process_wikitext_file(f"{data_path}/train-00001-of-00002.txt")
    articles = articles_1 + articles_2
    return articles

data_path = 'wikitext-103-raw-v1'
articles = get_train_samples(data_path)


In [5]:
tokenized_samples = tokenizer(
    articles,
    return_tensors=None
)

Token indices sequence length is longer than the specified maximum sequence length for this model (643 > 512). Running this sequence through the model will result in indexing errors


In [21]:
len(tokenized_samples['input_ids'])

305498

In [19]:
import numpy as np
np.sum([1 if len(t) > 640 else 0 for t in tokenized_samples['input_ids']]) / len(tokenized_samples['input_ids'])

0.16741844463793543

In [12]:
MAX_LENGTH = 640

import torch


def pad_and_truncate_tokens(
    tokens,
    max_length: int = MAX_LENGTH,
    pad_token_id: int = 0,
    effective: bool = False
):
    max_len_in_batch = max(len(seq) for seq in tokens)
    effective_max = min(max_len_in_batch, max_length) if effective else max_length
    
    padded_sequences = []
    for seq in tokens:
        if len(seq) > effective_max:
            seq = seq[:effective_max]
        if len(seq) < max_length:
            seq = seq + [pad_token_id] * (max_length - len(seq))
        padded_sequences.append(seq)
    
    return torch.tensor(padded_sequences, dtype=torch.long)

pad_and_truncate_tokens(tokenized_samples['input_ids'][:2])

tensor([[  101,  1027, 11748,  ...,     0,     0,     0],
        [  101,  1027, 11247,  ...,  2000,  2367,  3131]])

In [None]:
MAX_LENGTH = 640


class BrainDataset(Dataset):
    def __init__(self, tokenized_samples: str, max_length: int = MAX_LENGTH):
        self.samples = pad_and_truncate_tokens(tokenized_samples['input_ids'])

    def __getitem__(self, idx: int):
        assert idx >= 0 and idx < len(self.samples)
        return self.samples[idx]
    
    def __len__(self):
        return len(self.samples)

In [None]:
class BigBrainDataset(Dataset):
    def __init__(self, tokenized_samples, max_length: int = MAX_LENGTH):
        self.samples = tokenized_samples['input_ids']
        
    def __getitem__(self, idx: int):
        assert idx >= 0 and idx < len(self.samples)
        return self.samples[idx]

    def __len__(self):
        return len(self.samples)

def collate_fn(
    batch, max_length: int = MAX_LENGTH
) -> torch.Tensor:
    """
    Pad each sequence of the incoming sequences list
    :param batch: a list of the objects received from the dataset by __getitem__
    :param max_length: maximum sequence length to pad to (for "Brain" approach only)
    :return: tuple of padded sequences and corresponding training targets
    """
    return pad_and_truncate_tokens(batch, max_length, effective=True)


In [None]:
import torch
from torch.utils.data import Dataset, Sampler
from typing import List, Dict
import random
from collections import defaultdict
import numpy as np

class UltraBigBrainDataset(Dataset):
    def __init__(self, tokenized_samples, max_length: int = 512, n_bins: int = 10):
        self.max_length = max_length
        
        input_ids = tokenized_samples['input_ids']
        self.samples = self._pad_tokens(input_ids, max_length)
        
        self.lengths = [min(max_length, len(seq)) for seq in input_ids]
        self.n_bins = n_bins
        
        self.bins = defaultdict(list)
        min_len = min(self.lengths)
        max_len = max(self.lengths)
        
        if min_len == max_len:
            self.bins[0] = list(range(len(self.samples)))
        else:
            bin_size = (max_len - min_len) / n_bins
            for idx, length in enumerate(self.lengths):
                bin_id = min(n_bins - 1, int((length - min_len) // bin_size))
                self.bins[bin_id].append(idx)
        
        for bin_id in self.bins:
            random.shuffle(self.bins[bin_id])
    
    def _pad_tokens(self, token_lists, max_length):
        padded = []
        for tokens in token_lists:
            truncated = tokens[:max_length]
            if len(truncated) < max_length:
                truncated += [0] * (max_length - len(truncated))
            padded.append(truncated)
        return padded
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return torch.tensor(self.samples[idx], dtype=torch.long)

class UltraBigBrainBatchSampler(Sampler):
    def __init__(self, dataset: UltraBigBrainDataset, batch_size: int, k: int = 10):
        self.dataset = dataset
        self.batch_size = batch_size
        self.k = k
        
        self.length_to_indices = defaultdict(list)
        for idx, length in enumerate(dataset.lengths):
            self.length_to_indices[length].append(idx)
        
        self.sorted_lengths = sorted(self.length_to_indices.keys())
        for length in self.sorted_lengths:
            random.shuffle(self.length_to_indices[length])
        
        self.batches = self._create_batches()
    
    def _create_batches(self):
        batches = []
        length_to_indices = {k: v.copy() for k, v in self.length_to_indices.items()}
        sorted_lengths = self.sorted_lengths.copy()
        
        while sorted_lengths:
            batch = []
            start_length = sorted_lengths[0]
            
            available_lengths = [
                length for length in sorted_lengths 
                if abs(length - start_length) <= self.k
            ]
            
            while len(batch) < self.batch_size and available_lengths:
                length = available_lengths[0]
                
                if not length_to_indices[length]:
                    available_lengths.pop(0)
                    if length in sorted_lengths:
                        sorted_lengths.remove(length)
                    continue
                
                batch.append(length_to_indices[length].pop())
                
                if len(batch) == self.batch_size:
                    break
                
                if not length_to_indices[length]:
                    available_lengths.pop(0)
                    if length in sorted_lengths:
                        sorted_lengths.remove(length)
            
            if batch:
                batches.append(batch)
        
        return batches
    
    def __len__(self):
        return len(self.batches)
    
    def __iter__(self):
        batches = self.batches.copy()
        random.shuffle(batches)
        yield from batches

In [None]:
class UltraDuperBigBrainDataset(Dataset):
    def __init__(self, data_path: str, max_length: int = MAX_LENGTH):
        pass

    def __getitem__(self, idx: int):
        pass

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.benchmark import Timer
from torch.amp import autocast, GradScaler

class DataMode(Enum):
    BRAIN = 1
    BIG_BRAIN = 2
    ULTRA_BIG_BRAIN = 3
    ULTRA_DUPER_BIG_BRAIN = 4

BATCH_SIZE = 32
def get_dataloader(data_mode: DataMode):
    if data_mode == DataMode.BRAIN:
        dataset = BrainDataset(data_path)
        return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=32, pin_memory=True)
    if data_mode == DataMode.BIG_BRAIN:
        dataset = BigBrainDataset(data_path)
        return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=32, pin_memory=True)

In [7]:
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)

In [8]:
def run_epoch(data_mode: DataMode, model, optimizer, criterion, device='cpu'):
    dataloader = get_dataloader(data_mode)
    batch_times = []
    model = model.to(device)
    model.train()
    scaler = GradScaler(device.type)

    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    total_loss = 0
    total_samples = 0

    for batch_idx, batch in enumerate(progress_bar):
        inputs = batch["input_ids"].to(device)
        tgt = inputs[:, 1:]
        inp = inputs[:, :-1]

        src = inp.transpose(0, 1)
        tgt_y = tgt.reshape(-1)
        mask = generate_square_subsequent_mask(src.size(0)).to(device)

        def training_step():
            with autocast(device_type=device.type, dtype=torch.bfloat16):
                out = model(src, mask)
                out = out.transpose(0, 1)
                logits = out.reshape(-1, out.size(-1))
                loss = criterion(logits, tgt_y)
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            return loss

        if batch_idx >= 3:
            timer = Timer(
                stmt="training_step()",
                globals={"training_step": training_step},
                num_threads=torch.get_num_threads(),
            )
            m = timer.timeit(1)
            bt = m.mean
        else:
            bt = 0

        # call training_step exactly once
        loss = training_step()
        batch_loss = loss.item()

        bs = inputs.size(0)
        total_loss += batch_loss * bs
        total_samples += bs

        if batch_idx >= 3 and batch_idx > 0:
            batch_times.append(bt)
            avg_bt = sum(batch_times) / len(batch_times)
            progress_bar.set_postfix(
                loss=f"{batch_loss:.4f}",
                avg=f"{(total_loss/total_samples):.4f}",
                t=f"{bt:.3f}",
                avg_t=f"{avg_bt:.3f}"
            )
        else:
            progress_bar.set_postfix(status="warmup")

    epoch_loss = total_loss / total_samples

    if batch_times:
        avg_bt = sum(batch_times) / len(batch_times)
        print(f"epoch loss: {epoch_loss:.4f}")
        print(f"batch time avg {avg_bt:.3f}s min {min(batch_times):.3f}s max {max(batch_times):.3f}s")
        print(f"throughput: {total_samples / sum(batch_times):.1f} samples/s")

    return epoch_loss


In [None]:
device = torch.device('cuda')

model = get_gpt2_model(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

loss = run_epoch(
    DataMode.BRAIN,
    model,
    optimizer,
    criterion,
    device=device
)



                                                                                                             

KeyboardInterrupt: 