In [None]:
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

In [None]:
GENO_PATH = "/kaggle/input/genotype-embeddings/genotype_embeddings_64_EDdyg.csv"
MRI_DIR = "/kaggle/input/mri-embeddings-gmu/kaggle/working/MRI_Embeddings" 

In [None]:
def prepare_data(geno_csv_path, mri_df):

    df_geno = pd.read_csv(geno_csv_path)
    df_geno.columns = (
        df_geno.columns
        .str.replace('\ufeff', '', regex=False)
        .str.strip()
    )
    id_col = df_geno.columns[0]

    df_geno = df_geno.rename(columns={id_col: "subject_id"})
    geno_feature_cols = sorted(
        [c for c in df_geno.columns if c.startswith("g_emb_")],
        key=lambda x: int(x.split("_")[-1])
    )

    mri_df = mri_df.copy()

    mri_df["subject_id"] = (
        mri_df["subject_id"]
        .str.replace("_embedding", "", regex=False)
        .str.strip()
    )

    mri_df["embedding"] = mri_df["embedding"].apply(
        lambda x: np.asarray(x, dtype=np.float32)
    )

    df_merged = pd.merge(df_geno, mri_df, on="subject_id", how="inner")

    X_geno = df_merged[geno_feature_cols].values.astype(np.float32)
    X_mri = np.vstack(df_merged["embedding"].values)

    label_encoder = LabelEncoder()
    y = label_encoder.fit_transform(df_merged["label"])

    return X_geno, X_mri, y, label_encoder


In [None]:
class MultimodalDataset(Dataset):
    def __init__(self, x_geno, x_mri, y):
        self.x_geno = torch.tensor(x_geno, dtype=torch.float32)
        self.x_mri = torch.tensor(x_mri, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.x_geno[idx], self.x_mri[idx], self.y[idx]

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class GatedMultimodalUnit(nn.Module):
    def __init__(self, input_dim_a, input_dim_b, hidden_dim):
        super(GatedMultimodalUnit, self).__init__()
        
        self.fc_mri = nn.Linear(input_dim_a, hidden_dim)
        
        self.fc_snp = nn.Linear(input_dim_b, hidden_dim)
        self.fc_gate = nn.Linear(input_dim_a + input_dim_b, hidden_dim)

    def forward(self, x_mri, x_snp):
        h_mri = torch.tanh(self.fc_mri(x_mri))
        h_snp = torch.tanh(self.fc_snp(x_snp))
        combined = torch.cat((x_mri, x_snp), dim=1)
        z = torch.sigmoid(self.fc_gate(combined))
        h_fused = z * h_mri + (1 - z) * h_snp
        
        return h_fused



In [None]:
class MedicalFusionClassifier(nn.Module):
    def __init__(self, mri_dim, snp_dim, hidden_dim, num_classes):
        super().__init__()
        
        self.gmu = GatedMultimodalUnit(mri_dim, snp_dim, hidden_dim)
      
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
        
    def forward(self, x_mri, x_snp):
        fused_vector = self.gmu(x_mri, x_snp)
        output = self.classifier(fused_vector)
        return output

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def build_mri_df(MRI_DIR):
    records = []

    for root, _, files in os.walk(MRI_DIR):
        for f in files:
            if f.endswith(".npy"):
                file_path = os.path.join(root, f)
                subject_id = f.replace("_embedding.npy", "")

                label = os.path.basename(root)

                embedding = np.load(file_path)

                records.append({
                    "subject_id": subject_id,
                    "label": label,
                    "embedding": embedding
                })

    if len(records) == 0:
        raise ValueError("No MRI embedding files found")

    return pd.DataFrame(records)

In [None]:
mri_df = build_mri_df(MRI_DIR)

print(mri_df.head())
print(mri_df.shape)


In [None]:
mri_df

In [None]:
X_geno

In [None]:
X_geno, X_mri, y, label_encoder = prepare_data(
    GENO_PATH,
    mri_df
)


In [None]:
X_geno.shape
X_mri.shape   

In [None]:
Xg_train, Xg_test, Xm_train, Xm_test, y_train, y_test = train_test_split(
    X_geno, X_mri, y, test_size=0.2, stratify=y, random_state=42
)

In [None]:
train_ds = MultimodalDataset(Xg_train, Xm_train, y_train)
test_ds = MultimodalDataset(Xg_test, Xm_test, y_test)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

In [None]:
model = MedicalFusionClassifier(mri_dim=512, snp_dim=64, hidden_dim=128, num_classes=3)
model = model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)


In [None]:
EPOCHS = 500
best_acc = 0.0

In [None]:
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for geno, mri, labels in train_loader:
        geno, mri, labels = geno.to(device), mri.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(mri, geno)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    train_acc = 100 * correct / total
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for geno, mri, labels in test_loader:
            geno, mri, labels = geno.to(device), mri.to(device), labels.to(device)
            outputs = model(mri, geno)
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
            
    val_acc = 100 * val_correct / val_total
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_gmu_model.pth")
        
    if (epoch+1) % 5 == 0:
        print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {running_loss/len(train_loader):.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")
