In [1]:
import torch
import torch.nn as nn

In [183]:
class SelfAttention(nn.Module):
    def __init__(self,embed_size,n_heads):
        super(SelfAttention,self).__init__()
        self.embed_size = embed_size
        self.n_heads = n_heads
        self.head_dims = embed_size//n_heads
        # assert self.embed_size // self.n_heads == 0, "This cannot be the number of heads in the Encoder"

        self.values = nn.Linear(self.head_dims,self.head_dims,bias = False)
        self.keys = nn.Linear(self.head_dims,self.head_dims,bias = False)
        self.query = nn.Linear(self.head_dims,self.head_dims,bias = False)
        self.fc_out = nn.Linear(self.embed_size,self.embed_size,bias = False)

    def forward(self,values,keys,query,mask):
        N = query.shape[0]
        value_len,key_len,query_len = values.shape[1],keys.shape[1],query.shape[1]
        values = values.reshape(N,value_len,self.n_heads,self.head_dims)
        keys= keys.reshape(N,key_len,self.n_heads,self.head_dims)
        query = query.reshape(N,query_len,self.n_heads,self.head_dims)
        values = self.values(values)
        keys = self.keys(keys)
        query = self.query(query)
        

        unormalized_attention = torch.einsum('nqhd,nkhd->nhqk',[query,keys])

        if mask is not None: # This is so that the padded zeros to the end do not participate in weight changes
            unormalized_attention = unormalized_attention.masked_fill(mask == 0,float('-1e20'))
        normalized_attention = torch.softmax(unormalized_attention/(self.embed_size**0.5),dim = 3)

        out = torch.einsum('nhql,nlhd->nqhd',[normalized_attention,values]).reshape(N,query_len,self.embed_size)
        out = self.fc_out(out)
        return out

In [184]:
class TransformerBlock(nn.Module):
    def __init__(self,embed_size,n_heads,dropout,forward_expansion):
        super(TransformerBlock,self).__init__()
        self.attention = SelfAttention(embed_size,n_heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size,forward_expansion*embed_size,bias=False),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size,embed_size,bias= False)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self,values,keys,query,mask):
        attention = self.attention(values,keys,query,mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

In [185]:
class Encoder(nn.Module):
    def __init__(self,source_vocab_size,embed_size,num_layers,n_heads,device,forward_expansion,dropout,max_length):
        super(Encoder,self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size,embed_size)
        self.positional_encoding = nn.Embedding(max_length,embed_size)
        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size,n_heads,dropout,forward_expansion)
            for layer in  range(num_layers)] 
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,mask):
        N,seq_length = x.shape
        positions = torch.arange(0,seq_length).expand(N,seq_length).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.positional_encoding(positions))
        for layer in self.layers:
            out = layer(out,out,out,mask)
        return out

In [186]:
class DecoderBlock(nn.Module):
    def __init__(self,embed_size,n_heads,forward_expansion,dropout):
        super(DecoderBlock,self).__init__()
        self.attention = SelfAttention(embed_size,n_heads)
        self.norm = nn.LayerNorm(embed_size)
        self.embed_size = embed_size
        self.transformer_block = TransformerBlock(embed_size,n_heads,dropout,forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,value,key,src_mask,tgt_mask):
        attention = self.attention(x,x,x,tgt_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value,key,query,src_mask)
        return out

In [217]:
class Decoder(nn.Module):
    def __init__(self,tgt_vocab_size,embed_size,num_layers,n_heads,forward_expansion,dropout,device,max_length):
        super(Decoder,self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(tgt_vocab_size,embed_size)
        self.position_encoding = nn.Embedding(max_length,embed_size)
        self.layers = nn.ModuleList([
            DecoderBlock(embed_size,n_heads,forward_expansion,dropout) for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(embed_size,tgt_vocab_size,bias = False)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,enc_out,src_mask,tgt_mask):
        N,seq_length = x.shape
        positions = torch.arange(0,seq_length).expand(N,seq_length).to(self.device)
        X = self.dropout(self.word_embedding(x) + self.position_encoding(positions))

        for layer in self.layers:
            X = layer(X,enc_out,enc_out,src_mask,tgt_mask)
        out = self.fc_out(X)
        return out
        
        

In [218]:
class Transformer(nn.Module):
    def __init__(self,src_vocab_size,tgt_vocab_size,src_pad_idx,tgt_pad_idx,device,embed_size = 256,num_layers = 6,
                forward_expansion=4,n_heads = 8, dropout = 0.1,
                max_length = 100):
        super(Transformer,self).__init__()
        self.encoder = Encoder(
            src_vocab_size,embed_size,num_layers,n_heads,device,forward_expansion,dropout,max_length
        )
        self.decoder = Decoder(
            tgt_vocab_size,embed_size,num_layers,n_heads,forward_expansion,dropout,device,max_length
        )
        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx
        self.device = device

    def make_src_mask(self,src):
        src_mask = (src!=self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(self.device)

    def make_target_mask(self,tgt):
        N,trg_len = tgt.shape
        trg_mask = torch.tril(torch.ones((trg_len,trg_len))).expand(N,1,trg_len,trg_len)
        return trg_mask.to(self.device)

    def forward(self,src,tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_target_mask(tgt)
        enc_src = self.encoder(src,src_mask)
        out = self.decoder(tgt,enc_src,src_mask,tgt_mask)
        return out

In [219]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [220]:
x = torch.tensor([[1,5,6,4,3,9,5,2,0],[1,8,7,3,4,5,6,7,2]]).to(device)

In [221]:
y = torch.tensor([[1,7,4,3,5,9,2,0],[1,5,6,2,4,7,6,2]]).to(device)

In [222]:
src_pad_idx = 0
tgt_pad_idx = 0
src_vocab_size = 10
tgt_vocab_size = 10

In [223]:
model = Transformer(src_vocab_size,tgt_vocab_size,src_pad_idx,tgt_pad_idx,device)

In [224]:
out = model(x,y[:,:-1])

In [226]:
out.shape

torch.Size([2, 7, 10])

In [228]:
y.shape

torch.Size([2, 8])