In [50]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score, precision_score, recall_score
import os
import time
import copy

# ------------------------------------
# Configuration & Hyperparameters
# ------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# --- Seeds ---
SEEDS = [16, 24, 45, 54, 65]
GRAPH_SEED_BASE = 11

# --- Graph Learning Parameters ---
GRAPH_LEARN_EPOCHS = 5000
GRAPH_LEARN_ALPHA = 1.0
GRAPH_LEARN_BETA = 2.0
GRAPH_LEARN_ETA = 0.001

# --- GNN Model & Training Parameters ---
SAGE_DROPOUT = 0.2
LR = 4e-4
EPOCHS = 20000
PATIENCE = 1000

# --- HYPERPARAMETER ---
K_VAL = 0.00001

# ------------------------------------
# Dice Loss for Binary Classification
# ------------------------------------
class DiceLoss(nn.Module):
    """
    Dice loss, suitable for binary segmentation or multi-label classification tasks.
    """
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice_coeff = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
        return 1 - dice_coeff

# ------------------------------------
# Learn Graph Structure (GPU Accelerated)
# ------------------------------------
def learn_graph_gpu(X: torch.Tensor, seed_graph: int, pruning_k_percent: float) -> torch.Tensor:
    """
    Learns a graph structure from node features entirely on the GPU.
    """
    N = X.shape[0]
    # Set the seed for graph learning to make it reproducible for a given seed
    torch.manual_seed(seed_graph)
    i_idx, j_idx = torch.triu_indices(N, N, offset=1, device=DEVICE)
    w = torch.randn(len(i_idx), device=DEVICE) * 0.01 + 0.5
    w.requires_grad = True
    delL = torch.sum((X[i_idx] - X[j_idx]) ** 2, dim=1)
    optimizer = torch.optim.Adam([w], lr=GRAPH_LEARN_ETA)

    for epoch in range(GRAPH_LEARN_EPOCHS):
        optimizer.zero_grad()
        w_clipped = torch.clamp(w, min=0)
        degrees = torch.zeros(N, device=DEVICE)
        degrees.index_add_(0, i_idx, w_clipped)
        degrees.index_add_(0, j_idx, w_clipped)
        
        loss_laplacian = torch.sum(delL * w_clipped)
        loss_beta = (GRAPH_LEARN_BETA / 2.0) * torch.sum(w_clipped ** 2)
        loss_alpha = -GRAPH_LEARN_ALPHA * torch.sum(torch.log(degrees + 1e-12))
        loss = loss_laplacian + loss_beta + loss_alpha
        loss.backward()
        optimizer.step()

    best_w = torch.clamp(w.detach(), min=0)
    num_edges = len(best_w)
    k = max(1, int(pruning_k_percent * num_edges))
    
    if k < num_edges:
        threshold = torch.kthvalue(best_w, num_edges - k).values
        top_mask = best_w >= threshold
    else:
        top_mask = torch.ones_like(best_w, dtype=torch.bool)

    final_i = i_idx[top_mask]
    final_j = j_idx[top_mask]
    edge_index = torch.stack([torch.cat([final_i, final_j]), torch.cat([final_j, final_i])], dim=0)
    return edge_index

# ------------------------------------
# GNN Model
# ------------------------------------
class GraphSAGENet(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, 256)
        self.conv2 = SAGEConv(256, 512)
        self.conv3 = SAGEConv(512, 512)
        self.conv4 = SAGEConv(512, 256)
        self.out = torch.nn.Linear(256, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=SAGE_DROPOUT, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=SAGE_DROPOUT, training=self.training)
        x = F.relu(self.conv3(x, edge_index))
        x = F.dropout(x, p=SAGE_DROPOUT, training=self.training)
        x = F.relu(self.conv4(x, edge_index))
        x = F.dropout(x, p=SAGE_DROPOUT, training=self.training)
        return self.out(x)

# ------------------------------------
# Train and Evaluate for One Seed
# ------------------------------------
def run_single_seed(X_raw, Y_raw, seed_model, seed_graph, pruning_k_percent):
    """
    Trains and evaluates the model for a single random seed and a given k.
    """
    # Set seeds for reproducibility for this specific run
    torch.manual_seed(seed_model)
    np.random.seed(seed_model)
    if DEVICE.type == 'cuda':
        torch.cuda.manual_seed_all(seed_model)

    # Preprocessing
    scaler = StandardScaler()
    X_np_abs = np.abs(X_raw.numpy())
    X_scaled_np = scaler.fit_transform(X_np_abs)
    X = torch.tensor(X_scaled_np, dtype=torch.float32).to(DEVICE)
    Y = Y_raw.to(DEVICE)

    # Graph Learning
    start_time = time.time()
    edge_index = learn_graph_gpu(X, seed_graph, pruning_k_percent=pruning_k_percent)
    print(f"Graph learning finished in {time.time() - start_time:.2f}s.")

    data = Data(x=X, y=Y, edge_index=edge_index)

    # --- MODIFIED: 80/10/10 Random Data Split ---
    num_samples = X.shape[0]
    indices = np.arange(num_samples)
    np.random.shuffle(indices) # Shuffle indices randomly

    train_size = int(0.8 * num_samples)
    val_size = int(0.1 * num_samples)
    
    train_idx = indices[:train_size]
    val_idx = indices[train_size : train_size + val_size]
    test_idx = indices[train_size + val_size :]
    # --- END MODIFICATION ---

    data.train_mask = torch.zeros(X.shape[0], dtype=torch.bool, device=DEVICE)
    data.train_mask[train_idx] = True
    data.val_mask = torch.zeros(X.shape[0], dtype=torch.bool, device=DEVICE)
    data.val_mask[val_idx] = True
    data.test_mask = torch.zeros(X.shape[0], dtype=torch.bool, device=DEVICE)
    data.test_mask[test_idx] = True

    # Model Initialization and Training
    model = GraphSAGENet(X.size(1), Y.size(1)).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = DiceLoss()

    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    print("Starting GNN training...")
    for epoch in range(EPOCHS):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            val_out = model(data.x, data.edge_index)
            val_loss = criterion(val_out[data.val_mask], data.y[data.val_mask])
        
        if (epoch + 1) % 1000 == 0:
            print(f"Epoch {epoch+1:05d}/{EPOCHS} | Train Loss: {loss.item():.6f} | Val Loss: {val_loss.item():.6f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = copy.deepcopy(model.state_dict())
        else:
            patience_counter += 1

        if patience_counter >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1}!")
            break

    if best_model_state:
        model.load_state_dict(best_model_state)
    else:
        print("Warning: Early stopping did not trigger. Using the last model state.")

    # Evaluation
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)

    test_logits = out[data.test_mask]
    test_true = data.y[data.test_mask]
    test_preds = (torch.sigmoid(test_logits) > 0.5).int()

    test_preds_np = test_preds.cpu().numpy()
    test_true_np = test_true.cpu().numpy()

    # ##################################################################
    # ########### START: MODIFIED EVALUATION BLOCK #####################
    # ##################################################################

    # --- JACCARD INDEX (IoU) CALCULATION (MICRO-AVERAGED) ---
    intersection = np.sum(test_true_np * test_preds_np)
    union = np.sum(np.logical_or(test_true_np, test_preds_np).astype(int))
    micro_jaccard = 1.0 if union == 0 else intersection / union

    # --- OTHER METRICS (MACRO-AVERAGED) ---
    # These metrics calculate the score for each label independently and then average them.
    recall = recall_score(test_true_np, test_preds_np, average='macro', zero_division=0)
    precision = precision_score(test_true_np, test_preds_np, average='macro', zero_division=0)
    f1 = f1_score(test_true_np, test_preds_np, average='macro', zero_division=0)

    return recall, f1, precision, micro_jaccard

    # ##################################################################
    # ############ END: MODIFIED EVALUATION BLOCK ######################
    # ##################################################################


# ------------------------------------
# Main Execution
# ------------------------------------
if __name__ == "__main__":
    try:
        df = pd.read_csv("2500_gravity_anomaly_irregular_shape.csv")
    except FileNotFoundError:
        print("Error: '2500_gravity_anomaly_irregular_shape.csv' not found.")
        print("Please make sure the CSV file is in the same directory as the script.")
        exit()

    required_samples = 2500
    if df.shape[0] < required_samples:
        print(f"Error: Dataset has only {df.shape[0]} samples, but at least {required_samples} are recommended.")
        exit()

    X_raw = torch.tensor(df.iloc[:, :22].values, dtype=torch.float32)
    Y_raw = torch.tensor(df.iloc[:, 22:274].values, dtype=torch.float32)

    all_results = {'rec': [], 'f1': [], 'prec': [], 'jaccard': []}

    for i, seed in enumerate(SEEDS):
        print(f"\n--- Running Seed: {seed} ({i+1}/{len(SEEDS)}) ---")
        
        # --- MODIFIED: Pass the changing seed to the graph learning function ---
        rec, f1, prec, jaccard = run_single_seed(
            X_raw, Y_raw,
            seed_model=seed,
            seed_graph=seed, # Use the main loop seed for graph randomization
            pruning_k_percent=K_VAL
        )
        # --- END MODIFICATION ---
        
        print(f"\n[RESULTS FOR SEED {seed}]")
        print(f"  Recall: {rec:.4f} | Precision: {prec:.4f} | F1: {f1:.4f} | Jaccard (IoU): {jaccard:.4f}")

        all_results['rec'].append(rec)
        all_results['f1'].append(f1)
        all_results['prec'].append(prec)
        all_results['jaccard'].append(jaccard)

    print("\n\n#############################################")
    print("###      EXPERIMENT FINAL SUMMARY         ###")
    print("#############################################")
    print(f"\n--- Aggregated Results (over {len(SEEDS)} seeds) ---\n")

    recall = np.array(all_results['rec'])
    precision = np.array(all_results['prec'])
    f1 = np.array(all_results['f1'])
    jaccard = np.array(all_results['jaccard'])

    print(f"Jaccard (IoU) (Micro):  Mean = {jaccard.mean():.4f}, Std = {jaccard.std():.4f}")
    print(f"Recall (Macro):         Mean = {recall.mean():.4f}, Std = {recall.std():.4f}")
    print(f"Precision (Macro):      Mean = {precision.mean():.4f}, Std = {precision.std():.4f}")
    print(f"F1-Score (Macro):       Mean = {f1.mean():.4f}, Std = {f1.std():.4f}")
    print("---------------------------------------------")

Using device: cuda

--- Running Seed: 16 (1/5) ---
Graph learning finished in 46.47s.
Starting GNN training...
Epoch 01000/20000 | Train Loss: 0.262025 | Val Loss: 0.261253
Epoch 02000/20000 | Train Loss: 0.241189 | Val Loss: 0.242564
Epoch 03000/20000 | Train Loss: 0.222384 | Val Loss: 0.226037
Epoch 04000/20000 | Train Loss: 0.194513 | Val Loss: 0.195792
Epoch 05000/20000 | Train Loss: 0.172639 | Val Loss: 0.180263
Epoch 06000/20000 | Train Loss: 0.150742 | Val Loss: 0.159126
Epoch 07000/20000 | Train Loss: 0.128686 | Val Loss: 0.141708
Epoch 08000/20000 | Train Loss: 0.104011 | Val Loss: 0.119849
Epoch 09000/20000 | Train Loss: 0.087086 | Val Loss: 0.109496
Epoch 10000/20000 | Train Loss: 0.078166 | Val Loss: 0.107794
Epoch 11000/20000 | Train Loss: 0.070574 | Val Loss: 0.104746
Epoch 12000/20000 | Train Loss: 0.067351 | Val Loss: 0.105331
Early stopping at epoch 12544!

[RESULTS FOR SEED 16]
  Recall: 0.6198 | Precision: 0.6122 | F1: 0.6146 | Jaccard (IoU): 0.8425

--- Running Seed