In [None]:
from transformers import PreTrainedTokenizerFast 

tokenizer = PreTrainedTokenizerFast.from_pretrained("charanhu/kannada-tokenizer")

text = train_data['tgt']
src_text = train_data['src']

print(text[2])
print(src_text[2])

In [None]:
for i, s_txt in enumerate(text):
    encoding = tokenizer.encode(s_txt) 
    tokens = tokenizer.convert_ids_to_tokens(encoding) 
    decoded_text = tokenizer.decode(encoding) 

    # print(f"original text : {s_txt}") 
    # print(f"encoded text: {encoding}")
    # print(f"tokens: {tokens}") 
    # print(f"decoded_text: {decoded_text}") 

In [1]:
from transformers import AutoModel, AutoTokenizer
import itertools
import torch

# load model
model = AutoModel.from_pretrained("aneuraz/awesome-align-with-co")
tokenizer = AutoTokenizer.from_pretrained("aneuraz/awesome-align-with-co")

# model parameters
align_layer = 8
threshold = 1e-3

# define inputs
src = 'the universe is vast and who knows how many species are within it.'
tgt = 'ವಿಶ್ವವು ವಿಶಾಲವಾಗಿದೆ ಮತ್ತು ಅದರೊಳಗೆ ಎಷ್ಟು ಜಾತಿಗಳಿವೆ ಎಂದು ಯಾರಿಗೆ ತಿಳಿದಿದೆ.'

# pre-processing
sent_src, sent_tgt = src.strip().split(), tgt.strip().split()

print("="*50)
print("original words:")
print("="*50)
print("Source words:", sent_src)
print("Target words:", sent_tgt)
print()

token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [tokenizer.tokenize(word) for word in sent_tgt]

print("="*50)
print("tokenization (Word -> Subwords):")
print("="*50)
print("src tokenization:")
for i, (word, tokens) in enumerate(zip(sent_src, token_src)):
    print(f"Word {i}: '{word}' -> {tokens}")

print("\tgt tokenization:")
for i, (word, tokens) in enumerate(zip(sent_tgt, token_tgt)):
    print(f"Word {i}: '{word}' -> {tokens}")
print()

flat_tokens_src = list(itertools.chain(*token_src))
flat_tokens_tgt = list(itertools.chain(*token_tgt))

print("="*50)
print("="*50)
print("src subwords :")
for i, token in enumerate(flat_tokens_src):
    print(f"Subword {i}: '{token}'")

print("\n tgt subwords:")
for i, token in enumerate(flat_tokens_tgt):
    print(f"Subword {i}: '{token}'")
print()

# Convert to IDs
wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [tokenizer.convert_tokens_to_ids(x) for x in token_tgt]

print("="*50)
print("token ids (Word level):")
print("="*50)
print("src token ids:")
for i, (word, tokens, ids) in enumerate(zip(sent_src, token_src, wid_src)):
    print(f"Word {i}: '{word}' -> tokens: {tokens} -> ids: {ids}")

print("\tgt token ids:")
for i, (word, tokens, ids) in enumerate(zip(sent_tgt, token_tgt, wid_tgt)):
    print(f"Word {i}: '{word}' -> tokens: {tokens} -> ids: {ids}")
print()

ids_src = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', model_max_length=tokenizer.model_max_length, truncation=True)['input_ids']
ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', truncation=True, model_max_length=tokenizer.model_max_length)['input_ids']

print("="*50)
print("final model input (with special tokens):")
print("="*50)
print("SOURCE INPUT IDs:", ids_src.squeeze().tolist())
print("SOURCE TOKENS:", tokenizer.convert_ids_to_tokens(ids_src.squeeze().tolist()))

print("\target input ids:", ids_tgt.squeeze().tolist())
print("tgt tokens:", tokenizer.convert_ids_to_tokens(ids_tgt.squeeze().tolist()))
print()

sub2word_map_src = []
sub2word_map_tgt = []

print("="*50)
print("subword to word mapping")
print("="*50)
print("src mapping ")
subword_idx_src = 0
for word_idx, word_tokens in enumerate(token_src):
    print(f"Word {word_idx} ('{sent_src[word_idx]}'):")
    for token in word_tokens:
        print(f"  Subword {subword_idx_src}: '{token}' -> maps to word {word_idx}")
        sub2word_map_src.append(word_idx)
        subword_idx_src += 1

print(f"\nSource sub2word_map: {sub2word_map_src}")

subword_idx_tgt = 0
for word_idx, word_tokens in enumerate(token_tgt):
    print(f"Word {word_idx} ('{sent_tgt[word_idx]}'):")
    for token in word_tokens:
        print(f"  Subword {subword_idx_tgt}: '{token}' -> maps to word {word_idx}")
        sub2word_map_tgt.append(word_idx)
        subword_idx_tgt += 1


# alignment
model.eval()
with torch.no_grad():
    out_src = model(ids_src.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]
    out_tgt = model(ids_tgt.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]
    dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))
    softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod)
    softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod)
    softmax_inter = (softmax_srctgt > threshold)*(softmax_tgtsrc > threshold)

align_subwords = torch.nonzero(softmax_inter, as_tuple=False)
align_words = set()

print("="*50)
print("="*50)
for i, j in align_subwords:
    src_subword = flat_tokens_src[i.item()]
    tgt_subword = flat_tokens_tgt[j.item()]
    src_word_idx = sub2word_map_src[i.item()]
    tgt_word_idx = sub2word_map_tgt[j.item()]
    src_word = sent_src[src_word_idx]
    tgt_word = sent_tgt[tgt_word_idx]
    
    print(f"Subword {i.item()} '{src_subword}' (from word {src_word_idx} '{src_word}') <-> "
          f"Subword {j.item()} '{tgt_subword}' (from word {tgt_word_idx} '{tgt_word}')")
    
    align_words.add((src_word_idx, tgt_word_idx))

print(f"\nWORD-LEVEL ALIGNMENTS:")
for src_idx, tgt_idx in sorted(align_words):
    print(f"Word {src_idx} '{sent_src[src_idx]}' <-> Word {tgt_idx} '{sent_tgt[tgt_idx]}'")

print(f"\nFinal alignment set: {align_words}")

config.json: 0.00B [00:00, ?B/s]

2025-07-01 14:22:58.823366: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751379779.014786      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751379779.074126      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


pytorch_model.bin:   0%|          | 0.00/1.09G [00:00<?, ?B/s]

Some weights of BertModel were not initialized from the model checkpoint at aneuraz/awesome-align-with-co and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/1.09G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/40.0 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


ORIGINAL WORDS:
Source words: ['the', 'universe', 'is', 'vast', 'and', 'who', 'knows', 'how', 'many', 'species', 'are', 'within', 'it.']
Target words: ['ವಿಶ್ವವು', 'ವಿಶಾಲವಾಗಿದೆ', 'ಮತ್ತು', 'ಅದರೊಳಗೆ', 'ಎಷ್ಟು', 'ಜಾತಿಗಳಿವೆ', 'ಎಂದು', 'ಯಾರಿಗೆ', 'ತಿಳಿದಿದೆ.']

TOKENIZATION (Word -> Subwords):
SOURCE TOKENIZATION:
Word 0: 'the' -> ['the']
Word 1: 'universe' -> ['universe']
Word 2: 'is' -> ['is']
Word 3: 'vast' -> ['vast']
Word 4: 'and' -> ['and']
Word 5: 'who' -> ['who']
Word 6: 'knows' -> ['knows']
Word 7: 'how' -> ['how']
Word 8: 'many' -> ['many']
Word 9: 'species' -> ['species']
Word 10: 'are' -> ['are']
Word 11: 'within' -> ['within']
Word 12: 'it.' -> ['it', '.']

TARGET TOKENIZATION:
Word 0: 'ವಿಶ್ವವು' -> ['ವ', '##ಿ', '##ಶ್', '##ವ', '##ವು']
Word 1: 'ವಿಶಾಲವಾಗಿದೆ' -> ['ವ', '##ಿ', '##ಶ', '##ಾಲ', '##ವಾಗಿದೆ']
Word 2: 'ಮತ್ತು' -> ['ಮತ್ತು']
Word 3: 'ಅದರೊಳಗೆ' -> ['ಅದರ', '##ೊ', '##ಳ', '##ಗೆ']
Word 4: 'ಎಷ್ಟು' -> ['ಎ', '##ಷ್ಟು']
Word 5: 'ಜಾತಿಗಳಿವೆ' -> ['ಜ', '##ಾತ', '##ಿ', '##ಗಳ', '##ಿವೆ']
Word 6: 'ಎಂದ

In [1]:
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast
import itertools
import torch
from collections import defaultdict, Counter

align_model = AutoModel.from_pretrained("aneuraz/awesome-align-with-co")
align_tokenizer = AutoTokenizer.from_pretrained("aneuraz/awesome-align-with-co")

kn_tokenizer = PreTrainedTokenizerFast.from_pretrained("charanhu/kannada-tokenizer")

align_layer = 8
threshold = 1e-3

all_token_pairs = []  
pair_frequency = Counter() 

def get_alignments(src_text, tgt_text):
    """
    Get word alignments using awesome-align model
    Returns: set of (src_word_idx, tgt_word_idx) tuples
    """
    sent_src = src_text.strip().split()
    sent_tgt = tgt_text.strip().split()

    # awesome aligner for mapping src to tgt word 
    token_src = [align_tokenizer.tokenize(word) for word in sent_src]
    token_tgt = [align_tokenizer.tokenize(word) for word in sent_tgt]
    
    # tokens to IDs
    wid_src = [align_tokenizer.convert_tokens_to_ids(x) for x in token_src]
    wid_tgt = [align_tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
    
    ids_src = align_tokenizer.prepare_for_model(
        list(itertools.chain(*wid_src)), 
        return_tensors='pt', 
        model_max_length=align_tokenizer.model_max_length, 
        truncation=True
    )['input_ids']
    
    ids_tgt = align_tokenizer.prepare_for_model(
        list(itertools.chain(*wid_tgt)), 
        return_tensors='pt', 
        truncation=True, 
        model_max_length=align_tokenizer.model_max_length
    )['input_ids']
    
    # Create subword to word mapping
    sub2word_map_src = []
    for i, word_list in enumerate(token_src):
        sub2word_map_src += [i for x in word_list]
    
    sub2word_map_tgt = []
    for i, word_list in enumerate(token_tgt):
        sub2word_map_tgt += [i for x in word_list]
    
    # Get alignments
    align_model.eval()
    with torch.no_grad():
        out_src = align_model(ids_src.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]
        out_tgt = align_model(ids_tgt.unsqueeze(0), output_hidden_states=True)[2][align_layer][0, 1:-1]
        
        dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))
        softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod)
        softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod)
        softmax_inter = (softmax_srctgt > threshold) * (softmax_tgtsrc > threshold)
    
    align_subwords = torch.nonzero(softmax_inter, as_tuple=False)
    
    # word level  alignments 
    align_words = set()
    for i, j in align_subwords:
        align_words.add((sub2word_map_src[i], sub2word_map_tgt[j]))
    
    return align_words, sent_src, sent_tgt

for idx, (src_text, tgt_text) in enumerate(zip(train_data["src"], train_data["tgt"])):
    if idx % 100 == 0: 
        print(f"Processed {idx} sentences...")
    
    try:
        alignments, src_words, tgt_words = get_alignments(src_text, tgt_text)
        
        en_tokens = align_tokenizer.tokenize(src_text)  
        
        kn_tokens = align_tokenizer.tokenize(tgt_text)    
        
        # create token pairs from alignments
        for src_idx, tgt_idx in alignments:
            if src_idx < len(en_tokens) and tgt_idx < len(kn_tokens):
                en_token = en_tokens[src_idx]
                kn_token = kn_tokens[tgt_idx] if tgt_idx < len(kn_tokens) else tgt_words[tgt_idx]
                
                # create token pair
                token_pair = (en_token, kn_token)
                all_token_pairs.append(token_pair)
                
                # ct frequency
                pair_frequency[token_pair] += 1
        
        if idx < 3:
            print(f"\nExample {idx + 1}:")
            print(f"English: {src_text}")
            print(f"Kannada: {tgt_text}")
            print(f"Alignments: {alignments}")
            print(f"Token pairs: {[(en_tokens[i], kn_tokens[j] if j < len(kn_tokens) else tgt_words[j]) for i, j in alignments if i < len(en_tokens)]}")
            
    except Exception as e:
        print(f"Error processing sentence {idx}: {e}")
        continue

print(f"\nAlignment processing complete!")
print(f"Total unique token pairs: {len(pair_frequency)}")
print(f"Total alignment instances: {sum(pair_frequency.values())}")

# Display most frequent token pairs
print("\nMost frequent token pairs:")
for (en_token, kn_token), freq in pair_frequency.most_common(20):
    print(f"'{en_token}' <-> '{kn_token}': {freq} times")

# Filter by minimum frequency threshold
min_frequency = 5
filtered_pairs = {pair: freq for pair, freq in pair_frequency.items() if freq >= min_frequency}
print(f"\nToken pairs with frequency >= {min_frequency}: {len(filtered_pairs)}")

torch.save(pair_frequency, 'kannada_english_token_pairs.pt')
print("tok pair saved")

config.json: 0.00B [00:00, ?B/s]

2025-07-03 14:00:47.202271: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751551247.394399      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751551247.452753      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


pytorch_model.bin:   0%|          | 0.00/1.09G [00:00<?, ?B/s]

Some weights of BertModel were not initialized from the model checkpoint at aneuraz/awesome-align-with-co and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/1.09G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/40.0 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Processing sentence pairs...


NameError: name 'train_data' is not defined

In [1]:
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
import itertools
import torch
import json
import pickle
from tqdm import tqdm
import os
from typing import List, Tuple, Set, Dict, Any
import gc

class WordAligner:
    def __init__(self, model_name="aneuraz/awesome-align-with-co", align_layer=8, threshold=1e-3, use_gpu=True):
        """Initialize the word aligner with model and parameters."""
        print("Loading model and tokenizer...")
        self.device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
        print(f"Using device: {self.device}")
        
        self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.align_layer = align_layer
        self.threshold = threshold
        self.model.eval()
        
        
        self.use_amp = use_gpu and torch.cuda.is_available()
        if self.use_amp:
            print("mixed precision for faster inference")
        
        print("model loaded")
    
    def align_sentence_pair(self, src_text: str, tgt_text: str) -> Dict[str, Any]:
        """
            src_words: list of source words
            tgt_words: list of target words
            word_alignments: set of (src_idx, tgt_idx) tuples
            subword_alignments: list of detailed subword alignment info
        """
            sent_src = src_text.strip().split()
            sent_tgt = tgt_text.strip().split()
            
            if not sent_src or not sent_tgt:
                return None
            
            token_src = [self.tokenizer.tokenize(word) for word in sent_src]
            token_tgt = [self.tokenizer.tokenize(word) for word in sent_tgt]
            
            if not all(token_src) or not all(token_tgt):
                return None
            
            # flatten token lists 
            flat_tokens_src = list(itertools.chain(*token_src))
            flat_tokens_tgt = list(itertools.chain(*token_tgt))
            
            wid_src = [self.tokenizer.convert_tokens_to_ids(x) for x in token_src]
            wid_tgt = [self.tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
            
            ids_src = self.tokenizer.prepare_for_model(
                list(itertools.chain(*wid_src)), 
                return_tensors='pt', 
                model_max_length=self.tokenizer.model_max_length, 
                truncation=True
            )['input_ids'].to(self.device)
            
            ids_tgt = self.tokenizer.prepare_for_model(
                list(itertools.chain(*wid_tgt)), 
                return_tensors='pt', 
                truncation=True, 
                model_max_length=self.tokenizer.model_max_length
            )['input_ids'].to(self.device)
            
            # subword to word mappiing 
            sub2word_map_src = []
            for word_idx, word_tokens in enumerate(token_src):
                sub2word_map_src.extend([word_idx] * len(word_tokens))
            
            sub2word_map_tgt = []
            for word_idx, word_tokens in enumerate(token_tgt):
                sub2word_map_tgt.extend([word_idx] * len(word_tokens))
            
            # mixed precision 
            with torch.no_grad():
                if self.use_amp:
                    with torch.cuda.amp.autocast():
                        out_src = self.model(ids_src.unsqueeze(0), output_hidden_states=True)[2][self.align_layer][0, 1:-1]
                        out_tgt = self.model(ids_tgt.unsqueeze(0), output_hidden_states=True)[2][self.align_layer][0, 1:-1]
                        dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))
                else:
                    out_src = self.model(ids_src.unsqueeze(0), output_hidden_states=True)[2][self.align_layer][0, 1:-1]
                    out_tgt = self.model(ids_tgt.unsqueeze(0), output_hidden_states=True)[2][self.align_layer][0, 1:-1]
                    dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))
                
                softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod)
                softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod)
                softmax_inter = (softmax_srctgt > self.threshold) * (softmax_tgtsrc > self.threshold)
            
            align_subwords = torch.nonzero(softmax_inter, as_tuple=False)
            align_words = set()
            subword_alignments = []
            
            for i, j in align_subwords:
                i_idx, j_idx = i.item(), j.item()
                
                if i_idx >= len(sub2word_map_src) or j_idx >= len(sub2word_map_tgt):
                    continue
                
                src_subword = flat_tokens_src[i_idx]
                tgt_subword = flat_tokens_tgt[j_idx]
                src_word_idx = sub2word_map_src[i_idx]
                tgt_word_idx = sub2word_map_tgt[j_idx]
                
                subword_alignments.append({
                    'src_subword_idx': i_idx,
                    'tgt_subword_idx': j_idx,
                    'src_subword': src_subword,
                    'tgt_subword': tgt_subword,
                    'src_word_idx': src_word_idx,
                    'tgt_word_idx': tgt_word_idx,
                    'src_word': sent_src[src_word_idx],
                    'tgt_word': sent_tgt[tgt_word_idx]
                })
                
                align_words.add((src_word_idx, tgt_word_idx))
            
            return {
                'src_text': src_text,
                'tgt_text': tgt_text,
                'src_words': sent_src,
                'tgt_words': sent_tgt,
                'word_alignments': list(align_words),
                'subword_alignments': subword_alignments,
                'alignment_count': len(align_words)
            }
            

def process_dataset_batch(dataset, aligner, output_dir="alignment_output", batch_size=4000, max_samples=None):
    """Process the entire dataset in batches and save alignments."""
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Determine total samples to process
    total_samples = len(dataset) if max_samples is None else min(max_samples, len(dataset))
    print(f"Processing {total_samples} sentence pairs...")
    
    # Process in batches
    batch_num = 0
    alignments = []
    successful_alignments = 0
    failed_alignments = 0
    
    for i in tqdm(range(0, total_samples, batch_size), desc="Processing batches"):
        batch_alignments = []
        
        # Process current batch
        end_idx = min(i + batch_size, total_samples)
        for j in range(i, end_idx):
            sample = dataset[j]
            src_text = sample['src']  # English
            tgt_text = sample['tgt']  # Kannada
            
            alignment_result = aligner.align_sentence_pair(src_text, tgt_text)
            
            if alignment_result is not None:
                alignment_result['sample_idx'] = j
                batch_alignments.append(alignment_result)
                successful_alignments += 1
            else:
                failed_alignments += 1
        
        # Save batch to disk
        if batch_alignments:
            batch_file = os.path.join(output_dir, f"alignments_batch_{batch_num:04d}.pkl")
            with open(batch_file, 'wb') as f:
                pickle.dump(batch_alignments, f)
            
            # Also save as JSON for human readability (smaller sample)
            if len(batch_alignments) > 0:
                json_sample = batch_alignments[:10]  # Save first 10 as JSON for inspection
                json_file = os.path.join(output_dir, f"alignments_batch_{batch_num:04d}_sample.json")
                with open(json_file, 'w', encoding='utf-8') as f:
                    json.dump(json_sample, f, ensure_ascii=False, indent=2)
        
        alignments.extend(batch_alignments)
        batch_num += 1
        
        # Clear memory periodically
        if batch_num % 10 == 0:
            gc.collect()
            print(f"Processed {batch_num} batches. Success: {successful_alignments}, Failed: {failed_alignments}")
    
    # Save final statistics
    stats = {
        'total_processed': total_samples,
        'successful_alignments': successful_alignments,
        'failed_alignments': failed_alignments,
        'success_rate': successful_alignments / total_samples * 100,
        'total_batches': batch_num
    }
    
    stats_file = os.path.join(output_dir, "alignment_stats.json")
    with open(stats_file, 'w') as f:
        json.dump(stats, f, indent=2)
    
    print(f"\nAlignment completed!")
    print(f"Successful alignments: {successful_alignments}")
    print(f"Failed alignments: {failed_alignments}")
    print(f"Success rate: {stats['success_rate']:.2f}%")
    print(f"Results saved in: {output_dir}")
    
    return alignments, stats

def load_alignments_from_batches(output_dir="alignment_output"):
    """Load all alignment batches from disk."""
    alignments = []
    batch_files = sorted([f for f in os.listdir(output_dir) if f.startswith("alignments_batch_") and f.endswith(".pkl")])
    
    print(f"Loading {len(batch_files)} batch files...")
    for batch_file in tqdm(batch_files):
        with open(os.path.join(output_dir, batch_file), 'rb') as f:
            batch_alignments = pickle.load(f)
            alignments.extend(batch_alignments)
    
    return alignments

def create_translation_training_data(alignments, output_file="translation_training_data.json"):
    """Create training data optimized for translation models."""
    training_data = []
    
    for alignment in alignments:
        # Basic sentence pair
        training_sample = {
            'src': alignment['src_text'],
            'tgt': alignment['tgt_text'],
            'word_alignments': alignment['word_alignments'],
            'alignment_count': alignment['alignment_count']
        }
        
        # Add aligned word pairs for additional supervision
        aligned_pairs = []
        for src_idx, tgt_idx in alignment['word_alignments']:
            aligned_pairs.append({
                'src_word': alignment['src_words'][src_idx],
                'tgt_word': alignment['tgt_words'][tgt_idx],
                'src_idx': src_idx,
                'tgt_idx': tgt_idx
            })
        
        training_sample['aligned_word_pairs'] = aligned_pairs
        training_data.append(training_sample)
    
    # Save training data
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(training_data, f, ensure_ascii=False, indent=2)
    
    print(f"Training data saved to: {output_file}")
    return training_data

# Main execution
if __name__ == "__main__":
    # Load dataset
    print("Loading Samanantar Kannada dataset...")
    ds = load_dataset("ai4bharat/samanantar", "kn")
    train_data = ds['train']
    print(f"Dataset loaded with {len(train_data)} samples")
    
    # Initialize aligner
    aligner = WordAligner()

    # temp testing purposes 
    max_samples = 400000 
    
    alignments, stats = process_dataset_batch(
        train_data, 
        aligner, 
        output_dir="/kaggle/working/samanantar_alignments",
        batch_size=2000,  # Increased batch size for GPU
        max_samples=max_samples
    )
    
    # Create translation training data
    training_data = create_translation_training_data(
        alignments, 
        "/kaggle/working/samanantar_kn_translation_training.json"
    )
    
    print(f"\nProcessing complete! Generated {len(training_data)} training samples.")
    
    # Example of how to use the alignments
    if alignments:
        print("\nExample alignment:")
        example = alignments[0]
        print(f"Source: {example['src_text']}")
        print(f"Target: {example['tgt_text']}")
        print(f"Word alignments: {example['word_alignments']}")
        print("Aligned word pairs:")
        for src_idx, tgt_idx in example['word_alignments']:
            print(f"  '{example['src_words'][src_idx]}' <-> '{example['tgt_words'][tgt_idx]}'")

Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)
  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)
  Downloading nvidia_nvjitlink_cu12-12.4.127-py3-n

README.md: 0.00B [00:00, ?B/s]

train-00000-of-00002.parquet:   0%|          | 0.00/242M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/242M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4093524 [00:00<?, ? examples/s]

Dataset loaded with 4093524 samples
Loading model and tokenizer...
Using device: cuda


config.json: 0.00B [00:00, ?B/s]

2025-07-04 08:50:37.015669: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751619037.225846      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751619037.289241      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


pytorch_model.bin:   0%|          | 0.00/1.09G [00:00<?, ?B/s]

Some weights of BertModel were not initialized from the model checkpoint at aneuraz/awesome-align-with-co and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/1.09G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/40.0 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Using mixed precision for faster inference
Model loaded successfully!
Processing 400000 sentence pairs...



Processing batches:   0%|          | 0/200 [00:00<?, ?it/s][AYou're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  with torch.cuda.amp.autocast():

Processing batches:   0%|          | 1/200 [00:46<2:32:57, 46.12s/it][A
Processing batches:   1%|          | 2/200 [01:28<2:25:18, 44.03s/it][A
Processing batches:   2%|▏         | 3/200 [02:11<2:22:15, 43.33s/it][A
Processing batches:   2%|▏         | 4/200 [02:53<2:20:51, 43.12s/it][A
Processing batches:   2%|▎         | 5/200 [03:36<2:19:55, 43.05s/it][A
Processing batches:   3%|▎         | 6/200 [04:19<2:19:10, 43.05s/it][A
Processing batches:   4%|▎         | 7/200 [05:02<2:18:05, 42.93s/it][A
Processing batches:   4%|▍         | 8/200 [05:45<2:17:04, 42.84s/it][A
Processing batches:   4%|▍         | 9/200 [06:27<2:16:09, 42.77s/it][A
Processing batches:

Processed 10 batches. Success: 19940, Failed: 60



Processing batches:   6%|▌         | 11/200 [07:53<2:15:00, 42.86s/it][A
Processing batches:   6%|▌         | 12/200 [08:36<2:14:02, 42.78s/it][A
Processing batches:   6%|▋         | 13/200 [09:19<2:13:06, 42.71s/it][A
Processing batches:   7%|▋         | 14/200 [10:01<2:12:03, 42.60s/it][A
Processing batches:   8%|▊         | 15/200 [10:43<2:11:12, 42.56s/it][A
Processing batches:   8%|▊         | 16/200 [11:26<2:10:20, 42.50s/it][A
Processing batches:   8%|▊         | 17/200 [12:09<2:09:50, 42.57s/it][A
Processing batches:   9%|▉         | 18/200 [12:51<2:09:19, 42.63s/it][A
Processing batches:  10%|▉         | 19/200 [13:34<2:08:48, 42.70s/it][A
Processing batches:  10%|█         | 20/200 [14:18<2:08:48, 42.93s/it][A

Processed 20 batches. Success: 39873, Failed: 127



Processing batches:  10%|█         | 21/200 [15:01<2:08:11, 42.97s/it][A
Processing batches:  11%|█         | 22/200 [15:43<2:07:09, 42.86s/it][A
Processing batches:  12%|█▏        | 23/200 [16:26<2:06:20, 42.83s/it][A
Processing batches:  12%|█▏        | 24/200 [17:09<2:05:35, 42.81s/it][A
Processing batches:  12%|█▎        | 25/200 [17:52<2:05:22, 42.99s/it][A
Processing batches:  13%|█▎        | 26/200 [18:35<2:04:28, 42.92s/it][A
Processing batches:  14%|█▎        | 27/200 [19:17<2:03:19, 42.77s/it][A
Processing batches:  14%|█▍        | 28/200 [20:00<2:02:21, 42.68s/it][A
Processing batches:  14%|█▍        | 29/200 [20:43<2:02:08, 42.86s/it][A
Processing batches:  15%|█▌        | 30/200 [21:27<2:02:33, 43.25s/it][A

Processed 30 batches. Success: 59801, Failed: 199



Processing batches:  16%|█▌        | 31/200 [22:10<2:01:23, 43.10s/it][A
Processing batches:  16%|█▌        | 32/200 [22:53<2:00:16, 42.96s/it][A
Processing batches:  16%|█▋        | 33/200 [23:36<1:59:31, 42.94s/it][A
Processing batches:  17%|█▋        | 34/200 [24:18<1:58:43, 42.91s/it][A
Processing batches:  18%|█▊        | 35/200 [25:01<1:57:41, 42.80s/it][A
Processing batches:  18%|█▊        | 36/200 [25:44<1:56:46, 42.72s/it][A
Processing batches:  18%|█▊        | 37/200 [26:26<1:55:51, 42.65s/it][A
Processing batches:  19%|█▉        | 38/200 [27:08<1:54:56, 42.57s/it][A
Processing batches:  20%|█▉        | 39/200 [27:51<1:54:09, 42.54s/it][A
Processing batches:  20%|██        | 40/200 [28:34<1:54:02, 42.77s/it][A

Processed 40 batches. Success: 79730, Failed: 270



Processing batches:  20%|██        | 41/200 [29:17<1:53:08, 42.69s/it][A
Processing batches:  21%|██        | 42/200 [29:59<1:52:12, 42.61s/it][A
Processing batches:  22%|██▏       | 43/200 [30:41<1:51:09, 42.48s/it][A
Processing batches:  22%|██▏       | 44/200 [31:24<1:50:23, 42.46s/it][A
Processing batches:  22%|██▎       | 45/200 [32:06<1:49:40, 42.45s/it][A
Processing batches:  23%|██▎       | 46/200 [32:48<1:48:50, 42.41s/it][A
Processing batches:  24%|██▎       | 47/200 [33:31<1:47:59, 42.35s/it][A
Processing batches:  24%|██▍       | 48/200 [34:13<1:47:05, 42.27s/it][A
Processing batches:  24%|██▍       | 49/200 [34:55<1:46:25, 42.29s/it][A
Processing batches:  25%|██▌       | 50/200 [35:38<1:46:04, 42.43s/it][A

Processed 50 batches. Success: 99659, Failed: 341



Processing batches:  26%|██▌       | 51/200 [36:20<1:45:00, 42.29s/it][A
Processing batches:  26%|██▌       | 52/200 [37:02<1:44:11, 42.24s/it][A
Processing batches:  26%|██▋       | 53/200 [37:44<1:43:21, 42.18s/it][A
Processing batches:  27%|██▋       | 54/200 [38:26<1:42:51, 42.27s/it][A
Processing batches:  28%|██▊       | 55/200 [39:08<1:41:59, 42.20s/it][A
Processing batches:  28%|██▊       | 56/200 [39:51<1:41:16, 42.20s/it][A
Processing batches:  28%|██▊       | 57/200 [40:33<1:40:20, 42.10s/it][A
Processing batches:  29%|██▉       | 58/200 [41:14<1:39:28, 42.03s/it][A
Processing batches:  30%|██▉       | 59/200 [41:57<1:39:16, 42.24s/it][A
Processing batches:  30%|███       | 60/200 [42:40<1:39:09, 42.50s/it][A

Processed 60 batches. Success: 119594, Failed: 406



Processing batches:  30%|███       | 61/200 [43:22<1:38:07, 42.35s/it][A
Processing batches:  31%|███       | 62/200 [44:04<1:37:09, 42.24s/it][A
Processing batches:  32%|███▏      | 63/200 [44:47<1:36:40, 42.34s/it][A
Processing batches:  32%|███▏      | 64/200 [45:29<1:35:46, 42.26s/it][A
Processing batches:  32%|███▎      | 65/200 [46:11<1:34:48, 42.14s/it][A
Processing batches:  33%|███▎      | 66/200 [46:53<1:34:08, 42.15s/it][A
Processing batches:  34%|███▎      | 67/200 [47:35<1:33:08, 42.02s/it][A
Processing batches:  34%|███▍      | 68/200 [48:16<1:32:18, 41.96s/it][A
Processing batches:  34%|███▍      | 69/200 [48:58<1:31:35, 41.95s/it][A
Processing batches:  35%|███▌      | 70/200 [49:42<1:31:51, 42.40s/it][A

Processed 70 batches. Success: 139518, Failed: 482



Processing batches:  36%|███▌      | 71/200 [50:24<1:31:07, 42.38s/it][A
Processing batches:  36%|███▌      | 72/200 [51:06<1:30:23, 42.37s/it][A
Processing batches:  36%|███▋      | 73/200 [51:49<1:29:31, 42.29s/it][A
Processing batches:  37%|███▋      | 74/200 [52:31<1:28:48, 42.29s/it][A
Processing batches:  38%|███▊      | 75/200 [53:13<1:28:11, 42.33s/it][A
Processing batches:  38%|███▊      | 76/200 [53:56<1:27:38, 42.40s/it][A
Processing batches:  38%|███▊      | 77/200 [54:38<1:26:50, 42.36s/it][A
Processing batches:  39%|███▉      | 78/200 [55:21<1:26:09, 42.38s/it][A
Processing batches:  40%|███▉      | 79/200 [56:02<1:25:01, 42.16s/it][A
Processing batches:  40%|████      | 80/200 [56:46<1:25:04, 42.53s/it][A

Processed 80 batches. Success: 159445, Failed: 555



Processing batches:  40%|████      | 81/200 [57:28<1:24:08, 42.42s/it][A
Processing batches:  41%|████      | 82/200 [58:10<1:23:04, 42.24s/it][A
Processing batches:  42%|████▏     | 83/200 [58:51<1:22:07, 42.12s/it][A
Processing batches:  42%|████▏     | 84/200 [59:33<1:21:13, 42.01s/it][A
Processing batches:  42%|████▎     | 85/200 [1:00:15<1:20:23, 41.95s/it][A
Processing batches:  43%|████▎     | 86/200 [1:00:57<1:19:34, 41.88s/it][A
Processing batches:  44%|████▎     | 87/200 [1:01:38<1:18:47, 41.83s/it][A
Processing batches:  44%|████▍     | 88/200 [1:02:20<1:18:04, 41.82s/it][A
Processing batches:  44%|████▍     | 89/200 [1:03:02<1:17:13, 41.74s/it][A
Processing batches:  45%|████▌     | 90/200 [1:03:45<1:17:05, 42.05s/it][A

Processed 90 batches. Success: 179372, Failed: 628



Processing batches:  46%|████▌     | 91/200 [1:04:27<1:16:23, 42.05s/it][A
Processing batches:  46%|████▌     | 92/200 [1:05:09<1:15:40, 42.04s/it][A
Processing batches:  46%|████▋     | 93/200 [1:05:50<1:14:51, 41.98s/it][A
Processing batches:  47%|████▋     | 94/200 [1:06:32<1:14:02, 41.91s/it][A
Processing batches:  48%|████▊     | 95/200 [1:07:15<1:13:37, 42.08s/it][A
Processing batches:  48%|████▊     | 96/200 [1:07:57<1:13:10, 42.22s/it][A
Processing batches:  48%|████▊     | 97/200 [1:08:39<1:12:24, 42.18s/it][A
Processing batches:  49%|████▉     | 98/200 [1:09:22<1:11:47, 42.23s/it][A
Processing batches:  50%|████▉     | 99/200 [1:10:05<1:11:23, 42.41s/it][A
Processing batches:  50%|█████     | 100/200 [1:10:49<1:11:33, 42.93s/it][A

Processed 100 batches. Success: 199303, Failed: 697



Processing batches:  50%|█████     | 101/200 [1:11:31<1:10:47, 42.91s/it][A
Processing batches:  51%|█████     | 102/200 [1:12:15<1:10:08, 42.94s/it][A
Processing batches:  52%|█████▏    | 103/200 [1:12:57<1:09:25, 42.95s/it][A
Processing batches:  52%|█████▏    | 104/200 [1:13:40<1:08:19, 42.70s/it][A
Processing batches:  52%|█████▎    | 105/200 [1:14:22<1:07:32, 42.66s/it][A
Processing batches:  53%|█████▎    | 106/200 [1:15:05<1:06:48, 42.64s/it][A
Processing batches:  54%|█████▎    | 107/200 [1:15:48<1:06:18, 42.78s/it][A
Processing batches:  54%|█████▍    | 108/200 [1:16:30<1:05:30, 42.73s/it][A
Processing batches:  55%|█████▍    | 109/200 [1:17:13<1:04:50, 42.75s/it][A
Processing batches:  55%|█████▌    | 110/200 [1:17:57<1:04:32, 43.03s/it][A

Processed 110 batches. Success: 219235, Failed: 765



Processing batches:  56%|█████▌    | 111/200 [1:18:39<1:03:32, 42.84s/it][A
Processing batches:  56%|█████▌    | 112/200 [1:19:22<1:02:49, 42.84s/it][A
Processing batches:  56%|█████▋    | 113/200 [1:20:05<1:02:05, 42.82s/it][A
Processing batches:  57%|█████▋    | 114/200 [1:20:48<1:01:31, 42.92s/it][A
Processing batches:  57%|█████▊    | 115/200 [1:21:31<1:00:42, 42.85s/it][A
Processing batches:  58%|█████▊    | 116/200 [1:22:14<1:00:07, 42.94s/it][A
Processing batches:  58%|█████▊    | 117/200 [1:22:57<59:29, 43.01s/it]  [A
Processing batches:  59%|█████▉    | 118/200 [1:23:40<58:44, 42.98s/it][A
Processing batches:  60%|█████▉    | 119/200 [1:24:23<57:49, 42.83s/it][A
Processing batches:  60%|██████    | 120/200 [1:25:06<57:31, 43.14s/it][A

Processed 120 batches. Success: 239172, Failed: 828



Processing batches:  60%|██████    | 121/200 [1:25:49<56:32, 42.94s/it][A
Processing batches:  61%|██████    | 122/200 [1:26:31<55:33, 42.73s/it][A
Processing batches:  62%|██████▏   | 123/200 [1:27:13<54:34, 42.52s/it][A
Processing batches:  62%|██████▏   | 124/200 [1:27:55<53:44, 42.42s/it][A
Processing batches:  62%|██████▎   | 125/200 [1:28:37<52:55, 42.34s/it][A
Processing batches:  63%|██████▎   | 126/200 [1:29:20<52:08, 42.28s/it][A
Processing batches:  64%|██████▎   | 127/200 [1:30:02<51:23, 42.25s/it][A
Processing batches:  64%|██████▍   | 128/200 [1:30:44<50:41, 42.24s/it][A
Processing batches:  64%|██████▍   | 129/200 [1:31:26<50:02, 42.30s/it][A
Processing batches:  65%|██████▌   | 130/200 [1:32:10<49:53, 42.76s/it][A

Processed 130 batches. Success: 259110, Failed: 890



Processing batches:  66%|██████▌   | 131/200 [1:32:52<48:57, 42.57s/it][A
Processing batches:  66%|██████▌   | 132/200 [1:33:35<48:08, 42.47s/it][A
Processing batches:  66%|██████▋   | 133/200 [1:34:17<47:25, 42.48s/it][A
Processing batches:  67%|██████▋   | 134/200 [1:34:59<46:40, 42.43s/it][A
Processing batches:  68%|██████▊   | 135/200 [1:35:42<45:59, 42.46s/it][A
Processing batches:  68%|██████▊   | 136/200 [1:36:25<45:34, 42.73s/it][A
Processing batches:  68%|██████▊   | 137/200 [1:37:08<44:54, 42.76s/it][A
Processing batches:  69%|██████▉   | 138/200 [1:37:51<44:03, 42.64s/it][A
Processing batches:  70%|██████▉   | 139/200 [1:38:33<43:22, 42.66s/it][A
Processing batches:  70%|███████   | 140/200 [1:39:17<43:04, 43.08s/it][A

Processed 140 batches. Success: 279044, Failed: 956



Processing batches:  70%|███████   | 141/200 [1:40:00<42:14, 42.97s/it][A
Processing batches:  71%|███████   | 142/200 [1:40:43<41:29, 42.93s/it][A
Processing batches:  72%|███████▏  | 143/200 [1:41:26<40:48, 42.95s/it][A
Processing batches:  72%|███████▏  | 144/200 [1:42:09<40:06, 42.97s/it][A
Processing batches:  72%|███████▎  | 145/200 [1:42:52<39:23, 42.97s/it][A
Processing batches:  73%|███████▎  | 146/200 [1:43:34<38:33, 42.85s/it][A
Processing batches:  74%|███████▎  | 147/200 [1:44:17<37:52, 42.89s/it][A
Processing batches:  74%|███████▍  | 148/200 [1:45:01<37:15, 43.00s/it][A
Processing batches:  74%|███████▍  | 149/200 [1:45:43<36:28, 42.90s/it][A
Processing batches:  75%|███████▌  | 150/200 [1:46:28<36:15, 43.51s/it][A

Processed 150 batches. Success: 298970, Failed: 1030



Processing batches:  76%|███████▌  | 151/200 [1:47:11<35:23, 43.33s/it][A
Processing batches:  76%|███████▌  | 152/200 [1:47:54<34:33, 43.19s/it][A
Processing batches:  76%|███████▋  | 153/200 [1:48:37<33:41, 43.01s/it][A
Processing batches:  77%|███████▋  | 154/200 [1:49:20<32:57, 43.00s/it][A
Processing batches:  78%|███████▊  | 155/200 [1:50:02<32:10, 42.90s/it][A
Processing batches:  78%|███████▊  | 156/200 [1:50:45<31:24, 42.82s/it][A
Processing batches:  78%|███████▊  | 157/200 [1:51:27<30:37, 42.74s/it][A
Processing batches:  79%|███████▉  | 158/200 [1:52:11<29:59, 42.86s/it][A
Processing batches:  80%|███████▉  | 159/200 [1:52:54<29:23, 43.01s/it][A
Processing batches:  80%|████████  | 160/200 [1:53:39<29:03, 43.60s/it][A

Processed 160 batches. Success: 318897, Failed: 1103



Processing batches:  80%|████████  | 161/200 [1:54:22<28:15, 43.47s/it][A
Processing batches:  81%|████████  | 162/200 [1:55:05<27:29, 43.42s/it][A
Processing batches:  82%|████████▏ | 163/200 [1:55:49<26:46, 43.41s/it][A
Processing batches:  82%|████████▏ | 164/200 [1:56:32<26:01, 43.38s/it][A
Processing batches:  82%|████████▎ | 165/200 [1:57:16<25:19, 43.41s/it][A
Processing batches:  83%|████████▎ | 166/200 [1:57:59<24:32, 43.30s/it][A
Processing batches:  84%|████████▎ | 167/200 [1:58:42<23:48, 43.28s/it][A
Processing batches:  84%|████████▍ | 168/200 [1:59:25<23:01, 43.19s/it][A
Processing batches:  84%|████████▍ | 169/200 [2:00:08<22:15, 43.09s/it][A
Processing batches:  85%|████████▌ | 170/200 [2:00:53<21:49, 43.66s/it][A

Processed 170 batches. Success: 338823, Failed: 1177



Processing batches:  86%|████████▌ | 171/200 [2:01:36<20:59, 43.42s/it][A
Processing batches:  86%|████████▌ | 172/200 [2:02:19<20:13, 43.32s/it][A
Processing batches:  86%|████████▋ | 173/200 [2:03:02<19:26, 43.19s/it][A
Processing batches:  87%|████████▋ | 174/200 [2:03:44<18:39, 43.08s/it][A
Processing batches:  88%|████████▊ | 175/200 [2:04:27<17:53, 42.93s/it][A
Processing batches:  88%|████████▊ | 176/200 [2:05:09<17:06, 42.78s/it][A
Processing batches:  88%|████████▊ | 177/200 [2:05:52<16:22, 42.73s/it][A
Processing batches:  89%|████████▉ | 178/200 [2:06:35<15:40, 42.77s/it][A
Processing batches:  90%|████████▉ | 179/200 [2:07:18<14:59, 42.84s/it][A
Processing batches:  90%|█████████ | 180/200 [2:08:02<14:26, 43.34s/it][A

Processed 180 batches. Success: 358755, Failed: 1245



Processing batches:  90%|█████████ | 181/200 [2:08:45<13:36, 43.00s/it][A
Processing batches:  91%|█████████ | 182/200 [2:09:27<12:51, 42.89s/it][A
Processing batches:  92%|█████████▏| 183/200 [2:10:10<12:07, 42.81s/it][A
Processing batches:  92%|█████████▏| 184/200 [2:10:53<11:25, 42.86s/it][A
Processing batches:  92%|█████████▎| 185/200 [2:11:35<10:41, 42.74s/it][A
Processing batches:  93%|█████████▎| 186/200 [2:12:18<09:57, 42.69s/it][A
Processing batches:  94%|█████████▎| 187/200 [2:13:00<09:13, 42.61s/it][A
Processing batches:  94%|█████████▍| 188/200 [2:13:43<08:30, 42.56s/it][A
Processing batches:  94%|█████████▍| 189/200 [2:14:25<07:47, 42.46s/it][A
Processing batches:  95%|█████████▌| 190/200 [2:15:09<07:10, 43.02s/it][A

Processed 190 batches. Success: 378678, Failed: 1322



Processing batches:  96%|█████████▌| 191/200 [2:15:52<06:26, 42.90s/it][A
Processing batches:  96%|█████████▌| 192/200 [2:16:34<05:41, 42.73s/it][A
Processing batches:  96%|█████████▋| 193/200 [2:17:17<04:58, 42.63s/it][A
Processing batches:  97%|█████████▋| 194/200 [2:17:59<04:15, 42.57s/it][A
Processing batches:  98%|█████████▊| 195/200 [2:18:41<03:32, 42.52s/it][A
Processing batches:  98%|█████████▊| 196/200 [2:19:24<02:49, 42.45s/it][A
Processing batches:  98%|█████████▊| 197/200 [2:20:06<02:07, 42.39s/it][A
Processing batches:  99%|█████████▉| 198/200 [2:20:48<01:24, 42.37s/it][A
Processing batches: 100%|█████████▉| 199/200 [2:21:30<00:42, 42.29s/it][A
Processing batches: 100%|██████████| 200/200 [2:22:15<00:00, 42.68s/it][A

Processed 200 batches. Success: 398614, Failed: 1386

Alignment completed!
Successful alignments: 398614
Failed alignments: 1386
Success rate: 99.65%
Results saved in: /kaggle/working/samanantar_alignments





Training data saved to: /kaggle/working/samanantar_kn_translation_training.json

Processing complete! Generated 398614 training samples.

Example alignment:
Source: Hes a scientist.
Target: ಇವರು ಸಂಶೋಧಕ ಸ್ವಭಾವದವರು.
Word alignments: [(1, 2), (2, 2), (2, 1), (0, 0)]
Aligned word pairs:
  'a' <-> 'ಸ್ವಭಾವದವರು.'
  'scientist.' <-> 'ಸ್ವಭಾವದವರು.'
  'scientist.' <-> 'ಸಂಶೋಧಕ'
  'Hes' <-> 'ಇವರು'


In [None]:
import json 
with open ('/kaggle/working/samanantar_kn_translation_training.json', 'r') as file:
    data = json.load(file)
    print(data)
    print("\nPretty-printed JSON:")
    print(json.dumps(data, indent=4))


In [1]:
!pip install transformers torch
!pip install -U datasets

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len=5000):
        super().__init__()
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        
        # self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        # input shape: [batch_size, seq_len, d_model]
        # pe shape: [1, max_seq_len, d_model]
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_len = query.size(1)
        
        Q = self.q_linear(query).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.k_linear(key).view(batch_size, key.size(1), self.n_heads, self.head_dim).transpose(1, 2)
        V = self.v_linear(value).view(batch_size, value.size(1), self.n_heads, self.head_dim).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        attention_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, V)
        
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        return self.out(context), attention_weights

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        attn_output, _ = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, n_heads)
        self.cross_attention = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, encoder_output, self_mask=None, cross_mask=None):
        self_attn_output, _ = self.self_attention(x, x, x, self_mask)
        x = self.norm1(x + self.dropout(self_attn_output))
        
        cross_attn_output, cross_attention_weights = self.cross_attention(x, encoder_output, encoder_output, cross_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))
        
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x, cross_attention_weights

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_seq_len, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask=None):
        x = self.embedding(src) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        for layer in self.layers:
            x = layer(x, src_mask)
            
        return x

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_seq_len, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.output_projection = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, tgt, encoder_output, tgt_mask=None, src_mask=None):
        x = self.embedding(tgt) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        cross_attention_weights = []
        
        for layer in self.layers:
            x, cross_attn_weights = layer(x, encoder_output, tgt_mask, src_mask)
            cross_attention_weights.append(cross_attn_weights)
            
        output = self.output_projection(x)
        
        return output, cross_attention_weights

class EnglishKannadaTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_heads=8, 
                 n_encoder_layers=6, n_decoder_layers=6, d_ff=2048, max_seq_len=5000, dropout=0.1):
        super().__init__()
        
        self.encoder = TransformerEncoder(src_vocab_size, d_model, n_heads, n_encoder_layers, 
                                        d_ff, max_seq_len, dropout)
        self.decoder = TransformerDecoder(tgt_vocab_size, d_model, n_heads, n_decoder_layers, 
                                        d_ff, max_seq_len, dropout)
        
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        encoder_output = self.encoder(src, src_mask)
        decoder_output, cross_attention_weights = self.decoder(tgt, encoder_output, tgt_mask, src_mask)
        
        return decoder_output, cross_attention_weights
    
    def create_masks(self, src, tgt):
        device = src.device
        
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        
        # causal mask + padding mask
        tgt_len = tgt.size(1)
        
        tgt_mask = torch.tril(torch.ones(tgt_len, tgt_len, device=device)).unsqueeze(0).unsqueeze(0).bool()
        
        tgt_padding_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
        
        tgt_mask = tgt_mask & tgt_padding_mask
        
        return src_mask, tgt_mask
        
def calculate_alignment_loss(cross_attention_weights, alignment_matrix, lambda_align=0.1):
    
    last_layer_attention = cross_attention_weights[-1]  
    
    avg_attention = last_layer_attention.mean(dim=1) 
    
     # [batch, tgt_len, src_len]
    alignment_target = alignment_matrix.transpose(-2, -1).float() 
    
    # Calculate MSE loss
    alignment_loss = F.mse_loss(avg_attention, alignment_target)
    
    return lambda_align * alignment_loss

def calculate_total_loss(model_output, target, cross_attention_weights, alignment_matrix, lambda_align=0.1):
    # Translation loss (cross-entropy)
    vocab_size = model_output.size(-1)
    translation_loss = F.cross_entropy(model_output.view(-1, vocab_size), target.view(-1), ignore_index=0)
    
    # Alignment loss
    alignment_loss = calculate_alignment_loss(cross_attention_weights, alignment_matrix, lambda_align)
    
    total_loss = translation_loss + alignment_loss
    
    return total_loss, translation_loss, alignment_loss

Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)
  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)
  Downloading nvidia_nvjitlink_cu12-12.4.127-py3-n

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import json
from collections import Counter
import numpy as np
from tqdm import tqdm

class TranslationDataset(Dataset):
    def __init__(self, json_file_path, src_vocab=None, tgt_vocab=None, max_len=128):
        with open(json_file_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
        
        self.max_len = max_len
        
        if src_vocab is None or tgt_vocab is None:
            self.src_vocab, self.tgt_vocab = self.build_vocabularies()
        else:
            self.src_vocab = src_vocab
            self.tgt_vocab = tgt_vocab
            
        self.src_word2idx = {word: idx for idx, word in enumerate(self.src_vocab)}
        self.tgt_word2idx = {word: idx for idx, word in enumerate(self.tgt_vocab)}
        
        self.processed_data = self.process_data()
        
    def build_vocabularies(self):
        src_words = []
        tgt_words = []
        
        for item in self.data:
            src_words.extend(item['src'].split())
            tgt_words.extend(item['tgt'].split())
            
        src_counter = Counter(src_words)
        tgt_counter = Counter(tgt_words)
        
        src_vocab = ['<pad>', '<unk>', '<sos>', '<eos>'] + [word for word, _ in src_counter.most_common()]
        tgt_vocab = ['<pad>', '<unk>', '<sos>', '<eos>'] + [word for word, _ in tgt_counter.most_common()]
        
        return src_vocab, tgt_vocab
    
    def text_to_indices(self, text, word2idx, is_target=False):
        words = text.split()
        indices = []
        
        if is_target:
            indices.append(word2idx.get('<sos>', 1))
            
        for word in words:
            indices.append(word2idx.get(word, word2idx.get('<unk>', 1)))
            
        if is_target:
            indices.append(word2idx.get('<eos>', 1))
            
        return indices
    
    def create_alignment_matrix(self, src_text, tgt_text, word_alignments):
        src_words = src_text.split()
        tgt_words = tgt_text.split()
        
        src_len = len(src_words)
        tgt_len = len(tgt_words)
        
        alignment_matrix = np.zeros((src_len, tgt_len), dtype=np.float32)
        
        for src_idx, tgt_idx in word_alignments:
            if src_idx < src_len and tgt_idx < tgt_len:
                alignment_matrix[src_idx, tgt_idx] = 1.0
                
        return alignment_matrix
    
    def process_data(self):
        processed = []
        
        for item in self.data:
            src_indices = self.text_to_indices(item['src'], self.src_word2idx)
            tgt_indices = self.text_to_indices(item['tgt'], self.tgt_word2idx, is_target=True)
            
            alignment_matrix = self.create_alignment_matrix(
                item['src'], item['tgt'], item['word_alignments']
            )
            
            processed.append({
                'src': src_indices,
                'tgt': tgt_indices,
                'alignment': alignment_matrix
            })
            
        return processed
    
    def __len__(self):
        return len(self.processed_data)
    
    def __getitem__(self, idx):
        return self.processed_data[idx]

def collate_fn(batch):
    src_seqs = [item['src'] for item in batch]
    tgt_seqs = [item['tgt'] for item in batch]
    alignments = [item['alignment'] for item in batch]
    
    # Pad sequences
    src_max_len = max(len(seq) for seq in src_seqs)
    tgt_max_len = max(len(seq) for seq in tgt_seqs)
    
    src_padded = []
    tgt_padded = []
    tgt_input_padded = []
    alignment_padded = []
    
    for i in range(len(batch)):
        src_seq = src_seqs[i] + [0] * (src_max_len - len(src_seqs[i]))
        tgt_seq = tgt_seqs[i] + [0] * (tgt_max_len - len(tgt_seqs[i]))
        tgt_input_seq = tgt_seq[:-1] + [0] * (tgt_max_len - len(tgt_seq))
        
        # Pad alignment matrix
        alignment = alignments[i]
        padded_alignment = np.zeros((src_max_len, tgt_max_len - 1), dtype=np.float32)
        padded_alignment[:alignment.shape[0], :alignment.shape[1]] = alignment
        
        src_padded.append(src_seq)
        tgt_padded.append(tgt_seq[1:])  # Remove <sos> for target
        tgt_input_padded.append(tgt_input_seq)  # Remove <eos> for input
        alignment_padded.append(padded_alignment)
    
    return {
        'src': torch.LongTensor(src_padded),
        'tgt_input': torch.LongTensor(tgt_input_padded),
        'tgt_output': torch.LongTensor(tgt_padded),
        'alignment': torch.FloatTensor(np.array(alignment_padded))
    }

def train_model(model, train_loader, val_loader, num_epochs, device, learning_rate=0.0001):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_trans_loss = 0
        total_align_loss = 0
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch_idx, batch in enumerate(progress_bar):
            src = batch['src'].to(device)
            tgt_input = batch['tgt_input'].to(device)
            tgt_output = batch['tgt_output'].to(device)
            alignment = batch['alignment'].to(device)
            
            optimizer.zero_grad()
            
            src_mask, tgt_mask = model.create_masks(src, tgt_input)
            model_output, cross_attention_weights = model(src, tgt_input, src_mask, tgt_mask)
            
            total_loss_batch, trans_loss_batch, align_loss_batch = calculate_total_loss(
                model_output, tgt_output, cross_attention_weights, alignment
            )
            
            total_loss_batch.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            # Clear cache periodically
            if batch_idx % 100 == 0:
                torch.cuda.empty_cache()
            
            total_loss += total_loss_batch.item()
            total_trans_loss += trans_loss_batch.item()
            total_align_loss += align_loss_batch.item()
            
            progress_bar.set_postfix({
                'Loss': f'{total_loss_batch.item():.4f}',
                'Trans': f'{trans_loss_batch.item():.4f}',
                'Align': f'{align_loss_batch.item():.4f}'
            })
            
        avg_train_loss = total_loss / len(train_loader)
        avg_trans_loss = total_trans_loss / len(train_loader)
        avg_align_loss = total_align_loss / len(train_loader)
        
        print(f'Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Trans: {avg_trans_loss:.4f}, Align: {avg_align_loss:.4f}')
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                src = batch['src'].to(device)
                tgt_input = batch['tgt_input'].to(device)
                tgt_output = batch['tgt_output'].to(device)
                alignment = batch['alignment'].to(device)
                
                src_mask, tgt_mask = model.create_masks(src, tgt_input)
                model_output, cross_attention_weights = model(src, tgt_input, src_mask, tgt_mask)
                
                total_loss_batch, _, _ = calculate_total_loss(
                    model_output, tgt_output, cross_attention_weights, alignment
                )
                val_loss += total_loss_batch.item()
        
        avg_val_loss = val_loss / len(val_loader)
        print(f'Epoch {epoch+1} - Val Loss: {avg_val_loss:.4f}')
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), '/kaggle/working/best_model.pth')
            print(f'New best model saved with val loss: {best_val_loss:.4f}')
        
        scheduler.step()
        print()

def main():
    torch.cuda.empty_cache()
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    
    dataset = TranslationDataset('/kaggle/input/kn-en-samanantar/samanantar_kn_translation_training (3).json')
    
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
    
    src_vocab_size = len(dataset.src_vocab)
    tgt_vocab_size = len(dataset.tgt_vocab)
    
    print(f'src vocab size: {src_vocab_size}')
    print(f'target vocab size: {tgt_vocab_size}')
    
    model = EnglishKannadaTransformer(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        d_model=256,
        n_heads=8,
        n_encoder_layers=3,
        n_decoder_layers=3,
        d_ff=1024,
        max_seq_len=512,
        dropout=0.1
    )
    
    print(f'model init w {sum(p.numel() for p in model.parameters())} params')
    
    train_model(model, train_loader, val_loader, num_epochs=5, device=device)
    
    torch.save(model.state_dict(), '/kaggle/working/final_model.pth')
    
    torch.save({
        'src_vocab': dataset.src_vocab,
        'tgt_vocab': dataset.tgt_vocab,
        'src_word2idx': dataset.src_word2idx,
        'tgt_word2idx': dataset.tgt_word2idx
    }, '/kaggle/working/vocabularies.pth')
    
    print('Training completed!')

if __name__ == '__main__':
    main()

Using device: cuda
Source vocabulary size: 174304
Target vocabulary size: 418295
Model initialized with 264736759 parameters


Epoch 1/5:  27%|██▋       | 10735/39862 [25:25<1:10:08,  6.92it/s, Loss=9.6147, Trans=9.6126, Align=0.0021]  

In [23]:
from IPython.display import FileLink
FileLink(r'best_model.pth.zip')


In [27]:
from IPython.display import FileLink
FileLink(r'final_model.pth')
