<a href="https://colab.research.google.com/github/Khesorw/AshtraMind/blob/main/mt5_comp_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import transformers

In [None]:
!git clone https://github.com/rahular/itihasa.git


Cloning into 'itihasa'...
remote: Enumerating objects: 5656, done.[K
remote: Counting objects: 100% (35/35), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 5656 (delta 11), reused 19 (delta 3), pack-reused 5621 (from 1)[K
Receiving objects: 100% (5656/5656), 42.02 MiB | 10.89 MiB/s, done.
Resolving deltas: 100% (2785/2785), done.


In [None]:
from datasets import Dataset, DatasetDict

# Helper function to read parallel files
def load_translation_split(source_path, target_path, source_lang="sn", target_lang="en"):
    with open(source_path, encoding="utf-8") as src_file, open(target_path, encoding="utf-8") as tgt_file:
        sources = [line.strip() for line in src_file]
        targets = [line.strip() for line in tgt_file]

    # Combine into "translation" field
    return Dataset.from_dict({
        "translation": [
            {source_lang: s, target_lang: t}
            for s, t in zip(sources, targets)
        ]
    })

# Paths
base_path = "itihasa/data"

train_dataset = load_translation_split(f"{base_path}/train.sn", f"{base_path}/train.en")
val_dataset   = load_translation_split(f"{base_path}/dev.sn",   f"{base_path}/dev.en")
test_dataset  = load_translation_split(f"{base_path}/test.sn",  f"{base_path}/test.en")

# Create DatasetDict
full_dataset = DatasetDict({
    "train": train_dataset,
    "validation": val_dataset,
    "test": test_dataset
})

# Confirm
print(full_dataset)


DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 75161
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 6148
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 11721
    })
})


In [None]:
full_dataset["train"]['translation'][0]

{'en': 'The ascetic VƒÅlmƒ´ki asked NƒÅrada, the best of sages and foremost of those conversant with words, ever engaged in austerities and Vedic studies.',
 'sn': '‡•ê ‡§§‡§™‡§É ‡§∏‡•ç‡§µ‡§æ‡§ß‡•ç‡§Ø‡§æ‡§Ø‡§®‡§ø‡§∞‡§§‡§Ç ‡§§‡§™‡§∏‡•ç‡§µ‡•Ä ‡§µ‡§æ‡§ó‡•ç‡§µ‡§ø‡§¶‡§æ‡§Ç ‡§µ‡§∞‡§Æ‡•ç‡•§ ‡§®‡§æ‡§∞‡§¶‡§Ç ‡§™‡§∞‡§ø‡§™‡§™‡•ç‡§∞‡§ö‡•ç‡§õ ‡§µ‡§æ‡§≤‡•ç‡§Æ‡•Ä‡§ï‡§ø‡§∞‡•ç‡§Æ‡•Å‡§®‡§ø‡§™‡•Å‡§ô‡•ç‡§ó‡§µ‡§Æ‡•ç‡••'}

In [None]:
tokenized_datasets["validation"]

Dataset({
    features: ['translation', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 6148
})

In [None]:
tokenized_datasets['train']

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 150322
})

In [None]:
# ================================
# GOOGLE COLAB SETUP FOR ENGLISH TO SANSKRIT TRANSLATION WITH mT5
# ================================

# Step 1: Install required packages
# !pip install transformers datasets evaluate accelerate sentencepiece

# Step 2: Import libraries
import torch
from transformers import (
    MT5ForConditionalGeneration,
    T5Tokenizer,  # Use T5Tokenizer instead of MT5Tokenizer
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from datasets import Dataset, DatasetDict
import numpy as np
from evaluate import load
import gc

# Step 3: Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

# Step 4: Load your dataset (replace this with your actual dataset loading)
# Assuming you have your dataset as 'full_dataset'
# full_dataset = your_dataset_here





CUDA available: True
GPU device: NVIDIA A100-SXM4-40GB


In [None]:
# Step 5: Model and tokenizer setup (optimized for Colab)
model_name = "google/mt5-small"
print("Loading tokenizer and model...")

# Use T5Tokenizer for mT5 (this is the correct tokenizer)
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = MT5ForConditionalGeneration.from_pretrained(model_name)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"Model loaded on {device}")

Loading tokenizer and model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Model loaded on cuda


In [None]:
# Step 6: Preprocessing function optimized for Colab
def preprocess_function(examples, max_input_length=256, max_target_length=256):

    # Add task prefix for mT5
    inputs = ["translate English to Sanskrit: " + ex['en'] for ex in examples['translation']]
    targets = [ex['sn'] for ex in examples['translation']]

    # Tokenize inputs
    model_inputs = tokenizer(
        inputs,
        max_length=max_input_length,
        truncation=True,
        padding=False  # Don't pad here, let data collator handle it
    )

    # Tokenize targets
    labels = tokenizer(
        targets,
        max_length=max_target_length,
        truncation=True,
        padding=False
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs



# Step 7: Apply preprocessing
def prepare_dataset(full_dataset):
    print("Preprocessing dataset...")
    tokenized_datasets = full_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=full_dataset["train"].column_names,
        desc="Tokenizing datasets"
    )

    # Filter out examples that are too long to prevent OOM
    def filter_long_sequences(example):
        return len(example["input_ids"]) <= 256 and len(example["labels"]) <= 256

    tokenized_datasets = tokenized_datasets.filter(filter_long_sequences)

    print(f"Dataset sizes after filtering:")
    print(f"Train: {len(tokenized_datasets['train'])}")
    print(f"Validation: {len(tokenized_datasets['validation'])}")
    print(f"Test: {len(tokenized_datasets['test'])}")

    return tokenized_datasets


tokenized_datasets = prepare_dataset(full_dataset)


Preprocessing dataset...


Tokenizing datasets:   0%|          | 0/75161 [00:00<?, ? examples/s]

Tokenizing datasets:   0%|          | 0/6148 [00:00<?, ? examples/s]

Tokenizing datasets:   0%|          | 0/11721 [00:00<?, ? examples/s]

Filter:   0%|          | 0/75161 [00:00<?, ? examples/s]

Filter:   0%|          | 0/6148 [00:00<?, ? examples/s]

Filter:   0%|          | 0/11721 [00:00<?, ? examples/s]

Dataset sizes after filtering:
Train: 75161
Validation: 6148
Test: 11721


In [None]:
import torch
from transformers import T5Tokenizer

def verify_sanskrit_tokenization(tokenizer, full_dataset, num_samples=5):
    """
    Verify that Sanskrit tokenization and de-tokenization works correctly
    """
    print("=== SANSKRIT TOKENIZATION VERIFICATION ===\n")

    # Get some sample texts
    samples = full_dataset["train"]["translation"][:num_samples]

    for i, sample in enumerate(samples):
        english_text = sample['en']
        sanskrit_text = sample['sn']

        print(f"--- Sample {i+1} ---")
        print(f"Original English: {english_text}")
        print(f"Original Sanskrit: {sanskrit_text}")

        # Test English tokenization
        english_input = f"translate English to Sanskrit: {english_text}"
        english_tokens = tokenizer.encode(english_input)
        english_decoded = tokenizer.decode(english_tokens, skip_special_tokens=True)

        print(f"English tokenized length: {len(english_tokens)} tokens")
        print(f"English decoded: {english_decoded}")
        print(f"English matches: {english_input == english_decoded}")

        # Test Sanskrit tokenization
        sanskrit_tokens = tokenizer.encode(sanskrit_text)
        sanskrit_decoded = tokenizer.decode(sanskrit_tokens, skip_special_tokens=True)

        print(f"Sanskrit tokenized length: {len(sanskrit_tokens)} tokens")
        print(f"Sanskrit decoded: {sanskrit_decoded}")
        print(f"Sanskrit matches: {sanskrit_text == sanskrit_decoded}")

        # Show actual tokens (first few)
        print(f"Sanskrit tokens (first 10): {sanskrit_tokens[:10]}")
        print(f"Sanskrit token strings: {[tokenizer.decode([t]) for t in sanskrit_tokens[:10]]}")

        print("-" * 80)
        print()

def check_tokenizer_capabilities(tokenizer):
    """
    Check tokenizer's capabilities with Sanskrit
    """
    print("=== TOKENIZER CAPABILITIES ===\n")

    # Test Devanagari characters
    devanagari_chars = "‡§Ö‡§Ü‡§á‡§à‡§â‡§ä‡§è‡§ê‡§ì‡§î‡§ï‡§ñ‡§ó‡§ò‡§ô‡§ö‡§õ‡§ú‡§ù‡§û‡§ü‡§†‡§°‡§¢‡§£‡§§‡§•‡§¶‡§ß‡§®‡§™‡§´‡§¨‡§≠‡§Æ‡§Ø‡§∞‡§≤‡§µ‡§∂‡§∑‡§∏‡§π"
    tokens = tokenizer.encode(devanagari_chars)
    decoded = tokenizer.decode(tokens, skip_special_tokens=True)

    print(f"Devanagari test string: {devanagari_chars}")
    print(f"Tokenized to {len(tokens)} tokens")
    print(f"Decoded back: {decoded}")
    print(f"Perfect round-trip: {devanagari_chars == decoded}")
    print()

    # Test common Sanskrit words
    sanskrit_words = ["‡§ß‡§∞‡•ç‡§Æ", "‡§Ö‡§∞‡•ç‡§•", "‡§ï‡§æ‡§Æ", "‡§Æ‡•ã‡§ï‡•ç‡§∑", "‡§Ø‡•ã‡§ó", "‡§µ‡•á‡§¶", "‡§â‡§™‡§®‡§ø‡§∑‡§¶‡•ç"]
    for word in sanskrit_words:
        tokens = tokenizer.encode(word)
        decoded = tokenizer.decode(tokens, skip_special_tokens=True)
        print(f"'{word}' -> {len(tokens)} tokens -> '{decoded}' (match: {word == decoded})")

    print()

    # Vocabulary info
    print(f"Tokenizer vocabulary size: {len(tokenizer)}")
    print(f"Special tokens: {tokenizer.special_tokens_map}")

def analyze_tokenized_dataset(tokenized_datasets):
    """
    Analyze the tokenized dataset statistics
    """
    print("=== TOKENIZED DATASET ANALYSIS ===\n")

    train_data = tokenized_datasets["train"]

    # Analyze input lengths
    input_lengths = [len(example["input_ids"]) for example in train_data]
    label_lengths = [len(example["labels"]) for example in train_data]

    print("Input (English) sequence length statistics:")
    print(f"  Min: {min(input_lengths)}")
    print(f"  Max: {max(input_lengths)}")
    print(f"  Average: {sum(input_lengths)/len(input_lengths):.1f}")
    print(f"  Median: {sorted(input_lengths)[len(input_lengths)//2]}")

    print("\nTarget (Sanskrit) sequence length statistics:")
    print(f"  Min: {min(label_lengths)}")
    print(f"  Max: {max(label_lengths)}")
    print(f"  Average: {sum(label_lengths)/len(label_lengths):.1f}")
    print(f"  Median: {sorted(label_lengths)[len(label_lengths)//2]}")

    # Show a few tokenized examples
    print(f"\n=== SAMPLE TOKENIZED EXAMPLES ===")
    for i in range(3):
        example = train_data[i]
        print(f"\nExample {i+1}:")
        print(f"Input IDs length: {len(example['input_ids'])}")
        print(f"Input IDs (first 15): {example['input_ids'][:15]}")
        print(f"Label IDs length: {len(example['labels'])}")
        print(f"Label IDs (first 15): {example['labels'][:15]}")

        # Decode to verify
        input_text = tokenizer.decode(example['input_ids'], skip_special_tokens=True)
        # Labels might have -100 tokens, so filter them
        label_ids = [id for id in example['labels'] if id != -100]
        label_text = tokenizer.decode(label_ids, skip_special_tokens=True)

        print(f"Decoded input: {input_text}")
        print(f"Decoded label: {label_text}")

# Run all verification tests
def run_complete_verification(tokenizer, full_dataset, tokenized_datasets):
    """
    Run all verification tests
    """
    print("Starting complete tokenization verification...\n")

    # Test 1: Basic tokenization verification
    verify_sanskrit_tokenization(tokenizer, full_dataset, num_samples=3)

    # Test 2: Tokenizer capabilities
    check_tokenizer_capabilities(tokenizer)

    # Test 3: Dataset analysis
    analyze_tokenized_dataset(tokenized_datasets)


üîç Starting complete tokenization verification...

=== SANSKRIT TOKENIZATION VERIFICATION ===

--- Sample 1 ---
Original English: The ascetic VƒÅlmƒ´ki asked NƒÅrada, the best of sages and foremost of those conversant with words, ever engaged in austerities and Vedic studies.
Original Sanskrit: ‡•ê ‡§§‡§™‡§É ‡§∏‡•ç‡§µ‡§æ‡§ß‡•ç‡§Ø‡§æ‡§Ø‡§®‡§ø‡§∞‡§§‡§Ç ‡§§‡§™‡§∏‡•ç‡§µ‡•Ä ‡§µ‡§æ‡§ó‡•ç‡§µ‡§ø‡§¶‡§æ‡§Ç ‡§µ‡§∞‡§Æ‡•ç‡•§ ‡§®‡§æ‡§∞‡§¶‡§Ç ‡§™‡§∞‡§ø‡§™‡§™‡•ç‡§∞‡§ö‡•ç‡§õ ‡§µ‡§æ‡§≤‡•ç‡§Æ‡•Ä‡§ï‡§ø‡§∞‡•ç‡§Æ‡•Å‡§®‡§ø‡§™‡•Å‡§ô‡•ç‡§ó‡§µ‡§Æ‡•ç‡••
English tokenized length: 51 tokens
English decoded: translate English to Sanskrit: The ascetic VƒÅlmƒ´ki asked NƒÅrada, the best of sages and foremost of those conversant with words, ever engaged in austerities and Vedic studies.
English matches: True
Sanskrit tokenized length: 42 tokens
Sanskrit decoded: ‡•ê ‡§§‡§™‡§É ‡§∏‡•ç‡§µ‡§æ‡§ß‡•ç‡§Ø‡§æ‡§Ø‡§®‡§ø‡§∞‡§§‡§Ç ‡§§‡§™‡§∏‡•ç‡§µ‡•Ä ‡§µ‡§æ‡§ó‡•ç‡§µ‡§ø‡§¶‡§æ‡§Ç ‡§µ‡§∞‡§Æ‡•ç‡•§ ‡§®‡§æ‡§∞‡§¶‡§Ç ‡§™‡§∞‡§ø‡§™‡§™‡•ç‡§∞

In [None]:
sanskrit_text = "‡§∞‡§æ‡§Æ‡§É ‡§µ‡§®‡§Ç ‡§ó‡§ö‡•ç‡§õ‡§§‡§ø"
encoded = tokenizer(sanskrit_text, return_tensors="pt")
print("Input IDs:", encoded["input_ids"])
print("Tokens:", [tokenizer.convert_ids_to_tokens(i) for i in encoded["input_ids"]])


Input IDs: tensor([[25709, 15052,  3778, 64700,  1462, 25844,  6491,     1]])
Tokens: [['‚ñÅ‡§∞‡§æ‡§Æ', '‡§É', '‚ñÅ‡§µ', '‡§®‡§Ç', '‚ñÅ‡§ó', '‡§ö‡•ç‡§õ', '‡§§‡§ø', '</s>']]


In [None]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 75161
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 6148
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 11721
    })
})

In [None]:
# Step 8: Training setup optimized for Colab
def setup_training(tokenized_datasets, output_dir="/content/drive/MyDrive/mt5-sanskrit-translator"):
    """
    Training setup optimized for Google Colab
    Saves to Google Drive if mounted
    """

    # Data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        padding=True,
        return_tensors="pt"
    )

    # Training arguments optimized for Colab
    training_args_balanced = TrainingArguments(
      output_dir=output_dir,
      per_device_train_batch_size=16,         # Conservative but fast
      per_device_eval_batch_size=16,
      gradient_accumulation_steps=8,         # Effective batch size = 32
      learning_rate=3e-5,                    # Sweet spot for mT5
      num_train_epochs=30,                    # Good balance
      warmup_steps=500,
      logging_steps=50,
      eval_steps=150,
      save_steps=200,
      eval_strategy="steps",
      save_strategy="steps",
      metric_for_best_model="eval_loss",
      greater_is_better=False,
      report_to=None,
      dataloader_pin_memory=True,
      bf16=True,                             # A100's strength
      fp16=False,
      save_total_limit=2,
      remove_unused_columns=False,
      dataloader_num_workers=2,
      optim="adamw_torch_fused",
      lr_scheduler_type="linear",
      weight_decay=0.01,
      run_name="mt5-sanskrit-improved",      # Fix wandb warning
    )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args_balanced,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    return trainer

In [None]:
trainer = setup_training(tokenized_datasets)

  trainer = Trainer(


In [None]:
trainer.train()

Step,Training Loss,Validation Loss
150,4.9066,4.497838
300,4.8878,4.476573
450,4.8485,4.445612
600,4.7294,4.406802
750,4.7644,4.372947
900,4.7231,4.336829
1050,4.6827,4.307849
1200,4.5823,4.281235
1350,4.624,4.250913
1500,4.5819,4.226796


TrainOutput(global_step=9996, training_loss=4.298274987790526, metrics={'train_runtime': 10382.9306, 'train_samples_per_second': 123.061, 'train_steps_per_second': 0.963, 'total_flos': 1.58531653168128e+17, 'train_loss': 4.298274987790526, 'epoch': 17.0})

In [None]:
# SANSKRIT TRANSLATION INFERENCE SCRIPT
# Load and test your trained English-Sanskrit translator

import torch
from transformers import MT5ForConditionalGeneration, T5Tokenizer

class SanskritTranslator:
    def __init__(self, model_path="/content/drive/MyDrive/mt5-sanskrit-translator/checkpoint-9600"):
        """
        Initialize the Sanskrit translator with your trained model
        """
        print("üîÑ Loading Sanskrit translator...")

        # Load your trained model and tokenizer
        self.tokenizer = T5Tokenizer.from_pretrained(model_path)
        self.model = MT5ForConditionalGeneration.from_pretrained(model_path)

        # Move to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        self.model.eval()  # Set to evaluation mode

        print(f"‚úÖ Model loaded successfully on {self.device}")
        print(f"üìä Model size: {sum(p.numel() for p in self.model.parameters())/1e6:.1f}M parameters")

    def translate(self, english_text, max_length=256, num_beams=4, temperature=1.0):
        """
        Translate English text to Sanskrit
        """
        # Prepare input with task prefix
        input_text = f"translate English to Sanskrit: {english_text}"

        # Tokenize input
        input_ids = self.tokenizer.encode(
            input_text,
            return_tensors="pt",
            max_length=max_length,
            truncation=True
        ).to(self.device)

        # Generate translation
        with torch.no_grad():
            outputs = self.model.generate(
            input_ids,
            max_length=128,          # ‚Üê ADD THESE FIXES
            min_length=5,            # ‚Üê
            num_beams=2,             # ‚Üê
            no_repeat_ngram_size=3,  # ‚Üê
            repetition_penalty=1.2,  # ‚Üê
            length_penalty=1.0,      # ‚Üê
            early_stopping=True,
            do_sample=True,          # ‚Üê
            temperature=0.7,         # ‚Üê
            top_p=0.9,              # ‚Üê
        )




        # Decode output
        sanskrit_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return sanskrit_text

    def batch_translate(self, english_texts, max_length=256, num_beams=4):
        """
        Translate multiple texts at once
        """
        results = []
        for text in english_texts:
            translation = self.translate(text, max_length, num_beams)
            results.append(translation)
        return results

    def interactive_translate(self):
        """
        Interactive translation mode
        """
        print("\nüéØ INTERACTIVE SANSKRIT TRANSLATOR")
        print("Enter English text to translate (type 'quit' to exit)")
        print("-" * 50)

        while True:
            english_text = input("\nüìù English: ").strip()

            if english_text.lower() in ['quit', 'exit', 'q']:
                print("üëã Goodbye!")
                break

            if not english_text:
                continue

            try:
                print("üîÑ Translating...")
                sanskrit_text = self.translate(english_text)
                print(f"üïâÔ∏è  Sanskrit: {sanskrit_text}")

            except Exception as e:
                print(f"‚ùå Error: {e}")

# Test sentences for evaluation
TEST_SENTENCES = [
    "The sun rises in the east.",
    "Knowledge is the greatest treasure.",
    "Truth always prevails in the end.",
    "The wise man learns from experience.",
    "Meditation brings peace to the mind.",
    "Water flows down the mountain.",
    "The teacher guides the student.",
    "Love conquers all obstacles.",
    "Time heals all wounds.",
    "Practice makes perfect."
]

def run_test_translations(translator):
    """
    Run test translations on sample sentences
    """
    print("\nüß™ TESTING TRANSLATIONS")
    print("=" * 60)

    for i, sentence in enumerate(TEST_SENTENCES, 1):
        print(f"\n{i}. Testing: '{sentence}'")
        try:
            sanskrit = translator.translate(sentence)
            print(f"   Sanskrit: {sanskrit}")
            print(f"   Length: {len(sentence)} chars ‚Üí {len(sanskrit)} chars")
        except Exception as e:
            print(f"Error: {e}")

    print("\n" + "=" * 60)

def compare_with_original_dataset(translator, full_dataset, num_samples=5):
    """
    Compare translations with original dataset
    """
    print(f"\nüìä COMPARING WITH ORIGINAL DATASET ({num_samples} samples)")
    print("=" * 70)

    test_samples = full_dataset["test"]["translation"][:num_samples]

    for i, sample in enumerate(test_samples, 1):
        original_english = sample['en']
        original_sanskrit = sample['sn']

        print(f"\n--- Sample {i} ---")
        print(f"üìù Original English: {original_english}")
        print(f"üéØ Expected Sanskrit: {original_sanskrit}")

        try:
            predicted_sanskrit = translator.translate(original_english)
            print(f"ü§ñ Model Sanskrit:   {predicted_sanskrit}")

            # Simple similarity check
            if original_sanskrit == predicted_sanskrit:
                print("‚úÖ PERFECT MATCH!")
            elif len(set(original_sanskrit.split()) & set(predicted_sanskrit.split())) > 0:
                print("üü° Some words match")
            else:
                print("üîç Different translation")

        except Exception as e:
            print(f"‚ùå Translation error: {e}")

        print("-" * 50)

def main():
    """
    Main function to run the inference script
    """
    print("üïâÔ∏è  SANSKRIT TRANSLATOR - INFERENCE SCRIPT")
    print("=" * 50)

    # Initialize translator
    try:
        translator = SanskritTranslator()
    except Exception as e:
        print(f"‚ùå Failed to load model: {e}")
        print("Make sure your model is saved in the correct path!")
        return

    # Run test translations
    run_test_translations(translator)

    # Compare with dataset if available
    # Uncomment the line below if you want to compare with original dataset
    # compare_with_original_dataset(translator, full_dataset)

    # Interactive mode
    print("\nReady for interactive translation!")
    print("Options:")
    print("1. Type 'test' to run more test sentences")
    print("2. Type 'interactive' to start interactive mode")
    print("3. Type specific English text to translate")

    while True:
        user_input = input("\n> ").strip()

        if user_input.lower() == 'quit':
            break
        elif user_input.lower() == 'test':
            run_test_translations(translator)
        elif user_input.lower() == 'interactive':
            translator.interactive_translate()
        elif user_input:
            try:
                result = translator.translate(user_input)
                print(f"Sanskrit: {result}")
            except Exception as e:
                print(f"Error: {e}")

# Run the script


In [None]:
if __name__ == "__main__":
    main()


üïâÔ∏è  SANSKRIT TRANSLATOR - INFERENCE SCRIPT
üîÑ Loading Sanskrit translator...
‚úÖ Model loaded successfully on cuda
üìä Model size: 300.2M parameters

üß™ TESTING TRANSLATIONS

1. Testing: 'The sun rises in the east.'
   Sanskrit: ‡§∏‡•Ç‡§∞‡•ç‡§Ø‡§∏‡•ç‡§Ø ‡§∏‡•Ç‡§∞‡•ç‡§Ø‡§∏‡•ç‡§Ø ‡§∏‡•Ç‡§∞‡•ç‡§Ø‡§∏‡•ç‡§Ø ‡§∏‡•Ç‡§∞‡•ç‡§Ø‡§∏‡•ç‡§Ø ‡§∏‡•Ç‡§∞‡•ç‡§Ø‡§∏‡•ç‡§Ø ‡§ö‡•§ ‡§∏‡•Ç‡§∞‡•ç‡§Ø‡§∏‡•ç‡§Ø ‡§∏‡•Ç‡§∞‡•ç‡§Ø‡§∏‡•ç‡§Ø ‡§∏‡•Ç‡§∞‡•ç‡§Ø‡§∏‡•ç‡§Ø ‡§∏‡•Ç‡§∞‡•ç‡§Ø‡§∏‡•ç‡§Ø ‡§∏‡•Ç‡§∞‡•ç‡§Ø‡§∏‡•ç‡§Ø ‡§ö‡••
   Length: 26 chars ‚Üí 95 chars

2. Testing: 'Knowledge is the greatest treasure.'
   Sanskrit: ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ‡•§ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ ‡§Ø‡§•‡§æ‡••
   Length: 35 chars ‚Üí 77 chars

3. Testing: 'Truth always prevails in the end.'
   Sanskrit: ‡§∏‡§§‡•ç‡§Ø‡§Ç ‡§∏‡§§‡•ç‡§Ø‡§Ç ‡§∏‡§§‡•ç‡§Ø‡§Ç ‡§∏‡§§‡•ç‡§Ø‡§Ç ‡§∏‡§§‡•ç‡§Ø‡§Ç ‡§∏‡§§‡•ç‡§Ø‡§Ç ‡§∏‡§§‡•ç‡§Ø‡

# FINAL EVALUATION ON TEST SET