In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from torchtune.modules import RMSNorm
from tokenizers import Tokenizer
from pathlib import Path

In [None]:
!nvidia-smi

In [None]:

@dataclass
class ModelArgs:
    #Hyperparameters

    block_size = 128
    batch_size = 16
    embeddings_dims = 256
    attn_dropout = 0.1
    no_of_heads = 4 #IMP needs to be thoroughly calculated
    dropout = 0.1
    epochs = 100
    max_lr = 1e-4
    no_of_decoder_layers = 6 #IMP needs to be thoroughly calculated
    weight_decay_optim = 0.1
    beta_1 = 0.9
    beta_2 = 0.95
    clip = 1.0
    device = 'cuda'
    no_kv_heads = 2
    vocab_size = 10000

In [None]:
#Collab setup
data_path = Path('data')
data_path.mkdir(exist_ok=True)
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!cp input.txt data/input.txt


In [None]:
#Datasets

# Using tinyshakespeare

with open('data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()


In [None]:

#Subword level tokenization

#Loading custom trained BPE
# Load the tokenizer
tokenizer = Tokenizer.from_file("data/bpe_tokenizer_tinyshakespeare_1k.json")
vocab_size = tokenizer.get_vocab_size()
# Encode and decode functions
encode = lambda s: tokenizer.encode(s).ids
decode = lambda l: tokenizer.decode(l)

In [None]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - ModelArgs.block_size, (ModelArgs.batch_size,))
    x = torch.stack([data[i:i+ModelArgs.block_size] for i in ix])
    y = torch.stack([data[i+1:i+ModelArgs.block_size+1] for i in ix])
    x, y = x.to(ModelArgs.device), y.to(ModelArgs.device)
    return x, y

In [None]:
class Normalization(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):  
        super().__init__()
        self.rmsnorm_layer = RMSNorm(dim=embeddings_dims)
        
        
    def forward(self, x):
        
        x = self.rmsnorm_layer(x)
        return x
        

In [None]:
random_data.shape

In [None]:
rh = Normalization(embeddings_dims=ModelArgs.embeddings_dims)
random_data = torch.randn((ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims))
res = rh(random_data)
res.shape

In [202]:
import numpy as np
class RotaryEmbeddings(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size,
        batch_size: int = ModelArgs.batch_size
    ):
        super().__init__()
        
        self.embeddings_dims = embeddings_dims
        self.block_size = block_size
        self.batch_size = batch_size
        self.theta = 0  

        
    # def init_matrix(self, seq_len):
    #         self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, device=ModelArgs.device, requires_grad=False)
    #         for pos in range(seq_len):
    #             for j in range(1, self.embeddings_dims // 2):
    #                 self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims)
    #                 self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta))
    #                 self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta))
    #                 self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta))
    #                 self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta))
    #         return self.matrix
    
    def init_matrix(self, seq_len):
        self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, device=ModelArgs.device, requires_grad=False)
        
        positions = torch.arange(seq_len, device=ModelArgs.device, dtype=torch.float32).unsqueeze(1)
        # dims = torch.arange(1, self.embeddings_dims // 2, device=ModelArgs.device, dtype=torch.float32)
        theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims)
        angles = positions * theta
        
        cos_angles = torch.cos(angles)
        sin_angles = torch.sin(angles)
        
        indices = torch.arange(self.embeddings_dims, device=ModelArgs.device, dtype=torch.int64)
        print(indices)
        # print(indices.shape)
        print(indices[::2])
        even_indices = indices[::2]
        odd_indices = indices[1::2]
        
        self.matrix[:, even_indices, even_indices] = cos_angles
        self.matrix[:, odd_indices, odd_indices] = sin_angles
        self.matrix[:, odd_indices, even_indices] = -sin_angles
        self.matrix[:, even_indices, odd_indices] = cos_angles
        
        return self.matrix

    def forward(self, x):
        # B,T,C = x.shape
        # print("MATRIX:",x)
        if(x > self.block_size):
            matrix = self.init_matrix(x)
            return matrix
        else:
            matrix = self.init_matrix(self.block_size)
            
            return matrix

In [203]:
rot = RotaryEmbeddings()
res = rot(128)
res.shape

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
        168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 1

torch.Size([128, 256, 256])

In [None]:
class RotaryAttentionHead(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        no_of_heads: int = ModelArgs.no_of_heads,
        attn_dropout: int = ModelArgs.attn_dropout
    ):
        super().__init__()
        self.head_size = embeddings_dims // no_of_heads
        self.query = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=ModelArgs.device, bias=False, dtype=torch.float32)
        self.key = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=ModelArgs.device, bias=False, dtype=torch.float32)
        self.value = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=ModelArgs.device, bias=False, dtype=torch.float32)
        self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims)
        self.dropout = nn.Dropout(p = attn_dropout)
        
    def forward(self,x):
        # print(x.shape)
        batch, block_size, embeddings_dims = x.shape
        query = self.query(x)
        # print(query)
        key = self.key(x)
        values = self.value(x)
        matrix = self.rotary_matrix(block_size)
        
        # print(matrix.shape)
        # print(query.shape)
        masked = torch.tril(torch.ones((block_size, block_size), device=ModelArgs.device, requires_grad=False))
        rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
        rotary_key = matrix @ key.permute(1,2,0)  #  (B,T, C,C  ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
        weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
        weights_masked = weights.masked_fill(masked == 0, float('-inf'))
        scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
        scaled_weights = F.softmax(scaled_weights, dim=-1)
        value = scaled_weights @ values
        out = self.dropout(value)
        return out

In [None]:
rh = RotaryAttentionHead()
random_data = torch.randn((ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims))
res = rh(random_data)
res.shape

In [None]:
class MQA(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size,
        no_of_kv_heads: int = ModelArgs.no_of_heads,
        no_of_heads: int = ModelArgs.no_of_heads
    ):
        super().__init__()
        
        self.no_of_kv_heads = no_of_kv_heads
        self.no_of_q_heads = no_of_heads // no_of_kv_heads
        self.head_size = embeddings_dims // self.no_of_q_heads
        self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims)
        # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.key = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=ModelArgs.device, dtype=torch.float32, bias=False)
        self.value = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=ModelArgs.device, dtype=torch.float32, bias=False)
        self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=ModelArgs.device, dtype=torch.float32, bias=False)
        
        
        
    def scaled_dot_product(self, q, k, v, block_size, matrix):
            
            masked = torch.tril(torch.ones((block_size, block_size), device=ModelArgs.device, requires_grad=False))
            # print("Before: ")
            # print(q.shape)
            # print(torch.transpose(q, dim0=-2, dim1=-1).shape)
            # print(matrix.shape)
            # print(k.shape)
            # print(torch.transpose(k, dim0=-2, dim1=-1).shape)
            # rotary_query = matrix @ torch.transpose(q, dim0=-2, dim1=-1)
            # rotary_key = matrix @ torch.transpose(k, dim0=-2, dim1=-1)
            # print("After: ")
            # print(q.shape)
            # print(matrix.shape)
            # print(k.shape)
            masked = torch.tril(torch.ones((block_size, block_size), device=ModelArgs.device, requires_grad=False))
            rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
            rotary_key = matrix @ k.permute(1,2,0)  #  (B,T, C,C  ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
            weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
            weights_masked = weights.masked_fill(masked == 0, float('-inf'))
            scaled_weights = weights_masked / (torch.sqrt(torch.tensor(k.shape[-1])))
            scaled_weights = F.softmax(scaled_weights, dim=-1)
            value = scaled_weights @ v
            out = self.dropout(value)
            return value
    
    def forward(self,x):
        # print("MQA: ", x.shape)
        batch, block_size, embeddings_dims = x.shape
        multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=ModelArgs.device, bias=False) for _ in range(self.no_of_q_heads)])
        # query = self.query(x)
        matrix = self.rotary_matrix(block_size)
            

        key = self.key(x)
        values = self.value(x)
        # rotary_query = matrix @ torch.transpose(, dim0=1, dim1=0)
        # rotary_key = matrix @ torch.transpose(key, dim0=1, dim1=0)
        # matrix = self.rotary_matrix(block_size)
        # self.mqa = nn.ModuleList([
           
        # ])
        multi_query_concat = torch.cat([self.scaled_dot_product(query(x), key, values, block_size, matrix) for query in multi_query], dim=-1)
        # linear_layer_query = self.linear_layer(multi_query_concat)
        # masked = torch.tril(torch.ones((block_size, block_size), device=ModelArgs.device, requires_grad=False))
        # rotary_query = matrix @ torch.transpose(query, dim0=1, dim1=0) # (B,T,C ) @ (B,T,C,C) -> (B,C,T)
        # rotary_key = matrix @ torch.transpose(key, dim0=1, dim1=0) # (B,T,C ) @ (B,T,C,C) -> (B,C,T)
        # print(multi_query_concat.shape)
        # print(key.shape)
        # print(linear_layer_query.shape)
        # weights = linear_layer_query @ (torch.transpose(key, dim0=-2, dim1=-1))
        # weights_masked = weights.masked_fill(masked == 0, float('-inf'))
        # scaled_weights = weights_masked / (key.shape[-1] ** -0.5)
        # scaled_weights = F.softmax(scaled_weights, dim=-1)
        # value = scaled_weights @ values
        
        linear_layer= self.linear_layer(multi_query_concat)
        out = self.dropout(linear_layer)
        return out

In [None]:
class GQA(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size,
        no_of_q_heads: int = ModelArgs.no_of_heads,
        no_of_kv_heads: int = ModelArgs.no_kv_heads
    ):
        super().__init__()
        
        # self.head_size = embeddings_dims // no_of_q_heads
        # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.no_of_kv_heads = no_of_kv_heads
        self.no_of_q_heads = no_of_q_heads
        # self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        # self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
        self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_kv_heads, out_features=embeddings_dims , dtype=torch.float32, device=ModelArgs.device, bias=False)
        
    # def scaled_dot_product(self, q, k, v, block_size):
            
    #         masked = torch.tril(torch.ones((block_size, block_size), device=ModelArgs.device, requires_grad=False))
    #         weights = q @ (torch.transpose(k, dim0=-2, dim1=-1))
    #         weights_masked = weights.masked_fill(masked == 0, float('-inf'))
    #         scaled_weights = weights_masked / (k.shape[-1] ** -0.5)
    #         scaled_weights = F.softmax(scaled_weights, dim=-1)
    #         value = scaled_weights @ v
    #         return value
        
        
    def forward(self,x):
        
        batch, block_size, embeddings_dims = x.shape
        mqa = nn.ModuleList([MQA(embeddings_dims=embeddings_dims, block_size=block_size) for _ in range(self.no_of_kv_heads)])
        # query = self.query(x)
        # key = self.key(x)
        # values = self.value(x)
        # matrix = self.rotary_matrix(block_size)
        grouped_query_concat = torch.cat([group(x) for group in mqa], dim=-1)
        # linear_layer_query = self.linear_layer(multi_query_concat)
        # masked = torch.tril(torch.ones((block_size, block_size), device=ModelArgs.device, requires_grad=False))
        # rotary_query = matrix @ torch.transpose(query, dim0=1, dim1=0) # (B,T,C ) @ (B,T,C,C) -> (B,C,T)
        # rotary_key = matrix @ torch.transpose(key, dim0=1, dim1=0) # (B,T,C ) @ (B,T,C,C) -> (B,C,T)
        # print(multi_query_concat.shape)
        # print(key.shape)
        # print(linear_layer_query.shape)
        # print(grouped_query_concat.shape)     
        linear_layer= self.linear_layer(grouped_query_concat)
        out = self.dropout(linear_layer)
        return out

In [None]:

random_data = torch.randn((ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims))
gqa = GQA()
# input_data = torch.tensor()
res = gqa(random_data)
res.shape

In [None]:
masked = torch.tril(torch.ones((ModelArgs.block_size, ModelArgs.block_size), device=ModelArgs.device, requires_grad=False))
masked.shape

In [None]:
# class KVCache:
#     def __init__(
#         self,
#         embeddings_dims: int =  ModelArgs.embeddings_dims,
#         block_size: int  = ModelArgs.block_size,
#         no_of_decoder_layers: int =ModelArgs.no_of_decoder_layers
#     ):
#         super().__init__()
#         self.head_size = embeddings_dims / no_of_decoder_layers
#         self.k_cache = torch.ones((block_size, embeddings_dims, self.head_size), device=ModelArgs.device, requires_grad=False)
#         self.v_cache = torch.ones((block_size, embeddings_dims, self.head_size), device=ModelArgs.device, requires_grad=False)
#         self.block_size = block_size,
#         self.embeddings_dims = embeddings_dims
#     def update(
#         self,
#         k: torch.tensor,
#         v: torch.tensor
#     ):
#         self.k_cache[:self.block_size, :self.block_size] = k
#         self.v_cache = v
        
#     def get(self):
        

In [None]:
class Swish(nn.Module):
    def __init__(
        self,
        block_size: int = ModelArgs.block_size,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):
        super().__init__()
        
        self.sig = torch.nn.Sigmoid()
        
        
    def forward(self, x):
        swish = x * self.sig(x)
        
        return swish
         

In [None]:
class SWiGLU(nn.Module):
    def __init__(
        self,
        block_size: int = ModelArgs.block_size,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):
        super().__init__()
        
        self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims)
        # self.gated_layer = nn.Linear(in_features=block_size, out_features=embeddings_dims, device=ModelArgs.device, bias=False)
        self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=ModelArgs.device, bias=False, dtype=torch.float32)
        self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=ModelArgs.device, bias=False, dtype=torch.float32)
        self.linear_layer3 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=ModelArgs.device, bias=False, dtype=torch.float32)
        # self.gamma = nn.Parameter(torch.ones((block_size, embeddings_dims), device=ModelArgs.device), requires_grad=True)
        # self.beta = nn.Parameter(torch.ones((block_size, embeddings_dims), device=ModelArgs.device), requires_grad=True)
        
        
        
    def forward(self, x):
        swish_res = self.swish(self.linear_layer1(x))
        x_V = self.linear_layer2(x)
        res = torch.mul(swish_res, x_V)
        out = self.linear_layer3(res)
        return out
         

In [None]:
swiglue = SWiGLU()
res = swiglue(random_data)
res.shape

In [None]:
class FFN(nn.Module):
    def __init__(self,
                  embeddings_dims: int = ModelArgs.embeddings_dims,
                  block_size: int = ModelArgs.block_size,
                  vocab_size: int = ModelArgs.vocab_size,
                   dropout = ModelArgs.dropout
                 
                 ):
        super().__init__()
        
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=ModelArgs.device, dtype=torch.float32)
        self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims)
        self.dropout = nn.Dropout(p = dropout)
    def forward(self, x):
        
        x = self.swiglue(x)
        x = self.linear_layer(x)
        x = self.dropout(x)
        return x

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, 
                embeddings_dims: int = ModelArgs.embeddings_dims,
                dropout = ModelArgs.dropout,
                block_size: int = ModelArgs.block_size,
                vocab_size: int = ModelArgs.vocab_size,
                 
                 ) :
        super().__init__()
        
        
        self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size)
        self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, no_of_kv_heads=ModelArgs.no_kv_heads, no_of_q_heads=ModelArgs.no_of_heads)
        # self.norm = Normalization(embeddings_dims=embeddings_dims)
        self.norm = Normalization(embeddings_dims=embeddings_dims)
        self.dropout = nn.Dropout(p = dropout)
    def forward(self, x):
        
        x = self.norm(x + self.gqa(x))
        x = self.norm(x + self.feedforward_network(x))
        return x

In [None]:
class Llama(nn.Module):
    def __init__(self, 
                  embeddings_dims: int = ModelArgs.embeddings_dims,
                  no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers,
                  block_size: int = ModelArgs.block_size,
                  vocab_size: int = ModelArgs.vocab_size,
                  dropout = ModelArgs.dropout
                 
                 ) :
        super().__init__()
        
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, device=ModelArgs.device, dtype=torch.float32)
        self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout) for _ in range(no_of_decoder_layers)])
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, device=ModelArgs.device, dtype=torch.float32)
        self.dropout = nn.Dropout(p = dropout)
        self.norm = Normalization(embeddings_dims)
    def forward(self, x):
        x = self.embeddings(x)
        x = self.dropout(x)
        x = self.decoder(x)
        # x = self.norm(x)
        x = self.linear_layer(x)
        # out = self.norm(x)
        return x

In [None]:
# Instantiating the model
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
ModelArgs.device = device
model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
model = model.to(ModelArgs.device)

In [None]:
device

In [None]:
idx, targets = get_batch('test')
idx.shape

In [None]:
res = model(idx)
res

In [194]:
#Printing a summary of the architecture
from torchinfo import summary
idx, targets = get_batch('test')
# idx = idx.to(device)
summary(model=model,
        input_data=idx,
        # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
Llama (Llama)                                      [16, 128]            [16, 128, 10000]     --                   True
├─Embedding (embeddings)                           [16, 128]            [16, 128, 256]       2,560,000            True
├─Dropout (dropout)                                [16, 128, 256]       [16, 128, 256]       --                   --
├─Sequential (decoder)                             [16, 128, 256]       [16, 128, 256]       --                   True
│    └─DecoderLayer (0)                            [16, 128, 256]       [16, 128, 256]       --                   True
│    │    └─GQA (gqa)                              [16, 128, 256]       [16, 128, 256]       131,072              True
│    │    └─Normalization (norm)                   [16, 128, 256]       [16, 128, 256]       256                  True
│    │    └─FFN (feedforward_network)        

In [195]:
# Optimizer setup and scheduler steup

optimizer = torch.optim.AdamW(weight_decay=ModelArgs.weight_decay_optim, params=model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2))
# optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, weight_decay=weight_decay_optim)
# initial_iters = 2000
total_steps = 5000
eval_iters = 100
# warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=2000)
# lr_scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max= total_steps - initial_iters)
# lr_scheduler_linear = torch.optim.lr_scheduler.LinearLR(optimizer=optimizer, total_iters=initial_iters)

# @torch.inference_mode()
# def estimate_loss():
#     out = {}
#     model.eval()
#     for split in ['val']:
#         # losses = torch.zeros(eval_iters)
#         # for k in range(eval_iters):
#         idx, targets = get_batch(split=split)
#         logits = model(idx)
#         batch_size, block_size, embeddings_dims = logits.shape
#         logits = logits.view(batch_size*block_size, embeddings_dims) # Total tokens(words) => batch_size * block_size
#         targets = targets.view(batch_size * block_size)
#         loss = nn.functional.cross_entropy(logits, targets)
#         # losses[k] = loss.item()
#       # out[split] = losses.mean()
#         out[split] = loss.item()
#     model.train()
#     return out
@torch.inference_mode()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            idx, targets = get_batch(split=split)
            logits = model(idx)
            batch_size, block_size, embeddings_dims = logits.shape
            logits = logits.view(batch_size*block_size, embeddings_dims) # Total tokens(words) => batch_size * block_size
            targets = targets.view(batch_size * block_size)
            loss = nn.functional.cross_entropy(logits, targets)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [196]:
#Train the  model
from tqdm import tqdm

model.train()
for step in tqdm(range(total_steps)):

    # every once in a while evaluate the loss on train and val sets
    if (step  % eval_iters == 0 and step != 0) or step == total_steps - 1:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        torch.save(model.state_dict(), 'weights/Llama7M_steps_%d.pth' % (step))

    idx, targets = get_batch(split='train')
    logits = model(idx)
    batch_size, block_size, embeddings_dims = logits.shape
    logits = logits.view(batch_size*block_size, embeddings_dims)
    targets = targets.view(batch_size * block_size)
    loss = nn.functional.cross_entropy(logits, targets)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # print(loss.item())
    # break

    # if step != 0 and (step % eval_iters == 0 or step == total_steps -1) :
    #     loss_values = estimate_loss()
    #     print("Train Loss at {} steps : {}".format(step, loss.item()), "Val Loss at {} steps : {}".format(step, loss_values['val']))

  2%|▏         | 99/5000 [00:05<04:17, 19.02it/s]

step 100: train loss 7.1559, val loss 7.1906


  4%|▍         | 200/5000 [00:14<04:05, 19.57it/s]

step 200: train loss 6.2226, val loss 6.3335


  6%|▌         | 298/5000 [00:24<03:53, 20.10it/s]

step 300: train loss 5.9537, val loss 6.1816


  8%|▊         | 400/5000 [00:33<03:53, 19.71it/s]

step 400: train loss 5.7894, val loss 6.1026


 10%|▉         | 498/5000 [00:43<03:52, 19.40it/s]

step 500: train loss 5.6573, val loss 5.9967


 12%|█▏        | 598/5000 [00:52<03:47, 19.32it/s]

step 600: train loss 5.5737, val loss 5.9265


 14%|█▍        | 701/5000 [01:06<47:41,  1.50it/s]

step 700: train loss 5.5019, val loss 5.9163


 16%|█▌        | 801/5000 [01:16<36:43,  1.91it/s]

step 800: train loss 5.4324, val loss 5.8586


 18%|█▊        | 900/5000 [01:21<03:30, 19.48it/s]

step 900: train loss 5.4018, val loss 5.8174


 20%|█▉        | 999/5000 [01:30<03:21, 19.87it/s]

step 1000: train loss 5.3419, val loss 5.8585


 22%|██▏       | 1100/5000 [01:40<03:18, 19.64it/s]

step 1100: train loss 5.3198, val loss 5.8444


 24%|██▍       | 1199/5000 [01:49<03:11, 19.84it/s]

step 1200: train loss 5.2736, val loss 5.7947


 26%|██▌       | 1300/5000 [01:58<03:06, 19.79it/s]

step 1300: train loss 5.2451, val loss 5.7825


 28%|██▊       | 1400/5000 [02:08<03:01, 19.87it/s]

step 1400: train loss 5.2201, val loss 5.7784


 30%|██▉       | 1499/5000 [02:17<02:58, 19.59it/s]

step 1500: train loss 5.1817, val loss 5.7966


 32%|███▏      | 1600/5000 [02:27<02:53, 19.63it/s]

step 1600: train loss 5.1709, val loss 5.7729


 34%|███▍      | 1700/5000 [02:36<02:45, 19.93it/s]

step 1700: train loss 5.1261, val loss 5.7502


 36%|███▌      | 1799/5000 [02:45<02:40, 19.90it/s]

step 1800: train loss 5.1008, val loss 5.7725


 38%|███▊      | 1900/5000 [02:56<04:47, 10.79it/s]


KeyboardInterrupt: 