In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
from datasets import load_dataset

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import DataCollatorForSeq2Seq
from transformers import GenerationConfig, Seq2SeqTrainingArguments, Seq2SeqTrainer

from peft import LoraConfig, get_peft_model, TaskType
import numpy as np
import evaluate


In [None]:
ds = load_dataset("ai4bharat/samanantar", "kn")
ds

In [None]:
# Sampling a smaller subset
ds = ds['train'].train_test_split(5000, shuffle=True, seed=42)
ds

In [None]:
train_ds = ds['train'].shuffle(seed=42).select(range(25000))
test_ds = ds['test']
print(f"Test Dataset: {test_ds}")
print(f"Train Dataset: {train_ds}")

In [None]:
test_ds[0] # printing a sample

In [None]:
model_id = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id, model_max_length=512)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def format_text(src, trgt=None):
    if trgt is None:
        return f""" Translate English to Kannada: English: {src}, Kannada:"""
    else:
        return f""" Translate English to Kannada: English: {src}, Kannada:{trgt}"""

In [None]:
def tokenize_text(example):
    src = example["src"]
    trgt = example["tgt"]

    full_text = format_text(src, trgt)

    prompt_text = format_text(src)

    # Tokenize full text
    tokenized_full = tokenizer(
        full_text,
        truncation=True,
        max_length=256,
        padding=False
    )

    # Tokenize prompt-only
    tokenized_prompt = tokenizer(
        prompt_text,
        truncation=True,
        max_length=256,
        padding=False
    )

    input_ids = tokenized_full["input_ids"]

    # Create labels
    labels = input_ids.copy()

    # Mask English + instruction tokens
    prompt_len = len(tokenized_prompt["input_ids"])
    labels[:prompt_len] = [-100] * prompt_len

    return {
        "input_ids": input_ids,
        "attention_mask": tokenized_full["attention_mask"],
        "labels": labels
    }

In [None]:
train_tokenized_ds = train_ds.map(tokenize_text, remove_columns=ds.column_names['train'])
train_tokenized_ds

In [None]:
test_tokenized_ds = test_ds.map(tokenize_text, remove_columns=ds.column_names['train'])
test_tokenized_ds

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_id)
model

In [None]:
model.config.pad_token = tokenizer.eos_token_id

In [None]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)


In [None]:
lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()
lora_model.to("cuda")

In [None]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=lora_model,
    padding=True,          # REQUIRED
    label_pad_token_id=-100
)


In [None]:
batch = train_tokenized_ds[0]
outputs = lora_model(
    input_ids=torch.tensor([batch["input_ids"]]).cuda(),
    labels=torch.tensor([batch["labels"]]).cuda()
)

print(outputs.keys())


In [None]:
metric = evaluate.load("sacrebleu")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    
    # In case the model returns more than just logits
    if isinstance(preds, tuple):
        preds = preds[0]

    # Decode predictions and labels
    # Replace -100 in the labels as we can't decode them
    preds = np.where(preds < 0, tokenizer.pad_token_id, preds)
    
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Post-processing: extract only the Kannada part if needed
    # (Note: During training, the labels only contain the Kannada part because we masked the prompt)
    decoded_preds = [pred.split("Kannada:")[-1].strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    
    return {"bleu": result["score"]}

In [None]:
lora_model.generation_config = GenerationConfig(
    max_new_tokens=128,
    do_sample=False,   # important for eval stability
)


In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./llama-kannada-lora",
    eval_strategy="steps",
    logging_strategy="steps",
    
    eval_steps=250,          # Evaluate every 200 steps
    save_steps=500,
    learning_rate=2e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    predict_with_generate=True, # CRITICAL: This enables the generation loop

    fp16=True,                  # Faster training on most GPUs
    
    logging_steps=100,
    num_train_epochs=1,
    report_to="none",

)

trainer = Seq2SeqTrainer(
    model=lora_model,
    args=training_args,
    train_dataset=train_tokenized_ds,
    eval_dataset=test_tokenized_ds.select(range(100)), # Sample for faster eval
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)


In [None]:
# Start training!
result = trainer.train()