# Libraries

In [1]:
from google.colab import drive
import locale
locale.getpreferredencoding = lambda: "UTF-8"
!pip install tqdm
!pip install datasets
# !pip install hellaswag
!pip install tiktoken
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
import torch
import numpy as np
import time
from torch.amp import autocast, GradScaler
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
import random
import os
import multiprocessing as mp
from datasets import load_dataset
from transformers import ElectraTokenizer
import glob
import shutil
import tiktoken
import math
import inspect
from dataclasses import dataclass
# from hellaswag import render_example, iterate_examples

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Collecting datasets
  Downloading datasets-3.3.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.0-py3-none-any.whl (484 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m484.9/484.9 kB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

# Prepare Dataset

In [2]:
import os
import numpy as np
import tiktoken
from datasets import load_dataset
from tqdm import tqdm
import multiprocessing as mp

local_dir = "/content/edu_fineweb10B"
remote_name = "sample-10BT"
shard_size = int(1e8)

DATA_CACHE_DIR = os.path.join(local_dir)
os.makedirs(DATA_CACHE_DIR, exist_ok=True)


enc = tiktoken.get_encoding("gpt2")
eot = enc._special_tokens['<|endoftext|>']
def tokenize(doc):
    tokens = [eot]
    tokens.extend(enc.encode_ordinary(doc["text"]))
    tokens_np = np.array(tokens)
    assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
    tokens_np_uint16 = tokens_np.astype(np.uint16)
    return tokens_np_uint16

def write_datafile(filename, tokens_np):
    np.save(filename, tokens_np)

def prepare_fineweb():
    fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train", streaming=True)
    nprocs = max(1, os.cpu_count()//2)
    with mp.Pool(nprocs) as pool:
        shard_index = 0
        all_tokens_np = np.empty((shard_size,), dtype=np.uint16)
        token_count = 0
        progress_bar = None
        for tokens in pool.imap(tokenize, fw, chunksize=16):

            if token_count + len(tokens) < shard_size:
                all_tokens_np[token_count:token_count+len(tokens)] = tokens
                token_count += len(tokens)
                if progress_bar is None:
                    progress_bar = tqdm(total=shard_size, unit="tokens", desc=f"Shard {shard_index}")
                # progress_bar.update(len(tokens))
            else:
                split = "val" if shard_index == 0 else "train"
                filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
                remainder = shard_size - token_count
                progress_bar.update(remainder)
                all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
                write_datafile(filename, all_tokens_np)
                shard_index += 1
                progress_bar = None
                all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]
                token_count = len(tokens)-remainder

        if token_count != 0:
            split = "val" if shard_index == 0 else "train"
            filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
            write_datafile(filename, all_tokens_np[:token_count])

In [3]:
from concurrent.futures import ThreadPoolExecutor
drive.mount('/content/drive')
def copy_to_drive():
  google_drive_dir = "/content/drive/My Drive/datasets/fineweb10B/"
  os.makedirs(google_drive_dir, exist_ok=True)
  local_output_dir = "/content/edu_fineweb10B"

  for filename in os.listdir(local_output_dir):
    local_file_path = os.path.join(local_output_dir, filename)
    google_drive_file_path = os.path.join(google_drive_dir, filename)
    shutil.copy(local_file_path, google_drive_file_path)
    print(f"Copied {filename} to Google Drive.")

  print(f"All files copied to Google Drive at {google_drive_dir}")

def copy_from_drive():
  google_drive_dir = "/content/drive/My Drive/datasets/fineweb10B/"
  os.makedirs(google_drive_dir, exist_ok=True)
  local_output_dir = "/content/edu_fineweb10B"
  files = os.listdir(google_drive_dir)

  for filename in files:
    local_file_path = os.path.join(local_output_dir, filename)
    google_drive_file_path = os.path.join(google_drive_dir, filename)
    shutil.copy(google_drive_file_path, local_file_path)
    print(f"Copied {filename} to Disk.")

  print(f"All files copied to Disk at {google_drive_dir}")

# shuffle shards
import re
def shuffle_shards(output_dir):
    shard_files = [f for f in os.listdir(output_dir) if f.startswith("edufineweb_train_")]

    random.shuffle(shard_files)

    for idx, old_filename in enumerate(shard_files):
        new_filename = f"edufineweb_train_{idx:04d}.bin"

        old_filepath = os.path.join(output_dir, old_filename)
        new_filepath = os.path.join(output_dir, new_filename)

        temp_filepath = os.path.join(output_dir, f"temp_{old_filename}")

        shutil.move(old_filepath, temp_filepath)

        shutil.move(new_filepath, old_filepath) if os.path.exists(new_filepath) else None

        shutil.move(temp_filepath, new_filepath)

    print(f"Shuffling complete. Shards renamed and shuffled in {output_dir}")

def load_tokens(filename):
    npt = np.load(filename)
    npt = npt.astype(np.int32)
    ptt = torch.tensor(npt, dtype=torch.long)
    return ptt

class DataLoaderLite:
    def __init__(self, B, T, process_rank, num_processes, split):
        self.B = B
        self.T = T
        self.process_rank = process_rank
        self.num_processes = num_processes
        assert split in {'train', 'val'}

        data_root = "/content/edu_fineweb10B"
        shards = os.listdir(data_root)
        shards = [s for s in shards if split in s]
        shards = sorted(shards)
        shards = [os.path.join(data_root, s) for s in shards]
        self.shards = shards
        assert len(shards) > 0, f"no shards found for split {split}"
        self.reset()

    def reset(self):
        self.current_shard = 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = self.B * self.T * self.process_rank

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T)
        y = (buf[1:]).view(B, T)

        self.current_position += B * T * self.num_processes

        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = load_tokens(self.shards[self.current_shard])
            self.current_position = B * T * self.process_rank
        return x, y

Mounted at /content/drive


# Model

In [4]:
class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"

        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_emb = self.transformer.wpe(pos)
        tok_emb = self.transformer.wte(idx)
        x = tok_emb + pos_emb

        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    def configure_optimizers(self, weight_decay, learning_rate, device_type):
        param_dict = {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        if master_process:
            print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
            print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == "cuda"
        if master_process:
            print(f"using fused AdamW: {use_fused}")
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
        return optimizer


# Dataset

In [5]:
# prepare_fineweb()
# copy_to_drive()
copy_from_drive()
# shuffle_shards("/content/edu_fineweb10B")

Copied edufineweb_train_000009.npy to Disk.
Copied edufineweb_train_000023.npy to Disk.
Copied edufineweb_train_000042.npy to Disk.
Copied edufineweb_train_000064.npy to Disk.
Copied edufineweb_train_000041.npy to Disk.
Copied edufineweb_train_000096.npy to Disk.
Copied edufineweb_train_000005.npy to Disk.
Copied edufineweb_train_000083.npy to Disk.
Copied edufineweb_train_000068.npy to Disk.
Copied edufineweb_train_000094.npy to Disk.
Copied edufineweb_train_000084.npy to Disk.
Copied edufineweb_train_000051.npy to Disk.
Copied edufineweb_train_000088.npy to Disk.
Copied edufineweb_train_000024.npy to Disk.
Copied edufineweb_train_000095.npy to Disk.
Copied edufineweb_train_000047.npy to Disk.
Copied edufineweb_train_000015.npy to Disk.
Copied edufineweb_train_000076.npy to Disk.
Copied edufineweb_train_000007.npy to Disk.
Copied edufineweb_train_000080.npy to Disk.
Copied edufineweb_train_000053.npy to Disk.
Copied edufineweb_train_000062.npy to Disk.
Copied edufineweb_train_000060.n

# Model Initialisation and training

In [7]:
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
    assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
    init_process_group(backend='nccl')
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0
else:
    ddp_rank = 0
    ddp_local_rank = 0
    ddp_world_size = 1
    master_process = True
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        device = "mps"
    print(f"using device: {device}")

device_type = "cuda" if device.startswith("cuda") else "cpu"

torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

enc = tiktoken.get_encoding("gpt2")

total_batch_size = 524288
B = 32
# B = 2
T = 1024
assert total_batch_size % (B * T * ddp_world_size) == 0, "make sure total_batch_size is divisible by B * T * ddp_world_size"
grad_accum_steps = total_batch_size // (B * T * ddp_world_size)
if master_process:
    print(f"total desired batch size: {total_batch_size}")
    print(f"=> calculated gradient accumulation steps: {grad_accum_steps}")

train_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train")
val_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val")

torch.set_float32_matmul_precision('high')

using device: cuda
total desired batch size: 524288
=> calculated gradient accumulation steps: 16


In [None]:
import time
import torch
from torch.cuda.amp import autocast, GradScaler

# Assuming you've already defined your model, optimizer, and data loader
model = GPT(GPTConfig(vocab_size=50304))
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

# Initialize the scaler for mixed precision training (optional)
scaler = GradScaler()

# Load one batch (assuming you have a data loader named data_loader)
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)

# Start training loop
start_time_2 = time.time()

# Run 50 times to overfit the model on the single batch
for i in range(50):
    start_time = time.time()
    optimizer.zero_grad()

    # Forward pass with autocast for mixed precision (optional)
    with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
        logits, loss = model(x, y)

    # Backward pass with gradient scaling (optional for mixed precision)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    # Sync GPU if using CUDA (optional but recommended for performance)
    torch.cuda.synchronize()

    # Record time per step
    end_time = time.time()
    step_duration = end_time - start_time
    print(f"Step {i+1} loss {loss.item():.6f} | Duration: {step_duration:.4f} seconds")

# After the loop, print the average step duration
average_duration = (time.time() - start_time_2) / 50
print(f"Average step duration: {average_duration:.4f} seconds for 50 steps")


  scaler = GradScaler()


Step 1 loss 10.999573 | Duration: 0.8661 seconds
Step 2 loss 9.917404 | Duration: 0.8082 seconds
Step 3 loss 9.520599 | Duration: 0.8066 seconds
Step 4 loss 9.314896 | Duration: 0.8038 seconds
Step 5 loss 9.124435 | Duration: 0.8084 seconds
Step 6 loss 8.765011 | Duration: 0.8172 seconds
Step 7 loss 8.949883 | Duration: 0.8162 seconds
Step 8 loss 8.822571 | Duration: 0.8165 seconds
Step 9 loss 8.450050 | Duration: 0.8178 seconds
Step 10 loss 9.753695 | Duration: 0.8205 seconds
Step 11 loss 8.065765 | Duration: 0.8226 seconds
Step 12 loss 7.982419 | Duration: 0.8209 seconds
Step 13 loss 7.748437 | Duration: 0.8299 seconds
Step 14 loss 7.354215 | Duration: 0.8277 seconds
Step 15 loss 7.290989 | Duration: 0.8301 seconds
Step 16 loss 7.034863 | Duration: 0.8360 seconds
Step 17 loss 6.662466 | Duration: 0.8384 seconds
Step 18 loss 6.520720 | Duration: 0.8427 seconds
Step 19 loss 6.236246 | Duration: 0.8509 seconds
Step 20 loss 5.964995 | Duration: 0.8448 seconds
Step 21 loss 5.787529 | Dura

KeyboardInterrupt: 

In [8]:
model = GPT(GPTConfig(vocab_size=50304))
model.to(device)
use_compile = True
if use_compile:
    model = torch.compile(model)
if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model

max_lr = 9e-4
min_lr = max_lr * 0.1
warmup_steps = int(715 / 4)
max_steps = int(19073 / 4)
def get_lr(it):
    if it < warmup_steps:
        return max_lr * (it+1) / warmup_steps
    if it > max_steps:
        return min_lr
    decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (max_lr - min_lr)

optimizer = raw_model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device_type)
save_interval = 1000

log_dir = "log"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"log.txt")
with open(log_file, "w") as f:
    pass

for step in range(int(max_steps * 3)):
    t0 = time.time()
    last_step = (step == max_steps - 1)

    if step % 250 == 0 or last_step:
        model.eval()
        val_loader.reset()
        with torch.no_grad():
            val_loss_accum = 0.0
            val_loss_steps = 20
            for _ in range(val_loss_steps):
                x, y = val_loader.next_batch()
                x, y = x.to(device), y.to(device)
                with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
                    logits, loss = model(x, y)
                loss = loss / val_loss_steps
                val_loss_accum += loss.detach()
        if ddp:
            dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
        if master_process:
            print(f"validation loss: {val_loss_accum.item():.4f}")
            with open(log_file, "a") as f:
                f.write(f"{step} val {val_loss_accum.item():.4f}\n")
            if step > 0 and (step % 5000 == 0 or last_step):
                checkpoint_path = os.path.join(log_dir, f"model_{step:05d}.pt")
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'config': raw_model.config,
                    'step': step,
                    'val_loss': val_loss_accum.item()
                }
                # you might also want to add optimizer.state_dict() and
                # rng seeds etc., if you wanted to more exactly resume training
                torch.save(checkpoint, checkpoint_path)

    # once in a while evaluate hellaswag
    # if (step % 250 == 0 or last_step) and (not use_compile):
    #     num_correct_norm = 0
    #     num_total = 0
    #     for i, example in enumerate(iterate_examples("val")):
    #         # only process examples where i % ddp_world_size == ddp_rank
    #         if i % ddp_world_size != ddp_rank:
    #             continue
    #         # render the example into tokens and labels
    #         _, tokens, mask, label = render_example(example)
    #         tokens = tokens.to(device)
    #         mask = mask.to(device)
    #         # get the logits
    #         with torch.no_grad():
    #             with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
    #                 logits, loss = model(tokens)
    #             pred_norm = get_most_likely_row(tokens, mask, logits)
    #         num_total += 1
    #         num_correct_norm += int(pred_norm == label)
    #     # reduce the stats across all processes
    #     if ddp:
    #         num_total = torch.tensor(num_total, dtype=torch.long, device=device)
    #         num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
    #         dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
    #         dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
    #         num_total = num_total.item()
    #         num_correct_norm = num_correct_norm.item()
    #     acc_norm = num_correct_norm / num_total
    #     if master_process:
    #         print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
    #         with open(log_file, "a") as f:
    #             f.write(f"{step} hella {acc_norm:.4f}\n")

    # once in a while generate from the model (except step 0, which is noise)
    # if ((step > 0 and step % 250 == 0) or last_step) and (not use_compile):
    #     model.eval()
    #     num_return_sequences = 4
    #     max_length = 32
    #     tokens = enc.encode("Hello, I'm a language model,")
    #     tokens = torch.tensor(tokens, dtype=torch.long)
    #     tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
    #     xgen = tokens.to(device)
    #     sample_rng = torch.Generator(device=device)
    #     sample_rng.manual_seed(42 + ddp_rank)
    #     while xgen.size(1) < max_length:
    #         # forward the model to get the logits
    #         with torch.no_grad():
    #             with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
    #                 logits, loss = model(xgen) # (B, T, vocab_size)
    #             # take the logits at the last position
    #             logits = logits[:, -1, :] # (B, vocab_size)
    #             # get the probabilities
    #             probs = F.softmax(logits, dim=-1)
    #             # do top-k sampling of 50 (huggingface pipeline default)
    #             # topk_probs here becomes (5, 50), topk_indices is (5, 50)
    #             topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
    #             # select a token from the top-k probabilities
    #             # note: multinomial does not demand the input to sum to 1
    #             ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
    #             # gather the corresponding indices
    #             xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
    #             # append to the sequence
    #             xgen = torch.cat((xgen, xcol), dim=1)
    #     # print the generated text
    #     for i in range(num_return_sequences):
    #         tokens = xgen[i, :max_length].tolist()
    #         decoded = enc.decode(tokens)
    #         print(f"rank {ddp_rank} sample {i}: {decoded}")

    # do one step of the optimization
    model.train()
    optimizer.zero_grad()
    loss_accum = 0.0
    for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)
        # added after video, this field is also used by the forward pass.
        if ddp:
            model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(x, y)
        # we have to scale the loss to account for gradient accumulation,
        # because the gradients just add on each successive backward().
        # addition of gradients corresponds to a SUM in the objective, but
        # instead of a SUM we want MEAN. Scale the loss here so it comes out right
        loss = loss / grad_accum_steps
        loss_accum += loss.detach()
        loss.backward()
    if ddp:
        dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    # determine and set the learning rate for this iteration
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    optimizer.step()
    if device_type == "cuda":
        torch.cuda.synchronize() # wait for the GPU to finish work
    t1 = time.time()
    dt = t1 - t0 # time difference in seconds
    tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
    tokens_per_sec = tokens_processed / dt
    if master_process:
        if step % 100 == 0:

            print(f"step {step:5d} | loss: {loss_accum.item():.6f} | lr {lr:.4e} | norm: {norm:.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
        with open(log_file, "a") as f:
            f.write(f"{step} train {loss_accum.item():.6f}\n")

        if step % save_interval == 0:
            checkpoint_path = os.path.join(log_dir, f"model_{step:05d}.pt")
            checkpoint = {
                'model': raw_model.state_dict(),
                'config': raw_model.config,
                'step': step,
                'optimizer': optimizer.state_dict(),
            }
            torch.save(checkpoint, checkpoint_path)
            print(f"Model checkpoint saved at {checkpoint_path}")

if ddp:
    destroy_process_group()

num decayed parameter tensors: 50, with 124,354,560 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW: True
validation loss: 10.9514
step     0 | loss: 10.955009 | lr 5.0562e-06 | norm: 15.3464 | dt: 50681.09ms | tok/sec: 10344.84
Model checkpoint saved at log/model_00000.pt
step   100 | loss: 6.515263 | lr 5.1067e-04 | norm: 1.0224 | dt: 2706.43ms | tok/sec: 193719.17
step   200 | loss: 5.832395 | lr 8.9995e-04 | norm: 0.4256 | dt: 2707.62ms | tok/sec: 193633.93
validation loss: 5.5848
step   300 | loss: 5.401945 | lr 8.9859e-04 | norm: 0.5298 | dt: 2709.02ms | tok/sec: 193534.32
step   400 | loss: 4.857390 | lr 8.9533e-04 | norm: 0.5464 | dt: 2707.61ms | tok/sec: 193634.78
validation loss: 4.5891
step   500 | loss: 4.550300 | lr 8.9020e-04 | norm: 0.5809 | dt: 4108.06ms | tok/sec: 127624.25
step   600 | loss: 4.154492 | lr 8.8322e-04 | norm: 0.4032 | dt: 2710.48ms | tok/sec: 193429.96
step   700 | loss: 4.169420 | lr 8.7442e-04 | norm: 0.4605

KeyboardInterrupt: 

In [10]:
def copy_to_drive():
  google_drive_dir = "/content/drive/My Drive/models/transformer/"
  os.makedirs(google_drive_dir, exist_ok=True)
  local_output_dir = "/content/log"

  for filename in os.listdir(local_output_dir):
    local_file_path = os.path.join(local_output_dir, filename)
    google_drive_file_path = os.path.join(google_drive_dir, filename)
    shutil.copy(local_file_path, google_drive_file_path)
    print(f"Copied {filename} to Google Drive.")

  print(f"All files copied to Google Drive at {google_drive_dir}")
copy_to_drive()

Copied model_01000.pt to Google Drive.
Copied log.txt to Google Drive.
Copied model_00000.pt to Google Drive.
Copied model_03000.pt to Google Drive.
Copied model_05000.pt to Google Drive.
Copied model_02000.pt to Google Drive.
Copied model_04000.pt to Google Drive.
Copied model_04767.pt to Google Drive.
All files copied to Google Drive at /content/drive/My Drive/models/transformer/


In [11]:
model.eval()
num_return_sequences = 4
max_length = 32
tokens = enc.encode("Hello, I'm a language model,")
tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
xgen = tokens.to(device)
sample_rng = torch.Generator(device=device)
sample_rng.manual_seed(42 + ddp_rank)
while xgen.size(1) < max_length:
    # forward the model to get the logits
    with torch.no_grad():
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(xgen) # (B, T, vocab_size)
        # take the logits at the last position
        logits = logits[:, -1, :] # (B, vocab_size)
        # get the probabilities
        probs = F.softmax(logits, dim=-1)
        # do top-k sampling of 50 (huggingface pipeline default)
        # topk_probs here becomes (5, 50), topk_indices is (5, 50)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
        # select a token from the top-k probabilities
        # note: multinomial does not demand the input to sum to 1
        ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
        # gather the corresponding indices
        xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
        # append to the sequence
        xgen = torch.cat((xgen, xcol), dim=1)
# print the generated text
for i in range(num_return_sequences):
    tokens = xgen[i, :max_length].tolist()
    decoded = enc.decode(tokens)
    print(f"rank {ddp_rank} sample {i}: {decoded}")

rank 0 sample 0: Hello, I'm a language model, and I don't want to go down as if I were to create a robot. It's the way the robot was
rank 0 sample 1: Hello, I'm a language model, so i'm going to explain it back to the basics I've got all along, but this one actually needs some context
rank 0 sample 2: Hello, I'm a language model, so I could see lots of fun with it - especially if your class is a computer language model.
I've added
rank 0 sample 3: Hello, I'm a language model, and now I'm going to show you some.
First, let's go. The following code contains some code used
