In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import polars as po
import numpy as np
import pytorch_lightning as pl
import lightning

  from .autonotebook import tqdm as notebook_tqdm


In [26]:
from nltk.tokenize import word_tokenize
from nltk.probability import FreqDist
from transformers import GPT2Tokenizer
from dataclasses import dataclass
from torch.utils.data import DataChunk


In [27]:
class datasetobj:
    def __init__(self,df_path):
        self.df = po.read_csv(df_path, has_header=False, new_columns=["en", "fr"])
        self.vocab_en = FreqDist()
        self.vocab_fr = FreqDist()
        self.tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
    def __len__(self):
        return(len(self.df))
    def __getitem__(self,index):
        en = self.tokenizer(self.df[index,0])["input_ids"]
        fr = self.tokenizer(self.df[index,1])["input_ids"]
        return en,fr

In [28]:
import os
from torch.utils.data import random_split
class EnFinnishDataset(torch.utils.data.Dataset):
    def __init__(self,archive_path:str,context_len:int=512):
        super().__init__()
        with open(os.path.join(archive_path,"EUbookshop.en")) as fp:
            self.english_corpus = fp.readlines()
        with open(os.path.join(archive_path,"EUbookshop.fi")) as fp:
            self.finnish_corpus = fp.readlines()
        self.tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
        self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
        print(self.tokenizer.pad_token)      
        print(self.tokenizer.pad_token_id)     

    def __getitem__(self, index):
        en_tokens = self.tokenizer(self.english_corpus[index],
                                   padding="max_length",
                                   max_length=512)["input_ids"]
        finnish_tokens = self.tokenizer(self.finnish_corpus[index],padding="max_length",max_length=512)["input_ids"]
        return (en_tokens,finnish_tokens)

@dataclass
class DataModuleConfig:
    archive_path:str="EUbookshop-1"
    batch_size:int=32
    train_test:float=0.8
    train_val:float=0.8 
    context_len:int=512
     
class EnFinDataModule(lightning.LightningDataModule):
    def __init__(self,
                 config:DataModuleConfig):
        super().__init__()
        self.config = config
    def setup(self, stage):
        FullDataset = EnFinnishDataset(self.config.archive_path)
        self.train_ds,self.valds = random_split(FullDataset,lengths=[self.config.train_test*len(FullDataset),(1-self.config.train_test)*len(FullDataset)])
    def train_dataloader(self):
        return DataLoader(self.train_ds,batch_size=self.config.batch_size)
    def val_dataloader(self):
        return DataLoader(self.train_ds,batch_size=self.config.batch_size)

In [29]:
def testDataset(archive:str="EUbookshop-1"):
    ds = EnFinnishDataset(archive)
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.add_special_tokens({'pad_token': '<pad>'})
    for i in range(5):
        en_tokens,fi_tokens  = ds[i]
        print(en_tokens)        
        print(f"Number of en tokens : {len(en_tokens)}")
        print(f"Number of fi tokens : {len(fi_tokens)}")
        print(f"English: {tokenizer.decode(en_tokens)}")
        print(f"Finnish: {tokenizer.decode(fi_tokens)}")
        
testDataset()


<pad>
50257
[18467, 1044, 198, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 502

### Dataset Download


In [45]:
## learned positional_encoding, like in GPT-2
class PositionalEncodings(nn.Module):
    def __init__(self,max_seq_len:int,hidden_size:int):
        super().__init__()
        self.pos_emb = nn.Embedding(num_embeddings=max_seq_len,embedding_dim=hidden_size)
    def forward(self,positions):
        ## positions would have to be a set of indices
        return(self.pos_emb(positions))

In [None]:

def rotation_matrix(m:int,context_len:int=512):
    thetas = torch.pow(1e4,torch.arange(start=0,end=-context_len+2,step=-2)/context_len)
    cosines = torch.cos_(m*thetas)
    sines = torch.cos_(m*thetas)
    ro_pe = torch.zeros(size=(context_len,context_len))
    
class RotaryPE(nn.Module):
    def __init__(self, ):
        super().__init__()
    def forward(self,positions):
        ## positions is an 

In [46]:
@dataclass
class attnconfig:
    query_dim:int
    key_dim:int
    value_dim:int
    model_dim:int
    n_heads:int
    causal_mask:bool=False
    
    
class MultiHeadedAttention(nn.Module):
    def __init__(self,config:attnconfig):
        super().__init__()
        self.Wq = nn.ModuleList([nn.Linear(config.query_dim,config.model_dim//config.n_heads) for _ in range(config.n_heads)])
        self.Wk = nn.ModuleList([nn.Linear(config.key_dim,config.model_dim//config.n_heads) for _ in range(config.n_heads)])
        self.Wv = nn.ModuleList([nn.Linear(config.value_dim,config.model_dim//config.n_heads) for _ in range(config.n_heads)])
        self.sf = nn.Softmax(dim=-1)
        self.config=config
    def forward(self,query_vector,key_vector,value_vector):
        output =[]
        for i in range(self.config.n_heads):
            q=self.Wq[i](query_vector)
            k=self.Wk[i](key_vector)
            v=self.Wv[i](value_vector)
            A = self.sf(torch.matmul(q,k.T)/(np.sqrt(1/q.shape[1])))
            output.append(A@v)
            # print(f"Shape of head {i} output; {output[-1].shape}")
        return(torch.cat(output,dim=-1))

class MaskedMultiHeadAttention(MultiHeadedAttention):
    def forward(self,query_vector,key_vector,value_vector):
        output =[]
        for i in range(self.config.n_heads):
            q=self.Wq[i](query_vector)
            k=self.Wk[i](key_vector)
            v=self.Wv[i](value_vector)
            A = self.sf(torch.matmul(q,k.T)/(np.sqrt(1/q.shape[1])))
            A= torch.tril(input=A)
            output.append(A@v)
            # print(f"Shape of head {i} output; {output[-1].shape}")
        return(torch.cat(output,dim=-1))

In [47]:
class ResMLP(nn.Module):
    def __init__(self, input_size:int,num_layers:int):
        super().__init__()
        self.Linears = nn.ModuleList([nn.Linear(input_size,input_size) for _ in range(num_layers)])
    def forward(self,x):
        res =x
        for i in range(len(self.Linears)):
            x = self.Linears[i](x)  
        return res+x
@dataclass
class EncoderConfig:
      num_heads:int=4
      vocab_size:int=50762
      embedding_size:int=768
      max_seq_len:int=200
      atn_cfg:attnconfig=attnconfig(query_dim=embedding_size,key_dim=embedding_size,value_dim=embedding_size,model_dim=embedding_size,n_heads=num_heads)
      pos_weight:int=0.2
      mlp_depth:int=1
      
      

class TransformerEncoderBlock(nn.Module):
    def __init__(self,config:EncoderConfig):
        super().__init__()
        self.Embedding = nn.Embedding(config.vocab_size,config.embedding_size)
        self.PositionalEncoding =  PositionalEncodings(config.max_seq_len,config.embedding_size)
        self.attn_head = MultiHeadedAttention(config=config.atn_cfg)
        self.layer_norm1 = nn.LayerNorm(config.embedding_size)
        self.res1 = ResMLP(input_size=config.embedding_size,num_layers=config.mlp_depth)
        self.layer_norm2 = nn.LayerNorm(config.embedding_size)
        self.encodercfg = config
    
    def forward(self,x):
        embs = self.Embedding(x)
        pos_embs = self.PositionalEncoding(torch.arange(0,end=x.shape[0]))
        embs = embs + self.encodercfg.pos_weight*pos_embs
        embs = self.layer_norm1(self.attn_head(embs,embs,embs) + embs)
        embs = self.layer_norm2(self.res1(embs))
        return embs
    
    
    

In [48]:
@dataclass 
class DecoderConfig:
    num_heads:int=4
    vocab_size:int=50762
    embedding_size:int=768
    max_seq_len:int=200
    atn_cfg:attnconfig=attnconfig(query_dim=embedding_size,key_dim=embedding_size,value_dim=embedding_size,model_dim=embedding_size,n_heads=num_heads)
    pos_weight:int=0.2
    mlp_depth:int=1
    
class TransformerDecoderBlock(nn.Module):
    def __init__(self,config:DecoderConfig):
        super().__init__()
        self.Embedding = nn.Embedding(config.vocab_size,config.embedding_size)
        self.PositionalEncoding =  PositionalEncodings(config.max_seq_len,config.embedding_size)
        self.attn_head = MaskedMultiHeadAttention(config.atn_cfg)
        self.res1 = nn.Sequential( ResMLP(input_size=config.embedding_size,num_layers=config.mlp_depth),
                                  nn.LayerNorm(config.embedding_size))
        self.CrossAttention = MultiHeadedAttention(config.atn_cfg)
        
        self.fc = nn.Sequential(nn.Linear(config.embedding_size,config.vocab_size),
                                nn.Softmax(-1))
        self.layer_norm1 = nn.LayerNorm(config.embedding_size)
        self.layer_norm2 = nn.LayerNorm(config.embedding_size)
        self.layer_norm3 = nn.LayerNorm(config.embedding_size)
        self.decodercfg = config
        
    def forward(self,tokens,encoder_output):
        token_embeddings =  self.Embedding(tokens)
        pos_embs = self.PositionalEncoding(torch.arange(0,end=tokens.shape[0]))
        embs = embs + self.decodercfg.pos_weight*pos_embs
        embs = self.layer_norm1(self.attn_head(embs,embs,embs) + embs)

        embs = self.layernorm2(embs +  self.CrossAttention(embs,encoder_output,encoder_output))
        embs  = self.layer_norm3(self.res1(embs))
        return self.fc(embs)
        



In [None]:
@dataclass 
class TransformerConfig:
    encoder_cfg:EncoderConfig
    decoder_cfg:DecoderConfig

class Transformer(lightning.LightningModule):
    def __init__(self, config:TransformerConfig):
        self.encoder =  TransformerEncoderBlock(config.encoder_cfg)
        self.decoder = TransformerDecoderBlock(config.decoder_cfg)
        self.loss =  torch.nn.CrossEntropyLoss()
    def forward(self,eng_tokens,fin_tokens):
        encoder_outputs = self.encoder(eng_tokens)
        ## need to right shift the tokens.
        
        decoder_output = self.decoder(fin_tokens,encoder_outputs)
        return decoder_output
    def training_step(self,batch,batch_idx):
        ## need to consider batches.
        ## ?
        en,fin = batch
        
        return super().training_step()
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)
    
        
    

##### Test Functions

In [49]:

multihead_config = attnconfig(query_dim=20,key_dim=20,value_dim=20,model_dim=20,n_heads=2)
selfattn_config = attnconfig(query_dim=20,key_dim=20,value_dim=20,model_dim=20,n_heads=1)

def test_mha(config:attnconfig=selfattn_config):
    mha_layer = MultiHeadedAttention(config)
    torch_mha = nn.MultiheadAttention(config.model_dim,config.n_heads)
    N_vectors = 5
    query_vector = torch.randn([N_vectors,config.query_dim])
    key_vector = torch.randn([N_vectors,config.key_dim])
    value_vector = torch.randn([N_vectors,config.value_dim])
    with torch.no_grad():
        output = mha_layer(query_vector,key_vector,value_vector)
        torchput,attn_weights = torch_mha(query_vector,key_vector,value_vector)
    print(f"This is my mha op shape: {output.shape}") ## should be n x value vector size
    print(f"This is torch mha output shape {torchput.shape}")

test_mha(selfattn_config)
test_mha(multihead_config)

This is my mha op shape: torch.Size([5, 20])
This is torch mha output shape torch.Size([5, 20])
This is my mha op shape: torch.Size([5, 20])
This is torch mha output shape torch.Size([5, 20])


In [50]:
encfg=EncoderConfig(num_heads=4,vocab_size=50762,embedding_size=768,max_seq_len=2000,pos_weight=0.2)
def test_encoder(encodercfg:EncoderConfig):
    trans = TransformerEncoderBlock(encodercfg)
    batch_size=10
    input_tokens = torch.randint(low=0,high=encodercfg.vocab_size,size=(batch_size,))
    with torch.no_grad():
        output = trans(input_tokens)
    print(output.shape)
test_encoder(encfg)

torch.Size([10, 768])


In [15]:
import torch
from torch import nn
class RoPE(nn.Module):
    def __init__(self, embedding_dim:int,context_len:int):
        super().__init__( )
        self.d = embedding_dim
        thetas = torch.arange(start=0,end=context_len,step=1,dtype=torch.float).view(-1,1) @torch.pow(1e5,-2*torch.arange(start=0,end=self.d-1,step=2)/self.d).repeat_interleave(2).view(1,-1)
        ## this should be an context_len x d size matrix 
        print(f"Shape of theta matrix is : {thetas.shape}")
        self.costhetas  = torch.cos(thetas)
        self.sinethetas = torch.sin(thetas)
        self.even_idx = torch.arange(start=0,end=self.d,step=2,dtype=torch.int)
        self.odd_idx = torch.arange(start=1,end=self.d,step=2,dtype=torch.int)

    def interswap(self,token_embedding):
        odds =  token_embedding[...,self.odd_idx]
        evens = token_embedding[...,self.even_idx]
        token_embedding[...,self.odd_idx] =  -1*evens
        token_embedding[...,self.even_idx] = odds
        return token_embedding
    
    def forward(self,token_embeddings):
        output = token_embeddings*self.costhetas + self.interswap(token_embeddings)*self.sinethetas
        return output


class RelativePE(nn.Module):
    def __init__(self, embedding_dim:int,context_len:int):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings=context_len,embedding_dim=embedding_dim)
        
        
def test_rope():
    batch_dim=32
    num_tokens=40
    embedding_dim=512
    token_embedding = torch.randn(size=(batch_dim,num_tokens,embedding_dim))
    rope = RoPE(embedding_dim=embedding_dim,context_len=num_tokens)
    output = rope(token_embedding)
     
test_rope()

Shape of theta matrix is : torch.Size([40, 512])


In [5]:
torch.arange(10).view(1,-1)

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [27]:
import torch 
from transformers import GPT2Tokenizer
import os
import numpy as np 

ARCHIVE_PATH = "EUbookshop-1"
CONTEXT_LEN = 512


class EnFinnishDataset(torch.utils.data.Dataset):
    def __init__(self,archive_path:str,context_len:int=512):
        super().__init__()
        with open(os.path.join(archive_path,"EUbookshop.en")) as fp:
            self.english_corpus = fp.readlines()
        with open(os.path.join(archive_path,"EUbookshop.fi")) as fp:
            self.finnish_corpus = fp.readlines()
        self.tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
        print(f"Vocab size pre-addition: {len(self.tokenizer)}")
        self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
        self.context_len = context_len
        print(f"Vocab size pre-addition: {len(self.tokenizer)}")
        
        print(self.tokenizer.pad_token)      
        print(self.tokenizer.pad_token_id)     

    def return_masks(self,pad_idx:int):
        pad_masks = torch.zeros(size=(self.context_len,self.context_len))
        pad_masks[:,pad_idx:] = -torch.inf
        pad_masks[pad_idx:,:] = -torch.inf
        return pad_masks
    def __getitem__(self, index):
        en_tokens = torch.tensor(self.tokenizer(self.english_corpus[index],padding="max_length",max_length=self.context_len)["input_ids"])
        finnish_tokens = torch.tensor(self.tokenizer(self.finnish_corpus[index],padding="max_length",max_length=self.context_len)["input_ids"])
        en_pad_index = torch.where(en_tokens==self.tokenizer.pad_token_id)[0][0]
        fin_pad_index = torch.where(finnish_tokens == self.tokenizer.pad_token_id)[0][0]
        en_pad_masks = self.return_masks(en_pad_index)
        fin_pad_masks = self.return_masks(fin_pad_index)
        # en_pad_masks = torch.concat([torch.full((en_pad_index,),fill_value=0),torch.full((self.context_len-en_pad_index,),-torch.inf)])
        # fin_pad_masks = torch.concat([torch.full((fin_pad_index,),fill_value=0),torch.full((self.context_len-fin_pad_index,),-1*torch.inf)])
        # en_pad_masks = en_pad_masks.view(-1,1)@torch.concat([torch.zeros((fin_pad_index,)),torch.ones(self.context_len-fin_pad_index)])
        # print(en_pad_masks)
        return (en_tokens,en_pad_masks,finnish_tokens,fin_pad_masks)
    def __len__(self):
        return len(self.english_corpus)


def test_dataloading_and_item_shape():
    """
    Tests the instantiation of the EnFinnishDataset and verifies the structure,
    type, and shape of a single item fetched from it.
    """
    dataset = EnFinnishDataset(archive_path=ARCHIVE_PATH, context_len=CONTEXT_LEN)
    print(f"This is the vocabulary size: {dataset.tokenizer.vocab_size}")
    # 1. Assert that the dataset object is created and is not empty.
    assert dataset is not None, "Dataset object could not be instantiated."
    assert len(dataset) > 0, "Dataset is empty after loading."

    # 2. Retrieve a single sample to check its integrity.
    sample = dataset[0]
    assert isinstance(sample, tuple) and len(sample) == 4, f"Dataset sample should be a tuple of 4 elements, but got {type(sample)} of length {len(sample)}."
    
    en_tokens, en_mask, fin_tokens, fin_mask = sample
    
    # 3. Assert that all parts of the sample are tensors with the correct shape.
    expected_shape = [torch.Size([CONTEXT_LEN]),torch.Size([CONTEXT_LEN,CONTEXT_LEN])]*2
    checks=[("English tokens", en_tokens), ("English mask", en_mask), ("Finnish tokens", fin_tokens), ("Finnish mask", fin_mask)]
    for i,(name, tensor) in enumerate(checks):
        assert isinstance(tensor, torch.Tensor), f"{name} is not a torch.Tensor."
        assert tensor.shape == expected_shape[i], (
            f"{name} shape is incorrect.\n"
            f"Expected: {expected_shape}, Got: {tensor.shape}"
        )

test_dataloading_and_item_shape()

loading file vocab.json from cache at /home/saigum/.cache/huggingface/hub/models--openai-community--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/vocab.json
loading file merges.txt from cache at /home/saigum/.cache/huggingface/hub/models--openai-community--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/merges.txt
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at /home/saigum/.cache/huggingface/hub/models--openai-community--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/tokenizer_config.json
loading file tokenizer.json from cache at /home/saigum/.cache/huggingface/hub/models--openai-community--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e/tokenizer.json
loading file chat_template.jinja from cache at None
loading configuration file config.json from cache at /home/saigum/.cache/huggingface/hub/models--openai-community--gpt2/snapshots/607a30d783dfa6

Vocab size pre-addition: 50257
Vocab size pre-addition: 50258
<pad>
50257
This is the vocabulary size: 50257


In [None]:

class EnFinnishDataset(torch.utils.data.Dataset):
    def __init__(self,archive_path:str,context_len:int=512):
        super().__init__()
        with open(os.path.join(archive_path,"EUbookshop.en")) as fp:
            self.english_corpus = fp.readlines()
        with open(os.path.join(archive_path,"EUbookshop.fi")) as fp:
            self.finnish_corpus = fp.readlines()
        self.tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
        self.tokenizer.add_special_tokens({'pad_token': '<pad>',"bos_token": "<bos>"})        
        print("PAD token:", self.tokenizer.pad_token, self.tokenizer.pad_token_id)
        print("EOS token:", self.tokenizer.eos_token, self.tokenizer.eos_token_id)
        print("BOS token:", self.tokenizer.bos_token, self.tokenizer.bos_token_id)
        self.context_len = context_len
        print(self.tokenizer.pad_token)      
        print(self.tokenizer.pad_token_id)     

    def return_masks(self,pad_idx:int):
        pad_masks = torch.zeros(size=(self.context_len,self.context_len))
        pad_masks[:,pad_idx:] = -1e9
        pad_masks[pad_idx:,:] = -1e9
        return pad_masks
    def __getitem__(self, index):
        en_tokens = torch.tensor(self.tokenizer(self.english_corpus[index],padding="max_length",max_length=self.context_len)["input_ids"])
        finnish_tokens = torch.tensor(self.tokenizer(self.finnish_corpus[index],padding="max_length",max_length=self.context_len)["input_ids"])
        en_pad_index = torch.where(en_tokens==self.tokenizer.pad_token_id)[0][0]
        fin_pad_index = torch.where(finnish_tokens == self.tokenizer.pad_token_id)[0][0]
        en_pad_masks = self.return_masks(en_pad_index)
        fin_pad_masks = self.return_masks(fin_pad_index)
        # en_pad_masks = torch.concat([torch.full((en_pad_index,),fill_value=0),torch.full((self.context_len-en_pad_index,),-torch.inf)])
        # fin_pad_masks = torch.concat([torch.full((fin_pad_index,),fill_value=0),torch.full((self.context_len-fin_pad_index,),-1*torch.inf)])
        # en_pad_masks = en_pad_masks.view(-1,1)@torch.concat([torch.zeros((fin_pad_index,)),torch.ones(self.context_len-fin_pad_index)])
        # print(en_pad_masks)
        return (en_tokens,en_pad_masks,finnish_tokens,fin_pad_masks)
    def __len__(self):
        return len(self.english_corpus)

In [2]:
import torch 
from transformers import GPT2Tokenizer
import os
import numpy as np 
from collections import Counter
from tqdm import tqdm

ARCHIVE_PATH = "EUbookshop-1"
CONTEXT_LEN = 512


class EnFinnishDataset(torch.utils.data.Dataset):
    def __init__(self,archive_path:str,context_len:int=512):
        super().__init__()
        # --- Make sure the files exist before proceeding ---
        en_path = os.path.join(archive_path, "EUbookshop.en")
        fi_path = os.path.join(archive_path, "EUbookshop.fi")
        if not (os.path.exists(en_path) and os.path.exists(fi_path)):
            raise FileNotFoundError(
                f"Could not find dataset files at '{en_path}' and '{fi_path}'. "
                "Please ensure the archive is extracted correctly."
            )

        with open(en_path, 'r', encoding='utf-8') as fp:
            self.english_corpus = fp.readlines()
        with open(fi_path, 'r', encoding='utf-8') as fp:
            self.finnish_corpus = fp.readlines()
        
        self.tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
        self.tokenizer.add_special_tokens({'pad_token': '<pad>',"bos_token": "<bos>"})        
        self.context_len = context_len
        
        print(f"Tokenizer loaded. Pad token: '{self.tokenizer.pad_token}', ID: {self.tokenizer.pad_token_id}")
        print(f"Total vocabulary size: {len(self.tokenizer)}")

    def return_masks(self,pad_idx:int):
        # This function creates a square attention mask.
        # It's used to prevent the model from attending to padding tokens.
        pad_masks = torch.zeros(size=(self.context_len,self.context_len))
        if pad_idx < self.context_len: # Check if padding exists
            pad_masks[:,pad_idx:] = -torch.inf
            pad_masks[pad_idx:,:] = -torch.inf
        return pad_masks
        
    def __getitem__(self, index):
        # Tokenize with padding to max_length
        en_encoding = self.tokenizer(self.english_corpus[index], padding="max_length", max_length=self.context_len, truncation=True)
        fi_encoding = self.tokenizer(self.finnish_corpus[index], padding="max_length", max_length=self.context_len, truncation=True)

        en_tokens = torch.tensor(en_encoding["input_ids"])
        finnish_tokens = torch.tensor(fi_encoding["input_ids"])

        # Find the first occurrence of the pad token to determine sequence length
        en_pad_indices = torch.where(en_tokens == self.tokenizer.pad_token_id)[0]
        en_pad_index = en_pad_indices[0] if len(en_pad_indices) > 0 else self.context_len

        fin_pad_indices = torch.where(finnish_tokens == self.tokenizer.pad_token_id)[0]
        fin_pad_index = fin_pad_indices[0] if len(fin_pad_indices) > 0 else self.context_len

        en_pad_masks = self.return_masks(en_pad_index)
        fin_pad_masks = self.return_masks(fin_pad_index)
        
        return (en_tokens, en_pad_masks, finnish_tokens, fin_pad_masks)

    def __len__(self):
        return len(self.english_corpus)


def test_dataloading_and_item_shape():
    """
    Tests the instantiation of the EnFinnishDataset and verifies the structure,
    type, and shape of a single item fetched from it.
    """
    print("\n--- Running Dataloading and Shape Sanity Check ---")
    dataset = EnFinnishDataset(archive_path=ARCHIVE_PATH, context_len=CONTEXT_LEN)
    
    # 1. Assert that the dataset object is created and is not empty.
    assert dataset is not None, "Dataset object could not be instantiated."
    assert len(dataset) > 0, "Dataset is empty after loading."
    print("✅ Dataset instantiated and is not empty.")

    # 2. Retrieve a single sample to check its integrity.
    sample = dataset[0]
    assert isinstance(sample, tuple) and len(sample) == 4, f"Dataset sample should be a tuple of 4 elements, but got {type(sample)} of length {len(sample)}."
    print("✅ Sample is a tuple of 4 elements.")
    
    en_tokens, en_mask, fin_tokens, fin_mask = sample
    
    # 3. Assert that all parts of the sample are tensors with the correct shape.
    expected_shapes = [torch.Size([CONTEXT_LEN]), torch.Size([CONTEXT_LEN, CONTEXT_LEN]), torch.Size([CONTEXT_LEN]), torch.Size([CONTEXT_LEN, CONTEXT_LEN])]
    checks=[("English tokens", en_tokens), ("English mask", en_mask), ("Finnish tokens", fin_tokens), ("Finnish mask", fin_mask)]
    for i, (name, tensor) in enumerate(checks):
        assert isinstance(tensor, torch.Tensor), f"{name} is not a torch.Tensor."
        assert tensor.shape == expected_shapes[i], (
            f"{name} shape is incorrect.\n"
            f"Expected: {expected_shapes[i]}, Got: {tensor.shape}"
        )
    print("✅ All sample tensors have the correct shapes.")
    print("--- Sanity Check Passed ---\n")


def analyze_dataset(dataset: EnFinnishDataset):
    """
    Iterates through the entire dataset to collect token statistics and
    verify that the tokenization process is reversible.
    """
    print("\n--- Starting Full Dataset Analysis ---")
    en_token_counts = Counter()
    fi_token_counts = Counter()
    mismatched_decodings = []

    # Use tqdm for a progress bar
    for i in tqdm(range(len(dataset)), desc="Analyzing Dataset"):
        # 1. Get original text and tokenized tensors
        original_en = dataset.english_corpus[i].strip()
        original_fi = dataset.finnish_corpus[i].strip()
        en_tokens, _, fin_tokens, _ = dataset[i]

        # 2. Update token frequency counts
        # We only count non-padding tokens to get meaningful statistics
        en_non_pad_tokens = en_tokens[en_tokens != dataset.tokenizer.pad_token_id]
        fi_non_pad_tokens = fin_tokens[fin_tokens != dataset.tokenizer.pad_token_id]
        en_token_counts.update(en_non_pad_tokens.tolist())
        fi_token_counts.update(fi_non_pad_tokens.tolist())

        # 3. Check for information loss by decoding
        # skip_special_tokens=True removes <pad> tokens from the output
        decoded_en = dataset.tokenizer.decode(en_non_pad_tokens, skip_special_tokens=True).strip()
        decoded_fi = dataset.tokenizer.decode(fi_non_pad_tokens, skip_special_tokens=True).strip()

        # Compare original with decoded. Some minor differences are expected
        # due to tokenizer normalization, but major ones should be flagged.
        if original_en != decoded_en:
            mismatched_decodings.append(("en", i, original_en, decoded_en))
        if original_fi != decoded_fi:
            mismatched_decodings.append(("fi", i, original_fi, decoded_fi))
            
    print("\n--- Analysis Complete ---")

    # --- Report Token Statistics ---
    print("\n--- Token Distribution Statistics ---")
    print("\nTop 15 Most Common English Tokens:")
    for token_id, count in en_token_counts.most_common(15):
        token_str = dataset.tokenizer.decode([token_id])
        print(f"  - Token: '{token_str}' (ID: {token_id}) | Count: {count:,}")

    print("\nTop 15 Most Common Finnish Tokens:")
    for token_id, count in fi_token_counts.most_common(15):
        token_str = dataset.tokenizer.decode([token_id])
        print(f"  - Token: '{token_str}' (ID: {token_id}) | Count: {count:,}")
        
    # --- Report Information Integrity Check ---
    print("\n\n--- Tokenization Integrity Check ---")
    total_sentences = len(dataset) * 2
    num_mismatches = len(mismatched_decodings)
    print(f"✅ Found {num_mismatches} mismatched decodings out of {total_sentences:,} total sentences.")

    if num_mismatches > 0:
        print("\nDisplaying up to 5 examples of mismatches:")
        print("Note: Minor differences (e.g., whitespace, normalization) are common.")
        for lang, index, original, decoded in mismatched_decodings[:5]:
            print("-" * 20)
            print(f"Language: {lang}, Index: {index}")
            print(f"Original: '{original}'")
            print(f"Decoded:  '{decoded}'")
            print("-" * 20)


test_dataloading_and_item_shape()
full_dataset = EnFinnishDataset(archive_path=ARCHIVE_PATH, context_len=CONTEXT_LEN)
analyze_dataset(full_dataset)



--- Running Dataloading and Shape Sanity Check ---
Tokenizer loaded. Pad token: '<pad>', ID: 50257
Total vocabulary size: 50259
✅ Dataset instantiated and is not empty.
✅ Sample is a tuple of 4 elements.
✅ All sample tensors have the correct shapes.
--- Sanity Check Passed ---

Tokenizer loaded. Pad token: '<pad>', ID: 50257
Total vocabulary size: 50259

--- Starting Full Dataset Analysis ---


Analyzing Dataset: 100%|██████████| 100000/100000 [03:28<00:00, 478.55it/s]


--- Analysis Complete ---

--- Token Distribution Statistics ---

Top 15 Most Common English Tokens:
  - Token: ' the' (ID: 262) | Count: 166,399
  - Token: ',' (ID: 11) | Count: 108,079
  - Token: '
' (ID: 198) | Count: 100,000
  - Token: ' of' (ID: 286) | Count: 93,319
  - Token: '.' (ID: 13) | Count: 92,352
  - Token: ' to' (ID: 284) | Count: 71,732
  - Token: ' and' (ID: 290) | Count: 61,473
  - Token: ' in' (ID: 287) | Count: 48,820
  - Token: ' a' (ID: 257) | Count: 37,242
  - Token: ' is' (ID: 318) | Count: 35,106
  - Token: ' that' (ID: 326) | Count: 33,866
  - Token: ' for' (ID: 329) | Count: 25,346
  - Token: ' on' (ID: 319) | Count: 23,954
  - Token: ' be' (ID: 307) | Count: 21,304
  - Token: ' this' (ID: 428) | Count: 16,531

Top 15 Most Common Finnish Tokens:
  - Token: 'ä' (ID: 11033) | Count: 536,280
  - Token: ' k' (ID: 479) | Count: 130,446
  - Token: ',' (ID: 11) | Count: 125,960
  - Token: 'i' (ID: 72) | Count: 105,696
  - Token: '
' (ID: 198) | Count: 99,995
  - To


