In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import nltk
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
import random
import matplotlib.pyplot as plt
import spacy
from torchtext.data import Field, TabularDataset, BucketIterator
from torchtext.data.metrics import bleu_score

# nltk.download('punkt')
spacy_eng = spacy.load("en_core_web_sm")

def tokenize_eng(text):
    return [tok.text for tok in spacy_eng.tokenizer(text)]

# indic_tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indic-bert")
def tokenize_hi(text):
    return text.strip().split()  # crude word tokenizer for Hindi


english = Field(tokenize=tokenize_eng, lower=True, init_token="<sos>", eos_token="<eos>")
# We'll use raw text for Hindi target (tokenization handled separately)
hindi = Field(tokenize=tokenize_hi, lower=True, init_token="<sos>", eos_token="<eos>")


# Assume your CSV has columns: source, target
train_data, valid_data = TabularDataset.splits(
    path='.',  # directory where training.csv is located
    train='/team7_hi/team7_hi_train.csv',
    validation='/team7_hi/team7_hi_valid.csv',  # if you have a validation split, else ignore this line
    format='csv',
    fields=[('source', english), ('target', hindi)],
    skip_header=True
)
test_data = TabularDataset(
    path='/team7_hi/team7_hi_test.csv',
    format='csv',
    fields=[('source', english), ('target', hindi)],
    skip_header=True
)

# for i in range(50):
#     print(vars(train_data.examples[i]))
# Build vocab for English and Hindi on training data only
english.build_vocab(train_data, min_freq=2)  # min_freq to avoid rare tokens
hindi.build_vocab(train_data, min_freq=2)
# with open("hindi_vocab_full.txt", "w", encoding="utf-8") as f:
#     f.write("Index\tToken\n")
#     for idx, token in enumerate(hindi.vocab.itos):
#         f.write(f"{idx}\t{token}\n")

#     f.write("\nToken\tIndex\n")
#     for token, idx in hindi.vocab.stoi.items():
#         f.write(f"{token}\t{idx}\n")



def load_glove_embeddings(glove_path, word_to_idx, embedding_dim=200):
    embeddings = np.random.uniform(-0.25, 0.25, (len(word_to_idx), embedding_dim))
    found = 0

    with open(glove_path, 'r', encoding='utf8') as f:
        for line in f:
            values = line.strip().split()
            word = values[0]
            vector = np.array(values[1:], dtype=np.float32)
            if word in word_to_idx:
                embeddings[word_to_idx[word]] = vector
                found += 1
    print(f"Found embeddings for {found} out of {len(word_to_idx)} words.")
    return embeddings

# print(len(hindi.vocab.stoi))
class Encoder(nn.Module):
    def __init__(self, embedding_matrix, hidden_size, num_layers, p, trainable=False):
        super(Encoder, self).__init__()
        self.dropout = nn.Dropout(p)
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        num_embeddings, embedding_dim = embedding_matrix.shape
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.copy_(torch.from_numpy(embedding_matrix))
        self.embedding.weight.requires_grad = trainable  # Freeze embeddings (optional)

        self.rnn = nn.LSTM(embedding_dim, hidden_size, num_layers, dropout=p)

    def forward(self, x):
        embedding = self.dropout(self.embedding(x))
        outputs, (hidden, cell) = self.rnn(embedding)
        return hidden, cell

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_indicbert_embedding_matrix(vocab_stoi, model_name="ai4bharat/indic-bert"):
    """
    Returns a numpy array embedding matrix for your vocabulary aligned to indic-bert embeddings.

    Args:
        vocab_stoi: dict mapping token string to index (e.g., hindi.vocab.stoi)
        model_name: pretrained IndicBERT model name

    Returns:
        embedding_matrix: numpy array of shape (vocab_size, embedding_dim)
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    model.to(device)
    model.eval()

    # Use actual embedding dim from the model, not model.config.hidden_size
    embedding_layer = model.get_input_embeddings()
    embedding_dim = embedding_layer.embedding_dim
    print("Using embedding_dim =", embedding_dim)

    vocab_size = len(vocab_stoi)
    embedding_matrix = np.zeros((vocab_size, embedding_dim), dtype=np.float32)
    print("Embedding matrix shape:", embedding_matrix.shape)

    with torch.no_grad():
        for token, idx in vocab_stoi.items():
            tokenized = tokenizer.tokenize(token)
            if len(tokenized) == 0:
                tokenized = [tokenizer.unk_token]

            token_ids = tokenizer.convert_tokens_to_ids(tokenized)
            token_ids_tensor = torch.tensor(token_ids).unsqueeze(0).to(device)

            # Get subword embeddings and average them
            embeddings = embedding_layer(token_ids_tensor)  # (1, subword_count, embedding_dim)
            avg_embedding = embeddings.mean(dim=1).squeeze(0)  # (embedding_dim,)

            embedding_matrix[idx] = avg_embedding.cpu().numpy()

    return embedding_matrix
class Decoder(nn.Module):
    def __init__(
        self, embedding_matrix, hidden_size, output_size, num_layers, p, trainable=False
    ):
        super(Decoder, self).__init__()
        self.dropout = nn.Dropout(p)
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        vocab_size, embedding_size = embedding_matrix.shape

        # Use pretrained embeddings
        self.embedding = nn.Embedding.from_pretrained(
            torch.tensor(embedding_matrix), freeze=not trainable
        )

        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=p)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden, cell):
        # x shape: (N,) → batch of indices
        x = x.unsqueeze(0)  # shape: (1, N)

        embedding = self.dropout(self.embedding(x))
        # shape: (1, N, embedding_size)

        outputs, (hidden, cell) = self.rnn(embedding, (hidden, cell))
        # outputs shape: (1, N, hidden_size)

        predictions = self.fc(outputs.squeeze(0))  # shape: (N, output_size)

        return predictions, hidden, cell


embedding_matrix_hindi = get_indicbert_embedding_matrix(hindi.vocab.stoi, model_name="ai4bharat/indic-bert")
embedding_matrix_english = load_glove_embeddings('glove.6B.200d.txt', english.vocab.stoi, 200)

encoder = Encoder(embedding_matrix=embedding_matrix_english, hidden_size=256, num_layers=2, p=0.5)
decoder = Decoder(embedding_matrix=embedding_matrix_hindi, hidden_size=256, output_size=len(hindi.vocab), num_layers=2, p=0.5, trainable=False)

print(embedding_matrix_english.shape)
print(embedding_matrix_hindi.shape)

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, source, target, teacher_force_ratio=0.5):
        batch_size = source.shape[1]
        target_len = target.shape[0]
        target_vocab_size = len(hindi.vocab)

        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(device)

        hidden, cell = self.encoder(source)

        # Grab the first input to the Decoder which will be <SOS> token
        x = target[0]

        for t in range(1, target_len):
            # Use previous hidden, cell as context from encoder at start
            output, hidden, cell = self.decoder(x, hidden, cell)

            # Store next output prediction
            outputs[t] = output

            # Get the best word the Decoder predicted (index in the vocabulary)
            best_guess = output.argmax(1)

            # With probability of teacher_force_ratio we take the actual next word
            # otherwise we take the word that the Decoder predicted it to be.
            # Teacher Forcing is used so that the model gets used to seeing
            # similar inputs at training and testing time, if teacher forcing is 1
            # then inputs at test time might be completely different than what the
            # network is used to. This was a long comment.
            # print(best_guess)
            # print(hindi.vocab.itos[best_guess])
            # predicted_words = [hindi.vocab.itos[idx.item()] for idx in best_guess]
            # print(predicted_words)
            x = target[t] if random.random() < teacher_force_ratio else best_guess

        return outputs
num_epochs = 100
learning_rate = 0.001
batch_size = 64

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# Vocabulary sizes
input_size_encoder = len(english.vocab)
output_size_decoder = len(hindi.vocab)

# Embedding matrices (assumed preloaded)
# embedding_matrix_english: torch.FloatTensor with GloVe vectors
# embedding_matrix_hindi: torch.FloatTensor with IndicBERT vectors

# Hyperparameters
HIDDEN_SIZE = 1024
NUM_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

# Dataset iterators (replace with your actual dataset and Field objects)
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size=batch_size,
    sort_within_batch=True,
    sort_key=lambda x: len(x.source),
    device=device,
)

# Encoder and Decoder using preloaded embedding matrices
encoder = Encoder(
    embedding_matrix=embedding_matrix_english,
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
    p=ENC_DROPOUT,
    trainable=False
).to(device)

decoder = Decoder(
    embedding_matrix=embedding_matrix_hindi,
    hidden_size=HIDDEN_SIZE,
    output_size=output_size_decoder,
    num_layers=NUM_LAYERS,
    p=DEC_DROPOUT,
    trainable=False
).to(device)

# Seq2Seq Model
model = Seq2Seq(encoder, decoder).to(device)

# # Optimizer and Loss
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# pad_idx = hindi.vocab.stoi["<pad>"]
# criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)


# # Early stopping params
# patience = 5
# best_valid_loss = float('inf')
# early_stop_counter = 0

# train_losses = []
# valid_losses = []

# # Fixed sentence to translate
# fixed_sentence = "the weather is nice today"
# translation_log_file = "translation_progress.txt"

# with open(translation_log_file, "w", encoding="utf-8") as f_log:
#     f_log.write("Translation Progress Over Epochs\n\n")

# for epoch in range(num_epochs):
#     print(f"[Epoch {epoch+1}/{num_epochs}]")

#     model.train()
#     epoch_loss = 0

#     for batch in train_iterator:
#         src = batch.source.to(device)
#         trg = batch.target.to(device)

#         output = model(src, trg)

#         output = output[1:].reshape(-1, output.shape[2])
#         trg = trg[1:].reshape(-1)

#         optimizer.zero_grad()
#         loss = criterion(output, trg)
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
#         optimizer.step()

#         epoch_loss += loss.item()

#     train_loss = epoch_loss / len(train_iterator)
#     train_losses.append(train_loss)

#     # Validation loss
#     model.eval()
#     val_loss = 0
#     with torch.no_grad():
#         for batch in valid_iterator:
#             src = batch.source.to(device)
#             trg = batch.target.to(device)
#             output = model(src, trg, 0)  # No teacher forcing in validation

#             output = output[1:].reshape(-1, output.shape[2])
#             trg = trg[1:].reshape(-1)

#             loss = criterion(output, trg)
#             val_loss += loss.item()

#     valid_loss = val_loss / len(valid_iterator)
#     valid_losses.append(valid_loss)

#     print(f"Train Loss: {train_loss:.4f} | Val Loss: {valid_loss:.4f}")

#     # Early stopping logic
#     if valid_loss < best_valid_loss:
#         best_valid_loss = valid_loss
#         early_stop_counter = 0
#         torch.save(model.state_dict(), "best_model.pt")  # optional: save best model
#     else:
#         early_stop_counter += 1
#         if early_stop_counter >= patience:
#             print("Early stopping triggered.")
#             break

#     # Translate fixed sentence after each epoch
#     model.eval()
#     tokens = [tok.text.lower() for tok in spacy_eng.tokenizer(fixed_sentence)]
#     tokens = ["<sos>"] + tokens + ["<eos>"]
#     indices = [english.vocab.stoi[token] if token in english.vocab.stoi else english.vocab.stoi["<unk>"] for token in tokens]
#     src_tensor = torch.LongTensor(indices).unsqueeze(1).to(device)  # shape: (src_len, 1)

#     with torch.no_grad():
#         hidden, cell = model.encoder(src_tensor)
#         x = torch.LongTensor([hindi.vocab.stoi["<sos>"]]).to(device)
#         translated_sentence = []

#         for _ in range(50):  # max length
#             output, hidden, cell = model.decoder(x, hidden, cell)
#             best_guess = output.argmax(1).item()
#             predicted_word = hindi.vocab.itos[best_guess]

#             if predicted_word == "<eos>":
#                 break
#             translated_sentence.append(predicted_word)
#             x = torch.LongTensor([best_guess]).to(device)

#     translated_text = " ".join(translated_sentence)
#     with open(translation_log_file, "a", encoding="utf-8") as f_log:
#         f_log.write(f"Epoch {epoch+1}: {translated_text}\n")

# # Plot training and validation loss
# plt.figure(figsize=(10, 6))
# plt.plot(train_losses, label="Training Loss")
# plt.plot(valid_losses, label="Validation Loss")
# plt.xlabel("Epochs")
# plt.ylabel("Loss")
# plt.title("Training vs Validation Loss")
# plt.legend()
# plt.savefig("loss_curve.png")
# plt.show()



# # Load the trained model weights
model.load_state_dict(torch.load("best_model.pt", map_location=device))
# model.eval()


def translate_sentence(sentence, src_field, trg_field, model, device, max_len=50):
    model.eval()

    if isinstance(sentence, str):
        tokens = [tok.text.lower() for tok in spacy_eng.tokenizer(sentence)]
    else:
        tokens = [tok.lower() for tok in sentence]

    tokens = [src_field.init_token] + tokens + [src_field.eos_token]

    src_indices = [src_field.vocab.stoi.get(tok, src_field.vocab.stoi[src_field.unk_token]) for tok in tokens]
    src_tensor = torch.LongTensor(src_indices).unsqueeze(1).to(device)

    with torch.no_grad():
        hidden, cell = model.encoder(src_tensor)

    trg_indices = [trg_field.vocab.stoi[trg_field.init_token]]

    for _ in range(max_len):
        trg_tensor = torch.LongTensor([trg_indices[-1]]).to(device)

        with torch.no_grad():
            output, hidden, cell = model.decoder(trg_tensor, hidden, cell)

        best_guess = output.argmax(1).item()
        trg_indices.append(best_guess)

        if best_guess == trg_field.vocab.stoi[trg_field.eos_token]:
            break

    translated_tokens = [trg_field.vocab.itos[idx] for idx in trg_indices[1:]]  # remove <sos>
    return translated_tokens


# # Evaluate BLEU on a dataset iterator
# def evaluate_bleu(model, iterator, src_field, trg_field, device):
#     trgs = []
#     preds = []

#     for batch in iterator:
#         src = getattr(batch, 'src', getattr(batch, 'source', None))
#         trg = getattr(batch, 'trg', getattr(batch, 'target', None))

#         src = src.transpose(0, 1)  # shape: (batch_size, src_len)
#         trg = trg.transpose(0, 1)  # shape: (batch_size, trg_len)

#         for i in range(src.shape[0]):
#             src_indices = src[i].tolist()
#             trg_indices = trg[i].tolist()

#             src_tokens = [src_field.vocab.itos[idx] for idx in src_indices if idx not in
#                           [src_field.vocab.stoi[tok] for tok in [src_field.pad_token, src_field.init_token, src_field.eos_token]]]

#             trg_tokens = [trg_field.vocab.itos[idx] for idx in trg_indices if idx not in
#                           [trg_field.vocab.stoi[tok] for tok in [trg_field.pad_token, trg_field.init_token, trg_field.eos_token]]]

#             pred_tokens = translate_sentence(src_tokens, src_field, trg_field, model, device)

#             preds.append(pred_tokens)
#             trgs.append([trg_tokens])  # wrap target in a list for BLEU format

#     bleu1 = bleu_score(preds, trgs, max_n=1, weights=[1.0])
#     bleu2 = bleu_score(preds, trgs, max_n=2, weights=[0.5, 0.5])
#     bleu3 = bleu_score(preds, trgs, max_n=3, weights=[1/3, 1/3, 1/3])
#     bleu4 = bleu_score(preds, trgs, max_n=4, weights=[0.25, 0.25, 0.25, 0.25])

#     print(f"BLEU-1: {bleu1 * 100:.2f}")
#     print(f"BLEU-2: {bleu2 * 100:.2f}")
#     print(f"BLEU-3: {bleu3 * 100:.2f}")
#     print(f"BLEU-4: {bleu4 * 100:.2f}")

#     return bleu1, bleu2, bleu3, bleu4

import random

def print_random_translations_to_file(dataset, src_field, trg_field, model, device, output_path, n=5, name="Dataset"):
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(f"Translations from {name}:\n")

        examples = random.sample(dataset.examples, n)

        for i, example in enumerate(examples):
            src_tokens = getattr(example, 'src', getattr(example, 'source', None))
            trg_tokens = getattr(example, 'trg', getattr(example, 'target', None))

            if isinstance(src_tokens, list):
                src_text = ' '.join(src_tokens)
            else:
                src_text = ' '.join([src_field.vocab.itos[idx] for idx in src_tokens])

            if isinstance(trg_tokens, list):
                trg_text = ' '.join(trg_tokens)
            else:
                trg_text = ' '.join([trg_field.vocab.itos[idx] for idx in trg_tokens])

            translation = translate_sentence(src_text, src_field, trg_field, model, device)
            translation_text = ' '.join(translation).replace('<eos>', '').strip()

            f.write(f"\nExample {i+1}:\n")
            f.write(f"  Source:      {src_text}\n")
            f.write(f"  Target:      {trg_text}\n")
            f.write(f"  Translation: {translation_text}\n")

print_random_translations_to_file(train_data, english, hindi, model, device, "train_translations.txt", n=5, name="Train Data")
print_random_translations_to_file(test_data, english, hindi, model, device, "test_translations.txt", n=5, name="Test Data")




# # Run BLEU evaluation
# print("Evaluating on TRAIN set...")
# evaluate_bleu(model, train_iterator, english, hindi, device)

# print("\nEvaluating on TEST set...")
# evaluate_bleu(model, test_iterator, english, hindi, device)

