In [1]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [2]:
# Enable hot-reloading so if you edit src/train.py, it updates here immediately
%load_ext autoreload
%autoreload 2

import torch
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from src import MBartNeutralizer, WNCDataset, WeightedSeq2SeqTrainer

print(f"GPU Available: {torch.cuda.is_available()}")

GPU Available: True


In [3]:
neutralizer = MBartNeutralizer(model_name="facebook/mbart-large-50")
model = neutralizer.get_model()
tokenizer = neutralizer.get_tokenizer()

Initializing mBART on cuda...


In [4]:
# Load the filtered "Complex" dataset created by preprocess.py
train_set = WNCDataset("data/processed/train_complex.csv", tokenizer)
val_set = WNCDataset("data/processed/val_complex.csv", tokenizer)

# Sanity Check: Print one example
sample = train_set[0]
print("Input Shape:", sample["input_ids"].shape)
print("Labels Shape:", sample["labels"].shape)


Input Shape: torch.Size([128])
Labels Shape: torch.Size([128])




In [5]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=4,  # Adjust based on your GPU VRAM
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-5,  # Lower LR for fine-tuning
    logging_steps=100,
    save_steps=1500,
    eval_strategy="steps",
    eval_steps=1500,
    fp16=True,  # Essential for mBART memory efficiency
    remove_unused_columns=False,  # IMPORTANT: Keep 'loss_weights' in the batch
)

trainer = WeightedSeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
)

  trainer = WeightedSeq2SeqTrainer(


In [None]:
# This will output the live loss curve
# train_result = trainer.train()
train_result = trainer.train()

  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


Step,Training Loss,Validation Loss
1500,0.3707,0.353512
3000,0.3495,0.33919
4500,0.3403,0.328251
6000,0.2576,0.328729
7500,0.2624,0.32725
9000,0.269,0.322399
10500,0.2602,0.322384
12000,0.2154,0.331449




In [None]:
# Save the fine-tuned weights
neutralizer.save_model("models/mbart_neutralizer_en_v1")

# Quick Inference Test
input_text = "The radical regime failed to act."
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs)
print("Output:", tokenizer.decode(outputs[0], skip_special_tokens=True))

Model saved to models/mbart_neutralizer_en_v1


`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`...
Caching is incompatible with gradient checkpointing in MBartDecoderLayer. Setting `past_key_values=None`.


Output: it of ( ) 1. # the - course failed failed failed failed attempt attempt attempt attempt attempt at at the the


In [None]:
print("Loading fine-tuned model for final verification...")
# Make sure this matches the path you saved to
saved_path = "models/mbart_neutralizer_en_v1" 
neutralizer = MBartNeutralizer(model_name=saved_path)
model = neutralizer.get_model()
tokenizer = neutralizer.get_tokenizer()

# 2. Define "The Gauntlet" (Test Cases)
test_cases = [
    # Case 1: Subjective Intensifier (Easy)
    "The radical regime failed to act on the crisis.",
    
    # Case 2: Framing Bias (Harder - subtle verb change)
    "The controversial politician foolishly denied the allegations.",
    
    # Case 3: Presupposition (Hardest - implies guilt)
    "He exposed the senator's corruption." 
]

# 3. Run Robust Inference
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

print(f"\n{'='*20} PHASE 1 COMPLETE: ENGLISH BASELINE {'='*20}\n")

for text in test_cases:
    # A. Tokenize (Force English Source)
    tokenizer.src_lang = "en_XX"
    encoded = tokenizer(text, return_tensors="pt").to(device)
    
    # B. Generate (Prevent Repetition & Force English Output)
    generated_ids = model.generate(
        **encoded,
        forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"], 
        max_length=64,
        num_beams=5,             # Smarter search
        no_repeat_ngram_size=2,  # Prevents "same same" loops
        repetition_penalty=1.2,  # Soft penalty to encourage natural phrasing
        early_stopping=True
    )
    
    # C. Decode
    output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    # D. Display
    print(f"Original: {text}")
    print(f"Neutral:  {output}")
    print("-" * 50)

print("\nIf the 'Neutral' outputs removed the biased words (radical, foolishly, exposed)")
print("while keeping the facts, Phase 1 is SUCCESSFUL.")

Loading checkpoint from ./results/checkpoint-500...
Initializing mBART on cuda...


The tokenizer you are loading from './results/checkpoint-500' with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue.



--- STEP 500 CHECKPOINT RESULTS ---
Input:  The controversial leader foolishly denied the request.
Output: the controversial leader denied the request.
------------------------------
Input:  The radical regime failed to provide for its citizens.
Output: the regime failed to provide for its citizens.
------------------------------
