In [11]:
import numpy as np
import pandas as pd
import numpy as np
from scripts.vqvae import VQVAE
import torch
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import ast

import warnings
warnings.filterwarnings("ignore")

In [20]:
#---- Define age group assignment function ----
def assign_age_group(age):
    if 55 <= age < 60:
        return '55-60'
    elif 60 <= age < 65:
        return '60-65'
    elif 65 <= age < 70:
        return '65-70'
    elif age >= 70:
        return '70+'
    else:
        return 'Under 50'  
    
def load_original_embeddings():
    #---- Load training embeddings ----
    train_data = np.load('/workspace/jiezy/CLIP-GCA/NLST/nlst_train_with_labels.npz', allow_pickle=True)["arr_0"].item()
    test_data = np.load('/workspace/jiezy/CLIP-GCA/NLST/nlst_tune_with_labels.npz', allow_pickle=True)["arr_0"].item()
    
    #---- construct dataframes ----
    train_df = pd.DataFrame.from_dict(train_data, orient='index')
    test_df = pd.DataFrame.from_dict(test_data, orient='index')
    
    #---- Acquire unique identifiers ----
    train_df["pid"] = [k.split('/')[1] for k in list(train_data.keys())]
    test_df["pid"] = [k.split('/')[1] for k in list(test_data.keys())]
    # Replace first row with indices
    train_df.reset_index(drop=True, inplace=True)
    test_df.reset_index(drop=True, inplace=True)

    #---- Load patient demographics ----
    df = pd.read_csv("/workspace/jiezy/CLIP-GCA/NLST/nlst_780_prsn_idc_20210527.csv")
    df["gender"] = df["gender"].map({1:"M", 2:"F"})

    #---- add patient demographics to dataset ---- 
    train_df['pid'], test_df['pid'], df['pid'] = train_df['pid'].astype(str), test_df['pid'].astype(str), df['pid'].astype(str)
    train_df = train_df.merge(df[['pid', 'gender', "age", "race", "can_scr"]], on='pid', how='left')
    test_df = test_df.merge(df[['pid', 'gender', "age", "race", "can_scr"]], on='pid', how='left')
    
    #---- define age groups ----
    train_df['Age_group'], test_df['Age_group'] = train_df['age'].apply(assign_age_group), test_df['age'].apply(assign_age_group)
    return train_df, test_df

def load_ae_embeddings(dim):
    src_dir = "scripts/Synth-NLST"
    ae_train_df, ae_test_df = pd.read_csv(f"{src_dir}/train_nlst_{dim}.csv"),  pd.read_csv(f"{src_dir}/test_nlst_{dim}.csv")
    ae_train_df['embedding'] = ae_train_df['embedding'].apply(ast.literal_eval)
    ae_test_df['embedding'] = ae_test_df['embedding'].apply(ast.literal_eval)
    print(f'{dim} AE Embeddings loaded successfully...')
    return ae_train_df, ae_test_df

def load_vqvae_embeddings(n, dim):
    #---- Load low-dimensional embeddings ----
    src_dir = "scripts/Synth-NLST"
    vq_train_df, vq_test_df = pd.read_csv(f"{src_dir}/train_vq_nlst_{n}_{dim}.csv"), pd.read_csv(f"{src_dir}/test_vq_nlst_{n}_{dim}.csv")
    vq_train_df['embedding'] = vq_train_df['embedding'].apply(ast.literal_eval)
    vq_test_df['embedding'] = vq_test_df['embedding'].apply(ast.literal_eval)
    print(f'{n}x{dim} VQ Embeddings loaded successfully...')
    return vq_train_df, vq_test_df

## Load Embeddings

In [21]:
#---- Load Original Embeddings ----
train_df, test_df = load_original_embeddings()

#---- Load Vanilla Embeddings -----
ae_train_df, ae_test_df = load_ae_embeddings(dim=8)

#---- Load VQVAE Embeddings -----
vq_train_df, vq_test_df = load_vqvae_embeddings(n=512, dim=4)

8 AE Embeddings loaded successfully...
512x4 VQ Embeddings loaded successfully...


## Visualize Embeddings

In [None]:
import numpy as np
import pandas as pd
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns


#---- Combine Train + Test ----
original_df = pd.concat([train_df, test_df])
ae_df = pd.concat([ae_train_df, ae_test_df])
vq_df = pd.concat([vq_train_df, vq_test_df])

#---- Helper: Flatten embeddings ----
def flatten_embeddings(df):
    return np.vstack(df["embedding"].values)

# Prepare embeddings and labels
embeddings = {
    "Original": flatten_embeddings(original_df),
    "Autoencoder": flatten_embeddings(ae_df),
    "VQ-VAE": flatten_embeddings(vq_df),
}

# Gender labels (assume same length/order as embeddings)
labels = {
    "Original": original_df["gender"].values,
    "Autoencoder": ae_df["gender"].values,
    "VQ-VAE": vq_df["gender"].values,
}

#---- Plot TSNE for each embedding type ----
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

for ax, (name, X) in zip(axes, embeddings.items()):
    print(f"Running TSNE for {name}...")
    X_tsne = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(X)
    
    df_plot = pd.DataFrame({
        "x": X_tsne[:,0],
        "y": X_tsne[:,1],
        "gender": labels[name]
    })
    
    sns.scatterplot(
        data=df_plot, x="x", y="y", hue="gender",
        palette={"male": "blue", "female": "red"},
        alpha=0.6, s=20, ax=ax
    )
    ax.set_title(f"{name} Embeddings (t-SNE)")

plt.tight_layout()
plt.show()


Running TSNE for Original...
