In [1]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [2]:
# Enable hot-reloading so if you edit src/train.py, it updates here immediately
%load_ext autoreload
%autoreload 2

import torch
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from src import MBartNeutralizer, WNCDataset, WeightedSeq2SeqTrainer

print(f"GPU Available: {torch.cuda.is_available()}")

GPU Available: True


In [3]:
neutralizer = MBartNeutralizer(model_name="facebook/mbart-large-50")
model = neutralizer.get_model()
tokenizer = neutralizer.get_tokenizer()

Initializing mBART on cuda...


In [4]:
# Load the filtered "Complex" dataset created by preprocess.py
train_set = WNCDataset("data/processed/train_complex.csv", tokenizer)
val_set = WNCDataset("data/processed/val_complex.csv", tokenizer)

# Sanity Check: Print one example
sample = train_set[0]
print("Input Shape:", sample["input_ids"].shape)
print("Labels Shape:", sample["labels"].shape)


Input Shape: torch.Size([128])
Labels Shape: torch.Size([128])




In [5]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=4,  # Adjust based on your GPU VRAM
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-5,  # Lower LR for fine-tuning
    logging_steps=100,
    save_steps=1500,
    eval_strategy="steps",
    eval_steps=1500,
    fp16=True,  # Essential for mBART memory efficiency
    remove_unused_columns=False,  # IMPORTANT: Keep 'loss_weights' in the batch
)

trainer = WeightedSeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
)

  trainer = WeightedSeq2SeqTrainer(


In [6]:
# This will output the live loss curve
# train_result = trainer.train()
train_result = trainer.train()

  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


Step,Training Loss,Validation Loss
1500,0.3707,0.353512
3000,0.3495,0.33919
4500,0.3403,0.328251
6000,0.2576,0.328729
7500,0.2624,0.32725
9000,0.269,0.322399
10500,0.2602,0.322384
12000,0.2154,0.331449
13500,0.2125,0.330756
15000,0.2062,0.330291




In [7]:
# Save the fine-tuned weights
neutralizer.save_model("models/mbart_neutralizer_en_v1")

# Quick Inference Test
input_text = "The radical regime failed to act."
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs)
print("Output:", tokenizer.decode(outputs[0], skip_special_tokens=True))

Model saved to models/mbart_neutralizer_en_v1
Output: the regime failed to act.


In [8]:
print("Loading fine-tuned model for final verification...")
# Make sure this matches the path you saved to
saved_path = "models/mbart_neutralizer_en_v1" 
neutralizer = MBartNeutralizer(model_name=saved_path)
model = neutralizer.get_model()
tokenizer = neutralizer.get_tokenizer()

# 2. Define "The Gauntlet" (Test Cases)
test_cases = [
    # Case 1: Subjective Intensifier (Easy)
    "The radical regime failed to act on the crisis.",
    
    # Case 2: Framing Bias (Harder - subtle verb change)
    "The controversial politician foolishly denied the allegations.",
    
    # Case 3: Presupposition (Hardest - implies guilt)
    "He exposed the senator's corruption." 
]

# 3. Run Robust Inference
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

print(f"\n{'='*20} PHASE 1 COMPLETE: ENGLISH BASELINE {'='*20}\n")

for text in test_cases:
    # A. Tokenize (Force English Source)
    tokenizer.src_lang = "en_XX"
    encoded = tokenizer(text, return_tensors="pt").to(device)
    
    # B. Generate (Prevent Repetition & Force English Output)
    generated_ids = model.generate(
        **encoded,
        forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"], 
        max_length=64,
        num_beams=5,             # Smarter search
        no_repeat_ngram_size=2,  # Prevents "same same" loops
        repetition_penalty=1.2,  # Soft penalty to encourage natural phrasing
        early_stopping=True
    )
    
    # C. Decode
    output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    # D. Display
    print(f"Original: {text}")
    print(f"Neutral:  {output}")
    print("-" * 50)

print("\nIf the 'Neutral' outputs removed the biased words (radical, foolishly, exposed)")
print("while keeping the facts, Phase 1 is SUCCESSFUL.")

Loading fine-tuned model for final verification...
Initializing mBART on cuda...


The tokenizer you are loading from 'models/mbart_neutralizer_en_v1' with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue.




Original: The radical regime failed to act on the crisis.
Neutral:  the iranian government failed to act on the crisis.
--------------------------------------------------
Original: The controversial politician foolishly denied the allegations.
Neutral:  the controversial politician denied the allegations.
--------------------------------------------------
Original: He exposed the senator's corruption.
Neutral:  he accused the senator's corruption.
--------------------------------------------------

If the 'Neutral' outputs removed the biased words (radical, foolishly, exposed)
while keeping the facts, Phase 1 is SUCCESSFUL.


In [None]:
print("--- INITIALIZING PHASE 3: SYNTHETIC CHINESE TRAINING ---")

# 1. Initialize mBART for Chinese-to-Chinese (zh_CN)
# We use a fresh model instance (not the English-finetuned one) for Model 3.
# If you wanted to do "Transfer Learning" (Model 2 -> 3), you would load "models/mbart_neutralizer_en_v1" instead.
# For now, let's train from scratch (Base mBART) on the synthetic data to compare fairly.
neutralizer = MBartNeutralizer(
    model_name="facebook/mbart-large-50", 
    src_lang="zh_CN", 
    tgt_lang="zh_CN"
)
model = neutralizer.get_model()
tokenizer = neutralizer.get_tokenizer()

# 2. Load the Synthetic Chinese Data
# Ensure your translation script has finished and these files exist!
train_path = "data/processed/train_chinese_synthetic.csv"
val_path = "data/processed/val_chinese_synthetic.csv"

print(f"Loading datasets from {train_path}...")
train_set = WNCDataset(train_path, tokenizer)
val_set = WNCDataset(val_path, tokenizer)

# Sanity Check: Verify the first sample looks like Chinese
print(f"Sample Input ID 0: {train_set[0]['input_ids'][0]}") 
print("Language Code ID for zh_CN:", tokenizer.lang_code_to_id["zh_CN"]) 
# The first token of input_ids SHOULD match the zh_CN ID.

# 3. Configure Training (Safe Settings for RTX 4090)
training_args = Seq2SeqTrainingArguments(
    output_dir="./results_zh",          # <--- NEW OUTPUT DIR
    per_device_train_batch_size=4,      # 4090 can handle 4 easily with fp16
    gradient_accumulation_steps=8,      # Effective Batch Size = 32
    gradient_checkpointing=True,        # Save VRAM
    num_train_epochs=3,
    learning_rate=2e-5,                 # Standard fine-tuning rate
    logging_steps=100,
    save_steps=1500,                     # Save checkpoint every 500 steps
    save_total_limit=2,                 # Keep disk clean
    evaluation_strategy="steps",
    eval_steps=1500,                    # Evaluate less often to save time
    fp16=True,
    remove_unused_columns=False
)

# 4. Initialize Trainer
trainer = WeightedSeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model)
)


In [None]:
# 5. Train
trainer.train()

# 6. Save Final Model
output_path = "models/mbart_neutralizer_zh_synthetic"
neutralizer.save_model(output_path)
print(f"Phase 3 Complete. Model saved to {output_path}")