# Load Dataset

In [15]:
from datasets import load_dataset

df = load_dataset("ai4bharat/IndicHeadlineGeneration", "ml")

In [16]:
# Shuffle and select 1000 samples from the training set
train_dataset = df["train"].shuffle(seed=42).select(range(1000))

# Shuffle and select 300 samples from the test set
test_dataset = df["test"].select(range(300))

In [17]:
def custom_prompt(example):
    text = example["input"]
    prompt = f"""
Generate a **clear and concise news headline in Malayalam only** based on the following text.
Text (Malayalam): {text}
Important:
- The output must be **only a headline in Malayalam**.
- Do **not** use any other language or script.
- Do **not** include any extra commentary or formatting.
- Do **not** copy the text word-for-word.
- Start your output with: Headline:"""
    return {"prompt": prompt, "completion": " " + example["target"]}

# Apply the custom prompt formatting
train_dataset = train_dataset.map(custom_prompt, remove_columns=["input", "target"])
test_dataset = test_dataset.map(custom_prompt, remove_columns=["input", "target"])


# Load Model

In [18]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "google/gemma-3-4b-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_storage=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map= "auto",         #{"": "cuda:1"},
    torch_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [19]:
def tokenize_function(examples):
    return tokenizer(examples["prompt"], padding="max_length", truncation=True)

train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)


In [20]:
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["embed_tokens","lm_head"],
)

sft_config = SFTConfig(
    output_dir="gemma3_malayalam_headlines",
    max_seq_length=512,
    packing=True,
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=2e-4,
    fp16=False,
    bf16=True,  # since we're using bfloat16
    logging_steps=10,
    save_strategy="epoch",
    optim="paged_adamw_8bit",
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=False,
)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


# Trainer Setup

In [13]:
!pip install -U trl

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [1]:
trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=peft_config,
    processing_class=tokenizer,  # Use processing_class instead of tokenizer
)

trainer.train()
trainer.save_model()


NameError: name 'SFTTrainer' is not defined

In [None]:
eval_results = trainer.evaluate()
print(eval_results)