<a href="https://colab.research.google.com/github/INVISIBLE-SAM/Synergizing-Contextual-Semantics-and-Moral-Knowledge-Graphs-Moral-Foundation-Prediction/blob/main/MOTIVE_fusion_update_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m30.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


In [None]:
# prompt: download this filefrom gdrive https://drive.google.com/file/d/1DmKLanqTJSe1yGJ4FgJWJEfRwq26qgZd/view?usp=sharing

!gdown https://drive.google.com/uc?id=1DmKLanqTJSe1yGJ4FgJWJEfRwq26qgZd

Downloading...
From (original): https://drive.google.com/uc?id=1DmKLanqTJSe1yGJ4FgJWJEfRwq26qgZd
From (redirected): https://drive.google.com/uc?id=1DmKLanqTJSe1yGJ4FgJWJEfRwq26qgZd&confirm=t&uuid=2c1f87f1-15e9-42b5-be78-648c9f116974
To: /content/county_data_boarders.json
100% 201M/201M [00:03<00:00, 66.0MB/s]


In [None]:
import torch
import torch.nn as nn
from torch.nn import Sequential, Linear, ReLU, Dropout, Sigmoid, ModuleDict, LayerNorm, LSTM, GELU, MultiheadAttention
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score
from sklearn.metrics import classification_report
import pandas as pd
import numpy as np
import torch.nn.functional as F
import json
import math

# Updated moral foundations with virtue/vice combinations + Non-Moral
moral_foundations = ['Care', 'Fairness', 'Loyalty', 'Authority', 'Purity', 'Non_Moral']

class TextModelRoBERTaFrozen(nn.Module):
    def __init__(self):
        super().__init__()
        # Load RoBERTa model
        self.tokenizer = AutoTokenizer.from_pretrained('roberta-base')
        self.roberta = AutoModel.from_pretrained('roberta-base')

        # Freeze RoBERTa completely
        for param in self.roberta.parameters():
            param.requires_grad = False

        # Trainable processing layers
        self.processor = Sequential(
            Linear(768, 512), LayerNorm(512), GELU(), Dropout(0.1),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(0.1)
        )

        # Prediction heads
        self.heads = ModuleDict({
            foundation: Sequential(Linear(256, 128), ReLU(), Dropout(0.2),
                                 Linear(128, 1), Sigmoid())
            for foundation in moral_foundations
        })

    def forward(self, texts):
        # Tokenize and encode
        inputs = self.tokenizer(texts, padding=True, truncation=True,
                               max_length=512, return_tensors='pt')
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.roberta(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]  # CLS token

        features = self.processor(embeddings)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, texts):
        inputs = self.tokenizer(texts, padding=True, truncation=True,
                               max_length=512, return_tensors='pt')
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.roberta(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]
        return self.processor(embeddings)

class SpatialModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), ReLU(), Dropout(0.2),
            Linear(128, 256)
        )
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class TemporalModelLSTM(nn.Module):
    def __init__(self, input_dim, seq_len=10, hidden_size=128):
        super().__init__()
        self.projection = Linear(input_dim, hidden_size)
        self.lstm = LSTM(hidden_size, hidden_size, 2, batch_first=True,
                        dropout=0.2, bidirectional=True)
        self.encoder = Sequential(
            Linear(hidden_size * 2, 256), LayerNorm(256),
            ReLU(), Dropout(0.2)
        )
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        features = self.encoder(lstm_out[:, -1])
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        return self.encoder(lstm_out[:, -1])

class BehavioralModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), LayerNorm(128), ReLU(), Dropout(0.1),
            Linear(128, 256), LayerNorm(256), ReLU(), Dropout(0.1)
        )
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class HeterogeneousFusion(nn.Module):
    def __init__(self, d_model=256, num_heads=8):
        super().__init__()
        self.d_model = d_model
        self.modality_queries = nn.Parameter(torch.randn(4, d_model))

        self.cross_attention = nn.ModuleList([
            MultiheadAttention(d_model, num_heads, batch_first=True)
            for _ in range(4)
        ])

        self.self_attention = MultiheadAttention(d_model, num_heads, batch_first=True)
        self.layer_norms = nn.ModuleList([LayerNorm(d_model) for _ in range(5)])

        self.gate = Sequential(
            Linear(d_model * 4, d_model), ReLU(), Dropout(0.1),
            Linear(d_model, 4), Sigmoid()
        )

        self.heads = ModuleDict({
            f: Sequential(Linear(d_model, 128), ReLU(), Dropout(0.2),
                         Linear(128, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, text_feat, spatial_feat, temporal_feat, behavioral_feat):
        batch_size = text_feat.size(0)
        modality_feats = [f.unsqueeze(1) for f in [text_feat, spatial_feat, temporal_feat, behavioral_feat]]

        # Cross-attention
        attended = []
        for i, (feat, attn, norm) in enumerate(zip(modality_feats, self.cross_attention, self.layer_norms[:4])):
            query = self.modality_queries[i:i+1].unsqueeze(0).expand(batch_size, -1, -1)
            att_out, _ = attn(query, feat, feat)
            attended.append(norm(att_out + query))

        # Self-attention
        fused = torch.cat(attended, dim=1)
        self_att, _ = self.self_attention(fused, fused, fused)
        fused = self.layer_norms[4](fused + self_att)

        # Dynamic gating
        flat = fused.view(batch_size, -1)
        gates = self.gate(flat)

        gated = torch.stack([fused[:, i] * gates[:, i:i+1] for i in range(4)], dim=1).sum(dim=1)

        return {f: self.heads[f](gated) for f in moral_foundations}

def prepare_labels(df):
    """Combine virtue/vice pairs and create Non-Moral labels"""
    df = df.reset_index(drop=True)

    # Combine virtue/vice pairs using logical OR
    labels = {}
    labels['Care'] = (df['Care'] == 1) | (df['Harm'] == 1)
    labels['Fairness'] = (df['Fairness'] == 1) | (df.get('Cheating', df.get('cheating', pd.Series([False]*len(df), index=df.index))) == 1)
    labels['Loyalty'] = (df['Loyalty'] == 1) | (df['Betrayal'] == 1)
    labels['Authority'] = (df['Authority'] == 1) | (df['Subversion'] == 1)
    labels['Purity'] = (df['Purity'] == 1) | (df['Degradation'] == 1)

    # Create DataFrame from moral foundation labels
    moral_df = pd.DataFrame({
        'Care': labels['Care'],
        'Fairness': labels['Fairness'],
        'Loyalty': labels['Loyalty'],
        'Authority': labels['Authority'],
        'Purity': labels['Purity']
    })

    # Non_Moral = 1 when all moral foundations are 0 (False)
    # Using .any(axis=1) to check if any moral foundation is True
    # Then negate it to get Non_Moral
    labels['Non_Moral'] = (~moral_df.any(axis=1)).astype(int)

    # Convert to numpy arrays
    labels_array = np.column_stack([labels[f].astype(int).values for f in moral_foundations])
    return torch.tensor(labels_array, dtype=torch.float32)



def create_temporal_sequences(df, seq_len=10):
    """Create temporal sequences from dataset"""
    df_sorted = df.sort_values(['GEOID', 'year', 'month', 'day'])
    sequences, targets, other_features = [], [], []

    for geoid, group in df_sorted.groupby('GEOID'):
        if len(group) >= seq_len:
            for i in range(len(group) - seq_len + 1):
                seq_data = group.iloc[i:i+seq_len]

                # Temporal features
                temp_seq = []
                for _, row in seq_data.iterrows():
                    temp_features = [
                        row['month'], row['day'], row['year'], row['after_05/26/2020'],
                        row['female_pct'], row['lesscollege_pct'], row['cvap'],
                        row['clf_unemploy_pct'], row['lesshs_pct'], row['median_hh_inc'],
                        row['urm_pct'], row['ruralurban_cc'], row['demgov'], row['repgov'],
                        row['net_dem_president_votes'], row['net_dem_gov_votes'], row['is_blue'],
                        row['deaths'], row['cases'], row['cases_per_capita'],
                        row['deaths_per_capita'], row['cases_per_capita_discrete'], row['mask_score']
                    ]
                    temp_seq.append(temp_features)

                sequences.append(temp_seq)

                last_row = seq_data.iloc[-1]
                # Prepare labels for last timestep
                last_df = pd.DataFrame([last_row])
                targets.append(prepare_labels(last_df)[0])

                other_features.append({
                    'text': last_row['text'],
                    'geoid': str(last_row['GEOID']).zfill(5),
                    'behavioral': [last_row['retweet_count'], last_row['rt_discrete'],
                                 last_row['for_sah'], last_row['is_vivid'], last_row['sentiment_score']]
                })

    return sequences, targets, other_features

def prepare_dataset(csv_path, geojson_path, county_centroids_path, seq_len=10):
    """Streamlined dataset preparation"""
    # Load data
    df = pd.read_csv(csv_path)
    with open(geojson_path) as f:
        boundaries = json.load(f)
    with open(county_centroids_path) as f:
        centroids = json.load(f)

    # Create sequences
    sequences, targets, other_features = create_temporal_sequences(df, seq_len)

    # Process features
    data = {'text': [], 'spatial': [], 'temporal': [], 'behavioral': [], 'targets': []}

    for i, (seq, target, other) in enumerate(zip(sequences, targets, other_features)):
        data['text'].append(other['text'])
        data['temporal'].append(seq)
        data['behavioral'].append(other['behavioral'])
        data['targets'].append(target)

        # Extract spatial features
        geoid = other['geoid']
        if geoid in boundaries and geoid in centroids:
            # Simplified spatial feature extraction
            centroid = centroids[geoid]
            lat_norm = (centroid[0] - 24) / (49 - 24)  # Normalize lat
            lon_norm = (centroid[1] + 125) / (66 - 125)  # Normalize lon
            spatial_feat = [lat_norm, lon_norm, 0.5, 0.5, 0.5, 0.5]  # Simplified features
        else:
            spatial_feat = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]  # Default

        data['spatial'].append(spatial_feat)

    # Convert to tensors
    spatial_tensor = torch.tensor(data['spatial'], dtype=torch.float32)
    temporal_tensor = torch.tensor(data['temporal'], dtype=torch.float32)
    behavioral_tensor = torch.tensor(data['behavioral'], dtype=torch.float32)
    targets_tensor = torch.stack(data['targets'])

    # Normalize features
    scaler_spatial = StandardScaler()
    scaler_behavioral = StandardScaler()
    scaler_temporal = StandardScaler()

    spatial_tensor = torch.tensor(scaler_spatial.fit_transform(spatial_tensor), dtype=torch.float32)
    behavioral_tensor = torch.tensor(scaler_behavioral.fit_transform(behavioral_tensor), dtype=torch.float32)

    # Normalize temporal (reshape for scaling)
    temp_shape = temporal_tensor.shape
    temp_2d = temporal_tensor.view(-1, temp_shape[2])
    temp_scaled = scaler_temporal.fit_transform(temp_2d)
    temporal_tensor = torch.tensor(temp_scaled.reshape(temp_shape), dtype=torch.float32)

    # Create splits
    indices = list(range(len(data['text'])))
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

    datasets = {}
    for split, idx in [('train', train_idx), ('validation', val_idx), ('test', test_idx)]:
        datasets[split] = {
            'text_data': [data['text'][i] for i in idx],
            'spatial_features': spatial_tensor[idx],
            'temporal_sequences': temporal_tensor[idx],
            'behavioral_features': behavioral_tensor[idx],
            'moral_targets': targets_tensor[idx]
        }

    feature_info = {
        'spatial_feature_dim': spatial_tensor.shape[1],
        'temporal_feature_dim': temporal_tensor.shape[2],
        'temporal_sequence_length': seq_len,
        'behavioral_feature_dim': behavioral_tensor.shape[1]
    }

    return datasets, feature_info

def create_dataloader(dataset_dict, batch_size=16, shuffle=True):
    """Create DataLoader for the dataset"""
    def collate_fn(batch):
        texts = [item[0] for item in batch]
        spatial = torch.stack([item[1] for item in batch])
        temporal = torch.stack([item[2] for item in batch])
        behavioral = torch.stack([item[3] for item in batch])
        targets = torch.stack([item[4] for item in batch])

        return {
            'text_data': texts,
            'spatial_features': spatial,
            'temporal_sequences': temporal,
            'behavioral_features': behavioral,
            'targets': targets
        }

    dataset = [(dataset_dict['text_data'][i],
                dataset_dict['spatial_features'][i],
                dataset_dict['temporal_sequences'][i],
                dataset_dict['behavioral_features'][i],
                dataset_dict['moral_targets'][i])
               for i in range(len(dataset_dict['text_data']))]

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

def compute_loss(predictions, targets):
    """Compute BCE loss for moral foundation predictions"""
    loss = 0
    for i, foundation in enumerate(moral_foundations):
        loss += F.binary_cross_entropy(predictions[foundation], targets[:, i:i+1])
    return loss / len(moral_foundations)

def train_models(datasets, feature_info, device, epochs=30):
    """Streamlined training function"""
    train_loader = create_dataloader(datasets['train'], batch_size=16, shuffle=True)

    # Initialize models
    text_model = TextModelRoBERTaFrozen().to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)
    fusion_model = HeterogeneousFusion().to(device)

    models = [text_model, spatial_model, temporal_model, behavioral_model, fusion_model]

    # Train individual models first
    for model, name in zip(models[:-1], ['text', 'spatial', 'temporal', 'behavioral']):
        optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

        for epoch in range(epochs):
            model.train()
            total_loss = 0

            for batch in train_loader:
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                if name == 'text':
                    preds = model(batch['text_data'])
                elif name == 'spatial':
                    preds = model(batch['spatial_features'])
                elif name == 'temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # behavioral
                    preds = model(batch['behavioral_features'])

                loss = compute_loss(preds, batch['targets'])

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if epoch % 10 == 0:
                print(f"{name.capitalize()} model epoch {epoch}, loss: {total_loss/len(train_loader):.4f}")

    # Train fusion model
    all_params = []
    for model in models:
        all_params.extend(model.parameters())

    optimizer_fusion = AdamW(all_params, lr=5e-5, weight_decay=1e-4)

    for epoch in range(epochs):
        for model in models:
            model.train()

        total_loss = 0
        for batch in train_loader:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            # Extract features
            text_feat = text_model.get_features(batch['text_data'])
            spatial_feat = spatial_model.get_features(batch['spatial_features'])
            temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])

            # Fusion
            fusion_preds = fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
            loss = compute_loss(fusion_preds, batch['targets'])

            optimizer_fusion.zero_grad()
            loss.backward()
            optimizer_fusion.step()
            total_loss += loss.item()

        if epoch % 5 == 0:
            print(f"Fusion epoch {epoch}, loss: {total_loss/len(train_loader):.4f}")

    return text_model, spatial_model, temporal_model, behavioral_model, fusion_model

def evaluate_model(model, data_loader, device, model_type=None):
    """Evaluate model performance with classification report"""
    model.eval()
    all_preds = {f: [] for f in moral_foundations}
    all_targets = {f: [] for f in moral_foundations}

    with torch.no_grad():
        for batch in data_loader:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            if model_type == 'fusion':
                text_feat = batch['text_features']  # Pre-extracted
                spatial_feat = batch['spatial_features']
                temporal_feat = batch['temporal_features']
                behavioral_feat = batch['behavioral_features']
                preds = model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
            else:
                if model_type == 'text':
                    preds = model(batch['text_data'])
                elif model_type == 'spatial':
                    preds = model(batch['spatial_features'])
                elif model_type == 'temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # behavioral
                    preds = model(batch['behavioral_features'])

            targets = batch['targets']

            for i, foundation in enumerate(moral_foundations):
                pred = preds[foundation].cpu().numpy().flatten()
                target = targets[:, i].cpu().numpy()
                all_preds[foundation].extend(pred)
                all_targets[foundation].extend(target)

    # Compute metrics and classification reports
    metrics = {}
    reports = {}
    for foundation in moral_foundations:
        pred_binary = (np.array(all_preds[foundation]) > 0.5).astype(int)
        target_binary = np.array(all_targets[foundation]).astype(int)

        metrics[foundation] = {
            'accuracy': accuracy_score(target_binary, pred_binary),
            'f1': f1_score(target_binary, pred_binary, zero_division=0),
            'precision': precision_score(target_binary, pred_binary, zero_division=0),
            'recall': recall_score(target_binary, pred_binary, zero_division=0)
        }

        if len(np.unique(target_binary)) > 1:
            metrics[foundation]['auc'] = roc_auc_score(target_binary, all_preds[foundation])
        else:
            metrics[foundation]['auc'] = 0.0

        reports[foundation] = classification_report(target_binary, pred_binary, zero_division=0)

    return metrics, reports

# Main execution
if __name__ == "__main__":
    # Prepare dataset
    datasets, feature_info = prepare_dataset(
        csv_path='/content/augmented_tweets.csv',
        geojson_path='/content/county_data_boarders.json',
        county_centroids_path='/content/county_centroids.json',
        seq_len=10
    )

    # Train models
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    trained_models = train_models(datasets, feature_info, device)

    # Evaluate
    test_loader = create_dataloader(datasets['test'], batch_size=16, shuffle=False)

    model_names = ['text', 'spatial', 'temporal', 'behavioral']
    for model, name in zip(trained_models[:-1], model_names):
        metrics, reports = evaluate_model(model, test_loader, device, name)
        print(f"\n{name.upper()} MODEL RESULTS:")
        for foundation in moral_foundations:
            f1 = metrics[foundation]['f1']
            acc = metrics[foundation]['accuracy']
            print(f"  {foundation}: F1={f1:.3f}, Acc={acc:.3f}")

        print(f"\n{name.upper()} MODEL CLASSIFICATION REPORTS:")
        for foundation in moral_foundations:
            print(f"\n{foundation}:")
            print(reports[foundation])

    print("\n✅ Training and evaluation completed!")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Text model epoch 0, loss: 0.5471
Text model epoch 10, loss: 0.2996
Text model epoch 20, loss: 0.2851
Spatial model epoch 0, loss: 0.6682
Spatial model epoch 10, loss: 0.2920
Spatial model epoch 20, loss: 0.2825
Temporal model epoch 0, loss: 0.6218
Temporal model epoch 10, loss: 0.2749
Temporal model epoch 20, loss: 0.2463
Behavioral model epoch 0, loss: 0.6082
Behavioral model epoch 10, loss: 0.2519
Behavioral model epoch 20, loss: 0.2357
Fusion epoch 0, loss: 0.5124
Fusion epoch 5, loss: 0.2383
Fusion epoch 10, loss: 0.2005
Fusion epoch 15, loss: 0.1973
Fusion epoch 20, loss: 0.1638
Fusion epoch 25, loss: 0.1623

TEXT MODEL RESULTS:
  Care: F1=0.870, Acc=0.786
  Fairness: F1=0.000, Acc=1.000
  Loyalty: F1=0.000, Acc=0.839
  Authority: F1=0.000, Acc=0.929
  Purity: F1=0.000, Acc=0.982
  Non_Moral: F1=0.000, Acc=0.911

TEXT MODEL CLASSIFICATION REPORTS:

Care:
              precision    recall  f1-score   support

           0       0.50      0.33      0.40        12
           1       

# Motiv

In [None]:
import torch
import torch.nn as nn
from torch.nn import Sequential, Linear, ReLU, Dropout, Sigmoid, ModuleDict, LayerNorm, LSTM, GELU, MultiheadAttention
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score
from sklearn.metrics import classification_report
import pandas as pd
import numpy as np
import torch.nn.functional as F
import json
import math

# Updated moral foundations with virtue/vice combinations + Non-Moral
moral_foundations = ['Care', 'Fairness', 'Loyalty', 'Authority', 'Purity', 'Non_Moral']

class TextModelRoBERTaFrozen(nn.Module):
    def __init__(self):
        super().__init__()
        # Load RoBERTa model
        self.tokenizer = AutoTokenizer.from_pretrained('roberta-base')
        self.roberta = AutoModel.from_pretrained('roberta-base')

        # Freeze RoBERTa completely
        for param in self.roberta.parameters():
            param.requires_grad = False

        # Trainable processing layers
        self.processor = Sequential(
            Linear(768, 512), LayerNorm(512), GELU(), Dropout(0.1),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(0.1)
        )

        # Prediction heads
        self.heads = ModuleDict({
            foundation: Sequential(Linear(256, 128), ReLU(), Dropout(0.2),
                                 Linear(128, 1), Sigmoid())
            for foundation in moral_foundations
        })

    def forward(self, texts):
        # Tokenize and encode
        inputs = self.tokenizer(texts, padding=True, truncation=True,
                               max_length=512, return_tensors='pt')
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.roberta(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]  # CLS token

        features = self.processor(embeddings)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, texts):
        inputs = self.tokenizer(texts, padding=True, truncation=True,
                               max_length=512, return_tensors='pt')
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.roberta(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]
        return self.processor(embeddings)

class SpatialModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), ReLU(), Dropout(0.2),
            Linear(128, 256)
        )
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class TemporalModelLSTM(nn.Module):
    def __init__(self, input_dim, seq_len=10, hidden_size=128):
        super().__init__()
        self.projection = Linear(input_dim, hidden_size)
        self.lstm = LSTM(hidden_size, hidden_size, 2, batch_first=True,
                        dropout=0.2, bidirectional=True)
        self.encoder = Sequential(
            Linear(hidden_size * 2, 256), LayerNorm(256),
            ReLU(), Dropout(0.2)
        )
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        features = self.encoder(lstm_out[:, -1])
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        return self.encoder(lstm_out[:, -1])

class BehavioralModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), LayerNorm(128), ReLU(), Dropout(0.1),
            Linear(128, 256), LayerNorm(256), ReLU(), Dropout(0.1)
        )
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class HeterogeneousFusion(nn.Module):
    def __init__(self, d_model=256, num_heads=8):
        super().__init__()
        self.d_model = d_model
        self.modality_queries = nn.Parameter(torch.randn(4, d_model))

        self.cross_attention = nn.ModuleList([
            MultiheadAttention(d_model, num_heads, batch_first=True)
            for _ in range(4)
        ])

        self.self_attention = MultiheadAttention(d_model, num_heads, batch_first=True)
        self.layer_norms = nn.ModuleList([LayerNorm(d_model) for _ in range(5)])

        self.gate = Sequential(
            Linear(d_model * 4, d_model), ReLU(), Dropout(0.1),
            Linear(d_model, 4), Sigmoid()
        )

        self.heads = ModuleDict({
            f: Sequential(Linear(d_model, 128), ReLU(), Dropout(0.2),
                         Linear(128, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, text_feat, spatial_feat, temporal_feat, behavioral_feat):
        batch_size = text_feat.size(0)
        modality_feats = [f.unsqueeze(1) for f in [text_feat, spatial_feat, temporal_feat, behavioral_feat]]

        # Cross-attention
        attended = []
        for i, (feat, attn, norm) in enumerate(zip(modality_feats, self.cross_attention, self.layer_norms[:4])):
            query = self.modality_queries[i:i+1].unsqueeze(0).expand(batch_size, -1, -1)
            att_out, _ = attn(query, feat, feat)
            attended.append(norm(att_out + query))

        # Self-attention
        fused = torch.cat(attended, dim=1)
        self_att, _ = self.self_attention(fused, fused, fused)
        fused = self.layer_norms[4](fused + self_att)

        # Dynamic gating
        flat = fused.view(batch_size, -1)
        gates = self.gate(flat)

        gated = torch.stack([fused[:, i] * gates[:, i:i+1] for i in range(4)], dim=1).sum(dim=1)

        return {f: self.heads[f](gated) for f in moral_foundations}

def prepare_labels(df):
    """Combine virtue/vice pairs and create Non-Moral labels"""
    df = df.reset_index(drop=True)

    # Combine virtue/vice pairs using logical OR
    labels = {}
    labels['Care'] = (df['Care'] == 1) | (df['Harm'] == 1)
    labels['Fairness'] = (df['Fairness'] == 1) | (df.get('Cheating', df.get('cheating', pd.Series([False]*len(df), index=df.index))) == 1)
    labels['Loyalty'] = (df['Loyalty'] == 1) | (df['Betrayal'] == 1)
    labels['Authority'] = (df['Authority'] == 1) | (df['Subversion'] == 1)
    labels['Purity'] = (df['Purity'] == 1) | (df['Degradation'] == 1)

    # Create DataFrame from moral foundation labels
    moral_df = pd.DataFrame({
        'Care': labels['Care'],
        'Fairness': labels['Fairness'],
        'Loyalty': labels['Loyalty'],
        'Authority': labels['Authority'],
        'Purity': labels['Purity']
    })

    # Non_Moral = 1 when all moral foundations are 0 (False)
    # Using .any(axis=1) to check if any moral foundation is True
    # Then negate it to get Non_Moral
    labels['Non_Moral'] = (~moral_df.any(axis=1)).astype(int)

    # Convert to numpy arrays
    labels_array = np.column_stack([labels[f].astype(int).values for f in moral_foundations])
    return torch.tensor(labels_array, dtype=torch.float32)

def create_temporal_sequences(df, seq_len=10):
    """Create temporal sequences from dataset"""
    df_sorted = df.sort_values(['GEOID', 'year', 'month', 'day'])
    sequences, targets, other_features = [], [], []

    for geoid, group in df_sorted.groupby('GEOID'):
        if len(group) >= seq_len:
            for i in range(len(group) - seq_len + 1):
                seq_data = group.iloc[i:i+seq_len]

                # Temporal features
                temp_seq = []
                for _, row in seq_data.iterrows():
                    temp_features = [
                        row['month'], row['day'], row['year'], row['after_05/26/2020'],
                        row['female_pct'], row['lesscollege_pct'], row['cvap'],
                        row['clf_unemploy_pct'], row['lesshs_pct'], row['median_hh_inc'],
                        row['urm_pct'], row['ruralurban_cc'], row['demgov'], row['repgov'],
                        row['net_dem_president_votes'], row['net_dem_gov_votes'], row['is_blue'],
                        row['deaths'], row['cases'], row['cases_per_capita'],
                        row['deaths_per_capita'], row['cases_per_capita_discrete'], row['mask_score']
                    ]
                    temp_seq.append(temp_features)

                sequences.append(temp_seq)

                last_row = seq_data.iloc[-1]
                # Prepare labels for last timestep
                last_df = pd.DataFrame([last_row])
                targets.append(prepare_labels(last_df)[0])

                other_features.append({
                    'text': last_row['text'],
                    'geoid': str(last_row['GEOID']).zfill(5),
                    'behavioral': [last_row['retweet_count'], last_row['rt_discrete'],
                                 last_row['for_sah'], last_row['is_vivid'], last_row['sentiment_score']]
                })

    return sequences, targets, other_features

def prepare_dataset(csv_path, geojson_path, county_centroids_path, seq_len=10):
    """Streamlined dataset preparation"""
    # Load data
    df = pd.read_csv(csv_path)
    with open(geojson_path) as f:
        boundaries = json.load(f)
    with open(county_centroids_path) as f:
        centroids = json.load(f)

    # Create sequences
    sequences, targets, other_features = create_temporal_sequences(df, seq_len)

    # Process features
    data = {'text': [], 'spatial': [], 'temporal': [], 'behavioral': [], 'targets': []}

    for i, (seq, target, other) in enumerate(zip(sequences, targets, other_features)):
        data['text'].append(other['text'])
        data['temporal'].append(seq)
        data['behavioral'].append(other['behavioral'])
        data['targets'].append(target)

        # Extract spatial features
        geoid = other['geoid']
        if geoid in boundaries and geoid in centroids:
            # Simplified spatial feature extraction
            centroid = centroids[geoid]
            lat_norm = (centroid[0] - 24) / (49 - 24)  # Normalize lat
            lon_norm = (centroid[1] + 125) / (66 - 125)  # Normalize lon
            spatial_feat = [lat_norm, lon_norm, 0.5, 0.5, 0.5, 0.5]  # Simplified features
        else:
            spatial_feat = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]  # Default

        data['spatial'].append(spatial_feat)

    # Convert to tensors
    spatial_tensor = torch.tensor(data['spatial'], dtype=torch.float32)
    temporal_tensor = torch.tensor(data['temporal'], dtype=torch.float32)
    behavioral_tensor = torch.tensor(data['behavioral'], dtype=torch.float32)
    targets_tensor = torch.stack(data['targets'])

    # Normalize features
    scaler_spatial = StandardScaler()
    scaler_behavioral = StandardScaler()
    scaler_temporal = StandardScaler()

    spatial_tensor = torch.tensor(scaler_spatial.fit_transform(spatial_tensor), dtype=torch.float32)
    behavioral_tensor = torch.tensor(scaler_behavioral.fit_transform(behavioral_tensor), dtype=torch.float32)

    # Normalize temporal (reshape for scaling)
    temp_shape = temporal_tensor.shape
    temp_2d = temporal_tensor.view(-1, temp_shape[2])
    temp_scaled = scaler_temporal.fit_transform(temp_2d)
    temporal_tensor = torch.tensor(temp_scaled.reshape(temp_shape), dtype=torch.float32)

    # Create splits
    indices = list(range(len(data['text'])))
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

    datasets = {}
    for split, idx in [('train', train_idx), ('validation', val_idx), ('test', test_idx)]:
        datasets[split] = {
            'text_data': [data['text'][i] for i in idx],
            'spatial_features': spatial_tensor[idx],
            'temporal_sequences': temporal_tensor[idx],
            'behavioral_features': behavioral_tensor[idx],
            'moral_targets': targets_tensor[idx]
        }

    feature_info = {
        'spatial_feature_dim': spatial_tensor.shape[1],
        'temporal_feature_dim': temporal_tensor.shape[2],
        'temporal_sequence_length': seq_len,
        'behavioral_feature_dim': behavioral_tensor.shape[1]
    }

    return datasets, feature_info

def create_dataloader(dataset_dict, batch_size=16, shuffle=True):
    """Create DataLoader for the dataset"""
    def collate_fn(batch):
        texts = [item[0] for item in batch]
        spatial = torch.stack([item[1] for item in batch])
        temporal = torch.stack([item[2] for item in batch])
        behavioral = torch.stack([item[3] for item in batch])
        targets = torch.stack([item[4] for item in batch])

        return {
            'text_data': texts,
            'spatial_features': spatial,
            'temporal_sequences': temporal,
            'behavioral_features': behavioral,
            'targets': targets
        }

    dataset = [(dataset_dict['text_data'][i],
                dataset_dict['spatial_features'][i],
                dataset_dict['temporal_sequences'][i],
                dataset_dict['behavioral_features'][i],
                dataset_dict['moral_targets'][i])
               for i in range(len(dataset_dict['text_data']))]

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

def compute_loss(predictions, targets):
    """Compute BCE loss for moral foundation predictions"""
    loss = 0
    for i, foundation in enumerate(moral_foundations):
        loss += F.binary_cross_entropy(predictions[foundation], targets[:, i:i+1])
    return loss / len(moral_foundations)

def train_models(datasets, feature_info, device, epochs=30):
    """Streamlined training function"""
    train_loader = create_dataloader(datasets['train'], batch_size=16, shuffle=True)

    # Initialize models
    text_model = TextModelRoBERTaFrozen().to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)
    fusion_model = HeterogeneousFusion().to(device)

    models = [text_model, spatial_model, temporal_model, behavioral_model, fusion_model]

    # Train individual models first
    for model, name in zip(models[:-1], ['text', 'spatial', 'temporal', 'behavioral']):
        optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

        for epoch in range(epochs):
            model.train()
            total_loss = 0

            for batch in train_loader:
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                if name == 'text':
                    preds = model(batch['text_data'])
                elif name == 'spatial':
                    preds = model(batch['spatial_features'])
                elif name == 'temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # behavioral
                    preds = model(batch['behavioral_features'])

                loss = compute_loss(preds, batch['targets'])

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if epoch % 10 == 0:
                print(f"{name.capitalize()} model epoch {epoch}, loss: {total_loss/len(train_loader):.4f}")

    # Train fusion model
    all_params = []
    for model in models:
        all_params.extend(model.parameters())

    optimizer_fusion = AdamW(all_params, lr=5e-5, weight_decay=1e-4)

    for epoch in range(epochs):
        for model in models:
            model.train()

        total_loss = 0
        for batch in train_loader:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            # Extract features
            text_feat = text_model.get_features(batch['text_data'])
            spatial_feat = spatial_model.get_features(batch['spatial_features'])
            temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])

            # Fusion
            fusion_preds = fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
            loss = compute_loss(fusion_preds, batch['targets'])

            optimizer_fusion.zero_grad()
            loss.backward()
            optimizer_fusion.step()
            total_loss += loss.item()

        if epoch % 5 == 0:
            print(f"Fusion epoch {epoch}, loss: {total_loss/len(train_loader):.4f}")

    return text_model, spatial_model, temporal_model, behavioral_model, fusion_model

def evaluate_model_comprehensive(model, data_loader, device, model_type=None, text_model=None, spatial_model=None, temporal_model=None, behavioral_model=None):
    """Enhanced evaluation with loss calculation and comprehensive metrics"""
    model.eval()
    all_preds = {f: [] for f in moral_foundations}
    all_targets = {f: [] for f in moral_foundations}
    total_loss = 0
    num_batches = 0

    with torch.no_grad():
        for batch in data_loader:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            if model_type == 'fusion':
                # Extract features for fusion
                text_feat = text_model.get_features(batch['text_data'])
                spatial_feat = spatial_model.get_features(batch['spatial_features'])
                temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
                behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])
                preds = model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
            else:
                if model_type == 'text':
                    preds = model(batch['text_data'])
                elif model_type == 'spatial':
                    preds = model(batch['spatial_features'])
                elif model_type == 'temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # behavioral
                    preds = model(batch['behavioral_features'])

            targets = batch['targets']

            # Calculate loss
            loss = compute_loss(preds, targets)
            total_loss += loss.item()
            num_batches += 1

            for i, foundation in enumerate(moral_foundations):
                pred = preds[foundation].cpu().numpy().flatten()
                target = targets[:, i].cpu().numpy()
                all_preds[foundation].extend(pred)
                all_targets[foundation].extend(target)

    # Compute metrics and classification reports
    metrics = {}
    reports = {}
    avg_metrics = {'accuracy': 0, 'f1': 0, 'precision': 0, 'recall': 0, 'auc': 0}

    for foundation in moral_foundations:
        pred_binary = (np.array(all_preds[foundation]) > 0.5).astype(int)
        target_binary = np.array(all_targets[foundation]).astype(int)

        metrics[foundation] = {
            'accuracy': accuracy_score(target_binary, pred_binary),
            'f1': f1_score(target_binary, pred_binary, zero_division=0),
            'precision': precision_score(target_binary, pred_binary, zero_division=0),
            'recall': recall_score(target_binary, pred_binary, zero_division=0)
        }

        if len(np.unique(target_binary)) > 1:
            metrics[foundation]['auc'] = roc_auc_score(target_binary, all_preds[foundation])
        else:
            metrics[foundation]['auc'] = 0.0

        reports[foundation] = classification_report(target_binary, pred_binary, zero_division=0)

        # Add to averages
        for metric in avg_metrics:
            avg_metrics[metric] += metrics[foundation][metric]

    # Calculate averages
    for metric in avg_metrics:
        avg_metrics[metric] /= len(moral_foundations)

    avg_metrics['loss'] = total_loss / num_batches

    return metrics, reports, avg_metrics

def extract_fusion_weights(fusion_model):
    """Extract learned modality weights from fusion model"""
    try:
        # Create dummy input to get weights
        dummy_input = torch.randn(1, fusion_model.d_model * 4).to(next(fusion_model.parameters()).device)
        weights = fusion_model.gate(dummy_input).squeeze().cpu().numpy()
        return {
            'Text': weights[0],
            'Spatial': weights[1],
            'Temporal': weights[2],
            'Behavioral': weights[3]
        }
    except:
        # Default equal weights if extraction fails
        return {'Text': 0.25, 'Spatial': 0.25, 'Temporal': 0.25, 'Behavioral': 0.25}

def print_comprehensive_results(all_metrics, all_avg_metrics, fusion_weights):
    """Print results in the desired format"""
    print("COMPREHENSIVE MULTIMODAL EVALUATION RESULTS")
    print("=" * 80)
    print()

    model_names = ['TEXT', 'SPATIAL', 'TEMPORAL', 'BEHAVIORAL', 'FUSION']
    model_keys = ['text', 'spatial', 'temporal', 'behavioral', 'fusion']

    # Individual model results
    for name, key in zip(model_names, model_keys):
        if key in all_avg_metrics:
            avg = all_avg_metrics[key]
            print(f"{name} MODEL RESULTS:")
            print("-" * 50)
            print(f"  Average Loss:      {avg['loss']:.4f}")
            print(f"  Average Accuracy:  {avg['accuracy']:.4f}")
            print(f"  Average F1 Score:  {avg['f1']:.4f}")
            print(f"  Average Precision: {avg['precision']:.4f}")
            print(f"  Average Recall:    {avg['recall']:.4f}")
            print(f"  Average AUC-ROC:   {avg['auc']:.4f}")
            print()

    # Fusion model weights
    if 'fusion' in all_avg_metrics and fusion_weights:
        print("  Learned Modality Weights:")
        for modality, weight in fusion_weights.items():
            print(f"    {modality:<11}: {weight:.3f}")
        print()

    # Foundation-specific comparison
    print("FOUNDATION-SPECIFIC COMPARISON:")
    print("-" * 80)
    header = "Foundation   Text F1  Spatial F1 Temporal F1 Behavioral F1"
    if 'fusion' in all_metrics:
        header += " Fusion F1"
    print(header)
    print("-" * 80)

    for foundation in moral_foundations:
        row = f"{foundation:<12}"
        for key in ['text', 'spatial', 'temporal', 'behavioral']:
            if key in all_metrics:
                f1_score = all_metrics[key][foundation]['f1']
                row += f" {f1_score:.3f}     "
        if 'fusion' in all_metrics:
            f1_score = all_metrics['fusion'][foundation]['f1']
            row += f" {f1_score:.3f}    "
        print(row)

# Main execution
if __name__ == "__main__":
    # Prepare dataset
    datasets, feature_info = prepare_dataset(
        csv_path='/content/augmented_tweets.csv',
        geojson_path='/content/county_data_boarders.json',
        county_centroids_path='/content/county_centroids.json',
        seq_len=10
    )

    # Train models
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    text_model, spatial_model, temporal_model, behavioral_model, fusion_model = train_models(datasets, feature_info, device)

    # Evaluate all models
    test_loader = create_dataloader(datasets['test'], batch_size=16, shuffle=False)

    all_metrics = {}
    all_avg_metrics = {}
    all_reports = {}

    model_list = [
        (text_model, 'text'),
        (spatial_model, 'spatial'),
        (temporal_model, 'temporal'),
        (behavioral_model, 'behavioral')
    ]

    # Evaluate individual models
    for model, name in model_list:
        metrics, reports, avg_metrics = evaluate_model_comprehensive(
            model, test_loader, device, name
        )
        all_metrics[name] = metrics
        all_avg_metrics[name] = avg_metrics
        all_reports[name] = reports

    # Evaluate fusion model
    metrics, reports, avg_metrics = evaluate_model_comprehensive(
        fusion_model, test_loader, device, 'fusion',
        text_model, spatial_model, temporal_model, behavioral_model
    )
    all_metrics['fusion'] = metrics
    all_avg_metrics['fusion'] = avg_metrics
    all_reports['fusion'] = reports

    # Extract fusion weights
    fusion_weights = extract_fusion_weights(fusion_model)

    # Print comprehensive results
    print_comprehensive_results(all_metrics, all_avg_metrics, fusion_weights)

    print("\n✅ Training and evaluation completed!")


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Text model epoch 0, loss: 0.5489
Text model epoch 10, loss: 0.2857
Text model epoch 20, loss: 0.3004
Spatial model epoch 0, loss: 0.6586
Spatial model epoch 10, loss: 0.2995
Spatial model epoch 20, loss: 0.3076
Temporal model epoch 0, loss: 0.6251
Temporal model epoch 10, loss: 0.2768
Temporal model epoch 20, loss: 0.2417
Behavioral model epoch 0, loss: 0.6007
Behavioral model epoch 10, loss: 0.2623
Behavioral model epoch 20, loss: 0.2466
Fusion epoch 0, loss: 0.5482
Fusion epoch 5, loss: 0.2658
Fusion epoch 10, loss: 0.1994
Fusion epoch 15, loss: 0.1912
Fusion epoch 20, loss: 0.1634
Fusion epoch 25, loss: 0.1457
COMPREHENSIVE MULTIMODAL EVALUATION RESULTS

TEXT MODEL RESULTS:
--------------------------------------------------
  Average Loss:      0.3725
  Average Accuracy:  0.8780
  Average F1 Score:  0.1364
  Average Precision: 0.1364
  Average Recall:    0.1364
  Average AUC-ROC:   0.6079

SPATIAL MODEL RESULTS:
--------------------------------------------------
  Average Loss:     

# MLSMOTE

In [None]:
import torch
import torch.nn as nn
from torch.nn import Sequential, Linear, ReLU, Dropout, Sigmoid, ModuleDict, LayerNorm, LSTM, GELU, MultiheadAttention
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, classification_report
import pandas as pd
import numpy as np
import torch.nn.functional as F
import json
import math
from collections import Counter
import random

# Updated moral foundations
moral_foundations = ['Care', 'Fairness', 'Loyalty', 'Authority', 'Purity', 'Non_Moral']

class MLSMOTE:
    """Multi-label Synthetic Minority Over-sampling Technique"""

    def __init__(self, k=5, n_synthetic=1.0):
        self.k = k  # Number of nearest neighbors
        self.n_synthetic = n_synthetic  # Synthetic sample ratio
        self.knn = NearestNeighbors(n_neighbors=k+1, metric='euclidean')

    def calculate_imbalance_ratio(self, labels):
        """Calculate imbalance ratio per label (IRPL)"""[5]
        n_samples = len(labels)
        irpl = {}

        for i, foundation in enumerate(moral_foundations):
            positive_count = np.sum(labels[:, i])
            if positive_count > 0:
                irpl[foundation] = (n_samples - positive_count) / positive_count
            else:
                irpl[foundation] = float('inf')

        return irpl

    def identify_minority_labels(self, labels):
        """Identify minority labels based on Mean Imbalance Ratio"""[5]
        irpl = self.calculate_imbalance_ratio(labels)

        # Calculate Mean Imbalance Ratio (MIR)
        valid_ratios = [ratio for ratio in irpl.values() if ratio != float('inf')]
        mir = np.mean(valid_ratios) if valid_ratios else 1.0

        # Labels with IRPL > MIR are considered minority
        minority_labels = [label for label, ratio in irpl.items() if ratio > mir]

        return minority_labels, irpl

    def generate_synthetic_features(self, features, minority_indices):
        """Generate synthetic feature vectors using interpolation"""[3]
        synthetic_features = []

        if len(minority_indices) < 2:
            return np.array([])

        # Fit KNN on minority samples
        minority_features = features[minority_indices]
        self.knn.fit(minority_features)

        n_synthetic_samples = int(len(minority_indices) * self.n_synthetic)

        for _ in range(n_synthetic_samples):
            # Randomly select a minority sample
            idx = random.choice(range(len(minority_indices)))
            sample = minority_features[idx:idx+1]

            # Find k nearest neighbors
            distances, indices = self.knn.kneighbors(sample)
            neighbor_indices = indices[0][1:]  # Exclude the sample itself

            if len(neighbor_indices) > 0:
                # Randomly select a neighbor
                neighbor_idx = random.choice(neighbor_indices)
                neighbor = minority_features[neighbor_idx]

                # Generate synthetic sample using interpolation
                alpha = random.random()
                synthetic_sample = sample[0] + alpha * (neighbor - sample[0])
                synthetic_features.append(synthetic_sample)

        return np.array(synthetic_features)

    def generate_synthetic_labels(self, labels, minority_indices, n_synthetic):
        """Generate synthetic labelsets for new samples"""[3]
        synthetic_labels = []
        minority_labels = labels[minority_indices]

        for _ in range(n_synthetic):
            # Strategy: Union of random minority samples
            if len(minority_indices) >= 2:
                # Select two random minority samples
                idx1, idx2 = random.sample(range(len(minority_indices)), 2)
                label1 = minority_labels[idx1]
                label2 = minority_labels[idx2]

                # Create synthetic label by combining (logical OR)
                synthetic_label = np.logical_or(label1, label2).astype(int)
            else:
                # If only one minority sample, use it directly
                synthetic_label = minority_labels[0]

            synthetic_labels.append(synthetic_label)

        return np.array(synthetic_labels)

    def fit_resample(self, features, labels):
        """Apply MLSMOTE to generate synthetic samples"""[4]
        # Identify minority labels
        minority_labels, irpl = self.identify_minority_labels(labels)

        if not minority_labels:
            print("No minority labels identified. Returning original data.")
            return features, labels

        print(f"Identified minority labels: {minority_labels}")
        print(f"Imbalance ratios: {irpl}")

        # Collect all minority instances
        minority_mask = np.zeros(len(labels), dtype=bool)

        for i, foundation in enumerate(moral_foundations):
            if foundation in minority_labels:
                minority_mask |= (labels[:, i] == 1)

        minority_indices = np.where(minority_mask)[0]

        if len(minority_indices) < 2:
            print("Insufficient minority samples for MLSMOTE. Returning original data.")
            return features, labels

        # Generate synthetic features
        synthetic_features = self.generate_synthetic_features(features, minority_indices)

        if len(synthetic_features) == 0:
            print("No synthetic features generated. Returning original data.")
            return features, labels

        # Generate synthetic labels
        synthetic_labels = self.generate_synthetic_labels(labels, minority_indices, len(synthetic_features))

        # Combine original and synthetic data
        combined_features = np.vstack([features, synthetic_features])
        combined_labels = np.vstack([labels, synthetic_labels])

        print(f"Generated {len(synthetic_features)} synthetic samples")
        print(f"Original dataset size: {len(features)}, Augmented dataset size: {len(combined_features)}")

        return combined_features, combined_labels

# Neural Network Models (same as before but optimized)
class TextModelRoBERTaFrozen(nn.Module):
    def __init__(self, dropout_rate=0.1):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained('roberta-base')
        self.roberta = AutoModel.from_pretrained('roberta-base')

        # Freeze RoBERTa parameters
        for param in self.roberta.parameters():
            param.requires_grad = False

        self.processor = Sequential(
            Linear(768, 512), LayerNorm(512), GELU(), Dropout(dropout_rate),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(dropout_rate)
        )

        self.heads = ModuleDict({
            foundation: Sequential(
                Linear(256, 128), ReLU(), Dropout(dropout_rate*2),
                Linear(128, 1), Sigmoid()
            ) for foundation in moral_foundations
        })

    def forward(self, texts):
        inputs = self.tokenizer(texts, padding=True, truncation=True,
                               max_length=512, return_tensors='pt')
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.roberta(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]

        features = self.processor(embeddings)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, texts):
        inputs = self.tokenizer(texts, padding=True, truncation=True,
                               max_length=512, return_tensors='pt')
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.roberta(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]
        return self.processor(embeddings)

class SpatialModel(nn.Module):
    def __init__(self, input_dim, dropout_rate=0.2):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), ReLU(), Dropout(dropout_rate),
            Linear(128, 256), ReLU(), Dropout(dropout_rate)
        )
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class TemporalModelLSTM(nn.Module):
    def __init__(self, input_dim, seq_len=10, hidden_size=128, dropout_rate=0.2):
        super().__init__()
        self.projection = Linear(input_dim, hidden_size)
        self.lstm = LSTM(hidden_size, hidden_size, 2, batch_first=True,
                        dropout=dropout_rate, bidirectional=True)
        self.encoder = Sequential(
            Linear(hidden_size * 2, 256), LayerNorm(256),
            ReLU(), Dropout(dropout_rate)
        )
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        features = self.encoder(lstm_out[:, -1])
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        return self.encoder(lstm_out[:, -1])

class BehavioralModel(nn.Module):
    def __init__(self, input_dim, dropout_rate=0.1):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), LayerNorm(128), ReLU(), Dropout(dropout_rate),
            Linear(128, 256), LayerNorm(256), ReLU(), Dropout(dropout_rate)
        )
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class HeterogeneousFusion(nn.Module):
    def __init__(self, d_model=256, num_heads=8, dropout_rate=0.1):
        super().__init__()
        self.d_model = d_model
        self.modality_queries = nn.Parameter(torch.randn(4, d_model))

        self.cross_attention = nn.ModuleList([
            MultiheadAttention(d_model, num_heads, dropout=dropout_rate, batch_first=True)
            for _ in range(4)
        ])

        self.self_attention = MultiheadAttention(d_model, num_heads, dropout=dropout_rate, batch_first=True)
        self.layer_norms = nn.ModuleList([LayerNorm(d_model) for _ in range(5)])

        self.gate = Sequential(
            Linear(d_model * 4, d_model), ReLU(), Dropout(dropout_rate),
            Linear(d_model, 4), Sigmoid()
        )

        self.heads = ModuleDict({
            f: Sequential(
                Linear(d_model, 128), ReLU(), Dropout(dropout_rate*2),
                Linear(128, 1), Sigmoid()
            ) for f in moral_foundations
        })

    def forward(self, text_feat, spatial_feat, temporal_feat, behavioral_feat):
        batch_size = text_feat.size(0)
        modality_feats = [f.unsqueeze(1) for f in [text_feat, spatial_feat, temporal_feat, behavioral_feat]]

        # Cross-attention
        attended = []
        for i, (feat, attn, norm) in enumerate(zip(modality_feats, self.cross_attention, self.layer_norms[:4])):
            query = self.modality_queries[i:i+1].unsqueeze(0).expand(batch_size, -1, -1)
            att_out, _ = attn(query, feat, feat)
            attended.append(norm(att_out + query))

        # Self-attention
        fused = torch.cat(attended, dim=1)
        self_att, _ = self.self_attention(fused, fused, fused)
        fused = self.layer_norms[4](fused + self_att)

        # Dynamic gating
        flat = fused.view(batch_size, -1)
        gates = self.gate(flat)

        gated = torch.stack([fused[:, i] * gates[:, i:i+1] for i in range(4)], dim=1).sum(dim=1)

        return {f: self.heads[f](gated) for f in moral_foundations}

# Data preparation functions
def prepare_labels(df):
    """Combine virtue/vice pairs and create Non-Moral labels"""
    df = df.reset_index(drop=True)

    labels = {}
    labels['Care'] = (df['Care'] == 1) | (df['Harm'] == 1)
    labels['Fairness'] = (df['Fairness'] == 1) | (df.get('Cheating', df.get('cheating', pd.Series([False]*len(df), index=df.index))) == 1)
    labels['Loyalty'] = (df['Loyalty'] == 1) | (df['Betrayal'] == 1)
    labels['Authority'] = (df['Authority'] == 1) | (df['Subversion'] == 1)
    labels['Purity'] = (df['Purity'] == 1) | (df['Degradation'] == 1)

    moral_df = pd.DataFrame({
        'Care': labels['Care'],
        'Fairness': labels['Fairness'],
        'Loyalty': labels['Loyalty'],
        'Authority': labels['Authority'],
        'Purity': labels['Purity']
    })

    labels['Non_Moral'] = (~moral_df.any(axis=1)).astype(int)

    labels_array = np.column_stack([labels[f].astype(int).values for f in moral_foundations])
    return torch.tensor(labels_array, dtype=torch.float32)

def create_temporal_sequences(df, seq_len=10):
    """Create temporal sequences from dataset"""
    df_sorted = df.sort_values(['GEOID', 'year', 'month', 'day'])
    sequences, targets, other_features = [], [], []

    for geoid, group in df_sorted.groupby('GEOID'):
        if len(group) >= seq_len:
            for i in range(len(group) - seq_len + 1):
                seq_data = group.iloc[i:i+seq_len]

                temp_seq = []
                for _, row in seq_data.iterrows():
                    temp_features = [
                        row['month'], row['day'], row['year'], row['after_05/26/2020'],
                        row['female_pct'], row['lesscollege_pct'], row['cvap'],
                        row['clf_unemploy_pct'], row['lesshs_pct'], row['median_hh_inc'],
                        row['urm_pct'], row['ruralurban_cc'], row['demgov'], row['repgov'],
                        row['net_dem_president_votes'], row['net_dem_gov_votes'], row['is_blue'],
                        row['deaths'], row['cases'], row['cases_per_capita'],
                        row['deaths_per_capita'], row['cases_per_capita_discrete'], row['mask_score']
                    ]
                    temp_seq.append(temp_features)

                sequences.append(temp_seq)

                last_row = seq_data.iloc[-1]
                last_df = pd.DataFrame([last_row])
                targets.append(prepare_labels(last_df)[0])

                other_features.append({
                    'text': last_row['text'],
                    'geoid': str(last_row['GEOID']).zfill(5),
                    'behavioral': [last_row['retweet_count'], last_row['rt_discrete'],
                                 last_row['for_sah'], last_row['is_vivid'], last_row['sentiment_score']]
                })

    return sequences, targets, other_features

def apply_mlsmote_to_dataset(dataset_dict, feature_info):
    """Apply MLSMOTE to training dataset"""
    print("\n=== Applying MLSMOTE to Training Data ===")

    # Initialize MLSMOTE
    mlsmote = MLSMOTE(k=5, n_synthetic=1.0)

    # Convert text to features for MLSMOTE (using simple tokenization)
    text_features = []
    for text in dataset_dict['text_data']:
        # Simple feature extraction - could be improved with embeddings
        words = text.lower().split()
        feature_vec = [len(words), len(set(words))]  # Basic features
        text_features.append(feature_vec)

    text_features = np.array(text_features)

    # Combine all features for MLSMOTE
    spatial_features = dataset_dict['spatial_features'].numpy()
    behavioral_features = dataset_dict['behavioral_features'].numpy()
    temporal_features = dataset_dict['temporal_sequences'].numpy().mean(axis=1)  # Average temporal features

    # Combine features
    combined_features = np.hstack([
        text_features,
        spatial_features,
        behavioral_features,
        temporal_features
    ])

    # Apply MLSMOTE
    labels = dataset_dict['moral_targets'].numpy()
    augmented_features, augmented_labels = mlsmote.fit_resample(combined_features, labels)

    # Split augmented features back
    n_original = len(dataset_dict['text_data'])
    n_synthetic = len(augmented_features) - n_original

    if n_synthetic > 0:
        # For synthetic samples, create appropriate data structures
        synthetic_indices = list(range(n_original, len(augmented_features)))

        # Extend original data with synthetic samples
        extended_text_data = dataset_dict['text_data'].copy()
        for i in range(n_synthetic):
            # Use a representative text for synthetic samples
            extended_text_data.append(dataset_dict['text_data'][i % len(dataset_dict['text_data'])])

        # Update dataset with augmented data
        dataset_dict['text_data'] = extended_text_data
        dataset_dict['spatial_features'] = torch.tensor(
            augmented_features[:, 2:2+spatial_features.shape[1]], dtype=torch.float32
        )
        dataset_dict['behavioral_features'] = torch.tensor(
            augmented_features[:, 2+spatial_features.shape[1]:2+spatial_features.shape[1]+behavioral_features.shape[1]],
            dtype=torch.float32
        )

        # Fix the temporal sequences creation
        temp_feats = augmented_features[:, -temporal_features.shape[1]:]
        temp_tensor = torch.tensor(temp_feats, dtype=torch.float32)
        temp_reshaped = temp_tensor.reshape(-1, 1, temporal_features.shape[1])
        # Use proper repeat syntax for PyTorch tensors
        dataset_dict['temporal_sequences'] = temp_reshaped.repeat(1, feature_info['temporal_sequence_length'], 1)

        dataset_dict['moral_targets'] = torch.tensor(augmented_labels, dtype=torch.float32)

    return dataset_dict


def prepare_dataset(csv_path, geojson_path, county_centroids_path, seq_len=10, apply_mlsmote=True):
    """Streamlined dataset preparation with MLSMOTE integration"""
    # Load data
    df = pd.read_csv(csv_path)
    with open(geojson_path) as f:
        boundaries = json.load(f)
    with open(county_centroids_path) as f:
        centroids = json.load(f)

    # Create sequences
    sequences, targets, other_features = create_temporal_sequences(df, seq_len)

    # Process features
    data = {'text': [], 'spatial': [], 'temporal': [], 'behavioral': [], 'targets': []}

    for i, (seq, target, other) in enumerate(zip(sequences, targets, other_features)):
        data['text'].append(other['text'])
        data['temporal'].append(seq)
        data['behavioral'].append(other['behavioral'])
        data['targets'].append(target)

        # Extract spatial features
        geoid = other['geoid']
        if geoid in boundaries and geoid in centroids:
            centroid = centroids[geoid]
            lat_norm = (centroid[0] - 24) / (49 - 24)
            lon_norm = (centroid[1] + 125) / (66 - 125)
            spatial_feat = [lat_norm, lon_norm, 0.5, 0.5, 0.5, 0.5]
        else:
            spatial_feat = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]

        data['spatial'].append(spatial_feat)

    # Convert to tensors
    spatial_tensor = torch.tensor(data['spatial'], dtype=torch.float32)
    temporal_tensor = torch.tensor(data['temporal'], dtype=torch.float32)
    behavioral_tensor = torch.tensor(data['behavioral'], dtype=torch.float32)
    targets_tensor = torch.stack(data['targets'])

    # Normalize features
    scaler_spatial = StandardScaler()
    scaler_behavioral = StandardScaler()
    scaler_temporal = StandardScaler()

    spatial_tensor = torch.tensor(scaler_spatial.fit_transform(spatial_tensor), dtype=torch.float32)
    behavioral_tensor = torch.tensor(scaler_behavioral.fit_transform(behavioral_tensor), dtype=torch.float32)

    temp_shape = temporal_tensor.shape
    temp_2d = temporal_tensor.view(-1, temp_shape[2])
    temp_scaled = scaler_temporal.fit_transform(temp_2d)
    temporal_tensor = torch.tensor(temp_scaled.reshape(temp_shape), dtype=torch.float32)

    # Create splits
    indices = list(range(len(data['text'])))
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

    datasets = {}
    for split, idx in [('train', train_idx), ('validation', val_idx), ('test', test_idx)]:
        datasets[split] = {
            'text_data': [data['text'][i] for i in idx],
            'spatial_features': spatial_tensor[idx],
            'temporal_sequences': temporal_tensor[idx],
            'behavioral_features': behavioral_tensor[idx],
            'moral_targets': targets_tensor[idx]
        }

    feature_info = {
        'spatial_feature_dim': spatial_tensor.shape[1],
        'temporal_feature_dim': temporal_tensor.shape[2],
        'temporal_sequence_length': seq_len,
        'behavioral_feature_dim': behavioral_tensor.shape[1]
    }

    # Apply MLSMOTE to training data
    if apply_mlsmote:
        datasets['train'] = apply_mlsmote_to_dataset(datasets['train'], feature_info)

    return datasets, feature_info

def create_dataloader(dataset_dict, batch_size=16, shuffle=True):
    """Create DataLoader for the dataset"""
    def collate_fn(batch):
        texts = [item[0] for item in batch]
        spatial = torch.stack([item[1] for item in batch])
        temporal = torch.stack([item[2] for item in batch])
        behavioral = torch.stack([item[3] for item in batch])
        targets = torch.stack([item[4] for item in batch])

        return {
            'text_data': texts,
            'spatial_features': spatial,
            'temporal_sequences': temporal,
            'behavioral_features': behavioral,
            'targets': targets
        }

    dataset = [(dataset_dict['text_data'][i],
                dataset_dict['spatial_features'][i],
                dataset_dict['temporal_sequences'][i],
                dataset_dict['behavioral_features'][i],
                dataset_dict['moral_targets'][i])
               for i in range(len(dataset_dict['text_data']))]

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

def compute_loss(predictions, targets, class_weights=None):
    """Compute weighted BCE loss for moral foundation predictions"""
    loss = 0
    for i, foundation in enumerate(moral_foundations):
        weight = class_weights[i] if class_weights is not None else 1.0
        foundation_loss = F.binary_cross_entropy(predictions[foundation], targets[:, i:i+1])
        loss += weight * foundation_loss
    return loss / len(moral_foundations)

def calculate_class_weights(targets):
    """Calculate class weights for imbalanced labels"""
    weights = []
    for i in range(targets.shape[1]):
        pos_count = torch.sum(targets[:, i]).item()
        neg_count = len(targets) - pos_count
        if pos_count > 0:
            weight = neg_count / pos_count
        else:
            weight = 1.0
        weights.append(weight)
    return weights

def train_models(datasets, feature_info, device, epochs=30):
    """Enhanced training function with class weighting"""
    train_loader = create_dataloader(datasets['train'], batch_size=16, shuffle=True)
    val_loader = create_dataloader(datasets['validation'], batch_size=16, shuffle=False)

    # Calculate class weights
    class_weights = calculate_class_weights(datasets['train']['moral_targets'])
    print(f"Class weights: {dict(zip(moral_foundations, class_weights))}")

    # Initialize models with optimized hyperparameters
    text_model = TextModelRoBERTaFrozen(dropout_rate=0.15).to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim'], dropout_rate=0.25).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length'],
        dropout_rate=0.25
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim'], dropout_rate=0.15).to(device)
    fusion_model = HeterogeneousFusion(dropout_rate=0.15).to(device)

    models = [text_model, spatial_model, temporal_model, behavioral_model, fusion_model]

    # Train individual models with early stopping
    for model, name in zip(models[:-1], ['text', 'spatial', 'temporal', 'behavioral']):
        optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

        best_val_loss = float('inf')
        patience_counter = 0

        for epoch in range(epochs):
            # Training
            model.train()
            total_loss = 0

            for batch in train_loader:
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                if name == 'text':
                    preds = model(batch['text_data'])
                elif name == 'spatial':
                    preds = model(batch['spatial_features'])
                elif name == 'temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # behavioral
                    preds = model(batch['behavioral_features'])

                loss = compute_loss(preds, batch['targets'], class_weights)

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                total_loss += loss.item()

            # Validation
            model.eval()
            val_loss = 0
            with torch.no_grad():
                for batch in val_loader:
                    batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                            for k, v in batch.items()}

                    if name == 'text':
                        preds = model(batch['text_data'])
                    elif name == 'spatial':
                        preds = model(batch['spatial_features'])
                    elif name == 'temporal':
                        preds = model(batch['temporal_sequences'])
                    else:  # behavioral
                        preds = model(batch['behavioral_features'])

                    val_loss += compute_loss(preds, batch['targets'], class_weights).item()

            avg_val_loss = val_loss / len(val_loader)
            scheduler.step(avg_val_loss)

            # Early stopping
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= 10:
                    print(f"Early stopping for {name} model at epoch {epoch}")
                    break

            if epoch % 10 == 0:
                print(f"{name.capitalize()} model epoch {epoch}, train loss: {total_loss/len(train_loader):.4f}, val loss: {avg_val_loss:.4f}")

    # Train fusion model
    all_params = []
    for model in models:
        all_params.extend(model.parameters())

    optimizer_fusion = AdamW(all_params, lr=5e-5, weight_decay=1e-4)
    scheduler_fusion = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_fusion, patience=5, factor=0.5)

    best_fusion_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        # Training
        for model in models:
            model.train()

        total_loss = 0
        for batch in train_loader:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            # Extract features
            text_feat = text_model.get_features(batch['text_data'])
            spatial_feat = spatial_model.get_features(batch['spatial_features'])
            temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])

            # Fusion
            fusion_preds = fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
            loss = compute_loss(fusion_preds, batch['targets'], class_weights)

            optimizer_fusion.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
            optimizer_fusion.step()
            total_loss += loss.item()

        # Validation
        for model in models:
            model.eval()

        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                text_feat = text_model.get_features(batch['text_data'])
                spatial_feat = spatial_model.get_features(batch['spatial_features'])
                temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
                behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])

                fusion_preds = fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
                val_loss += compute_loss(fusion_preds, batch['targets'], class_weights).item()

        avg_val_loss = val_loss / len(val_loader)
        scheduler_fusion.step(avg_val_loss)

        # Early stopping
        if avg_val_loss < best_fusion_loss:
            best_fusion_loss = avg_val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= 10:
                print(f"Early stopping for fusion model at epoch {epoch}")
                break

        if epoch % 5 == 0:
            print(f"Fusion epoch {epoch}, train loss: {total_loss/len(train_loader):.4f}, val loss: {avg_val_loss:.4f}")

    return text_model, spatial_model, temporal_model, behavioral_model, fusion_model

def evaluate_model_comprehensive(model, data_loader, device, model_type=None, text_model=None, spatial_model=None, temporal_model=None, behavioral_model=None):
    """Enhanced evaluation with comprehensive metrics"""
    model.eval()
    all_preds = {f: [] for f in moral_foundations}
    all_targets = {f: [] for f in moral_foundations}
    total_loss = 0
    num_batches = 0

    with torch.no_grad():
        for batch in data_loader:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            if model_type == 'fusion':
                text_feat = text_model.get_features(batch['text_data'])
                spatial_feat = spatial_model.get_features(batch['spatial_features'])
                temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
                behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])
                preds = model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
            else:
                if model_type == 'text':
                    preds = model(batch['text_data'])
                elif model_type == 'spatial':
                    preds = model(batch['spatial_features'])
                elif model_type == 'temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # behavioral
                    preds = model(batch['behavioral_features'])

            targets = batch['targets']
            loss = compute_loss(preds, targets)
            total_loss += loss.item()
            num_batches += 1

            for i, foundation in enumerate(moral_foundations):
                pred = preds[foundation].cpu().numpy().flatten()
                target = targets[:, i].cpu().numpy()
                all_preds[foundation].extend(pred)
                all_targets[foundation].extend(target)

    # Compute metrics
    metrics = {}
    reports = {}
    avg_metrics = {'accuracy': 0, 'f1': 0, 'precision': 0, 'recall': 0, 'auc': 0}

    for foundation in moral_foundations:
        pred_binary = (np.array(all_preds[foundation]) > 0.5).astype(int)
        target_binary = np.array(all_targets[foundation]).astype(int)

        metrics[foundation] = {
            'accuracy': accuracy_score(target_binary, pred_binary),
            'f1': f1_score(target_binary, pred_binary, zero_division=0),
            'precision': precision_score(target_binary, pred_binary, zero_division=0),
            'recall': recall_score(target_binary, pred_binary, zero_division=0)
        }

        if len(np.unique(target_binary)) > 1:
            metrics[foundation]['auc'] = roc_auc_score(target_binary, all_preds[foundation])
        else:
            metrics[foundation]['auc'] = 0.0

        reports[foundation] = classification_report(target_binary, pred_binary, zero_division=0)

        for metric in avg_metrics:
            avg_metrics[metric] += metrics[foundation][metric]

    for metric in avg_metrics:
        avg_metrics[metric] /= len(moral_foundations)

    avg_metrics['loss'] = total_loss / num_batches

    return metrics, reports, avg_metrics

def extract_fusion_weights(fusion_model):
    """Extract learned modality weights from fusion model"""
    try:
        dummy_input = torch.randn(1, fusion_model.d_model * 4).to(next(fusion_model.parameters()).device)
        weights = fusion_model.gate(dummy_input).squeeze().cpu().numpy()
        return {
            'Text': weights[0],
            'Spatial': weights[1],
            'Temporal': weights[2],
            'Behavioral': weights[3]
        }
    except:
        return {'Text': 0.25, 'Spatial': 0.25, 'Temporal': 0.25, 'Behavioral': 0.25}

def print_comprehensive_results(all_metrics, all_avg_metrics, fusion_weights):
    """Print results in the desired format"""
    print("\n" + "="*80)
    print("COMPREHENSIVE MULTIMODAL EVALUATION RESULTS WITH MLSMOTE")
    print("="*80)
    print()

    model_names = ['TEXT', 'SPATIAL', 'TEMPORAL', 'BEHAVIORAL', 'FUSION']
    model_keys = ['text', 'spatial', 'temporal', 'behavioral', 'fusion']

    # Individual model results
    for name, key in zip(model_names, model_keys):
        if key in all_avg_metrics:
            avg = all_avg_metrics[key]
            print(f"{name} MODEL RESULTS:")
            print("-" * 50)
            print(f"  Average Loss:      {avg['loss']:.4f}")
            print(f"  Average Accuracy:  {avg['accuracy']:.4f}")
            print(f"  Average F1 Score:  {avg['f1']:.4f}")
            print(f"  Average Precision: {avg['precision']:.4f}")
            print(f"  Average Recall:    {avg['recall']:.4f}")
            print(f"  Average AUC-ROC:   {avg['auc']:.4f}")
            print()

    # Fusion model weights
    if 'fusion' in all_avg_metrics and fusion_weights:
        print("  Learned Modality Weights:")
        for modality, weight in fusion_weights.items():
            print(f"    {modality:<11}: {weight:.3f}")
        print()

    # Foundation-specific comparison
    print("FOUNDATION-SPECIFIC COMPARISON:")
    print("-" * 80)
    header = "Foundation   Text F1  Spatial F1 Temporal F1 Behavioral F1"
    if 'fusion' in all_metrics:
        header += " Fusion F1"
    print(header)
    print("-" * 80)

    for foundation in moral_foundations:
        row = f"{foundation:<12}"
        for key in ['text', 'spatial', 'temporal', 'behavioral']:
            if key in all_metrics:
                f1_score = all_metrics[key][foundation]['f1']
                row += f" {f1_score:.3f}     "
        if 'fusion' in all_metrics:
            f1_score = all_metrics['fusion'][foundation]['f1']
            row += f" {f1_score:.3f}    "
        print(row)

    print("\n" + "="*80)

# Main execution
if __name__ == "__main__":
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)

    print("🚀 Starting MLSMOTE-Enhanced Multimodal Training...")

    # Prepare dataset with MLSMOTE
    datasets, feature_info = prepare_dataset(
        csv_path='/content/augmented_tweets.csv',
        geojson_path='/content/county_data_boarders.json',
        county_centroids_path='/content/county_centroids.json',
        seq_len=10,
        apply_mlsmote=True
    )

    # Train models
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    text_model, spatial_model, temporal_model, behavioral_model, fusion_model = train_models(
        datasets, feature_info, device, epochs=30
    )

    # Evaluate all models
    test_loader = create_dataloader(datasets['test'], batch_size=16, shuffle=False)

    all_metrics = {}
    all_avg_metrics = {}
    all_reports = {}

    model_list = [
        (text_model, 'text'),
        (spatial_model, 'spatial'),
        (temporal_model, 'temporal'),
        (behavioral_model, 'behavioral')
    ]

    # Evaluate individual models
    for model, name in model_list:
        print(f"Evaluating {name} model...")
        metrics, reports, avg_metrics = evaluate_model_comprehensive(
            model, test_loader, device, name
        )
        all_metrics[name] = metrics
        all_avg_metrics[name] = avg_metrics
        all_reports[name] = reports

    # Evaluate fusion model
    print("Evaluating fusion model...")
    metrics, reports, avg_metrics = evaluate_model_comprehensive(
        fusion_model, test_loader, device, 'fusion',
        text_model, spatial_model, temporal_model, behavioral_model
    )
    all_metrics['fusion'] = metrics
    all_avg_metrics['fusion'] = avg_metrics
    all_reports['fusion'] = reports

    # Extract fusion weights
    fusion_weights = extract_fusion_weights(fusion_model)

    # Print comprehensive results
    print_comprehensive_results(all_metrics, all_avg_metrics, fusion_weights)

    print("\n✅ MLSMOTE-Enhanced Training and Evaluation Completed!")
    print("📊 Class imbalance has been addressed using MLSMOTE technique.")
    print("🎯 Models trained with synthetic minority samples for better performance.")


🚀 Starting MLSMOTE-Enhanced Multimodal Training...

=== Applying MLSMOTE to Training Data ===
Identified minority labels: ['Fairness', 'Purity']
Imbalance ratios: {'Care': np.float32(0.32307693), 'Fairness': np.float32(85.0), 'Loyalty': np.float32(5.0), 'Authority': np.float32(10.217391), 'Purity': np.float32(50.6), 'Non_Moral': np.float32(10.727273)}
Generated 8 synthetic samples
Original dataset size: 258, Augmented dataset size: 266
Using device: cuda
Class weights: {'Care': 0.3103448275862069, 'Fairness': 32.25, 'Loyalty': 4.541666666666667, 'Authority': 9.64, 'Purity': 23.181818181818183, 'Non_Moral': 11.090909090909092}


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Text model epoch 0, train loss: 6.6656, val loss: 4.5863
Text model epoch 10, train loss: 2.9066, val loss: 3.2185
Early stopping for text model at epoch 12
Spatial model epoch 0, train loss: 8.8653, val loss: 8.6464
Spatial model epoch 10, train loss: 2.8667, val loss: 3.0613
Spatial model epoch 20, train loss: 2.7000, val loss: 3.0766
Early stopping for spatial model at epoch 22
Temporal model epoch 0, train loss: 7.9999, val loss: 6.5587
Temporal model epoch 10, train loss: 2.2110, val loss: 2.8096
Early stopping for temporal model at epoch 17
Behavioral model epoch 0, train loss: 8.8591, val loss: 7.5827
Behavioral model epoch 10, train loss: 2.4963, val loss: 2.9961
Early stopping for behavioral model at epoch 18
Fusion epoch 0, train loss: 6.4322, val loss: 3.7532
Fusion epoch 5, train loss: 2.3577, val loss: 2.8669
Fusion epoch 10, train loss: 1.7760, val loss: 3.0121
Fusion epoch 15, train loss: 1.6808, val loss: 3.1852
Early stopping for fusion model at epoch 16
Evaluating tex

# Gat + fusion + focal loss

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, Dropout, Sigmoid, ModuleDict, LayerNorm, LSTM, GELU, MultiheadAttention ,Softmax
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, classification_report
import pandas as pd
import numpy as np
import json
import math
from tqdm.autonotebook import tqdm

# PyTorch Geometric imports for GAT
try:
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
except ImportError:
    print("PyTorch Geometric not found. Please install: pip install torch_geometric")
    GATConv = None
    Data = None

# Moral foundations
moral_foundations = ['Care', 'Fairness', 'Loyalty', 'Authority', 'Purity', 'Non_Moral']

# ===================================================================
# 1. Multi-Label Focal Loss (from MFTC)
# ===================================================================
class MultiLabelFocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha  # Now accepts None, scalar, or tensor

    def forward(self, inputs, targets):
        p = torch.sigmoid(inputs)
        p_t = targets * p + (1 - targets) * (1 - p)

        # Handle different alpha types
        if self.alpha is not None:
            if isinstance(self.alpha, (float, int)):
                # Single alpha value for all classes
                alpha_t = targets * self.alpha + (1 - targets) * (1 - self.alpha)
            else:
                # Per-class alpha values (tensor)
                # Ensure alpha is on the same device as targets
                alpha = self.alpha.to(targets.device)
                # Expand alpha to match batch dimension if needed
                if alpha.dim() == 1:
                    alpha = alpha.unsqueeze(0).expand_as(targets)
                alpha_t = targets * alpha + (1 - targets) * (1 - alpha)
        else:
            alpha_t = 1.0

        # Compute focal loss
        focal_loss = -alpha_t * (1 - p_t) ** self.gamma * torch.log(p_t + 1e-8)
        return focal_loss.mean()




# ===================================================================
# 2. eMFD Processing and Graph Construction (from MFTC)
# ===================================================================
class eMFDProcessor:
    def __init__(self, emfd_csv_path):
        self.moral_foundations_emfd = ['care', 'fairness', 'loyalty', 'authority', 'purity', 'non-moral']
        if not os.path.exists(emfd_csv_path):
            raise FileNotFoundError(f"eMFD file not found at {emfd_csv_path}")
        df = pd.read_csv(emfd_csv_path)
        prob_cols = [f'{f}_p' for f in self.moral_foundations_emfd if f != 'non-moral']
        self.emfd_data = {
            row['word']: np.append([row.get(col, 0.0) for col in prob_cols], 0.0).astype(np.float32)
            for _, row in df.iterrows()
        }

    def extract_moral_concepts(self, text):
        concepts = [
            {'word': w, 'probabilities': self.emfd_data[w]}
            for w in text.lower().split() if w in self.emfd_data
        ]
        return concepts if concepts else [{'word': 'neutral', 'probabilities': np.zeros(6, dtype=np.float32)}]

class MoralGraphConstructor:
    def create_moral_graph(self, moral_concepts):
        num_concepts = len(moral_concepts)
        node_features = torch.FloatTensor(np.array([
            np.concatenate([c['probabilities'], np.zeros(256)]) for c in moral_concepts
        ]))

        if num_concepts == 1:
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
        else:
            edge_index = torch.tensor(
                [[i, j] for i in range(num_concepts) for j in range(num_concepts) if i != j],
                dtype=torch.long
            ).t().contiguous()

        return Data(x=node_features, edge_index=edge_index) if Data else None

# ===================================================================
# 3. GAT eMFD Module (adapted from MFTC)
# ===================================================================
class GATeMFDModule(nn.Module):
    def __init__(self, input_dim=262, output_dim=256, num_heads=4, dropout=0.1):
        super(GATeMFDModule, self).__init__()
        self.concept_embeddings = nn.Embedding(1000, 256)
        self.gat_layer = GATConv(input_dim, output_dim, heads=num_heads, dropout=dropout, concat=False) if GATConv else nn.Identity()
        self.attention_pooling = nn.Linear(output_dim, 1)

    def forward(self, graph_data, device):
        if not GATConv:
            return torch.zeros((1, 256)).to(device)
        x, edge_index = graph_data.x.to(device), graph_data.edge_index.to(device)
        x[:, 6:] = self.concept_embeddings(torch.arange(x.size(0)).long().to(device))
        gat_output = F.elu(self.gat_layer(x, edge_index))
        attn_weights = F.softmax(self.attention_pooling(gat_output), dim=0)
        return torch.sum(attn_weights * gat_output, dim=0, keepdim=True)

class TextModelGATeMFD(nn.Module):
    def __init__(self, emfd_csv_path):
        super().__init__()
        self.emfd_processor = eMFDProcessor(emfd_csv_path)
        self.graph_constructor = MoralGraphConstructor()
        self.gat_module = GATeMFDModule()

        # Additional processing layers
        self.processor = Sequential(
            Linear(256, 512), LayerNorm(512), GELU(), Dropout(0.1),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(0.1)
        )

        # Prediction heads
        self.heads = ModuleDict({
            foundation: Sequential(Linear(256, 128), ReLU(), Dropout(0.2),
                                 Linear(128, 1), Sigmoid())
            for foundation in moral_foundations
        })

    def forward(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        processed_features = self.processor(features)
        return {f: self.heads[f](processed_features) for f in moral_foundations}

    def get_features(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        return self.processor(features)

# ===================================================================
# 4. Other Modality Models (from MOTIV)
# ===================================================================
class SpatialModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), ReLU(), Dropout(0.2),
            Linear(128, 256)
        )
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class TemporalModelLSTM(nn.Module):
    def __init__(self, input_dim, seq_len=10, hidden_size=128):
        super().__init__()
        self.projection = Linear(input_dim, hidden_size)
        self.lstm = LSTM(hidden_size, hidden_size, 2, batch_first=True,
                        dropout=0.2, bidirectional=True)
        self.encoder = Sequential(
            Linear(hidden_size * 2, 256), LayerNorm(256),
            ReLU(), Dropout(0.2)
        )
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        features = self.encoder(lstm_out[:, -1])
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        return self.encoder(lstm_out[:, -1])

class BehavioralModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), LayerNorm(128), ReLU(), Dropout(0.1),
            Linear(128, 256), LayerNorm(256), ReLU(), Dropout(0.1)
        )
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

# ===================================================================
# 5. Heterogeneous Fusion (from MOTIV)
# ===================================================================
class HeterogeneousFusion(nn.Module):
    def __init__(self, input_dims=None, d_model=256, num_heads=8):
        super().__init__()
        self.d_model = d_model

        # Feature projection layers to align dimensions
        if input_dims:
            self.projections = nn.ModuleList([
                nn.Linear(dim, d_model) for dim in input_dims
            ])
        else:
            self.projections = nn.ModuleList([
                nn.Linear(256, d_model) for _ in range(4)
            ])

        self.modality_queries = nn.Parameter(torch.randn(4, d_model))

        self.cross_attention = nn.ModuleList([
            MultiheadAttention(d_model, num_heads, batch_first=True)
            for _ in range(4)
        ])

        self.self_attention = MultiheadAttention(d_model, num_heads, batch_first=True)
        self.layer_norms = nn.ModuleList([LayerNorm(d_model) for _ in range(5)])

        # Improved gating mechanism
        self.modality_gate = nn.Sequential(
            Linear(d_model * 4, d_model), ReLU(), Dropout(0.1),
            Linear(d_model, 4), Softmax(dim=-1)  # Use softmax for proper weighting
        )

        self.final_projection = Linear(d_model, d_model)

        self.heads = ModuleDict({
            f: Sequential(Linear(d_model, 128), ReLU(), Dropout(0.2),
                         Linear(128, 1), Sigmoid())
            for f in moral_foundations
        })

    def forward(self, text_feat, spatial_feat, temporal_feat, behavioral_feat):
        batch_size = text_feat.size(0)

        # Project all features to same dimension
        features = [
            self.projections[i](feat)
            for i, feat in enumerate([text_feat, spatial_feat, temporal_feat, behavioral_feat])
        ]

        # Add sequence dimension for attention
        modality_feats = [f.unsqueeze(1) for f in features]

        # Cross-attention with residual connections
        attended = []
        for i, (feat, attn, norm) in enumerate(zip(modality_feats, self.cross_attention, self.layer_norms[:4])):
            query = self.modality_queries[i:i+1].unsqueeze(0).expand(batch_size, -1, -1)
            att_out, _ = attn(query, feat, feat)
            attended.append(norm(att_out.squeeze(1) + features[i]))  # Proper residual

        # Stack for self-attention
        stacked = torch.stack(attended, dim=1)  # (batch_size, 4, d_model)

        # Self-attention
        self_att, _ = self.self_attention(stacked, stacked, stacked)
        fused = self.layer_norms[4](stacked + self_att)

        # Improved gating: compute gates from aggregated representation
        pooled = fused.mean(dim=1)  # Global pooling
        gates = self.modality_gate(fused.view(batch_size, -1))  # (batch_size, 4)

        # Apply gates to individual modality representations
        gated = torch.sum(fused * gates.unsqueeze(-1), dim=1)  # Weighted sum

        # Final projection
        final_features = self.final_projection(gated)

        return {f: self.heads[f](final_features) for f in moral_foundations}


# ===================================================================
# 6. Data Processing Functions (from MOTIV)
# ===================================================================
def prepare_labels(df):
    """Combine virtue/vice pairs and create Non-Moral labels"""
    df = df.reset_index(drop=True)

    # Combine virtue/vice pairs using logical OR
    labels = {}
    labels['Care'] = (df['Care'] == 1) | (df['Harm'] == 1)
    labels['Fairness'] = (df['Fairness'] == 1) | (df.get('Cheating', df.get('cheating', pd.Series([False]*len(df), index=df.index))) == 1)
    labels['Loyalty'] = (df['Loyalty'] == 1) | (df['Betrayal'] == 1)
    labels['Authority'] = (df['Authority'] == 1) | (df['Subversion'] == 1)
    labels['Purity'] = (df['Purity'] == 1) | (df['Degradation'] == 1)

    # Create DataFrame from moral foundation labels
    moral_df = pd.DataFrame({
        'Care': labels['Care'],
        'Fairness': labels['Fairness'],
        'Loyalty': labels['Loyalty'],
        'Authority': labels['Authority'],
        'Purity': labels['Purity']
    })

    # Non_Moral = 1 when all moral foundations are 0 (False)
    labels['Non_Moral'] = (~moral_df.any(axis=1)).astype(int)

    # Convert to numpy arrays
    labels_array = np.column_stack([labels[f].astype(int).values for f in moral_foundations])
    return torch.tensor(labels_array, dtype=torch.float32)

def create_temporal_sequences(df, seq_len=10):
    """Create temporal sequences from dataset"""
    df_sorted = df.sort_values(['GEOID', 'year', 'month', 'day'])
    sequences, targets, other_features = [], [], []

    for geoid, group in df_sorted.groupby('GEOID'):
        if len(group) >= seq_len:
            for i in range(len(group) - seq_len + 1):
                seq_data = group.iloc[i:i+seq_len]

                # Temporal features
                temp_seq = []
                for _, row in seq_data.iterrows():
                    temp_features = [
                        row['month'], row['day'], row['year'], row['after_05/26/2020'],
                        row['female_pct'], row['lesscollege_pct'], row['cvap'],
                        row['clf_unemploy_pct'], row['lesshs_pct'], row['median_hh_inc'],
                        row['urm_pct'], row['ruralurban_cc'], row['demgov'], row['repgov'],
                        row['net_dem_president_votes'], row['net_dem_gov_votes'], row['is_blue'],
                        row['deaths'], row['cases'], row['cases_per_capita'],
                        row['deaths_per_capita'], row['cases_per_capita_discrete'], row['mask_score']
                    ]
                    temp_seq.append(temp_features)

                sequences.append(temp_seq)

                last_row = seq_data.iloc[-1]
                # Prepare labels for last timestep
                last_df = pd.DataFrame([last_row])
                targets.append(prepare_labels(last_df)[0])

                other_features.append({
                    'text': last_row['text'],
                    'geoid': str(last_row['GEOID']).zfill(5),
                    'behavioral': [last_row['retweet_count'], last_row['rt_discrete'],
                                 last_row['for_sah'], last_row['is_vivid'], last_row['sentiment_score']]
                })

    return sequences, targets, other_features

def prepare_dataset(csv_path, geojson_path, county_centroids_path, seq_len=10):
    """Streamlined dataset preparation"""
    # Load data
    df = pd.read_csv(csv_path)
    with open(geojson_path) as f:
        boundaries = json.load(f)
    with open(county_centroids_path) as f:
        centroids = json.load(f)

    # Create sequences
    sequences, targets, other_features = create_temporal_sequences(df, seq_len)

    # Process features
    data = {'text': [], 'spatial': [], 'temporal': [], 'behavioral': [], 'targets': []}

    for i, (seq, target, other) in enumerate(zip(sequences, targets, other_features)):
        data['text'].append(other['text'])
        data['temporal'].append(seq)
        data['behavioral'].append(other['behavioral'])
        data['targets'].append(target)

        # Extract spatial features
        geoid = other['geoid']
        if geoid in boundaries and geoid in centroids:
            # Simplified spatial feature extraction
            centroid = centroids[geoid]
            lat_norm = (centroid[0] - 24) / (49 - 24)  # Normalize lat
            lon_norm = (centroid[1] + 125) / (66 - 125)  # Normalize lon
            spatial_feat = [lat_norm, lon_norm, 0.5, 0.5, 0.5, 0.5]  # Simplified features
        else:
            spatial_feat = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]  # Default

        data['spatial'].append(spatial_feat)

    # Convert to tensors
    spatial_tensor = torch.tensor(data['spatial'], dtype=torch.float32)
    temporal_tensor = torch.tensor(data['temporal'], dtype=torch.float32)
    behavioral_tensor = torch.tensor(data['behavioral'], dtype=torch.float32)
    targets_tensor = torch.stack(data['targets'])

    # Normalize features
    scaler_spatial = StandardScaler()
    scaler_behavioral = StandardScaler()
    scaler_temporal = StandardScaler()

    spatial_tensor = torch.tensor(scaler_spatial.fit_transform(spatial_tensor), dtype=torch.float32)
    behavioral_tensor = torch.tensor(scaler_behavioral.fit_transform(behavioral_tensor), dtype=torch.float32)

    # Normalize temporal (reshape for scaling)
    temp_shape = temporal_tensor.shape
    temp_2d = temporal_tensor.view(-1, temp_shape[2])
    temp_scaled = scaler_temporal.fit_transform(temp_2d)
    temporal_tensor = torch.tensor(temp_scaled.reshape(temp_shape), dtype=torch.float32)

    # Create splits
    indices = list(range(len(data['text'])))
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

    datasets = {}
    for split, idx in [('train', train_idx), ('validation', val_idx), ('test', test_idx)]:
        datasets[split] = {
            'text_data': [data['text'][i] for i in idx],
            'spatial_features': spatial_tensor[idx],
            'temporal_sequences': temporal_tensor[idx],
            'behavioral_features': behavioral_tensor[idx],
            'moral_targets': targets_tensor[idx]
        }

    feature_info = {
        'spatial_feature_dim': spatial_tensor.shape[1],
        'temporal_feature_dim': temporal_tensor.shape[2],
        'temporal_sequence_length': seq_len,
        'behavioral_feature_dim': behavioral_tensor.shape[1]
    }

    return datasets, feature_info

def create_dataloader(dataset_dict, batch_size=16, shuffle=True):
    """Create DataLoader for the dataset"""
    def collate_fn(batch):
        texts = [item[0] for item in batch]
        spatial = torch.stack([item[1] for item in batch])
        temporal = torch.stack([item[2] for item in batch])
        behavioral = torch.stack([item[3] for item in batch])
        targets = torch.stack([item[4] for item in batch])

        return {
            'text_data': texts,
            'spatial_features': spatial,
            'temporal_sequences': temporal,
            'behavioral_features': behavioral,
            'targets': targets
        }

    dataset = [(dataset_dict['text_data'][i],
                dataset_dict['spatial_features'][i],
                dataset_dict['temporal_sequences'][i],
                dataset_dict['behavioral_features'][i],
                dataset_dict['moral_targets'][i])
               for i in range(len(dataset_dict['text_data']))]

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

# ===================================================================
# 7. Training and Evaluation Functions
# ===================================================================
def compute_focal_loss(predictions, targets, criterion):
    """Compute focal loss for moral foundation predictions"""
    # Stack predictions for all foundations
    pred_stack = torch.cat([predictions[f] for f in moral_foundations], dim=1)
    return criterion(pred_stack, targets)

def train_models(datasets, feature_info, device, emfd_csv_path, epochs=30):
    """Training function with focal loss"""
    train_loader = create_dataloader(datasets['train'], batch_size=16, shuffle=True)

    # Calculate class weights for focal loss

    print("⚖️ Calculating per-class alpha weights for focal loss...")
    all_targets = datasets['train']['moral_targets']
    num_positives = torch.sum(all_targets, dim=0)
    num_negatives = len(all_targets) - num_positives

    # Calculate alpha as inverse frequency ratio (better approach)
    alpha_per_class = num_negatives / (num_positives + num_negatives + 1e-8)
    alpha_tensor = alpha_per_class.to(device)

    print(f"✅ Per-class alpha weights: {alpha_tensor.cpu().numpy().round(3)}")

    # Initialize focal loss with per-class alphas
    focal_criterion = MultiLabelFocalLoss(gamma=2.0, alpha=alpha_tensor)


    # Initialize models
    text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)
    fusion_model = HeterogeneousFusion().to(device)

    models = [text_model, spatial_model, temporal_model, behavioral_model, fusion_model]

    # Train individual models first
    for model, name in zip(models[:-1], ['text', 'spatial', 'temporal', 'behavioral']):
        optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

        for epoch in range(epochs):
            model.train()
            total_loss = 0

            for batch in train_loader:
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                if name == 'text':
                    preds = model(batch['text_data'])
                elif name == 'spatial':
                    preds = model(batch['spatial_features'])
                elif name == 'temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # behavioral
                    preds = model(batch['behavioral_features'])

                loss = compute_focal_loss(preds, batch['targets'], focal_criterion)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if epoch % 10 == 0:
                print(f"{name.capitalize()} model epoch {epoch}, loss: {total_loss/len(train_loader):.4f}")

    # Train fusion model
    all_params = []
    for model in models:
        all_params.extend(model.parameters())

    optimizer_fusion = optim.AdamW(all_params, lr=5e-5, weight_decay=1e-4)

    for epoch in range(epochs):
        for model in models:
            model.train()

        total_loss = 0
        for batch in train_loader:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            # Extract features
            text_feat = text_model.get_features(batch['text_data'])
            spatial_feat = spatial_model.get_features(batch['spatial_features'])
            temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])

            # Fusion
            fusion_preds = fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
            loss = compute_focal_loss(fusion_preds, batch['targets'], focal_criterion)

            optimizer_fusion.zero_grad()
            loss.backward()
            optimizer_fusion.step()
            total_loss += loss.item()

        if epoch % 5 == 0:
            print(f"Fusion epoch {epoch}, loss: {total_loss/len(train_loader):.4f}")

    return text_model, spatial_model, temporal_model, behavioral_model, fusion_model

def evaluate_model_comprehensive(model, data_loader, device, model_type=None, text_model=None, spatial_model=None, temporal_model=None, behavioral_model=None):
    """Enhanced evaluation with comprehensive metrics"""
    model.eval()
    all_preds = {f: [] for f in moral_foundations}
    all_targets = {f: [] for f in moral_foundations}

    with torch.no_grad():
        for batch in data_loader:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            if model_type == 'fusion':
                # Extract features for fusion
                text_feat = text_model.get_features(batch['text_data'])
                spatial_feat = spatial_model.get_features(batch['spatial_features'])
                temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
                behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])
                preds = model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
            else:
                if model_type == 'text':
                    preds = model(batch['text_data'])
                elif model_type == 'spatial':
                    preds = model(batch['spatial_features'])
                elif model_type == 'temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # behavioral
                    preds = model(batch['behavioral_features'])

            targets = batch['targets']

            for i, foundation in enumerate(moral_foundations):
                pred = preds[foundation].cpu().numpy().flatten()
                target = targets[:, i].cpu().numpy()
                all_preds[foundation].extend(pred)
                all_targets[foundation].extend(target)

    # Compute metrics
    metrics = {}
    avg_metrics = {'accuracy': 0, 'f1': 0, 'precision': 0, 'recall': 0, 'auc': 0}

    for foundation in moral_foundations:
        pred_binary = (np.array(all_preds[foundation]) > 0.5).astype(int)
        target_binary = np.array(all_targets[foundation]).astype(int)

        metrics[foundation] = {
            'accuracy': accuracy_score(target_binary, pred_binary),
            'f1': f1_score(target_binary, pred_binary, zero_division=0),
            'precision': precision_score(target_binary, pred_binary, zero_division=0),
            'recall': recall_score(target_binary, pred_binary, zero_division=0)
        }

        if len(np.unique(target_binary)) > 1:
            metrics[foundation]['auc'] = roc_auc_score(target_binary, all_preds[foundation])
        else:
            metrics[foundation]['auc'] = 0.0

        # Add to averages
        for metric in avg_metrics:
            avg_metrics[metric] += metrics[foundation][metric]

    # Calculate averages
    for metric in avg_metrics:
        avg_metrics[metric] /= len(moral_foundations)

    return metrics, avg_metrics

def print_comprehensive_results(all_metrics, all_avg_metrics):
    """Print results in the desired format"""
    print("COMPREHENSIVE MULTIMODAL GAT-eMFD EVALUATION RESULTS")
    print("=" * 80)
    print()

    model_names = ['GAT-eMFD TEXT', 'SPATIAL', 'TEMPORAL', 'BEHAVIORAL', 'FUSION']
    model_keys = ['text', 'spatial', 'temporal', 'behavioral', 'fusion']

    # Individual model results
    for name, key in zip(model_names, model_keys):
        if key in all_avg_metrics:
            avg = all_avg_metrics[key]
            print(f"{name} MODEL RESULTS:")
            print("-" * 50)
            print(f"  Average Accuracy:  {avg['accuracy']:.4f}")
            print(f"  Average F1 Score:  {avg['f1']:.4f}")
            print(f"  Average Precision: {avg['precision']:.4f}")
            print(f"  Average Recall:    {avg['recall']:.4f}")
            print(f"  Average AUC-ROC:   {avg['auc']:.4f}")
            print()

    # Foundation-specific comparison
    print("FOUNDATION-SPECIFIC COMPARISON:")
    print("-" * 80)
    header = "Foundation   GAT-Text F1 Spatial F1 Temporal F1 Behavioral F1"
    if 'fusion' in all_metrics:
        header += " Fusion F1"
    print(header)
    print("-" * 80)

    for foundation in moral_foundations:
        row = f"{foundation:<12}"
        for key in ['text', 'spatial', 'temporal', 'behavioral']:
            if key in all_metrics:
                f1_score = all_metrics[key][foundation]['f1']
                row += f" {f1_score:.3f}      "
        if 'fusion' in all_metrics:
            f1_score = all_metrics['fusion'][foundation]['f1']
            row += f" {f1_score:.3f}    "
        print(row)

# ===================================================================
# 8. Main Execution
# ===================================================================
if __name__ == "__main__":
    import os

    # File paths - UPDATE THESE
    CSV_PATH = '/content/augmented_tweets.csv'
    GEOJSON_PATH = '/content/county_data_boarders.json'
    CENTROIDS_PATH = '/content/county_centroids.json'
    EMFD_CSV_PATH = '/content/eMFD_wordlist.csv'  # UPDATE THIS PATH

    if not GATConv:
        print("❌ PyTorch Geometric not available. Please install it.")
        exit()

    # Prepare dataset
    datasets, feature_info = prepare_dataset(
        csv_path=CSV_PATH,
        geojson_path=GEOJSON_PATH,
        county_centroids_path=CENTROIDS_PATH,
        seq_len=10
    )

    # Train models
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    text_model, spatial_model, temporal_model, behavioral_model, fusion_model = train_models(
        datasets, feature_info, device, EMFD_CSV_PATH
    )

    # Evaluate all models
    test_loader = create_dataloader(datasets['test'], batch_size=16, shuffle=False)

    all_metrics = {}
    all_avg_metrics = {}

    model_list = [
        (text_model, 'text'),
        (spatial_model, 'spatial'),
        (temporal_model, 'temporal'),
        (behavioral_model, 'behavioral')
    ]

    # Evaluate individual models
    for model, name in model_list:
        metrics, avg_metrics = evaluate_model_comprehensive(
            model, test_loader, device, name
        )
        all_metrics[name] = metrics
        all_avg_metrics[name] = avg_metrics

    # Evaluate fusion model
    metrics, avg_metrics = evaluate_model_comprehensive(
        fusion_model, test_loader, device, 'fusion',
        text_model, spatial_model, temporal_model, behavioral_model
    )
    all_metrics['fusion'] = metrics
    all_avg_metrics['fusion'] = avg_metrics

    # Print comprehensive results
    print_comprehensive_results(all_metrics, all_avg_metrics)

    print("\n✅ Training and evaluation completed!")


Using device: cuda
⚖️ Calculating per-class alpha weights for focal loss...
✅ Per-class alpha weights: [0.244 0.988 0.833 0.911 0.981 0.915]
Text model epoch 0, loss: 0.0341
Text model epoch 10, loss: 0.0300
Text model epoch 20, loss: 0.0298
Spatial model epoch 0, loss: 0.0374
Spatial model epoch 10, loss: 0.0306
Spatial model epoch 20, loss: 0.0291
Temporal model epoch 0, loss: 0.0365
Temporal model epoch 10, loss: 0.0293
Temporal model epoch 20, loss: 0.0272
Behavioral model epoch 0, loss: 0.0374
Behavioral model epoch 10, loss: 0.0283
Behavioral model epoch 20, loss: 0.0281
Fusion epoch 0, loss: 0.0369
Fusion epoch 5, loss: 0.0287
Fusion epoch 10, loss: 0.0275
Fusion epoch 15, loss: 0.0248
Fusion epoch 20, loss: 0.0241
Fusion epoch 25, loss: 0.0241
COMPREHENSIVE MULTIMODAL GAT-eMFD EVALUATION RESULTS

GAT-eMFD TEXT MODEL RESULTS:
--------------------------------------------------
  Average Accuracy:  0.8185
  Average F1 Score:  0.0145
  Average Precision: 0.1667
  Average Recall:   

# without focal loss

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, Dropout, Sigmoid, ModuleDict, LayerNorm, LSTM, GELU, MultiheadAttention, Softmax
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, classification_report
import pandas as pd
import numpy as np
import json
import math
import os
from tqdm.autonotebook import tqdm

# PyTorch Geometric imports for GAT
try:
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
except ImportError:
    print("PyTorch Geometric not found. Please install: pip install torch_geometric")
    GATConv = None
    Data = None

# Moral foundations
moral_foundations = ['Care', 'Fairness', 'Loyalty', 'Authority', 'Purity', 'Non_Moral']

# ===================================================================
# 1. eMFD Processing and Graph Construction
# ===================================================================
class eMFDProcessor:
    def __init__(self, emfd_csv_path):
        self.moral_foundations_emfd = ['care', 'fairness', 'loyalty', 'authority', 'purity', 'non-moral']
        if not os.path.exists(emfd_csv_path):
            raise FileNotFoundError(f"eMFD file not found at {emfd_csv_path}")
        df = pd.read_csv(emfd_csv_path)
        prob_cols = [f'{f}_p' for f in self.moral_foundations_emfd if f != 'non-moral']
        self.emfd_data = {
            row['word']: np.append([row.get(col, 0.0) for col in prob_cols], 0.0).astype(np.float32)
            for _, row in df.iterrows()
        }

    def extract_moral_concepts(self, text):
        concepts = [
            {'word': w, 'probabilities': self.emfd_data[w]}
            for w in text.lower().split() if w in self.emfd_data
        ]
        return concepts if concepts else [{'word': 'neutral', 'probabilities': np.zeros(6, dtype=np.float32)}]

class MoralGraphConstructor:
    def create_moral_graph(self, moral_concepts):
        num_concepts = len(moral_concepts)
        node_features = torch.FloatTensor(np.array([
            np.concatenate([c['probabilities'], np.zeros(256)]) for c in moral_concepts
        ]))

        if num_concepts == 1:
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
        else:
            edge_index = torch.tensor(
                [[i, j] for i in range(num_concepts) for j in range(num_concepts) if i != j],
                dtype=torch.long
            ).t().contiguous()

        return Data(x=node_features, edge_index=edge_index) if Data else None

# ===================================================================
# 2. GAT eMFD Module
# ===================================================================
class GATeMFDModule(nn.Module):
    def __init__(self, input_dim=262, output_dim=256, num_heads=4, dropout=0.1):
        super(GATeMFDModule, self).__init__()
        self.concept_embeddings = nn.Embedding(1000, 256)
        self.gat_layer = GATConv(input_dim, output_dim, heads=num_heads, dropout=dropout, concat=False) if GATConv else nn.Identity()
        self.attention_pooling = nn.Linear(output_dim, 1)

    def forward(self, graph_data, device):
        if not GATConv:
            return torch.zeros((1, 256)).to(device)
        x, edge_index = graph_data.x.to(device), graph_data.edge_index.to(device)
        x[:, 6:] = self.concept_embeddings(torch.arange(x.size(0)).long().to(device))
        gat_output = F.elu(self.gat_layer(x, edge_index))
        attn_weights = F.softmax(self.attention_pooling(gat_output), dim=0)
        return torch.sum(attn_weights * gat_output, dim=0, keepdim=True)

class TextModelGATeMFD(nn.Module):
    def __init__(self, emfd_csv_path):
        super().__init__()
        self.emfd_processor = eMFDProcessor(emfd_csv_path)
        self.graph_constructor = MoralGraphConstructor()
        self.gat_module = GATeMFDModule()

        # Additional processing layers
        self.processor = Sequential(
            Linear(256, 512), LayerNorm(512), GELU(), Dropout(0.1),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(0.1)
        )

        # Prediction heads - Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            foundation: Sequential(Linear(256, 128), ReLU(), Dropout(0.2),
                                 Linear(128, 1))
            for foundation in moral_foundations
        })

    def forward(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        processed_features = self.processor(features)
        return {f: self.heads[f](processed_features) for f in moral_foundations}

    def get_features(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        return self.processor(features)

# ===================================================================
# 3. Other Modality Models
# ===================================================================
class SpatialModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), ReLU(), Dropout(0.2),
            Linear(128, 256)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class TemporalModelLSTM(nn.Module):
    def __init__(self, input_dim, seq_len=10, hidden_size=128):
        super().__init__()
        self.projection = Linear(input_dim, hidden_size)
        self.lstm = LSTM(hidden_size, hidden_size, 2, batch_first=True,
                        dropout=0.2, bidirectional=True)
        self.encoder = Sequential(
            Linear(hidden_size * 2, 256), LayerNorm(256),
            ReLU(), Dropout(0.2)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        features = self.encoder(lstm_out[:, -1])
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        return self.encoder(lstm_out[:, -1])

class BehavioralModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), LayerNorm(128), ReLU(), Dropout(0.1),
            Linear(128, 256), LayerNorm(256), ReLU(), Dropout(0.1)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

# ===================================================================
# 4. Heterogeneous Fusion
# ===================================================================
class HeterogeneousFusion(nn.Module):
    def __init__(self, input_dims=None, d_model=256, num_heads=8):
        super().__init__()
        self.d_model = d_model

        # Feature projection layers to align dimensions
        if input_dims:
            self.projections = nn.ModuleList([
                nn.Linear(dim, d_model) for dim in input_dims
            ])
        else:
            self.projections = nn.ModuleList([
                nn.Linear(256, d_model) for _ in range(4)
            ])

        self.modality_queries = nn.Parameter(torch.randn(4, d_model))

        self.cross_attention = nn.ModuleList([
            MultiheadAttention(d_model, num_heads, batch_first=True)
            for _ in range(4)
        ])

        self.self_attention = MultiheadAttention(d_model, num_heads, batch_first=True)
        self.layer_norms = nn.ModuleList([LayerNorm(d_model) for _ in range(5)])

        # Improved gating mechanism
        self.modality_gate = nn.Sequential(
            Linear(d_model * 4, d_model), ReLU(), Dropout(0.1),
            Linear(d_model, 4), Softmax(dim=-1)
        )

        self.final_projection = Linear(d_model, d_model)

        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(d_model, 128), ReLU(), Dropout(0.2),
                         Linear(128, 1))
            for f in moral_foundations
        })

    def forward(self, text_feat, spatial_feat, temporal_feat, behavioral_feat):
        batch_size = text_feat.size(0)

        # Project all features to same dimension
        features = [
            self.projections[i](feat)
            for i, feat in enumerate([text_feat, spatial_feat, temporal_feat, behavioral_feat])
        ]

        # Add sequence dimension for attention
        modality_feats = [f.unsqueeze(1) for f in features]

        # Cross-attention with residual connections
        attended = []
        for i, (feat, attn, norm) in enumerate(zip(modality_feats, self.cross_attention, self.layer_norms[:4])):
            query = self.modality_queries[i:i+1].unsqueeze(0).expand(batch_size, -1, -1)
            att_out, _ = attn(query, feat, feat)
            attended.append(norm(att_out.squeeze(1) + features[i]))

        # Stack for self-attention
        stacked = torch.stack(attended, dim=1)

        # Self-attention
        self_att, _ = self.self_attention(stacked, stacked, stacked)
        fused = self.layer_norms[4](stacked + self_att)

        # Improved gating: compute gates from aggregated representation
        pooled = fused.mean(dim=1)
        gates = self.modality_gate(fused.view(batch_size, -1))

        # Apply gates to individual modality representations
        gated = torch.sum(fused * gates.unsqueeze(-1), dim=1)

        # Final projection
        final_features = self.final_projection(gated)

        return {f: self.heads[f](final_features) for f in moral_foundations}

# ===================================================================
# 5. Data Processing Functions
# ===================================================================
def prepare_labels(df):
    """Combine virtue/vice pairs and create Non-Moral labels"""
    df = df.reset_index(drop=True)

    # Combine virtue/vice pairs using logical OR
    labels = {}
    labels['Care'] = (df['Care'] == 1) | (df['Harm'] == 1)
    labels['Fairness'] = (df['Fairness'] == 1) | (df.get('Cheating', df.get('cheating', pd.Series([False]*len(df), index=df.index))) == 1)
    labels['Loyalty'] = (df['Loyalty'] == 1) | (df['Betrayal'] == 1)
    labels['Authority'] = (df['Authority'] == 1) | (df['Subversion'] == 1)
    labels['Purity'] = (df['Purity'] == 1) | (df['Degradation'] == 1)

    # Create DataFrame from moral foundation labels
    moral_df = pd.DataFrame({
        'Care': labels['Care'],
        'Fairness': labels['Fairness'],
        'Loyalty': labels['Loyalty'],
        'Authority': labels['Authority'],
        'Purity': labels['Purity']
    })

    # Non_Moral = 1 when all moral foundations are 0 (False)
    labels['Non_Moral'] = (~moral_df.any(axis=1)).astype(int)

    # Convert to numpy arrays
    labels_array = np.column_stack([labels[f].astype(int).values for f in moral_foundations])
    return torch.tensor(labels_array, dtype=torch.float32)

def create_temporal_sequences(df, seq_len=10):
    """Create temporal sequences from dataset"""
    df_sorted = df.sort_values(['GEOID', 'year', 'month', 'day'])
    sequences, targets, other_features = [], [], []

    for geoid, group in df_sorted.groupby('GEOID'):
        if len(group) >= seq_len:
            for i in range(len(group) - seq_len + 1):
                seq_data = group.iloc[i:i+seq_len]

                # Temporal features
                temp_seq = []
                for _, row in seq_data.iterrows():
                    temp_features = [
                        row['month'], row['day'], row['year'], row['after_05/26/2020'],
                        row['female_pct'], row['lesscollege_pct'], row['cvap'],
                        row['clf_unemploy_pct'], row['lesshs_pct'], row['median_hh_inc'],
                        row['urm_pct'], row['ruralurban_cc'], row['demgov'], row['repgov'],
                        row['net_dem_president_votes'], row['net_dem_gov_votes'], row['is_blue'],
                        row['deaths'], row['cases'], row['cases_per_capita'],
                        row['deaths_per_capita'], row['cases_per_capita_discrete'], row['mask_score']
                    ]
                    temp_seq.append(temp_features)

                sequences.append(temp_seq)

                last_row = seq_data.iloc[-1]
                # Prepare labels for last timestep
                last_df = pd.DataFrame([last_row])
                targets.append(prepare_labels(last_df)[0])

                other_features.append({
                    'text': last_row['text'],
                    'geoid': str(last_row['GEOID']).zfill(5),
                    'behavioral': [last_row['retweet_count'], last_row['rt_discrete'],
                                 last_row['for_sah'], last_row['is_vivid'], last_row['sentiment_score']]
                })

    return sequences, targets, other_features

def prepare_dataset(csv_path, geojson_path, county_centroids_path, seq_len=10):
    """Streamlined dataset preparation"""
    # Load data
    df = pd.read_csv(csv_path)
    with open(geojson_path) as f:
        boundaries = json.load(f)
    with open(county_centroids_path) as f:
        centroids = json.load(f)

    # Create sequences
    sequences, targets, other_features = create_temporal_sequences(df, seq_len)

    # Process features
    data = {'text': [], 'spatial': [], 'temporal': [], 'behavioral': [], 'targets': []}

    for i, (seq, target, other) in enumerate(zip(sequences, targets, other_features)):
        data['text'].append(other['text'])
        data['temporal'].append(seq)
        data['behavioral'].append(other['behavioral'])
        data['targets'].append(target)

        # Extract spatial features
        geoid = other['geoid']
        if geoid in boundaries and geoid in centroids:
            # Simplified spatial feature extraction
            centroid = centroids[geoid]
            lat_norm = (centroid[0] - 24) / (49 - 24)  # Normalize lat
            lon_norm = (centroid[1] + 125) / (66 - 125)  # Normalize lon
            spatial_feat = [lat_norm, lon_norm, 0.5, 0.5, 0.5, 0.5]  # Simplified features
        else:
            spatial_feat = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]  # Default

        data['spatial'].append(spatial_feat)

    # Convert to tensors
    spatial_tensor = torch.tensor(data['spatial'], dtype=torch.float32)
    temporal_tensor = torch.tensor(data['temporal'], dtype=torch.float32)
    behavioral_tensor = torch.tensor(data['behavioral'], dtype=torch.float32)
    targets_tensor = torch.stack(data['targets'])

    # Normalize features
    scaler_spatial = StandardScaler()
    scaler_behavioral = StandardScaler()
    scaler_temporal = StandardScaler()

    spatial_tensor = torch.tensor(scaler_spatial.fit_transform(spatial_tensor), dtype=torch.float32)
    behavioral_tensor = torch.tensor(scaler_behavioral.fit_transform(behavioral_tensor), dtype=torch.float32)

    # Normalize temporal (reshape for scaling)
    temp_shape = temporal_tensor.shape
    temp_2d = temporal_tensor.view(-1, temp_shape[2])
    temp_scaled = scaler_temporal.fit_transform(temp_2d)
    temporal_tensor = torch.tensor(temp_scaled.reshape(temp_shape), dtype=torch.float32)

    # Create splits
    indices = list(range(len(data['text'])))
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

    datasets = {}
    for split, idx in [('train', train_idx), ('validation', val_idx), ('test', test_idx)]:
        datasets[split] = {
            'text_data': [data['text'][i] for i in idx],
            'spatial_features': spatial_tensor[idx],
            'temporal_sequences': temporal_tensor[idx],
            'behavioral_features': behavioral_tensor[idx],
            'moral_targets': targets_tensor[idx]
        }

    feature_info = {
        'spatial_feature_dim': spatial_tensor.shape[1],
        'temporal_feature_dim': temporal_tensor.shape[2],
        'temporal_sequence_length': seq_len,
        'behavioral_feature_dim': behavioral_tensor.shape[1]
    }

    return datasets, feature_info

def create_dataloader(dataset_dict, batch_size=16, shuffle=True):
    """Create DataLoader for the dataset"""
    def collate_fn(batch):
        texts = [item[0] for item in batch]
        spatial = torch.stack([item[1] for item in batch])
        temporal = torch.stack([item[2] for item in batch])
        behavioral = torch.stack([item[3] for item in batch])
        targets = torch.stack([item[4] for item in batch])

        return {
            'text_data': texts,
            'spatial_features': spatial,
            'temporal_sequences': temporal,
            'behavioral_features': behavioral,
            'targets': targets
        }

    dataset = [(dataset_dict['text_data'][i],
                dataset_dict['spatial_features'][i],
                dataset_dict['temporal_sequences'][i],
                dataset_dict['behavioral_features'][i],
                dataset_dict['moral_targets'][i])
               for i in range(len(dataset_dict['text_data']))]

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

# ===================================================================
# 6. Training and Evaluation Functions
# ===================================================================
def compute_bce_loss(predictions, targets, criterion):
    """Compute BCE loss for moral foundation predictions"""
    # Stack predictions for all foundations
    pred_stack = torch.cat([predictions[f] for f in moral_foundations], dim=1)
    return criterion(pred_stack, targets)

def train_models(datasets, feature_info, device, emfd_csv_path, epochs=30):
    """Training function with BCEWithLogitsLoss"""
    train_loader = create_dataloader(datasets['train'], batch_size=16, shuffle=True)

    # Calculate positive weights for class imbalance
    print("⚖️ Calculating positive weights for BCEWithLogitsLoss...")
    all_targets = datasets['train']['moral_targets']
    num_positives = torch.sum(all_targets, dim=0)
    num_negatives = len(all_targets) - num_positives

    # Calculate pos_weight as ratio of negatives to positives
    pos_weight = num_negatives / (num_positives + 1e-8)
    pos_weight = pos_weight.to(device)

    print(f"✅ Positive weights: {pos_weight.cpu().numpy().round(3)}")

    # Initialize BCEWithLogitsLoss with positive weights
    bce_criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    # Initialize models
    text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)
    fusion_model = HeterogeneousFusion().to(device)

    models = [text_model, spatial_model, temporal_model, behavioral_model, fusion_model]

    # Train individual models first
    for model, name in zip(models[:-1], ['text', 'spatial', 'temporal', 'behavioral']):
        optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

        for epoch in range(epochs):
            model.train()
            total_loss = 0

            for batch in train_loader:
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                if name == 'text':
                    preds = model(batch['text_data'])
                elif name == 'spatial':
                    preds = model(batch['spatial_features'])
                elif name == 'temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # behavioral
                    preds = model(batch['behavioral_features'])

                loss = compute_bce_loss(preds, batch['targets'], bce_criterion)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if epoch % 10 == 0:
                print(f"{name.capitalize()} model epoch {epoch}, loss: {total_loss/len(train_loader):.4f}")

    # Train fusion model
    all_params = []
    for model in models:
        all_params.extend(model.parameters())

    optimizer_fusion = optim.AdamW(all_params, lr=5e-5, weight_decay=1e-4)

    for epoch in range(epochs):
        for model in models:
            model.train()

        total_loss = 0
        for batch in train_loader:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            # Extract features
            text_feat = text_model.get_features(batch['text_data'])
            spatial_feat = spatial_model.get_features(batch['spatial_features'])
            temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])

            # Fusion
            fusion_preds = fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
            loss = compute_bce_loss(fusion_preds, batch['targets'], bce_criterion)

            optimizer_fusion.zero_grad()
            loss.backward()
            optimizer_fusion.step()
            total_loss += loss.item()

        if epoch % 5 == 0:
            print(f"Fusion epoch {epoch}, loss: {total_loss/len(train_loader):.4f}")

    return text_model, spatial_model, temporal_model, behavioral_model, fusion_model

def evaluate_model_comprehensive(model, data_loader, device, model_type=None, text_model=None, spatial_model=None, temporal_model=None, behavioral_model=None):
    """Enhanced evaluation with comprehensive metrics"""
    model.eval()
    all_preds = {f: [] for f in moral_foundations}
    all_targets = {f: [] for f in moral_foundations}

    with torch.no_grad():
        for batch in data_loader:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            if model_type == 'fusion':
                # Extract features for fusion
                text_feat = text_model.get_features(batch['text_data'])
                spatial_feat = spatial_model.get_features(batch['spatial_features'])
                temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
                behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])
                preds = model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
            else:
                if model_type == 'text':
                    preds = model(batch['text_data'])
                elif model_type == 'spatial':
                    preds = model(batch['spatial_features'])
                elif model_type == 'temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # behavioral
                    preds = model(batch['behavioral_features'])

            targets = batch['targets']

            for i, foundation in enumerate(moral_foundations):
                # Apply sigmoid to get probabilities since we removed it from model outputs
                pred = torch.sigmoid(preds[foundation]).cpu().numpy().flatten()
                target = targets[:, i].cpu().numpy()
                all_preds[foundation].extend(pred)
                all_targets[foundation].extend(target)

    # Compute metrics
    metrics = {}
    avg_metrics = {'accuracy': 0, 'f1': 0, 'precision': 0, 'recall': 0, 'auc': 0}

    for foundation in moral_foundations:
        pred_binary = (np.array(all_preds[foundation]) > 0.5).astype(int)
        target_binary = np.array(all_targets[foundation]).astype(int)

        metrics[foundation] = {
            'accuracy': accuracy_score(target_binary, pred_binary),
            'f1': f1_score(target_binary, pred_binary, zero_division=0),
            'precision': precision_score(target_binary, pred_binary, zero_division=0),
            'recall': recall_score(target_binary, pred_binary, zero_division=0)
        }

        if len(np.unique(target_binary)) > 1:
            metrics[foundation]['auc'] = roc_auc_score(target_binary, all_preds[foundation])
        else:
            metrics[foundation]['auc'] = 0.0

        # Add to averages
        for metric in avg_metrics:
            avg_metrics[metric] += metrics[foundation][metric]

    # Calculate averages
    for metric in avg_metrics:
        avg_metrics[metric] /= len(moral_foundations)

    return metrics, avg_metrics

def print_comprehensive_results(all_metrics, all_avg_metrics):
    """Print results in the desired format"""
    print("COMPREHENSIVE MULTIMODAL GAT-eMFD EVALUATION RESULTS")
    print("=" * 80)
    print()

    model_names = ['GAT-eMFD TEXT', 'SPATIAL', 'TEMPORAL', 'BEHAVIORAL', 'FUSION']
    model_keys = ['text', 'spatial', 'temporal', 'behavioral', 'fusion']

    # Individual model results
    for name, key in zip(model_names, model_keys):
        if key in all_avg_metrics:
            avg = all_avg_metrics[key]
            print(f"{name} MODEL RESULTS:")
            print("-" * 50)
            print(f"  Average Accuracy:  {avg['accuracy']:.4f}")
            print(f"  Average F1 Score:  {avg['f1']:.4f}")
            print(f"  Average Precision: {avg['precision']:.4f}")
            print(f"  Average Recall:    {avg['recall']:.4f}")
            print(f"  Average AUC-ROC:   {avg['auc']:.4f}")
            print()

    # Foundation-specific comparison
    print("FOUNDATION-SPECIFIC COMPARISON:")
    print("-" * 80)
    header = "Foundation   GAT-Text F1 Spatial F1 Temporal F1 Behavioral F1"
    if 'fusion' in all_metrics:
        header += " Fusion F1"
    print(header)
    print("-" * 80)

    for foundation in moral_foundations:
        row = f"{foundation:<12}"
        for key in ['text', 'spatial', 'temporal', 'behavioral']:
            if key in all_metrics:
                f1_score = all_metrics[key][foundation]['f1']
                row += f" {f1_score:.3f}      "
        if 'fusion' in all_metrics:
            f1_score = all_metrics['fusion'][foundation]['f1']
            row += f" {f1_score:.3f}    "
        print(row)

# ===================================================================
# 7. Main Execution
# ===================================================================
if __name__ == "__main__":
    # File paths - UPDATE THESE
    CSV_PATH = '/content/augmented_tweets.csv'
    GEOJSON_PATH = '/content/county_data_boarders.json'
    CENTROIDS_PATH = '/content/county_centroids.json'
    EMFD_CSV_PATH = '/content/eMFD_wordlist.csv'  # UPDATE THIS PATH

    if not GATConv:
        print("❌ PyTorch Geometric not available. Please install it.")
        exit()

    # Prepare dataset
    datasets, feature_info = prepare_dataset(
        csv_path=CSV_PATH,
        geojson_path=GEOJSON_PATH,
        county_centroids_path=CENTROIDS_PATH,
        seq_len=10
    )

    # Train models
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    text_model, spatial_model, temporal_model, behavioral_model, fusion_model = train_models(
        datasets, feature_info, device, EMFD_CSV_PATH
    )

    # Evaluate all models
    test_loader = create_dataloader(datasets['test'], batch_size=16, shuffle=False)

    all_metrics = {}
    all_avg_metrics = {}

    model_list = [
        (text_model, 'text'),
        (spatial_model, 'spatial'),
        (temporal_model, 'temporal'),
        (behavioral_model, 'behavioral')
    ]

    # Evaluate individual models
    for model, name in model_list:
        metrics, avg_metrics = evaluate_model_comprehensive(
            model, test_loader, device, name
        )
        all_metrics[name] = metrics
        all_avg_metrics[name] = avg_metrics

    # Evaluate fusion model
    metrics, avg_metrics = evaluate_model_comprehensive(
        fusion_model, test_loader, device, 'fusion',
        text_model, spatial_model, temporal_model, behavioral_model
    )
    all_metrics['fusion'] = metrics
    all_avg_metrics['fusion'] = avg_metrics

    # Print comprehensive results
    print_comprehensive_results(all_metrics, all_avg_metrics)

    print("\n✅ Training and evaluation completed!")


Using device: cuda
⚖️ Calculating positive weights for BCEWithLogitsLoss...
✅ Positive weights: [ 0.323 85.     5.    10.217 50.6   10.727]
Text model epoch 0, loss: 1.5254
Text model epoch 10, loss: 1.0255
Text model epoch 20, loss: 0.9054
Spatial model epoch 0, loss: 1.1312
Spatial model epoch 10, loss: 1.0447
Spatial model epoch 20, loss: 0.9919
Temporal model epoch 0, loss: 1.1423
Temporal model epoch 10, loss: 0.8371
Temporal model epoch 20, loss: 0.7096
Behavioral model epoch 0, loss: 1.0925
Behavioral model epoch 10, loss: 0.8702
Behavioral model epoch 20, loss: 0.8175
Fusion epoch 0, loss: 1.1300
Fusion epoch 5, loss: 0.7517
Fusion epoch 10, loss: 0.5219
Fusion epoch 15, loss: 0.4178
Fusion epoch 20, loss: 0.3479
Fusion epoch 25, loss: 0.2874
COMPREHENSIVE MULTIMODAL GAT-eMFD EVALUATION RESULTS

GAT-eMFD TEXT MODEL RESULTS:
--------------------------------------------------
  Average Accuracy:  0.5744
  Average F1 Score:  0.1255
  Average Precision: 0.1478
  Average Recall:    

# with gat training with focal

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, Dropout, Sigmoid, ModuleDict, LayerNorm, LSTM, GELU, MultiheadAttention, Softmax
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, classification_report
import pandas as pd
import numpy as np
import json
import math
import os
from tqdm.autonotebook import tqdm

# PyTorch Geometric imports for GAT
try:
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
except ImportError:
    print("PyTorch Geometric not found. Please install: pip install torch_geometric")
    GATConv = None
    Data = None

# Moral foundations
moral_foundations = ['Care', 'Fairness', 'Loyalty', 'Authority', 'Purity', 'Non_Moral']

# ===================================================================
# 1. eMFD Processing and Graph Construction
# ===================================================================
class eMFDProcessor:
    def __init__(self, emfd_csv_path):
        self.moral_foundations_emfd = ['care', 'fairness', 'loyalty', 'authority', 'purity', 'non-moral']
        if not os.path.exists(emfd_csv_path):
            raise FileNotFoundError(f"eMFD file not found at {emfd_csv_path}")
        df = pd.read_csv(emfd_csv_path)
        prob_cols = [f'{f}_p' for f in self.moral_foundations_emfd if f != 'non-moral']
        self.emfd_data = {
            row['word']: np.append([row.get(col, 0.0) for col in prob_cols], 0.0).astype(np.float32)
            for _, row in df.iterrows()
        }

    def extract_moral_concepts(self, text):
        concepts = [
            {'word': w, 'probabilities': self.emfd_data[w]}
            for w in text.lower().split() if w in self.emfd_data
        ]
        return concepts if concepts else [{'word': 'neutral', 'probabilities': np.zeros(6, dtype=np.float32)}]

class MoralGraphConstructor:
    def create_moral_graph(self, moral_concepts):
        num_concepts = len(moral_concepts)
        node_features = torch.FloatTensor(np.array([
            np.concatenate([c['probabilities'], np.zeros(256)]) for c in moral_concepts
        ]))

        if num_concepts == 1:
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
        else:
            edge_index = torch.tensor(
                [[i, j] for i in range(num_concepts) for j in range(num_concepts) if i != j],
                dtype=torch.long
            ).t().contiguous()

        return Data(x=node_features, edge_index=edge_index) if Data else None

# ===================================================================
# 2. GAT eMFD Module
# ===================================================================
class GATeMFDModule(nn.Module):
    def __init__(self, input_dim=262, output_dim=256, num_heads=4, dropout=0.1):
        super(GATeMFDModule, self).__init__()
        self.concept_embeddings = nn.Embedding(1000, 256)
        self.gat_layer = GATConv(input_dim, output_dim, heads=num_heads, dropout=dropout, concat=False) if GATConv else nn.Identity()
        self.attention_pooling = nn.Linear(output_dim, 1)

    def forward(self, graph_data, device):
        if not GATConv:
            return torch.zeros((1, 256)).to(device)
        x, edge_index = graph_data.x.to(device), graph_data.edge_index.to(device)
        x[:, 6:] = self.concept_embeddings(torch.arange(x.size(0)).long().to(device))
        gat_output = F.elu(self.gat_layer(x, edge_index))
        attn_weights = F.softmax(self.attention_pooling(gat_output), dim=0)
        return torch.sum(attn_weights * gat_output, dim=0, keepdim=True)

class TextModelGATeMFD(nn.Module):
    def __init__(self, emfd_csv_path):
        super().__init__()
        self.emfd_processor = eMFDProcessor(emfd_csv_path)
        self.graph_constructor = MoralGraphConstructor()
        self.gat_module = GATeMFDModule()

        # Additional processing layers
        self.processor = Sequential(
            Linear(256, 512), LayerNorm(512), GELU(), Dropout(0.1),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(0.1)
        )

        # Prediction heads - Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            foundation: Sequential(Linear(256, 128), ReLU(), Dropout(0.2),
                                 Linear(128, 1))
            for foundation in moral_foundations
        })

    def forward(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        processed_features = self.processor(features)
        return {f: self.heads[f](processed_features) for f in moral_foundations}

    def get_features(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        return self.processor(features)

# ===================================================================
# 3. Other Modality Models
# ===================================================================
class SpatialModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), ReLU(), Dropout(0.2),
            Linear(128, 256)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class TemporalModelLSTM(nn.Module):
    def __init__(self, input_dim, seq_len=10, hidden_size=128):
        super().__init__()
        self.projection = Linear(input_dim, hidden_size)
        self.lstm = LSTM(hidden_size, hidden_size, 2, batch_first=True,
                        dropout=0.2, bidirectional=True)
        self.encoder = Sequential(
            Linear(hidden_size * 2, 256), LayerNorm(256),
            ReLU(), Dropout(0.2)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        features = self.encoder(lstm_out[:, -1])
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        return self.encoder(lstm_out[:, -1])

class BehavioralModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), LayerNorm(128), ReLU(), Dropout(0.1),
            Linear(128, 256), LayerNorm(256), ReLU(), Dropout(0.1)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

# ===================================================================
# 4. Heterogeneous Fusion
# ===================================================================
class HeterogeneousFusion(nn.Module):
    def __init__(self, input_dims=None, d_model=256, num_heads=8):
        super().__init__()
        self.d_model = d_model

        # Feature projection layers to align dimensions
        if input_dims:
            self.projections = nn.ModuleList([
                nn.Linear(dim, d_model) for dim in input_dims
            ])
        else:
            self.projections = nn.ModuleList([
                nn.Linear(256, d_model) for _ in range(4)
            ])

        self.modality_queries = nn.Parameter(torch.randn(4, d_model))

        self.cross_attention = nn.ModuleList([
            MultiheadAttention(d_model, num_heads, batch_first=True)
            for _ in range(4)
        ])

        self.self_attention = MultiheadAttention(d_model, num_heads, batch_first=True)
        self.layer_norms = nn.ModuleList([LayerNorm(d_model) for _ in range(5)])

        # Improved gating mechanism
        self.modality_gate = nn.Sequential(
            Linear(d_model * 4, d_model), ReLU(), Dropout(0.1),
            Linear(d_model, 4), Softmax(dim=-1)
        )

        self.final_projection = Linear(d_model, d_model)

        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(d_model, 128), ReLU(), Dropout(0.2),
                         Linear(128, 1))
            for f in moral_foundations
        })

    def forward(self, text_feat, spatial_feat, temporal_feat, behavioral_feat):
        batch_size = text_feat.size(0)

        # Project all features to same dimension
        features = [
            self.projections[i](feat)
            for i, feat in enumerate([text_feat, spatial_feat, temporal_feat, behavioral_feat])
        ]

        # Add sequence dimension for attention
        modality_feats = [f.unsqueeze(1) for f in features]

        # Cross-attention with residual connections
        attended = []
        for i, (feat, attn, norm) in enumerate(zip(modality_feats, self.cross_attention, self.layer_norms[:4])):
            query = self.modality_queries[i:i+1].unsqueeze(0).expand(batch_size, -1, -1)
            att_out, _ = attn(query, feat, feat)
            attended.append(norm(att_out.squeeze(1) + features[i]))

        # Stack for self-attention
        stacked = torch.stack(attended, dim=1)

        # Self-attention
        self_att, _ = self.self_attention(stacked, stacked, stacked)
        fused = self.layer_norms[4](stacked + self_att)

        # Improved gating: compute gates from aggregated representation
        pooled = fused.mean(dim=1)
        gates = self.modality_gate(fused.view(batch_size, -1))

        # Apply gates to individual modality representations
        gated = torch.sum(fused * gates.unsqueeze(-1), dim=1)

        # Final projection
        final_features = self.final_projection(gated)

        return {f: self.heads[f](final_features) for f in moral_foundations}

# ===================================================================
# 5. Custom Loss Function
# ===================================================================
class MultiLabelFocalLoss(nn.Module):
    """Focal Loss for multi-label classification."""
    def __init__(self, gamma=2.0, alpha=None):
        super(MultiLabelFocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none', pos_weight=self.alpha)
        p_t = torch.exp(-bce_loss)
        focal_loss = ((1 - p_t) ** self.gamma) * bce_loss
        return focal_loss.mean()

# ===================================================================
# 6. Data Processing Functions
# ===================================================================
def prepare_labels(df):
    """Combine virtue/vice pairs and create Non-Moral labels"""
    df = df.reset_index(drop=True)

    # Combine virtue/vice pairs using logical OR
    labels = {}
    labels['Care'] = (df['Care'] == 1) | (df['Harm'] == 1)
    labels['Fairness'] = (df['Fairness'] == 1) | (df.get('Cheating', df.get('cheating', pd.Series([False]*len(df), index=df.index))) == 1)
    labels['Loyalty'] = (df['Loyalty'] == 1) | (df['Betrayal'] == 1)
    labels['Authority'] = (df['Authority'] == 1) | (df['Subversion'] == 1)
    labels['Purity'] = (df['Purity'] == 1) | (df['Degradation'] == 1)

    # Create DataFrame from moral foundation labels
    moral_df = pd.DataFrame({
        'Care': labels['Care'],
        'Fairness': labels['Fairness'],
        'Loyalty': labels['Loyalty'],
        'Authority': labels['Authority'],
        'Purity': labels['Purity']
    })

    # Non_Moral = 1 when all moral foundations are 0 (False)
    labels['Non_Moral'] = (~moral_df.any(axis=1)).astype(int)

    # Convert to numpy arrays
    labels_array = np.column_stack([labels[f].astype(int).values for f in moral_foundations])
    return torch.tensor(labels_array, dtype=torch.float32)

def create_temporal_sequences(df, seq_len=10):
    """Create temporal sequences from dataset"""
    df_sorted = df.sort_values(['GEOID', 'year', 'month', 'day'])
    sequences, targets, other_features = [], [], []

    for geoid, group in df_sorted.groupby('GEOID'):
        if len(group) >= seq_len:
            for i in range(len(group) - seq_len + 1):
                seq_data = group.iloc[i:i+seq_len]

                # Temporal features
                temp_seq = []
                for _, row in seq_data.iterrows():
                    temp_features = [
                        row['month'], row['day'], row['year'], row['after_05/26/2020'],
                        row['female_pct'], row['lesscollege_pct'], row['cvap'],
                        row['clf_unemploy_pct'], row['lesshs_pct'], row['median_hh_inc'],
                        row['urm_pct'], row['ruralurban_cc'], row['demgov'], row['repgov'],
                        row['net_dem_president_votes'], row['net_dem_gov_votes'], row['is_blue'],
                        row['deaths'], row['cases'], row['cases_per_capita'],
                        row['deaths_per_capita'], row['cases_per_capita_discrete'], row['mask_score']
                    ]
                    temp_seq.append(temp_features)

                sequences.append(temp_seq)

                last_row = seq_data.iloc[-1]
                # Prepare labels for last timestep
                last_df = pd.DataFrame([last_row])
                targets.append(prepare_labels(last_df)[0])

                other_features.append({
                    'text': last_row['text'],
                    'geoid': str(last_row['GEOID']).zfill(5),
                    'behavioral': [last_row['retweet_count'], last_row['rt_discrete'],
                                 last_row['for_sah'], last_row['is_vivid'], last_row['sentiment_score']]
                })

    return sequences, targets, other_features

def prepare_dataset(csv_path, geojson_path, county_centroids_path, seq_len=10):
    """Streamlined dataset preparation"""
    # Load data
    df = pd.read_csv(csv_path)
    with open(geojson_path) as f:
        boundaries = json.load(f)
    with open(county_centroids_path) as f:
        centroids = json.load(f)

    # Create sequences
    sequences, targets, other_features = create_temporal_sequences(df, seq_len)

    # Process features
    data = {'text': [], 'spatial': [], 'temporal': [], 'behavioral': [], 'targets': []}

    for i, (seq, target, other) in enumerate(zip(sequences, targets, other_features)):
        data['text'].append(other['text'])
        data['temporal'].append(seq)
        data['behavioral'].append(other['behavioral'])
        data['targets'].append(target)

        # Extract spatial features
        geoid = other['geoid']
        if geoid in boundaries and geoid in centroids:
            # Simplified spatial feature extraction
            centroid = centroids[geoid]
            lat_norm = (centroid[0] - 24) / (49 - 24)  # Normalize lat
            lon_norm = (centroid[1] + 125) / (66 - 125)  # Normalize lon
            spatial_feat = [lat_norm, lon_norm, 0.5, 0.5, 0.5, 0.5]  # Simplified features
        else:
            spatial_feat = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]  # Default

        data['spatial'].append(spatial_feat)

    # Convert to tensors
    spatial_tensor = torch.tensor(data['spatial'], dtype=torch.float32)
    temporal_tensor = torch.tensor(data['temporal'], dtype=torch.float32)
    behavioral_tensor = torch.tensor(data['behavioral'], dtype=torch.float32)
    targets_tensor = torch.stack(data['targets'])

    # Normalize features
    scaler_spatial = StandardScaler()
    scaler_behavioral = StandardScaler()
    scaler_temporal = StandardScaler()

    spatial_tensor = torch.tensor(scaler_spatial.fit_transform(spatial_tensor), dtype=torch.float32)
    behavioral_tensor = torch.tensor(scaler_behavioral.fit_transform(behavioral_tensor), dtype=torch.float32)

    # Normalize temporal (reshape for scaling)
    temp_shape = temporal_tensor.shape
    temp_2d = temporal_tensor.view(-1, temp_shape[2])
    temp_scaled = scaler_temporal.fit_transform(temp_2d)
    temporal_tensor = torch.tensor(temp_scaled.reshape(temp_shape), dtype=torch.float32)

    # Create splits
    indices = list(range(len(data['text'])))
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

    datasets = {}
    for split, idx in [('train', train_idx), ('validation', val_idx), ('test', test_idx)]:
        datasets[split] = {
            'text_data': [data['text'][i] for i in idx],
            'spatial_features': spatial_tensor[idx],
            'temporal_sequences': temporal_tensor[idx],
            'behavioral_features': behavioral_tensor[idx],
            'moral_targets': targets_tensor[idx]
        }

    feature_info = {
        'spatial_feature_dim': spatial_tensor.shape[1],
        'temporal_feature_dim': temporal_tensor.shape[2],
        'temporal_sequence_length': seq_len,
        'behavioral_feature_dim': behavioral_tensor.shape[1]
    }

    return datasets, feature_info

def create_dataloader(dataset_dict, batch_size=16, shuffle=True):
    """Create DataLoader for the dataset"""
    def collate_fn(batch):
        texts = [item[0] for item in batch]
        spatial = torch.stack([item[1] for item in batch])
        temporal = torch.stack([item[2] for item in batch])
        behavioral = torch.stack([item[3] for item in batch])
        targets = torch.stack([item[4] for item in batch])

        return {
            'text_data': texts,
            'spatial_features': spatial,
            'temporal_sequences': temporal,
            'behavioral_features': behavioral,
            'targets': targets
        }

    dataset = [(dataset_dict['text_data'][i],
                dataset_dict['spatial_features'][i],
                dataset_dict['temporal_sequences'][i],
                dataset_dict['behavioral_features'][i],
                dataset_dict['moral_targets'][i])
               for i in range(len(dataset_dict['text_data']))]

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

# ===================================================================
# 7. Multi-Stage Training System
# ===================================================================
class MultiStageTrainer:
    def __init__(self, models, datasets, feature_info, device, emfd_csv_path, focal_loss_alpha=None, focal_loss_gamma=2.0):
        self.text_model, self.spatial_model, self.temporal_model, self.behavioral_model, self.fusion_model = models
        self.datasets = datasets
        self.device = device

        # Move models to device
        self.text_model = self.text_model.to(device)
        self.spatial_model = self.spatial_model.to(device)
        self.temporal_model = self.temporal_model.to(device)
        self.behavioral_model = self.behavioral_model.to(device)
        self.fusion_model = self.fusion_model.to(device)

        # Create data loaders
        self.train_loader = create_dataloader(datasets['train'], batch_size=16, shuffle=True)
        self.val_loader = create_dataloader(datasets['validation'], batch_size=16, shuffle=False)

        # Loss function with focal loss
        print(f"🔥 Using MultiLabelFocalLoss with gamma={focal_loss_gamma}")
        self.criterion = MultiLabelFocalLoss(gamma=focal_loss_gamma, alpha=focal_loss_alpha)

    def _freeze(self, model):
        for p in model.parameters():
            p.requires_grad_(False)

    def _unfreeze(self, model):
        for p in model.parameters():
            p.requires_grad_(True)

    def compute_loss(self, predictions, targets):
        """Compute loss for moral foundation predictions"""
        pred_stack = torch.cat([predictions[f] for f in moral_foundations], dim=1)
        return self.criterion(pred_stack, targets)

    def train_stage1_text_gat_emfd(self, epochs=15, lr=2e-5):
        """Stage 1: Train GAT eMFD Text Model"""
        print("\n" + "="*20 + " Stage 1: Training GAT eMFD Text Model " + "="*20)

        # Freeze other models
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._freeze(self.fusion_model)
        self._unfreeze(self.text_model)

        optimizer = optim.AdamW(self.text_model.parameters(), lr=lr, weight_decay=1e-5)

        for epoch in range(epochs):
            self.text_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"Stage 1 - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer.zero_grad()
                preds = self.text_model(batch['text_data'])
                loss = self.compute_loss(preds, batch['targets'])
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if epoch % 3 == 0:
                print(f"GAT eMFD Text model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

        # Save model
        torch.save(self.text_model.state_dict(), "gat_emfd_text_stage1.pth")
        print("✅ Stage 1 Complete. GAT eMFD Text model saved.")

    def train_stage2_other_modalities(self, epochs=10, lr=5e-4):
        """Stage 2: Train other modality models"""
        print("\n" + "="*20 + " Stage 2: Training Other Modality Models " + "="*20)

        # Freeze text and fusion models
        self._freeze(self.text_model)
        self._freeze(self.fusion_model)

        models_to_train = [
            (self.spatial_model, 'spatial'),
            (self.temporal_model, 'temporal'),
            (self.behavioral_model, 'behavioral')
        ]

        for model, name in models_to_train:
            print(f"\n--- Training {name.capitalize()} Model ---")
            self._unfreeze(model)
            optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

            for epoch in range(epochs):
                model.train()
                total_loss = 0

                for batch in tqdm(self.train_loader, desc=f"{name.capitalize()} - Epoch {epoch+1}/{epochs}"):
                    batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                            for k, v in batch.items()}

                    optimizer.zero_grad()

                    if name == 'spatial':
                        preds = model(batch['spatial_features'])
                    elif name == 'temporal':
                        preds = model(batch['temporal_sequences'])
                    else:  # behavioral
                        preds = model(batch['behavioral_features'])

                    loss = self.compute_loss(preds, batch['targets'])
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()

                if epoch % 3 == 0:
                    print(f"{name.capitalize()} model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

            # Save model
            torch.save(model.state_dict(), f"{name}_model_stage2.pth")
            self._freeze(model)

        print("✅ Stage 2 Complete. All modality models saved.")

    def train_stage3_fusion_integration(self, epochs=8, lr=1e-3):
        """Stage 3: Fusion Integration"""
        print("\n" + "="*15 + " Stage 3: Fusion Integration " + "="*15)

        # Freeze all individual models, unfreeze fusion
        self._freeze(self.text_model)
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._unfreeze(self.fusion_model)

        optimizer = optim.AdamW(self.fusion_model.parameters(), lr=lr, weight_decay=1e-4)

        for epoch in range(epochs):
            self.fusion_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"Stage 3 - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer.zero_grad()

                # Extract features with no gradients for individual models
                with torch.no_grad():
                    text_feat = self.text_model.get_features(batch['text_data'])
                    spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                    temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                    behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                # Fusion with gradients
                fusion_preds = self.fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
                loss = self.compute_loss(fusion_preds, batch['targets'])
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if epoch % 2 == 0:
                val_f1 = self._validate_fusion()
                print(f"Stage 3 - Epoch {epoch+1}: Loss: {total_loss/len(self.train_loader):.4f}, Val F1: {val_f1:.4f}")

        torch.save(self.fusion_model.state_dict(), "fusion_model_stage3.pth")
        print("✅ Stage 3 Complete. Fusion model saved.")

    def train_stage4_end_to_end_finetuning(self, epochs=5, best_model_path="best_fusion_model.pth"):
        """Stage 4: End-to-end fine-tuning"""
        print("\n" + "="*15 + " Stage 4: End-to-End Fine-tuning " + "="*15)

        # Unfreeze GAT eMFD and fusion, keep others frozen
        self._unfreeze(self.text_model)
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._unfreeze(self.fusion_model)

        optimizer = optim.AdamW([
            {'params': self.text_model.parameters(), 'lr': 5e-6},
            {'params': self.fusion_model.parameters(), 'lr': 1e-4}
        ], weight_decay=1e-4)

        best_val_f1 = 0.0
        patience_counter = 0

        for epoch in range(epochs):
            # Set training modes
            self.text_model.train()
            self.fusion_model.train()
            self.spatial_model.eval()
            self.temporal_model.eval()
            self.behavioral_model.eval()

            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"Stage 4 - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer.zero_grad()

                # Extract features (GAT eMFD with gradients, others frozen)
                text_feat = self.text_model.get_features(batch['text_data'])
                with torch.no_grad():
                    spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                    temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                    behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                # Fusion
                fusion_preds = self.fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
                loss = self.compute_loss(fusion_preds, batch['targets'])
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            val_f1 = self._validate_fusion()
            print(f"Stage 4 - Epoch {epoch+1}: Loss: {total_loss/len(self.train_loader):.4f}, Val F1: {val_f1:.4f}")

            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                patience_counter = 0
                torch.save({
                    'text_model': self.text_model.state_dict(),
                    'spatial_model': self.spatial_model.state_dict(),
                    'temporal_model': self.temporal_model.state_dict(),
                    'behavioral_model': self.behavioral_model.state_dict(),
                    'fusion_model': self.fusion_model.state_dict()
                }, best_model_path)
                print(f"🏆 New best model saved with F1: {val_f1:.4f}")
            else:
                patience_counter += 1
                if patience_counter >= 2:
                    print("🛑 Early stopping triggered.")
                    break

        print("✅ Stage 4 Complete. End-to-end training finished.")

    def _validate_fusion(self):
        """Validate fusion model"""
        self.text_model.eval()
        self.spatial_model.eval()
        self.temporal_model.eval()
        self.behavioral_model.eval()
        self.fusion_model.eval()

        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in self.val_loader:
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                text_feat = self.text_model.get_features(batch['text_data'])
                spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                fusion_preds = self.fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
                pred_stack = torch.cat([fusion_preds[f] for f in moral_foundations], dim=1)

                all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
                all_labels.append(batch['targets'].cpu().numpy())

        y_pred = np.vstack(all_preds)
        y_true = np.vstack(all_labels)
        return f1_score(y_true, y_pred, average='macro', zero_division=0)

# ===================================================================
# 8. Evaluation Functions
# ===================================================================
def evaluate_gat_emfd_only(test_loader, device, emfd_csv_path):
    """Evaluates the standalone GAT eMFD model"""
    print("\n" + "="*20 + " ABLATION: GAT eMFD TEXT MODEL ONLY " + "="*20)

    text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    text_model.load_state_dict(torch.load("gat_emfd_text_stage1.pth"))
    text_model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing GAT eMFD"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            preds = text_model(batch['text_data'])
            pred_stack = torch.cat([preds[f] for f in moral_foundations], dim=1)

            all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
            all_labels.append(batch['targets'].cpu().numpy())

    y_pred = np.vstack(all_preds)
    y_true = np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    print(classification_report(y_true, y_pred, target_names=moral_foundations, zero_division=0))
    print(f"** GAT eMFD Macro F1-Score: {macro_f1:.4f} **")
    return macro_f1

def evaluate_individual_models(test_loader, device, feature_info, emfd_csv_path):
    """Evaluate all individual models"""
    print("\n" + "="*20 + " INDIVIDUAL MODEL EVALUATION " + "="*20)

    # Load models
    text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)

    # Load weights
    text_model.load_state_dict(torch.load("gat_emfd_text_stage1.pth"))
    spatial_model.load_state_dict(torch.load("spatial_model_stage2.pth"))
    temporal_model.load_state_dict(torch.load("temporal_model_stage2.pth"))
    behavioral_model.load_state_dict(torch.load("behavioral_model_stage2.pth"))

    models = [
        (text_model, 'GAT-eMFD Text'),
        (spatial_model, 'Spatial'),
        (temporal_model, 'Temporal'),
        (behavioral_model, 'Behavioral')
    ]

    results = {}

    for model, name in models:
        model.eval()
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in tqdm(test_loader, desc=f"Testing {name}"):
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                if name == 'GAT-eMFD Text':
                    preds = model(batch['text_data'])
                elif name == 'Spatial':
                    preds = model(batch['spatial_features'])
                elif name == 'Temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # Behavioral
                    preds = model(batch['behavioral_features'])

                pred_stack = torch.cat([preds[f] for f in moral_foundations], dim=1)
                all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
                all_labels.append(batch['targets'].cpu().numpy())

        y_pred = np.vstack(all_preds)
        y_true = np.vstack(all_labels)
        macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

        print(f"\n--- {name} Model Results ---")
        print(f"Macro F1-Score: {macro_f1:.4f}")

        # Foundation-specific results
        foundation_f1s = []
        for i, foundation in enumerate(moral_foundations):
            f1 = f1_score(y_true[:, i], y_pred[:, i], zero_division=0)
            foundation_f1s.append(f1)

        results[name] = {
            'macro_f1': macro_f1,
            'foundation_f1s': foundation_f1s
        }

    return results

def evaluate_fusion_model(model_path, test_loader, device, feature_info, emfd_csv_path):
    """Evaluate the final fusion model"""
    print("\n" + "="*20 + " FINAL FUSION MODEL EVALUATION " + "="*20)

    # Load models
    text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)
    fusion_model = HeterogeneousFusion().to(device)

    # Load best model state
    checkpoint = torch.load(model_path)
    text_model.load_state_dict(checkpoint['text_model'])
    spatial_model.load_state_dict(checkpoint['spatial_model'])
    temporal_model.load_state_dict(checkpoint['temporal_model'])
    behavioral_model.load_state_dict(checkpoint['behavioral_model'])
    fusion_model.load_state_dict(checkpoint['fusion_model'])

    # Set to eval mode
    text_model.eval()
    spatial_model.eval()
    temporal_model.eval()
    behavioral_model.eval()
    fusion_model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing Fusion Model"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            # Extract features
            text_feat = text_model.get_features(batch['text_data'])
            spatial_feat = spatial_model.get_features(batch['spatial_features'])
            temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])

            # Fusion
            fusion_preds = fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
            pred_stack = torch.cat([fusion_preds[f] for f in moral_foundations], dim=1)

            all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
            all_labels.append(batch['targets'].cpu().numpy())

    y_pred = np.vstack(all_preds)
    y_true = np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    print(classification_report(y_true, y_pred, target_names=moral_foundations, zero_division=0))
    print(f"** FUSION MODEL Macro F1-Score: {macro_f1:.4f} **")

    return macro_f1

def print_comprehensive_results(individual_results, fusion_f1):
    """Print comprehensive comparison results"""
    print("\n" + "="*80)
    print("COMPREHENSIVE MULTIMODAL GAT-eMFD EVALUATION RESULTS")
    print("=" * 80)

    # Individual model results
    for model_name, results in individual_results.items():
        print(f"\n{model_name} MODEL:")
        print("-" * 50)
        print(f"  Macro F1-Score: {results['macro_f1']:.4f}")
        for i, foundation in enumerate(moral_foundations):
            print(f"  {foundation} F1: {results['foundation_f1s'][i]:.4f}")

    print(f"\nFUSION MODEL:")
    print("-" * 50)
    print(f"  Macro F1-Score: {fusion_f1:.4f}")

    # Foundation-specific comparison
    print("\n" + "="*80)
    print("FOUNDATION-SPECIFIC F1 COMPARISON:")
    print("=" * 80)
    header = "Foundation     GAT-eMFD  Spatial   Temporal  Behavioral  Fusion"
    print(header)
    print("-" * 70)

    for i, foundation in enumerate(moral_foundations):
        row = f"{foundation:<12}"
        for model_name in ['GAT-eMFD Text', 'Spatial', 'Temporal', 'Behavioral']:
            if model_name in individual_results:
                f1_score = individual_results[model_name]['foundation_f1s'][i]
                row += f"  {f1_score:.3f}    "
        row += f"  {fusion_f1:.3f}"
        print(row)

# ===================================================================
# 9. Main Execution
# ===================================================================
if __name__ == "__main__":
    # File paths - UPDATE THESE
    CSV_PATH = '/content/augmented_tweets.csv'
    GEOJSON_PATH = '/content/county_data_boarders.json'
    CENTROIDS_PATH = '/content/county_centroids.json'
    EMFD_CSV_PATH = '/content/eMFD_wordlist.csv'

    if not GATConv:
        print("❌ PyTorch Geometric not available. Please install it.")
        exit()

    # Prepare dataset
    print("🔄 Preparing dataset...")
    datasets, feature_info = prepare_dataset(
        csv_path=CSV_PATH,
        geojson_path=GEOJSON_PATH,
        county_centroids_path=CENTROIDS_PATH,
        seq_len=10
    )

    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Calculate class weights for focal loss
    print("⚖️ Calculating class weights for imbalanced data...")
    train_labels = datasets['train']['moral_targets'].numpy()
    num_positives = np.sum(train_labels, axis=0)
    num_negatives = len(train_labels) - num_positives
    pos_weight = np.where(num_positives > 0, num_negatives / num_positives, 1.0)
    pos_weight_tensor = torch.tensor(pos_weight, dtype=torch.float32).to(device)
    print(f"✅ Calculated pos_weight: {pos_weight_tensor.cpu().numpy().round(2)}")

    # Initialize models
    text_model = TextModelGATeMFD(EMFD_CSV_PATH)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim'])
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    )
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim'])
    fusion_model = HeterogeneousFusion()

    models = [text_model, spatial_model, temporal_model, behavioral_model, fusion_model]

    # Initialize trainer
    trainer = MultiStageTrainer(
        models=models,
        datasets=datasets,
        feature_info=feature_info,
        device=device,
        emfd_csv_path=EMFD_CSV_PATH,
        focal_loss_alpha=pos_weight_tensor,
        focal_loss_gamma=2.0
    )

    # Execute multi-stage training
    print("\n🚀 Starting Multi-Stage Training Pipeline...")
    print("=" * 80)

    trainer.train_stage1_text_gat_emfd(epochs=5, lr=2e-5)
    trainer.train_stage2_other_modalities(epochs=10, lr=5e-4)
    trainer.train_stage3_fusion_integration(epochs=2, lr=1e-3)
    trainer.train_stage4_end_to_end_finetuning(epochs=2, best_model_path="best_multimodal_gat_emfd.pth")

    # Final evaluation
    print("\n🔍 Starting Comprehensive Evaluation...")
    print("=" * 80)

    test_loader = create_dataloader(datasets['test'], batch_size=16, shuffle=False)

    # Evaluate individual models
    individual_results = evaluate_individual_models(test_loader, device, feature_info, EMFD_CSV_PATH)

    # Evaluate fusion model
    fusion_f1 = evaluate_fusion_model("best_multimodal_gat_emfd.pth", test_loader, device, feature_info, EMFD_CSV_PATH)

    # Print comprehensive results
    print_comprehensive_results(individual_results, fusion_f1)

    print("\n✅ Multi-stage training and comprehensive evaluation completed!")
    print("🏆 Best models saved as 'best_multimodal_gat_emfd.pth'")


🔄 Preparing dataset...
Using device: cuda
⚖️ Calculating class weights for imbalanced data...
✅ Calculated pos_weight: [ 0.32 85.    5.   10.22 50.6  10.73]
🔥 Using MultiLabelFocalLoss with gamma=2.0

🚀 Starting Multi-Stage Training Pipeline...



Stage 1 - Epoch 1/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT eMFD Text model epoch 1, loss: 0.6764


Stage 1 - Epoch 2/5:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 1 - Epoch 3/5:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 1 - Epoch 4/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT eMFD Text model epoch 4, loss: 0.6516


Stage 1 - Epoch 5/5:   0%|          | 0/17 [00:00<?, ?it/s]

✅ Stage 1 Complete. GAT eMFD Text model saved.


--- Training Spatial Model ---


Spatial - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 1, loss: 0.6802


Spatial - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 4, loss: 0.5838


Spatial - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 7, loss: 0.5502


Spatial - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 10, loss: 0.5353

--- Training Temporal Model ---


Temporal - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 1, loss: 0.7251


Temporal - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 4, loss: 0.5153


Temporal - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 7, loss: 0.4198


Temporal - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 10, loss: 0.3754

--- Training Behavioral Model ---


Behavioral - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 1, loss: 1.1067


Behavioral - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 4, loss: 0.4986


Behavioral - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 7, loss: 0.4779


Behavioral - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 10, loss: 0.4511
✅ Stage 2 Complete. All modality models saved.



Stage 3 - Epoch 1/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 3 - Epoch 1: Loss: 0.5599, Val F1: 0.1991


Stage 3 - Epoch 2/2:   0%|          | 0/17 [00:00<?, ?it/s]

✅ Stage 3 Complete. Fusion model saved.



Stage 4 - Epoch 1/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 4 - Epoch 1: Loss: 0.3059, Val F1: 0.2352
🏆 New best model saved with F1: 0.2352


Stage 4 - Epoch 2/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 4 - Epoch 2: Loss: 0.2597, Val F1: 0.2482
🏆 New best model saved with F1: 0.2482
✅ Stage 4 Complete. End-to-end training finished.

🔍 Starting Comprehensive Evaluation...



Testing GAT-eMFD Text:   0%|          | 0/4 [00:00<?, ?it/s]


--- GAT-eMFD Text Model Results ---
Macro F1-Score: 0.0926


Testing Spatial:   0%|          | 0/4 [00:00<?, ?it/s]


--- Spatial Model Results ---
Macro F1-Score: 0.0905


Testing Temporal:   0%|          | 0/4 [00:00<?, ?it/s]


--- Temporal Model Results ---
Macro F1-Score: 0.1549


Testing Behavioral:   0%|          | 0/4 [00:00<?, ?it/s]


--- Behavioral Model Results ---
Macro F1-Score: 0.1958



Testing Fusion Model:   0%|          | 0/4 [00:00<?, ?it/s]

              precision    recall  f1-score   support

        Care       1.00      0.16      0.27        44
    Fairness       0.00      0.00      0.00         0
     Loyalty       0.20      0.89      0.33         9
   Authority       0.09      0.75      0.17         4
      Purity       0.00      0.00      0.00         1
   Non_Moral       0.20      0.33      0.25         3

   micro avg       0.20      0.31      0.24        61
   macro avg       0.25      0.36      0.17        61
weighted avg       0.77      0.31      0.27        61
 samples avg       0.23      0.30      0.25        61

** FUSION MODEL Macro F1-Score: 0.1696 **

COMPREHENSIVE MULTIMODAL GAT-eMFD EVALUATION RESULTS

GAT-eMFD Text MODEL:
--------------------------------------------------
  Macro F1-Score: 0.0926
  Care F1: 0.0000
  Fairness F1: 0.0000
  Loyalty F1: 0.2812
  Authority F1: 0.1111
  Purity F1: 0.0500
  Non_Moral F1: 0.1132

Spatial MODEL:
--------------------------------------------------
  Macro F1-Scor

# with BCE

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, Dropout, Sigmoid, ModuleDict, LayerNorm, LSTM, GELU, MultiheadAttention, Softmax
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, classification_report
import pandas as pd
import numpy as np
import json
import math
import os
from tqdm.autonotebook import tqdm

# PyTorch Geometric imports for GAT
try:
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
except ImportError:
    print("PyTorch Geometric not found. Please install: pip install torch_geometric")
    GATConv = None
    Data = None

# Moral foundations
moral_foundations = ['Care', 'Fairness', 'Loyalty', 'Authority', 'Purity', 'Non_Moral']

# ===================================================================
# 1. eMFD Processing and Graph Construction
# ===================================================================
class eMFDProcessor:
    def __init__(self, emfd_csv_path):
        self.moral_foundations_emfd = ['care', 'fairness', 'loyalty', 'authority', 'purity', 'non-moral']
        if not os.path.exists(emfd_csv_path):
            raise FileNotFoundError(f"eMFD file not found at {emfd_csv_path}")
        df = pd.read_csv(emfd_csv_path)
        prob_cols = [f'{f}_p' for f in self.moral_foundations_emfd if f != 'non-moral']
        self.emfd_data = {
            row['word']: np.append([row.get(col, 0.0) for col in prob_cols], 0.0).astype(np.float32)
            for _, row in df.iterrows()
        }

    def extract_moral_concepts(self, text):
        concepts = [
            {'word': w, 'probabilities': self.emfd_data[w]}
            for w in text.lower().split() if w in self.emfd_data
        ]
        return concepts if concepts else [{'word': 'neutral', 'probabilities': np.zeros(6, dtype=np.float32)}]

class MoralGraphConstructor:
    def create_moral_graph(self, moral_concepts):
        num_concepts = len(moral_concepts)
        node_features = torch.FloatTensor(np.array([
            np.concatenate([c['probabilities'], np.zeros(256)]) for c in moral_concepts
        ]))

        if num_concepts == 1:
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
        else:
            edge_index = torch.tensor(
                [[i, j] for i in range(num_concepts) for j in range(num_concepts) if i != j],
                dtype=torch.long
            ).t().contiguous()

        return Data(x=node_features, edge_index=edge_index) if Data else None

# ===================================================================
# 2. GAT eMFD Module
# ===================================================================
class GATeMFDModule(nn.Module):
    def __init__(self, input_dim=262, output_dim=256, num_heads=4, dropout=0.1):
        super(GATeMFDModule, self).__init__()
        self.concept_embeddings = nn.Embedding(1000, 256)
        self.gat_layer = GATConv(input_dim, output_dim, heads=num_heads, dropout=dropout, concat=False) if GATConv else nn.Identity()
        self.attention_pooling = nn.Linear(output_dim, 1)

    def forward(self, graph_data, device):
        if not GATConv:
            return torch.zeros((1, 256)).to(device)
        x, edge_index = graph_data.x.to(device), graph_data.edge_index.to(device)
        x[:, 6:] = self.concept_embeddings(torch.arange(x.size(0)).long().to(device))
        gat_output = F.elu(self.gat_layer(x, edge_index))
        attn_weights = F.softmax(self.attention_pooling(gat_output), dim=0)
        return torch.sum(attn_weights * gat_output, dim=0, keepdim=True)

class TextModelGATeMFD(nn.Module):
    def __init__(self, emfd_csv_path):
        super().__init__()
        self.emfd_processor = eMFDProcessor(emfd_csv_path)
        self.graph_constructor = MoralGraphConstructor()
        self.gat_module = GATeMFDModule()

        # Additional processing layers
        self.processor = Sequential(
            Linear(256, 512), LayerNorm(512), GELU(), Dropout(0.1),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(0.1)
        )

        # Prediction heads - Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            foundation: Sequential(Linear(256, 128), ReLU(), Dropout(0.2),
                                 Linear(128, 1))
            for foundation in moral_foundations
        })

    def forward(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        processed_features = self.processor(features)
        return {f: self.heads[f](processed_features) for f in moral_foundations}

    def get_features(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        return self.processor(features)

# ===================================================================
# 3. Other Modality Models
# ===================================================================
class SpatialModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), ReLU(), Dropout(0.2),
            Linear(128, 256)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class TemporalModelLSTM(nn.Module):
    def __init__(self, input_dim, seq_len=10, hidden_size=128):
        super().__init__()
        self.projection = Linear(input_dim, hidden_size)
        self.lstm = LSTM(hidden_size, hidden_size, 2, batch_first=True,
                        dropout=0.2, bidirectional=True)
        self.encoder = Sequential(
            Linear(hidden_size * 2, 256), LayerNorm(256),
            ReLU(), Dropout(0.2)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        features = self.encoder(lstm_out[:, -1])
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        return self.encoder(lstm_out[:, -1])

class BehavioralModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), LayerNorm(128), ReLU(), Dropout(0.1),
            Linear(128, 256), LayerNorm(256), ReLU(), Dropout(0.1)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

# ===================================================================
# 4. Heterogeneous Fusion
# ===================================================================
class HeterogeneousFusion(nn.Module):
    def __init__(self, input_dims=None, d_model=256, num_heads=8):
        super().__init__()
        self.d_model = d_model

        # Feature projection layers to align dimensions
        if input_dims:
            self.projections = nn.ModuleList([
                nn.Linear(dim, d_model) for dim in input_dims
            ])
        else:
            self.projections = nn.ModuleList([
                nn.Linear(256, d_model) for _ in range(4)
            ])

        self.modality_queries = nn.Parameter(torch.randn(4, d_model))

        self.cross_attention = nn.ModuleList([
            MultiheadAttention(d_model, num_heads, batch_first=True)
            for _ in range(4)
        ])

        self.self_attention = MultiheadAttention(d_model, num_heads, batch_first=True)
        self.layer_norms = nn.ModuleList([LayerNorm(d_model) for _ in range(5)])

        # Improved gating mechanism
        self.modality_gate = nn.Sequential(
            Linear(d_model * 4, d_model), ReLU(), Dropout(0.1),
            Linear(d_model, 4), Softmax(dim=-1)
        )

        self.final_projection = Linear(d_model, d_model)

        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(d_model, 128), ReLU(), Dropout(0.2),
                         Linear(128, 1))
            for f in moral_foundations
        })

    def forward(self, text_feat, spatial_feat, temporal_feat, behavioral_feat):
        batch_size = text_feat.size(0)

        # Project all features to same dimension
        features = [
            self.projections[i](feat)
            for i, feat in enumerate([text_feat, spatial_feat, temporal_feat, behavioral_feat])
        ]

        # Add sequence dimension for attention
        modality_feats = [f.unsqueeze(1) for f in features]

        # Cross-attention with residual connections
        attended = []
        for i, (feat, attn, norm) in enumerate(zip(modality_feats, self.cross_attention, self.layer_norms[:4])):
            query = self.modality_queries[i:i+1].unsqueeze(0).expand(batch_size, -1, -1)
            att_out, _ = attn(query, feat, feat)
            attended.append(norm(att_out.squeeze(1) + features[i]))

        # Stack for self-attention
        stacked = torch.stack(attended, dim=1)

        # Self-attention
        self_att, _ = self.self_attention(stacked, stacked, stacked)
        fused = self.layer_norms[4](stacked + self_att)

        # Improved gating: compute gates from aggregated representation
        pooled = fused.mean(dim=1)
        gates = self.modality_gate(fused.view(batch_size, -1))

        # Apply gates to individual modality representations
        gated = torch.sum(fused * gates.unsqueeze(-1), dim=1)

        # Final projection
        final_features = self.final_projection(gated)

        return {f: self.heads[f](final_features) for f in moral_foundations}

# ===================================================================
# 5. Data Processing Functions
# ===================================================================
def prepare_labels(df):
    """Combine virtue/vice pairs and create Non-Moral labels"""
    df = df.reset_index(drop=True)

    # Combine virtue/vice pairs using logical OR
    labels = {}
    labels['Care'] = (df['Care'] == 1) | (df['Harm'] == 1)
    labels['Fairness'] = (df['Fairness'] == 1) | (df.get('Cheating', df.get('cheating', pd.Series([False]*len(df), index=df.index))) == 1)
    labels['Loyalty'] = (df['Loyalty'] == 1) | (df['Betrayal'] == 1)
    labels['Authority'] = (df['Authority'] == 1) | (df['Subversion'] == 1)
    labels['Purity'] = (df['Purity'] == 1) | (df['Degradation'] == 1)

    # Create DataFrame from moral foundation labels
    moral_df = pd.DataFrame({
        'Care': labels['Care'],
        'Fairness': labels['Fairness'],
        'Loyalty': labels['Loyalty'],
        'Authority': labels['Authority'],
        'Purity': labels['Purity']
    })

    # Non_Moral = 1 when all moral foundations are 0 (False)
    labels['Non_Moral'] = (~moral_df.any(axis=1)).astype(int)

    # Convert to numpy arrays
    labels_array = np.column_stack([labels[f].astype(int).values for f in moral_foundations])
    return torch.tensor(labels_array, dtype=torch.float32)

def create_temporal_sequences(df, seq_len=10):
    """Create temporal sequences from dataset"""
    df_sorted = df.sort_values(['GEOID', 'year', 'month', 'day'])
    sequences, targets, other_features = [], [], []

    for geoid, group in df_sorted.groupby('GEOID'):
        if len(group) >= seq_len:
            for i in range(len(group) - seq_len + 1):
                seq_data = group.iloc[i:i+seq_len]

                # Temporal features
                temp_seq = []
                for _, row in seq_data.iterrows():
                    temp_features = [
                        row['month'], row['day'], row['year'], row['after_05/26/2020'],
                        row['female_pct'], row['lesscollege_pct'], row['cvap'],
                        row['clf_unemploy_pct'], row['lesshs_pct'], row['median_hh_inc'],
                        row['urm_pct'], row['ruralurban_cc'], row['demgov'], row['repgov'],
                        row['net_dem_president_votes'], row['net_dem_gov_votes'], row['is_blue'],
                        row['deaths'], row['cases'], row['cases_per_capita'],
                        row['deaths_per_capita'], row['cases_per_capita_discrete'], row['mask_score']
                    ]
                    temp_seq.append(temp_features)

                sequences.append(temp_seq)

                last_row = seq_data.iloc[-1]
                # Prepare labels for last timestep
                last_df = pd.DataFrame([last_row])
                targets.append(prepare_labels(last_df)[0])

                other_features.append({
                    'text': last_row['text'],
                    'geoid': str(last_row['GEOID']).zfill(5),
                    'behavioral': [last_row['retweet_count'], last_row['rt_discrete'],
                                 last_row['for_sah'], last_row['is_vivid'], last_row['sentiment_score']]
                })

    return sequences, targets, other_features

def prepare_dataset(csv_path, geojson_path, county_centroids_path, seq_len=10):
    """Streamlined dataset preparation"""
    # Load data
    df = pd.read_csv(csv_path)
    with open(geojson_path) as f:
        boundaries = json.load(f)
    with open(county_centroids_path) as f:
        centroids = json.load(f)

    # Create sequences
    sequences, targets, other_features = create_temporal_sequences(df, seq_len)

    # Process features
    data = {'text': [], 'spatial': [], 'temporal': [], 'behavioral': [], 'targets': []}

    for i, (seq, target, other) in enumerate(zip(sequences, targets, other_features)):
        data['text'].append(other['text'])
        data['temporal'].append(seq)
        data['behavioral'].append(other['behavioral'])
        data['targets'].append(target)

        # Extract spatial features
        geoid = other['geoid']
        if geoid in boundaries and geoid in centroids:
            # Simplified spatial feature extraction
            centroid = centroids[geoid]
            lat_norm = (centroid[0] - 24) / (49 - 24)  # Normalize lat
            lon_norm = (centroid[1] + 125) / (66 - 125)  # Normalize lon
            spatial_feat = [lat_norm, lon_norm, 0.5, 0.5, 0.5, 0.5]  # Simplified features
        else:
            spatial_feat = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]  # Default

        data['spatial'].append(spatial_feat)

    # Convert to tensors
    spatial_tensor = torch.tensor(data['spatial'], dtype=torch.float32)
    temporal_tensor = torch.tensor(data['temporal'], dtype=torch.float32)
    behavioral_tensor = torch.tensor(data['behavioral'], dtype=torch.float32)
    targets_tensor = torch.stack(data['targets'])

    # Normalize features
    scaler_spatial = StandardScaler()
    scaler_behavioral = StandardScaler()
    scaler_temporal = StandardScaler()

    spatial_tensor = torch.tensor(scaler_spatial.fit_transform(spatial_tensor), dtype=torch.float32)
    behavioral_tensor = torch.tensor(scaler_behavioral.fit_transform(behavioral_tensor), dtype=torch.float32)

    # Normalize temporal (reshape for scaling)
    temp_shape = temporal_tensor.shape
    temp_2d = temporal_tensor.view(-1, temp_shape[2])
    temp_scaled = scaler_temporal.fit_transform(temp_2d)
    temporal_tensor = torch.tensor(temp_scaled.reshape(temp_shape), dtype=torch.float32)

    # Create splits
    indices = list(range(len(data['text'])))
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

    datasets = {}
    for split, idx in [('train', train_idx), ('validation', val_idx), ('test', test_idx)]:
        datasets[split] = {
            'text_data': [data['text'][i] for i in idx],
            'spatial_features': spatial_tensor[idx],
            'temporal_sequences': temporal_tensor[idx],
            'behavioral_features': behavioral_tensor[idx],
            'moral_targets': targets_tensor[idx]
        }

    feature_info = {
        'spatial_feature_dim': spatial_tensor.shape[1],
        'temporal_feature_dim': temporal_tensor.shape[2],
        'temporal_sequence_length': seq_len,
        'behavioral_feature_dim': behavioral_tensor.shape[1]
    }

    return datasets, feature_info

def create_dataloader(dataset_dict, batch_size=16, shuffle=True):
    """Create DataLoader for the dataset"""
    def collate_fn(batch):
        texts = [item[0] for item in batch]
        spatial = torch.stack([item[1] for item in batch])
        temporal = torch.stack([item[2] for item in batch])
        behavioral = torch.stack([item[3] for item in batch])
        targets = torch.stack([item[4] for item in batch])

        return {
            'text_data': texts,
            'spatial_features': spatial,
            'temporal_sequences': temporal,
            'behavioral_features': behavioral,
            'targets': targets
        }

    dataset = [(dataset_dict['text_data'][i],
                dataset_dict['spatial_features'][i],
                dataset_dict['temporal_sequences'][i],
                dataset_dict['behavioral_features'][i],
                dataset_dict['moral_targets'][i])
               for i in range(len(dataset_dict['text_data']))]

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

# ===================================================================
# 6. Multi-Stage Training System
# ===================================================================
class MultiStageTrainer:
    def __init__(self, models, datasets, feature_info, device, emfd_csv_path):
        self.text_model, self.spatial_model, self.temporal_model, self.behavioral_model, self.fusion_model = models
        self.datasets = datasets
        self.device = device

        # Move models to device
        self.text_model = self.text_model.to(device)
        self.spatial_model = self.spatial_model.to(device)
        self.temporal_model = self.temporal_model.to(device)
        self.behavioral_model = self.behavioral_model.to(device)
        self.fusion_model = self.fusion_model.to(device)

        # Create data loaders
        self.train_loader = create_dataloader(datasets['train'], batch_size=16, shuffle=True)
        self.val_loader = create_dataloader(datasets['validation'], batch_size=16, shuffle=False)

        # Loss function with BCE
        self.criterion = nn.BCEWithLogitsLoss()

    def _freeze(self, model):
        for p in model.parameters():
            p.requires_grad_(False)

    def _unfreeze(self, model):
        for p in model.parameters():
            p.requires_grad_(True)

    def compute_loss(self, predictions, targets):
        """Compute loss for moral foundation predictions"""
        pred_stack = torch.cat([predictions[f] for f in moral_foundations], dim=1)
        return self.criterion(pred_stack, targets)

    def train_stage1_text_gat_emfd(self, epochs=15, lr=2e-5):
        """Stage 1: Train GAT eMFD Text Model"""
        print("\n" + "="*20 + " Stage 1: Training GAT eMFD Text Model " + "="*20)

        # Freeze other models
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._freeze(self.fusion_model)
        self._unfreeze(self.text_model)

        optimizer = optim.AdamW(self.text_model.parameters(), lr=lr, weight_decay=1e-5)

        for epoch in range(epochs):
            self.text_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"Stage 1 - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer.zero_grad()
                preds = self.text_model(batch['text_data'])
                loss = self.compute_loss(preds, batch['targets'])
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if epoch % 3 == 0:
                print(f"GAT eMFD Text model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

        # Save model
        torch.save(self.text_model.state_dict(), "gat_emfd_text_stage1.pth")
        print("✅ Stage 1 Complete. GAT eMFD Text model saved.")

    def train_stage2_other_modalities(self, epochs=10, lr=5e-4):
        """Stage 2: Train other modality models"""
        print("\n" + "="*20 + " Stage 2: Training Other Modality Models " + "="*20)

        # Freeze text and fusion models
        self._freeze(self.text_model)
        self._freeze(self.fusion_model)

        models_to_train = [
            (self.spatial_model, 'spatial'),
            (self.temporal_model, 'temporal'),
            (self.behavioral_model, 'behavioral')
        ]

        for model, name in models_to_train:
            print(f"\n--- Training {name.capitalize()} Model ---")
            self._unfreeze(model)
            optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

            for epoch in range(epochs):
                model.train()
                total_loss = 0

                for batch in tqdm(self.train_loader, desc=f"{name.capitalize()} - Epoch {epoch+1}/{epochs}"):
                    batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                            for k, v in batch.items()}

                    optimizer.zero_grad()

                    if name == 'spatial':
                        preds = model(batch['spatial_features'])
                    elif name == 'temporal':
                        preds = model(batch['temporal_sequences'])
                    else:  # behavioral
                        preds = model(batch['behavioral_features'])

                    loss = self.compute_loss(preds, batch['targets'])
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()

                if epoch % 3 == 0:
                    print(f"{name.capitalize()} model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

            # Save model
            torch.save(model.state_dict(), f"{name}_model_stage2.pth")
            self._freeze(model)

        print("✅ Stage 2 Complete. All modality models saved.")

    def train_stage3_fusion_integration(self, epochs=8, lr=1e-3):
        """Stage 3: Fusion Integration"""
        print("\n" + "="*15 + " Stage 3: Fusion Integration " + "="*15)

        # Freeze all individual models, unfreeze fusion
        self._freeze(self.text_model)
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._unfreeze(self.fusion_model)

        optimizer = optim.AdamW(self.fusion_model.parameters(), lr=lr, weight_decay=1e-4)

        for epoch in range(epochs):
            self.fusion_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"Stage 3 - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer.zero_grad()

                # Extract features with no gradients for individual models
                with torch.no_grad():
                    text_feat = self.text_model.get_features(batch['text_data'])
                    spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                    temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                    behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                # Fusion with gradients
                fusion_preds = self.fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
                loss = self.compute_loss(fusion_preds, batch['targets'])
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if epoch % 2 == 0:
                val_f1 = self._validate_fusion()
                print(f"Stage 3 - Epoch {epoch+1}: Loss: {total_loss/len(self.train_loader):.4f}, Val F1: {val_f1:.4f}")

        torch.save(self.fusion_model.state_dict(), "fusion_model_stage3.pth")
        print("✅ Stage 3 Complete. Fusion model saved.")

    def train_stage4_end_to_end_finetuning(self, epochs=5, best_model_path="best_fusion_model.pth"):
        """Stage 4: End-to-end fine-tuning"""
        print("\n" + "="*15 + " Stage 4: End-to-End Fine-tuning " + "="*15)

        # Unfreeze GAT eMFD and fusion, keep others frozen
        self._unfreeze(self.text_model)
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._unfreeze(self.fusion_model)

        optimizer = optim.AdamW([
            {'params': self.text_model.parameters(), 'lr': 5e-6},
            {'params': self.fusion_model.parameters(), 'lr': 1e-4}
        ], weight_decay=1e-4)

        best_val_f1 = 0.0
        patience_counter = 0

        for epoch in range(epochs):
            # Set training modes
            self.text_model.train()
            self.fusion_model.train()
            self.spatial_model.eval()
            self.temporal_model.eval()
            self.behavioral_model.eval()

            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"Stage 4 - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer.zero_grad()

                # Extract features (GAT eMFD with gradients, others frozen)
                text_feat = self.text_model.get_features(batch['text_data'])
                with torch.no_grad():
                    spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                    temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                    behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                # Fusion
                fusion_preds = self.fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
                loss = self.compute_loss(fusion_preds, batch['targets'])
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            val_f1 = self._validate_fusion()
            print(f"Stage 4 - Epoch {epoch+1}: Loss: {total_loss/len(self.train_loader):.4f}, Val F1: {val_f1:.4f}")

            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                patience_counter = 0
                torch.save({
                    'text_model': self.text_model.state_dict(),
                    'spatial_model': self.spatial_model.state_dict(),
                    'temporal_model': self.temporal_model.state_dict(),
                    'behavioral_model': self.behavioral_model.state_dict(),
                    'fusion_model': self.fusion_model.state_dict()
                }, best_model_path)
                print(f"🏆 New best model saved with F1: {val_f1:.4f}")
            else:
                patience_counter += 1
                if patience_counter >= 2:
                    print("🛑 Early stopping triggered.")
                    break

        print("✅ Stage 4 Complete. End-to-end training finished.")

    def _validate_fusion(self):
        """Validate fusion model"""
        self.text_model.eval()
        self.spatial_model.eval()
        self.temporal_model.eval()
        self.behavioral_model.eval()
        self.fusion_model.eval()

        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in self.val_loader:
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                text_feat = self.text_model.get_features(batch['text_data'])
                spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                fusion_preds = self.fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
                pred_stack = torch.cat([fusion_preds[f] for f in moral_foundations], dim=1)

                all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
                all_labels.append(batch['targets'].cpu().numpy())

        y_pred = np.vstack(all_preds)
        y_true = np.vstack(all_labels)
        return f1_score(y_true, y_pred, average='macro', zero_division=0)

# ===================================================================
# 7. Evaluation Functions
# ===================================================================
def evaluate_gat_emfd_only(test_loader, device, emfd_csv_path):
    """Evaluates the standalone GAT eMFD model"""
    print("\n" + "="*20 + " ABLATION: GAT eMFD TEXT MODEL ONLY " + "="*20)

    text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    text_model.load_state_dict(torch.load("gat_emfd_text_stage1.pth"))
    text_model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing GAT eMFD"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            preds = text_model(batch['text_data'])
            pred_stack = torch.cat([preds[f] for f in moral_foundations], dim=1)

            all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
            all_labels.append(batch['targets'].cpu().numpy())

    y_pred = np.vstack(all_preds)
    y_true = np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    print(classification_report(y_true, y_pred, target_names=moral_foundations, zero_division=0))
    print(f"** GAT eMFD Macro F1-Score: {macro_f1:.4f} **")
    return macro_f1

def evaluate_individual_models(test_loader, device, feature_info, emfd_csv_path):
    """Evaluate all individual models"""
    print("\n" + "="*20 + " INDIVIDUAL MODEL EVALUATION " + "="*20)

    # Load models
    text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)

    # Load weights
    text_model.load_state_dict(torch.load("gat_emfd_text_stage1.pth"))
    spatial_model.load_state_dict(torch.load("spatial_model_stage2.pth"))
    temporal_model.load_state_dict(torch.load("temporal_model_stage2.pth"))
    behavioral_model.load_state_dict(torch.load("behavioral_model_stage2.pth"))

    models = [
        (text_model, 'GAT-eMFD Text'),
        (spatial_model, 'Spatial'),
        (temporal_model, 'Temporal'),
        (behavioral_model, 'Behavioral')
    ]

    results = {}

    for model, name in models:
        model.eval()
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in tqdm(test_loader, desc=f"Testing {name}"):
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                if name == 'GAT-eMFD Text':
                    preds = model(batch['text_data'])
                elif name == 'Spatial':
                    preds = model(batch['spatial_features'])
                elif name == 'Temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # Behavioral
                    preds = model(batch['behavioral_features'])

                pred_stack = torch.cat([preds[f] for f in moral_foundations], dim=1)
                all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
                all_labels.append(batch['targets'].cpu().numpy())

        y_pred = np.vstack(all_preds)
        y_true = np.vstack(all_labels)
        macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

        print(f"\n--- {name} Model Results ---")
        print(f"Macro F1-Score: {macro_f1:.4f}")

        # Foundation-specific results
        foundation_f1s = []
        for i, foundation in enumerate(moral_foundations):
            f1 = f1_score(y_true[:, i], y_pred[:, i], zero_division=0)
            foundation_f1s.append(f1)

        results[name] = {
            'macro_f1': macro_f1,
            'foundation_f1s': foundation_f1s
        }

    return results

def evaluate_fusion_model(model_path, test_loader, device, feature_info, emfd_csv_path):
    """Evaluate the final fusion model"""
    print("\n" + "="*20 + " FINAL FUSION MODEL EVALUATION " + "="*20)

    # Load models
    text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)
    fusion_model = HeterogeneousFusion().to(device)

    # Load best model state
    checkpoint = torch.load(model_path)
    text_model.load_state_dict(checkpoint['text_model'])
    spatial_model.load_state_dict(checkpoint['spatial_model'])
    temporal_model.load_state_dict(checkpoint['temporal_model'])
    behavioral_model.load_state_dict(checkpoint['behavioral_model'])
    fusion_model.load_state_dict(checkpoint['fusion_model'])

    # Set to eval mode
    text_model.eval()
    spatial_model.eval()
    temporal_model.eval()
    behavioral_model.eval()
    fusion_model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing Fusion Model"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            # Extract features
            text_feat = text_model.get_features(batch['text_data'])
            spatial_feat = spatial_model.get_features(batch['spatial_features'])
            temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])

            # Fusion
            fusion_preds = fusion_model(text_feat, spatial_feat, temporal_feat, behavioral_feat)
            pred_stack = torch.cat([fusion_preds[f] for f in moral_foundations], dim=1)

            all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
            all_labels.append(batch['targets'].cpu().numpy())

    y_pred = np.vstack(all_preds)
    y_true = np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    print(classification_report(y_true, y_pred, target_names=moral_foundations, zero_division=0))
    print(f"** FUSION MODEL Macro F1-Score: {macro_f1:.4f} **")

    return macro_f1

def print_comprehensive_results(individual_results, fusion_f1):
    """Print comprehensive comparison results"""
    print("\n" + "="*80)
    print("COMPREHENSIVE MULTIMODAL GAT-eMFD EVALUATION RESULTS")
    print("=" * 80)

    # Individual model results
    for model_name, results in individual_results.items():
        print(f"\n{model_name} MODEL:")
        print("-" * 50)
        print(f"  Macro F1-Score: {results['macro_f1']:.4f}")
        for i, foundation in enumerate(moral_foundations):
            print(f"  {foundation} F1: {results['foundation_f1s'][i]:.4f}")

    print(f"\nFUSION MODEL:")
    print("-" * 50)
    print(f"  Macro F1-Score: {fusion_f1:.4f}")

    # Foundation-specific comparison
    print("\n" + "="*80)
    print("FOUNDATION-SPECIFIC F1 COMPARISON:")
    print("=" * 80)
    header = "Foundation     GAT-eMFD  Spatial   Temporal  Behavioral  Fusion"
    print(header)
    print("-" * 70)

    for i, foundation in enumerate(moral_foundations):
        row = f"{foundation:<12}"
        for model_name in ['GAT-eMFD Text', 'Spatial', 'Temporal', 'Behavioral']:
            if model_name in individual_results:
                f1_score = individual_results[model_name]['foundation_f1s'][i]
                row += f"  {f1_score:.3f}    "
        row += f"  {fusion_f1:.3f}"
        print(row)

# ===================================================================
# 8. Main Execution
# ===================================================================
if __name__ == "__main__":
    # File paths - UPDATE THESE
    CSV_PATH = '/content/augmented_tweets.csv'
    GEOJSON_PATH = '/content/county_data_boarders.json'
    CENTROIDS_PATH = '/content/county_centroids.json'
    EMFD_CSV_PATH = '/content/eMFD_wordlist.csv'

    if not GATConv:
        print("❌ PyTorch Geometric not available. Please install it.")
        exit()

    # Prepare dataset
    print("🔄 Preparing dataset...")
    datasets, feature_info = prepare_dataset(
        csv_path=CSV_PATH,
        geojson_path=GEOJSON_PATH,
        county_centroids_path=CENTROIDS_PATH,
        seq_len=10
    )

    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Initialize models
    text_model = TextModelGATeMFD(EMFD_CSV_PATH)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim'])
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    )
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim'])
    fusion_model = HeterogeneousFusion()

    models = [text_model, spatial_model, temporal_model, behavioral_model, fusion_model]

    # Initialize trainer
    trainer = MultiStageTrainer(
        models=models,
        datasets=datasets,
        feature_info=feature_info,
        device=device,
        emfd_csv_path=EMFD_CSV_PATH
    )

    # Execute multi-stage training
    print("\n🚀 Starting Multi-Stage Training Pipeline...")
    print("=" * 80)

    trainer.train_stage1_text_gat_emfd(epochs=5, lr=2e-5)
    trainer.train_stage2_other_modalities(epochs=10, lr=5e-4)
    trainer.train_stage3_fusion_integration(epochs=2, lr=1e-3)
    trainer.train_stage4_end_to_end_finetuning(epochs=2, best_model_path="best_multimodal_gat_emfd.pth")

    # Final evaluation
    print("\n🔍 Starting Comprehensive Evaluation...")
    print("=" * 80)

    test_loader = create_dataloader(datasets['test'], batch_size=16, shuffle=False)

    # Evaluate individual models
    individual_results = evaluate_individual_models(test_loader, device, feature_info, EMFD_CSV_PATH)

    # Evaluate fusion model
    fusion_f1 = evaluate_fusion_model("best_multimodal_gat_emfd.pth", test_loader, device, feature_info, EMFD_CSV_PATH)

    # Print comprehensive results
    print_comprehensive_results(individual_results, fusion_f1)

    print("\n✅ Multi-stage training and comprehensive evaluation completed!")
    print("🏆 Best models saved as 'best_multimodal_gat_emfd.pth'")


🔄 Preparing dataset...
Using device: cuda

🚀 Starting Multi-Stage Training Pipeline...



Stage 1 - Epoch 1/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT eMFD Text model epoch 1, loss: 0.6801


Stage 1 - Epoch 2/5:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 1 - Epoch 3/5:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 1 - Epoch 4/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT eMFD Text model epoch 4, loss: 0.4894


Stage 1 - Epoch 5/5:   0%|          | 0/17 [00:00<?, ?it/s]

✅ Stage 1 Complete. GAT eMFD Text model saved.


--- Training Spatial Model ---


Spatial - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 1, loss: 0.5676


Spatial - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 4, loss: 0.3177


Spatial - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 7, loss: 0.2991


Spatial - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 10, loss: 0.2822

--- Training Temporal Model ---


Temporal - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 1, loss: 0.4340


Temporal - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 4, loss: 0.2666


Temporal - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 7, loss: 0.2722


Temporal - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 10, loss: 0.2521

--- Training Behavioral Model ---


Behavioral - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 1, loss: 0.4927


Behavioral - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 4, loss: 0.2428


Behavioral - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 7, loss: 0.2368


Behavioral - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 10, loss: 0.2340
✅ Stage 2 Complete. All modality models saved.



Stage 3 - Epoch 1/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 3 - Epoch 1: Loss: 0.3248, Val F1: 0.1364


Stage 3 - Epoch 2/2:   0%|          | 0/17 [00:00<?, ?it/s]

✅ Stage 3 Complete. Fusion model saved.



Stage 4 - Epoch 1/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 4 - Epoch 1: Loss: 0.2353, Val F1: 0.1586
🏆 New best model saved with F1: 0.1586


Stage 4 - Epoch 2/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 4 - Epoch 2: Loss: 0.2170, Val F1: 0.1549
✅ Stage 4 Complete. End-to-end training finished.

🔍 Starting Comprehensive Evaluation...



Testing GAT-eMFD Text:   0%|          | 0/4 [00:00<?, ?it/s]


--- GAT-eMFD Text Model Results ---
Macro F1-Score: 0.1467


Testing Spatial:   0%|          | 0/4 [00:00<?, ?it/s]


--- Spatial Model Results ---
Macro F1-Score: 0.1467


Testing Temporal:   0%|          | 0/4 [00:00<?, ?it/s]


--- Temporal Model Results ---
Macro F1-Score: 0.1424


Testing Behavioral:   0%|          | 0/4 [00:00<?, ?it/s]


--- Behavioral Model Results ---
Macro F1-Score: 0.2600



Testing Fusion Model:   0%|          | 0/4 [00:00<?, ?it/s]

              precision    recall  f1-score   support

        Care       0.84      0.95      0.89        44
    Fairness       0.00      0.00      0.00         0
     Loyalty       0.00      0.00      0.00         9
   Authority       0.00      0.00      0.00         4
      Purity       0.00      0.00      0.00         1
   Non_Moral       0.33      0.33      0.33         3

   micro avg       0.78      0.70      0.74        61
   macro avg       0.20      0.21      0.20        61
weighted avg       0.62      0.70      0.66        61
 samples avg       0.75      0.73      0.73        61

** FUSION MODEL Macro F1-Score: 0.2045 **

COMPREHENSIVE MULTIMODAL GAT-eMFD EVALUATION RESULTS

GAT-eMFD Text MODEL:
--------------------------------------------------
  Macro F1-Score: 0.1467
  Care F1: 0.8800
  Fairness F1: 0.0000
  Loyalty F1: 0.0000
  Authority F1: 0.0000
  Purity F1: 0.0000
  Non_Moral F1: 0.0000

Spatial MODEL:
--------------------------------------------------
  Macro F1-Scor

# with mlp text + gat

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, Dropout, Sigmoid, ModuleDict, LayerNorm, LSTM, GELU, MultiheadAttention, Softmax
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, classification_report
import pandas as pd
import numpy as np
import json
import math
import os
from tqdm.autonotebook import tqdm

# PyTorch Geometric imports for GAT
try:
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
except ImportError:
    print("PyTorch Geometric not found. Please install: pip install torch_geometric")
    GATConv = None
    Data = None

# Moral foundations
moral_foundations = ['Care', 'Fairness', 'Loyalty', 'Authority', 'Purity', 'Non_Moral']

# ===================================================================
# 1. GAT-eMFD Text Model (from paste.txt)
# ===================================================================
class eMFDProcessor:
    def __init__(self, emfd_csv_path):
        self.moral_foundations_emfd = ['care', 'fairness', 'loyalty', 'authority', 'purity', 'non-moral']
        if not os.path.exists(emfd_csv_path):
            raise FileNotFoundError(f"eMFD file not found at {emfd_csv_path}")
        df = pd.read_csv(emfd_csv_path)
        prob_cols = [f'{f}_p' for f in self.moral_foundations_emfd if f != 'non-moral']
        self.emfd_data = {
            row['word']: np.append([row.get(col, 0.0) for col in prob_cols], 0.0).astype(np.float32)
            for _, row in df.iterrows()
        }

    def extract_moral_concepts(self, text):
        concepts = [
            {'word': w, 'probabilities': self.emfd_data[w]}
            for w in text.lower().split() if w in self.emfd_data
        ]
        return concepts if concepts else [{'word': 'neutral', 'probabilities': np.zeros(6, dtype=np.float32)}]

class MoralGraphConstructor:
    def create_moral_graph(self, moral_concepts):
        num_concepts = len(moral_concepts)
        node_features = torch.FloatTensor(np.array([
            np.concatenate([c['probabilities'], np.zeros(256)]) for c in moral_concepts
        ]))

        if num_concepts == 1:
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
        else:
            edge_index = torch.tensor(
                [[i, j] for i in range(num_concepts) for j in range(num_concepts) if i != j],
                dtype=torch.long
            ).t().contiguous()

        return Data(x=node_features, edge_index=edge_index) if Data else None

class GATeMFDModule(nn.Module):
    def __init__(self, input_dim=262, output_dim=256, num_heads=4, dropout=0.1):
        super(GATeMFDModule, self).__init__()
        self.concept_embeddings = nn.Embedding(1000, 256)
        self.gat_layer = GATConv(input_dim, output_dim, heads=num_heads, dropout=dropout, concat=False) if GATConv else nn.Identity()
        self.attention_pooling = nn.Linear(output_dim, 1)

    def forward(self, graph_data, device):
        if not GATConv:
            return torch.zeros((1, 256)).to(device)
        x, edge_index = graph_data.x.to(device), graph_data.edge_index.to(device)
        x[:, 6:] = self.concept_embeddings(torch.arange(x.size(0)).long().to(device))
        gat_output = F.elu(self.gat_layer(x, edge_index))
        attn_weights = F.softmax(self.attention_pooling(gat_output), dim=0)
        return torch.sum(attn_weights * gat_output, dim=0, keepdim=True)

class TextModelGATeMFD(nn.Module):
    def __init__(self, emfd_csv_path):
        super().__init__()
        self.emfd_processor = eMFDProcessor(emfd_csv_path)
        self.graph_constructor = MoralGraphConstructor()
        self.gat_module = GATeMFDModule()

        # Additional processing layers
        self.processor = Sequential(
            Linear(256, 512), LayerNorm(512), GELU(), Dropout(0.1),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(0.1)
        )

        # Prediction heads - Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            foundation: Sequential(Linear(256, 128), ReLU(), Dropout(0.2),
                                 Linear(128, 1))
            for foundation in moral_foundations
        })

    def forward(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        processed_features = self.processor(features)
        return {f: self.heads[f](processed_features) for f in moral_foundations}

    def get_features(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        return self.processor(features)

# ===================================================================
# 2. MLP Text Model (from paste-2.txt)
# ===================================================================
class TextModelRoBERTaFrozen(nn.Module):
    def __init__(self):
        super().__init__()
        # Load RoBERTa model
        self.tokenizer = AutoTokenizer.from_pretrained('roberta-base')
        self.roberta = AutoModel.from_pretrained('roberta-base')

        # Freeze RoBERTa completely
        for param in self.roberta.parameters():
            param.requires_grad = False

        # Trainable processing layers
        self.processor = Sequential(
            Linear(768, 512), LayerNorm(512), GELU(), Dropout(0.1),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(0.1)
        )

        # Prediction heads - Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            foundation: Sequential(Linear(256, 128), ReLU(), Dropout(0.2),
                                 Linear(128, 1))
            for foundation in moral_foundations
        })

    def forward(self, texts):
        # Tokenize and encode
        inputs = self.tokenizer(texts, padding=True, truncation=True,
                               max_length=512, return_tensors='pt')
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.roberta(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]  # CLS token

        features = self.processor(embeddings)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, texts):
        inputs = self.tokenizer(texts, padding=True, truncation=True,
                               max_length=512, return_tensors='pt')
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.roberta(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]
        return self.processor(embeddings)

# ===================================================================
# 3. Other Modality Models
# ===================================================================
class SpatialModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), ReLU(), Dropout(0.2),
            Linear(128, 256)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class TemporalModelLSTM(nn.Module):
    def __init__(self, input_dim, seq_len=10, hidden_size=128):
        super().__init__()
        self.projection = Linear(input_dim, hidden_size)
        self.lstm = LSTM(hidden_size, hidden_size, 2, batch_first=True,
                        dropout=0.2, bidirectional=True)
        self.encoder = Sequential(
            Linear(hidden_size * 2, 256), LayerNorm(256),
            ReLU(), Dropout(0.2)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        features = self.encoder(lstm_out[:, -1])
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        return self.encoder(lstm_out[:, -1])

class BehavioralModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), LayerNorm(128), ReLU(), Dropout(0.1),
            Linear(128, 256), LayerNorm(256), ReLU(), Dropout(0.1)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

# ===================================================================
# 4. Enhanced Heterogeneous Fusion (5 modalities)
# ===================================================================
class HeterogeneousFusion(nn.Module):
    def __init__(self, input_dims=None, d_model=256, num_heads=8):
        super().__init__()
        self.d_model = d_model

        # Feature projection layers to align dimensions (5 modalities now)
        if input_dims:
            self.projections = nn.ModuleList([
                nn.Linear(dim, d_model) for dim in input_dims
            ])
        else:
            self.projections = nn.ModuleList([
                nn.Linear(256, d_model) for _ in range(5)  # GAT text, MLP text, spatial, temporal, behavioral
            ])

        self.modality_queries = nn.Parameter(torch.randn(5, d_model))

        self.cross_attention = nn.ModuleList([
            MultiheadAttention(d_model, num_heads, batch_first=True)
            for _ in range(5)
        ])

        self.self_attention = MultiheadAttention(d_model, num_heads, batch_first=True)
        self.layer_norms = nn.ModuleList([LayerNorm(d_model) for _ in range(6)])

        # Improved gating mechanism for 5 modalities
        self.modality_gate = nn.Sequential(
            Linear(d_model * 5, d_model), ReLU(), Dropout(0.1),
            Linear(d_model, 5), Softmax(dim=-1)
        )

        self.final_projection = Linear(d_model, d_model)

        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(d_model, 128), ReLU(), Dropout(0.2),
                         Linear(128, 1))
            for f in moral_foundations
        })

    def forward(self, gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat):
        batch_size = gat_text_feat.size(0)

        # Project all features to same dimension
        features = [
            self.projections[i](feat)
            for i, feat in enumerate([gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat])
        ]

        # Add sequence dimension for attention
        modality_feats = [f.unsqueeze(1) for f in features]

        # Cross-attention with residual connections
        attended = []
        for i, (feat, attn, norm) in enumerate(zip(modality_feats, self.cross_attention, self.layer_norms[:5])):
            query = self.modality_queries[i:i+1].unsqueeze(0).expand(batch_size, -1, -1)
            att_out, _ = attn(query, feat, feat)
            attended.append(norm(att_out.squeeze(1) + features[i]))

        # Stack for self-attention
        stacked = torch.stack(attended, dim=1)

        # Self-attention
        self_att, _ = self.self_attention(stacked, stacked, stacked)
        fused = self.layer_norms[5](stacked + self_att)

        # Improved gating: compute gates from aggregated representation
        pooled = fused.mean(dim=1)
        gates = self.modality_gate(fused.view(batch_size, -1))

        # Apply gates to individual modality representations
        gated = torch.sum(fused * gates.unsqueeze(-1), dim=1)

        # Final projection
        final_features = self.final_projection(gated)

        return {f: self.heads[f](final_features) for f in moral_foundations}

# ===================================================================
# 5. Data Processing Functions (same as paste.txt)
# ===================================================================
def prepare_labels(df):
    """Combine virtue/vice pairs and create Non-Moral labels"""
    df = df.reset_index(drop=True)

    # Combine virtue/vice pairs using logical OR
    labels = {}
    labels['Care'] = (df['Care'] == 1) | (df['Harm'] == 1)
    labels['Fairness'] = (df['Fairness'] == 1) | (df.get('Cheating', df.get('cheating', pd.Series([False]*len(df), index=df.index))) == 1)
    labels['Loyalty'] = (df['Loyalty'] == 1) | (df['Betrayal'] == 1)
    labels['Authority'] = (df['Authority'] == 1) | (df['Subversion'] == 1)
    labels['Purity'] = (df['Purity'] == 1) | (df['Degradation'] == 1)

    # Create DataFrame from moral foundation labels
    moral_df = pd.DataFrame({
        'Care': labels['Care'],
        'Fairness': labels['Fairness'],
        'Loyalty': labels['Loyalty'],
        'Authority': labels['Authority'],
        'Purity': labels['Purity']
    })

    # Non_Moral = 1 when all moral foundations are 0 (False)
    labels['Non_Moral'] = (~moral_df.any(axis=1)).astype(int)

    # Convert to numpy arrays
    labels_array = np.column_stack([labels[f].astype(int).values for f in moral_foundations])
    return torch.tensor(labels_array, dtype=torch.float32)

def create_temporal_sequences(df, seq_len=10):
    """Create temporal sequences from dataset"""
    df_sorted = df.sort_values(['GEOID', 'year', 'month', 'day'])
    sequences, targets, other_features = [], [], []

    for geoid, group in df_sorted.groupby('GEOID'):
        if len(group) >= seq_len:
            for i in range(len(group) - seq_len + 1):
                seq_data = group.iloc[i:i+seq_len]

                # Temporal features
                temp_seq = []
                for _, row in seq_data.iterrows():
                    temp_features = [
                        row['month'], row['day'], row['year'], row['after_05/26/2020'],
                        row['female_pct'], row['lesscollege_pct'], row['cvap'],
                        row['clf_unemploy_pct'], row['lesshs_pct'], row['median_hh_inc'],
                        row['urm_pct'], row['ruralurban_cc'], row['demgov'], row['repgov'],
                        row['net_dem_president_votes'], row['net_dem_gov_votes'], row['is_blue'],
                        row['deaths'], row['cases'], row['cases_per_capita'],
                        row['deaths_per_capita'], row['cases_per_capita_discrete'], row['mask_score']
                    ]
                    temp_seq.append(temp_features)

                sequences.append(temp_seq)

                last_row = seq_data.iloc[-1]
                # Prepare labels for last timestep
                last_df = pd.DataFrame([last_row])
                targets.append(prepare_labels(last_df)[0])

                other_features.append({
                    'text': last_row['text'],
                    'geoid': str(last_row['GEOID']).zfill(5),
                    'behavioral': [last_row['retweet_count'], last_row['rt_discrete'],
                                 last_row['for_sah'], last_row['is_vivid'], last_row['sentiment_score']]
                })

    return sequences, targets, other_features

def prepare_dataset(csv_path, geojson_path, county_centroids_path, seq_len=10):
    """Streamlined dataset preparation"""
    # Load data
    df = pd.read_csv(csv_path)
    with open(geojson_path) as f:
        boundaries = json.load(f)
    with open(county_centroids_path) as f:
        centroids = json.load(f)

    # Create sequences
    sequences, targets, other_features = create_temporal_sequences(df, seq_len)

    # Process features
    data = {'text': [], 'spatial': [], 'temporal': [], 'behavioral': [], 'targets': []}

    for i, (seq, target, other) in enumerate(zip(sequences, targets, other_features)):
        data['text'].append(other['text'])
        data['temporal'].append(seq)
        data['behavioral'].append(other['behavioral'])
        data['targets'].append(target)

        # Extract spatial features
        geoid = other['geoid']
        if geoid in boundaries and geoid in centroids:
            # Simplified spatial feature extraction
            centroid = centroids[geoid]
            lat_norm = (centroid[0] - 24) / (49 - 24)  # Normalize lat
            lon_norm = (centroid[1] + 125) / (66 - 125)  # Normalize lon
            spatial_feat = [lat_norm, lon_norm, 0.5, 0.5, 0.5, 0.5]  # Simplified features
        else:
            spatial_feat = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]  # Default

        data['spatial'].append(spatial_feat)

    # Convert to tensors
    spatial_tensor = torch.tensor(data['spatial'], dtype=torch.float32)
    temporal_tensor = torch.tensor(data['temporal'], dtype=torch.float32)
    behavioral_tensor = torch.tensor(data['behavioral'], dtype=torch.float32)
    targets_tensor = torch.stack(data['targets'])

    # Normalize features
    scaler_spatial = StandardScaler()
    scaler_behavioral = StandardScaler()
    scaler_temporal = StandardScaler()

    spatial_tensor = torch.tensor(scaler_spatial.fit_transform(spatial_tensor), dtype=torch.float32)
    behavioral_tensor = torch.tensor(scaler_behavioral.fit_transform(behavioral_tensor), dtype=torch.float32)

    # Normalize temporal (reshape for scaling)
    temp_shape = temporal_tensor.shape
    temp_2d = temporal_tensor.view(-1, temp_shape[2])
    temp_scaled = scaler_temporal.fit_transform(temp_2d)
    temporal_tensor = torch.tensor(temp_scaled.reshape(temp_shape), dtype=torch.float32)

    # Create splits
    indices = list(range(len(data['text'])))
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

    datasets = {}
    for split, idx in [('train', train_idx), ('validation', val_idx), ('test', test_idx)]:
        datasets[split] = {
            'text_data': [data['text'][i] for i in idx],
            'spatial_features': spatial_tensor[idx],
            'temporal_sequences': temporal_tensor[idx],
            'behavioral_features': behavioral_tensor[idx],
            'moral_targets': targets_tensor[idx]
        }

    feature_info = {
        'spatial_feature_dim': spatial_tensor.shape[1],
        'temporal_feature_dim': temporal_tensor.shape[2],
        'temporal_sequence_length': seq_len,
        'behavioral_feature_dim': behavioral_tensor.shape[1]
    }

    return datasets, feature_info

def create_dataloader(dataset_dict, batch_size=16, shuffle=True):
    """Create DataLoader for the dataset"""
    def collate_fn(batch):
        texts = [item[0] for item in batch]
        spatial = torch.stack([item[1] for item in batch])
        temporal = torch.stack([item[2] for item in batch])
        behavioral = torch.stack([item[3] for item in batch])
        targets = torch.stack([item[4] for item in batch])

        return {
            'text_data': texts,
            'spatial_features': spatial,
            'temporal_sequences': temporal,
            'behavioral_features': behavioral,
            'targets': targets
        }

    dataset = [(dataset_dict['text_data'][i],
                dataset_dict['spatial_features'][i],
                dataset_dict['temporal_sequences'][i],
                dataset_dict['behavioral_features'][i],
                dataset_dict['moral_targets'][i])
               for i in range(len(dataset_dict['text_data']))]

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

# ===================================================================
# 6. Enhanced Multi-Stage Training System (5 modalities)
# ===================================================================
class MultiStageTrainer:
    def __init__(self, models, datasets, feature_info, device, emfd_csv_path):
        self.gat_text_model, self.mlp_text_model, self.spatial_model, self.temporal_model, self.behavioral_model, self.fusion_model = models
        self.datasets = datasets
        self.device = device

        # Move models to device
        self.gat_text_model = self.gat_text_model.to(device)
        self.mlp_text_model = self.mlp_text_model.to(device)
        self.spatial_model = self.spatial_model.to(device)
        self.temporal_model = self.temporal_model.to(device)
        self.behavioral_model = self.behavioral_model.to(device)
        self.fusion_model = self.fusion_model.to(device)

        # Create data loaders
        self.train_loader = create_dataloader(datasets['train'], batch_size=16, shuffle=True)
        self.val_loader = create_dataloader(datasets['validation'], batch_size=16, shuffle=False)

        # Loss function with BCE
        self.criterion = nn.BCEWithLogitsLoss()

    def _freeze(self, model):
        for p in model.parameters():
            p.requires_grad_(False)

    def _unfreeze(self, model):
        for p in model.parameters():
            p.requires_grad_(True)

    def compute_loss(self, predictions, targets):
        """Compute loss for moral foundation predictions"""
        pred_stack = torch.cat([predictions[f] for f in moral_foundations], dim=1)
        return self.criterion(pred_stack, targets)

    def train_stage1_text_models(self, epochs=15, lr=2e-5):
        """Stage 1: Train both GAT eMFD and MLP Text Models"""
        print("\n" + "="*20 + " Stage 1: Training Both Text Models " + "="*20)

        # Freeze other models
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._freeze(self.fusion_model)

        # Train GAT Text Model
        print("Training GAT-eMFD Text Model...")
        self._unfreeze(self.gat_text_model)
        self._freeze(self.mlp_text_model)

        optimizer_gat = optim.AdamW(self.gat_text_model.parameters(), lr=lr, weight_decay=1e-5)

        for epoch in range(epochs):
            self.gat_text_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"GAT Text - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer_gat.zero_grad()
                preds = self.gat_text_model(batch['text_data'])
                loss = self.compute_loss(preds, batch['targets'])
                loss.backward()
                optimizer_gat.step()
                total_loss += loss.item()

            if epoch % 3 == 0:
                print(f"GAT Text model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

        # Train MLP Text Model
        print("Training MLP-RoBERTa Text Model...")
        self._freeze(self.gat_text_model)
        self._unfreeze(self.mlp_text_model)

        optimizer_mlp = optim.AdamW(self.mlp_text_model.parameters(), lr=lr, weight_decay=1e-5)

        for epoch in range(epochs):
            self.mlp_text_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"MLP Text - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer_mlp.zero_grad()
                preds = self.mlp_text_model(batch['text_data'])
                loss = self.compute_loss(preds, batch['targets'])
                loss.backward()
                optimizer_mlp.step()
                total_loss += loss.item()

            if epoch % 3 == 0:
                print(f"MLP Text model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

        # Save models
        torch.save(self.gat_text_model.state_dict(), "gat_text_stage1.pth")
        torch.save(self.mlp_text_model.state_dict(), "mlp_text_stage1.pth")
        print("✅ Stage 1 Complete. Both text models saved.")

    def train_stage2_other_modalities(self, epochs=10, lr=5e-4):
        """Stage 2: Train other modality models"""
        print("\n" + "="*20 + " Stage 2: Training Other Modality Models " + "="*20)

        # Freeze text and fusion models
        self._freeze(self.gat_text_model)
        self._freeze(self.mlp_text_model)
        self._freeze(self.fusion_model)

        models_to_train = [
            (self.spatial_model, 'spatial'),
            (self.temporal_model, 'temporal'),
            (self.behavioral_model, 'behavioral')
        ]

        for model, name in models_to_train:
            print(f"\n--- Training {name.capitalize()} Model ---")
            self._unfreeze(model)
            optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

            for epoch in range(epochs):
                model.train()
                total_loss = 0

                for batch in tqdm(self.train_loader, desc=f"{name.capitalize()} - Epoch {epoch+1}/{epochs}"):
                    batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                            for k, v in batch.items()}

                    optimizer.zero_grad()

                    if name == 'spatial':
                        preds = model(batch['spatial_features'])
                    elif name == 'temporal':
                        preds = model(batch['temporal_sequences'])
                    else:  # behavioral
                        preds = model(batch['behavioral_features'])

                    loss = self.compute_loss(preds, batch['targets'])
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()

                if epoch % 3 == 0:
                    print(f"{name.capitalize()} model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

            # Save model
            torch.save(model.state_dict(), f"{name}_model_stage2.pth")
            self._freeze(model)

        print("✅ Stage 2 Complete. All modality models saved.")

    def train_stage3_fusion_integration(self, epochs=8, lr=1e-3):
        """Stage 3: Fusion Integration with 5 modalities"""
        print("\n" + "="*15 + " Stage 3: Fusion Integration (5 modalities) " + "="*15)

        # Freeze all individual models, unfreeze fusion
        self._freeze(self.gat_text_model)
        self._freeze(self.mlp_text_model)
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._unfreeze(self.fusion_model)

        optimizer = optim.AdamW(self.fusion_model.parameters(), lr=lr, weight_decay=1e-4)

        for epoch in range(epochs):
            self.fusion_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"Stage 3 - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer.zero_grad()

                # Extract features with no gradients for individual models
                with torch.no_grad():
                    gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
                    mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
                    spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                    temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                    behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                # Fusion with gradients (5 modalities)
                fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
                loss = self.compute_loss(fusion_preds, batch['targets'])
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if epoch % 2 == 0:
                val_f1 = self._validate_fusion()
                print(f"Stage 3 - Epoch {epoch+1}: Loss: {total_loss/len(self.train_loader):.4f}, Val F1: {val_f1:.4f}")

        torch.save(self.fusion_model.state_dict(), "fusion_model_stage3.pth")
        print("✅ Stage 3 Complete. Fusion model saved.")

    def train_stage4_end_to_end_finetuning(self, epochs=5, best_model_path="best_fusion_model.pth"):
        """Stage 4: End-to-end fine-tuning with both text models"""
        print("\n" + "="*15 + " Stage 4: End-to-End Fine-tuning " + "="*15)

        # Unfreeze both text models and fusion, keep others frozen
        self._unfreeze(self.gat_text_model)
        self._unfreeze(self.mlp_text_model)
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._unfreeze(self.fusion_model)

        optimizer = optim.AdamW([
            {'params': self.gat_text_model.parameters(), 'lr': 5e-6},
            {'params': self.mlp_text_model.parameters(), 'lr': 5e-6},
            {'params': self.fusion_model.parameters(), 'lr': 1e-4}
        ], weight_decay=1e-4)

        best_val_f1 = 0.0
        patience_counter = 0

        for epoch in range(epochs):
            # Set training modes
            self.gat_text_model.train()
            self.mlp_text_model.train()
            self.fusion_model.train()
            self.spatial_model.eval()
            self.temporal_model.eval()
            self.behavioral_model.eval()

            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"Stage 4 - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer.zero_grad()

                # Extract features (both text models with gradients, others frozen)
                gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
                mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
                with torch.no_grad():
                    spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                    temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                    behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                # Fusion
                fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
                loss = self.compute_loss(fusion_preds, batch['targets'])
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            val_f1 = self._validate_fusion()
            print(f"Stage 4 - Epoch {epoch+1}: Loss: {total_loss/len(self.train_loader):.4f}, Val F1: {val_f1:.4f}")

            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                patience_counter = 0
                torch.save({
                    'gat_text_model': self.gat_text_model.state_dict(),
                    'mlp_text_model': self.mlp_text_model.state_dict(),
                    'spatial_model': self.spatial_model.state_dict(),
                    'temporal_model': self.temporal_model.state_dict(),
                    'behavioral_model': self.behavioral_model.state_dict(),
                    'fusion_model': self.fusion_model.state_dict()
                }, best_model_path)
                print(f"🏆 New best model saved with F1: {val_f1:.4f}")
            else:
                patience_counter += 1
                if patience_counter >= 2:
                    print("🛑 Early stopping triggered.")
                    break

        print("✅ Stage 4 Complete. End-to-end training finished.")

    def _validate_fusion(self):
        """Validate fusion model with 5 modalities"""
        self.gat_text_model.eval()
        self.mlp_text_model.eval()
        self.spatial_model.eval()
        self.temporal_model.eval()
        self.behavioral_model.eval()
        self.fusion_model.eval()

        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in self.val_loader:
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
                mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
                spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
                pred_stack = torch.cat([fusion_preds[f] for f in moral_foundations], dim=1)

                all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
                all_labels.append(batch['targets'].cpu().numpy())

        y_pred = np.vstack(all_preds)
        y_true = np.vstack(all_labels)
        return f1_score(y_true, y_pred, average='macro', zero_division=0)

# ===================================================================
# 7. Enhanced Evaluation Functions
# ===================================================================
def evaluate_individual_models(test_loader, device, feature_info, emfd_csv_path):
    """Evaluate all individual models including both text models"""
    print("\n" + "="*20 + " INDIVIDUAL MODEL EVALUATION " + "="*20)

    # Load models
    gat_text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    mlp_text_model = TextModelRoBERTaFrozen().to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)

    # Load weights
    gat_text_model.load_state_dict(torch.load("gat_text_stage1.pth"))
    mlp_text_model.load_state_dict(torch.load("mlp_text_stage1.pth"))
    spatial_model.load_state_dict(torch.load("spatial_model_stage2.pth"))
    temporal_model.load_state_dict(torch.load("temporal_model_stage2.pth"))
    behavioral_model.load_state_dict(torch.load("behavioral_model_stage2.pth"))

    models = [
        (gat_text_model, 'GAT-eMFD Text'),
        (mlp_text_model, 'MLP-RoBERTa Text'),
        (spatial_model, 'Spatial'),
        (temporal_model, 'Temporal'),
        (behavioral_model, 'Behavioral')
    ]

    results = {}

    for model, name in models:
        model.eval()
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in tqdm(test_loader, desc=f"Testing {name}"):
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                if 'Text' in name:
                    preds = model(batch['text_data'])
                elif name == 'Spatial':
                    preds = model(batch['spatial_features'])
                elif name == 'Temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # Behavioral
                    preds = model(batch['behavioral_features'])

                pred_stack = torch.cat([preds[f] for f in moral_foundations], dim=1)
                all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
                all_labels.append(batch['targets'].cpu().numpy())

        y_pred = np.vstack(all_preds)
        y_true = np.vstack(all_labels)
        macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

        print(f"\n--- {name} Model Results ---")
        print(f"Macro F1-Score: {macro_f1:.4f}")

        # Foundation-specific results
        foundation_f1s = []
        for i, foundation in enumerate(moral_foundations):
            f1 = f1_score(y_true[:, i], y_pred[:, i], zero_division=0)
            foundation_f1s.append(f1)

        results[name] = {
            'macro_f1': macro_f1,
            'foundation_f1s': foundation_f1s
        }

    return results

def evaluate_fusion_model(model_path, test_loader, device, feature_info, emfd_csv_path):
    """Evaluate the final fusion model with 5 modalities"""
    print("\n" + "="*20 + " FINAL FUSION MODEL EVALUATION (5 MODALITIES) " + "="*20)

    # Load models
    gat_text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    mlp_text_model = TextModelRoBERTaFrozen().to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)
    fusion_model = HeterogeneousFusion().to(device)

    # Load best model state
    checkpoint = torch.load(model_path)
    gat_text_model.load_state_dict(checkpoint['gat_text_model'])
    mlp_text_model.load_state_dict(checkpoint['mlp_text_model'])
    spatial_model.load_state_dict(checkpoint['spatial_model'])
    temporal_model.load_state_dict(checkpoint['temporal_model'])
    behavioral_model.load_state_dict(checkpoint['behavioral_model'])
    fusion_model.load_state_dict(checkpoint['fusion_model'])

    # Set to eval mode
    gat_text_model.eval()
    mlp_text_model.eval()
    spatial_model.eval()
    temporal_model.eval()
    behavioral_model.eval()
    fusion_model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing 5-Modality Fusion Model"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            # Extract features
            gat_text_feat = gat_text_model.get_features(batch['text_data'])
            mlp_text_feat = mlp_text_model.get_features(batch['text_data'])
            spatial_feat = spatial_model.get_features(batch['spatial_features'])
            temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])

            # Fusion
            fusion_preds = fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
            pred_stack = torch.cat([fusion_preds[f] for f in moral_foundations], dim=1)

            all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
            all_labels.append(batch['targets'].cpu().numpy())

    y_pred = np.vstack(all_preds)
    y_true = np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    print(classification_report(y_true, y_pred, target_names=moral_foundations, zero_division=0))
    print(f"** 5-MODALITY FUSION MODEL Macro F1-Score: {macro_f1:.4f} **")

    return macro_f1

def print_comprehensive_results(individual_results, fusion_f1):
    """Print comprehensive comparison results for 5 modalities"""
    print("\n" + "="*80)
    print("COMPREHENSIVE 5-MODALITY EVALUATION RESULTS")
    print("=" * 80)

    # Individual model results
    for model_name, results in individual_results.items():
        print(f"\n{model_name} MODEL:")
        print("-" * 50)
        print(f"  Macro F1-Score: {results['macro_f1']:.4f}")
        for i, foundation in enumerate(moral_foundations):
            print(f"  {foundation} F1: {results['foundation_f1s'][i]:.4f}")

    print(f"\n5-MODALITY FUSION MODEL:")
    print("-" * 50)
    print(f"  Macro F1-Score: {fusion_f1:.4f}")

    # Foundation-specific comparison
    print("\n" + "="*80)
    print("FOUNDATION-SPECIFIC F1 COMPARISON:")
    print("=" * 80)
    header = "Foundation     GAT-Text  MLP-Text  Spatial   Temporal  Behavioral  Fusion"
    print(header)
    print("-" * 80)

    for i, foundation in enumerate(moral_foundations):
        row = f"{foundation:<12}"
        for model_name in ['GAT-eMFD Text', 'MLP-RoBERTa Text', 'Spatial', 'Temporal', 'Behavioral']:
            if model_name in individual_results:
                f1_score = individual_results[model_name]['foundation_f1s'][i]
                row += f"  {f1_score:.3f}    "
        row += f"  {fusion_f1:.3f}"
        print(row)

# ===================================================================
# 8. Main Execution
# ===================================================================
if __name__ == "__main__":
    # File paths - UPDATE THESE
    CSV_PATH = '/content/augmented_tweets.csv'
    GEOJSON_PATH = '/content/county_data_boarders.json'
    CENTROIDS_PATH = '/content/county_centroids.json'
    EMFD_CSV_PATH = '/content/eMFD_wordlist.csv'

    if not GATConv:
        print("❌ PyTorch Geometric not available. Please install it.")
        exit()

    # Prepare dataset
    print("🔄 Preparing dataset...")
    datasets, feature_info = prepare_dataset(
        csv_path=CSV_PATH,
        geojson_path=GEOJSON_PATH,
        county_centroids_path=CENTROIDS_PATH,
        seq_len=10
    )

    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Initialize models (now 6 models: 2 text + 3 others + 1 fusion)
    gat_text_model = TextModelGATeMFD(EMFD_CSV_PATH)
    mlp_text_model = TextModelRoBERTaFrozen()
    spatial_model = SpatialModel(feature_info['spatial_feature_dim'])
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    )
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim'])
    fusion_model = HeterogeneousFusion()  # Now handles 5 modalities

    models = [gat_text_model, mlp_text_model, spatial_model, temporal_model, behavioral_model, fusion_model]

    # Initialize trainer
    trainer = MultiStageTrainer(
        models=models,
        datasets=datasets,
        feature_info=feature_info,
        device=device,
        emfd_csv_path=EMFD_CSV_PATH
    )

    # Execute multi-stage training
    print("\n🚀 Starting 5-Modality Multi-Stage Training Pipeline...")
    print("=" * 80)

    trainer.train_stage1_text_models(epochs=5, lr=2e-5)
    trainer.train_stage2_other_modalities(epochs=10, lr=5e-4)
    trainer.train_stage3_fusion_integration(epochs=2, lr=1e-3)
    trainer.train_stage4_end_to_end_finetuning(epochs=2, best_model_path="best_5modality_fusion.pth")

    # Final evaluation
    print("\n🔍 Starting Comprehensive 5-Modality Evaluation...")
    print("=" * 80)

    test_loader = create_dataloader(datasets['test'], batch_size=16, shuffle=False)

    # Evaluate individual models (including both text models)
    individual_results = evaluate_individual_models(test_loader, device, feature_info, EMFD_CSV_PATH)

    # Evaluate 5-modality fusion model
    fusion_f1 = evaluate_fusion_model("best_5modality_fusion.pth", test_loader, device, feature_info, EMFD_CSV_PATH)

    # Print comprehensive results
    print_comprehensive_results(individual_results, fusion_f1)

    print("\n✅ 5-modality training and comprehensive evaluation completed!")
    print("🏆 Best models saved as 'best_5modality_fusion.pth'")


🔄 Preparing dataset...
Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



🚀 Starting 5-Modality Multi-Stage Training Pipeline...

Training GAT-eMFD Text Model...


GAT Text - Epoch 1/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT Text model epoch 1, loss: 0.6706


GAT Text - Epoch 2/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT Text - Epoch 3/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT Text - Epoch 4/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT Text model epoch 4, loss: 0.4601


GAT Text - Epoch 5/5:   0%|          | 0/17 [00:00<?, ?it/s]

Training MLP-RoBERTa Text Model...


MLP Text - Epoch 1/5:   0%|          | 0/17 [00:00<?, ?it/s]

MLP Text model epoch 1, loss: 0.6792


MLP Text - Epoch 2/5:   0%|          | 0/17 [00:00<?, ?it/s]

MLP Text - Epoch 3/5:   0%|          | 0/17 [00:00<?, ?it/s]

MLP Text - Epoch 4/5:   0%|          | 0/17 [00:00<?, ?it/s]

MLP Text model epoch 4, loss: 0.4778


MLP Text - Epoch 5/5:   0%|          | 0/17 [00:00<?, ?it/s]

✅ Stage 1 Complete. Both text models saved.


--- Training Spatial Model ---


Spatial - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 1, loss: 0.5597


Spatial - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 4, loss: 0.3045


Spatial - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 7, loss: 0.3032


Spatial - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 10, loss: 0.2838

--- Training Temporal Model ---


Temporal - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 1, loss: 0.4454


Temporal - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 4, loss: 0.2675


Temporal - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 7, loss: 0.2616


Temporal - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 10, loss: 0.2453

--- Training Behavioral Model ---


Behavioral - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 1, loss: 0.4884


Behavioral - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 4, loss: 0.2493


Behavioral - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 7, loss: 0.2497


Behavioral - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 10, loss: 0.2511
✅ Stage 2 Complete. All modality models saved.



Stage 3 - Epoch 1/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 3 - Epoch 1: Loss: 0.3538, Val F1: 0.1424


Stage 3 - Epoch 2/2:   0%|          | 0/17 [00:00<?, ?it/s]

✅ Stage 3 Complete. Fusion model saved.



Stage 4 - Epoch 1/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 4 - Epoch 1: Loss: 0.2253, Val F1: 0.2210
🏆 New best model saved with F1: 0.2210


Stage 4 - Epoch 2/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 4 - Epoch 2: Loss: 0.2392, Val F1: 0.2231
🏆 New best model saved with F1: 0.2231
✅ Stage 4 Complete. End-to-end training finished.

🔍 Starting Comprehensive 5-Modality Evaluation...



Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Testing GAT-eMFD Text:   0%|          | 0/4 [00:00<?, ?it/s]


--- GAT-eMFD Text Model Results ---
Macro F1-Score: 0.1467


Testing MLP-RoBERTa Text:   0%|          | 0/4 [00:00<?, ?it/s]


--- MLP-RoBERTa Text Model Results ---
Macro F1-Score: 0.1467


Testing Spatial:   0%|          | 0/4 [00:00<?, ?it/s]


--- Spatial Model Results ---
Macro F1-Score: 0.1467


Testing Temporal:   0%|          | 0/4 [00:00<?, ?it/s]


--- Temporal Model Results ---
Macro F1-Score: 0.1389


Testing Behavioral:   0%|          | 0/4 [00:00<?, ?it/s]


--- Behavioral Model Results ---
Macro F1-Score: 0.2156



Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Testing 5-Modality Fusion Model:   0%|          | 0/4 [00:00<?, ?it/s]

              precision    recall  f1-score   support

        Care       0.81      0.95      0.88        44
    Fairness       0.00      0.00      0.00         0
     Loyalty       0.00      0.00      0.00         9
   Authority       0.00      0.00      0.00         4
      Purity       0.00      0.00      0.00         1
   Non_Moral       0.33      0.33      0.33         3

   micro avg       0.78      0.70      0.74        61
   macro avg       0.19      0.21      0.20        61
weighted avg       0.60      0.70      0.65        61
 samples avg       0.77      0.73      0.74        61

** 5-MODALITY FUSION MODEL Macro F1-Score: 0.2014 **

COMPREHENSIVE 5-MODALITY EVALUATION RESULTS

GAT-eMFD Text MODEL:
--------------------------------------------------
  Macro F1-Score: 0.1467
  Care F1: 0.8800
  Fairness F1: 0.0000
  Loyalty F1: 0.0000
  Authority F1: 0.0000
  Purity F1: 0.0000
  Non_Moral F1: 0.0000

MLP-RoBERTa Text MODEL:
--------------------------------------------------
  Ma

# GAT+mlp+ no focal loss proper fine tuning softmax  (correct one)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, Dropout, Sigmoid, ModuleDict, LayerNorm, LSTM, GELU, MultiheadAttention, Softmax
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, classification_report
import pandas as pd
import numpy as np
import json
import math
import os
from tqdm.autonotebook import tqdm

# PyTorch Geometric imports for GAT
try:
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
except ImportError:
    print("PyTorch Geometric not found. Please install: pip install torch_geometric")
    GATConv = None
    Data = None

# Moral foundations
moral_foundations = ['Care', 'Fairness', 'Loyalty', 'Authority', 'Purity', 'Non_Moral']

# ===================================================================
# 1. GAT-eMFD Text Model (from paste.txt)
# ===================================================================
class eMFDProcessor:
    def __init__(self, emfd_csv_path):
        self.moral_foundations_emfd = ['care', 'fairness', 'loyalty', 'authority', 'purity', 'non-moral']
        if not os.path.exists(emfd_csv_path):
            raise FileNotFoundError(f"eMFD file not found at {emfd_csv_path}")
        df = pd.read_csv(emfd_csv_path)
        prob_cols = [f'{f}_p' for f in self.moral_foundations_emfd if f != 'non-moral']
        self.emfd_data = {
            row['word']: np.append([row.get(col, 0.0) for col in prob_cols], 0.0).astype(np.float32)
            for _, row in df.iterrows()
        }

    def extract_moral_concepts(self, text):
        concepts = [
            {'word': w, 'probabilities': self.emfd_data[w]}
            for w in text.lower().split() if w in self.emfd_data
        ]
        return concepts if concepts else [{'word': 'neutral', 'probabilities': np.zeros(6, dtype=np.float32)}]

class MoralGraphConstructor:
    def create_moral_graph(self, moral_concepts):
        num_concepts = len(moral_concepts)
        node_features = torch.FloatTensor(np.array([
            np.concatenate([c['probabilities'], np.zeros(256)]) for c in moral_concepts
        ]))

        if num_concepts == 1:
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
        else:
            edge_index = torch.tensor(
                [[i, j] for i in range(num_concepts) for j in range(num_concepts) if i != j],
                dtype=torch.long
            ).t().contiguous()

        return Data(x=node_features, edge_index=edge_index) if Data else None

class GATeMFDModule(nn.Module):
    def __init__(self, input_dim=262, output_dim=256, num_heads=4, dropout=0.1):
        super(GATeMFDModule, self).__init__()
        self.concept_embeddings = nn.Embedding(1000, 256)
        self.gat_layer = GATConv(input_dim, output_dim, heads=num_heads, dropout=dropout, concat=False) if GATConv else nn.Identity()
        self.attention_pooling = nn.Linear(output_dim, 1)

    def forward(self, graph_data, device):
        if not GATConv:
            return torch.zeros((1, 256)).to(device)
        x, edge_index = graph_data.x.to(device), graph_data.edge_index.to(device)
        x[:, 6:] = self.concept_embeddings(torch.arange(x.size(0)).long().to(device))
        gat_output = F.elu(self.gat_layer(x, edge_index))
        attn_weights = F.softmax(self.attention_pooling(gat_output), dim=0)
        return torch.sum(attn_weights * gat_output, dim=0, keepdim=True)

class TextModelGATeMFD(nn.Module):
    def __init__(self, emfd_csv_path):
        super().__init__()
        self.emfd_processor = eMFDProcessor(emfd_csv_path)
        self.graph_constructor = MoralGraphConstructor()
        self.gat_module = GATeMFDModule()

        # Additional processing layers
        self.processor = Sequential(
            Linear(256, 512), LayerNorm(512), GELU(), Dropout(0.1),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(0.1)
        )

        # Prediction heads - Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            foundation: Sequential(Linear(256, 128), ReLU(), Dropout(0.2),
                                 Linear(128, 1))
            for foundation in moral_foundations
        })

    def forward(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        processed_features = self.processor(features)
        return {f: self.heads[f](processed_features) for f in moral_foundations}

    def get_features(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        return self.processor(features)

# ===================================================================
# 2. MLP Text Model (from paste-2.txt)
# ===================================================================
class TextModelRoBERTaFrozen(nn.Module):
    def __init__(self):
        super().__init__()
        # Load RoBERTa model
        self.tokenizer = AutoTokenizer.from_pretrained('roberta-base')
        self.roberta = AutoModel.from_pretrained('roberta-base')

        # Freeze RoBERTa completely
        for param in self.roberta.parameters():
            param.requires_grad = False

        # Trainable processing layers
        self.processor = Sequential(
            Linear(768, 512), LayerNorm(512), GELU(), Dropout(0.1),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(0.1)
        )

        # Prediction heads - Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            foundation: Sequential(Linear(256, 128), ReLU(), Dropout(0.2),
                                 Linear(128, 1))
            for foundation in moral_foundations
        })

    def forward(self, texts):
        # Tokenize and encode
        inputs = self.tokenizer(texts, padding=True, truncation=True,
                               max_length=512, return_tensors='pt')
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.roberta(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]  # CLS token

        features = self.processor(embeddings)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, texts):
        inputs = self.tokenizer(texts, padding=True, truncation=True,
                               max_length=512, return_tensors='pt')
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.roberta(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]
        return self.processor(embeddings)

# ===================================================================
# 3. Other Modality Models
# ===================================================================
class SpatialModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), ReLU(), Dropout(0.2),
            Linear(128, 256)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class TemporalModelLSTM(nn.Module):
    def __init__(self, input_dim, seq_len=10, hidden_size=128):
        super().__init__()
        self.projection = Linear(input_dim, hidden_size)
        self.lstm = LSTM(hidden_size, hidden_size, 2, batch_first=True,
                        dropout=0.2, bidirectional=True)
        self.encoder = Sequential(
            Linear(hidden_size * 2, 256), LayerNorm(256),
            ReLU(), Dropout(0.2)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        features = self.encoder(lstm_out[:, -1])
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        return self.encoder(lstm_out[:, -1])

class BehavioralModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), LayerNorm(128), ReLU(), Dropout(0.1),
            Linear(128, 256), LayerNorm(256), ReLU(), Dropout(0.1)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

# ===================================================================
# 4. Enhanced Heterogeneous Fusion (5 modalities)
# ===================================================================
class HeterogeneousFusion(nn.Module):
    def __init__(self, input_dims=None, d_model=256, num_heads=8):
        super().__init__()
        self.d_model = d_model

        # Feature projection layers to align dimensions (5 modalities now)
        if input_dims:
            self.projections = nn.ModuleList([
                nn.Linear(dim, d_model) for dim in input_dims
            ])
        else:
            self.projections = nn.ModuleList([
                nn.Linear(256, d_model) for _ in range(5)  # GAT text, MLP text, spatial, temporal, behavioral
            ])

        self.modality_queries = nn.Parameter(torch.randn(5, d_model))

        self.cross_attention = nn.ModuleList([
            MultiheadAttention(d_model, num_heads, batch_first=True)
            for _ in range(5)
        ])

        self.self_attention = MultiheadAttention(d_model, num_heads, batch_first=True)
        self.layer_norms = nn.ModuleList([LayerNorm(d_model) for _ in range(6)])

        # Improved gating mechanism for 5 modalities
        self.modality_gate = nn.Sequential(
            Linear(d_model * 5, d_model), ReLU(), Dropout(0.1),
            Linear(d_model, 5), Softmax(dim=-1)
        )

        self.final_projection = Linear(d_model, d_model)

        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(d_model, 128), ReLU(), Dropout(0.2),
                         Linear(128, 1))
            for f in moral_foundations
        })

    def forward(self, gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat):
        batch_size = gat_text_feat.size(0)

        # Project all features to same dimension
        features = [
            self.projections[i](feat)
            for i, feat in enumerate([gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat])
        ]

        # Add sequence dimension for attention
        modality_feats = [f.unsqueeze(1) for f in features]

        # Cross-attention with residual connections
        attended = []
        for i, (feat, attn, norm) in enumerate(zip(modality_feats, self.cross_attention, self.layer_norms[:5])):
            query = self.modality_queries[i:i+1].unsqueeze(0).expand(batch_size, -1, -1)
            att_out, _ = attn(query, feat, feat)
            attended.append(norm(att_out.squeeze(1) + features[i]))

        # Stack for self-attention
        stacked = torch.stack(attended, dim=1)

        # Self-attention
        self_att, _ = self.self_attention(stacked, stacked, stacked)
        fused = self.layer_norms[5](stacked + self_att)

        # Improved gating: compute gates from aggregated representation
        pooled = fused.mean(dim=1)
        gates = self.modality_gate(fused.view(batch_size, -1))

        # Apply gates to individual modality representations
        gated = torch.sum(fused * gates.unsqueeze(-1), dim=1)

        # Final projection
        final_features = self.final_projection(gated)

        return {f: self.heads[f](final_features) for f in moral_foundations}

# ===================================================================
# 5. Data Processing Functions (same as paste.txt)
# ===================================================================
def prepare_labels(df):
    """Combine virtue/vice pairs and create Non-Moral labels"""
    df = df.reset_index(drop=True)

    # Combine virtue/vice pairs using logical OR
    labels = {}
    labels['Care'] = (df['Care'] == 1) | (df['Harm'] == 1)
    labels['Fairness'] = (df['Fairness'] == 1) | (df.get('Cheating', df.get('cheating', pd.Series([False]*len(df), index=df.index))) == 1)
    labels['Loyalty'] = (df['Loyalty'] == 1) | (df['Betrayal'] == 1)
    labels['Authority'] = (df['Authority'] == 1) | (df['Subversion'] == 1)
    labels['Purity'] = (df['Purity'] == 1) | (df['Degradation'] == 1)

    # Create DataFrame from moral foundation labels
    moral_df = pd.DataFrame({
        'Care': labels['Care'],
        'Fairness': labels['Fairness'],
        'Loyalty': labels['Loyalty'],
        'Authority': labels['Authority'],
        'Purity': labels['Purity']
    })

    # Non_Moral = 1 when all moral foundations are 0 (False)
    labels['Non_Moral'] = (~moral_df.any(axis=1)).astype(int)

    # Convert to numpy arrays
    labels_array = np.column_stack([labels[f].astype(int).values for f in moral_foundations])
    return torch.tensor(labels_array, dtype=torch.float32)

def create_temporal_sequences(df, seq_len=10):
    """Create temporal sequences from dataset"""
    df_sorted = df.sort_values(['GEOID', 'year', 'month', 'day'])
    sequences, targets, other_features = [], [], []

    for geoid, group in df_sorted.groupby('GEOID'):
        if len(group) >= seq_len:
            for i in range(len(group) - seq_len + 1):
                seq_data = group.iloc[i:i+seq_len]

                # Temporal features
                temp_seq = []
                for _, row in seq_data.iterrows():
                    temp_features = [
                        row['month'], row['day'], row['year'], row['after_05/26/2020'],
                        row['female_pct'], row['lesscollege_pct'], row['cvap'],
                        row['clf_unemploy_pct'], row['lesshs_pct'], row['median_hh_inc'],
                        row['urm_pct'], row['ruralurban_cc'], row['demgov'], row['repgov'],
                        row['net_dem_president_votes'], row['net_dem_gov_votes'], row['is_blue'],
                        row['deaths'], row['cases'], row['cases_per_capita'],
                        row['deaths_per_capita'], row['cases_per_capita_discrete'], row['mask_score']
                    ]
                    temp_seq.append(temp_features)

                sequences.append(temp_seq)

                last_row = seq_data.iloc[-1]
                # Prepare labels for last timestep
                last_df = pd.DataFrame([last_row])
                targets.append(prepare_labels(last_df)[0])

                other_features.append({
                    'text': last_row['text'],
                    'geoid': str(last_row['GEOID']).zfill(5),
                    'behavioral': [last_row['retweet_count'], last_row['rt_discrete'],
                                 last_row['for_sah'], last_row['is_vivid'], last_row['sentiment_score']]
                })

    return sequences, targets, other_features

def prepare_dataset(csv_path, geojson_path, county_centroids_path, seq_len=10):
    """Streamlined dataset preparation"""
    # Load data
    df = pd.read_csv(csv_path)
    with open(geojson_path) as f:
        boundaries = json.load(f)
    with open(county_centroids_path) as f:
        centroids = json.load(f)

    # Create sequences
    sequences, targets, other_features = create_temporal_sequences(df, seq_len)

    # Process features
    data = {'text': [], 'spatial': [], 'temporal': [], 'behavioral': [], 'targets': []}

    for i, (seq, target, other) in enumerate(zip(sequences, targets, other_features)):
        data['text'].append(other['text'])
        data['temporal'].append(seq)
        data['behavioral'].append(other['behavioral'])
        data['targets'].append(target)

        # Extract spatial features
        geoid = other['geoid']
        if geoid in boundaries and geoid in centroids:
            # Simplified spatial feature extraction
            centroid = centroids[geoid]
            lat_norm = (centroid[0] - 24) / (49 - 24)  # Normalize lat
            lon_norm = (centroid[1] + 125) / (66 - 125)  # Normalize lon
            spatial_feat = [lat_norm, lon_norm, 0.5, 0.5, 0.5, 0.5]  # Simplified features
        else:
            spatial_feat = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]  # Default

        data['spatial'].append(spatial_feat)

    # Convert to tensors
    spatial_tensor = torch.tensor(data['spatial'], dtype=torch.float32)
    temporal_tensor = torch.tensor(data['temporal'], dtype=torch.float32)
    behavioral_tensor = torch.tensor(data['behavioral'], dtype=torch.float32)
    targets_tensor = torch.stack(data['targets'])

    # Normalize features
    scaler_spatial = StandardScaler()
    scaler_behavioral = StandardScaler()
    scaler_temporal = StandardScaler()

    spatial_tensor = torch.tensor(scaler_spatial.fit_transform(spatial_tensor), dtype=torch.float32)
    behavioral_tensor = torch.tensor(scaler_behavioral.fit_transform(behavioral_tensor), dtype=torch.float32)

    # Normalize temporal (reshape for scaling)
    temp_shape = temporal_tensor.shape
    temp_2d = temporal_tensor.view(-1, temp_shape[2])
    temp_scaled = scaler_temporal.fit_transform(temp_2d)
    temporal_tensor = torch.tensor(temp_scaled.reshape(temp_shape), dtype=torch.float32)

    # Create splits
    indices = list(range(len(data['text'])))
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

    datasets = {}
    for split, idx in [('train', train_idx), ('validation', val_idx), ('test', test_idx)]:
        datasets[split] = {
            'text_data': [data['text'][i] for i in idx],
            'spatial_features': spatial_tensor[idx],
            'temporal_sequences': temporal_tensor[idx],
            'behavioral_features': behavioral_tensor[idx],
            'moral_targets': targets_tensor[idx]
        }

    feature_info = {
        'spatial_feature_dim': spatial_tensor.shape[1],
        'temporal_feature_dim': temporal_tensor.shape[2],
        'temporal_sequence_length': seq_len,
        'behavioral_feature_dim': behavioral_tensor.shape[1]
    }

    return datasets, feature_info

def create_dataloader(dataset_dict, batch_size=16, shuffle=True):
    """Create DataLoader for the dataset"""
    def collate_fn(batch):
        texts = [item[0] for item in batch]
        spatial = torch.stack([item[1] for item in batch])
        temporal = torch.stack([item[2] for item in batch])
        behavioral = torch.stack([item[3] for item in batch])
        targets = torch.stack([item[4] for item in batch])

        return {
            'text_data': texts,
            'spatial_features': spatial,
            'temporal_sequences': temporal,
            'behavioral_features': behavioral,
            'targets': targets
        }

    dataset = [(dataset_dict['text_data'][i],
                dataset_dict['spatial_features'][i],
                dataset_dict['temporal_sequences'][i],
                dataset_dict['behavioral_features'][i],
                dataset_dict['moral_targets'][i])
               for i in range(len(dataset_dict['text_data']))]

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

# ===================================================================
# 6. Enhanced Multi-Stage Training System (5 modalities)
# ===================================================================
class MultiStageTrainer:
    def __init__(self, models, datasets, feature_info, device, emfd_csv_path):
        self.gat_text_model, self.mlp_text_model, self.spatial_model, self.temporal_model, self.behavioral_model, self.fusion_model = models
        self.datasets = datasets
        self.device = device

        # Move models to device
        self.gat_text_model = self.gat_text_model.to(device)
        self.mlp_text_model = self.mlp_text_model.to(device)
        self.spatial_model = self.spatial_model.to(device)
        self.temporal_model = self.temporal_model.to(device)
        self.behavioral_model = self.behavioral_model.to(device)
        self.fusion_model = self.fusion_model.to(device)

        # Create data loaders
        self.train_loader = create_dataloader(datasets['train'], batch_size=16, shuffle=True)
        self.val_loader = create_dataloader(datasets['validation'], batch_size=16, shuffle=False)

        train_targets = self.datasets['train']['moral_targets']
        # Calculate N_negative / N_positive for each class to get the weight
        pos_counts = train_targets.sum(dim=0)
        neg_counts = len(train_targets) - pos_counts
        # Add epsilon to avoid division by zero if a class has no positive samples
        pos_weight = neg_counts / (pos_counts + 1e-6)

        # Use these weights in your loss function
        self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(self.device))

    def _freeze(self, model):
        for p in model.parameters():
            p.requires_grad_(False)

    def _unfreeze(self, model):
        for p in model.parameters():
            p.requires_grad_(True)

    def compute_loss(self, predictions, targets):
        """Compute loss for moral foundation predictions"""
        pred_stack = torch.cat([predictions[f] for f in moral_foundations], dim=1)
        return self.criterion(pred_stack, targets)

    def train_stage1_text_models(self, epochs=15, lr=2e-5):
        """Stage 1: Train both GAT eMFD and MLP Text Models"""
        print("\n" + "="*20 + " Stage 1: Training Both Text Models " + "="*20)

        # Freeze other models
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._freeze(self.fusion_model)

        # Train GAT Text Model
        print("Training GAT-eMFD Text Model...")
        self._unfreeze(self.gat_text_model)
        self._freeze(self.mlp_text_model)

        optimizer_gat = optim.AdamW(self.gat_text_model.parameters(), lr=lr, weight_decay=1e-5)

        for epoch in range(epochs):
            self.gat_text_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"GAT Text - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer_gat.zero_grad()
                preds = self.gat_text_model(batch['text_data'])
                loss = self.compute_loss(preds, batch['targets'])
                loss.backward()
                optimizer_gat.step()
                total_loss += loss.item()

            if epoch % 3 == 0:
                print(f"GAT Text model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

        # Train MLP Text Model
        print("Training MLP-RoBERTa Text Model...")
        self._freeze(self.gat_text_model)
        self._unfreeze(self.mlp_text_model)

        optimizer_mlp = optim.AdamW(self.mlp_text_model.parameters(), lr=lr, weight_decay=1e-5)

        for epoch in range(epochs):
            self.mlp_text_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"MLP Text - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer_mlp.zero_grad()
                preds = self.mlp_text_model(batch['text_data'])
                loss = self.compute_loss(preds, batch['targets'])
                loss.backward()
                optimizer_mlp.step()
                total_loss += loss.item()

            if epoch % 3 == 0:
                print(f"MLP Text model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

        # Save models
        torch.save(self.gat_text_model.state_dict(), "gat_text_stage1.pth")
        torch.save(self.mlp_text_model.state_dict(), "mlp_text_stage1.pth")
        print("✅ Stage 1 Complete. Both text models saved.")

    def train_stage2_other_modalities(self, epochs=10, lr=5e-4):
        """Stage 2: Train other modality models"""
        print("\n" + "="*20 + " Stage 2: Training Other Modality Models " + "="*20)

        # Freeze text and fusion models
        self._freeze(self.gat_text_model)
        self._freeze(self.mlp_text_model)
        self._freeze(self.fusion_model)

        models_to_train = [
            (self.spatial_model, 'spatial'),
            (self.temporal_model, 'temporal'),
            (self.behavioral_model, 'behavioral')
        ]

        for model, name in models_to_train:
            print(f"\n--- Training {name.capitalize()} Model ---")
            self._unfreeze(model)
            optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

            for epoch in range(epochs):
                model.train()
                total_loss = 0

                for batch in tqdm(self.train_loader, desc=f"{name.capitalize()} - Epoch {epoch+1}/{epochs}"):
                    batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                            for k, v in batch.items()}

                    optimizer.zero_grad()

                    if name == 'spatial':
                        preds = model(batch['spatial_features'])
                    elif name == 'temporal':
                        preds = model(batch['temporal_sequences'])
                    else:  # behavioral
                        preds = model(batch['behavioral_features'])

                    loss = self.compute_loss(preds, batch['targets'])
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()

                if epoch % 3 == 0:
                    print(f"{name.capitalize()} model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

            # Save model
            torch.save(model.state_dict(), f"{name}_model_stage2.pth")
            self._freeze(model)

        print("✅ Stage 2 Complete. All modality models saved.")

    def train_stage3_fusion_integration(self, epochs=8, lr=1e-3):
        """Stage 3: Fusion Integration with 5 modalities"""
        print("\n" + "="*15 + " Stage 3: Fusion Integration (5 modalities) " + "="*15)

        # Freeze all individual models, unfreeze fusion
        self._freeze(self.gat_text_model)
        self._freeze(self.mlp_text_model)
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._unfreeze(self.fusion_model)

        optimizer = optim.AdamW(self.fusion_model.parameters(), lr=lr, weight_decay=1e-4)

        for epoch in range(epochs):
            self.fusion_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"Stage 3 - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer.zero_grad()

                # Extract features with no gradients for individual models
                with torch.no_grad():
                    gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
                    mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
                    spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                    temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                    behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                # Fusion with gradients (5 modalities)
                fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
                loss = self.compute_loss(fusion_preds, batch['targets'])
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if epoch % 2 == 0:
                val_f1 = self._validate_fusion()
                print(f"Stage 3 - Epoch {epoch+1}: Loss: {total_loss/len(self.train_loader):.4f}, Val F1: {val_f1:.4f}")

        torch.save(self.fusion_model.state_dict(), "fusion_model_stage3.pth")
        print("✅ Stage 3 Complete. Fusion model saved.")

    def train_stage4_end_to_end_finetuning(self, epochs=5, best_model_path="best_fusion_model.pth"):
        """Stage 4: End-to-end fine-tuning of ALL models (GAT text, MLP text, spatial, temporal, behavioral, fusion)"""
        print("\n" + "="*15 + " Stage 4: End-to-End Fine-tuning (ALL ENCODERS) " + "="*15)

        # Unfreeze ALL models for end-to-end training
        self._unfreeze(self.gat_text_model)
        self._unfreeze(self.mlp_text_model)
        self._unfreeze(self.spatial_model)
        self._unfreeze(self.temporal_model)
        self._unfreeze(self.behavioral_model)
        self._unfreeze(self.fusion_model)

        # Create optimizer with different learning rates for different model components
        optimizer = optim.AdamW([
            {'params': self.gat_text_model.parameters(), 'lr': 5e-6},
            {'params': self.mlp_text_model.parameters(), 'lr': 5e-6},
            {'params': self.spatial_model.parameters(), 'lr': 1e-5},
            {'params': self.temporal_model.parameters(), 'lr': 1e-5},
            {'params': self.behavioral_model.parameters(), 'lr': 1e-5},
            {'params': self.fusion_model.parameters(), 'lr': 1e-4}
        ], weight_decay=1e-4)

        best_val_f1 = 0.0
        patience_counter = 0

        for epoch in range(epochs):
            # Set all models to training mode
            self.gat_text_model.train()
            self.mlp_text_model.train()
            self.spatial_model.train()
            self.temporal_model.train()
            self.behavioral_model.train()
            self.fusion_model.train()

            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"Stage 4 (ALL ENCODERS) - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer.zero_grad()

                # Extract features from ALL models with gradients enabled
                gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
                mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
                spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                # Fusion with all gradients flowing
                fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
                loss = self.compute_loss(fusion_preds, batch['targets'])
                loss.backward()

                # Gradient clipping to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(self.gat_text_model.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.mlp_text_model.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.spatial_model.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.temporal_model.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.behavioral_model.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.fusion_model.parameters(), max_norm=1.0)

                optimizer.step()
                total_loss += loss.item()

            val_f1 = self._validate_fusion()
            print(f"Stage 4 (ALL ENCODERS) - Epoch {epoch+1}: Loss: {total_loss/len(self.train_loader):.4f}, Val F1: {val_f1:.4f}")

            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                patience_counter = 0
                torch.save({
                    'gat_text_model': self.gat_text_model.state_dict(),
                    'mlp_text_model': self.mlp_text_model.state_dict(),
                    'spatial_model': self.spatial_model.state_dict(),
                    'temporal_model': self.temporal_model.state_dict(),
                    'behavioral_model': self.behavioral_model.state_dict(),
                    'fusion_model': self.fusion_model.state_dict()
                }, best_model_path)
                print(f"🏆 New best model saved with F1: {val_f1:.4f} (ALL ENCODERS fine-tuned)")
            else:
                patience_counter += 1
                if patience_counter >= 2:
                    print("🛑 Early stopping triggered.")
                    break

        print("✅ Stage 4 Complete. End-to-end training of ALL ENCODERS finished.")

    def _validate_fusion(self):
        """Validate fusion model with 5 modalities - ALL models in eval mode"""
        self.gat_text_model.eval()
        self.mlp_text_model.eval()
        self.spatial_model.eval()
        self.temporal_model.eval()
        self.behavioral_model.eval()
        self.fusion_model.eval()

        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in self.val_loader:
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
                mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
                spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
                pred_stack = torch.cat([fusion_preds[f] for f in moral_foundations], dim=1)

                all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
                all_labels.append(batch['targets'].cpu().numpy())

        y_pred = np.vstack(all_preds)
        y_true = np.vstack(all_labels)
        return f1_score(y_true, y_pred, average='macro', zero_division=0)



def train_stage4_end_to_end_finetuning(self, epochs=5, best_model_path="best_fusion_model.pth"):
    """Stage 4: End-to-end fine-tuning of ALL models (GAT text, MLP text, spatial, temporal, behavioral, fusion)"""
    print("\n" + "="*15 + " Stage 4: End-to-End Fine-tuning (ALL ENCODERS) " + "="*15)

    # Unfreeze ALL models for end-to-end training
    self._unfreeze(self.gat_text_model)
    self._unfreeze(self.mlp_text_model)
    self._unfreeze(self.spatial_model)
    self._unfreeze(self.temporal_model)
    self._unfreeze(self.behavioral_model)
    self._unfreeze(self.fusion_model)

    # Create optimizer with different learning rates for different model components
    optimizer = optim.AdamW([
        {'params': self.gat_text_model.parameters(), 'lr': 5e-6},
        {'params': self.mlp_text_model.parameters(), 'lr': 5e-6},
        {'params': self.spatial_model.parameters(), 'lr': 1e-5},
        {'params': self.temporal_model.parameters(), 'lr': 1e-5},
        {'params': self.behavioral_model.parameters(), 'lr': 1e-5},
        {'params': self.fusion_model.parameters(), 'lr': 1e-4}
    ], weight_decay=1e-4)

    best_val_f1 = 0.0
    patience_counter = 0

    for epoch in range(epochs):
        # Set all models to training mode
        self.gat_text_model.train()
        self.mlp_text_model.train()
        self.spatial_model.train()
        self.temporal_model.train()
        self.behavioral_model.train()
        self.fusion_model.train()

        total_loss = 0

        for batch in tqdm(self.train_loader, desc=f"Stage 4 (ALL ENCODERS) - Epoch {epoch+1}/{epochs}"):
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            optimizer.zero_grad()

            # Extract features from ALL models with gradients enabled
            gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
            mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
            spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
            temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

            # Fusion with all gradients flowing
            fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
            loss = self.compute_loss(fusion_preds, batch['targets'])
            loss.backward()

            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(self.gat_text_model.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(self.mlp_text_model.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(self.spatial_model.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(self.temporal_model.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(self.behavioral_model.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(self.fusion_model.parameters(), max_norm=1.0)

            optimizer.step()
            total_loss += loss.item()

        val_f1 = self._validate_fusion()
        print(f"Stage 4 (ALL ENCODERS) - Epoch {epoch+1}: Loss: {total_loss/len(self.train_loader):.4f}, Val F1: {val_f1:.4f}")

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            patience_counter = 0
            torch.save({
                'gat_text_model': self.gat_text_model.state_dict(),
                'mlp_text_model': self.mlp_text_model.state_dict(),
                'spatial_model': self.spatial_model.state_dict(),
                'temporal_model': self.temporal_model.state_dict(),
                'behavioral_model': self.behavioral_model.state_dict(),
                'fusion_model': self.fusion_model.state_dict()
            }, best_model_path)
            print(f"🏆 New best model saved with F1: {val_f1:.4f} (ALL ENCODERS fine-tuned)")
        else:
            patience_counter += 1
            if patience_counter >= 2:
                print("🛑 Early stopping triggered.")
                break

    print("✅ Stage 4 Complete. End-to-end training of ALL ENCODERS finished.")
def _validate_fusion(self):
    """Validate fusion model with 5 modalities - ALL models in eval mode"""
    self.gat_text_model.eval()
    self.mlp_text_model.eval()
    self.spatial_model.eval()
    self.temporal_model.eval()
    self.behavioral_model.eval()
    self.fusion_model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in self.val_loader:
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
            mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
            spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
            temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

            fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
            pred_stack = torch.cat([fusion_preds[f] for f in moral_foundations], dim=1)

            all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
            all_labels.append(batch['targets'].cpu().numpy())

    y_pred = np.vstack(all_preds)
    y_true = np.vstack(all_labels)
    return f1_score(y_true, y_pred, average='macro', zero_division=0)

# ===================================================================
# 7. Enhanced Evaluation Functions
# ===================================================================
def evaluate_individual_models(test_loader, device, feature_info, emfd_csv_path):
    """Evaluate all individual models including both text models"""
    print("\n" + "="*20 + " INDIVIDUAL MODEL EVALUATION " + "="*20)

    # Load models
    gat_text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    mlp_text_model = TextModelRoBERTaFrozen().to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)

    # Load weights
    gat_text_model.load_state_dict(torch.load("gat_text_stage1.pth"))
    mlp_text_model.load_state_dict(torch.load("mlp_text_stage1.pth"))
    spatial_model.load_state_dict(torch.load("spatial_model_stage2.pth"))
    temporal_model.load_state_dict(torch.load("temporal_model_stage2.pth"))
    behavioral_model.load_state_dict(torch.load("behavioral_model_stage2.pth"))

    models = [
        (gat_text_model, 'GAT-eMFD Text'),
        (mlp_text_model, 'MLP-RoBERTa Text'),
        (spatial_model, 'Spatial'),
        (temporal_model, 'Temporal'),
        (behavioral_model, 'Behavioral')
    ]

    results = {}

    for model, name in models:
        model.eval()
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in tqdm(test_loader, desc=f"Testing {name}"):
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                if 'Text' in name:
                    preds = model(batch['text_data'])
                elif name == 'Spatial':
                    preds = model(batch['spatial_features'])
                elif name == 'Temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # Behavioral
                    preds = model(batch['behavioral_features'])

                pred_stack = torch.cat([preds[f] for f in moral_foundations], dim=1)
                all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
                all_labels.append(batch['targets'].cpu().numpy())

        y_pred = np.vstack(all_preds)
        y_true = np.vstack(all_labels)
        macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

        print(f"\n--- {name} Model Results ---")
        print(f"Macro F1-Score: {macro_f1:.4f}")

        # Foundation-specific results
        foundation_f1s = []
        for i, foundation in enumerate(moral_foundations):
            f1 = f1_score(y_true[:, i], y_pred[:, i], zero_division=0)
            foundation_f1s.append(f1)

        results[name] = {
            'macro_f1': macro_f1,
            'foundation_f1s': foundation_f1s
        }

    return results

def evaluate_fusion_model(model_path, test_loader, device, feature_info, emfd_csv_path):
    """Evaluate the final fusion model with 5 modalities"""
    print("\n" + "="*20 + " FINAL FUSION MODEL EVALUATION (5 MODALITIES) " + "="*20)

    # Load models
    gat_text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    mlp_text_model = TextModelRoBERTaFrozen().to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)
    fusion_model = HeterogeneousFusion().to(device)

    # Load best model state
    checkpoint = torch.load(model_path)
    gat_text_model.load_state_dict(checkpoint['gat_text_model'])
    mlp_text_model.load_state_dict(checkpoint['mlp_text_model'])
    spatial_model.load_state_dict(checkpoint['spatial_model'])
    temporal_model.load_state_dict(checkpoint['temporal_model'])
    behavioral_model.load_state_dict(checkpoint['behavioral_model'])
    fusion_model.load_state_dict(checkpoint['fusion_model'])

    # Set to eval mode
    gat_text_model.eval()
    mlp_text_model.eval()
    spatial_model.eval()
    temporal_model.eval()
    behavioral_model.eval()
    fusion_model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing 5-Modality Fusion Model"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            # Extract features
            gat_text_feat = gat_text_model.get_features(batch['text_data'])
            mlp_text_feat = mlp_text_model.get_features(batch['text_data'])
            spatial_feat = spatial_model.get_features(batch['spatial_features'])
            temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])

            # Fusion
            fusion_preds = fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
            pred_stack = torch.cat([fusion_preds[f] for f in moral_foundations], dim=1)

            all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
            all_labels.append(batch['targets'].cpu().numpy())

    y_pred = np.vstack(all_preds)
    y_true = np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    print(classification_report(y_true, y_pred, target_names=moral_foundations, zero_division=0))
    print(f"** 5-MODALITY FUSION MODEL Macro F1-Score: {macro_f1:.4f} **")

    return macro_f1

def print_comprehensive_results(individual_results, fusion_f1):
    """Print comprehensive comparison results for 5 modalities"""
    print("\n" + "="*80)
    print("COMPREHENSIVE 5-MODALITY EVALUATION RESULTS")
    print("=" * 80)

    # Individual model results
    for model_name, results in individual_results.items():
        print(f"\n{model_name} MODEL:")
        print("-" * 50)
        print(f"  Macro F1-Score: {results['macro_f1']:.4f}")
        for i, foundation in enumerate(moral_foundations):
            print(f"  {foundation} F1: {results['foundation_f1s'][i]:.4f}")

    print(f"\n5-MODALITY FUSION MODEL:")
    print("-" * 50)
    print(f"  Macro F1-Score: {fusion_f1:.4f}")

    # Foundation-specific comparison
    print("\n" + "="*80)
    print("FOUNDATION-SPECIFIC F1 COMPARISON:")
    print("=" * 80)
    header = "Foundation     GAT-Text  MLP-Text  Spatial   Temporal  Behavioral  Fusion"
    print(header)
    print("-" * 80)

    for i, foundation in enumerate(moral_foundations):
        row = f"{foundation:<12}"
        for model_name in ['GAT-eMFD Text', 'MLP-RoBERTa Text', 'Spatial', 'Temporal', 'Behavioral']:
            if model_name in individual_results:
                f1_score = individual_results[model_name]['foundation_f1s'][i]
                row += f"  {f1_score:.3f}    "
        row += f"  {fusion_f1:.3f}"
        print(row)

# ===================================================================
# 8. Main Execution
# ===================================================================
if __name__ == "__main__":
    # File paths - UPDATE THESE
    CSV_PATH = '/content/augmented_tweets.csv'
    GEOJSON_PATH = '/content/county_data_boarders.json'
    CENTROIDS_PATH = '/content/county_centroids.json'
    EMFD_CSV_PATH = '/content/eMFD_wordlist.csv'

    if not GATConv:
        print("❌ PyTorch Geometric not available. Please install it.")
        exit()

    # Prepare dataset
    print("🔄 Preparing dataset...")
    datasets, feature_info = prepare_dataset(
        csv_path=CSV_PATH,
        geojson_path=GEOJSON_PATH,
        county_centroids_path=CENTROIDS_PATH,
        seq_len=10
    )

    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Initialize models (now 6 models: 2 text + 3 others + 1 fusion)
    gat_text_model = TextModelGATeMFD(EMFD_CSV_PATH)
    mlp_text_model = TextModelRoBERTaFrozen()
    spatial_model = SpatialModel(feature_info['spatial_feature_dim'])
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    )
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim'])
    fusion_model = HeterogeneousFusion()  # Now handles 5 modalities

    models = [gat_text_model, mlp_text_model, spatial_model, temporal_model, behavioral_model, fusion_model]

    # Initialize trainer
    trainer = MultiStageTrainer(
        models=models,
        datasets=datasets,
        feature_info=feature_info,
        device=device,
        emfd_csv_path=EMFD_CSV_PATH
    )

    # Execute multi-stage training
    print("\n🚀 Starting 5-Modality Multi-Stage Training Pipeline...")
    print("=" * 80)

    trainer.train_stage1_text_models(epochs=5, lr=2e-5)
    trainer.train_stage2_other_modalities(epochs=10, lr=5e-4)
    trainer.train_stage3_fusion_integration(epochs=2, lr=1e-3)
    trainer.train_stage4_end_to_end_finetuning(epochs=2, best_model_path="best_5modality_fusion.pth")

    # Final evaluation
    print("\n🔍 Starting Comprehensive 5-Modality Evaluation...")
    print("=" * 80)

    test_loader = create_dataloader(datasets['test'], batch_size=16, shuffle=False)

    # Evaluate individual models (including both text models)
    individual_results = evaluate_individual_models(test_loader, device, feature_info, EMFD_CSV_PATH)

    # Evaluate 5-modality fusion model
    fusion_f1 = evaluate_fusion_model("best_5modality_fusion.pth", test_loader, device, feature_info, EMFD_CSV_PATH)

    # Print comprehensive results
    print_comprehensive_results(individual_results, fusion_f1)

    print("\n✅ 5-modality training and comprehensive evaluation completed!")
    print("🏆 Best models saved as 'best_5modality_fusion.pth'")


🔄 Preparing dataset...
Using device: cuda


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



🚀 Starting 5-Modality Multi-Stage Training Pipeline...

Training GAT-eMFD Text Model...


GAT Text - Epoch 1/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT Text model epoch 1, loss: 1.1199


GAT Text - Epoch 2/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT Text - Epoch 3/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT Text - Epoch 4/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT Text model epoch 4, loss: 1.1297


GAT Text - Epoch 5/5:   0%|          | 0/17 [00:00<?, ?it/s]

Training MLP-RoBERTa Text Model...


MLP Text - Epoch 1/5:   0%|          | 0/17 [00:00<?, ?it/s]

MLP Text model epoch 1, loss: 1.1198


MLP Text - Epoch 2/5:   0%|          | 0/17 [00:00<?, ?it/s]

MLP Text - Epoch 3/5:   0%|          | 0/17 [00:00<?, ?it/s]

MLP Text - Epoch 4/5:   0%|          | 0/17 [00:00<?, ?it/s]

MLP Text model epoch 4, loss: 1.0994


MLP Text - Epoch 5/5:   0%|          | 0/17 [00:00<?, ?it/s]

✅ Stage 1 Complete. Both text models saved.


--- Training Spatial Model ---


Spatial - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 1, loss: 1.1432


Spatial - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 4, loss: 1.0064


Spatial - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 7, loss: 0.9652


Spatial - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 10, loss: 0.9725

--- Training Temporal Model ---


Temporal - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 1, loss: 1.1793


Temporal - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 4, loss: 0.8522


Temporal - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 7, loss: 0.7401


Temporal - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 10, loss: 0.7228

--- Training Behavioral Model ---


Behavioral - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 1, loss: 1.0674


Behavioral - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 4, loss: 0.8291


Behavioral - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 7, loss: 0.7893


Behavioral - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 10, loss: 0.8114
✅ Stage 2 Complete. All modality models saved.



Stage 3 - Epoch 1/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 3 - Epoch 1: Loss: 1.1737, Val F1: 0.1974


Stage 3 - Epoch 2/2:   0%|          | 0/17 [00:00<?, ?it/s]

✅ Stage 3 Complete. Fusion model saved.



Stage 4 (ALL ENCODERS) - Epoch 1/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 4 (ALL ENCODERS) - Epoch 1: Loss: 0.5593, Val F1: 0.3481
🏆 New best model saved with F1: 0.3481 (ALL ENCODERS fine-tuned)


Stage 4 (ALL ENCODERS) - Epoch 2/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 4 (ALL ENCODERS) - Epoch 2: Loss: 0.5541, Val F1: 0.3437
✅ Stage 4 Complete. End-to-end training of ALL ENCODERS finished.

🔍 Starting Comprehensive 5-Modality Evaluation...



Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Testing GAT-eMFD Text:   0%|          | 0/4 [00:00<?, ?it/s]


--- GAT-eMFD Text Model Results ---
Macro F1-Score: 0.1097


Testing MLP-RoBERTa Text:   0%|          | 0/4 [00:00<?, ?it/s]


--- MLP-RoBERTa Text Model Results ---
Macro F1-Score: 0.1636


Testing Spatial:   0%|          | 0/4 [00:00<?, ?it/s]


--- Spatial Model Results ---
Macro F1-Score: 0.1858


Testing Temporal:   0%|          | 0/4 [00:00<?, ?it/s]


--- Temporal Model Results ---
Macro F1-Score: 0.1845


Testing Behavioral:   0%|          | 0/4 [00:00<?, ?it/s]


--- Behavioral Model Results ---
Macro F1-Score: 0.4021



Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Testing 5-Modality Fusion Model:   0%|          | 0/4 [00:00<?, ?it/s]

              precision    recall  f1-score   support

        Care       0.83      0.86      0.84        44
    Fairness       0.00      0.00      0.00         0
     Loyalty       0.16      0.56      0.24         9
   Authority       0.12      0.50      0.20         4
      Purity       0.00      0.00      0.00         1
   Non_Moral       0.20      0.33      0.25         3

   micro avg       0.44      0.75      0.56        61
   macro avg       0.22      0.38      0.26        61
weighted avg       0.64      0.75      0.67        61
 samples avg       0.48      0.76      0.56        61

** 5-MODALITY FUSION MODEL Macro F1-Score: 0.2564 **

COMPREHENSIVE 5-MODALITY EVALUATION RESULTS

GAT-eMFD Text MODEL:
--------------------------------------------------
  Macro F1-Score: 0.1097
  Care F1: 0.0000
  Fairness F1: 0.0000
  Loyalty F1: 0.2222
  Authority F1: 0.1111
  Purity F1: 0.2000
  Non_Moral F1: 0.1250

MLP-RoBERTa Text MODEL:
--------------------------------------------------
  Ma

# with focal loss

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, Dropout, Sigmoid, ModuleDict, LayerNorm, LSTM, GELU, MultiheadAttention, Softmax
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, classification_report
import pandas as pd
import numpy as np
import json
import math
import os
from tqdm.autonotebook import tqdm

# PyTorch Geometric imports for GAT
try:
    from torch_geometric.nn import GATConv
    from torch_geometric.data import Data
except ImportError:
    print("PyTorch Geometric not found. Please install: pip install torch_geometric")
    GATConv = None
    Data = None

# Moral foundations
moral_foundations = ['Care', 'Fairness', 'Loyalty', 'Authority', 'Purity', 'Non_Moral']

# ===================================================================
# 1. GAT-eMFD Text Model (from paste.txt)
# ===================================================================
class eMFDProcessor:
    def __init__(self, emfd_csv_path):
        self.moral_foundations_emfd = ['care', 'fairness', 'loyalty', 'authority', 'purity', 'non-moral']
        if not os.path.exists(emfd_csv_path):
            raise FileNotFoundError(f"eMFD file not found at {emfd_csv_path}")
        df = pd.read_csv(emfd_csv_path)
        prob_cols = [f'{f}_p' for f in self.moral_foundations_emfd if f != 'non-moral']
        self.emfd_data = {
            row['word']: np.append([row.get(col, 0.0) for col in prob_cols], 0.0).astype(np.float32)
            for _, row in df.iterrows()
        }

    def extract_moral_concepts(self, text):
        concepts = [
            {'word': w, 'probabilities': self.emfd_data[w]}
            for w in text.lower().split() if w in self.emfd_data
        ]
        return concepts if concepts else [{'word': 'neutral', 'probabilities': np.zeros(6, dtype=np.float32)}]

class MoralGraphConstructor:
    def create_moral_graph(self, moral_concepts):
        num_concepts = len(moral_concepts)
        node_features = torch.FloatTensor(np.array([
            np.concatenate([c['probabilities'], np.zeros(256)]) for c in moral_concepts
        ]))

        if num_concepts == 1:
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
        else:
            edge_index = torch.tensor(
                [[i, j] for i in range(num_concepts) for j in range(num_concepts) if i != j],
                dtype=torch.long
            ).t().contiguous()

        return Data(x=node_features, edge_index=edge_index) if Data else None

class GATeMFDModule(nn.Module):
    def __init__(self, input_dim=262, output_dim=256, num_heads=4, dropout=0.1):
        super(GATeMFDModule, self).__init__()
        self.concept_embeddings = nn.Embedding(1000, 256)
        self.gat_layer = GATConv(input_dim, output_dim, heads=num_heads, dropout=dropout, concat=False) if GATConv else nn.Identity()
        self.attention_pooling = nn.Linear(output_dim, 1)

    def forward(self, graph_data, device):
        if not GATConv:
            return torch.zeros((1, 256)).to(device)
        x, edge_index = graph_data.x.to(device), graph_data.edge_index.to(device)
        x[:, 6:] = self.concept_embeddings(torch.arange(x.size(0)).long().to(device))
        gat_output = F.elu(self.gat_layer(x, edge_index))
        attn_weights = F.softmax(self.attention_pooling(gat_output), dim=0)
        return torch.sum(attn_weights * gat_output, dim=0, keepdim=True)

class TextModelGATeMFD(nn.Module):
    def __init__(self, emfd_csv_path):
        super().__init__()
        self.emfd_processor = eMFDProcessor(emfd_csv_path)
        self.graph_constructor = MoralGraphConstructor()
        self.gat_module = GATeMFDModule()

        # Additional processing layers
        self.processor = Sequential(
            Linear(256, 512), LayerNorm(512), GELU(), Dropout(0.1),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(0.1)
        )

        # Prediction heads - Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            foundation: Sequential(Linear(256, 128), ReLU(), Dropout(0.2),
                                 Linear(128, 1))
            for foundation in moral_foundations
        })

    def forward(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        processed_features = self.processor(features)
        return {f: self.heads[f](processed_features) for f in moral_foundations}

    def get_features(self, texts):
        device = next(self.parameters()).device
        batch_features = []

        for text in texts:
            moral_concepts = self.emfd_processor.extract_moral_concepts(text)
            graph_data = self.graph_constructor.create_moral_graph(moral_concepts)
            if graph_data:
                gat_features = self.gat_module(graph_data, device)
            else:
                gat_features = torch.zeros((1, 256)).to(device)
            batch_features.append(gat_features)

        features = torch.cat(batch_features, dim=0)
        return self.processor(features)
class FocalLoss(nn.Module):
    """
    Implementation of Focal Loss for multi-label classification.
    This loss was introduced in 'Focal Loss for Dense Object Detection'.

    Args:
        alpha (float, optional): A balancing factor for classes. Defaults to 0.25.
        gamma (float, optional): A focusing parameter to down-weight easy examples. Defaults to 2.0.
        pos_weight (Tensor, optional): A weight of positive examples for each class. Defaults to None.
    """
    def __init__(self, alpha=0.25, gamma=2.0, pos_weight=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        # Use BCEWithLogitsLoss for numerical stability, with reduction='none' to apply focal term
        self.bce_with_logits = nn.BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight)

    def forward(self, inputs, targets):
        # Calculate the base BCE loss for each element
        bce_loss = self.bce_with_logits(inputs, targets)

        # Calculate the probability p_t
        p_t = torch.exp(-bce_loss)

        # Calculate the Focal Loss
        # The alpha term balances positive and negative examples
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        focal_loss = alpha_t * (1 - p_t)**self.gamma * bce_loss

        return focal_loss.mean()
# ===================================================================
# 2. MLP Text Model (from paste-2.txt)
# ===================================================================
class TextModelRoBERTaFrozen(nn.Module):
    def __init__(self):
        super().__init__()
        # Load RoBERTa model
        self.tokenizer = AutoTokenizer.from_pretrained('roberta-base')
        self.roberta = AutoModel.from_pretrained('roberta-base')

        # Freeze RoBERTa completely
        for param in self.roberta.parameters():
            param.requires_grad = False

        # Trainable processing layers
        self.processor = Sequential(
            Linear(768, 512), LayerNorm(512), GELU(), Dropout(0.1),
            Linear(512, 256), LayerNorm(256), GELU(), Dropout(0.1)
        )

        # Prediction heads - Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            foundation: Sequential(Linear(256, 128), ReLU(), Dropout(0.2),
                                 Linear(128, 1))
            for foundation in moral_foundations
        })

    def forward(self, texts):
        # Tokenize and encode
        inputs = self.tokenizer(texts, padding=True, truncation=True,
                               max_length=512, return_tensors='pt')
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.roberta(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]  # CLS token

        features = self.processor(embeddings)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, texts):
        inputs = self.tokenizer(texts, padding=True, truncation=True,
                               max_length=512, return_tensors='pt')
        inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.roberta(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]
        return self.processor(embeddings)

# ===================================================================
# 3. Other Modality Models
# ===================================================================
class SpatialModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), ReLU(), Dropout(0.2),
            Linear(128, 256)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

class TemporalModelLSTM(nn.Module):
    def __init__(self, input_dim, seq_len=10, hidden_size=128):
        super().__init__()
        self.projection = Linear(input_dim, hidden_size)
        self.lstm = LSTM(hidden_size, hidden_size, 2, batch_first=True,
                        dropout=0.2, bidirectional=True)
        self.encoder = Sequential(
            Linear(hidden_size * 2, 256), LayerNorm(256),
            ReLU(), Dropout(0.2)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        features = self.encoder(lstm_out[:, -1])
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        proj = self.projection(x)
        lstm_out, _ = self.lstm(proj)
        return self.encoder(lstm_out[:, -1])

class BehavioralModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.encoder = Sequential(
            Linear(input_dim, 128), LayerNorm(128), ReLU(), Dropout(0.1),
            Linear(128, 256), LayerNorm(256), ReLU(), Dropout(0.1)
        )
        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(256, 64), ReLU(), Linear(64, 1))
            for f in moral_foundations
        })

    def forward(self, x):
        features = self.encoder(x)
        return {f: self.heads[f](features) for f in moral_foundations}

    def get_features(self, x):
        return self.encoder(x)

# ===================================================================
# 4. Enhanced Heterogeneous Fusion (5 modalities)
# ===================================================================
class HeterogeneousFusion(nn.Module):
    def __init__(self, input_dims=None, d_model=256, num_heads=8):
        super().__init__()
        self.d_model = d_model

        # Feature projection layers to align dimensions (5 modalities now)
        if input_dims:
            self.projections = nn.ModuleList([
                nn.Linear(dim, d_model) for dim in input_dims
            ])
        else:
            self.projections = nn.ModuleList([
                nn.Linear(256, d_model) for _ in range(5)  # GAT text, MLP text, spatial, temporal, behavioral
            ])

        self.modality_queries = nn.Parameter(torch.randn(5, d_model))

        self.cross_attention = nn.ModuleList([
            MultiheadAttention(d_model, num_heads, batch_first=True)
            for _ in range(5)
        ])

        self.self_attention = MultiheadAttention(d_model, num_heads, batch_first=True)
        self.layer_norms = nn.ModuleList([LayerNorm(d_model) for _ in range(6)])

        # Improved gating mechanism for 5 modalities
        self.modality_gate = nn.Sequential(
            Linear(d_model * 5, d_model), ReLU(), Dropout(0.1),
            Linear(d_model, 5), Softmax(dim=-1)
        )

        self.final_projection = Linear(d_model, d_model)

        # Remove Sigmoid since BCEWithLogitsLoss includes it
        self.heads = ModuleDict({
            f: Sequential(Linear(d_model, 128), ReLU(), Dropout(0.2),
                         Linear(128, 1))
            for f in moral_foundations
        })

    def forward(self, gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat):
        batch_size = gat_text_feat.size(0)

        # Project all features to same dimension
        features = [
            self.projections[i](feat)
            for i, feat in enumerate([gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat])
        ]

        # Add sequence dimension for attention
        modality_feats = [f.unsqueeze(1) for f in features]

        # Cross-attention with residual connections
        attended = []
        for i, (feat, attn, norm) in enumerate(zip(modality_feats, self.cross_attention, self.layer_norms[:5])):
            query = self.modality_queries[i:i+1].unsqueeze(0).expand(batch_size, -1, -1)
            att_out, _ = attn(query, feat, feat)
            attended.append(norm(att_out.squeeze(1) + features[i]))

        # Stack for self-attention
        stacked = torch.stack(attended, dim=1)

        # Self-attention
        self_att, _ = self.self_attention(stacked, stacked, stacked)
        fused = self.layer_norms[5](stacked + self_att)

        # Improved gating: compute gates from aggregated representation
        pooled = fused.mean(dim=1)
        gates = self.modality_gate(fused.view(batch_size, -1))

        # Apply gates to individual modality representations
        gated = torch.sum(fused * gates.unsqueeze(-1), dim=1)

        # Final projection
        final_features = self.final_projection(gated)

        return {f: self.heads[f](final_features) for f in moral_foundations}

# ===================================================================
# 5. Data Processing Functions (same as paste.txt)
# ===================================================================
def prepare_labels(df):
    """Combine virtue/vice pairs and create Non-Moral labels"""
    df = df.reset_index(drop=True)

    # Combine virtue/vice pairs using logical OR
    labels = {}
    labels['Care'] = (df['Care'] == 1) | (df['Harm'] == 1)
    labels['Fairness'] = (df['Fairness'] == 1) | (df.get('Cheating', df.get('cheating', pd.Series([False]*len(df), index=df.index))) == 1)
    labels['Loyalty'] = (df['Loyalty'] == 1) | (df['Betrayal'] == 1)
    labels['Authority'] = (df['Authority'] == 1) | (df['Subversion'] == 1)
    labels['Purity'] = (df['Purity'] == 1) | (df['Degradation'] == 1)

    # Create DataFrame from moral foundation labels
    moral_df = pd.DataFrame({
        'Care': labels['Care'],
        'Fairness': labels['Fairness'],
        'Loyalty': labels['Loyalty'],
        'Authority': labels['Authority'],
        'Purity': labels['Purity']
    })

    # Non_Moral = 1 when all moral foundations are 0 (False)
    labels['Non_Moral'] = (~moral_df.any(axis=1)).astype(int)

    # Convert to numpy arrays
    labels_array = np.column_stack([labels[f].astype(int).values for f in moral_foundations])
    return torch.tensor(labels_array, dtype=torch.float32)

def create_temporal_sequences(df, seq_len=10):
    """Create temporal sequences from dataset"""
    df_sorted = df.sort_values(['GEOID', 'year', 'month', 'day'])
    sequences, targets, other_features = [], [], []

    for geoid, group in df_sorted.groupby('GEOID'):
        if len(group) >= seq_len:
            for i in range(len(group) - seq_len + 1):
                seq_data = group.iloc[i:i+seq_len]

                # Temporal features
                temp_seq = []
                for _, row in seq_data.iterrows():
                    temp_features = [
                        row['month'], row['day'], row['year'], row['after_05/26/2020'],
                        row['female_pct'], row['lesscollege_pct'], row['cvap'],
                        row['clf_unemploy_pct'], row['lesshs_pct'], row['median_hh_inc'],
                        row['urm_pct'], row['ruralurban_cc'], row['demgov'], row['repgov'],
                        row['net_dem_president_votes'], row['net_dem_gov_votes'], row['is_blue'],
                        row['deaths'], row['cases'], row['cases_per_capita'],
                        row['deaths_per_capita'], row['cases_per_capita_discrete'], row['mask_score']
                    ]
                    temp_seq.append(temp_features)

                sequences.append(temp_seq)

                last_row = seq_data.iloc[-1]
                # Prepare labels for last timestep
                last_df = pd.DataFrame([last_row])
                targets.append(prepare_labels(last_df)[0])

                other_features.append({
                    'text': last_row['text'],
                    'geoid': str(last_row['GEOID']).zfill(5),
                    'behavioral': [last_row['retweet_count'], last_row['rt_discrete'],
                                 last_row['for_sah'], last_row['is_vivid'], last_row['sentiment_score']]
                })

    return sequences, targets, other_features

def prepare_dataset(csv_path, geojson_path, county_centroids_path, seq_len=10):
    """Streamlined dataset preparation"""
    # Load data
    df = pd.read_csv(csv_path)
    with open(geojson_path) as f:
        boundaries = json.load(f)
    with open(county_centroids_path) as f:
        centroids = json.load(f)

    # Create sequences
    sequences, targets, other_features = create_temporal_sequences(df, seq_len)

    # Process features
    data = {'text': [], 'spatial': [], 'temporal': [], 'behavioral': [], 'targets': []}

    for i, (seq, target, other) in enumerate(zip(sequences, targets, other_features)):
        data['text'].append(other['text'])
        data['temporal'].append(seq)
        data['behavioral'].append(other['behavioral'])
        data['targets'].append(target)

        # Extract spatial features
        geoid = other['geoid']
        if geoid in boundaries and geoid in centroids:
            # Simplified spatial feature extraction
            centroid = centroids[geoid]
            lat_norm = (centroid[0] - 24) / (49 - 24)  # Normalize lat
            lon_norm = (centroid[1] + 125) / (66 - 125)  # Normalize lon
            spatial_feat = [lat_norm, lon_norm, 0.5, 0.5, 0.5, 0.5]  # Simplified features
        else:
            spatial_feat = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]  # Default

        data['spatial'].append(spatial_feat)

    # Convert to tensors
    spatial_tensor = torch.tensor(data['spatial'], dtype=torch.float32)
    temporal_tensor = torch.tensor(data['temporal'], dtype=torch.float32)
    behavioral_tensor = torch.tensor(data['behavioral'], dtype=torch.float32)
    targets_tensor = torch.stack(data['targets'])

    # Normalize features
    scaler_spatial = StandardScaler()
    scaler_behavioral = StandardScaler()
    scaler_temporal = StandardScaler()

    spatial_tensor = torch.tensor(scaler_spatial.fit_transform(spatial_tensor), dtype=torch.float32)
    behavioral_tensor = torch.tensor(scaler_behavioral.fit_transform(behavioral_tensor), dtype=torch.float32)

    # Normalize temporal (reshape for scaling)
    temp_shape = temporal_tensor.shape
    temp_2d = temporal_tensor.view(-1, temp_shape[2])
    temp_scaled = scaler_temporal.fit_transform(temp_2d)
    temporal_tensor = torch.tensor(temp_scaled.reshape(temp_shape), dtype=torch.float32)

    # Create splits
    indices = list(range(len(data['text'])))
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

    datasets = {}
    for split, idx in [('train', train_idx), ('validation', val_idx), ('test', test_idx)]:
        datasets[split] = {
            'text_data': [data['text'][i] for i in idx],
            'spatial_features': spatial_tensor[idx],
            'temporal_sequences': temporal_tensor[idx],
            'behavioral_features': behavioral_tensor[idx],
            'moral_targets': targets_tensor[idx]
        }

    feature_info = {
        'spatial_feature_dim': spatial_tensor.shape[1],
        'temporal_feature_dim': temporal_tensor.shape[2],
        'temporal_sequence_length': seq_len,
        'behavioral_feature_dim': behavioral_tensor.shape[1]
    }

    return datasets, feature_info

def create_dataloader(dataset_dict, batch_size=16, shuffle=True):
    """Create DataLoader for the dataset"""
    def collate_fn(batch):
        texts = [item[0] for item in batch]
        spatial = torch.stack([item[1] for item in batch])
        temporal = torch.stack([item[2] for item in batch])
        behavioral = torch.stack([item[3] for item in batch])
        targets = torch.stack([item[4] for item in batch])

        return {
            'text_data': texts,
            'spatial_features': spatial,
            'temporal_sequences': temporal,
            'behavioral_features': behavioral,
            'targets': targets
        }

    dataset = [(dataset_dict['text_data'][i],
                dataset_dict['spatial_features'][i],
                dataset_dict['temporal_sequences'][i],
                dataset_dict['behavioral_features'][i],
                dataset_dict['moral_targets'][i])
               for i in range(len(dataset_dict['text_data']))]

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

# ===================================================================
# 6. Enhanced Multi-Stage Training System (5 modalities)
# ===================================================================
class MultiStageTrainer:
    def __init__(self, models, datasets, feature_info, device, emfd_csv_path):
        self.gat_text_model, self.mlp_text_model, self.spatial_model, self.temporal_model, self.behavioral_model, self.fusion_model = models
        self.datasets = datasets
        self.device = device

        # Move models to device
        self.gat_text_model = self.gat_text_model.to(device)
        self.mlp_text_model = self.mlp_text_model.to(device)
        self.spatial_model = self.spatial_model.to(device)
        self.temporal_model = self.temporal_model.to(device)
        self.behavioral_model = self.behavioral_model.to(device)
        self.fusion_model = self.fusion_model.to(device)

        # Create data loaders
        self.train_loader = create_dataloader(datasets['train'], batch_size=16, shuffle=True)
        self.val_loader = create_dataloader(datasets['validation'], batch_size=16, shuffle=False)


        train_targets = self.datasets['train']['moral_targets']
        # Calculate N_negative / N_positive for each class to get the weight
        pos_counts = train_targets.sum(dim=0)
        neg_counts = len(train_targets) - pos_counts
        # Add epsilon to avoid division by zero if a class has no positive samples
        pos_weight = neg_counts / (pos_counts + 1e-6)

        # Use these weights in your NEW FocalLoss function
        # You can tune alpha and gamma, but these are standard defaults.
        self.criterion = FocalLoss(alpha=0.25, gamma=2.0, pos_weight=pos_weight.to(self.device))


    def _freeze(self, model):
        for p in model.parameters():
            p.requires_grad_(False)

    def _unfreeze(self, model):
        for p in model.parameters():
            p.requires_grad_(True)

    def compute_loss(self, predictions, targets):
        """Compute loss for moral foundation predictions"""
        pred_stack = torch.cat([predictions[f] for f in moral_foundations], dim=1)
        return self.criterion(pred_stack, targets)

    def train_stage1_text_models(self, epochs=15, lr=2e-5):
        """Stage 1: Train both GAT eMFD and MLP Text Models"""
        print("\n" + "="*20 + " Stage 1: Training Both Text Models " + "="*20)

        # Freeze other models
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._freeze(self.fusion_model)

        # Train GAT Text Model
        print("Training GAT-eMFD Text Model...")
        self._unfreeze(self.gat_text_model)
        self._freeze(self.mlp_text_model)

        optimizer_gat = optim.AdamW(self.gat_text_model.parameters(), lr=lr, weight_decay=1e-5)

        for epoch in range(epochs):
            self.gat_text_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"GAT Text - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer_gat.zero_grad()
                preds = self.gat_text_model(batch['text_data'])
                loss = self.compute_loss(preds, batch['targets'])
                loss.backward()
                optimizer_gat.step()
                total_loss += loss.item()

            if epoch % 3 == 0:
                print(f"GAT Text model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

        # Train MLP Text Model
        print("Training MLP-RoBERTa Text Model...")
        self._freeze(self.gat_text_model)
        self._unfreeze(self.mlp_text_model)

        optimizer_mlp = optim.AdamW(self.mlp_text_model.parameters(), lr=lr, weight_decay=1e-5)

        for epoch in range(epochs):
            self.mlp_text_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"MLP Text - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer_mlp.zero_grad()
                preds = self.mlp_text_model(batch['text_data'])
                loss = self.compute_loss(preds, batch['targets'])
                loss.backward()
                optimizer_mlp.step()
                total_loss += loss.item()

            if epoch % 3 == 0:
                print(f"MLP Text model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

        # Save models
        torch.save(self.gat_text_model.state_dict(), "gat_text_stage1.pth")
        torch.save(self.mlp_text_model.state_dict(), "mlp_text_stage1.pth")
        print("✅ Stage 1 Complete. Both text models saved.")

    def train_stage2_other_modalities(self, epochs=10, lr=5e-4):
        """Stage 2: Train other modality models"""
        print("\n" + "="*20 + " Stage 2: Training Other Modality Models " + "="*20)

        # Freeze text and fusion models
        self._freeze(self.gat_text_model)
        self._freeze(self.mlp_text_model)
        self._freeze(self.fusion_model)

        models_to_train = [
            (self.spatial_model, 'spatial'),
            (self.temporal_model, 'temporal'),
            (self.behavioral_model, 'behavioral')
        ]

        for model, name in models_to_train:
            print(f"\n--- Training {name.capitalize()} Model ---")
            self._unfreeze(model)
            optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

            for epoch in range(epochs):
                model.train()
                total_loss = 0

                for batch in tqdm(self.train_loader, desc=f"{name.capitalize()} - Epoch {epoch+1}/{epochs}"):
                    batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                            for k, v in batch.items()}

                    optimizer.zero_grad()

                    if name == 'spatial':
                        preds = model(batch['spatial_features'])
                    elif name == 'temporal':
                        preds = model(batch['temporal_sequences'])
                    else:  # behavioral
                        preds = model(batch['behavioral_features'])

                    loss = self.compute_loss(preds, batch['targets'])
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()

                if epoch % 3 == 0:
                    print(f"{name.capitalize()} model epoch {epoch+1}, loss: {total_loss/len(self.train_loader):.4f}")

            # Save model
            torch.save(model.state_dict(), f"{name}_model_stage2.pth")
            self._freeze(model)

        print("✅ Stage 2 Complete. All modality models saved.")

    def train_stage3_fusion_integration(self, epochs=8, lr=1e-3):
        """Stage 3: Fusion Integration with 5 modalities"""
        print("\n" + "="*15 + " Stage 3: Fusion Integration (5 modalities) " + "="*15)

        # Freeze all individual models, unfreeze fusion
        self._freeze(self.gat_text_model)
        self._freeze(self.mlp_text_model)
        self._freeze(self.spatial_model)
        self._freeze(self.temporal_model)
        self._freeze(self.behavioral_model)
        self._unfreeze(self.fusion_model)

        optimizer = optim.AdamW(self.fusion_model.parameters(), lr=lr, weight_decay=1e-4)

        for epoch in range(epochs):
            self.fusion_model.train()
            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"Stage 3 - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer.zero_grad()

                # Extract features with no gradients for individual models
                with torch.no_grad():
                    gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
                    mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
                    spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                    temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                    behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                # Fusion with gradients (5 modalities)
                fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
                loss = self.compute_loss(fusion_preds, batch['targets'])
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if epoch % 2 == 0:
                val_f1 = self._validate_fusion()
                print(f"Stage 3 - Epoch {epoch+1}: Loss: {total_loss/len(self.train_loader):.4f}, Val F1: {val_f1:.4f}")

        torch.save(self.fusion_model.state_dict(), "fusion_model_stage3.pth")
        print("✅ Stage 3 Complete. Fusion model saved.")

    def train_stage4_end_to_end_finetuning(self, epochs=5, best_model_path="best_fusion_model.pth"):
        """Stage 4: End-to-end fine-tuning of ALL models (GAT text, MLP text, spatial, temporal, behavioral, fusion)"""
        print("\n" + "="*15 + " Stage 4: End-to-End Fine-tuning (ALL ENCODERS) " + "="*15)

        # Unfreeze ALL models for end-to-end training
        self._unfreeze(self.gat_text_model)
        self._unfreeze(self.mlp_text_model)
        self._unfreeze(self.spatial_model)
        self._unfreeze(self.temporal_model)
        self._unfreeze(self.behavioral_model)
        self._unfreeze(self.fusion_model)

        # Create optimizer with different learning rates for different model components
        optimizer = optim.AdamW([
            {'params': self.gat_text_model.parameters(), 'lr': 5e-6},
            {'params': self.mlp_text_model.parameters(), 'lr': 5e-6},
            {'params': self.spatial_model.parameters(), 'lr': 1e-5},
            {'params': self.temporal_model.parameters(), 'lr': 1e-5},
            {'params': self.behavioral_model.parameters(), 'lr': 1e-5},
            {'params': self.fusion_model.parameters(), 'lr': 1e-4}
        ], weight_decay=1e-4)

        best_val_f1 = 0.0
        patience_counter = 0

        for epoch in range(epochs):
            # Set all models to training mode
            self.gat_text_model.train()
            self.mlp_text_model.train()
            self.spatial_model.train()
            self.temporal_model.train()
            self.behavioral_model.train()
            self.fusion_model.train()

            total_loss = 0

            for batch in tqdm(self.train_loader, desc=f"Stage 4 (ALL ENCODERS) - Epoch {epoch+1}/{epochs}"):
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                optimizer.zero_grad()

                # Extract features from ALL models with gradients enabled
                gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
                mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
                spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                # Fusion with all gradients flowing
                fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
                loss = self.compute_loss(fusion_preds, batch['targets'])
                loss.backward()

                # Gradient clipping to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(self.gat_text_model.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.mlp_text_model.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.spatial_model.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.temporal_model.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.behavioral_model.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.fusion_model.parameters(), max_norm=1.0)

                optimizer.step()
                total_loss += loss.item()

            val_f1 = self._validate_fusion()
            print(f"Stage 4 (ALL ENCODERS) - Epoch {epoch+1}: Loss: {total_loss/len(self.train_loader):.4f}, Val F1: {val_f1:.4f}")

            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                patience_counter = 0
                torch.save({
                    'gat_text_model': self.gat_text_model.state_dict(),
                    'mlp_text_model': self.mlp_text_model.state_dict(),
                    'spatial_model': self.spatial_model.state_dict(),
                    'temporal_model': self.temporal_model.state_dict(),
                    'behavioral_model': self.behavioral_model.state_dict(),
                    'fusion_model': self.fusion_model.state_dict()
                }, best_model_path)
                print(f"🏆 New best model saved with F1: {val_f1:.4f} (ALL ENCODERS fine-tuned)")
            else:
                patience_counter += 1
                if patience_counter >= 2:
                    print("🛑 Early stopping triggered.")
                    break

        print("✅ Stage 4 Complete. End-to-end training of ALL ENCODERS finished.")

    def _validate_fusion(self):
        """Validate fusion model with 5 modalities - ALL models in eval mode"""
        self.gat_text_model.eval()
        self.mlp_text_model.eval()
        self.spatial_model.eval()
        self.temporal_model.eval()
        self.behavioral_model.eval()
        self.fusion_model.eval()

        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in self.val_loader:
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
                mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
                spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
                temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
                behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

                fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
                pred_stack = torch.cat([fusion_preds[f] for f in moral_foundations], dim=1)

                all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
                all_labels.append(batch['targets'].cpu().numpy())

        y_pred = np.vstack(all_preds)
        y_true = np.vstack(all_labels)
        return f1_score(y_true, y_pred, average='macro', zero_division=0)



def train_stage4_end_to_end_finetuning(self, epochs=5, best_model_path="best_fusion_model.pth"):
    """Stage 4: End-to-end fine-tuning of ALL models (GAT text, MLP text, spatial, temporal, behavioral, fusion)"""
    print("\n" + "="*15 + " Stage 4: End-to-End Fine-tuning (ALL ENCODERS) " + "="*15)

    # Unfreeze ALL models for end-to-end training
    self._unfreeze(self.gat_text_model)
    self._unfreeze(self.mlp_text_model)
    self._unfreeze(self.spatial_model)
    self._unfreeze(self.temporal_model)
    self._unfreeze(self.behavioral_model)
    self._unfreeze(self.fusion_model)

    # Create optimizer with different learning rates for different model components
    optimizer = optim.AdamW([
        {'params': self.gat_text_model.parameters(), 'lr': 5e-6},
        {'params': self.mlp_text_model.parameters(), 'lr': 5e-6},
        {'params': self.spatial_model.parameters(), 'lr': 1e-5},
        {'params': self.temporal_model.parameters(), 'lr': 1e-5},
        {'params': self.behavioral_model.parameters(), 'lr': 1e-5},
        {'params': self.fusion_model.parameters(), 'lr': 1e-4}
    ], weight_decay=1e-4)

    best_val_f1 = 0.0
    patience_counter = 0

    for epoch in range(epochs):
        # Set all models to training mode
        self.gat_text_model.train()
        self.mlp_text_model.train()
        self.spatial_model.train()
        self.temporal_model.train()
        self.behavioral_model.train()
        self.fusion_model.train()

        total_loss = 0

        for batch in tqdm(self.train_loader, desc=f"Stage 4 (ALL ENCODERS) - Epoch {epoch+1}/{epochs}"):
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            optimizer.zero_grad()

            # Extract features from ALL models with gradients enabled
            gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
            mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
            spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
            temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

            # Fusion with all gradients flowing
            fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
            loss = self.compute_loss(fusion_preds, batch['targets'])
            loss.backward()

            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(self.gat_text_model.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(self.mlp_text_model.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(self.spatial_model.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(self.temporal_model.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(self.behavioral_model.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(self.fusion_model.parameters(), max_norm=1.0)

            optimizer.step()
            total_loss += loss.item()

        val_f1 = self._validate_fusion()
        print(f"Stage 4 (ALL ENCODERS) - Epoch {epoch+1}: Loss: {total_loss/len(self.train_loader):.4f}, Val F1: {val_f1:.4f}")

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            patience_counter = 0
            torch.save({
                'gat_text_model': self.gat_text_model.state_dict(),
                'mlp_text_model': self.mlp_text_model.state_dict(),
                'spatial_model': self.spatial_model.state_dict(),
                'temporal_model': self.temporal_model.state_dict(),
                'behavioral_model': self.behavioral_model.state_dict(),
                'fusion_model': self.fusion_model.state_dict()
            }, best_model_path)
            print(f"🏆 New best model saved with F1: {val_f1:.4f} (ALL ENCODERS fine-tuned)")
        else:
            patience_counter += 1
            if patience_counter >= 2:
                print("🛑 Early stopping triggered.")
                break

    print("✅ Stage 4 Complete. End-to-end training of ALL ENCODERS finished.")
def _validate_fusion(self):
    """Validate fusion model with 5 modalities - ALL models in eval mode"""
    self.gat_text_model.eval()
    self.mlp_text_model.eval()
    self.spatial_model.eval()
    self.temporal_model.eval()
    self.behavioral_model.eval()
    self.fusion_model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in self.val_loader:
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            gat_text_feat = self.gat_text_model.get_features(batch['text_data'])
            mlp_text_feat = self.mlp_text_model.get_features(batch['text_data'])
            spatial_feat = self.spatial_model.get_features(batch['spatial_features'])
            temporal_feat = self.temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = self.behavioral_model.get_features(batch['behavioral_features'])

            fusion_preds = self.fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
            pred_stack = torch.cat([fusion_preds[f] for f in moral_foundations], dim=1)

            all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
            all_labels.append(batch['targets'].cpu().numpy())

    y_pred = np.vstack(all_preds)
    y_true = np.vstack(all_labels)
    return f1_score(y_true, y_pred, average='macro', zero_division=0)

# ===================================================================
# 7. Enhanced Evaluation Functions
# ===================================================================
def evaluate_individual_models(test_loader, device, feature_info, emfd_csv_path):
    """Evaluate all individual models including both text models"""
    print("\n" + "="*20 + " INDIVIDUAL MODEL EVALUATION " + "="*20)

    # Load models
    gat_text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    mlp_text_model = TextModelRoBERTaFrozen().to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)

    # Load weights
    gat_text_model.load_state_dict(torch.load("gat_text_stage1.pth"))
    mlp_text_model.load_state_dict(torch.load("mlp_text_stage1.pth"))
    spatial_model.load_state_dict(torch.load("spatial_model_stage2.pth"))
    temporal_model.load_state_dict(torch.load("temporal_model_stage2.pth"))
    behavioral_model.load_state_dict(torch.load("behavioral_model_stage2.pth"))

    models = [
        (gat_text_model, 'GAT-eMFD Text'),
        (mlp_text_model, 'MLP-RoBERTa Text'),
        (spatial_model, 'Spatial'),
        (temporal_model, 'Temporal'),
        (behavioral_model, 'Behavioral')
    ]

    results = {}

    for model, name in models:
        model.eval()
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in tqdm(test_loader, desc=f"Testing {name}"):
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                if 'Text' in name:
                    preds = model(batch['text_data'])
                elif name == 'Spatial':
                    preds = model(batch['spatial_features'])
                elif name == 'Temporal':
                    preds = model(batch['temporal_sequences'])
                else:  # Behavioral
                    preds = model(batch['behavioral_features'])

                pred_stack = torch.cat([preds[f] for f in moral_foundations], dim=1)
                all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
                all_labels.append(batch['targets'].cpu().numpy())

        y_pred = np.vstack(all_preds)
        y_true = np.vstack(all_labels)
        macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

        print(f"\n--- {name} Model Results ---")
        print(f"Macro F1-Score: {macro_f1:.4f}")

        # Foundation-specific results
        foundation_f1s = []
        for i, foundation in enumerate(moral_foundations):
            f1 = f1_score(y_true[:, i], y_pred[:, i], zero_division=0)
            foundation_f1s.append(f1)

        results[name] = {
            'macro_f1': macro_f1,
            'foundation_f1s': foundation_f1s
        }

    return results

def evaluate_fusion_model(model_path, test_loader, device, feature_info, emfd_csv_path):
    """Evaluate the final fusion model with 5 modalities"""
    print("\n" + "="*20 + " FINAL FUSION MODEL EVALUATION (5 MODALITIES) " + "="*20)

    # Load models
    gat_text_model = TextModelGATeMFD(emfd_csv_path).to(device)
    mlp_text_model = TextModelRoBERTaFrozen().to(device)
    spatial_model = SpatialModel(feature_info['spatial_feature_dim']).to(device)
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    ).to(device)
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim']).to(device)
    fusion_model = HeterogeneousFusion().to(device)

    # Load best model state
    checkpoint = torch.load(model_path)
    gat_text_model.load_state_dict(checkpoint['gat_text_model'])
    mlp_text_model.load_state_dict(checkpoint['mlp_text_model'])
    spatial_model.load_state_dict(checkpoint['spatial_model'])
    temporal_model.load_state_dict(checkpoint['temporal_model'])
    behavioral_model.load_state_dict(checkpoint['behavioral_model'])
    fusion_model.load_state_dict(checkpoint['fusion_model'])

    # Set to eval mode
    gat_text_model.eval()
    mlp_text_model.eval()
    spatial_model.eval()
    temporal_model.eval()
    behavioral_model.eval()
    fusion_model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing 5-Modality Fusion Model"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            # Extract features
            gat_text_feat = gat_text_model.get_features(batch['text_data'])
            mlp_text_feat = mlp_text_model.get_features(batch['text_data'])
            spatial_feat = spatial_model.get_features(batch['spatial_features'])
            temporal_feat = temporal_model.get_features(batch['temporal_sequences'])
            behavioral_feat = behavioral_model.get_features(batch['behavioral_features'])

            # Fusion
            fusion_preds = fusion_model(gat_text_feat, mlp_text_feat, spatial_feat, temporal_feat, behavioral_feat)
            pred_stack = torch.cat([fusion_preds[f] for f in moral_foundations], dim=1)

            all_preds.append((torch.sigmoid(pred_stack) > 0.5).cpu().numpy())
            all_labels.append(batch['targets'].cpu().numpy())

    y_pred = np.vstack(all_preds)
    y_true = np.vstack(all_labels)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    print(classification_report(y_true, y_pred, target_names=moral_foundations, zero_division=0))
    print(f"** 5-MODALITY FUSION MODEL Macro F1-Score: {macro_f1:.4f} **")

    return macro_f1

def print_comprehensive_results(individual_results, fusion_f1):
    """Print comprehensive comparison results for 5 modalities"""
    print("\n" + "="*80)
    print("COMPREHENSIVE 5-MODALITY EVALUATION RESULTS")
    print("=" * 80)

    # Individual model results
    for model_name, results in individual_results.items():
        print(f"\n{model_name} MODEL:")
        print("-" * 50)
        print(f"  Macro F1-Score: {results['macro_f1']:.4f}")
        for i, foundation in enumerate(moral_foundations):
            print(f"  {foundation} F1: {results['foundation_f1s'][i]:.4f}")

    print(f"\n5-MODALITY FUSION MODEL:")
    print("-" * 50)
    print(f"  Macro F1-Score: {fusion_f1:.4f}")

    # Foundation-specific comparison
    print("\n" + "="*80)
    print("FOUNDATION-SPECIFIC F1 COMPARISON:")
    print("=" * 80)
    header = "Foundation     GAT-Text  MLP-Text  Spatial   Temporal  Behavioral  Fusion"
    print(header)
    print("-" * 80)

    for i, foundation in enumerate(moral_foundations):
        row = f"{foundation:<12}"
        for model_name in ['GAT-eMFD Text', 'MLP-RoBERTa Text', 'Spatial', 'Temporal', 'Behavioral']:
            if model_name in individual_results:
                f1_score = individual_results[model_name]['foundation_f1s'][i]
                row += f"  {f1_score:.3f}    "
        row += f"  {fusion_f1:.3f}"
        print(row)

# ===================================================================
# 8. Main Execution
# ===================================================================
if __name__ == "__main__":
    # File paths - UPDATE THESE
    CSV_PATH = '/content/augmented_tweets.csv'
    GEOJSON_PATH = '/content/county_data_boarders.json'
    CENTROIDS_PATH = '/content/county_centroids.json'
    EMFD_CSV_PATH = '/content/eMFD_wordlist.csv'

    if not GATConv:
        print("❌ PyTorch Geometric not available. Please install it.")
        exit()

    # Prepare dataset
    print("🔄 Preparing dataset...")
    datasets, feature_info = prepare_dataset(
        csv_path=CSV_PATH,
        geojson_path=GEOJSON_PATH,
        county_centroids_path=CENTROIDS_PATH,
        seq_len=10
    )

    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Initialize models (now 6 models: 2 text + 3 others + 1 fusion)
    gat_text_model = TextModelGATeMFD(EMFD_CSV_PATH)
    mlp_text_model = TextModelRoBERTaFrozen()
    spatial_model = SpatialModel(feature_info['spatial_feature_dim'])
    temporal_model = TemporalModelLSTM(
        feature_info['temporal_feature_dim'],
        feature_info['temporal_sequence_length']
    )
    behavioral_model = BehavioralModel(feature_info['behavioral_feature_dim'])
    fusion_model = HeterogeneousFusion()  # Now handles 5 modalities

    models = [gat_text_model, mlp_text_model, spatial_model, temporal_model, behavioral_model, fusion_model]

    # Initialize trainer
    trainer = MultiStageTrainer(
        models=models,
        datasets=datasets,
        feature_info=feature_info,
        device=device,
        emfd_csv_path=EMFD_CSV_PATH
    )

    # Execute multi-stage training
    print("\n🚀 Starting 5-Modality Multi-Stage Training Pipeline...")
    print("=" * 80)

    trainer.train_stage1_text_models(epochs=5, lr=2e-5)
    trainer.train_stage2_other_modalities(epochs=10, lr=5e-4)
    trainer.train_stage3_fusion_integration(epochs=2, lr=1e-3)
    trainer.train_stage4_end_to_end_finetuning(epochs=2, best_model_path="best_5modality_fusion.pth")

    # Final evaluation
    print("\n🔍 Starting Comprehensive 5-Modality Evaluation...")
    print("=" * 80)

    test_loader = create_dataloader(datasets['test'], batch_size=16, shuffle=False)

    # Evaluate individual models (including both text models)
    individual_results = evaluate_individual_models(test_loader, device, feature_info, EMFD_CSV_PATH)

    # Evaluate 5-modality fusion model
    fusion_f1 = evaluate_fusion_model("best_5modality_fusion.pth", test_loader, device, feature_info, EMFD_CSV_PATH)

    # Print comprehensive results
    print_comprehensive_results(individual_results, fusion_f1)

    print("\n✅ 5-modality training and comprehensive evaluation completed!")
    print("🏆 Best models saved as 'best_5modality_fusion.pth'")


🔄 Preparing dataset...
Using device: cuda


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



🚀 Starting 5-Modality Multi-Stage Training Pipeline...

Training GAT-eMFD Text Model...


GAT Text - Epoch 1/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT Text model epoch 1, loss: 0.2321


GAT Text - Epoch 2/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT Text - Epoch 3/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT Text - Epoch 4/5:   0%|          | 0/17 [00:00<?, ?it/s]

GAT Text model epoch 4, loss: 0.2186


GAT Text - Epoch 5/5:   0%|          | 0/17 [00:00<?, ?it/s]

Training MLP-RoBERTa Text Model...


MLP Text - Epoch 1/5:   0%|          | 0/17 [00:00<?, ?it/s]

MLP Text model epoch 1, loss: 0.2309


MLP Text - Epoch 2/5:   0%|          | 0/17 [00:00<?, ?it/s]

MLP Text - Epoch 3/5:   0%|          | 0/17 [00:00<?, ?it/s]

MLP Text - Epoch 4/5:   0%|          | 0/17 [00:00<?, ?it/s]

MLP Text model epoch 4, loss: 0.2352


MLP Text - Epoch 5/5:   0%|          | 0/17 [00:00<?, ?it/s]

✅ Stage 1 Complete. Both text models saved.


--- Training Spatial Model ---


Spatial - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 1, loss: 0.2222


Spatial - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 4, loss: 0.2030


Spatial - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 7, loss: 0.2609


Spatial - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Spatial model epoch 10, loss: 0.1925

--- Training Temporal Model ---


Temporal - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 1, loss: 0.2232


Temporal - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 4, loss: 0.1774


Temporal - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 7, loss: 0.1534


Temporal - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Temporal model epoch 10, loss: 0.1479

--- Training Behavioral Model ---


Behavioral - Epoch 1/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 1, loss: 0.2166


Behavioral - Epoch 2/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 3/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 4/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 4, loss: 0.1817


Behavioral - Epoch 5/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 6/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 7/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 7, loss: 0.1760


Behavioral - Epoch 8/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 9/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral - Epoch 10/10:   0%|          | 0/17 [00:00<?, ?it/s]

Behavioral model epoch 10, loss: 0.1746
✅ Stage 2 Complete. All modality models saved.



Stage 3 - Epoch 1/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 3 - Epoch 1: Loss: 0.2489, Val F1: 0.1689


Stage 3 - Epoch 2/2:   0%|          | 0/17 [00:00<?, ?it/s]

✅ Stage 3 Complete. Fusion model saved.



Stage 4 (ALL ENCODERS) - Epoch 1/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 4 (ALL ENCODERS) - Epoch 1: Loss: 0.1378, Val F1: 0.1399
🏆 New best model saved with F1: 0.1399 (ALL ENCODERS fine-tuned)


Stage 4 (ALL ENCODERS) - Epoch 2/2:   0%|          | 0/17 [00:00<?, ?it/s]

Stage 4 (ALL ENCODERS) - Epoch 2: Loss: 0.1172, Val F1: 0.1678
🏆 New best model saved with F1: 0.1678 (ALL ENCODERS fine-tuned)
✅ Stage 4 Complete. End-to-end training of ALL ENCODERS finished.

🔍 Starting Comprehensive 5-Modality Evaluation...



Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Testing GAT-eMFD Text:   0%|          | 0/4 [00:00<?, ?it/s]


--- GAT-eMFD Text Model Results ---
Macro F1-Score: 0.0000


Testing MLP-RoBERTa Text:   0%|          | 0/4 [00:00<?, ?it/s]


--- MLP-RoBERTa Text Model Results ---
Macro F1-Score: 0.0000


Testing Spatial:   0%|          | 0/4 [00:00<?, ?it/s]


--- Spatial Model Results ---
Macro F1-Score: 0.0000


Testing Temporal:   0%|          | 0/4 [00:00<?, ?it/s]


--- Temporal Model Results ---
Macro F1-Score: 0.0937


Testing Behavioral:   0%|          | 0/4 [00:00<?, ?it/s]


--- Behavioral Model Results ---
Macro F1-Score: 0.2361



Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Testing 5-Modality Fusion Model:   0%|          | 0/4 [00:00<?, ?it/s]

              precision    recall  f1-score   support

        Care       0.00      0.00      0.00        44
    Fairness       0.00      0.00      0.00         0
     Loyalty       0.25      0.56      0.34         9
   Authority       0.00      0.00      0.00         4
      Purity       0.00      0.00      0.00         1
   Non_Moral       0.33      0.67      0.44         3

   micro avg       0.20      0.11      0.15        61
   macro avg       0.10      0.20      0.13        61
weighted avg       0.05      0.11      0.07        61
 samples avg       0.11      0.12      0.11        61

** 5-MODALITY FUSION MODEL Macro F1-Score: 0.1315 **

COMPREHENSIVE 5-MODALITY EVALUATION RESULTS

GAT-eMFD Text MODEL:
--------------------------------------------------
  Macro F1-Score: 0.0000
  Care F1: 0.0000
  Fairness F1: 0.0000
  Loyalty F1: 0.0000
  Authority F1: 0.0000
  Purity F1: 0.0000
  Non_Moral F1: 0.0000

MLP-RoBERTa Text MODEL:
--------------------------------------------------
  Ma