In [1]:
import torch

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

@dataclass
class ModelArgs :
    n_layers : int
    hidden_dim : int
    head_dim : int
    

class SimpleSelfAttention(nn.Module) :
    def __init__(self,hidden_dim) :
        super().__init__()
        
        self.hidden_dim = hidden_dim
        
        self.wq = nn.Linear(hidden_dim,hidden_dim)
        self.wk = nn.Linear(hidden_dim,hidden_dim)
        self.wv = nn.Linear(hidden_dim,hidden_dim)
        
        self.wo = nn.Linear(hidden_dim,hidden_dim)
    
    def forward(self,x) :
        # x is of shape (B,S,d)
        
        xq = self.wq(x) # shape (B, S, d)
        xk = self.wk(x)
        xv = self.wv(x)
        
        attn = (xq @ xk.transpose(-1,-2))/(self.hidden_dim ** 0.5)
        
        attn_softmax = F.softmax(attn,dim=-1) # (B, S, S)
        
        out = attn_softmax @ xv # (B, S, d)
        
        return self.wo(out)

class MultiHeadSelfAttention(nn.Module) :
    def __init__(self,hidden_dim,n_heads, head_dim) :
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = head_dim
        
        self.wq = nn.Linear(hidden_dim,n_heads *head_dim)
        self.wk = nn.Linear(hidden_dim,n_heads *head_dim)
        self.wv = nn.Linear(hidden_dim,n_heads *head_dim)
        
        self.wo = nn.Linear(n_heads *head_dim,hidden_dim)
    
    def forward(self,x) :
        # x is of shape (B,S,d)
        B, S, d = x.shape
        
        xq = self.wq(x).view(B, S, self.n_heads, self.head_dim).transpose(1,2) # (B, n_heads, S, head_dim)
        xk = self.wk(x).view(B, S, self.n_heads, self.head_dim).transpose(1,2)
        xv = self.wv(x).view(B, S, self.n_heads, self.head_dim).transpose(1,2)
        
        attn = (xq @ xk.transpose(-1,-2))/(self.hidden_dim ** 0.5) # (B, n_heads, S, S)
        
        attn_softmax = F.softmax(attn,dim=-1) # (B, n_heads, S, S)
        
        out = attn_softmax @ xv # (B, n_heads, S, head_dim)
        
        out = out.transpose(1,2).contiguous().view(B,S,self.hidden_dim)
        
        return self.wo(out)

class MaskedMultiHeadSelfAttention(nn.Module) :
    def __init__(self,hidden_dim,n_heads, head_dim) :
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = head_dim
        
        self.wq = nn.Linear(hidden_dim,n_heads *head_dim)
        self.wk = nn.Linear(hidden_dim,n_heads *head_dim)
        self.wv = nn.Linear(hidden_dim,n_heads *head_dim)
        
        self.wo = nn.Linear(n_heads *head_dim,hidden_dim)
    
    def forward(self,x) :
        # x is of shape (B,S,d)
        B, S, d = x.shape
        
        xq = self.wq(x).view(B, S, self.n_heads, self.head_dim).transpose(1,2) # (B, n_heads, S, head_dim)
        xk = self.wk(x).view(B, S, self.n_heads, self.head_dim).transpose(1,2)
        xv = self.wv(x).view(B, S, self.n_heads, self.head_dim).transpose(1,2)
        
        attn = (xq @ xk.transpose(-1,-2))/(self.head_dim ** 0.5) # (B, n_heads, S, S)
        
        mask = torch.triu(torch.ones(S,S,device=x.device),diagonal=1).unsqueeze(0).unsqueeze(0)
        attn = attn.masked_fill(mask==1,-float("inf"))
        
        attn_softmax = F.softmax(attn,dim=-1) # (B, n_heads, S, S)
        
        out = attn_softmax @ xv # (B, n_heads, S, head_dim)
        
        out = out.transpose(1,2).contiguous().view(B,S,self.hidden_dim)
        
        return self.wo(out)

class GroupedQueryAttention(nn.Module):
    def __init__(self, hidden_dim, n_q_heads, n_kv_heads, head_dim):
        super().__init__()

        assert n_q_heads % n_kv_heads == 0
        self.hidden_dim = hidden_dim
        self.n_q_heads = n_q_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = head_dim
        self.group_size = n_q_heads // n_kv_heads

        self.wq = nn.Linear(hidden_dim, n_q_heads * head_dim)
        self.wk = nn.Linear(hidden_dim, n_kv_heads * head_dim)
        self.wv = nn.Linear(hidden_dim, n_kv_heads * head_dim)

        self.wo = nn.Linear(n_q_heads * head_dim, hidden_dim)

    def forward(self, x):
        B, S, _ = x.shape

        q = self.wq(x).view(B, S, self.n_q_heads, self.head_dim).transpose(1, 2)
        k = self.wk(x).view(B, S, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.wv(x).view(B, S, self.n_kv_heads, self.head_dim).transpose(1, 2)

        # Repeat K/V for grouped heads
        k = k.repeat_interleave(self.group_size, dim=1)
        v = v.repeat_interleave(self.group_size, dim=1)

        attn = (q @ k.transpose(-1, -2)) / (self.head_dim ** 0.5)

        mask = torch.triu(torch.ones(S, S, device=x.device), diagonal=1)
        attn = attn.masked_fill(mask.bool(), float("-inf"))

        attn = torch.softmax(attn, dim=-1)
        out = attn @ v

        out = out.transpose(1, 2).contiguous().view(B, S, -1)
        return self.wo(out)


class SwiGLU(nn.Module):
    def __init__(self,hidden_dim,intermediate_dim) :
        super().__init__()
        
        self.w1 = nn.Linear(hidden_dim,intermediate_dim)
        self.w2 = nn.Linear(hidden_dim,intermediate_dim)
        
        self.wo = nn.Linear(intermediate_dim,hidden_dim)
    
    def forward(self,x) :
        # x is of shape (B, S, d)
        return self.wo(F.silu(self.w1(x)) * self.w2(x))

class RMSNorm(nn.Module) :
    def __init__(self,epsilon, hidden_dim) :
        super().__init__()
        self.epsilon = epsilon
        self.alpha = nn.Parameter(torch.ones(hidden_dim))
    
    def forward(self, x) :
        # x / (norm x + eps)
        norm_x = x.pow(2).mean(dim=-1,keepdim=True).sqrt()
        #norm_x = torch.norm(x,dim=-1,keepdim=True)
        return self.alpha * (x/(norm_x + self.epsilon))

class TransformerBlock(nn.Module) :
    def __init__(self,hidden_dim,intermediate_dim,n_heads, head_dim,epsilon) :
        super().__init__()
        
        self.masked_multihead_attention = MaskedMultiHeadSelfAttention(hidden_dim,n_heads, head_dim)
        self.swiglu = SwiGLU(hidden_dim,intermediate_dim)
        
        self.norm_1 = RMSNorm(epsilon, hidden_dim)
        self.norm_2 = RMSNorm(epsilon, hidden_dim)
    
    def forward(self,x) :
        x = x + self.masked_multihead_attention(self.norm_1(x))
        x = x + self.swiglu(self.norm_2(x))
        
        return x
    
class Transformer(nn.Module) :
    def __init__(self,hidden_dim,intermediate_dim,n_heads, head_dim,epsilon, n_layers, vocab_size) :
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.layers = nn.ModuleList(TransformerBlock(hidden_dim,intermediate_dim,n_heads, head_dim,epsilon) for _ in range(n_layers))
        self.classifier = nn.Linear(hidden_dim,vocab_size)
        
    def forward(self,x) :
        h = self.embedding(x)
        
        for layer in self.layers :
            h = layer(h)
        
        out = self.classifier(h)
        
        return out



model = Transformer(hidden_dim=128,intermediate_dim=256,n_heads=4, head_dim=32,epsilon=1e-8, n_layers=3, vocab_size=10000)
model.train()

batch_size = 32
sequence_len = 24
epochs = 2
learning_rate = 10e-4
vocab_size=10000

optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

x = torch.randint(low=0,high=10000,size=(batch_size, sequence_len))
y = torch.randint(low=0,high=10000,size=(batch_size, sequence_len))

for epoch in range(epochs) :
    optimizer.zero_grad()
    y_pred = model(x)
    loss = loss_fn(y_pred.view(-1,vocab_size),y.view(-1).long())
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}/{epochs} - loss : {loss.item()}")


Epoch 0/2 - loss : 9.329459190368652
Epoch 1/2 - loss : 9.200955390930176


In [None]:
from torch.utils.data import DataLoader, Dataset
import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")
with open("path.txt","r") as f :
    text = f.read()
    
tokens = tokenizer.encode(text)

class TrainingDataset(Dataset) :
    def __init__(self,tokens,sequence_len) :
        super().__init__()
        self.tokens = tokens
        self.sequence_len = sequence_len
    def __len__(self) :
        return len(self.tokens)
    
    def __getitem__(self,idx) :
        x = torch.tensor(tokens[idx : idx+self.sequence_len])
        y = torch.tensor(tokens[idx+1 : idx+self.sequence_len+1])
        return x, y

dataset = TrainingDataset(tokens,128)

loader = DataLoader(dataset,batch_size=32,shuffle=True)

for epoch in range(epochs) :
    for x,y in loader :
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred.view(-1,vocab_size),y.view(-1).long())
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch}/{epochs} - loss : {loss.item()}")

In [None]:
class SimpleCNN(nn.Module) :
    def __init__(self) :
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,stride=2,padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3,stride=2)
        self.conv2 = nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=2,padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3,stride=2)
        self.classifier = nn.Linear(32*1*1,2)
    
    def forward(self,x) :
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = torch.flatten(x,1)
        return self.classifier(x)
        