In [None]:
# Enable hot-reloading so if you edit src/train.py, it updates here immediately
%load_ext autoreload
%autoreload 2

import torch
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from src.model import MBartNeutralizer
from src.train import WNCDataset, WeightedSeq2SeqTrainer

print(f"GPU Available: {torch.cuda.is_available()}")

In [None]:
# Load the architecture
neutralizer = MBartNeutralizer(model_name="facebook/mbart-large-50")
model = neutralizer.get_model()
tokenizer = neutralizer.get_tokenizer()

In [None]:
# Load the filtered "Complex" dataset created by preprocess.py
train_set = WNCDataset("data/processed/train_complex.csv", tokenizer)
val_set = WNCDataset("data/processed/val_complex.csv", tokenizer)

# Sanity Check: Print one example to show weights are working
sample = train_set[0]
print("Input Shape:", sample["input_ids"].shape)
print("Weight Mask Sum:", sample["loss_weights"].sum())  # Should be < 128

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=4,  # Adjust based on your GPU VRAM
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=3e-5,  # Lower LR for fine-tuning
    logging_steps=50,
    save_steps=500,
    evaluation_strategy="steps",
    eval_steps=500,
    fp16=True,  # Essential for mBART memory efficiency
    remove_unused_columns=False,  # IMPORTANT: Keep 'loss_weights' in the batch
)

trainer = WeightedSeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
)

In [None]:
# This will output the live loss curve
train_result = trainer.train()

In [None]:
# Save the fine-tuned weights
neutralizer.save_model("models/mbart_neutralizer_en_v1")

# Quick Inference Test
input_text = "The radical regime failed to act."
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs)
print("Output:", tokenizer.decode(outputs[0], skip_special_tokens=True))