In [2]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report
from sklearn.metrics import roc_curve, precision_recall_curve
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Create directory for saving plots
os.makedirs("plots", exist_ok=True)

# ----------------------------------------
# Load preprocessed data
# ----------------------------------------
try:
    X_flux = np.load("X_flux_aligned.npy")
    X_tabular = np.load("X_tabular.npy")
    y = np.load("y.npy")
    flux_embeddings = np.load("flux_embeddings.npy")
except FileNotFoundError as e:
    print(f"Error: Missing file - {e}")
    raise

# Filter labeled samples
mask = y != -1
X_flux = X_flux[mask]
X_tabular = X_tabular[mask]
y = y[mask]
flux_embeddings = flux_embeddings[mask]

# Verify shapes
print("✅ Data shapes:")
print(f" - X_flux: {X_flux.shape}")
print(f" - X_tabular: {X_tabular.shape}")
print(f" - y: {y.shape}")
print(f" - flux_embeddings: {flux_embeddings.shape}")

if len(y) == 0:
    raise ValueError("Error: No labeled samples after filtering (y is empty)")

# Split
Xf_train, Xf_val, Xt_train, Xt_val, y_train, y_val = train_test_split(
    flux_embeddings, X_tabular, y, test_size=0.2, stratify=y, random_state=42
)

# Verify split shapes
print("✅ Split shapes:")
print(f" - Xf_train: {Xf_train.shape}, Xf_val: {Xf_val.shape}")
print(f" - Xt_train: {Xt_train.shape}, Xt_val: {Xt_val.shape}")
print(f" - y_train: {y_train.shape}, y_val: {y_val.shape}")

# Convert to tensors
Xf_train_tensor = torch.tensor(Xf_train, dtype=torch.float32)
Xt_train_tensor = torch.tensor(Xt_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)

Xf_val_tensor = torch.tensor(Xf_val, dtype=torch.float32)
Xt_val_tensor = torch.tensor(Xt_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32).unsqueeze(1)

# DataLoaders
train_loader = DataLoader(TensorDataset(Xf_train_tensor, Xt_train_tensor, y_train_tensor), batch_size=32, shuffle=True)
val_loader = DataLoader(TensorDataset(Xf_val_tensor, Xt_val_tensor, y_val_tensor), batch_size=32, shuffle=False)

# ----------------------------------------
# Focal Loss implementation
# ----------------------------------------
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):  # Adjusted alpha for imbalance
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):
        bce_loss = self.bce(inputs, targets)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        return focal_loss.mean()

# ----------------------------------------
# Gated Multimodal Fusion Model
# ----------------------------------------
class GatedMultimodalFusion(nn.Module):
    def __init__(self, flux_dim=16, tabular_dim=5):  # Fixed flux_dim to 16
        super().__init__()
        self.flux_branch = nn.Sequential(
            nn.Linear(flux_dim, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.5),  # Increased dropout
            nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.5)
        )
        self.tabular_branch = nn.Sequential(
            nn.Linear(tabular_dim, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.5)
        )
        self.gate = nn.Sequential(
            nn.Linear(64 * 2, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        self.classifier = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(32, 1)
        )

    def forward(self, flux_x, tabular_x):
        flux_feat = self.flux_branch(flux_x)   # (batch, 64)
        tab_feat = self.tabular_branch(tabular_x)  # (batch, 64)
        fused_input = torch.cat((flux_feat, tab_feat), dim=1)
        gate = self.gate(fused_input)          # (batch, 1)
        combined = gate * flux_feat + (1 - gate) * tab_feat
        return self.classifier(combined)

# ----------------------------------------
# Training Setup
# ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")
model = GatedMultimodalFusion(flux_dim=16).to(device)
criterion = FocalLoss(alpha=0.25, gamma=2.0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)  # Added weight decay
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

# Lists to store metrics
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

# Early stopping parameters
patience = 5
best_val_loss = float('inf')
epochs_no_improve = 0
early_stop = False

# ----------------------------------------
# Training Loop with Metrics
# ----------------------------------------
epochs = 30
threshold = 0.5  # Lowered threshold

for epoch in range(epochs):
    if early_stop:
        print("Early stopping triggered")
        break

    model.train()
    train_loss = 0
    train_preds = []
    train_labels = []

    for fx, tx, lbl in train_loader:
        fx, tx, lbl = fx.to(device), tx.to(device), lbl.to(device)
        pred = model(fx, tx)
        loss = criterion(pred, lbl)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        # Compute training accuracy
        probs = torch.sigmoid(pred).cpu().detach().numpy()
        preds = (probs > threshold).astype(int)
        train_preds.extend(preds.flatten())
        train_labels.extend(lbl.cpu().numpy().flatten())

    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = accuracy_score(train_labels, train_preds)
    train_losses.append(avg_train_loss)
    train_accuracies.append(train_accuracy)

    # Validation
    model.eval()
    val_loss = 0
    val_preds = []
    val_labels = []
    val_probs = []

    with torch.no_grad():
        for fx, tx, lbl in val_loader:
            fx, tx, lbl = fx.to(device), tx.to(device), lbl.to(device)
            pred = model(fx, tx)
            loss = criterion(pred, lbl)
            val_loss += loss.item()

            probs = torch.sigmoid(pred).cpu().numpy()
            preds = (probs > threshold).astype(int)
            val_preds.extend(preds.flatten())
            val_labels.extend(lbl.cpu().numpy().flatten())
            val_probs.extend(probs.flatten())

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = accuracy_score(val_labels, val_preds)
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_accuracy)

    print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, "
          f"Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")

    # Learning rate scheduling
    scheduler.step(avg_val_loss)

    # Early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), "checkpoints/best_model.pth")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping after {epoch+1} epochs")
            early_stop = True

# Load best model
model.load_state_dict(torch.load("checkpoints/best_model.pth"))

# ----------------------------------------
# Final Evaluation
# ----------------------------------------
model.eval()
with torch.no_grad():
    val_logits = model(Xf_val_tensor.to(device), Xt_val_tensor.to(device))
    val_probs = torch.sigmoid(val_logits).cpu().numpy()
    val_preds = (val_probs > threshold).astype(int)

# Debug shapes before confusion matrix
print("✅ Evaluation shapes:")
print(f" - y_val: {y_val.shape}")
print(f" - val_preds: {val_preds.shape}")
print(f" - val_probs: {val_probs.shape}")

# Compute metrics
acc = accuracy_score(y_val, val_preds)
auc = roc_auc_score(y_val, val_probs)
if len(y_val) > 0 and len(val_preds) > 0 and len(np.unique(y_val)) > 1:
    cm = confusion_matrix(y_val, val_preds)
    print("✅ Confusion Matrix:\n", cm)
else:
    print("⚠️ Warning: Cannot compute confusion matrix due to invalid y_val or val_preds")
    cm = None

report = classification_report(y_val, val_preds, target_names=["False Positive", "Confirmed"], zero_division=0)

print(f"\n✅ Accuracy: {acc:.4f}")
print(f"✅ AUC: {auc:.4f}")
print("📋 Classification Report:\n", report)

# ----------------------------------------
# Visualization Functions
# ----------------------------------------
def plot_loss_curve(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', marker='o')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss Over Epochs')
    plt.legend()
    plt.grid(True)
    plt.savefig("plots/loss_curve.png")
    plt.show()

def plot_accuracy_curve(train_accuracies, val_accuracies):
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Training Accuracy', marker='o')
    plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy Over Epochs')
    plt.legend()
    plt.grid(True)
    plt.savefig("plots/accuracy_curve.png")
    plt.show()

def plot_confusion_matrix(cm):
    if cm is None or cm.size == 0:
        print("⚠️ Warning: Cannot plot confusion matrix - invalid or empty matrix")
        return
    plt.figure(figsize=(5, 4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["False Positive", "Confirmed"],
                yticklabels=["False Positive", "Confirmed"])
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix - Gated Multimodal Fusion")
    plt.savefig("plots/confusion_matrix.png")
    plt.show()

def plot_roc_curve(y_true, y_probs):
    if len(y_true) == 0 or len(y_probs) == 0 or len(np.unique(y_true)) <= 1:
        print("⚠️ Warning: Cannot plot ROC curve - invalid or insufficient data")
        return
    fpr, tpr, _ = roc_curve(y_true, y_probs)
    auc_score = roc_auc_score(y_true, y_probs)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {auc_score:.4f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend()
    plt.grid(True)
    plt.savefig("plots/roc_curve.png")
    plt.show()

def plot_precision_recall_curve(y_true, y_probs):
    if len(y_true) == 0 or len(y_probs) == 0 or len(np.unique(y_true)) <= 1:
        print("⚠️ Warning: Cannot plot PR curve - invalid or insufficient data")
        return
    precision, recall, _ = precision_recall_curve(y_true, y_probs)
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, label='Precision-Recall Curve')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()
    plt.grid(True)
    plt.savefig("plots/precision_recall_curve.png")
    plt.show()

# ----------------------------------------
# Generate Visualizations
# ----------------------------------------
plot_loss_curve(train_losses, val_losses)
plot_accuracy_curve(train_accuracies, val_accuracies)
plot_confusion_matrix(cm)
plot_roc_curve(y_val, val_probs)
plot_precision_recall_curve(y_val, val_probs)

print("✅ Visualizations saved in 'plots' directory:")
print(" - loss_curve.png")
print(" - accuracy_curve.png")
print(" - confusion_matrix.png")
print(" - roc_curve.png")
print(" - precision_recall_curve.png")

✅ Data shapes:
 - X_flux: (829, 2000)
 - X_tabular: (829, 5)
 - y: (829,)
 - flux_embeddings: (829, 64)
✅ Split shapes:
 - Xf_train: (663, 64), Xf_val: (166, 64)
 - Xt_train: (663, 5), Xt_val: (166, 5)
 - y_train: (663,), y_val: (166,)
✅ Using device: cpu




RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x64 and 16x128)