# Notebook 4.1: Transformer Model - Results and Error Analysis

**Objective:** Deeply analyze the performance of the fine-tuned DistilBERT model.

This notebook loads the saved artifacts from the training process and evaluates the model's performance, looking at overall metrics and specific error cases.

In [None]:
"""
Results Analysis Notebooks
------------------------

Comprehensive analysis of model performance:
1. Overall metrics
2. Per-class analysis
3. Error cases study
4. Performance comparison
5. Visualization of results

Analysis Types:
- Confusion matrix per class
- F1-score distribution
- Sample predictions analysis
- Error case deep dives
"""

# --- 1. Setup, Imports, and Artifact Loading ---
import torch
import pandas as pd
import numpy as np
import os
import pickle
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from sklearn.metrics import classification_report

# --- Matplotlib Style ---
plt.style.use('ggplot')

# Re-define the model and dataset classes
class SciX_HF_Dataset(Dataset):
    """PyTorch Dataset for our text classification task."""
    def __init__(self, texts, labels, tokenizer, max_len, binarizer):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.binarizer = binarizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts[item])
        label_vector = self.binarizer.transform([self.labels[item]])[0]
        encoding = self.tokenizer.encode_plus(
            text, add_special_tokens=True, max_length=self.max_len,
            return_token_type_ids=False, padding='max_length',
            truncation=True, return_attention_mask=True, return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label_vector, dtype=torch.float)
        }

class MultiLabelClassifier(nn.Module):
    """Transformer-based multi-label classifier."""
    def __init__(self, model_name, n_labels):
        super(MultiLabelClassifier, self).__init__()
        self.transformer_body = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.2)
        self.classifier = nn.Linear(self.transformer_body.config.hidden_size, n_labels)
        
    def forward(self, input_ids, attention_mask):
        output = self.transformer_body(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = output.last_hidden_state[:, 0]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits
        
# --- Configuration and Paths ---
MODEL_NAME = "distilbert-base-uncased"
DATASET_NAME = "adsabs/SciX_UAT_keywords"
TEXT_COLUMN = "text"
LABEL_COLUMN = "verified_uat_labels"
BATCH_SIZE = 16  # Can be larger for inference, e.g., 32
MAX_TOKEN_LEN = 512

MODEL_OUTPUT_DIR = '../models/transformer'
MODEL_PATH = os.path.join(MODEL_OUTPUT_DIR, 'best_model_state.pth')
THRESHOLDS_PATH = os.path.join(MODEL_OUTPUT_DIR, 'best_thresholds.npy') 
BINARIZER_PATH = os.path.join(MODEL_OUTPUT_DIR, 'label_binarizer.pkl')

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# --- LOAD ALL ARTIFACTS ---
print("Loading all necessary artifacts...")
with open(BINARIZER_PATH, 'rb') as f:
    mlb = pickle.load(f)
num_labels = len(mlb.classes_)

best_thresholds = np.load(THRESHOLDS_PATH)
final_model = MultiLabelClassifier(MODEL_NAME, n_labels=num_labels).to(DEVICE)
final_model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(DEVICE)))
final_model.eval()

# --- LOAD DATA ---
dataset = load_dataset(DATASET_NAME)
def combine_text(examples):
    title = examples['title'] if examples['title'] is not None else ""
    abstract = examples['abstract'] if examples['abstract'] is not None else ""
    examples['text'] = title + " " + abstract
    return examples
dataset = dataset.map(combine_text)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
val_dataset = SciX_HF_Dataset(
    texts=dataset['val'][TEXT_COLUMN],
    labels=dataset['val'][LABEL_COLUMN], 
    tokenizer=tokenizer,
    max_len=MAX_TOKEN_LEN, 
    binarizer=mlb
)
print("Artifacts and data loaded successfully.")


# --- GENERATE PREDICTIONS AND PROBABILITIES ---
print("\nGenerating predictions on the validation set...")
all_labels = []
all_probs = []

with torch.no_grad():
    for batch in tqdm(DataLoader(val_dataset, batch_size=BATCH_SIZE), desc="Generating Predictions"):
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        
        outputs = final_model(input_ids=input_ids, attention_mask=attention_mask)
        probs = torch.sigmoid(outputs)
        
        all_probs.extend(probs.cpu().numpy())
        all_labels.extend(batch["labels"].cpu().numpy())

all_probs = np.array(all_probs)
all_labels = np.array(all_labels)

# Generate binary predictions using the optimal per-label thresholds
all_preds = (all_probs > best_thresholds).astype(int)

# Create a DataFrame from the final report for easy querying
transformer_report_df = pd.DataFrame(classification_report(
    all_labels,
    all_preds,
    target_names=mlb.classes_,
    output_dict=True,
    zero_division=0
)).transpose()

print("Analysis setup complete. You can now run the error analysis cells.")

Using device: cuda
Loading all necessary artifacts...
Artifacts and data loaded successfully.

Generating predictions on the validation set...


Generating Predictions:   0%|          | 0/190 [00:00<?, ?it/s]

Analysis setup complete. You can now run the error analysis cells.


In [10]:
# --- 1.2 Qualitative Error Analysis Function ---

def analyze_transformer_errors(df_original, true_labels, pred_labels, probs, mlb, class_name, n_samples=3):
    """
    Prints samples of False Positives and False Negatives for the Transformer model.
    """
    print("="*80)
    print(f"Error Analysis for Class: '{class_name}'")
    print(f"(Threshold for this class: {best_thresholds[list(mlb.classes_).index(class_name)]:.3f})")
    print("="*80)
    
    class_idx = list(mlb.classes_).index(class_name)
    
    # --- False Positives (Model predicted it, but it was wrong) ---
    fp_mask = (pred_labels[:, class_idx] == 1) & (true_labels[:, class_idx] == 0)
    print(f"\nFound {fp_mask.sum()} False Positives.")
    if fp_mask.sum() > 0:
        # Find the indices of the false positives
        fp_indices = np.where(fp_mask)[0]
        # Sample from these indices
        sample_indices = np.random.choice(fp_indices, size=min(n_samples, len(fp_indices)), replace=False)
        
        for i, idx in enumerate(sample_indices):
            row = df_original.iloc[idx]
            print(f"\n--- FP Sample #{i+1} (Original Index: {idx}) ---")
            print(f"Title: {row['title']}")
            print(f"  > Model's Confidence for '{class_name}': {probs[idx, class_idx]:.3f}")
            print(f"  > TRUE labels: {mlb.inverse_transform(true_labels[idx].reshape(1,-1))[0]}")

    # --- False Negatives (Model missed it, but it was a true label) ---
    fn_mask = (pred_labels[:, class_idx] == 0) & (true_labels[:, class_idx] == 1)
    print(f"\nFound {fn_mask.sum()} False Negatives.")
    if fn_mask.sum() > 0:
        fn_indices = np.where(fn_mask)[0]
        sample_indices = np.random.choice(fn_indices, size=min(n_samples, len(fn_indices)), replace=False)

        for i, idx in enumerate(sample_indices):
            row = df_original.iloc[idx]
            # Get all labels the model did predict for this sample
            predicted_labels_indices = np.where(pred_labels[idx] == 1)[0]
            predicted_labels = [mlb.classes_[j] for j in predicted_labels_indices]

            print(f"\n--- FN Sample #{i+1} (Original Index: {idx}) ---")
            print(f"Title: {row['title']}")
            print(f"  > Model's Confidence for '{class_name}': {probs[idx, class_idx]:.3f}")
            print(f"  > Model PREDICTED: {predicted_labels if predicted_labels else 'None'}")

In [None]:
# --- 1.3 Run Analysis on Specific Classes (Corrected) ---

# Get the original validation dataframe
val_original_df = pd.DataFrame(dataset['val'])

# --- Analyze a class where the Transformer performs MUCH better than the baseline ---
# Find a class with good support and a high F1-score
high_performing_classes = transformer_report_df[
    (transformer_report_df['support'] > 50) & 
    (transformer_report_df['f1-score'] > 0.7) & 
    # This is the key fix: Exclude summary rows
    (transformer_report_df.index.isin(mlb.classes_))
]
if not high_performing_classes.empty:
    good_class = high_performing_classes.sort_values('f1-score', ascending=False).index[0]
    analyze_transformer_errors(val_original_df, all_labels, all_preds, all_probs, mlb, good_class)
else:
    print("\nCould not find a high-performing class with >50 support and >0.7 F1 to analyze.")

# --- Analyze a class where the Transformer still struggles ---
# Find a class with high support but a low F1-score, EXCLUDING summary rows
problematic_classes = transformer_report_df[
    (transformer_report_df['support'] > 50) & 
    (transformer_report_df['f1-score'] < 0.5) & 
    # This is the key fix: Exclude summary rows
    (transformer_report_df.index.isin(mlb.classes_))
]
if not problematic_classes.empty:
    problem_class = problematic_classes.sort_values('f1-score').index[0]
    analyze_transformer_errors(val_original_df, all_labels, all_preds, all_probs, mlb, problem_class)
else:
    print("\nNo highly problematic classes found with >50 support and <0.5 F1 to analyze.")

Error Analysis for Class: 'gamma-ray bursts'
(Threshold for this class: 0.560)

Found 4 False Positives.

--- FP Sample #1 (Original Index: 126) ---
Title: Magnetic Field Strength Effects on Nucleosynthesis from Neutron Star Merger Outflows
  > Model's Confidence for 'gamma-ray bursts': 0.674
  > TRUE labels: ('accretion', 'magnetic fields', 'magnetohydrodynamical simulations', 'magnetohydrodynamics', 'neutron stars', 'nuclear astrophysics', 'nucleosynthesis', 'r-process')

--- FP Sample #2 (Original Index: 2026) ---
Title: Late-time Radio and Millimeter Observations of Superluminous Supernovae and Long Gamma-Ray Bursts: Implications for Central Engines, Fast Radio Bursts, and Obscured Star Formation
  > Model's Confidence for 'gamma-ray bursts': 0.752
  > TRUE labels: ('core-collapse supernovae', 'extragalactic radio sources', 'magnetars', 'radio astrometry', 'radio transient sources', 'relativistic jets', 'star formation', 'stellar physics', 'supernova remnants')

--- FP Sample #3 (O