In [4]:
# Install required libraries
!pip install -q transformers datasets accelerate scikit-learn pandas install hf_transfer
from transformers.trainer_utils import get_last_checkpoint
import os
import json
import pickle
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import (
    AutoTokenizer, 
    BartForSequenceClassification, 
    TrainingArguments, 
    Trainer
)
from datasets import Dataset
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score, accuracy_score
from torch import nn

# H100 Hardware Acceleration (PyTorch 2.4+)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

# Directory Setup
BASE_DIR = "/workspace" 
DATA_DIR = os.path.join(BASE_DIR, "data")
OUTPUT_DIR = os.path.join(BASE_DIR, "experiments")
RETRIEVAL_CACHE = os.path.join(DATA_DIR, "retrieval_results.pkl")
VERIFIER_MODEL = "facebook/bart-large"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Using device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

[31mERROR: Could not find a version that satisfies the requirement install (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for install[0m[31m
[0mUsing device: NVIDIA A40


In [2]:
print("Loading data...")
def load_json(path):
    with open(os.path.join(DATA_DIR, path)) as f: return json.load(f)

train_claims = load_json("train_claims_quantemp.json")
val_claims = load_json("val_claims_quantemp.json")
test_claims = load_json("test_claims_quantemp.json")

with open(RETRIEVAL_CACHE, "rb") as f:
    retrieval_results = pickle.load(f)

print(f"Loaded {len(train_claims)} training claims.")
print(f"Retrieval results keys: {retrieval_results.keys()}")

Loading data...
Loaded 9935 training claims.
Retrieval results keys: dict_keys(['baseline', 'decomposed', 'repo'])


In [6]:
def normalize_label(label):
    l = str(label).lower()
    if any(x in l for x in ["support", "true", "correct"]): return 0
    if any(x in l for x in ["refute", "false", "pants"]): return 1
    return 2

def create_examples(claims, evidence_dict, top_k=3):
    examples = []
    for idx, obj in enumerate(claims):
        label = normalize_label(obj["label"])
        evs = evidence_dict.get(idx, [])
        for ev in evs[:top_k]:
            if len(ev.strip()) < 20: continue
            examples.append({"claim": obj["claim"], "evidence": ev[:1024], "label": label})
    return examples

# Tokenization
tokenizer = AutoTokenizer.from_pretrained(VERIFIER_MODEL)

def process_ds(examples):
    ds = Dataset.from_list(examples)
    return ds.map(lambda x: tokenizer(
        x["evidence"], x["claim"], truncation=True, padding="max_length", max_length=512
    ), batched=True)

train_ds = process_ds(create_examples(train_claims, retrieval_results["decomposed"]["train"]))
val_ds = process_ds(create_examples(val_claims, retrieval_results["decomposed"]["val"]))

# Compute weights for the loss function
y_labels = [ex['label'] for ex in create_examples(train_claims, retrieval_results["decomposed"]["train"])]
weights = compute_class_weight("balanced", classes=np.array([0,1,2]), y=y_labels)
print(f"Class weights: {weights}")

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Map:   0%|          | 0/29178 [00:00<?, ? examples/s]

Map:   0%|          | 0/9027 [00:00<?, ? examples/s]

Class weights: [1.80814278 0.57492463 1.41325196]


In [7]:
from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support

def compute_metrics(eval_pred):
    # Extract logits and labels
    logits, labels = eval_pred.predictions, eval_pred.label_ids
    
    # BART specific fix: extract first element if it's a tuple
    if isinstance(logits, tuple):
        logits = logits[0]
        
    # Get class predictions
    preds = np.argmax(logits, axis=-1)
    
    # Calculate global metrics
    macro_f1 = f1_score(labels, preds, average="macro")
    acc = accuracy_score(labels, preds)
    
    # Calculate per-class metrics
    # Returns 4 arrays: (precisions, recalls, f1s, supports)
    precision, recall, f1, support = precision_recall_fscore_support(
        labels, preds, labels=[0, 1, 2], zero_division=0
    )
    
    # Create the dictionary of results
    results = {
        "accuracy": acc,
        "macro_f1": macro_f1,
    }
    
    # Add per-class details to the results
    class_names = ["support", "refute", "nei"]
    for i, name in enumerate(class_names):
        results[f"{name}_precision"] = precision[i]
        results[f"{name}_recall"] = recall[i]
        results[f"{name}_f1"] = f1[i]
        results[f"{name}_support"] = int(support[i])
        
    return results

In [11]:
import os
import torch
import numpy as np
from torch import nn
from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support
from transformers import (
    BartForSequenceClassification, 
    TrainingArguments, 
    Trainer, 
    utils
)

# --- 1. THE PERMANENT SECURITY BYPASS ---
# This forces the library to allow torch.load on PyTorch < 2.6
import transformers.utils.import_utils as import_utils
import_utils.check_torch_load_is_safe = lambda: True

# --- 2. A40 HARDWARE ACCELERATION ---
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

# --- 3. UPDATED METRICS (Per-Class + Global) ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred.predictions, eval_pred.label_ids
    if isinstance(logits, tuple):
        logits = logits[0]
    
    preds = np.argmax(logits, axis=-1)
    
    # Global metrics
    macro_f1 = f1_score(labels, preds, average="macro")
    acc = accuracy_score(labels, preds)
    
    # Per-class metrics (Support=0, Refute=1, NEI=2)
    precision, recall, f1, support = precision_recall_fscore_support(
        labels, preds, labels=[0, 1, 2], zero_division=0
    )
    
    results = {"accuracy": acc, "macro_f1": macro_f1}
    class_names = ["support", "refute", "nei"]
    for i, name in enumerate(class_names):
        results[f"{name}_f1"] = f1[i]
        results[f"{name}_support"] = int(support[i])
        
    return results

# --- 4. CUSTOM TRAINER WITH CLASS WEIGHTS ---
class BartA40Trainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        # DEVICE and weights must be defined in your notebook scope
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(weights, dtype=torch.float32).to(DEVICE))
        loss = loss_fct(outputs.logits, labels)
        return (loss, outputs) if return_outputs else loss

# --- 5. OPTIMIZED TRAINING ARGUMENTS ---
training_args = TrainingArguments(
    output_dir=os.path.join(OUTPUT_DIR, "bart_a40_optimized"),
    
    # Speed & Memory
    bf16=True,
    tf32=True,
    optim="adamw_torch_fused",
    torch_compile=True,
    torch_compile_backend="aot_eager",
    gradient_checkpointing=True,
    
    # Fast Evaluation (Optimized for A40)
    eval_strategy="steps",
    eval_steps=200,                     # Increased to 500 for better speed
    per_device_eval_batch_size=16,      # Higher batch size for eval
    eval_accumulation_steps=50,         # Large chunks to CPU to avoid PCIe lag
    
    # Training Batch Config
    per_device_train_batch_size=14,
    gradient_accumulation_steps=2,      # Effective batch = 28
    
    learning_rate=1e-5,
    num_train_epochs=3,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="macro_f1",
    report_to="none"
)

# --- 6. INITIALIZE AND TRAIN ---
model = BartForSequenceClassification.from_pretrained(
    VERIFIER_MODEL, 
    num_labels=3,
    use_safetensors=True
)

trainer = BartA40Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics
)

# RESUME FROM CHECKPOINT
trainer.train(resume_from_checkpoint=True)
trainer.save_model(os.path.join(OUTPUT_DIR, "final_bart_a40_model"))

Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-large and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight'].
W0116 18:20:47.546000 337 torch/fx/experimental/symbolic_shapes.py:6823] [8/5] _maybe_guard_rel() was called on non-relation expression Eq(s33, s6) | Eq(s6, 1)
W0116 18:20:48.849000 337 torch/fx/experimental/symbolic_shapes.py:6823] [8/6] _maybe_guard_rel() was called on non-relation expression Eq(s33, s6) | Eq(s6, 1)


Step,Training Loss,Validation Loss,Accuracy,Macro F1,Support F1,Support Support,Refute F1,Refute Support,Nei F1,Nei Support
1400,0.8169,0.908592,0.605628,0.564822,0.503755,1803,0.734357,5238,0.456354,1986
1600,0.6615,0.924532,0.63875,0.583768,0.53787,1803,0.773115,5238,0.44032,1986
1800,0.6615,0.908773,0.632879,0.583829,0.527178,1803,0.763441,5238,0.460868,1986
2000,0.6054,0.938681,0.637532,0.582149,0.549989,1803,0.77266,5238,0.423799,1986
2200,0.6054,0.944447,0.638529,0.585705,0.548945,1803,0.769792,5238,0.438377,1986
2400,0.6054,0.988068,0.639304,0.591645,0.539968,1803,0.768435,5238,0.466531,1986
2600,0.5299,1.05044,0.653595,0.596005,0.536433,1803,0.78695,5238,0.464632,1986
2800,0.5299,1.085357,0.638197,0.58459,0.548479,1803,0.772072,5238,0.43322,1986
3000,0.4202,1.079976,0.626565,0.580038,0.547166,1803,0.757004,5238,0.435944,1986
3200,0.4202,1.140268,0.627894,0.578189,0.541831,1803,0.762463,5238,0.430273,1986


W0116 18:21:06.713000 337 torch/fx/experimental/symbolic_shapes.py:6823] [8/7] _maybe_guard_rel() was called on non-relation expression Eq(s33, s6) | Eq(s6, 1)
W0116 18:21:06.861000 337 torch/fx/experimental/symbolic_shapes.py:6823] [8/7] _maybe_guard_rel() was called on non-relation expression Eq(s21, s69) | Eq(s69, 1)
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0}
Non-default generation parameters

In [13]:
model = trainer.model.eval()

def test_eval(claims, evidence_dict):
    preds, truths = [], []
    for idx, obj in enumerate(tqdm(claims)):
        label = normalize_label(obj["label"])
        evs = evidence_dict.get(idx, [])
        
        if not evs:
            preds.append(2)  # Default to NEI if no evidence
            truths.append(label)
            continue
        
        # Inference
        batch = tokenizer([e[:1024] for e in evs[:5]], [obj["claim"]]*len(evs[:5]), 
                          padding=True, truncation=True, return_tensors="pt").to(DEVICE)
        
        with torch.no_grad():
            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                logits = model(**batch).logits
            
            # Average probabilities across all evidence snippets for this claim
            avg_prob = torch.softmax(logits, dim=1).mean(dim=0)
            preds.append(int(torch.argmax(avg_prob).cpu()))
            truths.append(label)

    # Calculate all three variations
    m_f1 = f1_score(truths, preds, average="macro")
    w_f1 = f1_score(truths, preds, average="weighted")
    acc = accuracy_score(truths, preds)
    
    print("-" * 30)
    print(f"Test Accuracy:    {acc:.4f}")
    print(f"Test Macro-F1:   {m_f1:.4f} (Treats all classes equal)")
    print(f"Test Weighted-F1: {w_f1:.4f} (Accounts for class size)")
    print("-" * 30)
    
    return m_f1, w_f1

# Run the evaluation
macro_f1, weighted_f1 = test_eval(test_claims, retrieval_results["decomposed"]["test"])

100%|██████████| 2495/2495 [00:57<00:00, 43.55it/s]

------------------------------
Test Accuracy:    0.6369
Test Macro-F1:   0.5889 (Treats all classes equal)
Test Weighted-F1: 0.6495 (Accounts for class size)
------------------------------



