In [48]:
import os
import math
import time
import inspect
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import csv
import urllib.request
import multiprocessing as mp
import numpy as np
import tiktoken
from tqdm import tqdm


In [49]:
class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y

In [50]:
class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

In [51]:
class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [52]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 6
    n_head: int = 8
    n_embd: int = 512


In [53]:
class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # weight sharing scheme
        self.transformer.wte.weight = self.lm_head.weight

        # init params
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss


    def configure_optimizers(self, weight_decay, learning_rate, device_type):
          # start with all of the candidate parameters (that require grad)
          param_dict = {pn: p for pn, p in self.named_parameters()}
          param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}


          decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
          print(type(decay_params[0]))
          nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
          optim_groups = [
              {'params': decay_params, 'weight_decay': weight_decay},
              {'params': nodecay_params, 'weight_decay': 0.0}
          ]

          num_decay_params = sum(p.numel() for p in decay_params)

          num_nodecay_params = sum(p.numel() for p in nodecay_params)
          if master_process:
              print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
              print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
          fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
          use_fused = fused_available and device_type == "cuda"
          if master_process:
              print(f"using fused AdamW: {use_fused}")
          optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
          return optimizer


In [54]:
# Initialize tokenizer
enc = tiktoken.get_encoding("gpt2")
eot = enc._special_tokens['<|endoftext|>']

In [None]:


# Settings
local_dir = "twitter_finetune_shards"
shard_size = int(1e8)        # 100 million tokens per training shard
val_shard_size = shard_size // 10  # 10 million tokens for validation shard (small val set)
os.makedirs(local_dir, exist_ok=True)

# Sentiment140 CSV URL (public dataset)
csv_url = "https://cs.stanford.edu/people/alecmgo/trainingandtestdata.zip"
zip_file = "trainingandtestdata.zip"
csv_file = "training.1600000.processed.noemoticon.csv"

# Download & unzip if not present
if not os.path.exists(csv_file):
    print("Downloading Sentiment140 dataset (81MB)...")
    urllib.request.urlretrieve(csv_url, zip_file)
    import zipfile
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        zip_ref.extractall()
    os.remove(zip_file)
else:
    print("Dataset CSV found, skipping download.")



def tokenize(row):
    text = row[-1]  # tweet text is last column in CSV
    tokens = [eot] + enc.encode_ordinary(text)
    return np.array(tokens, dtype=np.uint16)

def write_datafile(filename, tokens_np):
    np.save(filename, tokens_np)

# Read CSV safely with utf-8 and error replace to avoid decode errors
print("Reading CSV into memory...")
with open(csv_file, encoding='utf-8', errors='replace') as f:
    reader = csv.reader(f)
    rows = list(reader)

print(f"Total tweets: {len(rows)}")

# Multiprocessing pool for tokenizing tweets
nprocs = max(1, os.cpu_count() // 2)
with mp.Pool(nprocs) as pool:
    shard_index = 0
    is_val_shard = True
    all_tokens_np = np.empty((val_shard_size,), dtype=np.uint16)
    token_count = 0
    progress_bar = None

    for tokens in pool.imap(tokenize, rows, chunksize=64):
        current_shard_size = val_shard_size if is_val_shard else shard_size

        if token_count + len(tokens) < current_shard_size:
            # Append tokens to current shard buffer
            all_tokens_np[token_count:token_count+len(tokens)] = tokens
            token_count += len(tokens)
            # Setup progress bar on first write
            if progress_bar is None:
                progress_bar = tqdm(total=current_shard_size, unit="tokens",
                                    desc=f"{'Val' if is_val_shard else 'Train'} Shard {shard_index}")
            progress_bar.update(len(tokens))
        else:
            # Fill remainder of current shard and save it
            remainder = current_shard_size - token_count
            all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
            split = "val" if is_val_shard else "train"
            filename = os.path.join(local_dir, f"twitter_{split}_{shard_index:06d}.npy")
            write_datafile(filename, all_tokens_np)
            shard_index += 1
            if progress_bar:
                progress_bar.close()
                progress_bar = None
            # leftover tokens start the next shard
            leftover = tokens[remainder:]
            if is_val_shard:
                is_val_shard = False
                all_tokens_np = np.empty((shard_size,), dtype=np.uint16)
            all_tokens_np[:len(leftover)] = leftover
            token_count = len(leftover)

    # Write last partial shard if any tokens remain
    if token_count != 0:
        split = "val" if is_val_shard else "train"
        filename = os.path.join(local_dir, f"twitter_{split}_{shard_index:06d}.npy")
        write_datafile(filename, all_tokens_np[:token_count])
        if progress_bar:
            progress_bar.close()

print(f"✅ Done! Shards saved in '{local_dir}'")


Downloading Sentiment140 dataset (81MB)...
Reading CSV into memory...
Total tweets: 1600000


Val Shard 0: 100%|█████████▉| 9999981/10000000 [00:12<00:00, 778902.57tokens/s]
Train Shard 1:  24%|██▍       | 24472289/100000000 [00:33<01:17, 969243.83tokens/s]

In [None]:
def load_tokens(filename):
    npt = np.load(filename)
    npt = npt.astype(np.int32)
    ptt = torch.tensor(npt, dtype=torch.long)
    return ptt


In [None]:
class DataLoaderLite:
    def __init__(self, B, T, process_rank, num_processes, split):
        self.B = B
        self.T = T
        self.process_rank = process_rank
        self.num_processes = num_processes
        assert split in {'train', 'val'}

        data_root = "/content/twitter_finetune_shards"
        shards = os.listdir(data_root)
        shards = [s for s in shards if split in s]
        shards = sorted(shards)
        shards = [os.path.join(data_root, s) for s in shards]
        self.shards = shards
        assert len(shards) > 0, f"no shards found for split {split}"
        if master_process:
            print(f"found {len(shards)} shards for split {split}")
        self.reset()

    def reset(self):
        self.current_shard = 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = self.B * self.T * self.process_rank

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        self.current_position += B * T * self.num_processes
        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = load_tokens(self.shards[self.current_shard])
            self.current_position = B * T * self.process_rank
        return x, y

In [None]:
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist


In [None]:
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
if ddp:
    assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
    init_process_group(backend='nccl')
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0  #this process will do logging, checkpointing etc.
else:
    #non-DDP run
    ddp_rank = 0
    ddp_local_rank = 0
    ddp_world_size = 1
    master_process = True

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        device = "mps"
    print(f"using device: {device}")

device_type = "cuda" if device.startswith("cuda") else "cpu"

torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

enc = tiktoken.get_encoding("gpt2")

total_batch_size = 512*8
B = 8
T = 512
assert total_batch_size % (B * T * ddp_world_size) == 0, "make sure total_batch_size is divisible by B * T * ddp_world_size"
grad_accum_steps = total_batch_size // (B * T * ddp_world_size)
if master_process:
    print(f"total desired batch size: {total_batch_size}")
    print(f"=> calculated gradient accumulation steps: {grad_accum_steps}")

train_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train")
val_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val")

torch.set_float32_matmul_precision('high')

In [None]:
model = GPT(GPTConfig(vocab_size=50304))
model.to(device)
use_compile = False
if use_compile:
    model = torch.compile(model)
if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model


max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 715
max_steps = 19073 #

In [None]:
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_steps:
        return max_lr * (it+1) / warmup_steps
    # 2) if it > lr_decay_iters, return min learning rate
    if it > max_steps:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
    return min_lr + coeff * (max_lr - min_lr)

In [None]:
optimizer = raw_model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device_type)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
log_dir = "log"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"log.txt")
with open(log_file, "w") as f: # open for writing to clear the file
    pass

for step in range(max_steps):
    t0 = time.time()
    last_step = (step == max_steps - 1)

    # once in a while evaluate our validation loss
    if step % 250 == 0 or last_step:
        model.eval()
        val_loader.reset()
        with torch.no_grad():
            val_loss_accum = 0.0
            val_loss_steps = 20
            for _ in range(val_loss_steps):
                x, y = val_loader.next_batch()
                x, y = x.to(device), y.to(device)
                with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
                    logits, loss = model(x, y)
                loss = loss / val_loss_steps
                val_loss_accum += loss.detach()
        if ddp:
            dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
        if master_process:
            print(f"validation loss: {val_loss_accum.item():.4f}")
            with open(log_file, "a") as f:
                f.write(f"{step} val {val_loss_accum.item():.4f}\n")
            if step >= 0 and (step % 500 == 0 or last_step):
                # optionally write model checkpoints
                checkpoint_path = os.path.join('/content/drive/MyDrive/model.pt')
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'config': raw_model.config,
                    'step': step,
                    'val_loss': val_loss_accum.item()
                }

                torch.save(checkpoint, checkpoint_path)


    # do one step of the optimization
    model.train()
    optimizer.zero_grad()
    loss_accum = 0.0
    for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)

        if ddp:
            model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(x, y)
        # we have to scale the loss to account for gradient accumulation,
        # because the gradients just add on each successive backward().
        # addition of gradients corresponds to a SUM in the objective, but
        # instead of a SUM we want MEAN. Scale the loss here so it comes out right
        loss = loss / grad_accum_steps
        loss_accum += loss.detach()
        loss.backward()
    if ddp:
        dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    # determine and set the learning rate for this iteration
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    optimizer.step()
    if device_type == "cuda":
        torch.cuda.synchronize() # wait for the GPU to finish work
    t1 = time.time()
    dt = t1 - t0 # time difference in seconds
    tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
    tokens_per_sec = tokens_processed / dt
    if master_process:
        print(f"step {step:5d} | loss: {loss_accum.item():.6f} | lr {lr:.4e} | norm: {norm:.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
        with open(log_file, "a") as f:
            f.write(f"{step} train {loss_accum.item():.6f}\n")

if ddp:
    destroy_process_group()

In [None]:
import torch
import torch.nn.functional as F

@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    """
    Generate new tokens iteratively until end of text or max length is reached.

    Args:
        model: The trained PyTorch model.
        idx: Initial sequence of tokens (shape (B, T)).
        max_new_tokens: Maximum number of tokens to generate.
        temperature: Softmax temperature for sampling.
        top_k: Consider only the top_k most likely tokens.

    Returns:
        Generated sequence of tokens (shape (B, T + max_new_tokens)).
    """
    for _ in range(max_new_tokens):
        # Crop context if longer than model block size
        idx_cond = idx if idx.size(1) <= model.config.block_size else idx[:, -model.config.block_size:]
        # Forward pass to get logits
        logits, _ = model(idx_cond)
        # Take logits for last token and scale by temperature
        logits = logits[:, -1, :] / temperature
        # Optionally filter logits to top_k tokens only
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        # Softmax to get probabilities
        probs = F.softmax(logits, dim=-1)
        # Sample from distribution
        idx_next = torch.multinomial(probs, num_samples=1)
        # Append sampled token
        idx = torch.cat((idx, idx_next), dim=1)
        # Stop if EOS token sampled
        if idx_next.item() == enc._special_tokens['<|endoftext|>']:
            break
    return idx






In [None]:
m = GPT(GPTConfig(vocab_size=50304))


In [32]:
device='cuda'
checkpoint_path = '/content/drive/MyDrive/model.pt'
checkpoint = torch.load(checkpoint_path, weights_only=False)
m = GPT(checkpoint['config'])
m.load_state_dict(checkpoint['model'])
m.to(device)
print(f"Model loaded from checkpoint saved at step {checkpoint['step']}")

Model loaded from checkpoint saved at step 0


In [46]:
start_token = torch.tensor([[enc._special_tokens['<|endoftext|>'],]], device='cuda', dtype=torch.long)

generated_tokens = generate(
    model=m,
    idx=start_token,
    max_new_tokens=100,
    temperature=1.0,
    top_k=50
)

print("Generated token IDs:", generated_tokens)

Generated token IDs: tensor([[50256, 13025, 33796, 45987,  7660, 12039, 48189,  7257, 14425,  5093,
         11555, 23209, 18802, 31419, 46849, 24778, 17119, 43870, 41993, 25259,
         46761, 18380, 22090, 19634,  9131, 40115, 14922, 14922, 12434, 46351,
         35763,  2387,  7670,  3491, 24878,  1700, 20780, 29899,  4355,  1601,
         12434, 25960, 14074, 42423, 32591, 11156, 36263, 43628, 35339, 47625,
         39863, 37538, 14922, 43048, 18542, 24532, 47982, 39744, 24651, 10839,
         11286, 46220, 48227, 24276,  7026, 41749, 27129, 44486, 29674,   253,
         25843, 30222,  3833, 36342, 44448, 31034, 12465, 46059, 10461, 42456,
          1787, 29377, 28560, 39767, 28756,  1366, 48012, 31283, 31283, 12289,
         12494, 47696, 30845, 24414, 33485,  4168, 30084, 43489,  8609, 31283,
         39245]], device='cuda:0')


In [47]:
print(enc.decode(generated_tokens.tolist()[0]))

<|endoftext|> reasonablyImportant026 finger Must PSU CA denial north velmc ERAICLE felonMartin warriorsAMIdinand pitchers771Softribes glow LA HOiopiop Driver925 Hiroshiresvy star semester record bree Guysacityury Driver Cambod Ottawa unconditional 302 assemb unrestrictedencryptedpkgviation goodies Gilesiop Wouldn illustrated Eugeneowicz invoking sexist voicesagen goblins Occupations sting cheap Yusarious unilaterally Mohammad�ravel Taken pressure videogierrez (~ tremendous propell Evilhaust pot BRE 340 Chromebookoccupied data footprints boun boun mothersootingBloomberg BlackBerry 1939helm skin demographics445rix boun Paso
