In [34]:
# !pip install torchtune
# !pip install torchao
# !pip install wandb


import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from torchtune.modules import RMSNorm
from tokenizers import Tokenizer
from pathlib import Path
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler 
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import wandb

In [65]:
import os

def setup(rank=None, world_size=None):
    # os.environ['MASTER_ADDR'] = 'localhost' 
    # os.environ['MASTER_PORT'] = '12355'  
    init_process_group("nccl")

def cleanup():
    destroy_process_group()



In [35]:
#Collab setup
from pathlib import Path
data_path = Path('data')
data_path.mkdir(exist_ok=True)
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!cp input.txt data/input.txt


--2025-04-25 16:25:53--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-04-25 16:25:54 (2.63 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [36]:

@dataclass
class ModelArgs:
    #Hyperparameters

    block_size = 256
    batch_size = 64
    embeddings_dims = 768
    attn_dropout = 0.1
    no_of_heads = 8 #IMP needs to be thoroughly calculated
    dropout = 0.1
    epochs = 100
    max_lr = 2.5e-4
    no_of_decoder_layers = 8 #IMP needs to be thoroughly calculated
    weight_decay_optim = 0.1
    beta_1 = 0.9
    beta_2 = 0.95
    device = 'cuda'
    no_kv_heads = 2
    scaling_factor = 0.5
    vocab_size = 2000
    local_block_size = 128
    base_freq=10000

In [37]:
#Datasets

# Using tinyshakespeare

with open('data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()


In [38]:
def save_checkpoint(model):
    ckp = model.module.state_dict()
    torch.save(ckp, "checkpoint.pt")
    print("Checkpoint saved")


In [39]:

#Subword level tokenization

#Loading custom trained BPE
# Load the tokenizer
# tokenizer = Tokenizer.from_file("data/bpe_tokenizer_tinyshakespeare_1k.json")
# vocab_size = tokenizer.get_vocab_size()
# Encode and decode functions
# encode = lambda s: tokenizer.encode(s).ids
# decode = lambda l: tokenizer.decode(l)





###############################################################################
#Character level tokenization

# # here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)


# create a mapping from characters to integers
stoi = { ch: i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - ModelArgs.block_size, (ModelArgs.batch_size,))
    x = torch.stack([data[i:i+ModelArgs.block_size] for i in ix])
    y = torch.stack([data[i+1:i+ModelArgs.block_size+1] for i in ix])
    x, y = x.to(ModelArgs.device), y.to(ModelArgs.device)
    return x, y

In [40]:
class Normalization(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):  
        super().__init__()
        self.rmsnorm_layer = RMSNorm(dim=embeddings_dims)
        
        
    def forward(self, x):
        
        x = self.rmsnorm_layer(x)
        return x
        

In [55]:
# import numpy as np
class RotaryEmbeddings(nn.Module):
    def __init__(
        self,
         device,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size,
        batch_size: int = ModelArgs.batch_size,
        scaling_factor: float = 0.5,
    ):
        super().__init__()

        self.embeddings_dims = embeddings_dims
        self.block_size = block_size
        self.batch_size = batch_size
        self.scaling_factor = scaling_factor
        self.theta = 0
        self.device=device

    def apply_rope(self, seq, base_freq):
        batch_size, seq_len, embeds_dims = seq.shape
        token_indices = torch.arange(0 , seq_len, dtype=torch.float32,  device = self.device).unsqueeze(1)
        positions = torch.arange(0 , self.embeddings_dims, 2, dtype=torch.float32,  device = self.device).unsqueeze(0)
        theta = base_freq ** (-2 * (positions * self.scaling_factor) / self.embeddings_dims) #Position Interpolation
        angles = token_indices * theta
        angles = angles.expand(seq_len, -1) # because this thing needs to be applied to every sequence in the batch but with embeds dims halved
        x_reshaped = seq.view(batch_size, seq_len, self.embeddings_dims // 2, 2)
        
        cos_angles = torch.cos(angles)
        sin_angles = torch.sin(angles)


        out = torch.stack([x_reshaped[..., 0]*cos_angles - (x_reshaped[...,1] * sin_angles), x_reshaped[...,1] * cos_angles + x_reshaped[..., 0] * sin_angles], dim=1)
        out = out.view(batch_size, seq_len, embeds_dims)
        return out

    def forward(self, x, base_freq):

        res = self.apply_rope(x,base_freq=base_freq)
        return res 
    
    


In [56]:
class MQA(nn.Module):
    def __init__(
        self,
        device,
        no_of_q_heads: int,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size,
        

    ):
        super().__init__()


        # self.no_of_q_heads = no_of_heads // no_of_kv_heads
        # self.no_of_q_heads = no_of_q_heads
        self.no_of_kv_heads = 2 # I want to have a kv for each pair of query heads 
        self.head_size = embeddings_dims // no_of_q_heads
        # self.kv_head_size = (embeddings_dims // self.no_of_kv_heads) * 2
        self.rotary= RotaryEmbeddings(embeddings_dims=self.head_size,  device = device)
        # self.rotary_k = RotaryEmbeddings(embeddings_dims=self.kv_head_size,  device = device)
        # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,  bias=False)
        self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,  dtype=torch.float32, bias=False,  device = device)
        self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,  dtype=torch.float32, bias=False,  device = device)
        self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
        self.linear_layer = nn.Linear(in_features=self.head_size * self.no_of_kv_heads, out_features=embeddings_dims,  dtype=torch.float32, bias=False,  device = device)
        self.device = device
        self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=self.head_size,  bias=False,  device = self.device) for _ in range(self.no_of_kv_heads)])

    def scaled_dot_product(self, q, k, v, block_size, base_freq):

            # masked = torch.tril(torch.ones((block_size, block_size),  requires_grad=False,  device = self.device))
            q = self.rotary(q, base_freq)
            masked_table = torch.tril(torch.ones((block_size, block_size),  requires_grad=False,  device = self.device))
            # rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
            # rotary_key = matrix @ k.permute(1,2,0)  #  (B,T, C,C  ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
            # print("Query: ", q.shape)
            # print("Keys: ", k.shape)
            # print(q.permute(2,0,1).shape)
            # print(k.permute(2,0,1).transpose(-2, -1).shape)
            # weights = q.permute(2,0,1) @ k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
            # weights = q @ k.permute(2,1,0)
            # print(weights.shape)
            # print(masked.shape)
            weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
            masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
            weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
            weights_normalized = self.dropout(weights_normalized)
            out = weights_normalized @ v
            return out

    def forward(self,x, base_freq=10000):
        # print("MQA: ", x.shape)
        batch, block_size, embeddings_dims = x.shape

        # query = self.query(x)
        # matrix = self.rotary_matrix(block_size)


        key = self.key(x)
        values = self.value(x)
        # print("Keys: ", key.shape)
        # print("Values: ", values.shape)
        # rotary_value = self.rotary(values)
        rotary_key = self.rotary(key, base_freq)
        multi_query_concat = torch.cat([self.scaled_dot_product(query(x), rotary_key, values, block_size, base_freq) for query in self.multi_query], dim=-1)
        # print("Multi query: ", multi_query_concat.shape)

        linear_layer= self.linear_layer(multi_query_concat)
        # out = self.dropout(linear_layer)
        return linear_layer

In [57]:
class GQA(nn.Module):
    def __init__(
        self,
         device,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size,
        # no_of_q_heads: int = ModelArgs.no_of_heads,
        mqa_heads: int = ModelArgs.no_kv_heads
    ):
        super().__init__()

        # self.no_of_kv_heads = no_of_kv_heads
        self.no_of_q_heads = ModelArgs.no_of_heads // mqa_heads
        # self.head_dim = embeddings_dims // self.no_kv_heads
        self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
        self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_q_heads, out_features=embeddings_dims , dtype=torch.float32,  bias=False,  device = device)
        self.device = device
        self.mqa = nn.ModuleList([MQA(no_of_q_heads=self.no_of_q_heads, embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_q_heads)])
        # self.mqa = MQA(no_of_q_heads=self.no_of_q_heads, device=self.device, embeddings_dims=embeddings_dims, block_size=block_size)
    def forward(self,x, base_freq):

        batch, block_size, embeddings_dims = x.shape

        # res = self.mqa(x)
        grouped_query_concat = torch.cat([group(x, base_freq) for group in self.mqa], dim=-1)

        linear_layer= self.linear_layer(grouped_query_concat) #Basically MQA is made into GQA with no_of_q_heads and this class right here is just to consolidate everything into one
        out = self.dropout(linear_layer)
        return out



In [58]:
class Swish(nn.Module):
    def __init__(
        self,
        device,
        block_size: int = ModelArgs.block_size,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):
        super().__init__()

        self.sig = torch.nn.Sigmoid()


    def forward(self, x):
        swish = x * self.sig(x)

        return swish

In [59]:
class SWiGLU(nn.Module):
    def __init__(
        self,
        device,
        block_size: int = ModelArgs.block_size,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):
        super().__init__()
        self.hidden_dims = int(2 * ( 4 * embeddings_dims) / 3)
        self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device)
        self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims,  bias=False, dtype=torch.float32,  device = device)
        self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims,  bias=False, dtype=torch.float32,  device = device)
        self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims,  bias=False, dtype=torch.float32,  device = device)




    def forward(self, x):
        swish_res = self.swish(self.linear_layer1(x))
        x_V = self.linear_layer2(x)
        res = torch.mul(swish_res, x_V)
        out = self.linear_layer3(res)
        return out

In [60]:

class FFN(nn.Module):
    def __init__(self,
                  device,
                  embeddings_dims: int = ModelArgs.embeddings_dims,
                  block_size: int = ModelArgs.block_size,
                  vocab_size: int = ModelArgs.vocab_size,
                   dropout = ModelArgs.dropout

                 ):
        super().__init__()

        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims,  dtype=torch.float32,  device = device)
        self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims,  device = device)
        # self.dropout = nn.Dropout(p = dropout)
    def forward(self, x):

        x = self.swiglue(x)
        x = self.linear_layer(x)
        # x = self.dropout(x)
        return x


In [61]:
class DecoderLayer(nn.Module):
    def __init__(self,
                device,
                embeddings_dims: int = ModelArgs.embeddings_dims,
                dropout = ModelArgs.dropout,
                block_size: int = ModelArgs.block_size,
                vocab_size: int = ModelArgs.vocab_size,

                 ) :
        super().__init__()

        # self.base_freq = ModelArgs.base_freq
        self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size,  device = device)
        self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, mqa_heads=2,  device = device)
        # self.norm = Normalization(embeddings_dims=embeddings_dims)
        self.norm1 = Normalization(embeddings_dims=embeddings_dims)
        self.norm2 = Normalization(embeddings_dims=embeddings_dims)
        self.dropout = nn.Dropout(p = dropout)
    def forward(self, x, base_freq):

        x = x + self.gqa(self.norm1(x), base_freq)
        x = x + self.feedforward_network(self.norm2(x))
        return x


In [62]:
class Gemma(nn.Module):
    def __init__(self,
                    device,
                  embeddings_dims: int = ModelArgs.embeddings_dims,
                  no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers,
                  block_size: int = ModelArgs.block_size,
                  vocab_size: int = ModelArgs.vocab_size,
                  dropout = ModelArgs.dropout

                 ) :
        super().__init__()
        self.base_freq = ModelArgs.base_freq
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims,  dtype=torch.float32,  device = device)
        self.decoder = nn.ModuleList(DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout,  device = device) for _ in range(no_of_decoder_layers))
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size,  dtype=torch.float32,  device = device)
        self.dropout = nn.Dropout(p = dropout)
        # self.norm = Normalization(embeddings_dims)
        
        
        #weight tying
        # self.embeddings.weight = self.linear_layer.weight
    
        self.apply(self._init_weights)

    def _init_weights(self, module):
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
               
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
               
                     
                    
    def forward(self, x):
        global_base_freq = 100000 
        local_base_freq = 10000
        index = 0
        no_heads = ModelArgs.no_of_heads
        x = self.embeddings(x)
        x = self.dropout(x)
        # x = self.decoder(x)
        for layer in self.decoder:
            if no_heads %5 == 0:
                x = layer(x, global_base_freq)
            else:
                x = x[:, : index + ModelArgs.local_block_size, :]
                x = layer(x, local_base_freq)
                index += ModelArgs.local_block_size
        x = self.linear_layer(x)
        # out = self.norm(x)
        return x

In [63]:
# Instantiating the model
# device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
# ModelArgs.device = device
model = Gemma(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=ModelArgs.device)
model = model.to(ModelArgs.device)

# model = DDP(model, device_ids=[gpu_ids])

In [64]:
#Printing a summary of the architecture
from torchinfo import summary
idx, targets = get_batch('test')
idx = idx.to(ModelArgs.device)
summary(model=model,
        input_data=idx,
        # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
Gemma (Gemma)                                                [64, 256]            [64, 128, 2000]      --                   True
├─Embedding (embeddings)                                     [64, 256]            [64, 256, 768]       1,536,000            True
├─Dropout (dropout)                                          [64, 256, 768]       [64, 256, 768]       --                   --
├─ModuleList (decoder)                                       --                   --                   --                   True
│    └─DecoderLayer (0)                                      [64, 128, 768]       [64, 128, 768]       --                   True
│    │    └─Normalization (norm1)                            [64, 128, 768]       [64, 128, 768]       768                  True
│    │    └─GQA (gqa)                                        [64, 128, 768]       [64, 128, 76

In [None]:
def train():
    setup()
    device=torch.distributed.get_rank()
    
    
    # rank = torch.distributed.get_rank()
    print(f"Start running basic DDP example on rank {device}.")
    # # create model and move it to GPU with id rank
    # device_id = rank % torch.cuda.device_count()
    # CFG = ModelArgs()
    
    if(device == 0):
        # Initialise run
        wandb.init(
            # entity = 'rajceo2031',
                        project = 'Llama-DDP',
                        # config = CFG,
                        # save_code = True,
                        #group = 'ANN',
                        #job_type = 'train'
)
    
    model = Gemma(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device)
    # Optimizer setup and scheduler steup
    torch.cuda.set_device(device)
    model = model.cuda()
        
    # Wrap model with DDP after moving to GPU
    model = DDP(model, find_unused_parameters=False)
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=ModelArgs.max_lr)
    val_loader = DataLoader(val_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler = DistributedSampler(val_dataset))
    train_loader = DataLoader(train_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler = DistributedSampler(train_dataset))
    
    
   
        
    save_chechpoint_iter = 100
    total_iters = 25000
    eval_iters = 100
    # for X,y in train_loader:
    #     print(X.shape)
    #     print(y.shape)

     # Only create progress bar for rank 0
    # eval_epoch_iterator = range(eval_iters)
    train_epoch_iterator = range(total_iters)
    if device == 0:
        train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training")
        # eval_epoch_iterator = tqdm(eval_epoch_iterator, desc='Validation')
   
    # lr_scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max= total_steps - initial_iters)
    val_iter = iter(loader)
    @torch.inference_mode()
    def estimate_loss():
        out = {}
        
        model.eval()
        loader = None
        # print("Starting the eval...")
        for split in ['train', 'val']:
            print(f"Starting with {split} evaluation...")
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):
                # idx, targets = get_batch(split=split)
                if(split == 'train'):
                    loader = train_loader
                else:
                    loader = val_loader
                    
                # for idx, targets in loader:
                idx, targets = next(val_iter)
                idx = idx.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)
                logits = model(idx)
                batch_size, block_size, embeddings_dims = logits.shape
                logits = logits.view(batch_size*block_size, embeddings_dims) # Total tokens(words) => batch_size * block_size
                targets = targets.view(batch_size * block_size)
                loss = nn.functional.cross_entropy(logits, targets)
                losses[k] = loss.item()

                # if device == 0:
                #     eval_epoch_iterator.set_postfix({"loss": f"{loss.item():.4f}"})
                    
            out[split] = losses.mean()
            

        model.train()
        return out

    # model = model.to(rank)
    model.train()
    iterator = iter(train_loader)
    # for step in tqdm(range(total_iters)):
    for step in train_epoch_iterator:
        train_loader.sampler.set_epoch(step)
        val_loader.sampler.set_epoch(step)

        # every once in a while evaluate the loss on train and val sets
        if (step  % eval_iters == 0 and step != 0) or step == total_iters - 1:
            losses = estimate_loss()
            # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
            if device == 0:  # Only print on main process
                print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
                # Log training loss more frequently
        # if device == 0:
                wandb.log({
                    "training_step_loss": losses['train'],
                    "val_step_loss": losses['val'],
                    "step": step
                })
        if(step % save_chechpoint_iter == 0 and device == 0 and step != 0):
            print(f"Saving the model checkpoint for step: {step}")
            save_checkpoint(model)
        
        
       
        # idx, targets = get_batch(split='train')
        # print(f"Starting the train step: {step}...")
        # for idx, targets in train_loader:
        idx, targets = next(iterator)
        idx = idx.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)
        logits = model(idx)
        batch_size, block_size, embeddings_dims = logits.shape
        logits = logits.view(batch_size*block_size, embeddings_dims)
        targets = targets.view(batch_size * block_size)
        loss = nn.functional.cross_entropy(logits, targets)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        
        # if device == 0:

    # Cleanup
    if device == 0:
        wandb.finish()
    cleanup()

    
# %% [code] {"execution":{"iopub.execute_input":"2025-02-09T16:05:17.608500Z","iopub.status.busy":"2025-02-09T16:05:17.608220Z","iopub.status.idle":"2025-02-09T16:05:19.612034Z","shell.execute_reply":"2025-02-09T16:05:19.610871Z","shell.execute_reply.started":"2025-02-09T16:05:17.608481Z"}}
world_size = torch.cuda.device_count()
print(f"World size: {world_size}")
train()

