In [35]:
import torch as t
from math import sqrt
import torch.nn as nn
from torchinfo import summary

In [31]:
#!pip install spacy

Collecting spacy
  Downloading spacy-3.5.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.6 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m0:01[0m:01[0m0m
[?25hCollecting spacy-legacy<3.1.0,>=3.0.11
  Downloading spacy_legacy-3.0.12-py2.py3-none-any.whl (29 kB)
Collecting spacy-loggers<2.0.0,>=1.0.0
  Downloading spacy_loggers-1.0.4-py3-none-any.whl (11 kB)
Collecting murmurhash<1.1.0,>=0.28.0
  Downloading murmurhash-1.0.9-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (21 kB)
Collecting cymem<2.1.0,>=2.0.2
  Downloading cymem-2.0.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34 kB)
Collecting preshed<3.1.0,>=3.0.2
  Downloading preshed-3.0.8-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (122 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━

In [34]:
# EMBED_SIZE refers to representation size of each word
# HIDDEN_SIZE refers to size of query, key and value vectors.

In [3]:
nHEADS = 8
EMBED_SIZE = 512
HIDDEN_SIZE = 64

In [4]:
print(t.__version__)

2.1.0.dev20230510+cu118


In [5]:
X = t.rand(7,512)
X

tensor([[0.8026, 0.8250, 0.4240,  ..., 0.1833, 0.9176, 0.2783],
        [0.6592, 0.1254, 0.0827,  ..., 0.8300, 0.7167, 0.4494],
        [0.4486, 0.2490, 0.2301,  ..., 0.5284, 0.8057, 0.7191],
        ...,
        [0.9032, 0.9598, 0.5788,  ..., 0.3619, 0.3243, 0.9281],
        [0.4655, 0.7555, 0.4076,  ..., 0.7516, 0.3761, 0.7108],
        [0.8270, 0.4645, 0.5747,  ..., 0.8043, 0.3223, 0.7804]])

In [6]:
X.shape

torch.Size([7, 512])

In [7]:
class EncAttention(nn.Module):
    def __init__(self):
        super(EncAttention, self).__init__()
        self.Wq = nn.Parameter(t.rand(nHEADS, EMBED_SIZE, HIDDEN_SIZE))
        self.Wk = nn.Parameter(t.rand(nHEADS, EMBED_SIZE, HIDDEN_SIZE))
        self.Wv = nn.Parameter(t.rand(nHEADS, EMBED_SIZE, HIDDEN_SIZE))
        self.Wo = nn.Parameter(t.rand(nHEADS*HIDDEN_SIZE, EMBED_SIZE))
        
    def forward(self,X):
        Q = X@self.Wq
        K = X@self.Wk
        V = X@self.Wv
        Z = t.bmm(Q,K.transpose(1,2))/sqrt(HIDDEN_SIZE)
        Z = nn.Softmax(dim=2)(Z)
        Z = t.einsum('ijj->ij',[Z])
        Z = t.einsum('ij,ijk->ijk',Z,V)
        Z = t.reshape(Z,(Z.shape[1],-1))
        Z = Z@self.Wo
        Z = nn.LayerNorm(Z.shape)(Z+X)
        return Z

In [8]:
class FeedForward(nn.Module):
    def __init__(self):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(512,2048)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(2048,512)
        self.relu2 = nn.ReLU()
            
    def forward(self,X):
        Z = self.linear1(X)
        Z = self.relu1(Z)
        Z = self.linear2(Z)
        Z = self.relu2(Z)
        Z = nn.LayerNorm(Z.shape)(Z+X)
        return Z

In [9]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.attention = EncAttention()
        self.feedforward = FeedForward()
        
    def forward(self,X):
        Z = self.attention(X)
        Z = self.feedforward(Z)
        return Z

In [10]:
model1 = Encoder()

In [11]:
#Masked Attention

In [12]:
X

tensor([[0.8026, 0.8250, 0.4240,  ..., 0.1833, 0.9176, 0.2783],
        [0.6592, 0.1254, 0.0827,  ..., 0.8300, 0.7167, 0.4494],
        [0.4486, 0.2490, 0.2301,  ..., 0.5284, 0.8057, 0.7191],
        ...,
        [0.9032, 0.9598, 0.5788,  ..., 0.3619, 0.3243, 0.9281],
        [0.4655, 0.7555, 0.4076,  ..., 0.7516, 0.3761, 0.7108],
        [0.8270, 0.4645, 0.5747,  ..., 0.8043, 0.3223, 0.7804]])

In [13]:
Z = model1(X)

In [14]:
Z.shape

torch.Size([7, 512])

In [18]:
class MaskedAttention(nn.Module):
    def __init__(self):
        super(MaskedAttention, self).__init__()
        self.Wq = nn.Parameter(t.rand(nHEADS, EMBED_SIZE, HIDDEN_SIZE))
        self.Wk = nn.Parameter(t.rand(nHEADS, EMBED_SIZE, HIDDEN_SIZE))
        self.Wv = nn.Parameter(t.rand(nHEADS, EMBED_SIZE, HIDDEN_SIZE))
        self.Wo = nn.Parameter(t.rand(nHEADS*HIDDEN_SIZE, EMBED_SIZE))
        
        
    def forward(self,X):
        Q = X@self.Wq
        K = X@self.Wk
        V = X@self.Wv
        Z = t.bmm(Q,K.transpose(1,2))/sqrt(HIDDEN_SIZE)
        r,c = t.triu_indices(Z.shape[1],Z.shape[1],1)
        Z[:,r,c] = float('-inf')
        Z = nn.Softmax(dim=2)(Z)
        Z = t.einsum('ijj->ij',[Z])
        Z = t.einsum('ij,ijk->ijk',Z,V)
        Z = t.reshape(Z,(Z.shape[1],-1))
        Z = Z@self.Wo
        Z = nn.LayerNorm(Z.shape)(Z+X)
        return Z, Q

In [19]:
class EncDecAttention(nn.Module):
    def __init__(self):
        super(EncDecAttention,self).__init__()
        self.getFromEncOutput1 = nn.Sequential(
            nn.Linear(512,1024),
            nn.ReLU(),
            nn.Linear(1024,512),
            nn.ReLU()
        )
        self.getFromEncOutput2 = nn.Sequential(
            nn.Linear(512,1024),
            nn.ReLU(),
            nn.Linear(1024,512),
            nn.ReLU()
        )
        
        
    def forward(self,maskquery,enc_output):
        Q = maskquery
        K = self.getFromEncOutput1(enc_output)
        V = self.getFromEncOutput2(enc_output)
        Z = t.bmm(Q,K.transpose(1,2))/sqrt(HIDDEN_SIZE)
        r,c = t.triu_indices(Z.shape[1],Z.shape[1],1)
        Z[:,r,c] = float('-inf')
        Z = nn.Softmax(dim=2)(Z)
        Z = t.einsum('ijj->ij',[Z])
        Z = t.einsum('ij,ijk->ijk',Z,V)
        Z = t.reshape(Z,(Z.shape[1],-1))
        Z = Z@self.Wo
        Z = nn.LayerNorm(Z.shape)(Z+X)
        return Z

In [29]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        self.masked = MaskedAttention()
        self.encdec = EncDecAttention()
        self.feedforward = FeedForward()
        
    def forward(self,X,enc_output):
        Z, Q = self.masked(X)
        Z = self.encdec(Q,enc_output)
        Z = self.feedforward(Z)

In [26]:
class EncoderStack(nn.Module):
    def __init__(self):
        super(EncoderStack,self).__init__()
        self.enc1 = Encoder()
        self.enc2 = Encoder()
        self.enc3 = Encoder()
        self.enc4 = Encoder()
        self.enc5 = Encoder()
        self.enc6 = Encoder()
        
    def forward(self,X):
        Z = self.enc1(X)
        Z = self.enc2(Z)
        Z = self.enc3(Z)
        Z = self.enc4(Z)
        Z = self.enc5(Z)
        Z = self.enc6(Z)
        return Z

In [30]:
class DecoderStack(nn.Module):
    def __init__(self):
        super(DecoderStack,self).__init__()
        self.dec1 = Decoder()
        self.dec2 = Decoder()
        self.dec3 = Decoder()
        self.dec4 = Decoder()
        self.dec5 = Decoder()
        self.dec6 = Decoder()
        
    def forward(self,X, enc_output):
        Z = self.dec1(X, enc_output)
        Z = self.dec2(Z, enc_output)
        Z = self.dec3(Z, enc_output)
        Z = self.dec4(Z, enc_output)
        Z = self.dec5(Z, enc_output)
        Z = self.dec6(Z, enc_output)
        return Z