In [1]:
# !pip install torchtune
# !pip install torchao
# !pip install wandb


import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import tqdm 
from dataclasses import dataclass
from torchtune.modules import RMSNorm
from tokenizers import Tokenizer
from pathlib import Path
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler 
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import wandb
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets


In [None]:
import wandb

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("API_KEY")

wandb.login(key=secret_value_0)

In [2]:
import os

def setup(rank=None, world_size=None):
    # os.environ['MASTER_ADDR'] = 'localhost' 
    # os.environ['MASTER_PORT'] = '12355'  
    init_process_group("nccl")

def cleanup():
    destroy_process_group()



In [3]:
#Collab setup
from pathlib import Path
data_path = Path('data')
data_path.mkdir(exist_ok=True)
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!cp input.txt data/input.txt


--2025-04-28 16:41:18--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-04-28 16:41:18 (2.88 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [5]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf", token='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')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

1

In [None]:


@dataclass
class ModelArgs:
    #Hyperparameters

    block_size = 256
    batch_size = 64
    embeddings_dims = 512
    attn_dropout = 0.1
    no_of_heads = 8 #IMP needs to be thoroughly calculated
    dropout = 0.1
    epochs = 100
    max_lr = 2.5e-4
    no_of_decoder_layers = 6 #IMP needs to be thoroughly calculated
    weight_decay_optim = 0.1
    beta_1 = 0.9
    beta_2 = 0.95
    device = 'cuda:0'
    no_kv_heads = 2
    scaling_factor = 0.5
    vocab_size = len(tokenizer.get_vocab()) + 768
    base_freq=10000
    s = 1.0
    experts=16
    top_experts=1
    noisy_topk = True
    use_checkpointing = False

In [None]:
#Datasets

# Using tinyshakespeare

with open('data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()


In [None]:
def save_checkpoint(model):
    ckp = model.module.state_dict()
    torch.save(ckp, "checkpoint.pt")
    print("Checkpoint saved")


In [None]:

#Subword level tokenization

#Loading custom trained BPE
# Load the tokenizer
# tokenizer = Tokenizer.from_file("data/bpe_tokenizer_tinyshakespeare_1k.json")
# vocab_size = tokenizer.get_vocab_size()
# Encode and decode functions
# encode = lambda s: tokenizer.encode(s).ids
# decode = lambda l: tokenizer.decode(l)





###############################################################################
#Character level tokenization

# # here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)


# create a mapping from characters to integers
stoi = { ch: i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - ModelArgs.block_size, (ModelArgs.batch_size,))
    x = torch.stack([data[i:i+ModelArgs.block_size] for i in ix])
    y = torch.stack([data[i+1:i+ModelArgs.block_size+1] for i in ix])
    x, y = x.to(ModelArgs.device), y.to(ModelArgs.device)
    return x, y

In [None]:
tinystories = True
fw = False
fw_train = None
fw_test = None
if(tinystories):
    
    fw_train = load_dataset("roneneldan/TinyStories", split="train")
    fw_test = load_dataset("roneneldan/TinyStories", split="validation")
    print(fw_train)
    print(fw_test)
if(fw):   
    fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False)
    fw_train = fw_train.train_test_split(test_size=0.01)
    print(fw_train)
    print(fw_train)


In [None]:


def prepare_dataset(split, device, batch_size):
    print("Device is: ", device)
 
    def collate_fn(batch):
        # Extract text data
        texts = [item ["text"] for item in batch]

        input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
        
        input_encodings["labels"] = input_encodings["input_ids"].clone()  # Use `input_ids` as labels
        
        input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:]  # Shift right
        input_encodings["labels"][:, -1] = tokenizer.eos_token_id  # Let the last token be end 
       
        return input_encodings

  
    dataloader = None
    if(tinystories):
        if(split == 'train'):
            data_loader = DataLoader(
            fw_train,
            # generator=generator,
            batch_size=batch_size,
             
            # sampler=DistributedSampler(fw_train, shuffle=True),
            collate_fn=collate_fn,
            drop_last=True,
            shuffle=False
        )
        elif(split == 'val'):
            data_loader = DataLoader(
            fw_test,
              
            
            batch_size=batch_size,
            # sampler=DistributedSampler(fw_test, shuffle=True),
            collate_fn=collate_fn,
            drop_last=True,
            shuffle=False
        )
    elif(fw):
        if(split == 'train'):
            data_loader = DataLoader(
            fw_train['train'],
            batch_size=batch_size,
            
            
            sampler=DistributedSampler(fw_train['train'], shuffle=True),
            collate_fn=collate_fn,
            drop_last=True,
            shuffle=False
    )
        elif(split == 'val'):
            data_loader = DataLoader(
            fw_train['test'],
            batch_size=batch_size,
                # generator=generator,
            sampler=DistributedSampler(fw_train["test"]),
            collate_fn=collate_fn,
              
            drop_last=True,
            shuffle=False
        )
    return data_loader





    

In [None]:

# from andrej karapathy github
def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0):
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    generated_tokens = []
    ModelArgs.inference=True
    for _ in range(max_length):
        with torch.no_grad(), torch.autocast(device_type=ModelArgs.device, dtype=torch.bfloat16):
            outputs = model(input_ids)
            logits = outputs[:, -1, :]
            
            probs = F.softmax(logits, dim=-1)
            
            # Top-k filtering
            top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
            
            
            # Apply temperature scaling
            probs = probs / temperature
            
            # Sample from top-k
            next_token = torch.multinomial(top_k_probs, num_samples=1)
           
            
            # generated_tokens.append(next_token.item())
            
            xcol = torch.gather(top_k_indices, -1, next_token)
            # generated_tokens.append(xcol)
            input_ids = torch.cat([input_ids, xcol], dim=1) #1 because is it the dimension of the sequence
            
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


In [None]:
class Normalization(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):  
        super().__init__()
        self.rmsnorm_layer = RMSNorm(dim=embeddings_dims)
        
        
    def forward(self, x):
        
        x = self.rmsnorm_layer(x)
        return x
        

In [None]:

class Swish(nn.Module):
    def __init__(
        self,
        block_size: int = ModelArgs.block_size,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        device = ModelArgs.device
    ):
        super().__init__()

        self.sig = torch.nn.Sigmoid()


    def forward(self, x):
        swish = x * self.sig(x)

        return swish



class SWiGLUExpertMoE(nn.Module):
    def __init__(
        self,
        block_size: int = ModelArgs.block_size,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        device = ModelArgs.device
    ):
        super().__init__()

        self.hidden_dims = embeddings_dims * 2  #Apply this when memory permits
        self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device)
        self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims,  bias=False, device = device)
        self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims,  bias=False, device = device)
        self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims,  bias=False, device = device)




    def forward(self, x):
        swish_res = self.swish(self.linear_layer1(x))
        x_V = self.linear_layer2(x)
        res = torch.mul(swish_res, x_V)
        out = self.linear_layer3(res)
        return out


#MoE Layer

class MoeLayer(nn.Module):
    def __init__(
        self,
        dropout = ModelArgs.dropout,
        embeddings_size = ModelArgs.embeddings_dims,
        device = ModelArgs.device,
        # inner_dimensional_states: int = 3072
    ):
        super().__init__()

        self.heads = nn.ModuleList([SWiGLUExpertMoE() for _ in range(ModelArgs.experts)])
        self.gate = nn.Linear(in_features=embeddings_size, out_features=ModelArgs.experts, device=device, bias=False)
        self.shared_expert = SWiGLUExpertMoE()
        if(ModelArgs.noisy_topk is True and ModelArgs.use_checkpointing == False):
            self.noise = nn.Linear(in_features=embeddings_size, out_features=ModelArgs.experts, device=device, bias=False)
        # self.outputs = torch.zeros((batch_size,block_size, embeddings_size), device=device) #batch size needs to be defined because we are accessing it explicitly
        self.device = device
        self.shared_expert_out = None
    def forward(self, x):
        # mlp_weights_init = self.mlp.apply(weights_init)
        self.gate_out = self.gate(x) #[bz, seq, num_experts]
        if(ModelArgs.noisy_topk == True and ModelArgs.use_checkpointing == False):
            noise = self.noise(x)
            gaussian_noise = torch.normal(0, 1, size=self.gate_out.shape, device=self.device)
            noisy_router = F.softplus(noise) * gaussian_noise
            noisy_router += self.gate_out
        else:
            noisy_router = self.gate_out
        top_k_values, top_k_indices = torch.topk(noisy_router, k=ModelArgs.top_experts) #[bs, seq len, top k]
        probs = torch.nn.functional.softmax(top_k_values, dim=-1) #[bs, seq len, top k]

        out = 0

        out = torch.zeros_like(x)
        for expert_idx in range(ModelArgs.experts):
            # Create mask for current expert across all top_k positions
            expert_mask = (top_k_indices == expert_idx)
            
            # Sum probabilities for current expert
            expert_weights = (probs * expert_mask).sum(dim=-1)  # [batch, seq_len]
            
            # Get inputs where expert is used
            selected = expert_weights > 0
            if not selected.any():
                continue
            self.shared_expert_out += self.shared_expert(x[selected])
            # Process all selected inputs through expert
            expert_out = self.heads[expert_idx](x[selected]) + self.shared_expert_out
            
            # Weight and accumulate outputs
            out[selected] += expert_out * expert_weights[selected].unsqueeze(-1)

        return out

In [None]:
# import numpy as np
class RotaryEmbeddings(nn.Module):
    def __init__(
        self,
         device,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size,
        batch_size: int = ModelArgs.batch_size,
        scaling_factor: float = 0.5,
    ):
        super().__init__()

        self.embeddings_dims = embeddings_dims
        self.block_size = block_size
        self.batch_size = batch_size
        self.scaling_factor = scaling_factor
        self.theta = 0
        self.device=device

    def apply_rope(self, seq, base_freq):
        batch_size, seq_len, embeds_dims = seq.shape
        token_indices = torch.arange(0 , seq_len, dtype=torch.float32,  device = self.device).unsqueeze(1)
        positions = torch.arange(0 , self.embeddings_dims, 2, dtype=torch.float32,  device = self.device).unsqueeze(0)
        theta = base_freq ** (-2 * (positions * self.scaling_factor) / self.embeddings_dims) #Position Interpolation
        angles = token_indices * theta
        angles = angles.expand(seq_len, -1) # because this thing needs to be applied to every sequence in the batch but with embeds dims halved
        x_reshaped = seq.view(batch_size, seq_len, self.embeddings_dims // 2, 2)
        
        cos_angles = torch.cos(angles)
        sin_angles = torch.sin(angles)


        out = torch.stack([x_reshaped[..., 0]*cos_angles - (x_reshaped[...,1] * sin_angles), x_reshaped[...,1] * cos_angles + x_reshaped[..., 0] * sin_angles], dim=1)
        out = out.view(batch_size, seq_len, embeds_dims)
        return out

    def forward(self, x, base_freq):

        res = self.apply_rope(x,base_freq=base_freq)
        return res 
    
    


In [None]:

class AttentionHead(nn.Module):
    def __init__(
        self,
        attn_dropout = ModelArgs.attn_dropout,
        embeddings_dims = ModelArgs.embeddings_dims,
        no_of_heads = ModelArgs.no_of_heads,
        device = ModelArgs.device
    ):
        super().__init__()
        self.head_size = embeddings_dims // no_of_heads
        self.no_of_heads = no_of_heads
        # if(ModelArgs.use_flash_attention==False):
        self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.keys = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,device=ModelArgs.device, bias=False)
        self.values = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device,bias=False)
    # self.dropout = nn.Dropout(p = attn_dropout)
          
        self.dropout = nn.Dropout(p = attn_dropout)
        self.device = device
       
        self.rotary= RotaryEmbeddings(embeddings_dims=self.head_size,  device = device)
            
    def forward(self, x, rope=False):
        batch_size, block_size, embd_dims = x.shape

        k = self.keys(x)
        q = self.query(x)
        v = self.values(x)
        if(rope):
            q = self.rotary(q)
            k = self.rotary(k)
        masked_table = torch.tril(torch.ones(block_size, block_size, device=ModelArgs.device))
        weights = ( q @ torch.transpose(k, dim0=-2, dim1=-1) * ModelArgs.s * torch.log(q.shape[1])) * (k.shape[-1] ** -0.5)
        masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
        weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
        weights_normalized = self.dropout(weights_normalized)
        out = weights_normalized @ v
        return out
         
# MHA




class MHA(nn.Module):
    def __init__(
        self,
        attn_dropout = ModelArgs.attn_dropout,
        embeddings_dims = ModelArgs.embeddings_dims,
        no_of_heads = ModelArgs.no_of_heads,
        device = ModelArgs.device
    ):
        super().__init__()
        self.no_of_heads = no_of_heads
        self.heads = nn.ModuleList([AttentionHead(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads, device=device) for _ in range(no_of_heads)])
        self.dropout = nn.Dropout(p = attn_dropout)
        self.linear = nn.Linear(in_features=self.no_of_heads * embeddings_dims, out_features=embeddings_dims, device=device, bias=False) # 12 (no of heads) * (batch_size) 64 = 768 -> gives out the text embeddings

    def forward(self, x, rope):
        concat = torch.cat([head(x, rope=rope) for head in self.heads], dim=-1)
        print(concat.shape)
        linear_layer = self.linear(concat)
        out = self.dropout(linear_layer)
        return out



In [None]:

class FFN(nn.Module):
    def __init__(self,
                  device,
                  embeddings_dims: int = ModelArgs.embeddings_dims,
                  block_size: int = ModelArgs.block_size,
                  vocab_size: int = ModelArgs.vocab_size,
                   dropout = ModelArgs.dropout

                 ):
        super().__init__()

        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims,  dtype=torch.float32,  device = device)
        self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims,  dtype=torch.float32, device = device)
        
        self.dropout = nn.Dropout(p = dropout)  # Uncommenting the dropout line
    def forward(self, x):

        x = self.linear_layer(x)
        x = F.gelu(x)
        x = self.linear_layer2(x)
        x = F.gelu(x)
        x = self.dropout(x)  # Uncommenting the dropout line
        return x


In [None]:
class DecoderLayer(nn.Module):
    def __init__(self,
                device,
                attn_dropout: int  = ModelArgs.attn_dropout,
                no_of_heads: int = ModelArgs.no_of_heads,
                embeddings_dims: int = ModelArgs.embeddings_dims,
                dropout = ModelArgs.dropout,
                block_size: int = ModelArgs.block_size,
                vocab_size: int = ModelArgs.vocab_size,

                 ) :
        super().__init__()

        # self.base_freq = ModelArgs.base_freq
        self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size,  device = device)
        self.mha = MHA(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads)
        self.layer_norm1 = Normalization(embeddings_dims=embeddings_dims)
        self.layer_norm2 = Normalization(embeddings_dims=embeddings_dims)
        self.dropout = nn.Dropout(p = dropout)

        self.moe_block = MoeLayer(dropout=dropout, embeddings_size=embeddings_dims)

    def forward(self, x, rope, ffn):

        x = x + self.mha(self.layer_norm1(x), rope)  #Very important step -> Layer Norm on input and then passes it to the subsequent blocks
        if(ffn):
            x = x + self.feedforward_network(self.layer_norm2(x))
        else:
            x = x + self.moe_block(self.layer_norm2(x)) #Very important step

        return x


In [None]:
class Llama4Scout(nn.Module):
    def __init__(self,
                    device,
                  embeddings_dims: int = ModelArgs.embeddings_dims,
                  no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers,
                  block_size: int = ModelArgs.block_size,
                  vocab_size: int = ModelArgs.vocab_size,
                  dropout = ModelArgs.dropout

                 ) :
        super().__init__()
        self.base_freq = ModelArgs.base_freq
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims,  dtype=torch.float32,  device = device)
        self.decoder = nn.ModuleList(DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout,  device = device) for _ in range(no_of_decoder_layers))
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size,  dtype=torch.float32,  device = device)
        self.dropout = nn.Dropout(p = dropout)
        self.norm = Normalization(embeddings_dims)
        
        
        #weight tying
        # self.embeddings.weight = self.linear_layer.weight
    
        self.apply(self._init_weights)

    def _init_weights(self, module):
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
               
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
               
                     
                    
    def forward(self, x):
        index = 0
        no_of_layers = 0
        x = self.embeddings(x)
        x = self.dropout(x)
        # x = self.decoder(x)
        for layer in self.decoder:
            if no_of_layers % 2 == 0:
                x = layer(x, rope=True, ffn=True)
                # print("x shape: ", x.shape)
            else:
                
                x = layer(x, rope=False, ffn=False)
                # print("x shape local: ", x.shape)
            no_of_layers += 1
        # print(x.shape)
        x = self.norm(x)
        x = self.linear_layer(x)
        
        return x

In [None]:
# Instantiating the model

model = Llama4Scout(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=ModelArgs.device, ffn=True)
model = model.to(ModelArgs.device)


In [None]:
#Printing a summary of the architecture
from torchinfo import summary
idx, targets = get_batch('test')
idx = idx.to(ModelArgs.device)
summary(model=model,
        input_data=idx,
        # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

In [None]:
# import tqdm 
def train():
    # Set device to CUDA if available
    device = ModelArgs.device
    print(f"Start running training on {device}.")
    
    # Initialize wandb for experiment tracking
    wandb.init(
        project = 'Gemma-Training',
        # config = ModelArgs, # you can uncomment this to log model config
    )
    
    # Create model and move to GPU
    model = Llama4Scout(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, 
                  vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device)
    model = model.to(device)

    print("Model loaded")
    # Setup optimizer
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=ModelArgs.max_lr)
    
    # Training parameters
    save_checkpoint_iter = 500
    total_iters = 25000
    eval_iters = 500

    
    # Training progress bar
    train_epoch_iterator = tqdm.tqdm(range(total_iters), desc="Training")
    val_dataloader = prepare_dataset('val', device, ModelArgs.batch_size)
    val_iterator = iter(val_dataloader)
    # Get batches for training
    @torch.inference_mode()
    def estimate_loss():
        out = {}
        model.eval()
        count = 0
        for split in ['val']:
            print(f"Starting with {split} evaluation...")
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):

                nonlocal val_iterator
                
                # for k, batch in enumerate(dataloader):
                try:
                    batch = next(val_iterator)
                except StopIteration:
                    val_iterator = iter(val_dataloader)
                    batch = next(val_iterator)
            
                input_ids = batch["input_ids"].to(device)
                targets = batch["labels"].to(device)
                
                logits = model(input_ids)
                batch_size, block_size, embeddings_dims = logits.shape
                logits = logits.view(batch_size*block_size, embeddings_dims)
                targets = targets.view(batch_size * block_size)
                loss = nn.functional.cross_entropy(logits, targets)
                losses[k] = loss.item()
                # count += 1
            out[split] = losses.mean()

        model.train()
        return out
    token_count = 0
    # Start training loop
    model.train()
    print("Lessgoo...")
    dataloader = prepare_dataset('train', device, ModelArgs.batch_size)
    train_dataloader = iter(dataloader) 
    accumulated_loss = 0.0
    for step in train_epoch_iterator:
        # Periodically evaluate loss on train and val sets
        if (step % eval_iters == 0 and step != 0) or step == total_iters - 1:
            losses = estimate_loss()
            avg_val_loss = torch.Tensor([losses['val']]).to(device)
            print(f"step {step}: train loss {accumulated_loss:.4f}, val loss {losses['val']:.4f}")
            val_perplexity = torch.exp(torch.tensor(avg_val_loss)).item()
            # Log metrics to wandb
            wandb.log({
                "val_perplexity": val_perplexity,
                # "val_step_loss": losses['train'],
                "val_step_loss": losses['val'],
                "step": step
            })
            
        # Save checkpoint periodically
        if step % save_checkpoint_iter == 0 and step != 0:
            print(f"Saving the model checkpoint for step: {step}")
            torch.save(model.state_dict(), "checkpoint.pt")
            print("Checkpoint saved")
        
        # Get batch for training step
        try:
            batch = next(train_dataloader)
        except StopIteration:
            train_dataloader = iter(dataloader)
            batch = next(train_dataloader)
            
        # for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        targets = batch["labels"].to(device)
        
        # Forward pass
        logits = model(input_ids)
        batch_size, block_size, embeddings_dims = logits.shape
        logits = logits.view(batch_size*block_size, embeddings_dims)
        targets = targets.view(batch_size * block_size)
        loss = nn.functional.cross_entropy(logits, targets)

        token_count += (len(input_ids) * ModelArgs.batch_size)
        
        # Backward pass
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        accumulated_loss = loss.item()
        perplexity = torch.exp(torch.tensor(accumulated_loss)).item()  # Calculate perplexity
        # if(device == 0):
        wandb.log({
                    # "Learning Rate": scheduler.get_last_lr()[0],
                    "Train_Loss": accumulated_loss,
                    # "Train loss": loss.item(),
                    "Train Perplexity": perplexity,
                    "Total Tokens Processed": token_count,
                    "Step": step,
                    # "Gradient Norm": total_norm_before.item(),
                    # "Epoch": epoch
                    
        })
        
        if(step % eval_iters == 0):
                prompt = "Once upon a time "
                generated_text = topk_sampling(model, prompt, max_length=ModelArgs.block_size, top_k=50, temperature=1.0, device=device)
    
     
                print(f" Step: {step} | Generated Text: {generated_text}")

    # Finish wandb run
    wandb.finish()

# Print CUDA device count but won't be using DDP
world_size = torch.cuda.device_count()
print(f"CUDA devices available: {world_size}")
train()