<a href="https://colab.research.google.com/github/INVISIBLE-SAM/Synergizing-Contextual-Semantics-and-Moral-Knowledge-Graphs-Moral-Foundation-Prediction/blob/main/roberta_base_MFTC_and_MFRC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

psamsahil_mfrc_cleaned1_path = kagglehub.dataset_download('psamsahil/mfrc-cleaned1')
psamsahil_emfd_1_path = kagglehub.dataset_download('psamsahil/emfd-1')
psamsahil_mftc_1_path = kagglehub.dataset_download('psamsahil/mftc-1')

print('Data source import complete.')


In [None]:
!pip install torch-geometric

# MFRC

In [None]:
# ===================================================================
# Full Code for RoBERTa+MLP and GAT Fusion Model for MFRC
#
# This script is adapted for the MFRC (Moral Foundation Reddit Corpus).
# It implements and evaluates a dual-path architecture:
# - Path 1: Fine-tunes 'roberta-base' and passes its output through an MLP.
# - Path 2: Uses a GAT over moral concepts from the eMFD lexicon.
# - Fusion: Combines path outputs using cross-attention.
# - Training: Follows a multi-stage training strategy.
# - Evaluation: Includes separate test set evaluations for each path and
#   the final fused model to serve as an ablation study.
# ===================================================================

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaModel, RobertaTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report
import pandas as pd
import numpy as np
from tqdm.autonotebook import tqdm
import os

# PyTorch Geometric imports for GAT
# Note: You may need to install torch_geometric
# pip install torch_geometric
try:
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
except ImportError:
    print("PyTorch Geometric not found. Please install it: `pip install torch_geometric`")
    GATConv = None # Placeholder
    Data = None    # Placeholder

# ===================================================================
# 1. Configuration and Device Setup
# ===================================================================
def get_device():
    """Get the appropriate device (CUDA or CPU) for training."""
    if torch.cuda.is_available():
        print("✅ Using CUDA (GPU) for training.")
        return torch.device('cuda')
    else:
        print("⚠️ Using CPU for training. This will be very slow.")
        return torch.device('cpu')

DEVICE = get_device()

# --- IMPORTANT: PLEASE SET YOUR FILE PATHS HERE ---
# UPDATE this path to your MFRC csv file.
MFRC_CSV_PATH = "/kaggle/input/mfrc-cleaned1/MFRC onehot_cleaned.csv"
# UPDATE this path to your eMFD lexicon file.
EMFD_CSV_PATH = "/kaggle/input/emfd-1/eMFD_wordlist.csv"
# --------------------------------------------------

# ===================================================================
# 2. Data Loading and Preprocessing for MFRC (UPDATED)
# ===================================================================
def convert_mfrc_to_6_foundations(df: pd.DataFrame) -> (pd.DataFrame, list):
    """
    Converts MFRC 8-label format to a 6-label format.
    Merges 'Equality' and 'Proportionality' into 'Fairness'.
    Original columns expected: Care, Equality, Proportionality, Loyalty, Authority, Purity, Non-Moral.
    """
    print("🔄 Converting MFRC labels from 8 to 6 foundations...")
    converted_df = df.copy()

    # Check for required input columns for the conversion
    if 'Equality' not in df.columns or 'Proportionality' not in df.columns:
        raise ValueError("Input CSV must contain 'Equality' and 'Proportionality' columns for conversion.")

    converted_df['Fairness'] = ((df['Equality'] == 1) | (df['Proportionality'] == 1)).astype(int)

    target_columns = ['Care', 'Fairness', 'Loyalty', 'Authority', 'Purity', 'Non-Moral']

    # Ensure all target columns are present
    for col in target_columns:
        if col not in converted_df.columns:
            converted_df[col] = 0

    # Keep only the text column and the final 6 target columns
    if 'text' not in converted_df.columns:
        raise ValueError("Input CSV must contain a 'text' column.")
    columns_to_keep = ['text'] + target_columns
    final_df = converted_df[columns_to_keep]

    print(f"✅ Label conversion complete. Target columns: {target_columns}")
    return final_df, target_columns

class MFRCDataset(Dataset):
    """Custom PyTorch Dataset for MFRC."""
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {'text': str(self.texts[idx]), 'labels': torch.tensor(self.labels[idx], dtype=torch.float32)}

def create_data_loaders(csv_path, batch_size=16):
    """Loads, preprocesses, and splits MFRC data into DataLoaders."""
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"Error: The file '{csv_path}' was not found.")
    print(f"📂 Loading MFRC data from: {csv_path}")
    df = pd.read_csv(csv_path)
    processed_df, target_names = convert_mfrc_to_6_foundations(df)

    texts = processed_df['text'].tolist()
    labels = processed_df[target_names].values

    # Stratify based on the 'Care' foundation as a proxy for moral content distribution
    X_train_val, X_test, y_train_val, y_test = train_test_split(texts, labels, test_size=0.15, random_state=42, stratify=labels[:,0])
    X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.1, random_state=42, stratify=y_train_val[:,0])
    print(f"📊 Data splits - Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")

    train_dataset = MFRCDataset(X_train, y_train)
    val_dataset = MFRCDataset(X_val, y_val)
    test_dataset = MFRCDataset(X_test, y_test)

    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

    def collate_fn(batch):
        # Tokenize texts for RoBERTa
        inputs = tokenizer(
            [item['text'] for item in batch],
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=128 # Reddit comments can be longer, consider increasing if needed
        )
        # Pass original texts for GAT path
        inputs['texts'] = [item['text'] for item in batch]
        # Stack labels
        inputs['labels'] = torch.stack([item['labels'] for item in batch])
        return inputs

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    return train_loader, val_loader, test_loader, target_names

# ===================================================================
# 3. Model Architecture Definition
# ===================================================================
class RobertaMLPPath(nn.Module):
    """Path 1: Fine-tunes RoBERTa and passes its [CLS] output through an MLP."""
    def __init__(self, output_dim=256, dropout=0.1):
        super(RobertaMLPPath, self).__init__()
        print("🔄 Initializing Path 1: RoBERTa+MLP")
        self.roberta = RobertaModel.from_pretrained('roberta-base')
        self.mlp = nn.Sequential(
            nn.Linear(768, 512),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.BatchNorm1d(512),
            nn.Linear(512, output_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.BatchNorm1d(output_dim)
        )
        print("✅ Path 1 Initialized.")
    def forward(self, input_ids, attention_mask):
        roberta_output = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = roberta_output.pooler_output
        return self.mlp(pooled_output)

class eMFDProcessor:
    """Extracts moral concepts from text using the eMFD lexicon."""
    def __init__(self, emfd_csv_path):
        if not os.path.exists(emfd_csv_path):
            raise FileNotFoundError(f"eMFD lexicon file not found at {emfd_csv_path}")
        df = pd.read_csv(emfd_csv_path)
        self.moral_foundations = ['care', 'fairness', 'loyalty', 'authority', 'purity']
        prob_cols = [f'{f}_p' for f in self.moral_foundations]

        # Create a dictionary mapping words to their moral foundation probabilities
        self.emfd_data = {}
        for _, row in df.iterrows():
            # Create a 6-dim vector (5 foundations + 1 for non-moral placeholder)
            probabilities = np.append([row.get(col, 0.0) for col in prob_cols], 0.0).astype(np.float32)
            self.emfd_data[row['word']] = probabilities

    def extract_moral_concepts(self, text):
        """Finds words from the eMFD lexicon in the input text."""
        concepts = [{'word': w, 'probabilities': self.emfd_data[w]} for w in text.lower().split() if w in self.emfd_data]
        # If no concepts are found, return a single 'neutral' concept to avoid errors
        if not concepts:
            return [{'word': 'neutral', 'probabilities': np.zeros(6, dtype=np.float32)}]
        return concepts

class MoralGraphConstructor:
    """Builds a graph from extracted moral concepts."""
    def create_moral_graph(self, moral_concepts):
        if not Data: return None
        num_concepts = len(moral_concepts)

        # Node features: 6 moral probabilities + 256 for learnable embeddings
        node_features = torch.FloatTensor(np.array([
            np.concatenate([c['probabilities'], np.zeros(256)]) for c in moral_concepts
        ]))

        # FIX: Handle graphs with a single node by creating a self-loop
        if num_concepts == 1:
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
        else:
            # Create a fully connected graph (excluding self-loops)
            edge_index = torch.tensor(
                [[i, j] for i in range(num_concepts) for j in range(num_concepts) if i != j],
                dtype=torch.long
            ).t().contiguous()

        return Data(x=node_features, edge_index=edge_index)

class GATeMFDModule(nn.Module):
    """Processes the moral graph using a Graph Attention Network (GAT)."""
    def __init__(self, input_dim=262, output_dim=256, num_heads=4, dropout=0.1):
        super(GATeMFDModule, self).__init__()
        # Embedding for each concept word (up to 1000 unique concepts per graph)
        self.concept_embeddings = nn.Embedding(1000, 256)
        self.gat_layer = GATConv(input_dim, output_dim, heads=num_heads, dropout=dropout, concat=False) if GATConv else nn.Identity()
        self.attention_pooling = nn.Linear(output_dim, 1)

    def forward(self, graph_data):
        if not GATConv: return torch.zeros((1, 256)).to(DEVICE)
        x, edge_index = graph_data.x.to(DEVICE), graph_data.edge_index.to(DEVICE)

        # Add learnable embeddings to the feature vector
        x[:, 6:] = self.concept_embeddings(torch.arange(x.size(0)).long().to(DEVICE))

        gat_output = F.elu(self.gat_layer(x, edge_index))
        # Use attention to pool node features into a single graph-level representation
        attn_weights = F.softmax(self.attention_pooling(gat_output), dim=0)
        pooled_output = torch.sum(attn_weights * gat_output, dim=0, keepdim=True)
        return pooled_output

class GATeMFDPath(nn.Module):
    """Path 2: Full GAT pipeline from text to a single feature vector."""
    def __init__(self, emfd_csv_path):
        super(GATeMFDPath, self).__init__()
        print("🔄 Initializing Path 2: GAT eMFD")
        self.emfd_processor = eMFDProcessor(emfd_csv_path)
        self.graph_constructor = MoralGraphConstructor()
        self.gat_module = GATeMFDModule()
        print("✅ Path 2 Initialized.")

    def forward(self, texts):
        batch_graphs_out = []
        for text in texts:
            concepts = self.emfd_processor.extract_moral_concepts(text)
            graph = self.graph_constructor.create_moral_graph(concepts)
            if graph:
                graph_out = self.gat_module(graph)
                batch_graphs_out.append(graph_out)
            else: # Fallback if graph creation fails
                batch_graphs_out.append(torch.zeros((1, 256)).to(DEVICE))
        return torch.cat(batch_graphs_out, dim=0)

class CrossAttentionFusionLayer(nn.Module):
    """Fuses features from the two paths using cross-attention."""
    def __init__(self, feature_dim=256, num_heads=8, dropout=0.1):
        super(CrossAttentionFusionLayer, self).__init__()
        self.cross_attn_1_to_2 = nn.MultiheadAttention(feature_dim, num_heads, dropout=dropout, batch_first=True)
        self.cross_attn_2_to_1 = nn.MultiheadAttention(feature_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(feature_dim)
        self.norm2 = nn.LayerNorm(feature_dim)
        self.fusion_proj = nn.Sequential(
            nn.Linear(feature_dim * 2, feature_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.BatchNorm1d(feature_dim)
        )
    def forward(self, p1, p2):
        p1_res, p2_res = p1.unsqueeze(1), p2.unsqueeze(1)
        # p1 attends to p2
        attn_out1, _ = self.cross_attn_1_to_2(query=p1_res, key=p2_res, value=p2_res)
        p1_fused = self.norm1(attn_out1 + p1_res)
        # p2 attends to p1
        attn_out2, _ = self.cross_attn_2_to_1(query=p2_res, key=p1_res, value=p1_res)
        p2_fused = self.norm2(attn_out2 + p2_res)
        # Concatenate and project
        fused_vector = torch.cat([p1_fused.squeeze(1), p2_fused.squeeze(1)], dim=1)
        return self.fusion_proj(fused_vector)

class EnhancedClassifier(nn.Module):
    """Final classification head for the fused features."""
    def __init__(self, input_dim=256, num_classes=6, dropout=0.1):
        super(EnhancedClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 512), nn.GELU(), nn.Dropout(dropout), nn.BatchNorm1d(512),
            nn.Linear(512, 256), nn.GELU(), nn.Dropout(dropout), nn.BatchNorm1d(256),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        return self.classifier(x)

class CompleteFusionModel(nn.Module):
    """Wraps the fusion layer and the final classifier."""
    def __init__(self, num_classes=6):
        super(CompleteFusionModel, self).__init__()
        self.fusion_layer = CrossAttentionFusionLayer()
        self.classifier = EnhancedClassifier(num_classes=num_classes)
    def forward(self, p1_features, p2_features):
        fused_features = self.fusion_layer(p1_features, p2_features)
        return self.classifier(fused_features)

# ===================================================================
# 3.5. Custom Loss Function Definition (NEW SECTION)
# ===================================================================
class MultiLabelFocalLoss(nn.Module):
    """
    Focal Loss for multi-label classification.
    Adapted to work like BCEWithLogitsLoss, taking logits as input.
    """
    def __init__(self, gamma=2.0, alpha=None):
        """
        Args:
            gamma (float): The focusing parameter. Higher values give more weight to hard examples.
            alpha (Tensor, optional): A manual rescaling weight given to each class.
                                      If given, it will be used as the `pos_weight` in BCE.
                                      Should be a 1D tensor of size (num_classes,).
        """
        super(MultiLabelFocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, inputs, targets):
        """
        Args:
            inputs: (Tensor) The model's raw output logits of shape (batch_size, num_classes).
            targets: (Tensor) The ground truth labels of shape (batch_size, num_classes).
        """
        # Calculate the binary cross-entropy loss with logits, but without reduction
        bce_loss = F.binary_cross_entropy_with_logits(
            inputs, targets, reduction='none', pos_weight=self.alpha
        )

        # Calculate the probability p_t for each example
        # p_t = p if y=1, 1-p if y=0, where p = sigmoid(inputs)
        # A numerically stable way to compute p_t is using the calculated BCE loss:
        # p_t = exp(-bce_loss)
        p_t = torch.exp(-bce_loss)

        # Calculate the Focal Loss
        # The core formula is: FL = (1 - p_t)^gamma * BCE_loss
        focal_loss = ((1 - p_t) ** self.gamma) * bce_loss

        # Return the mean loss over the batch
        return focal_loss.mean()

# ===================================================================
# 4. Multi-Stage Training
# ===================================================================
class MultiStageTrainer:
    """Manages the four-stage training process and saves component models."""
    def __init__(self, roberta_mlp_path, gat_path, fusion_model, train_loader, val_loader, num_classes=6, focal_loss_alpha=None, focal_loss_gamma=2.0):
        self.roberta_mlp_path = roberta_mlp_path.to(DEVICE)
        self.gat_path = gat_path.to(DEVICE)
        self.fusion_model = fusion_model.to(DEVICE)
        self.train_loader = train_loader
        self.val_loader = val_loader
        # --- UPDATED: Use MultiLabelFocalLoss ---
        print(f"🔥 Using MultiLabelFocalLoss with gamma={focal_loss_gamma} and calculated alpha weights.")
        self.criterion = MultiLabelFocalLoss(gamma=focal_loss_gamma, alpha=focal_loss_alpha)
        self.num_classes = num_classes

    def _freeze(self, model):
        for param in model.parameters():
            param.requires_grad = False

    def _unfreeze(self, model):
        for param in model.parameters():
            param.requires_grad = True

    def train_stage1(self, epochs=9, lr=2e-5):
        print("\n" + "="*20 + " Stage 1: Training RoBERTa+MLP Path " + "="*20)
        # Freeze GAT and fusion model, unfreeze RoBERTa+MLP path
        self._freeze(self.gat_path)
        self._freeze(self.fusion_model)
        self._unfreeze(self.roberta_mlp_path)

        # Temporary classifier for stage 1
        stage1_classifier = nn.Linear(256, self.num_classes).to(DEVICE)
        optimizer = optim.AdamW(list(self.roberta_mlp_path.parameters()) + list(stage1_classifier.parameters()), lr=lr)

        for epoch in range(epochs):
            self.roberta_mlp_path.train()
            stage1_classifier.train()
            for batch in tqdm(self.train_loader, desc=f"Stage 1 - Epoch {epoch+1}/{epochs}"):
                optimizer.zero_grad()
                features = self.roberta_mlp_path(batch['input_ids'].to(DEVICE), batch['attention_mask'].to(DEVICE))
                loss = self.criterion(stage1_classifier(features), batch['labels'].to(DEVICE))
                loss.backward()
                optimizer.step()

        # Save fine-tuned RoBERTa+MLP path and classifier
        torch.save(self.roberta_mlp_path.state_dict(), "roberta_mlp_path_stage1.pth")
        torch.save(stage1_classifier.state_dict(), "roberta_mlp_classifier_stage1.pth")
        print("✅ Stage 1 Complete. Models saved.")

    def train_stage2(self, epochs=5, lr=5e-4):
        print("\n" + "="*20 + " Stage 2: Training GAT Path " + "="*20)
        # Freeze RoBERTa+MLP path entirely, freeze fusion model
        self._freeze(self.roberta_mlp_path)
        self._freeze(self.fusion_model)
        self._unfreeze(self.gat_path)

        # Temporary classifier for stage 2
        stage2_classifier = nn.Linear(256, self.num_classes).to(DEVICE)
        optimizer = optim.AdamW(list(self.gat_path.parameters()) + list(stage2_classifier.parameters()), lr=lr)

        for epoch in range(epochs):
            self.gat_path.train()
            stage2_classifier.train()
            for batch in tqdm(self.train_loader, desc=f"Stage 2 - Epoch {epoch+1}/{epochs}"):
                optimizer.zero_grad()
                features = self.gat_path(batch['texts'])
                loss = self.criterion(stage2_classifier(features), batch['labels'].to(DEVICE))
                loss.backward()
                optimizer.step()

        torch.save(self.gat_path.state_dict(), "gat_path_stage2.pth")
        torch.save(stage2_classifier.state_dict(), "gat_classifier_stage2.pth")
        print("✅ Stage 2 Complete. Models saved.")

    def train_stage3_and_4(self, best_model_path="best_fusion_model.pth"):
        print("\n" + "="*15 + " Stage 3 & 4: Joint Training & Fine-tuning " + "="*15)
        # Freeze RoBERTa+MLP path entirely to use fine-tuned embeddings
        self._freeze(self.roberta_mlp_path)
        self._unfreeze(self.gat_path)
        self._unfreeze(self.fusion_model)

        print("--- Stage 3: Fusion Integration (Frozen RoBERTa) ---")
        optimizer = optim.AdamW([
            {'params': self.gat_path.parameters(), 'lr': 5e-4},
            {'params': self.fusion_model.parameters(), 'lr': 1e-3}
        ])
        self._run_joint_training(5, optimizer, best_model_path, "Stage 3")

        print("\n--- Stage 4: Fusion Fine-tuning (RoBERTa Frozen) ---")
        # Keep RoBERTa frozen, continue fine-tuning GAT and fusion model
        optimizer = optim.AdamW([
            {'params': self.gat_path.parameters(), 'lr': 5e-5},
            {'params': self.fusion_model.parameters(), 'lr': 1e-4}
        ])
        self._run_joint_training(5, optimizer, best_model_path, "Stage 4", use_early_stopping=True)

        print("✅ Stage 3 & 4 Complete.")

    def _run_joint_training(self, epochs, optimizer, best_model_path, stage_desc, use_early_stopping=False):
        best_val_f1, patience_counter = 0.0, 0
        for epoch in range(epochs):
            self.roberta_mlp_path.eval()  # RoBERTa path frozen, so eval mode
            self.gat_path.train()
            self.fusion_model.train()
            for batch in tqdm(self.train_loader, desc=f"{stage_desc} - Epoch {epoch+1}/{epochs}"):
                optimizer.zero_grad()
                with torch.no_grad():
                    p1_features = self.roberta_mlp_path(batch['input_ids'].to(DEVICE), batch['attention_mask'].to(DEVICE))
                p2_features = self.gat_path(batch['texts'])
                logits = self.fusion_model(p1_features, p2_features)
                loss = self.criterion(logits, batch['labels'].to(DEVICE))
                loss.backward()
                optimizer.step()

            val_f1 = self._validate()
            print(f"{stage_desc} - Epoch {epoch+1}: Val F1-Macro: {val_f1:.4f}")

            if val_f1 > best_val_f1:
                best_val_f1, patience_counter = val_f1, 0
                torch.save(self.fusion_model.state_dict(), best_model_path)
                torch.save(self.roberta_mlp_path.state_dict(), "roberta_mlp_path_final.pth")
                torch.save(self.gat_path.state_dict(), "gat_path_final.pth")
                print(f"🏆 New best model saved with F1-Macro: {val_f1:.4f}")
            elif use_early_stopping:
                patience_counter += 1
                if patience_counter >= 2:
                    print("🛑 Early stopping triggered.")
                    break

    def _validate(self):
        self.roberta_mlp_path.eval()
        self.gat_path.eval()
        self.fusion_model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in self.val_loader:
                p1 = self.roberta_mlp_path(batch['input_ids'].to(DEVICE), batch['attention_mask'].to(DEVICE))
                p2 = self.gat_path(batch['texts'])
                logits = self.fusion_model(p1, p2)
                all_preds.append((torch.sigmoid(logits) > 0.5).cpu().numpy())
                all_labels.append(batch['labels'].cpu().numpy())
        y_pred, y_true = np.vstack(all_preds), np.vstack(all_labels)
        return f1_score(y_true, y_pred, average='macro', zero_division=0)

# ===================================================================
# 5. Final Evaluation for Ablation Study
# ===================================================================
def evaluate_roberta_mlp_only(test_loader, target_names):
    """Evaluates the standalone RoBERTa+MLP path on the test set."""
    print("\n" + "="*20 + " ABLATION: RoBERTa+MLP PATH ONLY " + "="*20)
    roberta_mlp_path = RobertaMLPPath().to(DEVICE)
    roberta_mlp_classifier = nn.Linear(256, len(target_names)).to(DEVICE)
    roberta_mlp_path.load_state_dict(torch.load("roberta_mlp_path_stage1.pth"))
    roberta_mlp_classifier.load_state_dict(torch.load("roberta_mlp_classifier_stage1.pth"))
    roberta_mlp_path.eval(); roberta_mlp_classifier.eval()

    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing RoBERTa+MLP"):
            features = roberta_mlp_path(batch['input_ids'].to(DEVICE), batch['attention_mask'].to(DEVICE))
            logits = roberta_mlp_classifier(features)
            all_preds.append((torch.sigmoid(logits) > 0.5).cpu().numpy())
            all_labels.append(batch['labels'].numpy())

    y_pred, y_true = np.vstack(all_preds), np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    print(classification_report(y_true, y_pred, target_names=target_names, zero_division=0))
    print(f"** Macro F1-Score: {macro_f1:.4f} **")

def evaluate_gat_only(test_loader, target_names):
    """Evaluates the standalone GAT path on the test set."""
    print("\n" + "="*20 + " ABLATION: GAT PATH ONLY " + "="*20)
    gat_path = GATeMFDPath(EMFD_CSV_PATH).to(DEVICE)
    gat_classifier = nn.Linear(256, len(target_names)).to(DEVICE)
    gat_path.load_state_dict(torch.load("gat_path_stage2.pth"))
    gat_classifier.load_state_dict(torch.load("gat_classifier_stage2.pth"))
    gat_path.eval(); gat_classifier.eval()

    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing GAT"):
            features = gat_path(batch['texts'])
            logits = gat_classifier(features)
            all_preds.append((torch.sigmoid(logits) > 0.5).cpu().numpy())
            all_labels.append(batch['labels'].numpy())

    y_pred, y_true = np.vstack(all_preds), np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    print(classification_report(y_true, y_pred, target_names=target_names, zero_division=0))
    print(f"** Macro F1-Score: {macro_f1:.4f} **")

def evaluate_full_fusion_model(model_path, test_loader, target_names):
    """Evaluates the final fused model on the test set."""
    print("\n" + "="*20 + " FINAL MODEL: FUSED (RoBERTa+GAT) " + "="*20)
    # Re-initialize the component paths and load their final states
    roberta_mlp_path = RobertaMLPPath().to(DEVICE)
    gat_path = GATeMFDPath(EMFD_CSV_PATH).to(DEVICE)
    roberta_mlp_path.load_state_dict(torch.load("roberta_mlp_path_final.pth"))
    gat_path.load_state_dict(torch.load("gat_path_final.pth"))

    fusion_model = CompleteFusionModel(num_classes=len(target_names)).to(DEVICE)
    fusion_model.load_state_dict(torch.load(model_path))
    roberta_mlp_path.eval(); gat_path.eval(); fusion_model.eval()

    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing Full Model"):
            p1 = roberta_mlp_path(batch['input_ids'].to(DEVICE), batch['attention_mask'].to(DEVICE))
            p2 = gat_path(batch['texts'])
            logits = fusion_model(p1, p2)
            all_preds.append((torch.sigmoid(logits) > 0.5).cpu().numpy())
            all_labels.append(batch['labels'].numpy())

    y_pred, y_true = np.vstack(all_preds), np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    print(classification_report(y_true, y_pred, target_names=target_names, zero_division=0))
    print(f"** Macro F1-Score: {macro_f1:.4f} **")


# ===================================================================
# 6. Main Execution Block (UPDATED)
# ===================================================================
if __name__ == "__main__":
    if not all([GATConv, Data]):
        print("❌ Critical PyTorch Geometric components not found. Aborting.")
    else:
        try:
            train_loader, val_loader, test_loader, target_names = create_data_loaders(
                csv_path=MFRC_CSV_PATH, batch_size=70
            )

            # --- NEW: Calculate class weights for Focal Loss ---
            print("⚖️ Calculating class weights for imbalanced data...")
            # Extract all labels from the training set
            train_labels = np.array([data['labels'].numpy() for data in train_loader.dataset])
            num_positives = np.sum(train_labels, axis=0)
            num_negatives = len(train_labels) - num_positives
            # Avoid division by zero for classes that might not appear
            pos_weight = np.where(num_positives > 0, num_negatives / num_positives, 1.0)
            pos_weight_tensor = torch.tensor(pos_weight, dtype=torch.float32).to(DEVICE)
            print(f"✅ Calculated pos_weight for {len(target_names)} classes: {pos_weight_tensor.cpu().numpy().round(2)}")

            roberta_mlp_path = RobertaMLPPath()
            gat_path = GATeMFDPath(emfd_csv_path=EMFD_CSV_PATH)
            fusion_model = CompleteFusionModel(num_classes=len(target_names))

            # --- UPDATED: Instantiate Trainer with Focal Loss ---
            trainer = MultiStageTrainer(
                roberta_mlp_path=roberta_mlp_path,
                gat_path=gat_path,
                fusion_model=fusion_model,
                train_loader=train_loader,
                val_loader=val_loader,
                num_classes=len(target_names),
                focal_loss_alpha=pos_weight_tensor  # Pass the calculated weights
            )

            # --- Execute Training Pipeline ---
            trainer.train_stage1()
            trainer.train_stage2()
            trainer.train_stage3_and_4(best_model_path="best_mfrc_fusion_model.pth")

            # --- Execute Ablation Study Evaluation on Test Set ---
            # (No changes needed in evaluation functions)
            print("\n" + "#"*70 + "\n# FINAL EVALUATION & ABLATION STUDY RESULTS (MFRC DATASET)\n" + "#"*70)
            evaluate_roberta_mlp_only(test_loader, target_names)
            evaluate_gat_only(test_loader, target_names)
            evaluate_full_fusion_model("best_mfrc_fusion_model.pth", test_loader, target_names)

        except FileNotFoundError as e:
            print(f"❌ {e}")
            print("👉 Please update MFRC_CSV_PATH and EMFD_CSV_PATH at the top of the script.")
        except Exception as e:
            print(f"An unexpected error occurred: {e}")


# MFTC

In [None]:
# ===================================================================
# Full Code for RoBERTa+MLP and GAT Fusion Model for MFTC
# WITH FOCAL LOSS AND UPDATED FINE-TUNING STRATEGY
#
# This script implements and evaluates a dual-path architecture.
# - Path 1: Fine-tunes 'roberta-base' and passes its output through an MLP.
#   This path is then FROZEN after its initial training stage.
# - Path 2: Uses a GAT over moral concepts from the eMFD lexicon.
# - Fusion: Combines path outputs using cross-attention.
# - Training: Follows a multi-stage training strategy with Focal Loss.
# - Dataset: Adapted for the MFTC (Moral Foundation Twitter Corpus).
# - Evaluation: Includes separate test set evaluations for each path and
#   the final fused model to serve as an ablation study.
# ===================================================================

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaModel, RobertaTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report
import pandas as pd
import numpy as np
from tqdm.autonotebook import tqdm
import os

# PyTorch Geometric imports for GAT
# Note: You may need to install torch_geometric
# pip install torch_geometric
try:
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
except ImportError:
    print("PyTorch Geometric not found. Please install it: `pip install torch_geometric`")
    GATConv = None # Placeholder
    Data = None    # Placeholder

# ===================================================================
# 1. Configuration and Device Setup
# ===================================================================
def get_device():
    """Get the appropriate device (CUDA or CPU) for training."""
    if torch.cuda.is_available():
        print("✅ Using CUDA (GPU) for training.")
        return torch.device('cuda')
    else:
        print("⚠️ Using CPU for training. This will be very slow.")
        return torch.device('cpu')

DEVICE = get_device()

# --- IMPORTANT: PLEASE SET YOUR FILE PATHS HERE ---
MFTC_CSV_PATH = "/kaggle/input/mftc-1/merged_dataset_cleaned.csv"      # <--- UPDATE THIS
EMFD_CSV_PATH = "/kaggle/input/emfd-1/eMFD_wordlist.csv" # <--- UPDATE THIS
# --------------------------------------------------

# ===================================================================
# 2. Data Loading and Preprocessing for MFTC
# ===================================================================
def convert_mftc_to_6_foundations(df: pd.DataFrame) -> (pd.DataFrame, list):
    """Converts MFTC's 11 moral labels into 6 broader MFT foundations."""
    print("🔄 Converting MFTC labels from 11 to 6 foundations...")
    converted_df = pd.DataFrame()

    text_col = 'text' if 'text' in df.columns else 'tweet_text'
    if text_col not in df.columns:
        raise ValueError(f"Input CSV must contain a '{text_col}' column.")
    converted_df['text'] = df[text_col]

    converted_df['Care'] = ((df.get('care', 0) == 1) | (df.get('harm', 0) == 1)).astype(int)
    converted_df['Fairness'] = ((df.get('fairness', 0) == 1) | (df.get('cheating', 0) == 1)).astype(int)
    converted_df['Loyalty'] = ((df.get('loyalty', 0) == 1) | (df.get('betrayal', 0) == 1)).astype(int)
    converted_df['Authority'] = ((df.get('authority', 0) == 1) | (df.get('subversion', 0) == 1)).astype(int)
    converted_df['Purity'] = ((df.get('purity', 0) == 1) | (df.get('degradation', 0) == 1)).astype(int)
    converted_df['Non-Moral'] = df.get('non-moral', 0).astype(int)

    target_columns = ['Care', 'Fairness', 'Loyalty', 'Authority', 'Purity', 'Non-Moral']
    print(f"✅ Label conversion complete. Target columns: {target_columns}")
    return converted_df, target_columns

class MoralFoundationDataset(Dataset):
    """Custom PyTorch Dataset for MFTC."""
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {'text': str(self.texts[idx]), 'labels': torch.tensor(self.labels[idx], dtype=torch.float32)}

def create_data_loaders(csv_path, batch_size=16):
    """Loads, preprocesses, and splits MFTC data into DataLoaders."""
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"Error: The file '{csv_path}' was not found.")
    print(f"📂 Loading MFTC data from: {csv_path}")
    df = pd.read_csv(csv_path)
    processed_df, target_names = convert_mftc_to_6_foundations(df)

    texts = processed_df['text'].tolist()
    labels = processed_df[target_names].values

    X_train_val, X_test, y_train_val, y_test = train_test_split(texts, labels, test_size=0.15, random_state=42, stratify=labels[:,0])
    X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.1, random_state=42, stratify=y_train_val[:,0])
    print(f"📊 Data splits - Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")

    train_dataset, val_dataset, test_dataset = MoralFoundationDataset(X_train, y_train), MoralFoundationDataset(X_val, y_val), MoralFoundationDataset(X_test, y_test)
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

    def collate_fn(batch):
        inputs = tokenizer([item['text'] for item in batch], return_tensors='pt', padding=True, truncation=True, max_length=128)
        inputs['texts'] = [item['text'] for item in batch]
        inputs['labels'] = torch.stack([item['labels'] for item in batch])
        return inputs

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    return train_loader, val_loader, test_loader, target_names

# ===================================================================
# 3. Model Architecture Definition
# ===================================================================
class RobertaMLPPath(nn.Module):
    def __init__(self, output_dim=256, dropout=0.1):
        super(RobertaMLPPath, self).__init__()
        self.roberta = RobertaModel.from_pretrained('roberta-base')
        self.mlp = nn.Sequential(nn.Linear(768, 512), nn.GELU(), nn.Dropout(dropout), nn.BatchNorm1d(512), nn.Linear(512, output_dim), nn.GELU(), nn.Dropout(dropout), nn.BatchNorm1d(output_dim))
    def forward(self, input_ids, attention_mask):
        return self.mlp(self.roberta(input_ids=input_ids, attention_mask=attention_mask).pooler_output)

class eMFDProcessor:
    def __init__(self, emfd_csv_path):
        self.moral_foundations = ['care', 'fairness', 'loyalty', 'authority', 'purity', 'non-moral']
        if not os.path.exists(emfd_csv_path): raise FileNotFoundError(f"eMFD file not found at {emfd_csv_path}")
        df = pd.read_csv(emfd_csv_path)
        prob_cols = [f'{f}_p' for f in self.moral_foundations if f != 'non-moral']
        self.emfd_data = {row['word']: np.append([row.get(col, 0.0) for col in prob_cols], 0.0).astype(np.float32) for _, row in df.iterrows()}
    def extract_moral_concepts(self, text):
        concepts = [{'word': w, 'probabilities': self.emfd_data[w]} for w in text.lower().split() if w in self.emfd_data]
        return concepts if concepts else [{'word': 'neutral', 'probabilities': np.zeros(6, dtype=np.float32)}]

class MoralGraphConstructor:
    def create_moral_graph(self, moral_concepts):
        num_concepts = len(moral_concepts)
        node_features = torch.FloatTensor(np.array([
            np.concatenate([c['probabilities'], np.zeros(256)]) for c in moral_concepts
        ]))
        if num_concepts == 1:
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
        else:
            edge_index = torch.tensor(
                [[i, j] for i in range(num_concepts) for j in range(num_concepts) if i != j],
                dtype=torch.long
            ).t().contiguous()
        return Data(x=node_features, edge_index=edge_index) if Data else None

class GATeMFDModule(nn.Module):
    def __init__(self, input_dim=262, output_dim=256, num_heads=4, dropout=0.1):
        super(GATeMFDModule, self).__init__()
        self.concept_embeddings = nn.Embedding(1000, 256)
        self.gat_layer = GATConv(input_dim, output_dim, heads=num_heads, dropout=dropout, concat=False) if GATConv else nn.Identity()
        self.attention_pooling = nn.Linear(output_dim, 1)
    def forward(self, graph_data):
        if not GATConv: return torch.zeros((1, 256)).to(DEVICE)
        x, edge_index = graph_data.x.to(DEVICE), graph_data.edge_index.to(DEVICE)
        x[:, 6:] = self.concept_embeddings(torch.arange(x.size(0)).long().to(DEVICE))
        gat_output = F.elu(self.gat_layer(x, edge_index))
        attn_weights = F.softmax(self.attention_pooling(gat_output), dim=0)
        return torch.sum(attn_weights * gat_output, dim=0, keepdim=True)

class GATeMFDPath(nn.Module):
    def __init__(self, emfd_csv_path):
        super(GATeMFDPath, self).__init__(); print("🔄 Initializing Path 2: GAT eMFD")
        self.emfd_processor, self.graph_constructor, self.gat_module = eMFDProcessor(emfd_csv_path), MoralGraphConstructor(), GATeMFDModule()
        print("✅ Path 2 Initialized.")
    def forward(self, texts):
        return torch.cat([self.gat_module(self.graph_constructor.create_moral_graph(self.emfd_processor.extract_moral_concepts(text)) or torch.zeros((1, 256)).to(DEVICE)) for text in texts], dim=0)

class CrossAttentionFusionLayer(nn.Module):
    def __init__(self, feature_dim=256, num_heads=8, dropout=0.1):
        super(CrossAttentionFusionLayer, self).__init__()
        self.cross_attn_1_to_2 = nn.MultiheadAttention(feature_dim, num_heads, dropout=dropout, batch_first=True)
        self.cross_attn_2_to_1 = nn.MultiheadAttention(feature_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1, self.norm2 = nn.LayerNorm(feature_dim), nn.LayerNorm(feature_dim)
        self.fusion_proj = nn.Sequential(nn.Linear(feature_dim * 2, feature_dim), nn.GELU(), nn.Dropout(dropout), nn.BatchNorm1d(feature_dim))
    def forward(self, p1, p2):
        p1r, p2r = p1.unsqueeze(1), p2.unsqueeze(1)
        a1, _ = self.cross_attn_1_to_2(query=p1r, key=p2r, value=p2r); p1f = self.norm1(a1 + p1r)
        a2, _ = self.cross_attn_2_to_1(query=p2r, key=p1r, value=p1r); p2f = self.norm2(a2 + p2r)
        return self.fusion_proj(torch.cat([p1f.squeeze(1), p2f.squeeze(1)], dim=1))

class EnhancedClassifier(nn.Module):
    def __init__(self, input_dim=256, num_classes=6, dropout=0.1):
        super(EnhancedClassifier, self).__init__()
        self.classifier = nn.Sequential(nn.Linear(input_dim, 512), nn.GELU(), nn.Dropout(dropout), nn.BatchNorm1d(512), nn.Linear(512, 256), nn.GELU(), nn.Dropout(dropout), nn.BatchNorm1d(256), nn.Linear(256, num_classes))
    def forward(self, x): return self.classifier(x)

class CompleteFusionModel(nn.Module):
    def __init__(self, num_classes=6):
        super(CompleteFusionModel, self).__init__(); self.fusion_layer = CrossAttentionFusionLayer(); self.classifier = EnhancedClassifier(num_classes=num_classes)
    def forward(self, p1, p2): return self.classifier(self.fusion_layer(p1, p2))

# ===================================================================
# 3.5. Custom Loss Function Definition (NEW)
# ===================================================================
class MultiLabelFocalLoss(nn.Module):
    """Focal Loss for multi-label classification."""
    def __init__(self, gamma=2.0, alpha=None):
        super(MultiLabelFocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none', pos_weight=self.alpha)
        p_t = torch.exp(-bce_loss)
        focal_loss = ((1 - p_t) ** self.gamma) * bce_loss
        return focal_loss.mean()

# ===================================================================
# 4. Multi-Stage Training (UPDATED)
# ===================================================================
class MultiStageTrainer:
    """Manages the four-stage training process and saves component models."""
    def __init__(self, roberta_mlp_path, gat_path, fusion_model, train_loader, val_loader, num_classes=6, focal_loss_alpha=None, focal_loss_gamma=2.0):
        self.roberta_mlp_path, self.gat_path, self.fusion_model = roberta_mlp_path.to(DEVICE), gat_path.to(DEVICE), fusion_model.to(DEVICE)
        self.train_loader, self.val_loader = train_loader, val_loader
        # --- UPDATED: Use MultiLabelFocalLoss ---
        print(f"🔥 Using MultiLabelFocalLoss with gamma={focal_loss_gamma} and calculated alpha weights.")
        self.criterion = MultiLabelFocalLoss(gamma=focal_loss_gamma, alpha=focal_loss_alpha)
        self.num_classes = num_classes

    def _freeze(self, model): [p.requires_grad_(False) for p in model.parameters()]
    def _unfreeze(self, model): [p.requires_grad_(True) for p in model.parameters()]

    def train_stage1(self, epochs=9, lr=2e-5):
        print("\n" + "="*20 + " Stage 1: Training RoBERTa+MLP Path " + "="*20)
        self._freeze(self.gat_path); self._freeze(self.fusion_model); self._unfreeze(self.roberta_mlp_path)

        stage1_classifier = nn.Linear(256, self.num_classes).to(DEVICE)
        optimizer = optim.AdamW(list(self.roberta_mlp_path.parameters()) + list(stage1_classifier.parameters()), lr=lr)

        for epoch in range(epochs):
            self.roberta_mlp_path.train(); stage1_classifier.train()
            for batch in tqdm(self.train_loader, desc=f"Stage 1 - Epoch {epoch+1}/{epochs}"):
                optimizer.zero_grad()
                features = self.roberta_mlp_path(batch['input_ids'].to(DEVICE), batch['attention_mask'].to(DEVICE))
                loss = self.criterion(stage1_classifier(features), batch['labels'].to(DEVICE))
                loss.backward(); optimizer.step()

        torch.save(self.roberta_mlp_path.state_dict(), "roberta_mlp_path_stage1.pth")
        torch.save(stage1_classifier.state_dict(), "roberta_mlp_classifier_stage1.pth")
        print("✅ Stage 1 Complete. Models saved.")

    def train_stage2(self, epochs=5, lr=5e-4):
        print("\n" + "="*20 + " Stage 2: Training GAT Path " + "="*20)
        self._freeze(self.roberta_mlp_path); self._freeze(self.fusion_model); self._unfreeze(self.gat_path)

        stage2_classifier = nn.Linear(256, self.num_classes).to(DEVICE)
        optimizer = optim.AdamW(list(self.gat_path.parameters()) + list(stage2_classifier.parameters()), lr=lr)

        for epoch in range(epochs):
            self.gat_path.train(); stage2_classifier.train()
            for batch in tqdm(self.train_loader, desc=f"Stage 2 - Epoch {epoch+1}/{epochs}"):
                optimizer.zero_grad()
                features = self.gat_path(batch['texts'])
                loss = self.criterion(stage2_classifier(features), batch['labels'].to(DEVICE))
                loss.backward(); optimizer.step()

        torch.save(self.gat_path.state_dict(), "gat_path_stage2.pth")
        torch.save(stage2_classifier.state_dict(), "gat_classifier_stage2.pth")
        print("✅ Stage 2 Complete. Models saved.")

    def train_stage3_and_4(self, best_model_path="best_fusion_model.pth"):
        print("\n" + "="*15 + " Stage 3 & 4: Joint Training & Fine-tuning " + "="*15)
        self._freeze(self.roberta_mlp_path) # Freeze RoBERTa+MLP path entirely
        self._unfreeze(self.gat_path)
        self._unfreeze(self.fusion_model)

        print("--- Stage 3: Fusion Integration (Frozen RoBERTa) ---")
        optimizer = optim.AdamW([{'params': self.gat_path.parameters(), 'lr': 5e-4}, {'params': self.fusion_model.parameters(), 'lr': 1e-3}])
        self._run_joint_training(5, optimizer, best_model_path, "Stage 3")

        print("\n--- Stage 4: Fusion Fine-tuning (RoBERTa Frozen) ---")
        optimizer = optim.AdamW([{'params': self.gat_path.parameters(), 'lr': 5e-5}, {'params': self.fusion_model.parameters(), 'lr': 1e-4}])
        self._run_joint_training(5, optimizer, best_model_path, "Stage 4", use_early_stopping=True)
        print("✅ Stage 3 & 4 Complete.")

    def _run_joint_training(self, epochs, optimizer, best_model_path, stage_desc, use_early_stopping=False):
        best_val_f1, patience_counter = 0.0, 0
        for epoch in range(epochs):
            self.roberta_mlp_path.eval() # RoBERTa path frozen, so eval mode
            self.gat_path.train(); self.fusion_model.train()
            for batch in tqdm(self.train_loader, desc=f"{stage_desc} - Epoch {epoch+1}/{epochs}"):
                optimizer.zero_grad()
                with torch.no_grad():
                    p1_features = self.roberta_mlp_path(batch['input_ids'].to(DEVICE), batch['attention_mask'].to(DEVICE))
                p2_features = self.gat_path(batch['texts'])
                logits = self.fusion_model(p1_features, p2_features)
                loss = self.criterion(logits, batch['labels'].to(DEVICE))
                loss.backward(); optimizer.step()

            val_f1 = self._validate()
            print(f"{stage_desc} - Epoch {epoch+1}: Val F1-Macro: {val_f1:.4f}")

            if val_f1 > best_val_f1:
                best_val_f1, patience_counter = val_f1, 0
                torch.save(self.fusion_model.state_dict(), best_model_path)
                torch.save(self.roberta_mlp_path.state_dict(), "roberta_mlp_path_final.pth")
                torch.save(self.gat_path.state_dict(), "gat_path_final.pth")
                print(f"🏆 New best model saved with F1-Macro: {val_f1:.4f}")
            elif use_early_stopping:
                patience_counter += 1
                if patience_counter >= 2: print("🛑 Early stopping triggered."); break

    def _validate(self):
        self.roberta_mlp_path.eval(); self.gat_path.eval(); self.fusion_model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in self.val_loader:
                p1 = self.roberta_mlp_path(batch['input_ids'].to(DEVICE), batch['attention_mask'].to(DEVICE))
                p2 = self.gat_path(batch['texts'])
                logits = self.fusion_model(p1, p2)
                all_preds.append((torch.sigmoid(logits) > 0.5).cpu().numpy())
                all_labels.append(batch['labels'].cpu().numpy())
        y_pred, y_true = np.vstack(all_preds), np.vstack(all_labels)
        return f1_score(y_true, y_pred, average='macro', zero_division=0)


# ===================================================================
# 5. Final Evaluation for Ablation Study (UPDATED)
# ===================================================================
def evaluate_roberta_mlp_only(test_loader, target_names):
    """Evaluates the standalone RoBERTa+MLP path on the test set."""
    print("\n" + "="*20 + " ABLATION: RoBERTa+MLP PATH ONLY " + "="*20)
    roberta_mlp_path = RobertaMLPPath().to(DEVICE)
    roberta_mlp_classifier = nn.Linear(256, len(target_names)).to(DEVICE)
    roberta_mlp_path.load_state_dict(torch.load("roberta_mlp_path_stage1.pth"))
    roberta_mlp_classifier.load_state_dict(torch.load("roberta_mlp_classifier_stage1.pth"))
    roberta_mlp_path.eval(); roberta_mlp_classifier.eval()

    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing RoBERTa+MLP"):
            features = roberta_mlp_path(batch['input_ids'].to(DEVICE), batch['attention_mask'].to(DEVICE))
            logits = roberta_mlp_classifier(features)
            all_preds.append((torch.sigmoid(logits) > 0.5).cpu().numpy())
            all_labels.append(batch['labels'].numpy())

    y_pred, y_true = np.vstack(all_preds), np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    print(classification_report(y_true, y_pred, target_names=target_names, zero_division=0))
    print(f"** Macro F1-Score: {macro_f1:.4f} **")

def evaluate_gat_only(test_loader, target_names):
    """Evaluates the standalone GAT path on the test set."""
    print("\n" + "="*20 + " ABLATION: GAT PATH ONLY " + "="*20)
    gat_path = GATeMFDPath(EMFD_CSV_PATH).to(DEVICE)
    gat_classifier = nn.Linear(256, len(target_names)).to(DEVICE)
    gat_path.load_state_dict(torch.load("gat_path_stage2.pth"))
    gat_classifier.load_state_dict(torch.load("gat_classifier_stage2.pth"))
    gat_path.eval(); gat_classifier.eval()

    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing GAT"):
            features = gat_path(batch['texts'])
            logits = gat_classifier(features)
            all_preds.append((torch.sigmoid(logits) > 0.5).cpu().numpy())
            all_labels.append(batch['labels'].numpy())

    y_pred, y_true = np.vstack(all_preds), np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    print(classification_report(y_true, y_pred, target_names=target_names, zero_division=0))
    print(f"** Macro F1-Score: {macro_f1:.4f} **")

def evaluate_full_fusion_model(model_path, test_loader, target_names):
    """Evaluates the final fused model on the test set."""
    print("\n" + "="*20 + " FINAL MODEL: FUSED (RoBERTa+GAT) " + "="*20)
    # Re-initialize the component paths and load their final states
    roberta_mlp_path = RobertaMLPPath().to(DEVICE)
    gat_path = GATeMFDPath(EMFD_CSV_PATH).to(DEVICE)
    roberta_mlp_path.load_state_dict(torch.load("roberta_mlp_path_final.pth"))
    gat_path.load_state_dict(torch.load("gat_path_final.pth"))

    fusion_model = CompleteFusionModel(num_classes=len(target_names)).to(DEVICE)
    fusion_model.load_state_dict(torch.load(model_path))
    roberta_mlp_path.eval(); gat_path.eval(); fusion_model.eval()

    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing Full Model"):
            p1 = roberta_mlp_path(batch['input_ids'].to(DEVICE), batch['attention_mask'].to(DEVICE))
            p2 = gat_path(batch['texts'])
            logits = fusion_model(p1, p2)
            all_preds.append((torch.sigmoid(logits) > 0.5).cpu().numpy())
            all_labels.append(batch['labels'].numpy())

    y_pred, y_true = np.vstack(all_preds), np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    print(classification_report(y_true, y_pred, target_names=target_names, zero_division=0))
    print(f"** Macro F1-Score: {macro_f1:.4f} **")

# ===================================================================
# 6. Main Execution Block (UPDATED)
# ===================================================================
if __name__ == "__main__":
    if not all([GATConv, Data]):
        print("❌ Critical PyTorch Geometric components not found. Aborting.")
    else:
        try:
            train_loader, val_loader, test_loader, target_names = create_data_loaders(
                csv_path=MFTC_CSV_PATH, batch_size=32)

            # --- NEW: Calculate class weights for Focal Loss ---
            print("⚖️ Calculating class weights for imbalanced data...")
            train_labels = np.array([data['labels'].numpy() for data in train_loader.dataset])
            num_positives = np.sum(train_labels, axis=0)
            num_negatives = len(train_labels) - num_positives
            pos_weight = np.where(num_positives > 0, num_negatives / num_positives, 1.0)
            pos_weight_tensor = torch.tensor(pos_weight, dtype=torch.float32).to(DEVICE)
            print(f"✅ Calculated pos_weight for {len(target_names)} classes: {pos_weight_tensor.cpu().numpy().round(2)}")

            roberta_mlp_path = RobertaMLPPath()
            gat_path = GATeMFDPath(emfd_csv_path=EMFD_CSV_PATH)
            fusion_model = CompleteFusionModel(num_classes=len(target_names))

            # --- UPDATED: Instantiate Trainer with Focal Loss ---
            trainer = MultiStageTrainer(
                roberta_mlp_path=roberta_mlp_path,
                gat_path=gat_path,
                fusion_model=fusion_model,
                train_loader=train_loader,
                val_loader=val_loader,
                num_classes=len(target_names),
                focal_loss_alpha=pos_weight_tensor  # Pass the calculated weights
            )

            # --- Execute Training Pipeline ---
            trainer.train_stage1()
            trainer.train_stage2()
            trainer.train_stage3_and_4(best_model_path="best_mftc_fusion_model.pth")

            # --- Execute Ablation Study Evaluation on Test Set ---
            print("\n" + "#"*70 + "\n# FINAL EVALUATION & ABLATION STUDY RESULTS\n" + "#"*70)

            evaluate_roberta_mlp_only(test_loader, target_names)
            evaluate_gat_only(test_loader, target_names)
            # --- UPDATED: Call evaluation function with new signature ---
            evaluate_full_fusion_model("best_mftc_fusion_model.pth", test_loader, target_names)

        except FileNotFoundError as e:
            print(f"❌ {e}")
            print("👉 Please update MFTC_CSV_PATH and EMFD_CSV_PATH at the top of the script.")
        except Exception as e:
            print(f"An unexpected error occurred: {e}")


In [None]:
print("helloe")