In [2]:
import torch
import torch.nn as nn
print(torch.__version__)
print(hasattr(nn, 'MultiHeadAttention'))
try:
    attn = nn.MultiHeadAttention(embed_dim=128, num_heads=8)
    print("Successfully created MultiHeadAttention")
except AttributeError as e:
    print(f"Error: {e}")
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report

# ----------------------------------------
# Load preprocessed data
# ----------------------------------------
X_flux = np.load("X_flux_aligned.npy")
X_tabular = np.load("X_tabular.npy")
y = np.load("y.npy")
flux_embeddings = np.load("flux_embeddings.npy")

# Filter labeled samples
mask = y != -1
X_flux = X_flux[mask]
X_tabular = X_tabular[mask]
y = y[mask]
flux_embeddings = flux_embeddings[mask]

# Split
Xf_train, Xf_val, Xt_train, Xt_val, y_train, y_val = train_test_split(
    flux_embeddings, X_tabular, y, test_size=0.2, stratify=y, random_state=42
)

# Convert to tensors
Xf_train_tensor = torch.tensor(Xf_train, dtype=torch.float32)
Xt_train_tensor = torch.tensor(Xt_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)

Xf_val_tensor = torch.tensor(Xf_val, dtype=torch.float32)
Xt_val_tensor = torch.tensor(Xt_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32).unsqueeze(1)

train_loader = DataLoader(TensorDataset(Xf_train_tensor, Xt_train_tensor, y_train_tensor), batch_size=32, shuffle=True)

# ----------------------------------------
# Focal Loss implementation
# ----------------------------------------
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):
        bce_loss = self.bce(inputs, targets)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        return focal_loss.mean()

# ----------------------------------------
# Positional Encoding
# ----------------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)
# ----------------------------------------
# Gating Fusion Model
# ----------------------------------------
class GatedMultimodalFusion(nn.Module):
    def __init__(self,seq_len=2000, flux_dim=64, tabular_dim=5):
        super().__init__()
        self.input_proj = nn.Linear(flux_dim, 128)
        self.pos_encoder = PositionalEncoding(128, seq_len)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=128, nhead=8, batch_first=True),
            num_layers=4
        )
        self.tabular_branch = nn.Sequential(
            nn.Linear(tabular_dim, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(64, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.3)
        )
        self.gate = nn.Sequential(
            nn.Linear(64 * 2, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        self.classifier = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, 1)
        )

    def forward(self, flux_x, tabular_x):
        flux_feat = self.flux_branch(flux_x)   # (batch, 64)
        tab_feat = self.tabular_branch(tabular_x)  # (batch, 64)
        fused_input = torch.cat((flux_feat, tab_feat), dim=1)
        gate = self.gate(fused_input)          # (batch, 1)
        combined = gate * flux_feat + (1 - gate) * tab_feat
        return self.classifier(combined)

# ----------------------------------------
# Training Setup
# ----------------------------------------
model = GatedMultimodalFusion()
criterion = FocalLoss(alpha=1.0, gamma=2.0)  # Replaces BCEWithLogitsLoss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# ----------------------------------------
# Training Loop
# ----------------------------------------
for epoch in range(30):
    model.train()
    total_loss = 0
    for fx, tx, lbl in train_loader:
        pred = model(fx, tx)
        loss = criterion(pred, lbl)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/30 - Loss: {total_loss / len(train_loader):.4f}")

# ----------------------------------------
# Evaluation
# ----------------------------------------
model.eval()
with torch.no_grad():
    val_logits = model(Xf_val_tensor, Xt_val_tensor)
    val_probs = torch.sigmoid(val_logits).numpy()

    # 🎯 Adjust threshold for better precision/accuracy
    threshold = 0.6
    val_preds = (val_probs > threshold).astype(int)

# Metrics
acc = accuracy_score(y_val_tensor.numpy(), val_preds)
auc = roc_auc_score(y_val_tensor.numpy(), val_probs)
cm = confusion_matrix(y_val_tensor.numpy(), val_preds)
report = classification_report(y_val_tensor.numpy(), val_preds, target_names=["False Positive", "Confirmed"])

print(f"\n✅ Accuracy: {acc:.4f}")
print(f"✅ AUC: {auc:.4f}")
print("🧾 Confusion Matrix:\n", cm)
print("📋 Classification Report:\n", report)


RuntimeError: Trying to override a python impl for DispatchKey.Meta on operator aten::broadcast_tensors