In [41]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import polars as po
import numpy as np
import pytorch_lightning as pl

In [42]:
from nltk.tokenize import word_tokenize
from nltk.probability import FreqDist
from transformers import GPT2Tokenizer
from torch.utils.data import DataChunk
from dataclasses import dataclass


In [43]:

class datasetobj:
    def __init__(self,df_path):
        self.df = po.read_csv(df_path, has_header=False, new_columns=["en", "fr"])
        self.vocab_en = FreqDist()
        self.vocab_fr = FreqDist()
        self.tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
    def __len__(self):
        return(len(self.df))
    def __getitem__(self,index):
        en = self.tokenizer(self.df[index,0])["input_ids"]
        fr = self.tokenizer(self.df[index,0])["input_ids"]
        return en,fr

### Dataset Download


In [44]:
class EnFrDataset(Dataset):
    def __init__(self,ds:datasetobj):
        super().__init__()
        self.dataset = ds
    
    def __len__(self):
        return self.dataset.height
    
    def __getitem__(self, index):
        return     

In [45]:
## learned positional_encoding, like in GPT-2
class PositionalEncodings(nn.Module):
    def __init__(self,max_seq_len:int,hidden_size:int):
        super().__init__()
        self.pos_emb = nn.Embedding(num_embeddings=max_seq_len,embedding_dim=hidden_size)
    def forward(self,positions):
        ## positions would have to be a set of indices
        return(self.pos_emb(positions))

In [46]:
@dataclass
class attnconfig:
    query_dim:int
    key_dim:int
    value_dim:int
    model_dim:int
    n_heads:int
    causal_mask:bool=False
    
    
class MultiHeadedAttention(nn.Module):
    def __init__(self,config:attnconfig):
        super().__init__()
        self.Wq = nn.ModuleList([nn.Linear(config.query_dim,config.model_dim//config.n_heads) for _ in range(config.n_heads)])
        self.Wk = nn.ModuleList([nn.Linear(config.key_dim,config.model_dim//config.n_heads) for _ in range(config.n_heads)])
        self.Wv = nn.ModuleList([nn.Linear(config.value_dim,config.model_dim//config.n_heads) for _ in range(config.n_heads)])
        self.sf = nn.Softmax(dim=-1)
        self.config=config
    def forward(self,query_vector,key_vector,value_vector):
        output =[]
        for i in range(self.config.n_heads):
            q=self.Wq[i](query_vector)
            k=self.Wk[i](key_vector)
            v=self.Wv[i](value_vector)
            A = self.sf(torch.matmul(q,k.T)/(np.sqrt(1/q.shape[1])))
            output.append(A@v)
            # print(f"Shape of head {i} output; {output[-1].shape}")
        return(torch.cat(output,dim=-1))

class MaskedMultiHeadAttention(MultiHeadedAttention):
    def forward(self,query_vector,key_vector,value_vector):
        output =[]
        for i in range(self.config.n_heads):
            q=self.Wq[i](query_vector)
            k=self.Wk[i](key_vector)
            v=self.Wv[i](value_vector)
            A = self.sf(torch.matmul(q,k.T)/(np.sqrt(1/q.shape[1])))
            A= torch.tril(input=A)
            output.append(A@v)
            # print(f"Shape of head {i} output; {output[-1].shape}")
        return(torch.cat(output,dim=-1))

In [47]:
class ResMLP(nn.Module):
    def __init__(self, input_size:int,num_layers:int):
        super().__init__()
        self.Linears = nn.ModuleList([nn.Linear(input_size,input_size) for _ in range(num_layers)])
    def forward(self,x):
        res =x
        for i in range(len(self.Linears)):
            x = self.Linears[i](x)  
        return res+x
@dataclass
class EncoderConfig:
      num_heads:int=4
      vocab_size:int=50762
      embedding_size:int=768
      max_seq_len:int=200
      atn_cfg:attnconfig=attnconfig(query_dim=embedding_size,key_dim=embedding_size,value_dim=embedding_size,model_dim=embedding_size,n_heads=num_heads)
      pos_weight:int=0.2
      mlp_depth:int=1
      
      
class TransformerEncoder(nn.Module):
    def __init__(self,config:EncoderConfig):
        super().__init__()
        self.Embedding = nn.Embedding(config.vocab_size,config.embedding_size)
        self.PositionalEncoding =  PositionalEncodings(config.max_seq_len,config.embedding_size)
        self.attn_head = MultiHeadedAttention(config=config.atn_cfg)
        self.layer_norm1 = nn.LayerNorm(config.embedding_size)
        self.res1 = ResMLP(input_size=config.embedding_size,num_layers=config.mlp_depth)
        self.layer_norm2 = nn.LayerNorm(config.embedding_size)
        self.encodercfg = config
    
    def forward(self,x):
        embs = self.Embedding(x)
        pos_embs = self.PositionalEncoding(torch.arange(0,end=x.shape[0]))
        embs = embs + self.encodercfg.pos_weight*pos_embs
        embs = self.layer_norm1(self.attn_head(embs,embs,embs) + embs)
        embs = self.layer_norm2(self.res1(embs))
        return embs
    
    
    

In [48]:
@dataclass 
class DecoderConfig:
    num_heads:int=4
    vocab_size:int=50762
    embedding_size:int=768
    max_seq_len:int=200
    atn_cfg:attnconfig=attnconfig(query_dim=embedding_size,key_dim=embedding_size,value_dim=embedding_size,model_dim=embedding_size,n_heads=num_heads)
    pos_weight:int=0.2
    mlp_depth:int=1
        
    
class TransformerDecoder(nn.Module):
    def __init__(self,config:DecoderConfig):
        super().__init__()
        self.Embedding = nn.Embedding(config.vocab_size,config.embedding_size)
        self.PositionalEncoding =  PositionalEncodings(config.max_seq_len,config.embedding_size)
        self.attn_head = MaskedMultiHeadAttention(config.atn_cfg)
        self.res1 = nn.Sequential( ResMLP(input_size=config.embedding_size,num_layers=config.mlp_depth),
                                  nn.LayerNorm(config.embedding_size))
        self.encoder_inputs = MultiHeadedAttention(config.atn_cfg)
        self.res2 = nn.Sequential(ResMLP(input_size=config.embedding_size,num_layers=config.mlp_depth),
                                    nn.LayerNorm(config.embedding_size))
        self.res3 =  self.res2 = nn.Sequential(ResMLP(input_size=config.embedding_size,num_layers=config.mlp_depth),
                                    nn.LayerNorm(config.embedding_size))
        self.fc = nn.Sequential(nn.Linear(config.embedding_size,config.vocab_size),
                                nn.Softmax(-1))
        
        self.decodercfg = config

##### Test Functions

In [49]:

multihead_config = attnconfig(query_dim=20,key_dim=20,value_dim=20,model_dim=20,n_heads=2)
selfattn_config = attnconfig(query_dim=20,key_dim=20,value_dim=20,model_dim=20,n_heads=1)

def test_mha(config:attnconfig=selfattn_config):
    mha_layer = MultiHeadedAttention(config)
    torch_mha = nn.MultiheadAttention(config.model_dim,config.n_heads)
    N_vectors = 5
    query_vector = torch.randn([N_vectors,config.query_dim])
    key_vector = torch.randn([N_vectors,config.key_dim])
    value_vector = torch.randn([N_vectors,config.value_dim])
    with torch.no_grad():
        output = mha_layer(query_vector,key_vector,value_vector)
        torchput,attn_weights = torch_mha(query_vector,key_vector,value_vector)
    print(f"This is my mha op shape: {output.shape}") ## should be n x value vector size
    print(f"This is torch mha output shape {torchput.shape}")

test_mha(selfattn_config)
test_mha(multihead_config)

This is my mha op shape: torch.Size([5, 20])
This is torch mha output shape torch.Size([5, 20])
This is my mha op shape: torch.Size([5, 20])
This is torch mha output shape torch.Size([5, 20])


In [50]:
encfg=EncoderConfig(num_heads=4,vocab_size=50762,embedding_size=768,max_seq_len=2000,pos_weight=0.2)
def test_encoder(encodercfg:EncoderConfig):
    trans = TransformerEncoder(encodercfg)
    batch_size=10
    input_tokens = torch.randint(low=0,high=encodercfg.vocab_size,size=(batch_size,))
    with torch.no_grad():
        output = trans(input_tokens)
    print(output.shape)
test_encoder(encfg)

torch.Size([10, 768])


In [None]:
class Transformer(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

In [None]:
class EnglishFrenchDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()