## Speed up steps

- bf16  # with torch.autocast(device_type=device, dtype=torch.bfloat16): y = model(x)
- torch.compile  # model = torch.compile(model)
- vocab size in powers of 2
- Use fused=True in Adamw

Gotchas:
- Weight decay only embeddings and matmul weights (not a speed-up, but something new)
- When using DDP, dataloader should not load the same data for all GPUs
- When using DDP, model should be created with same parameters. Either by setting a seed or loading from checkpoint.
- 

## Train on one batch



In [15]:
%reload_ext autoreload
import os
import math
import time
import inspect
from copy import deepcopy
from dataclasses import dataclass

import tiktoken
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchsummary import summary

In [16]:
# Set device  [cuda:1, mps, cpu]
if torch.cuda.is_available():
    device = "cuda:1"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
device

'cuda:1'

In [17]:
torch.set_float32_matmul_precision('high')

In [18]:
from transformers import LlamaTokenizerFast
enc = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") # 32000
enc

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.


LlamaTokenizerFast(name_or_path='hf-internal-testing/llama-tokenizer', vocab_size=32000, model_max_length=2048, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [None]:
model = 

In [10]:
def generate(model, tokens, n=10):
    tokens = deepcopy(tokens)
    for i in range(n):
        logits, _ = model(tokens.to("cuda:1"))  # (B, T, vocab_size)
        # We only care out last token
        next_token_logits = logits[:, -1, :].to("cpu")  # (B, vocab_size)
        probs = F.softmax(next_token_logits, dim=-1)  # (B, vocab_size)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)  # (B, k)
        # Sample
        next_token_indices = torch.multinomial(topk_probs, num_samples=1)  # (B, 1)
        next_token = torch.gather(topk_indices, -1, next_token_indices)  # (B, 1)
        tokens = torch.cat([tokens, next_token], dim=-1)
    return [enc.decode(ts.tolist()) for ts in tokens]

In [11]:
generate(model, tokens)

['Hello, I\'m Indian, my name is Gopal," I said, saying as I looked',
 "Hello, I'm Indian, my name is Chandi Gupta. I am a resident of the",
 "Hello, I'm Indian, my name is Rajan Gandhi and my mother's name is Ch",
 "Hello, I'm Indian, my name is Ajay. My daughter was born on November 28",
 "Hello, I'm Indian, my name is Raj Kumar Sharma and I work for Aichal"]

In [12]:
mini_model = GPT(GPTConfig(block_size=512, vocab_size=tokenizer.n_vocab, n_layer=4, n_head=2, n_embd=128))
generate(mini_model.to("cuda:1"), tokens)

["Hello, I'm Indian, my name isEconomic spells delegatessburgh Desktop681PE PSU Synd aggrav",
 "Hello, I'm Indian, my name is Celeb ranculous baconivably Spemare sidelineoxin elbows",
 "Hello, I'm Indian, my name is golden Int dorm Pai Participants Yi surprisinglyotton Candleatech",
 "Hello, I'm Indian, my name is Guarant Wy catchy Made 2015dfxmonitorAttribute boss presence",
 "Hello, I'm Indian, my name is correlation don optim Farmer Platoarro agree Jiusatotaur"]