In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
import collections
import unicodedata
from typing import Dict, List, Tuple, Optional
import json
import logging
from indicnlp.normalize.indic_normalize import TamilNormalizer
from torch.utils.data import Dataset, DataLoader
import unicodedata
import tqdm
import dataclasses
import re

In [2]:
import multiprocessing

In [3]:
multiprocessing.cpu_count()

20

In [4]:
logger = logging.getLogger("TA-EN NMT")
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

In [5]:
def prepare_data(csv_path = "", train_size=0.9, no_of_samples = 20000, df = None, name = None, shuffle = True):
    try:
        if csv_path:
            df = pd.read_csv(csv_path)
        if not name:
            name = os.path.basename(csv_path).split(".")[0]
        
        if "Unnamed: 0" in df.columns:
            df = df.drop(columns=["Unnamed: 0"])
        print(len(df))
        df = df.dropna()
        df = df.drop_duplicates()
        df = df.rename(columns={"en": "English", "ta": "Tamil"})
        if shuffle:
            df = df.sample(frac=1).reset_index(drop=True)
        print(len(df))
        df = df.iloc[:no_of_samples]
        folder = f"./data_{name.strip()}/"
        print(folder)
        if not os.path.exists(folder):
            os.makedirs(folder)
        df["Tamil"].to_csv(f"{folder}//{name}_source_full.txt", index=False, header=False)
        df["English"].to_csv(f"{folder}//{name}_target_full.txt", index=False, header=False)
        
        train_size = int(train_size * len(df))
        train_data = df.iloc[:train_size]
        valid_data = df.iloc[train_size:]

        train_data["Tamil"].to_csv(f"{folder}//{name}_source_train.txt", index=False, header=False)
        train_data["English"].to_csv(f"{folder}//{name}_target_train.txt", index=False, header=False)
        valid_data["Tamil"].to_csv(f"{folder}//{name}_source_valid.txt", index=False, header=False)
        valid_data["English"].to_csv(f"{folder}//{name}_target_valid.txt", index=False, header=False)
        return True
    except Exception as e:
        logger.error(f"UNABLE TO PREPARE DATA:{e}")
        return False

In [6]:
def combine_datasets(root_dir,dataset_path = None, no_of_samples=2000,train_size = 0.9, excluded = ["corpus.bcn.test 2k.csv","corpus.bcn.dev 1k.csv","parallel 8k gloss.xlsx"], condense=False):
    try:
        files = [file for file in os.listdir(root_dir) if file not in excluded]
        if dataset_path:
            assert os.path.basename(dataset_path) in files, f"Dataset {dataset_path} not found in {root_dir}"
            if not prepare_data(os.path.join(root_dir,dataset_path),train_size=train_size,no_of_samples=no_of_samples):
                return False
        elif condense:
            base_file = pd.read_csv(os.path.join(root_dir,files[0]))
            for file in files[1:]:
                base_file = pd.concat([base_file,pd.read_csv(os.path.join(root_dir,file))],axis=0)
            base_file.dropna(inplace=True)
            base_file.drop(["Unnamed: 0"],axis=1,inplace=True)
            base_file.drop_duplicates(inplace=True)
            if not prepare_data(None,train_size,no_of_samples,base_file,name = "full_corpa"):
                return False
        else:
            for file in files:
                if not prepare_data(os.path.join(root_dir,file),train_size=train_size,no_of_samples=no_of_samples):
                    return False
        return True
    except Exception as e:
        logger.error(f"UNABLE TO COMBINE AND SAVE DATASETS:{e}")
        return False
        

        

In [6]:
combine_datasets("en-ta",condense=True)

1186199
1186199
./data_full_corpa/


True

In [None]:
@dataclasses.dataclass
class BPETokenizer:

    language : str
    vocab_size : int = 40000
    special_tokens : Dict[str, int] = dataclasses.field(default_factory=lambda: {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3})
    merges : Dict[Tuple[str, str], str] = dataclasses.field(default_factory=lambda: {})
    vocab_to_idx : Dict[str, int] = dataclasses.field(default_factory=lambda: {})
    idx_to_vocab : Dict[int, str] = dataclasses.field(default_factory=lambda: {})
    tamil_normalizer : TamilNormalizer = TamilNormalizer()

    def __post_init__(self):
        if self.language not in ["ta","en"]:
            return "Only English and Tamil is supported"

    def normalize_text(self, text):
        if self.language == "ta":
            text = unicodedata.normalize('NFC', text)
            text = re.sub(r'"{2,}', '"', text)
            text = re.sub(r"'{2,}", "'", text)
            text = re.sub(r"[a-zA-Z]",'',text)
            text = re.sub(r'([.;!:?"\'])', r' \1 ', text)
            text = re.sub(r'\s+',' ', text).strip()
            return self.tamil_normalizer.normalize(text)
        elif self.language == "en":
            text = unicodedata.normalize('NFC', text)
            text = text.lower()
            text = re.sub(r'([.;!:?"\'])', r' \1 ', text)
            text = re.sub(r'\s+',' ', text).strip()
            return text
        else:
            return "Unsupported language"
        
    def text_tokenize(self, text):
        text = self.normalize_text(text)
        text = text.split()
        
        # Use regex to split words into grapheme clusters instead of characters
        import regex
        words = []
        for word in text:
            # Split by grapheme clusters
            graphemes = regex.findall(r'\X', word)
            words.append(graphemes + ["</w>"])
        
        return words
    
    def _get_stats_chunk(self, words_chunk):
        pairs = {}
        for word in words_chunk:
            for i in range(len(word) - 1):
                pair = (word[i], word[i + 1])
                pairs[pair] = pairs.get(pair, 0) + 1
        return pairs

    def _merge_tokens(self, words, new_tokens, best_pair):
        new_words = []
        for word in words:
            new_word = []
            i = 0
            while i < len(word):
                if i<len(word)-1 and word[i] == best_pair[0] and word[i+1] == best_pair[1]:
                    new_word.append(new_tokens)
                    i+=2
                else:
                    new_word.append(word[i])
                    i+=1
            new_words.append(new_word)
        return new_words
    
    def _train_chunk(self, words, shared_merges, shared_vocab_to_idx, shared_vocab_size):
        local_merges = {pair : merged for pair, merged in shared_merges.items()}
        local_vocab_to_idx = {char : idx for char, idx in shared_vocab_to_idx.items()}
        local_vocab_size = shared_vocab_size.value
        
    
    def train_bpe(self, text, output_file, num_merges = None, is_save = True, num_workers = None) -> list:

        words = self.text_tokenize(text)

        chars = set()
        for word in words:
            chars.update(word)
        for key,val in self.special_tokens.items():
            self.vocab_to_idx[key] = val
        for char in chars:
            self.vocab_to_idx[char] = len(self.vocab_to_idx)
        

        if num_merges is None:
            num_merges = self.vocab_size - len(chars) - len(self.special_tokens)

        for i in tqdm.tqdm(range(num_merges)):

            cores = multiprocessing.cpu_count()

            chunk_size = len(words)//cores if num_workers == None else num_workers
            chunks = [words[i:i+chunk_size] for i in range(0, len(words), chunk_size)]
            logger.info("Starting get_stats on multiprocessing")
            with multiprocessing.Pool(processes=cores) as pool:
                results = pool.map(self._get_stats_chunk, chunks)
            logger.info("Ending multiprocessing")
            pairs = collections.Counter()
            for result in results:
                pairs.update(result)

            if not pairs:
                logger.info("All Pairs were merged, stopping training.")
                break

            best_pair = max(pairs.items(), key = lambda x: x[1])[0]

            new_token = best_pair[0] + best_pair[1]
            self.vocab_to_idx[new_token] = len(self.vocab_to_idx)
            self.idx_to_vocab[len(self.idx_to_vocab)] = new_token

            self.merges[best_pair] = new_token

            chunks = [(words[i:i+chunk_size], new_token, best_pair) for i in range(0, len(words), chunk_size)]
            logger.info("Starting merge_tokens on multiprocessing")
            with multiprocessing.Pool(processes=cores) as pool:
                results = pool.map(self._merge_tokens, chunks)
            logger.info("Ending multiprocessing")
            words = []
            for result in results:
                words.extend(result)

            if (i + 1) % 1000 == 0:
                logger.info(f"Completed {i + 1} merges. Current vocab size: {len(self.vocab_to_idx)}")

        logger.info(f"Done merges. Current vocab size: {len(self.idx_to_vocab)}")
        
        if is_save:
            self.save(output_file)
        return words
    
    def tokenize(self, text):

        words = self.text_tokenize(text)
        tokenized = []
        for word_tokens in words:

            while len(word_tokens) > 1:
                best_pair = None
                best_idx = -1
                
                for i in range(len(word_tokens) - 1):
                    pair = (word_tokens[i], word_tokens[i+1])
                    if pair in self.merges:
                        best_pair = pair
                        best_idx = i
                        break
                
                if best_pair is None:
                    break
                    
                word_tokens[best_idx] = self.merges[best_pair]
                del word_tokens[best_idx + 1]
            
            for subword in word_tokens:
                if subword in self.vocab_to_idx:
                    tokenized.append(self.vocab_to_idx[subword])
                else:
                    tokenized.append(self.vocab_to_idx["<UNK>"])
        return tokenized, words
    
    def decode(self, tokens):
        text = [self.idx_to_vocab.get(i,"<UNK>") for i in tokens]
        text = "".join(text).replace("</w>"," ")
        if text.endswith(" "):
            text = text[:-1]
        return text
            
            
    def save(self, output_path) -> None:

        try:

            with open(f"{output_path}.merges.json", "w", encoding = "utf-8") as f:
                serializable_merges = {f"{k[0]} {k[1]}": v for k, v in self.merges.items()}
                json.dump(serializable_merges,f, ensure_ascii = False,indent = 2)

            with open(f"{output_path}.vocab.json", 'w', encoding='utf-8') as f:
                json.dump(self.vocab_to_idx, f, ensure_ascii = False, indent = 2)

            logger.info("Successfully saved tokenizer")
        except Exception as e:
            logger.info(f"save failed: {e}")

    @classmethod
    def load(cls,output_path, language) -> 'BPETokenizer':

        tokenizer = cls(language)

        with open(f"{output_path}.vocab.json", 'r', encoding='utf-8') as f:
            tokenizer.vocab_to_idx = json.load(f)
            tokenizer.vocab_to_idx = {k: int(v) if isinstance(v, str) and v.isdigit() else v 
                              for k, v in tokenizer.vocab_to_idx.items()}
        with open(f"{output_path}.merges.json", 'r', encoding='utf-8') as f:
            serialized_merges = json.load(f)
            tokenizer.merges = {tuple(k.split()): v for k, v in serialized_merges.items()}
            
        tokenizer.idx_to_vocab = {v: k for k, v in tokenizer.vocab_to_idx.items()}
        
        return tokenizer


In [1]:
import pstats

# Load the profile output file
stats = pstats.Stats("profile.out")

# Print the profiling summary
stats.print_stats()


Sun Mar  2 00:11:02 2025    profile.out

         208347048 function calls (208346015 primitive calls) in 92.104 seconds

   Random listing order was used

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       31    0.000    0.000    0.000    0.000 C:\Users\ASUS\miniconda3\envs\xformers-env\Lib\re\_parser.py:86(opengroup)
       29    0.000    0.000    0.000    0.000 C:\Users\ASUS\miniconda3\envs\xformers-env\Lib\re\_parser.py:953(fix_flags)
        1    0.000    0.000    0.000    0.000 C:\Users\ASUS\miniconda3\envs\xformers-env\Lib\functools.py:396(__get__)
      864    0.000    0.000    0.000    0.000 C:\Users\ASUS\miniconda3\envs\xformers-env\Lib\re\_parser.py:240(__next)
      113    0.000    0.000    0.001    0.000 C:\Users\ASUS\miniconda3\envs\xformers-env\Lib\re\__init__.py:164(match)
        9    0.000    0.000    0.040    0.004 C:\Users\ASUS\miniconda3\envs\xformers-env\Lib\re\__init__.py:179(sub)
       42    0.000    0.000    0.000    0.000 C:\Users\

<pstats.Stats at 0x291c52a2150>

In [8]:
def learn_bpe(datasets = ["full_corpa"],vocab_size={"target_lang":30000, "source_lang": 30000}, languages = {"src":"ta","trg":"en"}, output_prefixes = {"src":"src_bpe1","trg":"trg_bpe1"}):
    full_source_text = ""
    full_target_text = ""
    for dataset in datasets:
        with open(f"data_{dataset}/{dataset}_source_full.txt", 'r', encoding='utf-8') as f:
            full_source_text += " " + f.read()
        with open(f"data_{dataset}/{dataset}_target_full.txt", 'r', encoding='utf-8') as f:
            full_target_text += " " + f.read()
    full_src_vocab_size = vocab_size.get("source_lang",20000)
    full_tgt_vocab_size = vocab_size.get("target_lang",20000)
    src_tokenizer = BPETokenizer(languages["src"],full_src_vocab_size)
    src_tokenizer.train_bpe(full_source_text,output_prefixes["src"],num_workers=5)
    tgt_tokenizer = BPETokenizer(languages["trg"],full_tgt_vocab_size,num_workers=5)
    tgt_tokenizer.train_bpe(full_target_text,output_prefixes["trg"])
    return src_tokenizer, tgt_tokenizer

In [None]:
src_tok, tgt_tok = learn_bpe(["general_en_ta 87k","pmindia"])

  0%|          | 0/29727 [00:00<?, ?it/s]2025-03-01 17:55:24,373 - TA-EN NMT - INFO - Starting get_stats on multiprocessing


In [23]:
tokenizer1 = BPETokenizer.load("87_tamil_bpe","ta")

In [7]:
tokenizer = BPETokenizer("ta",25000)

In [8]:
final = tokenizer.train_bpe("./data_general_en_ta 87k/general_en_ta 87k_source_full.txt","87_tamil_bpe")

  4%|▍         | 997/24765 [00:27<10:45, 36.80it/s]2025-03-01 14:33:40,706 - TA-EN NMT - INFO - Completed 1000 merges. Current vocab size: 1235
  8%|▊         | 1998/24765 [00:48<06:50, 55.45it/s]2025-03-01 14:34:02,205 - TA-EN NMT - INFO - Completed 2000 merges. Current vocab size: 2235
 12%|█▏        | 2998/24765 [01:07<05:35, 64.79it/s]2025-03-01 14:34:20,616 - TA-EN NMT - INFO - Completed 3000 merges. Current vocab size: 3235
 16%|█▌        | 3999/24765 [01:24<05:50, 59.20it/s]2025-03-01 14:34:37,269 - TA-EN NMT - INFO - Completed 4000 merges. Current vocab size: 4235
 20%|██        | 4993/24765 [01:38<04:59, 65.96it/s]2025-03-01 14:34:52,196 - TA-EN NMT - INFO - Completed 5000 merges. Current vocab size: 5235
 24%|██▍       | 5997/24765 [01:53<04:34, 68.43it/s]2025-03-01 14:35:06,773 - TA-EN NMT - INFO - Completed 6000 merges. Current vocab size: 6235
 26%|██▌       | 6394/24765 [01:58<04:32, 67.46it/s]2025-03-01 14:35:12,117 - TA-EN NMT - INFO - All Pairs were merged, stopping tr

In [32]:
['அந்நியன்</w>'] in final

False

In [11]:
with open("./data_general_en_ta 87k/general_en_ta 87k_source_full.txt", "r") as f:
    ta = f.readlines()

In [29]:
merges = tokenizer1.merges

In [31]:
merges[('அந்நி', 'யன்')]

KeyError: ('அந்நி', 'யன்')

In [13]:
tokenizer.text_tokenize(ta[0])

[['நா', 'ன்', '</w>'],
 ['இ', 'ப்', 'போ', 'து', '</w>'],
 ['வே', 'லை', 'யி', 'ல்', '</w>'],
 ['இ', 'ரு', 'க்', 'கி', 'றே', 'ன்', '.', '</w>']]

In [46]:
tokens = src_tok.tokenize("டாம் இரவு மிகவும் தாமதமாகத் திரும்பினார்.\n".strip())

In [47]:
tokens

([356, 1292, 1115, 8716, 41, 8959, 12521, 183, 419, 275],
 [['டாம்</w>'],
  ['இரவு</w>'],
  ['மிக', 'வும்', '</w>'],
  ['தாமதமாகத்</w>'],
  ['திரும்', 'பி', 'னார்</w>'],
  ['.</w>']])

In [35]:
for i in tokens[0]:
    print(f"{i} ---> {tokenizer.idx_to_vocab[i]}")

518 ---> மேரி</w>
268 ---> ஒரு</w>
1580 ---> வலி
1116 ---> மையான
38 ---> </w>
2533 ---> பெண்.</w>


In [49]:
merges

{('்', '</w>'): '்</w>',
 ('்', 'க'): '்க',
 ('.', '</w>'): '.</w>',
 ('ு', '</w>'): 'ு</w>',
 ('்', '.</w>'): '்.</w>',
 ('்', 'த'): '்த',
 ('க', '்க'): 'க்க',
 ('ம', '்</w>'): 'ம்</w>',
 ('ன', '்</w>'): 'ன்</w>',
 ('்', 'ப'): '்ப',
 ('ை', '</w>'): 'ை</w>',
 ('ர', 'ு'): 'ரு',
 ('்க', 'ள'): '்கள',
 ('ி', 'ற'): 'ிற',
 ('்', 'ட'): '்ட',
 ('ன', '்'): 'ன்',
 ('ி', 'ய'): 'ிய',
 ('ா', 'ர'): 'ார',
 ('ி', 'ல'): 'ில',
 ('ந', 'ா'): 'நா',
 ('ந', '்த'): 'ந்த',
 ('த', '்த'): 'த்த',
 ('ங', '்கள'): 'ங்கள',
 ('த', 'ு'): 'து',
 ('ப', '்ப'): 'ப்ப',
 ('ன', '்.</w>'): 'ன்.</w>',
 ('அ', 'வ'): 'அவ',
 ('நா', 'ன்</w>'): 'நான்</w>',
 ('த', 'ு</w>'): 'து</w>',
 ('ட', 'ா'): 'டா',
 ('்', 'ல'): '்ல',
 ('ட', '்ட'): 'ட்ட',
 ('வ', 'ி'): 'வி',
 ('ை', 'ய'): 'ைய',
 ('ா', 'க'): 'ாக',
 ('ே', 'ன்.</w>'): 'ேன்.</w>',
 ('எ', 'ன்'): 'என்',
 ('க்க', 'ு</w>'): 'க்கு</w>',
 ('ம', 'ு'): 'மு',
 ('ம', '்.</w>'): 'ம்.</w>',
 ('து', '.</w>'): 'து.</w>',
 ('டா', 'ம்</w>'): 'டாம்</w>',
 ('ச', 'ெ'): 'செ',
 ('?', '</w>'): '?</w>',
 ('ங்க

In [48]:
('தாமத', 'ம') in merges

False

In [28]:
ta[6]

'டாம் இரவு மிகவும் தாமதமாகத் திரும்பினார்.\n'

In [64]:
striped = []
for i in ta:
    striped.append(i.strip())

In [None]:
import re
import os
import tqdm

class TamilBPETokenizer:
    def __init__(self, vocab_size: int = 32000):
        self.vocab_size = vocab_size
        self.special_tokens = {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3}
        self.merges = {}  
        self.vocab = {}   
        self.inverse_vocab = {}  

    def _normalize_tamil_text(self, text: str) -> str:
        """Normalize Tamil text for consistent processing."""
        # Normalize to NFC form to ensure consistent representation of Tamil characters
        text = unicodedata.normalize('NFC', text)
        # You might need additional Tamil-specific normalization here
        return text

    def _get_stats(self, words: List[List[str]]) -> Dict[Tuple[str, str], int]:
        """Count frequency of adjacent symbol pairs in the training data."""
        pairs = collections.defaultdict(int)
        for word in words:
            for i in range(len(word) - 1):
                pairs[tuple(word[i:i+2])] += 1
        return pairs

    def _merge_pair(self, words: List[List[str]], pair: Tuple[str, str], new_token: str) -> List[List[str]]:
        """Apply a merge operation to all occurrences of the pair in the training data."""
        first, second = pair
        new_words = []
        for word in words:
            i = 0
            new_word = []
            while i < len(word):
                if i < len(word) - 1 and word[i] == first and word[i+1] == second:
                    new_word.append(new_token)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_words.append(new_word)
        return new_words

    def _tokenize_tamil_words(self, tamil_text: str) -> List[List[str]]:
        """Split Tamil text into words and initialize each word as a list of characters."""
        # Normalize the text for consistent processing
        normalized_text = self._normalize_tamil_text(tamil_text)
        
        # Split text into words
        words = normalized_text.split()
        
        # Initialize each word as a list of characters
        # We add a special end-of-word token to handle word boundaries during merges
        return [[c for c in word] + ['</w>'] for word in words]

    def train(self, file_path: str, num_merges: Optional[int] = None):
        """Train the BPE model on Tamil text data."""
        if num_merges is None:
            # If not specified, set num_merges to reach desired vocab size
            # accounting for special tokens and initial character vocabulary
            num_merges = self.vocab_size - len(self.special_tokens) - 500  # Rough estimate for initial char vocab
        
        print(f"Reading training data from {file_path}")
        # Read the training data
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        
        # Tokenize into words and characters
        print("Tokenizing Tamil text...")
        words = self._tokenize_tamil_words(text)
        
        # Initialize vocabulary with unique characters
        chars = set()
        for word in words:
            chars.update(word)
        
        print(f"Initial vocabulary size (characters): {len(chars)}")
        
        # Initialize the vocabulary with special tokens and characters
        vocab = {}
        for token, idx in self.special_tokens.items():
            vocab[token] = idx
        
        idx = len(vocab)
        for char in sorted(chars):
            vocab[char] = idx
            idx += 1
        
        # Perform BPE merges
        print(f"Performing {num_merges} BPE merges...")
        for i in tqdm.tqdm(range(num_merges)):
            # Get pair frequencies
            pairs = self._get_stats(words)
            if not pairs:
                break
                
            # Find the most frequent pair
            best_pair = max(pairs.items(), key=lambda x: x[1])[0]
            
            # Create a new token by joining the pair
            new_token = best_pair[0] + best_pair[1]
            
            # Add the new token to the vocabulary
            vocab[new_token] = len(vocab)
            
            # Update the merges dictionary
            self.merges[best_pair] = new_token
            
            # Apply the merge throughout the data
            words = self._merge_pair(words, best_pair, new_token)
            
            if (i + 1) % 1000 == 0:
                print(f"Completed {i + 1} merges. Current vocab size: {len(vocab)}")
        
        print(f"Final vocabulary size: {len(vocab)}")
        self.vocab = vocab
        self.inverse_vocab = {v: k for k, v in vocab.items()}
        
        return self

    def tokenize(self, text: str) -> List[int]:
        """Tokenize Tamil text using the learned BPE model."""
        # Normalize and split into words
        normalized_text = self._normalize_tamil_text(text)
        words = normalized_text.split()
        
        result = []
        for word in words:
            # Initialize as list of characters
            word_tokens = [c for c in word] + ['</w>']
            
            # Apply merges until no more can be applied
            while True:
                # Find the first applicable merge
                pairs = [(word_tokens[i], word_tokens[i+1]) 
                        for i in range(len(word_tokens)-1)]
                
                # Find the first merge rule that can be applied
                for pair in pairs:
                    if pair in self.merges:
                        new_token = self.merges[pair]
                        word_tokens = self._merge_pair([word_tokens], pair, new_token)[0]
                        break
                else:
                    # No merge rule can be applied, we're done with this word
                    break
                    
            # Convert tokens to IDs
            for token in word_tokens:
                if token in self.vocab:
                    result.append(self.vocab[token])
                else:
                    result.append(self.vocab["<UNK>"])
                    
        return result

    def decode(self, token_ids: List[int]) -> str:
        """Decode a list of token IDs back to Tamil text."""
        tokens = [self.inverse_vocab.get(idx, "<UNK>") for idx in token_ids]
        
        # Join the tokens and replace the end-of-word token with space
        text = ''.join(tokens).replace('</w>', ' ')
        
        # Remove trailing space if exists
        if text.endswith(' '):
            text = text[:-1]
            
        return text

    def save(self, prefix: str):
        """Save the tokenizer vocabulary and merges to files."""
        # Save vocabulary
        with open(f"{prefix}.vocab.json", 'w', encoding='utf-8') as f:
            json.dump(self.vocab, f, ensure_ascii=False, indent=2)
            
        # Save merges
        with open(f"{prefix}.merges.json", 'w', encoding='utf-8') as f:
            # Convert tuple keys to strings for JSON serialization
            serializable_merges = {f"{k[0]} {k[1]}": v for k, v in self.merges.items()}
            json.dump(serializable_merges, f, ensure_ascii=False, indent=2)
            
        print(f"Tokenizer saved to {prefix}.vocab.json and {prefix}.merges.json")
        
    @classmethod
    def load(cls, prefix: str) -> 'TamilBPETokenizer':
        """Load a tokenizer from saved files."""
        tokenizer = cls()
        
        # Load vocabulary
        with open(f"{prefix}.vocab.json", 'r', encoding='utf-8') as f:
            tokenizer.vocab = json.load(f)
            # Convert string keys to integers where needed
            tokenizer.vocab = {k: int(v) if isinstance(v, str) and v.isdigit() else v 
                              for k, v in tokenizer.vocab.items()}
            
        # Load merges
        with open(f"{prefix}.merges.json", 'r', encoding='utf-8') as f:
            serialized_merges = json.load(f)
            # Convert string keys back to tuples
            tokenizer.merges = {tuple(k.split()): v for k, v in serialized_merges.items()}
            
        # Create inverse vocab
        tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
        
        return tokenizer


def train_tamil_bpe(input_file: str, output_prefix: str, vocab_size: int = 32000):
    """Train a Tamil BPE tokenizer and save it."""
    tokenizer = TamilBPETokenizer(vocab_size=vocab_size)
    tokenizer.train(input_file)
    tokenizer.save(output_prefix)
    return tokenizer

    
# tokenizer = train_tamil_bpe("C://Users//ASUS//Desktop//AI Projects//TamilNMT//data_general_en_ta 87k//general_en_ta 87k_source_full.txt", "bpee", 30000)

# # Test the tokenizer on a sample sentence
# sample = "தமிழ் மொழி மிகவும் பழமையான மொழிகளில் ஒன்றாகும்"
# tokens = tokenizer.tokenize(sample)
# decoded = tokenizer.decode(tokens)

# print("\nSample tokenization test:")
# print(f"Original: {sample}")
# print(f"Tokens: {tokens}")
# print(f"Decoded: {decoded}")

In [5]:
def _get_stats(words: List[List[str]]) -> Dict[Tuple[str, str], int]:
        """Count frequency of adjacent symbol pairs in the training data."""
        pairs = collections.defaultdict(int)
        for word in words:
            for i in range(len(word) - 1):
                pairs[tuple(word[i:i+2])] += 1
        return pairs

In [6]:
_get_stats(ta)

defaultdict(int, {})

In [8]:
bpe = TamilBPETokenizer(30000)
ta = bpe._tokenize_tamil_words(ta)

In [9]:
_get_stats(ta)

defaultdict(int,
            {('ந', 'ா'): 464,
             ('ா', 'ன'): 598,
             ('ன', '்'): 1652,
             ('்', '</w>'): 2929,
             ('இ', 'ப'): 27,
             ('ப', '்'): 527,
             ('்', 'ப'): 652,
             ('ப', 'ோ'): 174,
             ('ோ', 'த'): 102,
             ('த', 'ு'): 914,
             ('ு', '</w>'): 1316,
             ('வ', 'ே'): 207,
             ('ே', 'ல'): 58,
             ('ல', 'ை'): 328,
             ('ை', 'ய'): 305,
             ('ய', 'ி'): 145,
             ('ி', 'ல'): 475,
             ('ல', '்'): 790,
             ('இ', 'ர'): 256,
             ('ர', 'ு'): 798,
             ('ு', 'க'): 663,
             ('க', '்'): 1106,
             ('்', 'க'): 2063,
             ('க', 'ி'): 675,
             ('ி', 'ற'): 634,
             ('ற', 'ே'): 184,
             ('ே', 'ன'): 346,
             ('்', '.'): 1179,
             ('.', '</w>'): 1688,
             ('த', 'ே'): 119,
             ('ே', 'ர'): 123,
             ('ர', 'ை'): 94,
          

In [58]:
normalised = unicodedata.normalize('NFC',ta[0])

In [61]:
ta_norm = TamilNormalizer().normalize(ta)

In [None]:
striped = []
for i in ta:
    striped.append(i.strip())

['ந',
 'ா',
 'ன',
 '்',
 ' ',
 'இ',
 'ப',
 '்',
 'ப',
 'ோ',
 'த',
 'ு',
 ' ',
 'வ',
 'ே',
 'ல',
 'ை',
 'ய',
 'ி',
 'ல',
 '்',
 ' ',
 'இ',
 'ர',
 'ு',
 'க',
 '்',
 'க',
 'ி',
 'ற',
 'ே',
 'ன',
 '்',
 '.',
 '\n',
 'த',
 'ே',
 'ர',
 'ை',
 'ய',
 'ி',
 'ல',
 'ி',
 'ர',
 'ு',
 'ந',
 '்',
 'த',
 'ு',
 ' ',
 'த',
 'வ',
 'ள',
 'ை',
 'ய',
 'ை',
 ' ',
 'வ',
 'ே',
 'ற',
 'ு',
 'ப',
 'ட',
 'ு',
 'த',
 '்',
 'த',
 'ி',
 'ப',
 '்',
 ' ',
 'ப',
 'ா',
 'ர',
 '்',
 'க',
 '்',
 'க',
 ' ',
 'ம',
 'ு',
 'ட',
 'ி',
 'ய',
 'ா',
 'த',
 'ு',
 '.',
 '\n',
 'உ',
 'ங',
 '்',
 'க',
 'ள',
 'ு',
 'க',
 '்',
 'க',
 'ு',
 ' ',
 'ப',
 'க',
 'ு',
 'த',
 'ி',
 ' ',
 'ந',
 'ே',
 'ர',
 ' ',
 'வ',
 'ே',
 'ல',
 'ை',
 ' ',
 'இ',
 'ர',
 'ு',
 'க',
 '்',
 'க',
 'ி',
 'ற',
 'த',
 'ா',
 '?',
 '\n',
 'இ',
 'த',
 'ு',
 ' ',
 'ச',
 'ி',
 'க',
 '்',
 'க',
 'ல',
 'ா',
 'ன',
 'த',
 'ு',
 '.',
 '\n',
 'இ',
 'ந',
 '்',
 'த',
 ' ',
 'க',
 'ட',
 'ை',
 'ய',
 'ி',
 'ல',
 '்',
 ' ',
 'ந',
 'ல',
 '்',
 'ல',
 ' ',
 'க',
 'ா',
 'ல',
 'ண',
 'ி',


In [None]:
def clean_full_corpus(full_corpus:List[str]):
    cleaned = []
    for i in full_corpus:
        i = i.strip()
        if i.startswith('"') and len(i) > 1:
            i = i[1:]
        elif i.startswith('"'):
            del         

'"\n'

In [None]:
idx_to_word={}
word_to_idx = {}

for idx, word in enumerate(ta):
    

'"நாம் இந்தப் பணியை முன்னெடுத்துச் செல்வோம்.\n'

In [None]:

class TranslationDataset(Dataset):
    def __init__(self, tamil_lines, english_lines):
        self.tamil = tamil_lines
        self.english = english_lines

    def __len__(self): return len(self.tamil)
    def __getitem__(self, idx): return self.tamil[idx], self.english[idx]

tamil_bpe = open("train.tamil.bpe", "r", encoding="utf-8").readlines()
english_bpe = open("train.english.bpe", "r", encoding="utf-8").readlines()
dataset = TranslationDataset(tamil_bpe, english_bpe)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [16]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_heads=8, n_layers=6, d_ff=2048, dropout=0.1):
        super(Transformer, self).__init__()
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, n_heads, d_ff, dropout),
            num_layers=n_layers
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model, n_heads, d_ff, dropout),
            num_layers=n_layers
        )
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_enc = nn.Parameter(torch.zeros(10000, d_model))  # Positional encoding
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        self.d_model = d_model

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src = self.src_emb(src) + self.pos_enc[:src.size(0)]
        tgt = self.tgt_emb(tgt) + self.pos_enc[:tgt.size(0)]
        memory = self.encoder(src, src_key_padding_mask=src_mask)
        output = self.decoder(tgt, memory, memory_key_padding_mask=src_mask, tgt_mask=tgt_mask)
        return self.fc_out(output)

# Example usage (later in training)
model = Transformer(src_vocab_size=32000, tgt_vocab_size=32000)  # Adjust vocab sizes
if torch.cuda.is_available():
    model = model.cuda()  # Use RTX 4060



In [None]:
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast

optimizer = optim.Adam(model.parameters(), lr=0.0001)
scaler = GradScaler()
criterion = nn.CrossEntropyLoss(ignore_index=0)  # PAD token

for epoch in range(10):
    model.train()
    for tamil_batch, english_batch in loader:
        tamil = torch.tensor(tamil_batch, dtype=torch.long).cuda()
        english = torch.tensor(english_batch, dtype=torch.long).cuda()

        optimizer.zero_grad()
        with autocast():
            output = model(tamil, english[:, :-1])  # Shift target for teacher forcing
            loss = criterion(output.view(-1, output.size(-1)), english[:, 1:].contiguous().view(-1))
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()