<a href="https://colab.research.google.com/github/Jupiterian/ATP-Rankings-API/blob/main/Neurosemantics_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import pandas as pd
import os

REMOTE_DIR = "/content/drive/MyDrive/ADNI/Remotely_Collected_Data"
MRI_DIR = "/content/drive/MyDrive/ADNI/MR_Image_Analysis"


In [None]:


REMOTE_DIR = "/content/drive/MyDrive/ADNI/Remotely_Collected_Data"
MRI_DIR = "/content/drive/MyDrive/ADNI/MR_Image_Analysis"

def load_csv_folder(folder):
    dfs = []
    for f in os.listdir(folder):
        if f.endswith(".csv"):
            dfs.append(
                pd.read_csv(
                    os.path.join(folder, f),
                    low_memory=False
                )
            )
    return pd.concat(dfs, ignore_index=True)

df_remote = load_csv_folder(REMOTE_DIR)
df_mri = load_csv_folder(MRI_DIR)

print(df_remote.shape, df_mri.shape)


(14823, 219) (12228, 347)


In [None]:
# Ensure PTID is string in both
df_remote["PTID"] = df_remote["PTID"].astype(str)
df_mri["PTID"] = df_mri["PTID"].astype(str)


In [None]:
speech_id_cols = [
    "PTID",
    "EXAMDATE",
    "RMT_Timepoint"
]

speech_feature_cols = [
    c for c in df_remote.columns
    if c.startswith(("STORY1", "STORY2"))
]

speech = df_remote[speech_id_cols + speech_feature_cols].copy()
print("Speech shape:", speech.shape)


Speech shape: (14823, 211)


In [None]:
numeric_speech_feature_cols = speech[speech_feature_cols].select_dtypes(include=['number']).columns.tolist()

speech_agg = (
    speech
    .groupby("PTID")[numeric_speech_feature_cols]
    .mean()
    .reset_index()
)

print("Speech aggregated:", speech_agg.shape)

Speech aggregated: (499, 206)


In [None]:
mri_id_cols = ["PTID", "EXAMDATE"]

mri_feature_cols = [
    c for c in df_mri.columns
    if c.startswith("ST") and c.endswith(("CV", "SV"))
]

mri = df_mri[mri_id_cols + mri_feature_cols].copy()
print("MRI shape:", mri.shape)


MRI shape: (12228, 121)


In [None]:
mri_agg = (
    mri
    .groupby("PTID")[mri_feature_cols]
    .mean()
    .reset_index()
)

print("MRI aggregated:", mri_agg.shape)


MRI aggregated: (3235, 120)


In [None]:
#df is fusion dataframe
df = speech_agg.merge(
    mri_agg,
    on="PTID",
    how="inner"
)

print("Final fused dataset:", df.shape)
df.head()


Final fused dataset: (457, 325)


Unnamed: 0,PTID,STORY1IMM_PLAYBACK_COUNT,STORY1IMM_AUDIO_LENGTH,STORY1IMM_NUM_PRONOUNS,STORY1IMM_NUM_NOUNS,STORY1IMM_NUM_PROPER_NOUNS,STORY1IMM_NUM_VERBS,STORY1IMM_NUM_ADVERBS,STORY1IMM_NUM_ADJECTIVES,STORY1IMM_NUM_WORDS,...,ST9SV,ST147SV,ST148SV,ST149SV,ST150SV,ST151SV,ST152SV,ST153SV,ST154SV,ST155SV
0,002_S_0413,0.0,53.05,20.0,11.0,2.0,14.5,2.0,3.0,80.5,...,1714.983333,222624.943249,225106.405565,447731.348815,218727.833333,217941.0,436668.833333,48315.166667,588920.904371,973713.1
1,002_S_4213,0.5,47.62,18.5,14.5,2.5,12.5,5.0,3.0,92.5,...,1792.3875,209967.30571,207728.858633,417696.164343,190352.75,186986.25,377339.0,47928.75,556767.664343,886311.9
2,002_S_4799,0.0,41.08,6.0,7.0,2.0,3.0,0.0,1.0,32.0,...,2006.266667,250450.925246,249114.419315,499565.344561,243200.444444,239326.0,482526.444444,56701.0,665996.789005,1089487.0
3,002_S_6007,0.0,27.64,5.0,10.5,1.5,5.5,1.0,0.5,43.0,...,2506.725,220075.740693,223832.554947,443908.29564,213069.875,209948.125,423018.0,52733.5,602375.29564,951692.0
4,002_S_6053,0.0,54.57,14.0,17.5,4.0,13.5,1.5,6.5,87.5,...,2469.4,254526.508163,252594.472827,507120.980989,254505.5,255573.166667,510078.666667,59516.0,682455.980989,1132441.0


In [None]:
id_cols = ["PTID", "RID", "VISCODE", "EXAMDATE"]

speech_cols = [c for c in df.columns if c.startswith("STORY") or c.startswith("RMT")]
mri_cols = [c for c in df.columns if c not in speech_cols + id_cols]

In [None]:
X_speech = df[speech_cols].values
X_mri    = df[mri_cols].values

In [None]:
import numpy as np

# Convert to float
X_speech = X_speech.astype(np.float32)
X_mri = X_mri.astype(np.float32)

# Replace inf with nan
X_speech[np.isinf(X_speech)] = np.nan
X_mri[np.isinf(X_mri)] = np.nan

# Column-wise mean imputation
speech_means = np.nanmean(X_speech, axis=0)
mri_means = np.nanmean(X_mri, axis=0)

inds = np.where(np.isnan(X_speech))
X_speech[inds] = np.take(speech_means, inds[1])

inds = np.where(np.isnan(X_mri))
X_mri[inds] = np.take(mri_means, inds[1])


In [None]:
speech_var = X_speech.var(axis=0)
mri_var = X_mri.var(axis=0)

X_speech = X_speech[:, speech_var > 1e-6]
X_mri = X_mri[:, mri_var > 1e-6]

print("Speech dims after var filter:", X_speech.shape[1])
print("MRI dims after var filter:", X_mri.shape[1])


Speech dims after var filter: 199
MRI dims after var filter: 119


In [None]:
from sklearn.preprocessing import StandardScaler

speech_scaler = StandardScaler()
mri_scaler = StandardScaler()

X_speech = speech_scaler.fit_transform(X_speech)
X_mri = mri_scaler.fit_transform(X_mri)


In [None]:
speech_tensor = torch.tensor(X_speech, dtype=torch.float32)
mri_tensor = torch.tensor(X_mri, dtype=torch.float32)

def check_tensor(name, x):
    print(
        name,
        "shape:", tuple(x.shape),
        "nan:", torch.isnan(x).any().item(),
        "inf:", torch.isinf(x).any().item(),
        "mean:", x.mean().item(),
        "std:", x.std().item()
    )

check_tensor("speech", speech_tensor)
check_tensor("mri", mri_tensor)


speech shape: (457, 199) nan: False inf: False mean: 3.355681843508762e-10 std: 1.0000054836273193
mri shape: (457, 119) nan: False inf: False mean: -1.5782631868077601e-09 std: 1.0000091791152954


In [None]:
import torch.nn as nn
import torch.nn.functional as F

LATENT_DIM = 128

class SpeechEncoder(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, LATENT_DIM),
            nn.BatchNorm1d(LATENT_DIM),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)

class MRIEncoder(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, LATENT_DIM),
            nn.BatchNorm1d(LATENT_DIM),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)

speech_enc = SpeechEncoder(X_speech.shape[1])
mri_enc = MRIEncoder(X_mri.shape[1])

speech_enc.train()
mri_enc.train()

MRIEncoder(
  (net): Sequential(
    (0): Linear(in_features=119, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
)

In [None]:
import torch.optim as optim
opt = torch.optim.Adam(
    list(speech_enc.parameters()) + list(mri_enc.parameters()),
    lr=3e-4   # was 1e-3
)


In [None]:
def clip_loss(z_speech, z_mri, temperature=0.2):
    logits = (z_speech @ z_mri.T) / temperature
    logits = torch.clamp(logits, -100, 100)

    labels = torch.arange(len(z_speech), device=z_speech.device)

    loss_s = F.cross_entropy(logits, labels)
    loss_m = F.cross_entropy(logits.T, labels)
    return (loss_s + loss_m) / 2


In [None]:
from sklearn.model_selection import train_test_split

ptids = df["PTID"].values
train_ids, test_ids = train_test_split(
    np.unique(ptids),
    test_size=0.2,
    random_state=42
)

train_mask = np.isin(ptids, train_ids)
test_mask  = np.isin(ptids, test_ids)

X_speech_train = X_speech[train_mask]
X_mri_train    = X_mri[train_mask]

X_speech_test = X_speech[test_mask]
X_mri_test    = X_mri[test_mask]


In [None]:
from torch.utils.data import Dataset, DataLoader

class PairedDataset(Dataset):
    def __init__(self, speech_data, mri_data):
        self.speech_data = speech_data
        self.mri_data = mri_data

    def __len__(self):
        return len(self.speech_data)

    def __getitem__(self, idx):
        return self.speech_data[idx], self.mri_data[idx]

train_dataset = PairedDataset(X_speech_train, X_mri_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
speech_enc.train()
mri_enc.train()

train_losses = []

for epoch in range(100):
    total_loss = 0.0
    for speech, mri in train_loader:
        z_s = speech_enc(speech)
        z_m = mri_enc(mri)

        loss = clip_loss(z_s, z_m)

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            list(speech_enc.parameters()) + list(mri_enc.parameters()), 1.0
        )
        opt.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}: loss={avg_loss:.4f}")

Epoch 0: loss=0.0045
Epoch 10: loss=0.0460


KeyboardInterrupt: 

In [None]:
# Plot losses
import matplotlib.pyplot as plt

plt.figure(figsize=(6,4))
plt.plot(train_losses, label="Training Loss")
plt.xlabel("Epoch")
plt.ylabel("CLIP Loss")
plt.title("Contrastive Training Loss")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
speech_enc.eval()
mri_enc.eval()

with torch.no_grad():
    Zs = speech_enc(torch.tensor(X_speech_test, dtype=torch.float32))
    Zm = mri_enc(torch.tensor(X_mri_test, dtype=torch.float32))

    sim = Zs @ Zm.T

    top1 = (sim.argmax(dim=1) == torch.arange(len(sim))).float().mean()

    top5 = (
        sim.topk(5, dim=1).indices ==
        torch.arange(len(sim)).unsqueeze(1)
    ).any(dim=1).float().mean()

print("TEST Speech → MRI Top‑1:", top1.item())
print("TEST Speech → MRI Top‑5:", top5.item())


TEST Speech → MRI Top‑1: 0.010869565419852734
TEST Speech → MRI Top‑5: 0.07608695328235626


In [None]:
speech_enc.eval()
mri_enc.eval()

with torch.no_grad():
    Zs = speech_enc(torch.tensor(X_speech, dtype=torch.float32))
    Zm = mri_enc(torch.tensor(X_mri, dtype=torch.float32))

    sim = Zs @ Zm.T
    top1 = (sim.argmax(dim=1) == torch.arange(len(sim))).float().mean()
    top5 = (
        sim.topk(5, dim=1).indices ==
        torch.arange(len(sim)).unsqueeze(1)
    ).any(dim=1).float().mean()

print("Speech → MRI Top‑1:", top1.item())
print("Speech → MRI Top‑5:", top5.item())


Speech → MRI Top‑1: 0.7768052220344543
Speech → MRI Top‑5: 0.8008752465248108
