In [7]:
import torch
import torch.nn as nn
import numpy as np

# --- Device Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Model Definitions ---
class TransformerEncoder(nn.Module):
    def __init__(self, seq_len=512, d_model=128, nhead=4, num_layers=4):
        super(TransformerEncoder, self).__init__()
        self.pos_embedding = nn.Parameter(torch.randn(1, seq_len, d_model))
        self.input_proj = nn.Linear(1, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, x):  # x: (B, T, 1)
        x = self.input_proj(x) + self.pos_embedding  # Add positional encoding
        x = self.encoder(x)  # (B, T, d_model)
        x = x.permute(0, 2, 1)  # (B, d_model, T)
        x = self.pool(x).squeeze(-1)  # (B, d_model)
        return x

class TabularEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(TabularEncoder, self).__init__()
        self.gelu = nn.GELU()
        self.proj = nn.Linear(input_dim, hidden_dim)
        self.gate = nn.Linear(input_dim, hidden_dim)

    def forward(self, x):  # x: (B, input_dim)
        return self.gelu(self.proj(x)) * torch.sigmoid(self.gate(x))  # GEGLU activation

class FusionBlock(nn.Module):
    def __init__(self, d_model):
        super(FusionBlock, self).__init__()
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        self.gate = nn.Linear(d_model * 2, d_model)

    def forward(self, flux_embed, tabular_embed):  # (B, d_model), (B, d_model)
        q = self.query_proj(flux_embed)
        k = self.key_proj(tabular_embed)
        v = self.value_proj(tabular_embed)
        attn_scores = torch.softmax((q * k).sum(dim=-1, keepdim=True), dim=1)
        attended = attn_scores * v
        concat = torch.cat([flux_embed, attended], dim=-1)
        fused = torch.tanh(self.gate(concat))
        return fused

class ClassifierHead(nn.Module):
    def __init__(self, input_dim):
        super(ClassifierHead, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.classifier(x)

# --- Combined Model Wrapper ---
class MultimodalModel(nn.Module):
    def __init__(self, flux_dim=512, tabular_dim=5, d_model=128):
        super(MultimodalModel, self).__init__()
        self.flux_encoder = TransformerEncoder(seq_len=flux_dim, d_model=d_model)
        self.tabular_encoder = TabularEncoder(input_dim=tabular_dim, hidden_dim=d_model)
        self.fusion = FusionBlock(d_model=d_model)
        self.classifier = ClassifierHead(input_dim=d_model)

    def forward(self, flux_input, tabular_input):
        flux_feat = self.flux_encoder(flux_input)
        tabular_feat = self.tabular_encoder(tabular_input)
        fused = self.fusion(flux_feat, tabular_feat)
        return self.classifier(fused)

# --- Load Saved Model ---
model = MultimodalModel().to(device)
model.load_state_dict(torch.load("best_model_run0.pt", map_location=device))
model.eval()

# --- Inference Function ---
def predict_exoplanet(flux_array, tabular_array):
    flux_tensor = torch.tensor(flux_array, dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device)  # (1, 512, 1)
    tabular_tensor = torch.tensor(tabular_array, dtype=torch.float32).unsqueeze(0).to(device)  # (1, 5)

    with torch.no_grad():
        logits = model(flux_tensor, tabular_tensor)
        prob = torch.sigmoid(logits).item()
        label = int(prob > 0.5)

    return label, prob

# --- Example Usage ---
# Replace with real inputs when available
example_flux = np.random.rand(512)  # Preprocessed light curve (normalized)
example_tabular = np.array([5778, 4.44, 0.02, 1.0, 13.5])  # Teff, logg, Fe/H, Radius, Kmag

label, probability = predict_exoplanet(example_flux, example_tabular)
print(f"Prediction: {'Exoplanet' if label == 1 else 'Non-Exoplanet'} | Probability: {probability:.4f}")


RuntimeError: Error(s) in loading state_dict for MultimodalModel:
	Missing key(s) in state_dict: "flux_encoder.pos_embedding", "flux_encoder.input_proj.weight", "flux_encoder.input_proj.bias", "flux_encoder.encoder.layers.0.self_attn.in_proj_weight", "flux_encoder.encoder.layers.0.self_attn.in_proj_bias", "flux_encoder.encoder.layers.0.self_attn.out_proj.weight", "flux_encoder.encoder.layers.0.self_attn.out_proj.bias", "flux_encoder.encoder.layers.0.linear1.weight", "flux_encoder.encoder.layers.0.linear1.bias", "flux_encoder.encoder.layers.0.linear2.weight", "flux_encoder.encoder.layers.0.linear2.bias", "flux_encoder.encoder.layers.0.norm1.weight", "flux_encoder.encoder.layers.0.norm1.bias", "flux_encoder.encoder.layers.0.norm2.weight", "flux_encoder.encoder.layers.0.norm2.bias", "flux_encoder.encoder.layers.1.self_attn.in_proj_weight", "flux_encoder.encoder.layers.1.self_attn.in_proj_bias", "flux_encoder.encoder.layers.1.self_attn.out_proj.weight", "flux_encoder.encoder.layers.1.self_attn.out_proj.bias", "flux_encoder.encoder.layers.1.linear1.weight", "flux_encoder.encoder.layers.1.linear1.bias", "flux_encoder.encoder.layers.1.linear2.weight", "flux_encoder.encoder.layers.1.linear2.bias", "flux_encoder.encoder.layers.1.norm1.weight", "flux_encoder.encoder.layers.1.norm1.bias", "flux_encoder.encoder.layers.1.norm2.weight", "flux_encoder.encoder.layers.1.norm2.bias", "flux_encoder.encoder.layers.2.self_attn.in_proj_weight", "flux_encoder.encoder.layers.2.self_attn.in_proj_bias", "flux_encoder.encoder.layers.2.self_attn.out_proj.weight", "flux_encoder.encoder.layers.2.self_attn.out_proj.bias", "flux_encoder.encoder.layers.2.linear1.weight", "flux_encoder.encoder.layers.2.linear1.bias", "flux_encoder.encoder.layers.2.linear2.weight", "flux_encoder.encoder.layers.2.linear2.bias", "flux_encoder.encoder.layers.2.norm1.weight", "flux_encoder.encoder.layers.2.norm1.bias", "flux_encoder.encoder.layers.2.norm2.weight", "flux_encoder.encoder.layers.2.norm2.bias", "flux_encoder.encoder.layers.3.self_attn.in_proj_weight", "flux_encoder.encoder.layers.3.self_attn.in_proj_bias", "flux_encoder.encoder.layers.3.self_attn.out_proj.weight", "flux_encoder.encoder.layers.3.self_attn.out_proj.bias", "flux_encoder.encoder.layers.3.linear1.weight", "flux_encoder.encoder.layers.3.linear1.bias", "flux_encoder.encoder.layers.3.linear2.weight", "flux_encoder.encoder.layers.3.linear2.bias", "flux_encoder.encoder.layers.3.norm1.weight", "flux_encoder.encoder.layers.3.norm1.bias", "flux_encoder.encoder.layers.3.norm2.weight", "flux_encoder.encoder.layers.3.norm2.bias", "tabular_encoder.proj.weight", "tabular_encoder.proj.bias", "tabular_encoder.gate.weight", "tabular_encoder.gate.bias", "fusion.query_proj.weight", "fusion.query_proj.bias", "fusion.key_proj.weight", "fusion.key_proj.bias", "fusion.value_proj.weight", "fusion.value_proj.bias", "fusion.gate.weight", "fusion.gate.bias", "classifier.classifier.0.weight", "classifier.classifier.0.bias", "classifier.classifier.3.weight", "classifier.classifier.3.bias". 
	Unexpected key(s) in state_dict: "input_proj.weight", "input_proj.bias", "transformer.layers.0.self_attn.in_proj_weight", "transformer.layers.0.self_attn.in_proj_bias", "transformer.layers.0.self_attn.out_proj.weight", "transformer.layers.0.self_attn.out_proj.bias", "transformer.layers.0.linear1.weight", "transformer.layers.0.linear1.bias", "transformer.layers.0.linear2.weight", "transformer.layers.0.linear2.bias", "transformer.layers.0.norm1.weight", "transformer.layers.0.norm1.bias", "transformer.layers.0.norm2.weight", "transformer.layers.0.norm2.bias", "transformer.layers.1.self_attn.in_proj_weight", "transformer.layers.1.self_attn.in_proj_bias", "transformer.layers.1.self_attn.out_proj.weight", "transformer.layers.1.self_attn.out_proj.bias", "transformer.layers.1.linear1.weight", "transformer.layers.1.linear1.bias", "transformer.layers.1.linear2.weight", "transformer.layers.1.linear2.bias", "transformer.layers.1.norm1.weight", "transformer.layers.1.norm1.bias", "transformer.layers.1.norm2.weight", "transformer.layers.1.norm2.bias", "transformer.layers.2.self_attn.in_proj_weight", "transformer.layers.2.self_attn.in_proj_bias", "transformer.layers.2.self_attn.out_proj.weight", "transformer.layers.2.self_attn.out_proj.bias", "transformer.layers.2.linear1.weight", "transformer.layers.2.linear1.bias", "transformer.layers.2.linear2.weight", "transformer.layers.2.linear2.bias", "transformer.layers.2.norm1.weight", "transformer.layers.2.norm1.bias", "transformer.layers.2.norm2.weight", "transformer.layers.2.norm2.bias", "transformer.layers.3.self_attn.in_proj_weight", "transformer.layers.3.self_attn.in_proj_bias", "transformer.layers.3.self_attn.out_proj.weight", "transformer.layers.3.self_attn.out_proj.bias", "transformer.layers.3.linear1.weight", "transformer.layers.3.linear1.bias", "transformer.layers.3.linear2.weight", "transformer.layers.3.linear2.bias", "transformer.layers.3.norm1.weight", "transformer.layers.3.norm1.bias", "transformer.layers.3.norm2.weight", "transformer.layers.3.norm2.bias", "tabular_mlp.0.weight", "tabular_mlp.0.bias", "tabular_mlp.1.weight", "tabular_mlp.1.bias", "tabular_mlp.1.running_mean", "tabular_mlp.1.running_var", "tabular_mlp.1.num_batches_tracked", "tabular_mlp.4.weight", "tabular_mlp.4.bias", "cross_attention.in_proj_weight", "cross_attention.in_proj_bias", "cross_attention.out_proj.weight", "cross_attention.out_proj.bias", "gate.0.weight", "gate.0.bias", "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias". 