# Transformer Architecture
Follows the "Attention is all you need" paper architecture

In [49]:
import torch
from torch import nn
import math
from torch.utils.data import Dataset, DataLoader
from torchtext import transforms

# Input

## Token Embeddings

In [6]:
class EmeddingsLayer(nn.Module):
    def __init__(self, d_model:int, vocab_size: int):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.embedding = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.d_model)

    def forward(self, X):
        return self.embedding(X) * math.sqrt(self.d_model) # "In the embedding layers, we multiply those weights by sqrt(d_model)"


## Positional Ecodings

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model:int, context_size:int):
        super().__init__()
        self.d_model = d_model
        self.context_size = context_size

        self.pe = torch.zeros(self.context_size, self.d_model,requires_grad=False)
        for pos in range(self.context_size):
            for i in range(0, self.d_model, 2):
                self.pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/self.d_model)))
                self.pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/self.d_model)))
    
    def forward(self):
        return self.pe.unsqueeze(0)

## Test embedding and position encodings

In [8]:

embed_test = EmeddingsLayer(d_model=512,vocab_size=50000)
pencoding = PositionalEncoding(d_model=512,context_size=1024)
example_data = torch.randint(1,50000,(64,1024))
print(example_data.shape)
embed_output = embed_test(example_data)
print(embed_output.shape,embed_output[0][0][:10])
pe_output = pencoding()
print(pe_output.shape,pe_output[0][0][:10])
embed_pos_output = embed_output + pe_output
print(embed_pos_output.shape,embed_pos_output[0][0][:10])

torch.Size([64, 1024])
torch.Size([64, 1024, 512]) tensor([ 37.6751,  27.8283,  -0.8374,  -6.8180,  -5.5830,  52.5307, -13.6899,
          5.2778,  -3.5278,   6.6493], grad_fn=<SliceBackward0>)
torch.Size([1, 1024, 512]) tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
torch.Size([64, 1024, 512]) tensor([ 37.6751,  28.8283,  -0.8374,  -5.8180,  -5.5830,  53.5307, -13.6899,
          6.2778,  -3.5278,   7.6493], grad_fn=<SliceBackward0>)


# Attention

In [10]:
class AttentionHead(nn.Module):
    
    def __init__(self, head_dim:int,p_drop:float,masked:bool = False) -> None:
        super().__init__()
        self.masked = masked
        self.queries= nn.Linear(in_features=head_dim,out_features=head_dim) # kaparthy set bias=False why?
        self.keys = nn.Linear(in_features=head_dim,out_features=head_dim) # kaparthy set bias=False why?
        self.values = nn.Linear(in_features=head_dim,out_features=head_dim) # kaparthy set bias=False why?
        self.dropout = nn.Dropout(p=p_drop)


    def forward(self,Q:torch.Tensor,K:torch.Tensor,V:torch.Tensor) -> torch.Tensor:
        B,T,C = K.shape
        Q = self.dropout(self.queries(Q))
        K = self.dropout(self.keys(K))
        V = self.dropout(self.values(V))

        scaled_dot_product_attention = (Q @ K.transpose(2,1))/torch.sqrt(torch.tensor(C))

        if self.masked:
            mask = torch.tril(torch.ones(T,T)) == 0
            scaled_dot_product_attention = scaled_dot_product_attention.masked_fill(mask,-float("inf"))

        dot_product_softened = torch.softmax(scaled_dot_product_attention,dim=-1)
        return dot_product_softened @ V


In [11]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self,d_model:int, p_drop:float, num_heads:int = 8, masked = False) -> None:
        super().__init__()
        self.head_dim = math.floor(d_model/num_heads)
        self.layer_norm = nn.LayerNorm(d_model)
        self.heads = [AttentionHead(head_dim=self.head_dim,p_drop=p_drop, masked=masked) for h in range(num_heads)]
        self.linear = nn.Linear(d_model,d_model)
        self.dropout = nn.Dropout(p=p_drop)
    
    def forward(self,X:torch.Tensor,Q:tuple,K:tuple,V:tuple) ->torch.Tensor:
        heads_output = []
        for head_index,head in enumerate(self.heads):
                queries = Q[head_index]
                keys = K[head_index]
                values = V[head_index]
                v = head(queries,keys,values) # this could be distributed to multiple devices for // processing
                heads_output.append(v) # accumulate result
        
        o = torch.cat(heads_output,dim=-1)
        linear_output = self.linear(o)
        dropped_output = self.dropout(linear_output)
        return self.layer_norm(X+dropped_output)


mhsa = MultiHeadSelfAttention(d_model=512,p_drop=0.1,num_heads=8,masked=True)
sample_data = torch.randn((5,10,512))
splits = torch.split(sample_data,64,dim=2)
mhsa(sample_data,splits,splits,splits).shape

torch.Size([5, 10, 512])

# Position-wise Feedforward Network

In [12]:
class FeedForward(nn.Module):
    def __init__(self,d_model:int,p_drop:float,d_ff:int) -> None:
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(in_features=d_model,out_features=d_ff),
            nn.ReLU(),
            nn.Dropout(p=p_drop),
            nn.Linear(in_features=d_ff,out_features=d_model)
        )
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.layer_norm(X+self.ffn(X))



# Encorder

In [13]:
class EncoderLayer(nn.Module):
    def __init__(self,d_model:int,p_drop:float, d_ff:int,num_heads:int,**kwargs) -> None:
        super().__init__()
        self.head_dim = d_model // num_heads
        self.multihead_self_attention = MultiHeadSelfAttention(d_model=d_model,p_drop=p_drop,num_heads=num_heads)
        self.feedforward = FeedForward(d_model=d_model,p_drop=p_drop,d_ff=d_ff)

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        splits = torch.split(X,self.head_dim,dim=2)
        return self.feedforward(self.multihead_self_attention(X,splits,splits,splits))
    
class Encoder(nn.Module):
    def __init__(self, number_of_encoder_blocks:int=6,**kwargs) -> None:
        super().__init__()
        self.encode = nn.Sequential(*[EncoderLayer(**kwargs) for n in range(number_of_encoder_blocks)])

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.encode(X)


# Decoder

In [14]:
class DecoderLayer(nn.Module):
    def __init__(self,d_model:int,p_drop:float,d_ff:int,num_heads:int,**kwargs) -> None:
        super().__init__()
        self.head_dim = d_model // num_heads
        self.masked_multi_head_self_attention = MultiHeadSelfAttention(d_model=d_model,p_drop=p_drop,num_heads=num_heads,masked=True)
        self.multi_head_self_attention = MultiHeadSelfAttention(d_model=d_model,p_drop=p_drop,num_heads=num_heads,masked=False)
        self.feedforward = FeedForward(d_model=d_model,p_drop=p_drop,d_ff=d_ff)

    def forward(self,outputs:torch.Tensor,encoded_sequence:torch.Tensor) -> torch.Tensor:
        output_splits = torch.split(outputs,self.head_dim,dim=2)
        encoded_sequence_splits = torch.split(encoded_sequence,self.head_dim,dim=2)
        masked_output = self.masked_multi_head_self_attention(outputs,output_splits,output_splits,output_splits)
        mhsa_output = self.multi_head_self_attention(masked_output,Q=masked_output,K=encoded_sequence_splits,V=encoded_sequence_splits)
        
        return self.feedforward(mhsa_output)


class Decoder(nn.Module):
    def __init__(self, number_of_decoder_blocks:int, **kwargs) -> None:
        super().__init__()
        self.decoder_layers = nn.ModuleList([DecoderLayer(**kwargs) for n in range(number_of_decoder_blocks)])
        
    def forward(self,outputs:torch.Tensor,encoded_sequence:torch.Tensor) -> torch.Tensor:
        for decoder_layer in self.decoder_layers:
            outputs = decoder_layer(outputs,encoded_sequence)
        return outputs



# Transformer

In [15]:
class Transformer(nn.Module):
    def __init__(self,
                 vocab_size:int,
                 batch_size:int,
                 context_size:int,
                 d_model:int,
                 d_ff:int,
                 num_heads:int,
                 number_of_encoder_blocks:int,
                 number_of_decoder_blocks:int,
                 p_drop:float):
        
        super().__init__()
        self.context_size = context_size

        self.embedding = EmeddingsLayer(d_model=d_model,vocab_size=vocab_size)
        self.positional_encoding = PositionalEncoding(d_model=d_model,context_size=context_size)
        self.dropout = nn.Dropout(p=p_drop)
        self.encoder = Encoder(
                            vocab_size=batch_size,
                            batch_size=batch_size,
                            context_size=context_size,
                            d_model=d_model,
                            d_ff=d_ff,
                            num_heads=num_heads,
                            number_of_encoder_blocks=number_of_encoder_blocks,
                            p_drop=p_drop)
        
        self.decoder = Decoder(
                            vocab_size=batch_size,
                            batch_size=batch_size,
                            context_size=context_size,
                            d_model=d_model,
                            p_drop=p_drop,
                            d_ff=d_ff,
                            num_heads=num_heads,
                            number_of_decoder_blocks=number_of_decoder_blocks)
        
        self.linear = nn.Linear(in_features=d_model,out_features=vocab_size)


    def forward(self,X:torch.Tensor,y:torch.Tensor) -> torch.Tensor:
        pos_encoding = self.positional_encoding()

        input_embeddings = self.embedding(X) 
        inputs = self.dropout(input_embeddings+pos_encoding) # B*T*C
        
        encoded_sequence = self.encoder(inputs)
        output_embedding = self.embedding(y)
        outputs = self.dropout(output_embedding+pos_encoding) # B*T*C

        decoder_output = self.decoder(outputs,encoded_sequence)

        output_logits = self.linear(decoder_output)
        output_probs = torch.softmax(output_logits,dim=-1)


        return output_probs

In [77]:
config = {
    "vocab_size":37000,
    "batch_size":64,
    "context_size":10,
    "d_model":512,
    "num_heads":1,
    "d_ff":2048,
    "number_of_encoder_blocks": 6,
    "number_of_decoder_blocks": 6,
    "p_drop":0.1
}
device = "cuda" if torch.cuda.is_available() else "cpu"

model = Transformer(**config)

X = torch.randint(1,config["vocab_size"],(config["batch_size"],config["context_size"]))
y = torch.randint(1,config["vocab_size"],(config["batch_size"],config["context_size"]))

o = model(X,y)
print(o.argmax(dim=-1)) # will substitute for sampling for now

tensor([[30908, 25044, 10012, 22277, 33097, 26681, 15800,   328,  1235, 17931],
        [11908, 22090,  2252, 21624, 17150,  8975, 15391, 36947, 10469,  7199],
        [23676, 36830, 29603, 15104, 17268, 21152, 31797, 27589, 18469, 12114],
        [33004,  3447, 18985, 23113,  4303, 21491, 29902, 24992, 34827, 14415],
        [31494, 26420, 19301, 27433, 34343, 31011, 29458, 16520,  3884, 24261],
        [35821, 36820, 34717,  9134,  9943, 20816, 10174, 12720, 32344, 30117],
        [ 6371, 12510, 26485, 29164, 24968, 16980,  4347, 14114, 13843, 16131],
        [ 5055, 13493, 36506,  7870,  3387,  4982,  3830, 14695, 13812,  2677],
        [10376, 27180,  3044, 17918, 30099, 31488,  1198, 14573, 34837,  7715],
        [17361, 10007, 14353,  7483, 13443, 24416, 21664, 36106, 36532, 14707],
        [13519, 23756,  8392, 25098, 26130, 25333, 35014, 26205, 17080, 25167],
        [ 3796,  6077,  2738,  8693,  4961, 11041, 24364, 30279, 27484, 11578],
        [ 2599,  5368, 33832, 27697, 299

# Data Preparation

In [22]:
swa_sentences = []
with open("./data/translate/gamayun_kit5k.swa","r") as f:
    swa_sentences = f.readlines()
eng_sentences = []
with open("./data/translate/gamayun_kit5k.eng","r") as f:
    eng_sentences = f.readlines()

swa_sentences = [s.rstrip("\n") for s in swa_sentences]
eng_sentences = [s.rstrip("\n") for s in eng_sentences]

print(f"Size of swahili dataset: {len(swa_sentences)} ")
print(f"Size of english dataset: {len(eng_sentences)} ")
print(f"Max swahili sentence: {max([len(s) for s in swa_sentences])} ")
print(f"Max english sentence: {max([len(s) for s in eng_sentences])} ")

print(swa_sentences[:5])
print(eng_sentences[:5])

Size of swahili dataset: 5000 
Size of english dataset: 5000 
Max swahili sentence: 249 
Max english sentence: 233 
['Huyo ni rafiki yako mpya?', 'Job hana hamu ya mpira wa vikapu.', 'Adam aliniambia kuwa Alice alikuwa na mpenzi mpya wa kiume', 'Radio haikutanga kuhusu ajali hiyo.', 'Adamu ana wasiwasi tutapotea.']
['Is that your new friend?', "Jacob wasn't interested in baseball.", 'Adam told me that Alice had a new boyfriend.', "The radio didn't inform about the accident.", "Adam is worried we'll get lost."]


In [42]:
max_sentence_length = 250

In [33]:
START_TOKEN = '>'
PADDING_TOKEN = '<'
END_TOKEN = '--'

swa_vocab = list(set(''.join(swa_sentences)))
swa_vocab.insert(0,START_TOKEN)
swa_vocab.append(PADDING_TOKEN)
swa_vocab.append(END_TOKEN)
eng_vocab = list(set(''.join(eng_sentences)))
eng_vocab.insert(0,START_TOKEN)
eng_vocab.append(PADDING_TOKEN)
eng_vocab.append(END_TOKEN)

print(swa_vocab)
print(eng_vocab)

print(f"Eng vocab_size :{len(swa_vocab)}")
print(f"Swa vocab_size :{len(eng_vocab)}")

swa_token_to_index = {t:i for i,t in enumerate(swa_vocab)}
print(swa_token_to_index)
swa_index_to_token = {i:t for i,t in enumerate(swa_vocab)}
eng_token_to_index = {t:i for i,t in enumerate(eng_vocab)}
print(eng_token_to_index)
eng_index_to_token = {i:t for i,t in enumerate(eng_vocab)}

['>', 'n', '3', '0', 'o', '7', 'I', 'J', '6', 'l', 'p', 'y', ':', 'f', 'e', 'u', '(', 'm', 'P', 's', 'k', 'U', 'a', 'z', 'F', '/', 'v', 'c', '8', 'D', 'Z', '5', '$', 'O', '&', 'g', 'V', 'Y', ',', '.', 't', 'G', 'B', '1', '?', ';', "'", ')', 'S', '-', 'R', 'W', 'A', '!', 'T', 'j', 'i', 'q', 'r', '”', '"', 'w', '4', 'H', 'C', 'L', 'M', ' ', 'd', 'x', '\u200b', 'Q', 'b', 'N', 'E', 'K', '9', '+', '—', '2', 'h', '<', '--']
['>', 'n', '3', '0', 'o', '’', '7', 'I', 'J', '6', 'l', 'p', 'é', 'y', ':', 'f', 'e', 'u', '(', 'm', 'P', 's', 'k', 'U', 'a', 'z', 'F', 'v', 'c', '8', 'D', '5', 'Z', '$', 'O', '&', 'g', 'V', 'Y', ',', 't', '.', 'G', 'B', '1', '?', ';', "'", ')', 'S', '-', 'W', 'R', 'à', 'A', '!', 'T', 'j', 'i', '_', 'q', 'r', '”', '"', 'w', '4', 'H', 'C', 'L', 'M', ' ', 'd', '°', 'x', 'Q', 'b', '“', 'N', 'E', '9', 'K', '—', '2', 'h', '<', '--']
Eng vocab_size :83
Swa vocab_size :86
{'>': 0, 'n': 1, '3': 2, '0': 3, 'o': 4, '7': 5, 'I': 6, 'J': 7, '6': 8, 'l': 9, 'p': 10, 'y': 11, ':': 12, 

In [34]:
swahili_sentences_tokenized = [[swa_token_to_index[t] for t in s] for s in swa_sentences]
english_sentences_tokenized = [[eng_token_to_index[t] for t in s] for s in eng_sentences]

In [35]:
swahili_sentences_tokenized_train = swahili_sentences_tokenized[:4500]
swahili_sentences_tokenized_test = swahili_sentences_tokenized[4500:]
english_sentences_tokenized_train = english_sentences_tokenized[:4500]
english_sentences_tokenized_test = english_sentences_tokenized[4500:]

In [67]:
class TranslationDataset(Dataset):

    def __init__(self, swahili_sentences, english_sentences,transforms=None):
        self.swahili_sentences = swahili_sentences
        self.english_sentences = english_sentences
        self.transforms = transforms

    def __len__(self):
        return len(self.english_sentences)

    def __getitem__(self, idx):
        eng_sentence = self.english_sentences[idx]
        swa_sentence = self.swahili_sentences[idx]

        for _ in range(len(eng_sentence), max_sentence_length):
            eng_sentence.append(eng_token_to_index[PADDING_TOKEN])
        for _ in range(len(swa_sentence), max_sentence_length):
            swa_sentence.append(swa_token_to_index[PADDING_TOKEN])

        if self.transforms:
            swa_sentence = self.transforms(swa_sentence)
            eng_sentence = self.transforms(eng_sentence)

        # print(eng_sentence.shape,swa_sentence.shape)
        return eng_sentence, swa_sentence

In [68]:
training_dataset = TranslationDataset(swahili_sentences=swahili_sentences_tokenized_train,english_sentences=english_sentences_tokenized_train,transforms=transforms.ToTensor())
testing_dataset = TranslationDataset(swahili_sentences=swahili_sentences_tokenized_test,english_sentences=english_sentences_tokenized_test,transforms=transforms.ToTensor())
training_dataloader = DataLoader(training_dataset,batch_size=10,shuffle=True)
testing_dataloader = DataLoader(testing_dataset,batch_size=10,shuffle=False)

eng_sentence,swa_sentence = next(iter(training_dataloader))
print(eng_sentence.shape,swa_sentence.shape)
print(eng_sentence[0])

torch.Size([10, 250]) torch.Size([10, 250])
tensor([69, 13, 70, 15, 61, 58, 16,  1, 71, 21, 70, 64, 24,  1, 40, 16, 71, 70,
        71, 16, 40, 24, 58, 10, 21, 41, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84,
        84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84,
        84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84,
        84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84,
        84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84,
        84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84,
        84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84,
        84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84,
        84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84,
        84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84,
        84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 

# Train Transformer


In [None]:
config = {
    "vocab_size":37000,
    "batch_size":64,
    "context_size":10,
    "d_model":512,
    "num_heads":1,
    "d_ff":2048,
    "number_of_encoder_blocks": 6,
    "number_of_decoder_blocks": 6,
    "p_drop":0.1
}

EPOCHS = 1





In [None]:
def create_masks(eng_batch, kn_batch):
    num_sentences = len(eng_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      eng_sentence_length, kn_sentence_length = len(eng_batch[idx]), len(kn_batch[idx])
      eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
      kn_chars_to_padding_mask = np.arange(kn_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
      encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, kn_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, kn_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, kn_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    print(f"encoder_self_attention_mask {encoder_self_attention_mask.size()}: {encoder_self_attention_mask[0, :10, :10]}")
    print(f"decoder_self_attention_mask {decoder_self_attention_mask.size()}: {decoder_self_attention_mask[0, :10, :10]}")
    print(f"decoder_cross_attention_mask {decoder_cross_attention_mask.size()}: {decoder_cross_attention_mask[0, :10, :10]}")
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask
