In [2]:
import pandas as pd
import numpy as np
import random
import sentencepiece as spm
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from collections import Counter  # ADD THIS IMPORT

# Set device and ensure reproducibility
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# --- Data Loading and Preprocessing ---
# Update the file path - make it more flexible
try:
    df = pd.read_csv('/kaggle/input/asm-eng/Filtered_data.tsv', sep='\t', names=['asm', 'eng'], on_bad_lines='skip')
except FileNotFoundError:
    print("Dataset not found at /content/Filtered_data.tsv")
    print("Please ensure the dataset file exists or update the path")
    # Create dummy data for testing if file not found
    df = pd.DataFrame({
        'asm': ['মই ঘৰত থাকোঁ।', 'তেওঁ আজি বিদ্যালয়লৈ গ\'ল।'],
        'eng': ['I am at home.', 'He went to school today.']
    })

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)

with open('all_text_for_bpe.txt', 'w', encoding='utf-8') as f:
    for text in pd.concat([train_df['asm'], train_df['eng']]):
        f.write(str(text).strip().lower() + '\n')

spm.SentencePieceTrainer.Train(
    '--input=all_text_for_bpe.txt --model_prefix=spm_bpe --vocab_size=8000 '
    '--character_coverage=1.0 --model_type=bpe'
)
sp = spm.SentencePieceProcessor()
sp.Load('spm_bpe.model')

PAD_IDX = 0
SOS_IDX = sp.bos_id()
EOS_IDX = sp.eos_id()
VOCAB_SIZE = sp.GetPieceSize()

# --- Display Vocabulary Information ---
print(f"\n=== VOCABULARY INFORMATION ===")
print(f"Vocabulary size: {VOCAB_SIZE}")
print(f"PAD token: {PAD_IDX} -> '{sp.id_to_piece(PAD_IDX)}'")
print(f"SOS token: {SOS_IDX} -> '{sp.id_to_piece(SOS_IDX)}'")
print(f"EOS token: {EOS_IDX} -> '{sp.id_to_piece(EOS_IDX)}'")
print(f"UNK token: {sp.unk_id()} -> '{sp.id_to_piece(sp.unk_id())}'")

# Display sample vocabulary
print(f"\nSample vocabulary (first 50 tokens):")
for i in range(min(50, VOCAB_SIZE)):
    piece = sp.id_to_piece(i)
    print(f"{i:3d}: '{piece}'")

# Display high-frequency tokens
print(f"\nHigh-frequency tokens (500-600):")
for i in range(500, min(600, VOCAB_SIZE)):
    piece = sp.id_to_piece(i)
    print(f"{i:3d}: '{piece}'")

# Test tokenization with sample sentences
print(f"\n=== TOKENIZATION EXAMPLES ===")
test_sentences = [
    "মই ঘৰত থাকোঁ।",
    "তেওঁ আজি বিদ্যালয়লৈ গ'ল।",
    "বইখন টেবুলৰ ওপৰত আছে।",
    "I am at home.",
    "He went to school today.",
    "The book is on the table."
]

for sentence in test_sentences:
    tokens = sp.encode_as_pieces(sentence.lower().strip())
    token_ids = sp.encode_as_ids(sentence.lower().strip())
    print(f"\nSentence: {sentence}")
    print(f"Tokens: {tokens}")
    print(f"Token IDs: {token_ids}")
    print(f"Reconstructed: {sp.decode_pieces(tokens)}")

# --- Dataset and DataLoader ---
class TranslationDataset(Dataset):
    def __init__(self, df, sp_model, max_len=100):
        self.src_sents = df['asm'].astype(str).tolist()
        self.trg_sents = df['eng'].astype(str).tolist()
        self.sp_model = sp_model
        self.max_len = max_len

    def __len__(self):
        return len(self.src_sents)

    def __getitem__(self, idx):
        src_encoded = self.sp_model.EncodeAsIds(self.src_sents[idx].lower().strip())
        trg_encoded = self.sp_model.EncodeAsIds(self.trg_sents[idx].lower().strip())
        src_tensor = torch.LongTensor([SOS_IDX] + src_encoded[:self.max_len-2] + [EOS_IDX])
        trg_tensor = torch.LongTensor([SOS_IDX] + trg_encoded[:self.max_len-2] + [EOS_IDX])
        return src_tensor, trg_tensor

def collate_fn(batch):
    srcs, trgs = zip(*batch)
    src_padded = pad_sequence(srcs, batch_first=True, padding_value=PAD_IDX)
    trg_padded = pad_sequence(trgs, batch_first=True, padding_value=PAD_IDX)
    return src_padded, trg_padded

BATCH_SIZE = 32  # Increased for stability
train_loader = DataLoader(TranslationDataset(train_df, sp), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(TranslationDataset(val_df, sp), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(TranslationDataset(test_df, sp), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

# Move analysis functions here and call them
def analyze_vocabulary(sp_model, train_df, val_df):
    """Analyze vocabulary coverage and statistics"""
    print(f"\n=== VOCABULARY ANALYSIS ===")
    
    # Collect all unique tokens
    all_tokens = set()
    asm_tokens = set()
    eng_tokens = set()
    
    # Analyze training data
    for _, row in train_df.iterrows():
        asm_toks = sp_model.encode_as_pieces(str(row['asm']).lower().strip())
        eng_toks = sp_model.encode_as_pieces(str(row['eng']).lower().strip())
        
        all_tokens.update(asm_toks)
        all_tokens.update(eng_toks)
        asm_tokens.update(asm_toks)
        eng_tokens.update(eng_toks)
    
    print(f"Total unique tokens in training: {len(all_tokens)}")
    print(f"Assamese tokens: {len(asm_tokens)}")
    print(f"English tokens: {len(eng_tokens)}")
    print(f"Vocabulary coverage: {len(all_tokens)/sp_model.GetPieceSize()*100:.1f}%")
    
    # Find most common tokens
    token_counts = Counter()
    
    for _, row in train_df.iterrows():
        tokens = sp_model.encode_as_pieces(str(row['asm']).lower().strip())
        tokens.extend(sp_model.encode_as_pieces(str(row['eng']).lower().strip()))
        token_counts.update(tokens)
    
    print(f"\nTop 20 most frequent tokens:")
    for token, count in token_counts.most_common(20):
        try:
            token_id = sp_model.piece_to_id(token)
            print(f"'{token}' (ID: {token_id}): {count} times")
        except:
            print(f"'{token}': {count} times (ID not found)")

def analyze_language_distribution(sp_model, train_df):
    """Analyze how vocabulary is distributed between languages"""
    print(f"\n=== LANGUAGE DISTRIBUTION IN VOCABULARY ===")
    
    # Count token frequency by language
    asm_token_counts = Counter()
    eng_token_counts = Counter()
    
    # Analyze Assamese sentences
    for text in train_df['asm'].astype(str):
        tokens = sp_model.encode_as_pieces(text.lower().strip())
        asm_token_counts.update(tokens)
    
    # Analyze English sentences  
    for text in train_df['eng'].astype(str):
        tokens = sp_model.encode_as_pieces(text.lower().strip())
        eng_token_counts.update(tokens)
    
    # Find shared vs unique tokens
    asm_tokens = set(asm_token_counts.keys())
    eng_tokens = set(eng_token_counts.keys())
    shared_tokens = asm_tokens & eng_tokens
    asm_only = asm_tokens - eng_tokens
    eng_only = eng_tokens - asm_tokens
    
    print(f"Total vocabulary size: {sp_model.GetPieceSize()}")
    print(f"Tokens used in Assamese: {len(asm_tokens)}")
    print(f"Tokens used in English: {len(eng_tokens)}")
    print(f"Shared tokens: {len(shared_tokens)}")
    print(f"Assamese-only tokens: {len(asm_only)}")
    print(f"English-only tokens: {len(eng_only)}")
    
    print(f"\nTop 10 shared tokens:")
    shared_counts = {token: asm_token_counts[token] + eng_token_counts[token] 
                    for token in shared_tokens}
    for token, count in sorted(shared_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
        print(f"  '{token}': {count} times (Asm: {asm_token_counts[token]}, Eng: {eng_token_counts[token]})")
    
    print(f"\nTop 10 Assamese-only tokens:")
    asm_only_sorted = [(token, asm_token_counts[token]) for token in asm_only]
    asm_only_sorted.sort(key=lambda x: x[1], reverse=True)
    for token, count in asm_only_sorted[:10]:
        print(f"  '{token}': {count} times")
    
    print(f"\nTop 10 English-only tokens:")
    eng_only_sorted = [(token, eng_token_counts[token]) for token in eng_only]
    eng_only_sorted.sort(key=lambda x: x[1], reverse=True)
    for token, count in eng_only_sorted[:10]:
        print(f"  '{token}': {count} times")

# Call analysis functions immediately after defining them
analyze_vocabulary(sp, train_df, val_df)
analyze_language_distribution(sp, train_df)

# --- Model Definition ---
# --- BiLSTM Encoder ---
class BiLSTMEncoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_IDX)
        self.dropout = nn.Dropout(dropout)
        
        # Single bidirectional LSTM layer for simplicity
        self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers=1, bidirectional=True, 
                           dropout=0, batch_first=True)
        
        # Layer normalization for stability
        self.layer_norm = nn.LayerNorm(hid_dim * 2)
        
        # Project to decoder size
        self.fc_hidden = nn.Linear(hid_dim * 2, hid_dim)
        self.fc_cell = nn.Linear(hid_dim * 2, hid_dim)

    def forward(self, src):
        # Get actual sequence lengths for packing
        src_lengths = (src != PAD_IDX).sum(dim=1).cpu()
        
        embedded = self.dropout(self.embedding(src))
        
        # Pack sequences for efficiency
        packed_embedded = nn.utils.rnn.pack_padded_sequence(
            embedded, src_lengths, batch_first=True, enforce_sorted=False)
        
        packed_outputs, (hidden, cell) = self.lstm(packed_embedded)
        
        # Unpack sequences
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
        
        # Apply layer normalization
        outputs = self.layer_norm(outputs)
        
        # Process final states - hidden/cell are [2, batch, hid_dim] for bidirectional
        # Concatenate forward and backward final states
        final_hidden = torch.cat((hidden[0], hidden[1]), dim=1)  # [batch, hid_dim*2]
        final_cell = torch.cat((cell[0], cell[1]), dim=1)        # [batch, hid_dim*2]
        
        # Project to decoder dimensions
        decoder_hidden = torch.tanh(self.fc_hidden(final_hidden)).unsqueeze(0)  # [1, batch, hid_dim]
        decoder_cell = torch.tanh(self.fc_cell(final_cell)).unsqueeze(0)        # [1, batch, hid_dim]
        
        return outputs, (decoder_hidden, decoder_cell)

class LSTMAttention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.hid_dim = hid_dim
        
        # Attention mechanism
        self.W_h = nn.Linear(hid_dim, hid_dim, bias=False)  # For hidden state
        self.W_s = nn.Linear(hid_dim * 2, hid_dim, bias=False)  # For encoder outputs
        self.v = nn.Linear(hid_dim, 1, bias=False)
        
        self.dropout = nn.Dropout(0.1)

    def forward(self, hidden_cell, encoder_outputs):
        # Extract hidden state
        if isinstance(hidden_cell, tuple):
            hidden = hidden_cell[0]  # [1, batch, hid_dim]
        else:
            hidden = hidden_cell
        
        # Remove layer dimension if present
        if hidden.dim() == 3:
            hidden = hidden.squeeze(0)  # [batch, hid_dim]
        
        batch_size, src_len, enc_hid_dim = encoder_outputs.size()
        
        # Prepare hidden state for attention computation
        hidden_transformed = self.W_h(hidden)  # [batch, hid_dim]
        hidden_expanded = hidden_transformed.unsqueeze(1).expand(-1, src_len, -1)  # [batch, src_len, hid_dim]
        
        # Transform encoder outputs
        encoder_transformed = self.W_s(encoder_outputs)  # [batch, src_len, hid_dim]
        
        # Compute attention scores
        energy = torch.tanh(hidden_expanded + encoder_transformed)  # [batch, src_len, hid_dim]
        attention_scores = self.v(energy).squeeze(2)  # [batch, src_len]
        
        # Create mask for padding tokens
        mask = (encoder_outputs.sum(dim=2) != 0)  # [batch, src_len]
        attention_scores = attention_scores.masked_fill(~mask, -1e10)
        
        # Apply softmax
        attention_weights = torch.softmax(attention_scores, dim=1)  # [batch, src_len]
        attention_weights = self.dropout(attention_weights)
        
        return attention_weights

class LSTMDecoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, dropout, attention):
        super().__init__()
        self.attention = attention
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_IDX)
        self.dropout = nn.Dropout(dropout)
        
        # Single layer LSTM
        self.lstm = nn.LSTM(emb_dim + hid_dim * 2, hid_dim, num_layers=1, batch_first=True)
        
        # Output projection with residual connection
        self.fc_out = nn.Linear(hid_dim + hid_dim * 2 + emb_dim, vocab_size)
        
        # Layer normalization
        self.layer_norm = nn.LayerNorm(hid_dim)

    def forward(self, input, hidden_cell, encoder_outputs):
        input = input.unsqueeze(1)  # [batch, 1]
        embedded = self.dropout(self.embedding(input))  # [batch, 1, emb_dim]
        
        # Compute attention
        attention_weights = self.attention(hidden_cell, encoder_outputs)  # [batch, src_len]
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)  # [batch, 1, hid_dim*2]
        
        # Combine embedding and context
        rnn_input = torch.cat((embedded, context), dim=2)  # [batch, 1, emb_dim + hid_dim*2]
        
        # LSTM forward pass
        output, new_hidden_cell = self.lstm(rnn_input, hidden_cell)  # [batch, 1, hid_dim]
        
        # Apply layer normalization
        output = self.layer_norm(output)
        
        # Prepare for output projection
        embedded_flat = embedded.squeeze(1)  # [batch, emb_dim]
        output_flat = output.squeeze(1)      # [batch, hid_dim]
        context_flat = context.squeeze(1)    # [batch, hid_dim*2]
        
        # Final output projection
        prediction = self.fc_out(torch.cat((output_flat, context_flat, embedded_flat), dim=1))
        
        return prediction, new_hidden_cell

class BiLSTMSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.fc_out.out_features
        
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
        encoder_outputs, hidden_cell = self.encoder(src)
        
        input_token = trg[:, 0]
        
        for t in range(1, trg_len):
            output, hidden_cell = self.decoder(input_token, hidden_cell, encoder_outputs)
            outputs[:, t] = output
            
            # Teacher forcing with scheduled sampling
            if random.random() < teacher_forcing_ratio:
                input_token = trg[:, t]
            else:
                input_token = output.argmax(1)
            
        return outputs

# --- Training and Evaluation Functions ---
def train_epoch(model, dataloader, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    total_norm = 0
    
    for src, trg in dataloader:
        src, trg = src.to(device), trg.to(device)
        optimizer.zero_grad()
        output = model(src, trg)
        output_dim = output.shape[-1]
        output = output[:, 1:].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)
        loss = criterion(output, trg)
        loss.backward()
        
        # Monitor gradient norms
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        total_norm += grad_norm.item()
        
        optimizer.step()
        epoch_loss += loss.item()
    
    avg_grad_norm = total_norm / len(dataloader)
    if avg_grad_norm > 5.0:  # Warning threshold
        print(f"Warning: High gradient norm: {avg_grad_norm:.2f}")
    
    return epoch_loss / len(dataloader)

def evaluate(model, dataloader, criterion):
    model.eval()
    epoch_loss = 0
    all_refs, all_hyps = [], []
    smooth_fn = SmoothingFunction().method1
    with torch.no_grad():
        for src, trg in dataloader:
            src, trg = src.to(device), trg.to(device)
            output = model(src, trg, 0)
            output_dim = output.shape[-1]
            loss_output = output[:, 1:].reshape(-1, output_dim)
            loss_trg = trg[:, 1:].reshape(-1)
            loss = criterion(loss_output, loss_trg)
            epoch_loss += loss.item()
            hyp_tokens = output.argmax(2)
            for i in range(hyp_tokens.shape[0]):
                hyp_ids = hyp_tokens[i, 1:].tolist()
                ref_ids = trg[i, 1:].tolist()
                if EOS_IDX in hyp_ids:
                    hyp_ids = hyp_ids[:hyp_ids.index(EOS_IDX)]
                if EOS_IDX in ref_ids:
                    ref_ids = ref_ids[:ref_ids.index(EOS_IDX)]
                all_hyps.append(sp.decode_ids(hyp_ids).split())
                all_refs.append([sp.decode_ids(ref_ids).split()])
    bleu = corpus_bleu(all_refs, all_hyps, smoothing_function=smooth_fn)
    return epoch_loss / len(dataloader), bleu * 100

def train_epoch_with_tf(model, dataloader, optimizer, criterion, clip, tf_ratio):
    """Training function with dynamic teacher forcing ratio"""
    model.train()
    epoch_loss = 0
    total_norm = 0
    
    for src, trg in dataloader:
        src, trg = src.to(device), trg.to(device)
        optimizer.zero_grad()
        output = model(src, trg, teacher_forcing_ratio=tf_ratio)  # Use dynamic TF ratio
        output_dim = output.shape[-1]
        output = output[:, 1:].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)
        loss = criterion(output, trg)
        loss.backward()
        
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        total_norm += grad_norm.item()
        
        optimizer.step()
        epoch_loss += loss.item()
    
    avg_grad_norm = total_norm / len(dataloader)
    if avg_grad_norm > 5.0:
        print(f"Warning: High gradient norm: {avg_grad_norm:.2f}")
    
    return epoch_loss / len(dataloader)

# --- Improved BiLSTM Model with Better Architecture ---
class ImprovedBiLSTMEncoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_IDX)
        self.dropout = nn.Dropout(dropout)
        
        # Single bidirectional LSTM layer for simplicity
        self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers=1, bidirectional=True, 
                           dropout=0, batch_first=True)
        
        # Layer normalization for stability
        self.layer_norm = nn.LayerNorm(hid_dim * 2)
        
        # Project to decoder size
        self.fc_hidden = nn.Linear(hid_dim * 2, hid_dim)
        self.fc_cell = nn.Linear(hid_dim * 2, hid_dim)

    def forward(self, src):
        # Get actual sequence lengths for packing
        src_lengths = (src != PAD_IDX).sum(dim=1).cpu()
        
        embedded = self.dropout(self.embedding(src))
        
        # Pack sequences for efficiency
        packed_embedded = nn.utils.rnn.pack_padded_sequence(
            embedded, src_lengths, batch_first=True, enforce_sorted=False)
        
        packed_outputs, (hidden, cell) = self.lstm(packed_embedded)
        
        # Unpack sequences
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
        
        # Apply layer normalization
        outputs = self.layer_norm(outputs)
        
        # Process final states - hidden/cell are [2, batch, hid_dim] for bidirectional
        # Concatenate forward and backward final states
        final_hidden = torch.cat((hidden[0], hidden[1]), dim=1)  # [batch, hid_dim*2]
        final_cell = torch.cat((cell[0], cell[1]), dim=1)        # [batch, hid_dim*2]
        
        # Project to decoder dimensions
        decoder_hidden = torch.tanh(self.fc_hidden(final_hidden)).unsqueeze(0)  # [1, batch, hid_dim]
        decoder_cell = torch.tanh(self.fc_cell(final_cell)).unsqueeze(0)        # [1, batch, hid_dim]
        
        return outputs, (decoder_hidden, decoder_cell)

class ImprovedLSTMAttention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.hid_dim = hid_dim
        
        # Attention mechanism
        self.W_h = nn.Linear(hid_dim, hid_dim, bias=False)  # For hidden state
        self.W_s = nn.Linear(hid_dim * 2, hid_dim, bias=False)  # For encoder outputs
        self.v = nn.Linear(hid_dim, 1, bias=False)
        
        self.dropout = nn.Dropout(0.1)

    def forward(self, hidden_cell, encoder_outputs):
        # Extract hidden state
        if isinstance(hidden_cell, tuple):
            hidden = hidden_cell[0]  # [1, batch, hid_dim]
        else:
            hidden = hidden_cell
        
        # Remove layer dimension if present
        if hidden.dim() == 3:
            hidden = hidden.squeeze(0)  # [batch, hid_dim]
        
        batch_size, src_len, enc_hid_dim = encoder_outputs.size()
        
        # Prepare hidden state for attention computation
        hidden_transformed = self.W_h(hidden)  # [batch, hid_dim]
        hidden_expanded = hidden_transformed.unsqueeze(1).expand(-1, src_len, -1)  # [batch, src_len, hid_dim]
        
        # Transform encoder outputs
        encoder_transformed = self.W_s(encoder_outputs)  # [batch, src_len, hid_dim]
        
        # Compute attention scores
        energy = torch.tanh(hidden_expanded + encoder_transformed)  # [batch, src_len, hid_dim]
        attention_scores = self.v(energy).squeeze(2)  # [batch, src_len]
        
        # Create mask for padding tokens
        mask = (encoder_outputs.sum(dim=2) != 0)  # [batch, src_len]
        attention_scores = attention_scores.masked_fill(~mask, -1e10)
        
        # Apply softmax
        attention_weights = torch.softmax(attention_scores, dim=1)  # [batch, src_len]
        attention_weights = self.dropout(attention_weights)
        
        return attention_weights

class ImprovedLSTMDecoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, dropout, attention):
        super().__init__()
        self.attention = attention
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_IDX)
        self.dropout = nn.Dropout(dropout)
        
        # Single layer LSTM
        self.lstm = nn.LSTM(emb_dim + hid_dim * 2, hid_dim, num_layers=1, batch_first=True)
        
        # Output projection with residual connection
        self.fc_out = nn.Linear(hid_dim + hid_dim * 2 + emb_dim, vocab_size)
        
        # Layer normalization
        self.layer_norm = nn.LayerNorm(hid_dim)

    def forward(self, input, hidden_cell, encoder_outputs):
        input = input.unsqueeze(1)  # [batch, 1]
        embedded = self.dropout(self.embedding(input))  # [batch, 1, emb_dim]
        
        # Compute attention
        attention_weights = self.attention(hidden_cell, encoder_outputs)  # [batch, src_len]
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)  # [batch, 1, hid_dim*2]
        
        # Combine embedding and context
        rnn_input = torch.cat((embedded, context), dim=2)  # [batch, 1, emb_dim + hid_dim*2]
        
        # LSTM forward pass
        output, new_hidden_cell = self.lstm(rnn_input, hidden_cell)  # [batch, 1, hid_dim]
        
        # Apply layer normalization
        output = self.layer_norm(output)
        
        # Prepare for output projection
        embedded_flat = embedded.squeeze(1)  # [batch, emb_dim]
        output_flat = output.squeeze(1)      # [batch, hid_dim]
        context_flat = context.squeeze(1)    # [batch, hid_dim*2]
        
        # Final output projection
        prediction = self.fc_out(torch.cat((output_flat, context_flat, embedded_flat), dim=1))
        
        return prediction, new_hidden_cell

class ImprovedBiLSTMSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.fc_out.out_features
        
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
        encoder_outputs, hidden_cell = self.encoder(src)
        
        input_token = trg[:, 0]
        
        for t in range(1, trg_len):
            output, hidden_cell = self.decoder(input_token, hidden_cell, encoder_outputs)
            outputs[:, t] = output
            
            # Teacher forcing with scheduled sampling
            if random.random() < teacher_forcing_ratio:
                input_token = trg[:, t]
            else:
                input_token = output.argmax(1)
            
        return outputs

# --- Hyperparameters and Model Instantiation ---
EMB_DIM = 256
HID_DIM = 256  # Increased hidden size
ENC_DROPOUT = 0.3
DEC_DROPOUT = 0.3
CLIP = 1.0
NUM_EPOCHS = 60
PATIENCE = 20

print("\n=== Creating Improved BiLSTM with Attention Model ===")
attn = ImprovedLSTMAttention(HID_DIM)
enc = ImprovedBiLSTMEncoder(VOCAB_SIZE, EMB_DIM, HID_DIM, ENC_DROPOUT)
dec = ImprovedLSTMDecoder(VOCAB_SIZE, EMB_DIM, HID_DIM, DEC_DROPOUT, attn)
model = ImprovedBiLSTMSeq2Seq(enc, dec, device).to(device)

# Better initialization
def improved_init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Embedding):
        nn.init.uniform_(m.weight, -0.1, 0.1)
        # Zero out padding token embedding
        with torch.no_grad():
            m.weight[PAD_IDX].fill_(0)
    elif isinstance(m, nn.LSTM):
        for name, param in m.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param.data)
            elif 'bias' in name:
                nn.init.zeros_(param.data)
                # Set forget gate bias to 1
                n = param.size(0)
                param.data[n//4:n//2].fill_(1.)

model.apply(improved_init_weights)

# Better optimizer with lower learning rate
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-5, 
                       betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX, label_smoothing=0.1)

# Warmup + Cosine scheduler
from torch.optim.lr_scheduler import LambdaLR
import math

def get_lr_scheduler(optimizer, num_epochs, warmup_epochs=5):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return float(epoch) / float(max(1, warmup_epochs))
        else:
            progress = (epoch - warmup_epochs) / (num_epochs - warmup_epochs)
            return 0.5 * (1 + math.cos(math.pi * progress))
    return LambdaLR(optimizer, lr_lambda)

scheduler = get_lr_scheduler(optimizer, NUM_EPOCHS)

# --- Training Loop ---
best_bleu = -1
best_val_loss = float('inf')
epochs_no_improve = 0
epochs_no_loss_improve = 0

print("Starting improved BiLSTM training...")
print(f"Improved model has {count_parameters(model):,} trainable parameters")

for epoch in range(1, NUM_EPOCHS + 1):
    # More gradual teacher forcing schedule
    if epoch <= 15:
        tf_ratio = 0.9
    elif epoch <= 30:
        tf_ratio = 0.7
    elif epoch <= 45:
        tf_ratio = 0.5
    else:
        tf_ratio = 0.3
    
    train_loss = train_epoch_with_tf(model, train_loader, optimizer, criterion, CLIP, tf_ratio)
    valid_loss, valid_bleu = evaluate(model, val_loader, criterion)
    
    scheduler.step()
    
    # Track best model based on BLEU
    if valid_bleu > best_bleu:
        best_bleu = valid_bleu
        epochs_no_improve = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_bleu': best_bleu,
            'valid_loss': valid_loss
        }, 'improved-bilstm-model.pt')
        print(f"*** New best BLEU: {valid_bleu:.2f} - Model saved! ***")
    else:
        epochs_no_improve += 1
    
    # Track validation loss for learning rate adjustment
    if valid_loss < best_val_loss:
        best_val_loss = valid_loss
        epochs_no_loss_improve = 0
    else:
        epochs_no_loss_improve += 1
    
    # Reduce LR if validation loss plateaus
    if epochs_no_loss_improve >= 8:
        for param_group in optimizer.param_groups:
            old_lr = param_group['lr']
            param_group['lr'] *= 0.8
            print(f"Reduced LR from {old_lr:.6f} to {param_group['lr']:.6f}")
        epochs_no_loss_improve = 0
    
    current_lr = optimizer.param_groups[0]['lr']
    print(f'Epoch: {epoch:02} | Train Loss: {train_loss:.3f} | Val Loss: {valid_loss:.3f} | Val BLEU: {valid_bleu:.2f} | TF: {tf_ratio:.2f} | LR: {current_lr:.6f}')
    
    # Early stopping
    if epochs_no_improve >= PATIENCE:
        print(f'BLEU not improving for {PATIENCE} epochs - early stopping!')
        break
    
    # Stop if learning rate becomes too small
    if current_lr < 1e-7:
        print("Learning rate too small - stopping!")
        break

print(f"Training completed! Best BLEU score: {best_bleu:.2f}")

# Improved translation with better decoding
def improved_translate_sentence(model, sentence, max_len=40):
    model.eval()
    tokens = [SOS_IDX] + sp.encode_as_ids(sentence.lower().strip()) + [EOS_IDX]
    src_tensor = torch.LongTensor(tokens).unsqueeze(0).to(device)
    
    with torch.no_grad():
        encoder_outputs, hidden_cell = model.encoder(src_tensor)
    
    trg_indexes = [SOS_IDX]
    
    for i in range(max_len):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
        
        with torch.no_grad():
            output, hidden_cell = model.decoder(trg_tensor, hidden_cell, encoder_outputs)
        
        # Apply repetition penalty
        if len(trg_indexes) > 1:
            for prev_token in set(trg_indexes[1:]):  # Exclude SOS
                if prev_token < output.size(1):
                    count = trg_indexes.count(prev_token)
                    penalty = min(count * 1.2, 3.0)  # Cap penalty
                    output[0, prev_token] -= penalty
        
        # Use nucleus sampling (top-p)
        sorted_logits, sorted_indices = torch.sort(output, descending=True)
        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
        
        # Remove tokens with cumulative probability above the threshold (nucleus)
        sorted_indices_to_remove = cumulative_probs > 0.9
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        
        # Set logits to -inf for removed tokens
        sorted_logits[sorted_indices_to_remove] = float('-inf')
        
        # Sample from the filtered distribution
        probs = torch.softmax(sorted_logits / 0.8, dim=-1)  # Temperature scaling
        pred_token = torch.multinomial(probs, 1)[0, 0].item()
        pred_token = sorted_indices[0, pred_token].item()
        
        trg_indexes.append(pred_token)
        if pred_token == EOS_IDX:
            break
    
    return sp.decode(trg_indexes)

# Load best model
try:
    checkpoint = torch.load('improved-bilstm-model.pt', map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Improved BiLSTM model loaded! Best BLEU: {checkpoint['best_bleu']:.2f}")
except FileNotFoundError:
    print("Using current trained model state...")

sample_sentences = [
    "মই ঘৰত থাকোঁ।",
    "তেওঁ আজি বিদ্যালয়লৈ গ'ল।",
    "বইখন টেবুলৰ ওপৰত আছে।"
]

print("\n=== Improved BiLSTM Translation Examples ===")
for sentence in sample_sentences:
    greedy_translation = improved_translate_sentence(model, sentence)
    beam_translation = bilstm_beam_search(model, sentence, beam_width=3, max_len=40)
    print(f"SOURCE:     {sentence}")
    print(f"IMPROVED:   {greedy_translation}")
    print(f"BEAM (k=3): {beam_translation}")
    print("-" * 50)

print("\n=== Final Evaluation ===")
test_loss, test_bleu = evaluate(model, test_loader, criterion)
print(f'Test Loss: {test_loss:.3f} | Test BLEU: {test_bleu:.2f}')

# Additional evaluation with more samples
print("\n=== Additional Improved BiLSTM Test Sentences ===")
additional_test_sentences = [
    "আমি খাওয়া-দাওয়া কৰি আছো।",
    "তেওঁ কিতাপ পঢ়ি আছে।", 
    "আজি বৰষুণ হৈছে।",
    "স্কুলত পাঠদান চলি আছে।",
    "আমাৰ ঘৰখন ডাঙৰ।"
]

for sentence in additional_test_sentences:
    try:
        greedy_translation = improved_translate_sentence(model, sentence)
        beam_translation = bilstm_beam_search(model, sentence, beam_width=3, max_len=40)
        print(f"SOURCE:     {sentence}")
        print(f"IMPROVED:   {greedy_translation}")
        print(f"BEAM (k=3): {beam_translation}")
        print("-" * 30)
    except Exception as e:
        print(f"Error translating '{sentence}': {str(e)}")
        continue

Using device: cuda

=== VOCABULARY INFORMATION ===
Vocabulary size: 8000
PAD token: 0 -> '<unk>'
SOS token: 1 -> '<s>'
EOS token: 2 -> '</s>'
UNK token: 0 -> '<unk>'

Sample vocabulary (first 50 tokens):
  0: '<unk>'
  1: '<s>'
  2: '</s>'
  3: '▁t'
  4: 'he'
  5: '▁a'
  6: 'in'
  7: '▁the'
  8: '▁ক'
  9: 'য়'
 10: 'াৰ'
 11: '▁ব'
 12: '▁প'
 13: '▁o'
 14: '▁s'
 15: '▁স'
 16: 're'
 17: '্ৰ'
 18: '▁b'
 19: 'er'
 20: 'ha'
 21: '▁c'
 22: 'on'
 23: 'en'
 24: '▁আ'
 25: '▁p'
 26: '▁w'
 27: 'is'
 28: '▁in'
 29: '্য'
 30: 'ed'
 31: 'ৰা'
 32: '▁m'
 33: '▁of'
 34: 'ar'
 35: 'an'
 36: '্ত'
 37: 'at'
 38: 'ৰি'
 39: '▁ম'
 40: '▁f'
 41: '▁d'
 42: 'it'
 43: '▁অ'
 44: 'or'
 45: '▁ন'
 46: '▁হ'
 47: 'ান'
 48: '▁দ'
 49: 'al'

High-frequency tokens (500-600):
500: '▁গ্ৰ'
501: 'ুক'
502: '▁day'
503: '▁tak'
504: 'ord'
505: 'hile'
506: 'ight'
507: '▁বিজ'
508: 'ঘট'
509: '▁বৃ'
510: 'হণ'
511: 'ৎস'
512: 'og'
513: '▁ও'
514: '▁থকা'
515: '▁মন্ত'
516: 'ire'
517: '▁part'
518: '▁।'
519: '▁comm'
520: '▁ভাৰত'
521: 'হা'
522

In [6]:
# Add this function after the analyze_vocabulary function:
def analyze_language_distribution(sp_model, train_df):
    """Analyze how vocabulary is distributed between languages"""
    print(f"\n=== LANGUAGE DISTRIBUTION IN VOCABULARY ===")
    
    from collections import Counter, defaultdict
    
    # Count token frequency by language
    asm_token_counts = Counter()
    eng_token_counts = Counter()
    
    # Analyze Assamese sentences
    for text in train_df['asm'].astype(str):
        tokens = sp_model.encode_as_pieces(text.lower().strip())
        asm_token_counts.update(tokens)
    
    # Analyze English sentences  
    for text in train_df['eng'].astype(str):
        tokens = sp_model.encode_as_pieces(text.lower().strip())
        eng_token_counts.update(tokens)
    
    # Find shared vs unique tokens
    asm_tokens = set(asm_token_counts.keys())
    eng_tokens = set(eng_token_counts.keys())
    shared_tokens = asm_tokens & eng_tokens
    asm_only = asm_tokens - eng_tokens
    eng_only = eng_tokens - asm_tokens
    
    print(f"Total vocabulary size: {VOCAB_SIZE}")
    print(f"Tokens used in Assamese: {len(asm_tokens)}")
    print(f"Tokens used in English: {len(eng_tokens)}")
    print(f"Shared tokens: {len(shared_tokens)}")
    print(f"Assamese-only tokens: {len(asm_only)}")
    print(f"English-only tokens: {len(eng_only)}")
    
    print(f"\nTop 10 shared tokens:")
    shared_counts = {token: asm_token_counts[token] + eng_token_counts[token] 
                    for token in shared_tokens}
    for token, count in sorted(shared_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
        print(f"  '{token}': {count} times (Asm: {asm_token_counts[token]}, Eng: {eng_token_counts[token]})")
    
    print(f"\nTop 10 Assamese-only tokens:")
    for token, count in asm_token_counts.most_common(10):
        if token in asm_only:
            print(f"  '{token}': {count} times")
    
    print(f"\nTop 10 English-only tokens:")
    for token, count in eng_token_counts.most_common(10):
        if token in eng_only:
            print(f"  '{token}': {count} times")

# Call this function after the existing analyze_vocabulary call:
analyze_vocabulary(sp, train_df, val_df)
analyze_language_distribution(sp, train_df)  # Add this line


=== VOCABULARY ANALYSIS ===
Total unique tokens in training: 7793
Assamese tokens: 4548
English tokens: 7474
Vocabulary coverage: 97.4%

Top 20 most frequent tokens:
'▁the' (ID: 7): 10682 times
',' (ID: 7857): 8371 times
'.' (ID: 7861): 6697 times
'▁of' (ID: 33): 4551 times
'।' (ID: 7872): 4468 times
'▁in' (ID: 28): 3608 times
'▁to' (ID: 53): 2788 times
'▁a' (ID: 5): 2433 times
'-' (ID: 7884): 2026 times
's' (ID: 7826): 1958 times
'▁and' (ID: 94): 1840 times
'ৰ' (ID: 7822): 1757 times
'▁on' (ID: 99): 1392 times
'▁is' (ID: 123): 1327 times
'’' (ID: 7894): 1257 times
'ত' (ID: 7832): 1172 times
''' (ID: 7896): 1130 times
'▁আৰু' (ID: 150): 1101 times
'▁এই' (ID: 139): 1076 times
'ক' (ID: 7831): 1068 times

=== LANGUAGE DISTRIBUTION IN VOCABULARY ===
Total vocabulary size: 8000
Tokens used in Assamese: 4548
Tokens used in English: 7474
Shared tokens: 4229
Assamese-only tokens: 319
English-only tokens: 3245

Top 10 shared tokens:
  '▁the': 10682 times (Asm: 6, Eng: 10676)
  ',': 8371 times (