In [1]:
import torch
from torch.utils.data import Dataset
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F

import wandb
import tiktoken
from tiktoken import Encoding
from Transformer import TransformerDecoderWithRoPE




In [2]:
import os

class AsciiTokenizer():
    def __init__(self, vocab_size=128):
        self.vocab_size = vocab_size  # ASCII range
        self.pad_token = 0  # Padding token index
        self.unk_token = 1  # Unknown token index

    def encode(self, text, seq_length):
        tokens = [ord(c) if ord(c) < self.vocab_size else self.unk_token for c in text]
        if len(tokens) < seq_length:
            tokens += [self.pad_token] * (seq_length - len(tokens))  # Pad
        else:
            tokens = tokens[:seq_length]  # Truncate
        return tokens

    def decode(self, tokens):
        chars = [chr(t) if t < self.vocab_size else '?' for t in tokens if t != self.pad_token]
        return ''.join(chars)

class TextDataset(Dataset):
    def __init__(self, dataset_name = "wikipedia", tokenizer=None, seq_length=64):
        self.tokenizer = tokenizer if tokenizer else AsciiTokenizer()
        self.seq_length = seq_length
        if dataset_name == "wikipedia":
            from datasets import load_dataset
            self.dataset = load_dataset("wikimedia/wikipedia", "20231101.en")['train']
        else:
            raise ValueError(f"Dataset {dataset_name} not supported.")


    def __len__(self):
        return len(self.dataset)

    def get_pad_token(self):
        if isinstance(self.tokenizer, AsciiTokenizer):
            return self.tokenizer.pad_token
        elif isinstance(self.tokenizer, Encoding):
            return self.tokenizer.special_tokens["<|pad|>"]
        else:
            raise ValueError("Tokenizer must be either AsciiTokenizer or tiktoken.Encoding")

    def __getitem__(self, idx):
        # Check if using ascii tokenizer
        if isinstance(self.tokenizer, AsciiTokenizer):
            # select only seq_length + 1 tokens for input-target pair

            text = self.dataset[idx]
            # random int between 0 and len(text) - seq_length - 1
            if len(text['text']) < self.seq_length + 1:
                start_idx = 0
            else:
                upper_bound = len(text['text']) - self.seq_length - 1
                if upper_bound <= 0:
                    start_idx = 0
                else:
                    start_idx = torch.randint(0, upper_bound, (1,)).item()
            text['text'] = text['text'][start_idx:start_idx + self.seq_length + 1]

            code = self.tokenizer.encode(text['text'], self.seq_length)
            tokens = torch.tensor(code, dtype=torch.long)

            # pad if needed
            if len(tokens) < self.seq_length + 1:
                padding = torch.full((self.seq_length + 1 - len(tokens),), self.tokenizer.pad_token, dtype=torch.long)
                tokens = torch.cat([tokens, padding], dim=0)
        elif isinstance(self.tokenizer, Encoding):
            text = self.dataset[idx]
            tokens = enc.encode(text)
            tokens = tokens[:self.seq_length + 1]  # Truncate
            tokens = torch.tensor(tokens, dtype=torch.long)
            # pad if needed
            if len(tokens) < self.seq_length + 1:
                padding = torch.full((self.seq_length + 1 - len(tokens),), 0, dtype=torch.long)
                tokens = torch.cat([tokens, padding], dim=0)
        else:
            raise ValueError("Tokenizer must be either AsciiTokenizer or tiktoken.Encoding")
            

        return {"input_ids": tokens[:-1], "target_ids": tokens[1:]}  # Shifted for language modeling

In [3]:
# enc = tiktoken.get_encoding("o200k_base")
# enc = tiktoken.Encoding(
#     name=enc.name,
#     pat_str=enc._pat_str,
#     mergeable_ranks=enc._mergeable_ranks,
#     special_tokens={**enc._special_tokens, "<|pad|>": len(enc._mergeable_ranks)}
# )
#
# enc._special_tokens

In [4]:
# enc.encode("Hello, world!")

In [5]:
dataset = TextDataset(seq_length = 128)

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

In [None]:
# Test both implementations
input_size = 10
hidden_size = 768
batch_size = 256 + 128
seq_len = 512
learning_rate = 0.0001
num_layers = 16
epochs = 50
max_grad_norm = 1.0  # Gradient clipping value
vocab_size = 128
warmup_steps = 2000
file_name = "transformerLargeContext"
checkpoint_dir = f"{file_name}/checkpoints"

# Start a new wandb run to track this script.
run = wandb.init(
    project="smaLLM",
    config={
        "learning_rate": learning_rate,
        "architecture": "Transformer decoder with RoPE",
        "vocab_size": vocab_size,
        "dropout": 0.1,
        "layer_size": hidden_size,
        "seq_length": seq_len,
        "batch_size": batch_size,
        "num_layers": num_layers,
        "dataset": "Wikipedia english",
        "epochs": epochs,
        "grad_clipping": max_grad_norm,
        "mixed_precision": True,
        "optimizer": "AdamW",
        "scheduler": "LambdaLR with warmup",
        "tokenizer": "AsciiTokenizer (128 vocab size)",
        "warmup_steps": warmup_steps,
    },
)

# Create checkpoint directory
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

def save_checkpoint(model, optimizer, scheduler, scaler, epoch, batch, loss, checkpoint_path):
    """Save complete training state"""
    checkpoint = {
        'epoch': epoch,
        'batch': batch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")

def load_checkpoint(checkpoint_path, model, optimizer, scheduler, scaler, device):
    """Load complete training state"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])

    epoch = checkpoint['epoch']
    batch = checkpoint['batch']

    print(f"Checkpoint loaded from epoch {epoch}, batch {batch}")
    return epoch, batch

def find_latest_checkpoint(checkpoint_dir):
    """Find the most recent checkpoint"""
    if not os.path.exists(checkpoint_dir):
        return None

    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')]
    if not checkpoints:
        return None

    # Sort by modification time
    latest = max(checkpoints, key=lambda f: os.path.getmtime(os.path.join(checkpoint_dir, f)))
    return os.path.join(checkpoint_dir, latest)

print("Testing Sequential version:")
model1 = TransformerDecoderWithRoPE(vocab_size=vocab_size, hidden_size=hidden_size, num_layers=num_layers)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

pad_token = dataset.get_pad_token()
criterion = nn.CrossEntropyLoss(ignore_index=pad_token)

optimizer = torch.optim.AdamW(model1.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: min((step+1)/warmup_steps, 1))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model1.to(device)

scaler = GradScaler()

# Try to load checkpoint
start_epoch = 0
start_batch = 0
latest_checkpoint = find_latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
    start_epoch, start_batch = load_checkpoint(latest_checkpoint, model1, optimizer, scheduler, scaler, device)
    start_epoch_offset = start_epoch  # Resume from this epoch
else:
    start_epoch_offset = 0

for epoch in range(start_epoch_offset, epochs):
    # If resuming mid-epoch, start from where we left off
    batch_start = start_batch if epoch == start_epoch_offset else 0

    for k, batch in enumerate(dataloader):
        if k < batch_start:
            continue

        inputs = batch['input_ids'].to(device)
        targets = batch['target_ids'].to(device)
        optimizer.zero_grad()

        # Mixed precision context
        with autocast():
            outputs = model1(inputs)
            loss = criterion(outputs.view(-1, 128), targets.view(-1))

        # Scale gradients, backward pass
        scaler.scale(loss).backward()

        # Gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model1.parameters(), max_grad_norm)

        # Optimizer step and scaler update
        scaler.step(optimizer)
        scaler.update()

        # Scheduler step
        scheduler.step()

        run.log({"train/loss": loss.item(),
                  "epoch": epoch,
                  "batch": k,
                  "lr": scheduler.get_last_lr()[0]
                 })

        if k % 1000 == 0 and k > 0:
            checkpoint_path = f"{checkpoint_dir}/{file_name}_epoch{epoch}_batch{k}.pt"
            save_checkpoint(model1, optimizer, scheduler, scaler, epoch, k, loss.item(), checkpoint_path)

[34m[1mwandb[0m: Currently logged in as: [33mmaciej-lachut[0m ([33mmaciej-lachut-poznan-university-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Testing Sequential version:


  scaler = GradScaler()
  with autocast():


{'<|endoftext|>': 199999, '<|endofprompt|>': 200018, '<|pad|>': 199998}