<a href="https://colab.research.google.com/github/adititadkod15-tech/HinglishLID1/blob/main/V100_charsiu_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# --- 1. SET LOCAL PATH ---
model_id = "/data1/users/aditi/charsiu_byt5"

# Load Data
df = pd.read_excel('/content/dakshina_dataset_with_gruut_aksharmukha_namesDateTime_clean.xlsx')

def format_input(text):
    return f"<hin>: {str(text).lower()}"

df['input_text'] = df['hinglish'].apply(format_input)
df['target_text'] = df['ipa'].apply(str)

train_df, val_df = train_test_split(df[['input_text', 'target_text']], test_size=0.1)
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)

# Setup Model & Tokenizer from local path
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

def preprocess_function(examples):
    model_inputs = tokenizer(examples["input_text"], max_length=128, truncation=True, padding="max_length")
    labels = tokenizer(text_target=examples["target_text"], max_length=128, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# disable_tqdm=True is added to the map function to respect your environment
tokenized_train = train_dataset.map(preprocess_function, batched=True, load_from_cache_file=False)
tokenized_val = val_dataset.map(preprocess_function, batched=True, load_from_cache_file=False)

In [None]:
import panphon
import numpy as np
import Levenshtein
import jiwer
import torch
import time
import os
import shutil
from transformers import (
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    TrainerCallback
)

output_dir = "./byt5-hinglish-ipa-v4"
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)

# Untie embeddings for specialization
model.config.tie_word_embeddings = False
model.resize_token_embeddings(len(tokenizer))

# Metrics remain the same
ft = panphon.FeatureTable()
def get_phoneme_distance(p_char, r_char):
    p_vecs = ft.word_to_vector_list(p_char)
    r_vecs = ft.word_to_vector_list(r_char)
    return sum(1 for pf, rf in zip(p_vecs[0], r_vecs[0]) if pf != rf) / 22 if (p_vecs and r_vecs) else 1.0

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple): preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    wer = jiwer.wer(decoded_labels, decoded_preds)
    total_edits = sum(Levenshtein.distance(p, r) for p, r in zip(decoded_preds, decoded_labels))
    total_chars = sum(len(r) for r in decoded_labels)
    per = total_edits / total_chars if total_chars > 0 else 0
    return {"wer": wer, "per": per}

# --- V100 COMPATIBLE ARGUMENTS ---
training_args = Seq2SeqTrainingArguments(
    label_smoothing_factor=0.1,
    output_dir=output_dir,
    overwrite_output_dir=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    # V100 Hardware Specifics
    fp16=True,                        # Standard for V100 (Replaces bf16)
    tf32=False,                       # V100 does not support TF32
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    gradient_accumulation_steps=2,    # Effective Batch Size = 256
    gradient_checkpointing=True,      # Keep enabled for ByT5 memory efficiency

    # Learning Schedule
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_steps=200,
    num_train_epochs=30,

    # Generation Efficiency
    predict_with_generate=True,       # Required for compute_metrics (WER, PER, etc.)
    generation_max_length=64,
    generation_num_beams=1,           # Fast evaluation

    # Logging & Performance
    logging_steps=50,
    save_total_limit=2,
    dataloader_num_workers=4,
    disable_tqdm=True,                # Disables all progress bars
    report_to="none"
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

trainer.train()

In [None]:
# Inference Function
def phonemize_hinglish_precise(text, model, tokenizer, device="cuda"):
    input_text = f"<hin>: {text.lower().strip()}"
    inputs = tokenizer(input_text, return_tensors="pt").to(device)

    with torch.no_grad():
        output_tokens = model.generate(
            **inputs,
            max_new_tokens=256,
            num_beams=5,
            repetition_penalty=1.0,
            no_repeat_ngram_size=0,
            early_stopping=True
        )
    return tokenizer.decode(output_tokens[0], skip_special_tokens=True)

# Example Usage
test_sentences = ["Meeting 09:03 AM pe hai", "Total amount Rs.500 hai"]
model.eval()
for sent in test_sentences:
    print(f"{sent} -> {phonemize_hinglish_precise(sent, model, tokenizer)}")



In [None]:
import matplotlib.pyplot as plt

# Extract data from trainer logs
history = trainer.state.log_history
train_loss = [x['loss'] for x in history if 'loss' in x]
eval_loss = [x['eval_loss'] for x in history if 'eval_loss' in x]

plt.plot(train_loss, label='Train')
plt.plot(eval_loss, label='Eval')
plt.title('V100 Training Progress')
plt.legend()
plt.show()