In [None]:
#  1. IMPORTS AND SETUP 
import pandas as pd
import numpy as np
import random
import time
import sentencepiece as spm
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


#  2. DATA LOADING & TOKENIZER TUNING 
print("\n--- Loading Data and Training Tokenizer ---")

df = pd.read_csv('/kaggle/input/filtered/Filtered_data.tsv', sep='\t',
                 on_bad_lines='skip', names=["Assamese sentence", "English sentence"])
df.dropna(inplace=True)

train_df, test_df = train_test_split(df, test_size=0.2, random_state=SEED)
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=SEED)


print(f"Data split: {len(train_df)} training, {len(val_df)} validation, {len(test_df)} test pairs.")


with open('all_text_for_bpe.txt', 'w', encoding='utf-8') as f:
    for text in pd.concat([train_df['Assamese sentence'], train_df['English sentence']]):
        f.write(str(text).strip().lower() + '\n')

spm.SentencePieceTrainer.Train(
    '--input=all_text_for_bpe.txt --model_prefix=spm_bpe --vocab_size=8000 '
    '--character_coverage=1.0 --model_type=bpe --pad_id=0 --pad_piece=<pad> '
    '--bos_id=1 --bos_piece=<s> --eos_id=2 --eos_piece=</s> --unk_id=3 --unk_piece=<unk>'
)

# Load the trained tokenizer
sp = spm.SentencePieceProcessor()
sp.Load('spm_bpe.model')

# Define special token indices
PAD_IDX, SOS_IDX, EOS_IDX = sp.pad_id(), sp.bos_id(), sp.eos_id()
VOCAB_SIZE = sp.GetPieceSize()
print(f"Joint Vocabulary Size: {VOCAB_SIZE}")


print("\n--- Sample Vocabulary Tokens ---")
sample_tokens = [sp.IdToPiece(i) for i in range(4, 25)]
print(f"Sample tokens: {sample_tokens}")



# 3. DATASET AND DATALOADERS 
class TranslationDataset(Dataset):
    def __init__(self, df, sp_model, max_len=100):
        self.src_sents = df['Assamese sentence'].astype(str).tolist()
        self.trg_sents = df['English sentence'].astype(str).tolist()
        self.sp_model = sp_model
        self.max_len = max_len

    def __len__(self):
        return len(self.src_sents)

    def __getitem__(self, idx):
        src_encoded = self.sp_model.EncodeAsIds(self.src_sents[idx].lower().strip())
        trg_encoded = self.sp_model.EncodeAsIds(self.trg_sents[idx].lower().strip())
        src_tensor = torch.LongTensor([SOS_IDX] + src_encoded[:self.max_len-2] + [EOS_IDX])
        trg_tensor = torch.LongTensor([SOS_IDX] + trg_encoded[:self.max_len-2] + [EOS_IDX])
        return src_tensor, trg_tensor

def collate_fn(batch):
    srcs, trgs = zip(*batch)
    src_padded = pad_sequence(srcs, batch_first=True, padding_value=PAD_IDX)
    trg_padded = pad_sequence(trgs, batch_first=True, padding_value=PAD_IDX)
    return src_padded, trg_padded

BATCH_SIZE = 64
train_loader = DataLoader(TranslationDataset(train_df, sp), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(TranslationDataset(val_df, sp), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(TranslationDataset(test_df, sp), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
print(f"\nDataLoaders created with batch size {BATCH_SIZE}.")

#  4. MODEL ARCHITECTURE 
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_IDX)
        self.rnn = nn.GRU(emb_dim, hid_dim, num_layers=n_layers,
                          bidirectional=True, dropout=dropout if n_layers > 1 else 0, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc_hidden = nn.Linear(hid_dim * 2, hid_dim)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        s = hidden.shape
        hidden = hidden.view(self.rnn.num_layers, 2, s[1], s[2])
        hidden_cat = torch.cat((hidden[:, 0, :, :], hidden[:, 1, :, :]), dim=2)
        decoder_hidden = torch.tanh(self.fc_hidden(hidden_cat))
        return outputs, decoder_hidden

class Attention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.attn = nn.Linear((hid_dim * 2) + hid_dim, hid_dim)
        self.v = nn.Linear(hid_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        top_hidden = hidden[-1]
        src_len = encoder_outputs.shape[1]
        top_hidden = top_hidden.unsqueeze(1).repeat(1, src_len, 1)
        energy = torch.tanh(self.attn(torch.cat((top_hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)
        return torch.softmax(attention, dim=1)

class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, n_layers, dropout, attention):
        super().__init__()
        self.attention = attention
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_IDX)
        self.rnn = nn.GRU((hid_dim * 2) + emb_dim, hid_dim, num_layers=n_layers,
                          dropout=dropout if n_layers > 1 else 0, batch_first=True)
        self.fc_out = nn.Linear((hid_dim * 2) + hid_dim + emb_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, encoder_outputs):
        input = input.unsqueeze(1)
        embedded = self.dropout(self.embedding(input))
        a = self.attention(hidden, encoder_outputs).unsqueeze(1)
        context = torch.bmm(a, encoder_outputs)
        rnn_input = torch.cat((embedded, context), dim=2)
        output, new_hidden = self.rnn(rnn_input, hidden)
        embedded, output, context = embedded.squeeze(1), output.squeeze(1), context.squeeze(1)
        prediction = self.fc_out(torch.cat((output, context, embedded), dim=1))
        return prediction, new_hidden

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder, self.decoder, self.device = encoder, decoder, device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size, trg_len = trg.shape
        trg_vocab_size = self.decoder.fc_out.out_features
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
        encoder_outputs, hidden = self.encoder(src)
        input = trg[:, 0]
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden, encoder_outputs)
            outputs[:, t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[:, t] if teacher_force else top1
        return outputs


#  5. TRAINING & EVALUATION FUNCTIONS 

def train_epoch(model, dataloader, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    for src, trg in dataloader:
        src, trg = src.to(device), trg.to(device)
        optimizer.zero_grad()
        output = model(src, trg)
        output_dim = output.shape[-1]
        output = output[:, 1:].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(dataloader)

def evaluate(model, dataloader, criterion):
    model.eval()
    epoch_loss = 0
    all_refs, all_hyps = [], []
    smooth_fn = SmoothingFunction().method1
    with torch.no_grad():
        for src, trg in dataloader:
            src, trg = src.to(device), trg.to(device)
            output = model(src, trg, 0) # No teacher forcing
            output_dim = output.shape[-1]
            loss_output, loss_trg = output[:, 1:].reshape(-1, output_dim), trg[:, 1:].reshape(-1)
            epoch_loss += criterion(loss_output, loss_trg).item()
            hyp_tokens = output.argmax(2)
            for i in range(hyp_tokens.shape[0]):
                hyp_ids, ref_ids = hyp_tokens[i, 1:].tolist(), trg[i, 1:].tolist()
                if EOS_IDX in hyp_ids: hyp_ids = hyp_ids[:hyp_ids.index(EOS_IDX)]
                if EOS_IDX in ref_ids: ref_ids = ref_ids[:ref_ids.index(EOS_IDX)]
                all_hyps.append(sp.decode_ids(hyp_ids).split())
                all_refs.append([sp.decode_ids(ref_ids).split()])
    bleu = corpus_bleu(all_refs, all_hyps, smoothing_function=smooth_fn)
    return epoch_loss / len(dataloader), bleu * 100


#  6. TUNED HYPERPARAMETERS, INSTANTIATION & TRAINING 

print("\n--- Initializing Tuned Model and Training ---")

EMB_DIM, HID_DIM = 256, 512
ENC_LAYERS, DEC_LAYERS = 2, 2
ENC_DROPOUT, DEC_DROPOUT = 0.5, 0.5
CLIP, NUM_EPOCHS, PATIENCE = 1.0, 50, 7

# Instantiate model
attn = Attention(HID_DIM)
enc = Encoder(VOCAB_SIZE, EMB_DIM, HID_DIM, ENC_LAYERS, ENC_DROPOUT)
dec = Decoder(VOCAB_SIZE, EMB_DIM, HID_DIM, DEC_LAYERS, DEC_DROPOUT, attn)
model = Seq2Seq(enc, dec, device).to(device)
print(f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters.")

# Optimizer, Loss, and Scheduler with strong regularization
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX, label_smoothing=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=2, verbose=True)

# Training loop
best_bleu = -1.0
epochs_no_improve = 0
for epoch in range(1, NUM_EPOCHS + 1):
    start_time = time.time()
    train_loss = train_epoch(model, train_loader, optimizer, criterion, CLIP)
    valid_loss, valid_bleu = evaluate(model, val_loader, criterion)
    end_time = time.time()

    scheduler.step(valid_bleu)
    if valid_bleu > best_bleu:
        best_bleu = valid_bleu
        epochs_no_improve = 0
        torch.save(model.state_dict(), 'best-model.pt')
        print(f" New best BLEU score: {valid_bleu:.2f}. Model saved.")
    else:
        epochs_no_improve += 1
    
    print(f'Epoch: {epoch:02} | Time: {end_time - start_time:.0f}s | Train Loss: {train_loss:.3f} | '
          f'Val. Loss: {valid_loss:.3f} | Val. BLEU: {valid_bleu:.2f} | Patience: {epochs_no_improve}/{PATIENCE}')
    
    if epochs_no_improve >= PATIENCE:
        print('Early stopping triggered!')
        break
print(f"\nTraining finished. Best validation BLEU: {best_bleu:.2f}")

#  7. INFERENCE AND FINAL TESTING 
print("\n--- Loading Best Model and Testing Translations ---")
# Load the best performing model
model.load_state_dict(torch.load('best-model.pt'))

def translate_sentence(sentence, model):
    model.eval()
    tokens = [SOS_IDX] + sp.encode_as_ids(sentence.lower().strip()) + [EOS_IDX]
    src_tensor = torch.LongTensor(tokens).unsqueeze(0).to(device)
    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(src_tensor)
    trg_indexes = [SOS_IDX]
    for i in range(100):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
        with torch.no_grad():
            output, hidden = model.decoder(trg_tensor, hidden, encoder_outputs)
        pred_token = output.argmax(1).item()
        trg_indexes.append(pred_token)
        if pred_token == EOS_IDX: break
    return sp.decode(trg_indexes).strip()

# Test on some examples
sample_sentences = [
    "তেওঁ আজি বিদ্যালয়লৈ গ'ল।",
    "বইখন টেবুলৰ ওপৰত আছে।",
    "মই তোমাক ভাল পাওঁ।"
]
for sentence in sample_sentences:
    translation = translate_sentence(sentence, model)
    print(f"Source:      {sentence}")
    print(f"Translation: {translation}")
    print("-" * 20)

# Final evaluation on the test set
test_loss, test_bleu = evaluate(model, test_loader, criterion)
print(f'\n Final Test BLEU on unseen data: {test_bleu:.2f} ')