In [8]:
import torch
from torch import nn

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [12]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads #Model embedding size divided by the number of attention heads

        assert(heads == embed_size*head_dim) #Checks for the number of heads to be the encoding size times the dimension of each attention head
        self.values = nn.Linear(in_features = embed_size, out_features = embed_size, bias = False) #Value matrix output
        self.keys = nn.Linear(in_features = embed_size, out_features = embed_size, bias = False) #Key matrix output, trainable matrix for each
        self.queries = nn.Linear(in_features = embed_size, out_features = embed_size, bias = False) #Value matrix output
        self.fc_out = nn.Linear(in_features = embed_size, out_features = embed_size, bias = False) #Wo matrix that multiples the concatentation of all heads

    def forward(self, values, keys, query, mask=None):
        #query = (N -> training samples in a batch, query_len -> number of tokens in the query, embed_size -> embedding size of tokens)
        N = query.shape[0] #Number of training samples
        
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = self.values(values) #(N, value_len, embed_size)
        keys = self.keys(keys)
        queries = self.queries(query)

        #Note: embed_size = head_dim * heads
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)

        # Queries, Keys's dimensions = (N,N_tokens, d_k)
        #Need einsum to generalize multiplication of query and key vectors
        #nqhd: n_token, ,query_len, num_heads, head_dim
        #nkhd: key_len != query_len always
        #Resulting einsum: nhqk -> How much each key affects query

        attn_scores = np.einsum("nqhd,nkhd->nhqk",[queries,keys])

        #Handling batches: Truncate longer setences greater than threshold, Pad shorter sentences
        #Masked tokens: Make the softmax probability zero: Store them as a really large negative floating point number

        if mask is not None:
            attn_scores = attn_scores.masked_fill(masked == 0, float(1e-20))

        #Scale and normalize attention score
        attention = torch.softmax(attn_scores/(self.head_dim ** (0.5)),dim = 3) #dim = 3 -> k in nhqk attention scores dimensions indexed from 0

        #Einsum with attention and values
        #Note: key_len = value_len: Defined across the same set of tokens (key_space and value_space)
        out = torch.einsum("nhql,nlhd->nqhd",[attention,values]).reshape(N,query_len,self.heads * self.head_dim)

        out = self.fc_out(out)

        return out

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, dropout, forward_expansion):
        super().__init__()
        self.attention = SelfAttention(embed_size,heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        #Up-project it to forward_expansion * embed_size and then down-project it to embed_size
        self.feed_forward = nn.Sequential(
            nn.Linear(in_features = embed_size, out_features = forward_expansion * embed_size)
            nn.ReLU(),
            nn.Linear(in_features = forward_expansion * embed_size, out_features = embed_size)
        )

        self.dropout = nn.Dropout(dropout) #Regularization, doesnt overfit during training

    def forward(self,value,key,query,mask):
        X = self.dropout(self.norm1(attention + query))  # Add input query to attention then normalize the output
        forward = self.feed_forward(X)
        out = self.dropout(self.norm2(forward + X))
        return out

In [None]:
class Encoder(nn.Module):
    def __init__(self, src_vocab_size,