In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class SelfAttention(nn.Module) :
    def __init__(self,hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.wq = nn.Linear(hidden_dim,hidden_dim)
        self.wk = nn.Linear(hidden_dim,hidden_dim)
        self.wv = nn.Linear(hidden_dim,hidden_dim)
        
        self.wo = nn.Linear(hidden_dim,hidden_dim)
    
    def forward(self,x) :
        xq = self.wq(x)
        xk = self.wk(x)
        xv = self.wv(x)
        
        attention = (xq @ xk.transpose(-1,-2))/(self.hidden_dim **0.5)
        
        attention_softmax = F.softmax(attention,dim=1) @ xv
        
        return self.wo(attention_softmax)

In [None]:
class MultiheadAttention(nn.Module) :
    def __init__(self,hidden_dim,n_head,head_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_head = n_head
        self.head_dim = head_dim
        self.wq = nn.Linear(hidden_dim,n_head*head_dim)
        self.wk = nn.Linear(hidden_dim,n_head*head_dim)
        self.wv = nn.Linear(hidden_dim,n_head*head_dim)
        
        self.wo = nn.Linear(n_head*head_dim,hidden_dim)
    
    def forward(self,x) :
        B,S = x.shape
        xq = self.wq(x).view(B,S,self.n_head,self.head_dim).transpose(1,2)
        xk = self.wk(x).view(B,S,self.n_head,self.head_dim).transpose(1,2)
        xv = self.wv(x).view(B,S,self.n_head,self.head_dim).transpose(1,2)
        
        attention = (xq @ xk.transpose(-1,-2))/(self.hidden_dim **0.5)
        
        attention_softmax = F.softmax(attention) @ xv
        
        return self.wo(attention_softmax.transpose(2,1).contigous().view(B,S,self.n_head*self.head_dim))

In [None]:
class MaskedMultiheadAttention(nn.Module) :
    def __init__(self,hidden_dim,n_head,head_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_head = n_head
        self.head_dim = head_dim
        self.wq = nn.Linear(hidden_dim,n_head*head_dim)
        self.wk = nn.Linear(hidden_dim,n_head*head_dim)
        self.wv = nn.Linear(hidden_dim,n_head*head_dim)
        
        self.wo = nn.Linear(n_head*head_dim,hidden_dim)
    
    def forward(self,x) :
        B,S = x.shape
        xq = self.wq(x).view(B,S,self.n_head,self.head_dim).transpose(1,2)
        xk = self.wk(x).view(B,S,self.n_head,self.head_dim).transpose(1,2)
        xv = self.wv(x).view(B,S,self.n_head,self.head_dim).transpose(1,2)
        
        attention = (xq @ xk.transpose(-1,-2))/(self.hidden_dim **0.5)
        
        attention_softmax = F.softmax(attention) @ xv
        
        return self.wo(attention_softmax.transpose(2,1).contigous().view(B,S,self.n_head*self.head_dim))

In [None]:
class SelfAttention(nn.Module) :
    def __init__(self, hidden_dim) :
        super().__init__()
        self.hidden_dim = hidden_dim
        self.wq = nn.Linear(hidden_dim,hidden_dim)
        self.wk = nn.Linear(hidden_dim,hidden_dim)
        self.wv = nn.Linear(hidden_dim,hidden_dim)
        
        self.wo = nn.Linear(hidden_dim,hidden_dim)
    
    def forward(self,x) :
        B,S = x.shape
        xq = self.wq(x)
        xk = self.wk(x)
        xv = self.wv(x)
        
        attention = (xq @ xk.transpose(-1,-2))/(self.hidden_dim**0.5)
        
        out = F.softmax(attention,dim=-1)
        out = out @ xv
        return self.wo(out)
    
    
class MaskedMultiheadAttention(nn.Module) :
    def __init__(self, hidden_dim,n_head,head_dim) :
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_head = n_head
        self.head_dim = head_dim
        
        self.wq = nn.Linear(hidden_dim,n_head*head_dim)
        self.wk = nn.Linear(hidden_dim,n_head*head_dim)
        self.wv = nn.Linear(hidden_dim,n_head*head_dim)
        
        self.wo = nn.Linear(n_head*head_dim,hidden_dim)
    
    def forward(self,x) :
        B,S = x.shape
        xq = self.wq(x).view(B,S,self.n_head,self.head_dim).transpose(1,2)
        xk = self.wk(x).view(B,S,self.n_head,self.head_dim).transpose(1,2)
        xv = self.wv(x).view(B,S,self.n_head,self.head_dim).transpose(1,2)
        
        attention = (xq @ xk.transpose(-1,-2))/(self.hidden_dim**0.5)
        
        mask = torch.triu(torch.ones(S,S),diag=1).unsqueeze(0).unsqueeze(0)
        attention = attention.masked_fill(mask==1,-float("inf"))
        
        out = F.softmax(attention,dim=-1)
        
        out = out @ xv
        
        return self.wo(out.transpose(1,2).contiguous().view(B,S,self.n_head*self.head_dim))

In [None]:
class Swiglu(nn.Module) :
    def __init__(self,hidden_dim,intermediate_dim) :
        super().__init__()
        
        self.w1 = nn.Linear(hidden_dim,intermediate_dim)
        self.w2 = nn.Linear(hidden_dim,intermediate_dim)
        
        self.wo = nn.Linear(intermediate_dim,hidden_dim)
        
    def forward(self,x) :
        return self.wo(F.silu(self.w1(x))*self.w2(x))
    
class RMSNorm(nn.Module):
    def __init__(self,epsilon,hidden_dim) :
        super().__init__()
        self.epsilon = epsilon
        self.alpha = nn.Parameter(torch.ones(hidden_dim))
    def forward(self,x) :
        norm_x =  x.pow(2).mean(dim=-1,keepdim=True).sqrt()
        return self.alpha * (x / (norm_x + self.epsilon))

In [1]:
def counter(n):
    for i in range(n):
        yield i

for x in counter(5):
    print(x)


0
1
2
3
4


In [3]:
generator = counter(5)

In [None]:
class MaskedMultiHeadAttention(nn.Module) :
    def __init__(self,hidden_dim,n_head,head_dim) :
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.head_dim = head_dim
        self.n_head = n_head
        
        self.wq = nn.Linear(hidden_dim,n_head*head_dim)
        self.wk = nn.Linear(hidden_dim,n_head*head_dim)
        self.wv = nn.Linear(hidden_dim,n_head*head_dim)
        
        self.wo = nn.Linear(n_head*head_dim,hidden_dim)
        
    def forward(self,x) :
        B,S = x.shape[0],x.shape[1]
        
        xq = self.wq(x).view(B,S,self.n_head,self.head_dim).transpose(1,2) #B,n_head,S,head_dim
        xk = self.wk(x).view(B,S,self.n_head,self.head_dim).transpose(1,2) #B,n_head,S,head_dim
        xv = self.wv(x).view(B,S,self.n_head,self.head_dim).transpose(1,2) #B,n_head,S,head_dim
        
        attention = (xq @ xk.transpose(-1,-2)) / (self.head_dim**0.5) # B, n_head, S, S
        
        mask = torch.triu(torch.ones(S,S,device=x.device),diag=1).unsqueeze(0).unsqueeze(0)
        
        attention_masked = attention.masked_fill(mask==1,-float("inf"))
        out = F.softmax(attention_masked,dim=-1) @ xv
        
        return self.wo(out.transpose(1,2).contiguous().view(B,S,self.n_head*self.head_dim))
        

class RMSNorm(nn.Module) :
    def __init__(self,epsilon, hidden_dim):
        self.epsilon = epsilon
        self.hidden_dim = hidden_dim
        self.alpha = torch.Parameter(torch.ones(hidden_dim))
        
    def forward(self,x) :
        norm = x.pow(2).mean(dim=-1,keepdim=True).sqrt()
        return self.alpha * (x/(norm + self.epsilon))