In [1]:
import torch as t
from math import sqrt
import torch.nn as nn
from torchinfo import summary

In [2]:
# EMBED_SIZE refers to representation size of each word
# HIDDEN_SIZE refers to size of query, key and value vectors.

In [3]:
nHEADS = 8
EMBED_SIZE = 512
HIDDEN_SIZE = 64

In [4]:
print(t.__version__)

2.1.0.dev20230510+cu118


In [5]:
X1 = t.rand(512)
X2 = t.rand(512)

In [6]:
Wq = t.rand(nHEADS, EMBED_SIZE, HIDDEN_SIZE,requires_grad=True)
Wk = t.rand(nHEADS, EMBED_SIZE, HIDDEN_SIZE,requires_grad=True)
Wv = t.rand(nHEADS, EMBED_SIZE, HIDDEN_SIZE,requires_grad=True)
Wo = t.rand(nHEADS*HIDDEN_SIZE, EMBED_SIZE, requires_grad=True)

In [7]:
Wq.shape

torch.Size([8, 512, 64])

In [8]:
X = t.vstack((X1,X2))

In [9]:
Q = X@Wq
K = X@Wk
V = X@Wv

In [10]:
Q.shape

torch.Size([8, 2, 64])

In [11]:
K.transpose(1,2).shape

torch.Size([8, 64, 2])

In [12]:
X.shape

torch.Size([2, 512])

In [14]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.Wq = nn.Parameter(t.rand(nHEADS, EMBED_SIZE, HIDDEN_SIZE))
        self.Wk = nn.Parameter(t.rand(nHEADS, EMBED_SIZE, HIDDEN_SIZE))
        self.Wv = nn.Parameter(t.rand(nHEADS, EMBED_SIZE, HIDDEN_SIZE))
        self.Wo = nn.Parameter(t.rand(nHEADS*HIDDEN_SIZE, EMBED_SIZE))
        self.linear1 = nn.Linear(512,2048)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(2048,512)
        self.relu2 = nn.ReLU()
        
    def forward(self,X):
        Q = X@self.Wq
        K = X@self.Wk
        V = X@self.Wv
        Z = t.bmm(Q,K.transpose(1,2))/sqrt(HIDDEN_SIZE)
        Z = nn.Softmax(dim=2)(Z)
        Z = t.einsum('ijj->ij',[Z])
        Z = t.einsum('ij,ijk->ijk',Z,V)
        Z = t.reshape(Z,(Z.shape[1],-1))
        Z = Z@self.Wo
        Z = nn.LayerNorm(Z.shape)(Z+X)
        residual = Z
        Z = self.linear1(Z)
        Z = self.relu1(Z)
        Z = self.linear2(Z)
        Z = self.relu2(Z)
        Z = nn.LayerNorm(Z.shape)(Z+residual)
        return Z

In [15]:
model1 = Encoder()