In [5]:
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import inspect

In [None]:
#Layer norm
class LayerNorm(nn.Module):
    def __init__(self, ndim):
        super().__init__()
        self.weight= nn.Parameter(torch.ones(ndim)) #weight is a vector, not a matrix- elemnent vise scaling happens when Wx+b happens- if W were a matrix, it eats the purpose of norm and it would just become a linear layer.
        self.bias=nn.Parameter(torch.zeros(ndim))

    def forward(self,x):
        return F.layer_norm(x,self.weight.shape,self.weight, self.bias,1e-5)
    


#causal/masked self attention head
class Head(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.head_size=config.n_embed // config.n_head
        self.k=nn.Linear(config.n_embed, config.n_embed // config.n_head, bias=config.bias)
        self.q=nn.Linear(config.n_embed, config.n_embed // config.n_head, bias=config.bias)
        self.v=nn.Linear(config.n_embed, config.n_embed // config.n_head, bias=config.bias)
        self.register_buffer("tril",torch.tril(torch.ones(config.block_size,config.block_size)))
        self.dropout=nn.Dropout(config.dropout)

    def forward(self, x):
        B,T,C= x.size()
        #head_size= n_embed/n_head
        q=self.q(x)
        k=self.k(x)
        v=self.v(x) #BxTxhead_size
        wei= q @ k.transpose(-2,-1)* (1.0 / math.sqrt(k.size(-1))) #BxTxT
        wei=wei.masked_fill(self.tril[:T,:T]==0,float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei=self.dropout(wei)
        out= wei @ v  #BxTxT @ BxTxhead_size= BxTxhead_size
        return out

class MaskedMultiAtt(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.heads= nn.ModuleList([Head(config) for _ in range(config.n_head)])
        self.proj= nn.Linear(config.n_embed,config.n_embed)
        self.dropout=nn.Dropout(config.dropout)

    def forward(self, x):
        out= torch.cat([h(x) for h in self.heads], dim=-1) #concatenates the all the n_head heads of size n_embed/n_heads --> output is BxTxC
        out = self.dropout(self.proj(out))
        return out 



#Multi-Layer-Perceptron
class MLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.c_fc=nn.Linear(config.n_embed,4* config.n_embed)   #probably should change this 4x to 3x or smaller
        self.gelu= nn.GELU()
        self.proj= nn.Linear(4* config.n_head, config.n_embed)
        self.dropout=nn.Dropout(config.dropout)

    def forward(self,x):
        x=self.c_fc(x)
        x=self.gelu(x)
        x=self.proj(x)
        x=self.dropout(x)
        return x

#Block
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1=LayerNorm(config.n_embed,config.bias)
        self.attn=MaskedMultiAtt(config)
        self.ln2=LayerNorm(config.n_embed,config.bias)
        self.mlp= MLP(config)

    def forward(self,x):
        x=x+self.attn(self.ln1(x))
        x=x+self.mlp(self.ln2(x))
        return x

#defining size of model and structure
@dataclass
class SLMconfig:
    n_embed:int =384
    vocab_size:int =65
    block_size:int =64
    n_layer:int =4
    n_head:int =6
    bias: bool= False
    dropout: float= 0.0


#final architecture
class SLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config=config

        self.transformer= nn.ModuleDict(dict(
            wt_tok_em=nn.Embedding(config.vocab_size,config.n_embed),
            wt_pos_emb=nn.Embedding(config.block_size,config.n_embed),
            drop = nn.Dropout(config.dropout),
            h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f=LayerNorm(config.n_head)

        ))
        self.lm_head=nn.Linear(config.n_embed,config.n_vocab,bias=False)
        self.transformer.wt_tok_em.weight= self.lm_head.weight
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
        
    def get_num_params(self, non_embedding=True):
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wt_pos_emb.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self,config):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device)
        tok_em=self.transformer.wt_tok_em(idx) # b x t x n_embed
        pos_em=self.transformer.wt_tok_em(pos) # bxtxn_embed
        x=self.transformer.drop(tok_em+pos_em)
        for block in self.transformer.h:
            x=block(x)
        x=self.transformer.ln_f(x)
        if targets is not None:
            logits=self.lm_head(x) #b x t x vocab_size
            loss=F.cross_entropy(logits.view(-1,logits.size(-1)),targets.view(-1),ignore_index=-1)
        else:
            d#uring inference no loss is calc
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None
        return logits, loss

    def crop_blocksize(self, block_size):
        assert block_size <= self.config.block_size
         self.config.block_size = block_size
        self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
        #verify
        for block in self.transformer.h:
            if hasattr(block.attn, 'bias'):
                block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]

    @torch.no_grad()
    def generate(self, idx,max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond= idx if idx.size(-1)<= self.config.block_size else idx[:,-self.config.block_size,:]
            logits, _= self(idx_cond)
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs= F.softmax(logits, dim=-1)
            idx_next= torch.multinomial(probs, nu_samples=1,)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


