<a href="https://colab.research.google.com/github/Aditya-Shandilya1182/neo/blob/main/train_parallel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import tiktoken
from tiktoken import get_encoding
import torch.nn as nn
from torch.nn import functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from dataclasses import dataclass
from datasets import load_dataset
from tqdm import tqdm
import pickle
import numpy as np
from torch.amp import GradScaler, autocast

In [None]:
device = 'cuda'

In [None]:
class MLP(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.w1 = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False, dtype=config.d_type)
    self.w2 = nn.Linear(4 * config.n_embd, config.n_embd, bias=False, dtype=config.d_type)
    self.dropout = nn.Dropout(config.dropout)

  def forward(self, x):
    x_w1 = self.w1(x)
    x = F.silu(x_w1)
    x = self.w2(x)
    x = self.dropout(x)
    return x

In [None]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0, "Embedding size must be divisible by the number of heads"
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False, dtype=config.d_type)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False, dtype=config.d_type)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer(
            "tril",
            torch.tril(torch.ones(config.block_size, config.block_size, dtype=torch.bool))
        )
        self.att_dropout = nn.Dropout(config.dropout)
        self.dropout = config.dropout

    def forward(self, x):
        B,T,C = x.size()
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        y = self.att_dropout(y)
        return y

In [None]:
class Block(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.attention = Attention(config)
    self.feed_forward = MLP(config)
    self.attention_norm = nn.RMSNorm(config.n_embd, dtype=config.d_type)
    self.ffn_norm = nn.RMSNorm(config.n_embd, dtype=config.d_type)

  def forward(self, x):
    x = x + self.attention(self.attention_norm(x))
    x = x + self.feed_forward(self.ffn_norm(x))
    return x

In [None]:
class Neo(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd, dtype=config.d_type)
        self.position_embedding = nn.Embedding(config.block_size, config.n_embd, dtype=config.d_type)
        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        self.ln_f = nn.RMSNorm(config.n_embd, dtype=config.d_type)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, dtype=config.d_type)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, index, targets=None):
        B, T = index.shape

        tok_emb = self.token_embedding(index)
        pos_emb = self.position_embedding(torch.arange(T, device=self.config.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    def generate(self, index, max_new_tokens):
        for _ in range(max_new_tokens):
            b_s = self.config.block_size
            index_cond = index[:, -b_s:]
            logits, loss = self.forward(index_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            index_next = torch.multinomial(probs, num_samples=1)
            index = torch.cat((index, index_next), dim=1)
        return index

In [None]:
@dataclass
class ModelConfig:
    batch_size: int = 8
    block_size: int = 512
    max_iters: int = 4000
    learning_rate: float = 1e-4
    eval_iters: int = 500
    n_embd: int = 512
    n_head: int = 8
    n_layer: int = 12
    head_size: int = 512
    d_type: torch.dtype = torch.float32
    vocab_size: int = 50257
    dropout: float = 0.1
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
dataset = load_dataset("roneneldan/TinyStories")
tokenizer = tiktoken.get_encoding("gpt2")
encoded_train_chunks = []
for sample in dataset['train']:
    encoded_train_chunks.append(tokenizer.encode(sample['text']))
train_encoded = torch.tensor([token for chunk in encoded_train_chunks for token in chunk], dtype=torch.long)
encoded_val_chunks = []
for sample in dataset['validation']:
    encoded_val_chunks.append(tokenizer.encode(sample['text']))
val_encoded = torch.tensor([token for chunk in encoded_val_chunks for token in chunk], dtype=torch.long)
print(len(train_encoded))
print(len(val_encoded))

In [None]:
train_encoded.to(device)
val_encoded.to(device)

In [None]:
config = ModelConfig()

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        data = train_encoded if split == 'train' else val_encoded

        if data.size(0) <= block_size:
            raise ValueError(f"{split.capitalize()} dataset size is too small for the requested block size.")

        losses = torch.zeros(eval_iters)

        for k in range(eval_iters):
            ix = torch.randint(0, data.size(0) - block_size, (batch_size,))
            x = torch.stack([data[i:i+block_size] for i in ix])
            y = torch.stack([data[i+1:i+block_size+1] for i in ix])

            x, y = x.to(device), y.to(device)
            logits, loss = model(x, y)
            losses[k] = loss.mean().item()

        out[split] = losses.mean().item()

    model.train()
    return out

In [None]:
train_len = len(train_encoded)
val_len = len(val_encoded)

In [None]:
max_iters = config.max_iters
gradient_accumulation_steps = 8
eval_iters = config.eval_iters
block_size = config.block_size
batch_size = config.batch_size
max_grad_norm = 1.0

model = Neo(config)
model = nn.DataParallel(model, device_ids=[0, 1])
model = model.to(device)
model = torch.compile(model)


print(f"Using devices: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}")
print(next(model.parameters()).device)

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
scaler = GradScaler()

train_losses = []
val_losses = []

for iter in range(max_iters):
    print(iter)

    if iter % eval_iters == 0:
        losses = estimate_loss()
        print(f"step: {iter}, train loss: {losses['train']:.3f}, val loss: {losses['val']:.3f}")

    ix = torch.randint(train_len - block_size, (batch_size,))
    x = torch.stack([train_encoded[i:i+block_size] for i in ix])
    y = torch.stack([train_encoded[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)

    with autocast(device_type='cuda', dtype=torch.float16):
        logits, loss = model(x, y)
        loss = loss / gradient_accumulation_steps

    loss_mean = loss.mean()
    scaler.scale(loss_mean).backward()

    if (iter + 1) % gradient_accumulation_steps == 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        print(f"Loss at step {iter + 1}: {loss_mean.item() * gradient_accumulation_steps:.3f}")
        train_losses.append(loss_mean.item() * gradient_accumulation_steps)


In [None]:
with open('model-dp-mp.pkl', 'wb') as f:
    pickle.dump(model, f)
print('model saved')

In [None]:
with open("/kaggle/input/tiny_test_dp_mp/pytorch/default/1/model-dp-mp.pkl", "rb") as f:
    model = pickle.load(f)

model = model.to(device)
model.eval()

prompt = 'Hello! Can you see me?'
context = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)  # (1, T)

if hasattr(model, 'module'):
    model_to_use = model.module
else:
    model_to_use = model

generated = model_to_use.generate(context, max_new_tokens=100)

generated_chars = tokenizer.decode(generated[0].tolist())
print(generated_chars)


Hello! Can you see me?Tim and I'm sorry. supply you did not changing a key to takeack. He wants angrily. She does not like the one. He hopes the slide some more careful."

The net was the place, and low, feeling a picture of the blue and Tom and a big suspicion together. He realized that he is alone.

The unavoid, the bear was thoughtful and started to his favorite was so to believe she was not tricked her tasty. She was so junk and great game
