In [None]:
# load libraries
import os
import torch
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from torch import optim
from torch.utils.data import DataLoader

os.chdir("..")
from data.dataset import ConditionalDataset
from data.utils import to_one_hot, load_asd_score, load_dfnc, vector2matrix, compute_sub_per_state, compute_fnc_per_state, compute_dwell_state, compute_transition_matrix, find_unique_ind
from models.vae import VAE
from models.ivae import iVAE
from models.utils import EarlyStopper
from visualization.utils import plot_fnc

In [None]:
# set VAE parameters
method = 'vae'
seed = 9
n_layer = 5
hidden_dim = [256, 128, 64, 32, 16]
latent_dim = 2
batch_size = 512
learning_rate = 0.001
n_epoch = 1000
cuda = False
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
res_path = f'/data/users4/xli/interpolation/results/dfnc_asd/{method}/test'
os.mkdir(res_path) if not os.path.exists(res_path) else None
ckpt_file = os.path.join(res_path, f'{method}_layer{n_layer}_dim{hidden_dim}_bs{batch_size}_lr{learning_rate}_seed{seed}.pt')

# load subject measures and dFNC data
data_path = '/data/users2/zfu/Matlab/GSU/Neuromark/Results'
abide1_sub_path = os.path.join(data_path, 'Subject_selection', 'ABIDE1', 'sub_info_ABIDE1_TRall.mat')
abide1_sub_data, abide1_site, invalid_sub_ind = load_asd_score(abide1_sub_path)
abide1_dfnc_path = os.path.join(data_path, 'DFNC', 'ABIDE1', 'TRall', 'ABIDE1_dfnc_sub_*_sess_001_results.mat')
abide1_dfnc_data = load_dfnc(abide1_dfnc_path, invalid_sub_ind, dataset='ABIDE')

# concatenate patients and controls
n_sub = 266
ind_pt = np.where(((abide1_site=="NYU") | (abide1_site=="UCLA_1") | (abide1_site=="USM")) & (abide1_sub_data[:,0]==1))[0][:n_sub//2]
ind_hc = np.where(((abide1_site=="NYU") | (abide1_site=="UCLA_1") | (abide1_site=="USM")) & (abide1_sub_data[:,0]==2))[0][:n_sub-len(ind_pt)]
print(f"N patient = {ind_pt.shape[0]}; N control = {ind_hc.shape[0]}")

sfnc_data = np.concatenate((abide1_dfnc_data[ind_pt, :], abide1_dfnc_data[ind_hc, :]), axis=0)
sub_data = np.concatenate((abide1_sub_data[ind_pt, :], abide1_sub_data[ind_hc, :]), axis=0)
age_pt = abide1_sub_data[ind_pt, -2]
age_hc = abide1_sub_data[ind_hc, -2]
sex_pt = abide1_sub_data[ind_pt, -1]
sex_hc = abide1_sub_data[ind_hc, -1]
# print(f"Age patient mean={np.mean(age_pt):.2f}, std={np.std(age_pt):.2f}, max={max(age_pt):.0f}, min={min(age_pt):.0f}; control mean={np.mean(age_hc):.2f}, std={np.std(age_hc):.2f}, max={max(age_hc):.0f}, min={min(age_hc):.0f}")
# print(f"Sex patient N male = {np.sum(sex_pt==1)}, N female = {np.sum(sex_pt==2)}; control N male = {np.sum(sex_hc==1)}, N female = {np.sum(sex_hc==2)}")

n_test = 41
n_pt_test = n_test//2
n_hc_test = n_test - n_pt_test
ind_pt_test = []
ind_hc_test = []
site_patient = abide1_site[ind_pt]
site_control = abide1_site[ind_hc]
site_patient_list = list(np.unique(abide1_site[ind_pt]))
site_control_list = list(np.unique(abide1_site[ind_hc]))

# print("Patient")
for i, s in enumerate(site_patient_list):
    ind = np.where(site_patient == s)[0]
    # stratify test set by site
    n_start = ind[0]
    n_end = int(np.round(len(ind)/len(ind_pt)*n_test*0.5))
    ind_pt_site_test = ind_pt[n_start:n_start+n_end]
    ind_pt_test += list(ind_pt_site_test)
    # print(f"{s} & {len(ind)} & {np.sum(sex_pt[ind]==1)} & {np.sum(sex_pt[ind]==2)} & ${round(np.mean(age_pt[ind]), 2)}\\pm{round(np.std(age_pt[ind]),2)}$ & ${round(np.min(age_pt[ind]),2)}-{round(np.max(age_pt[ind]),2)}$")

# print("Control")
for i, s in enumerate(site_control_list):
    ind = np.where(site_control == s)[0]
    # stratify test set by site
    n_start = ind[0]
    n_end = int(np.round(len(ind)/len(ind_hc)*n_test*0.5))
    if s=="NYU":
        n_end += 1
    ind_hc_site_test = ind_hc[n_start:n_start+n_end]
    ind_hc_test += list(ind_hc_site_test)
    # print(f"{s} & {len(ind)} & {np.sum(sex_hc[ind]==1)} & {np.sum(sex_hc[ind]==2)} & ${round(np.mean(age_hc[ind]), 2)}\\pm{round(np.std(age_hc[ind]),2)}$ & ${round(np.min(age_hc[ind]),2)}-{round(np.max(age_hc[ind]),2)}$")

In [None]:
# split data into training and test sets
n_train = n_sub - n_test
ind_pt_train = list(set(ind_pt) - set(ind_pt_test))
ind_hc_train = list(set(ind_hc) - set(ind_hc_test))
n_pt_train = len(ind_pt_train)
n_hc_train = len(ind_hc_train)
n_pt_test = len(ind_pt_test)
n_hc_test = len(ind_hc_test)
print(f"N patient test = {len(ind_pt_test)}; N control test = {len(ind_hc_test)}; N patient train = {len(ind_pt_train)}; N control train = {len(ind_hc_train)}")

In [None]:
dfnc_data_test = np.concatenate((abide1_dfnc_data[ind_pt_test, :, :], abide1_dfnc_data[ind_hc_test, :, :]), axis=0)
sub_data_test = np.concatenate((abide1_sub_data[ind_pt_test, :], abide1_sub_data[ind_hc_test, :]), axis=0)
dfnc_data_train = np.concatenate((abide1_dfnc_data[ind_pt_train, :], abide1_dfnc_data[ind_hc_train, :]), axis=0)
sub_data_train = np.concatenate((abide1_sub_data[ind_pt_train, :], abide1_sub_data[ind_hc_train, :]), axis=0)

n_window = dfnc_data_test.shape[1]
n_feature = dfnc_data_test.shape[2]  
print(f"n_window = {n_window}; n_feature = {n_feature}")

dfnc_data_test_2d = dfnc_data_test.reshape((n_test*n_window, n_feature))
dfnc_data_train_2d = dfnc_data_train.reshape((n_train*n_window, n_feature))

y_test = to_one_hot(np.array([1]*n_pt_test*n_window+[2]*n_hc_test*n_window))[0][:,1:]
y_train = to_one_hot(np.array([1]*n_pt_train*n_window+[2]*n_hc_train*n_window))[0][:,1:]

In [None]:
kmeans_list = []
loss_l1_list = []
loss_l2_list = []
cluster_center_list = []
label_list = []
all_state_list = []

for nc in range(2,10):
    kmeans = KMeans(n_clusters=nc, random_state=0).fit(dfnc_data_train_2d)
    kmeans_list.append(kmeans)
    cluster_center = kmeans.cluster_centers_
    cluster_center_list.append(cluster_center)
    label = kmeans.labels_
    label_list.append(label)

    # K-means states
    state_list = []
    for i in range(nc):
        state = vector2matrix(cluster_center[i])
        state_list.append(state)
    all_state_list.append(state_list)

    loss_l1, loss_l2 = 0, 0
    for i, l in enumerate(label):
        loss_l1 += np.sum(np.abs(cluster_center[l]-dfnc_data_train_2d[i,:])) # L1 norm
        loss_l2 += np.sum((cluster_center[l]-dfnc_data_train_2d[i,:])**2) # L2 norm
    loss_l1 /= n_train*n_window
    loss_l2 /= n_train*n_window
    loss_l1_list.append(loss_l1)
    loss_l2_list.append(loss_l2)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10,3))
axes[0].plot(range(2,10), loss_l1_list)
axes[0].set_xlabel("k", fontsize=12)
axes[0].set_ylabel("L1 loss", fontsize=12)
axes[1].plot(range(2,10), loss_l2_list)
axes[1].set_xlabel("k", fontsize=12)
axes[1].set_ylabel("L2 loss", fontsize=12)
plt.tight_layout()
plt.savefig(os.path.join(res_path, 'dfnc_kmeans_loss.png'), dpi=500)

In [None]:
n_state = 5 # state 5

# np.save(os.path.join(res_path, 'loss_l1_list.npy'), loss_l1_list)
# np.save(os.path.join(res_path, 'loss_l2_list.npy'), loss_l2_list)
# np.save(os.path.join(res_path, f'label_state{n_state}.npy'), label_list[n_state-2])
# np.save(os.path.join(res_path, f'cluster_center_state{n_state}.npy'), cluster_center_list[n_state-2])
# np.save(os.path.join(res_path, f'all_state_state{n_state}.npy'), all_state_list[n_state-2])

kmeans_state = all_state_list[n_state-2]
kmeans_label = label_list[n_state-2]
kmeans_cluster_center = cluster_center_list[n_state-2]

label_pt = kmeans_label[:n_pt_train*n_window]
label_hc = kmeans_label[n_pt_train*n_window:]
dfnc_data_train_2d_pt = dfnc_data_train_2d[:n_pt_train*n_window,:]
dfnc_data_train_2d_hc = dfnc_data_train_2d[n_pt_train*n_window:,:]

num_sub_per_state, ratio_sub_per_state = compute_sub_per_state(kmeans_label=kmeans_label, n_pt=n_pt_train, n_window=n_window)
num_fnc_per_state, ratio_fnc_per_state = compute_fnc_per_state(kmeans_label=kmeans_label, n_pt=n_pt_train, n_window=n_window)
sorted_state_ind = np.argsort(ratio_fnc_per_state[1,:])
ratio_fnc_per_state_sorted = ratio_fnc_per_state[:, sorted_state_ind]

print(num_sub_per_state)
print(ratio_sub_per_state)
print(num_fnc_per_state)
print(ratio_fnc_per_state)
print(sorted_state_ind)

In [None]:
dfnc_state = np.zeros((3, n_state, 53, 53))

fig, axes = plt.subplots(nrows=3, ncols=n_state, figsize=(5*n_state, 5*3))

for i, j in enumerate(np.arange(n_state)[sorted_state_ind]):
  state_pt = dfnc_data_train_2d_pt[np.where(label_pt == j)[0],:]
  state_hc = dfnc_data_train_2d_hc[np.where(label_hc == j)[0],:]
  stat, pvalue = stats.ttest_ind(a=state_pt, b=state_hc, equal_var=True)
  pvalue_map = vector2matrix(pvalue)
  pvalue_mask = pvalue_map <= (0.05/len(pvalue))
  upper_triangle_mask = np.triu(np.ones_like(pvalue_mask)).astype(bool)
  pvalue_mask[upper_triangle_mask] = 1
  cluster_median_1d_pt = np.median(state_pt, axis=0)
  cluster_median_1d_hc = np.median(state_hc, axis=0)
  dfnc_state[0,i] = vector2matrix(cluster_median_1d_pt)
  dfnc_state[1,i] = vector2matrix(cluster_median_1d_hc)
  dfnc_state[2,i] = dfnc_state[0,i] - dfnc_state[1,i]
  if i == 0:
    show_xticks = True
  else:
    show_xticks = False
  plot_fnc(dfnc_state[0,i], axes[0,i], f"{round(ratio_fnc_per_state_sorted[0,i]*100,1)}% ASD\n{round(ratio_fnc_per_state_sorted[1,i]*100,1)}% CTR", show_xticks=show_xticks)
  plot_fnc(dfnc_state[1,i], axes[1,i], show_xticks=show_xticks)
  plot_fnc(dfnc_state[2,i] * pvalue_mask * 2, axes[2,i], show_xticks=show_xticks)

plt.tight_layout()
plt.savefig(os.path.join(res_path, f'dfnc_kmeans_{n_state}states_pt_hc_pt-hc_sorted.png'), dpi=500)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(5*n_state, 5))

for i, j in enumerate(np.arange(n_state)[sorted_state_ind]):
  state = dfnc_data_train_2d[np.where(kmeans_label == j)[0], :]
  state_median_1d = np.median(state, axis=0)
  state_median_2d = vector2matrix(state_median_1d)
  plot_fnc(state_median_2d, axes[i], f"State {i+1}")

plt.tight_layout()
plt.savefig(os.path.join(res_path, f'dfnc_kmeans_{n_state}states_pt_sorted.png'), bbox_inches='tight', dpi=500)

In [None]:
dwell_time = num_fnc_per_state / num_sub_per_state
dwell_time_sorted = dwell_time[:, sorted_state_ind]

kmeans_label_2d = kmeans_label.reshape((n_train, n_window))
dwell_state_mean_pt, dwell_state_ste_pt, dwell_state_mean_hc, dwell_state_ste_hc, dwell_state_pvalue = compute_dwell_state(kmeans_label_2d, sorted_state_ind, n_pt_train, n_hc_train)
transition_matrix_pt, transition_matrix_hc, transition_matrix, transition_matrix_pvalue = compute_transition_matrix(kmeans_label_2d, sorted_state_ind, n_pt_train)

np.save(os.path.join(res_path, 'dwell_time.npy'), dwell_time_sorted)
np.save(os.path.join(res_path, 'dwell_state_pvalue.npy'), dwell_state_pvalue)
np.save(os.path.join(res_path, 'dwell_state_mean_pt.npy'), dwell_state_mean_pt)
np.save(os.path.join(res_path, 'dwell_state_ste_pt.npy'), dwell_state_ste_pt)
np.save(os.path.join(res_path, 'dwell_state_mean_hc.npy'), dwell_state_mean_hc)
np.save(os.path.join(res_path, 'dwell_state_ste_hc.npy'), dwell_state_ste_hc)
np.save(os.path.join(res_path, 'transition_matrix_pt.npy'), transition_matrix_pt)
np.save(os.path.join(res_path, 'transition_matrix_hc.npy'), transition_matrix_hc)
np.save(os.path.join(res_path, 'transition_matrix.npy'), transition_matrix)
np.save(os.path.join(res_path, 'transition_matrix_pvalue.npy'), transition_matrix_pvalue)

In [None]:
kmeans = kmeans_list[n_state-2]
kmeans_label_test = kmeans.predict(dfnc_data_test_2d)
label_pt_test = kmeans_label_test[:n_pt_test*n_window]
label_hc_test = kmeans_label_test[n_pt_test*n_window:]
dfnc_data_test_2d_pt = dfnc_data_test_2d[:n_pt_test*n_window,:]
dfnc_data_test_2d_hc = dfnc_data_test_2d[n_pt_test*n_window:,:]

num_sub_per_state_test, ratio_sub_per_state_test = compute_sub_per_state(kmeans_label=kmeans_label_test, n_pt=n_pt_test, n_window=n_window)
num_fnc_per_state_test, ratio_fnc_per_state_test = compute_fnc_per_state(kmeans_label=kmeans_label_test, n_pt=n_pt_test, n_window=n_window)
sorted_state_ind_test = np.argsort(ratio_fnc_per_state_test[1,:])
ratio_fnc_per_state_sorted_test = ratio_fnc_per_state_test[:, sorted_state_ind_test]

print(num_sub_per_state_test)
print(ratio_sub_per_state_test)
print(num_fnc_per_state_test)
print(ratio_fnc_per_state_test)
print(sorted_state_ind_test)

In [None]:
dfnc_state_test = np.zeros((3, n_state, 53, 53))

fig, axes = plt.subplots(nrows=3, ncols=n_state, figsize=(5*n_state, 5*3))

for i, j in enumerate(np.arange(n_state)[sorted_state_ind_test]):
  state_pt = dfnc_data_test_2d_pt[np.where(label_pt_test == j)[0],:]
  state_hc = dfnc_data_test_2d_hc[np.where(label_hc_test == j)[0],:]
  stat, pvalue = stats.ttest_ind(a=state_pt, b=state_hc, equal_var=True)
  pvalue_map = vector2matrix(pvalue)
  pvalue_mask = pvalue_map <= (0.05/len(pvalue))
  upper_triangle_mask = np.triu(np.ones_like(pvalue_mask)).astype(bool)
  pvalue_mask[upper_triangle_mask] = 1
  cluster_median_1d_pt = np.median(state_pt, axis=0)
  cluster_median_1d_hc = np.median(state_hc, axis=0)
  dfnc_state_test[0,i] = vector2matrix(cluster_median_1d_pt)
  dfnc_state_test[1,i] = vector2matrix(cluster_median_1d_hc)
  dfnc_state_test[2,i] = dfnc_state_test[0,i] - dfnc_state_test[1,i]
  if i == 0:
    show_xticks = True
  else:
    show_xticks = False
  plot_fnc(dfnc_state_test[0,i], axes[0,i], f"{round(ratio_fnc_per_state_sorted_test[0,i]*100,1)}% ASD\n{round(ratio_fnc_per_state_sorted_test[1,i]*100,1)}% CTR", show_xticks=show_xticks)
  plot_fnc(dfnc_state_test[1,i], axes[1,i], show_xticks=show_xticks)
  plot_fnc(dfnc_state_test[2,i] * pvalue_mask * 2, axes[2,i], show_xticks=show_xticks)

plt.tight_layout()
plt.savefig(os.path.join(res_path, f'dfnc_kmeans_{n_state}states_pt_hc_pt-hc_sorted_test.png'), dpi=500)

In [None]:
dwell_time_test = num_fnc_per_state_test / num_sub_per_state_test
dwell_time_sorted_test = dwell_time_test[:, sorted_state_ind_test]
kmeans_label_2d_test = kmeans_label_test.reshape((n_test, n_window))

dwell_state_mean_pt_test, dwell_state_ste_pt_test, dwell_state_mean_hc_test, dwell_state_ste_hc_test, dwell_state_pvalue_test = compute_dwell_state(kmeans_label_2d_test, sorted_state_ind_test, n_pt_test, n_hc_test)
transition_matrix_pt_test, transition_matrix_hc_test, transition_matrix_test, transition_matrix_pvalue_test = compute_transition_matrix(kmeans_label_2d_test, sorted_state_ind_test, n_pt_test)

np.save(os.path.join(res_path, 'dwell_time_test.npy'), dwell_time_sorted_test)
np.save(os.path.join(res_path, 'dwell_state_pvalue_test.npy'), dwell_state_pvalue_test)
np.save(os.path.join(res_path, 'dwell_state_mean_pt_test.npy'), dwell_state_mean_pt_test)
np.save(os.path.join(res_path, 'dwell_state_ste_pt_test.npy'), dwell_state_ste_pt_test)
np.save(os.path.join(res_path, 'dwell_state_mean_hc_test.npy'), dwell_state_mean_hc_test)
np.save(os.path.join(res_path, 'dwell_state_ste_hc_test.npy'), dwell_state_ste_hc_test)
np.save(os.path.join(res_path, 'transition_matrix_pt_test.npy'), transition_matrix_pt_test)
np.save(os.path.join(res_path, 'transition_matrix_hc_test.npy'), transition_matrix_hc_test)
np.save(os.path.join(res_path, 'transition_matrix_test.npy'), transition_matrix_test)
np.save(os.path.join(res_path, 'transition_matrix_pvalue_test.npy'), transition_matrix_pvalue_test)

In [None]:
data_dim = dfnc_data_train_2d.shape[1]
aux_dim = y_train.shape[1]

loader_params = {'num_workers': 1, 'pin_memory': True} if cuda else {}

ds_train = ConditionalDataset(dfnc_data_train_2d.astype(np.float32), y_train.astype(np.float32), device)
data_loader_train = DataLoader(ds_train, shuffle=False, batch_size=batch_size, **loader_params)

ds_test = ConditionalDataset(dfnc_data_test_2d.astype(np.float32), y_test.astype(np.float32), device)
data_loader_test = DataLoader(ds_test, shuffle=False, batch_size=batch_size, **loader_params)

In [None]:
if method == 'vae':
    model = VAE(input_dim=data_dim, 
                latent_dim=latent_dim, 
                hidden_dims=hidden_dim, 
                seed=seed)
    print(model)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=10)
    early_stopper = EarlyStopper(patience=20, threshold=1e-2, min_delta=1e-3)

    # train model
    model.train()
    for it in range(n_epoch):
        loss_train = 0
        for _, (x, _) in enumerate(data_loader_train):
            x = x.view(x.size(0), -1).to(device)
            optimizer.zero_grad()
            x_rec, mean, logvar = model(x)
            loss = model.loss(x_rec, x, mean, logvar)
            loss.backward()
            optimizer.step()
            loss_train += loss.item()
        loss_train /= len(data_loader_train)
        scheduler.step(loss_train)
        print(f'Epoch: {it}; Loss: {loss_train:.5f}')
        if early_stopper.early_stop(loss_train):
            print(f'Early stopping triggered!')
            break

    # save model checkpoint after training
    torch.save(model.state_dict(), ckpt_file)
    
    x_train, u_train = ds_train.x, ds_train.y
    x_test, u_test = ds_test.x, ds_test.y
    _, z_train, _ = model(x_train)
    _, z_test, _ = model(x_test)
    z_train = z_train.detach().cpu().numpy()
    z_test = z_test.detach().cpu().numpy()

elif method == 'ivae':
    model = iVAE(data_dim=data_dim,
                latent_dim=latent_dim,
                aux_dim=aux_dim,
                hidden_dim=hidden_dim,
                n_layer=n_layer,
                activation='xtanh',
                device=device,
                seed=seed)
    print(model)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=10)
    early_stopper = EarlyStopper(patience=20, threshold=1e-2, min_delta=1e-3)

    # train model
    model.train()

    for it in range(n_epoch):
        loss_train = 0
        for _, (x, u) in enumerate(data_loader_train):
            optimizer.zero_grad()
            x, u = x.to(device), u.to(device)
            loss, z_est = model.loss(x, u)
            loss.mul(-1).backward()
            optimizer.step()
            loss_train += -loss.item()
        loss_train /= len(data_loader_train)
        scheduler.step(loss_train)
        print(f'Epoch: {it}; Loss: {loss_train:.3f}')
        if early_stopper.early_stop(loss_train):
            print(f'Early stopping triggered!')
            break

    # save model checkpoint after training
    torch.save(model.state_dict(), ckpt_file)

    x_train, u_train = ds_train.x, ds_train.y
    x_test, u_test = ds_test.x, ds_test.y
    _, _, z_train, _ = model(x_train, u_train)
    _, _, z_test, _ = model(x_test, u_test)
    z_train = z_train.detach().cpu().numpy()
    z_test = z_test.detach().cpu().numpy()

In [None]:
centroid_ms = 120
cols = plt.cm.Oranges(np.linspace(0,1,n_window))
for i in range(n_pt_train-1):
  cols = np.concatenate([cols, plt.cm.Oranges(np.linspace(0,1,n_window))])
for i in range(n_hc_train):
  cols = np.concatenate([cols, plt.cm.Blues(np.linspace(0,1,n_window))])
plt.figure(figsize=(5, 5))
plt.xticks([], [])
plt.yticks([], [])
plt.scatter(z_train[:, 0], z_train[:, 1], marker='o', c=cols, s=1)
plt.tight_layout()
plt.savefig(os.path.join(res_path, 'vae_latent_space_train.png'), dpi=500)

In [None]:
lim = np.max(np.abs(z_train))

plt.figure(figsize=(10,5))

plt.subplot(1,2,1)
plt.xticks([], [])
plt.yticks([], [])
plt.scatter(z_train[:n_pt_train*n_window, 0], z_train[:n_pt_train*n_window, 1], marker='o', c=cols[:n_pt_train*n_window], s=1)
plt.xlim([-lim, lim])
plt.ylim([-lim, lim])

plt.subplot(1,2,2)
plt.xticks([], [])
plt.yticks([], [])
plt.scatter(z_train[n_pt_train*n_window:, 0], z_train[n_pt_train*n_window:, 1], marker='o', c=cols[n_pt_train*n_window:], s=1)
plt.xlim([-lim, lim])
plt.ylim([-lim, lim])

plt.tight_layout()
plt.savefig(os.path.join(res_path, 'vae_latent_space_train_pt_hc.png'), dpi=500)

In [None]:
vae_dfnc_data_train = np.zeros(dfnc_data_train_2d.shape)
for i in range(dfnc_data_train_2d.shape[0]):
  vae_z = torch.Tensor([[z_train[i,0], z_train[i,1]]])
  vae_dfnc_data_train[i,:] = np.squeeze(model.decode(vae_z).detach().numpy())
np.save(os.path.join(res_path, 'z_train.npy'), z_train)
np.save(os.path.join(res_path, 'z_test.npy'), z_test)
np.save(os.path.join(res_path, 'vae_dfnc_data_train.npy'), vae_dfnc_data_train)

# z_train = np.load(os.path.join(res_path, 'z_train.npy'))
# z_test = np.load(os.path.join(res_path, 'z_test.npy'))
# vae_dfnc_data_train = np.load(os.path.join(res_path, 'vae_dfnc_data_train.npy'))

In [None]:
model = VAE(input_dim=data_dim, 
            latent_dim=latent_dim, 
            hidden_dims=hidden_dim, 
            seed=seed)
checkpoint = torch.load(ckpt_file, map_location=torch.device('cpu'), weights_only=False)
model.load_state_dict(checkpoint)
model.eval()

In [None]:
vae_kmeans_list = []
vae_loss_l1_list = []
vae_loss_l2_list = []
vae_cluster_center_list = []
vae_label_list = []
vae_all_state_list = []

for nc in range(2,10):
  vae_kmeans = KMeans(n_clusters=nc, random_state=0).fit(z_train)
  vae_kmeans_list.append(vae_kmeans)
  vae_cluster_center = vae_kmeans.cluster_centers_
  vae_cluster_center_list.append(vae_cluster_center)
  vae_label = vae_kmeans.labels_
  vae_label_list.append(vae_label)

  # K-means states
  state_list = []
  for i in range(nc):
    vae_z = torch.Tensor([[vae_cluster_center[i,0], vae_cluster_center[i,1]]])
    x_reconstructed = np.squeeze(model.decode(vae_z).detach().numpy())
    state = vector2matrix(x_reconstructed)
    state_list.append(state)
  vae_all_state_list.append(state_list)

  loss_l1, loss_l2 = 0, 0
  for i in range(len(vae_label)):
    loss_l1 += np.sum(np.abs(vae_cluster_center[vae_label[i]]-z_train[i,:])) # L1 norm
    loss_l2 += np.sum((vae_cluster_center[vae_label[i]]-z_train[i,:])**2) # L2 norm
  loss_l1 /= n_train*n_window
  loss_l2 /= n_train*n_window
  vae_loss_l1_list.append(loss_l1)
  vae_loss_l2_list.append(loss_l2)

In [None]:
fig, axes = plt.subplots(1,2,figsize=(10,3))
axes[0].plot(range(2,10), vae_loss_l1_list)
axes[0].set_xlabel("k", fontsize=12)
axes[0].set_ylabel("L1 loss", fontsize=12)
axes[1].plot(range(2,10), vae_loss_l2_list)
axes[1].set_xlabel("k", fontsize=12)
axes[1].set_ylabel("L2 loss", fontsize=12)
plt.tight_layout()
plt.savefig(os.path.join(res_path, 'vae_dfnc_kmeans_loss.png'), dpi=500)

In [None]:
np.save(os.path.join(res_path, 'vae_loss_l1_list.npy'), vae_loss_l1_list)
np.save(os.path.join(res_path, 'vae_loss_l2_list.npy'), vae_loss_l2_list)
np.save(os.path.join(res_path, f'vae_label_state{n_state}.npy'), vae_label_list[n_state-2])
np.save(os.path.join(res_path, f'vae_cluster_center_state{n_state}.npy'), vae_cluster_center_list[n_state-2])
np.save(os.path.join(res_path, f'vae_all_state_state{n_state}.npy'), vae_all_state_list[n_state-2])

# vae_loss_l1_list = np.load(os.path.join(res_path, 'vae_loss_l1_list.npy'))
# vae_loss_l2_list = np.load(os.path.join(res_path, 'vae_loss_l2_list.npy'))
# vae_kmeans_label = np.load(os.path.join(res_path, 'vae_label_state5.npy'))
# vae_kmeans_cluster_center = np.load(os.path.join(res_path, 'vae_cluster_center_state5.npy'))
# vae_kmeans_state = np.load(os.path.join(res_path, 'vae_all_state_state5.npy'))

In [None]:
vae_kmeans_state = vae_all_state_list[n_state-2]
vae_kmeans_label = vae_label_list[n_state-2]
vae_kmeans_cluster_center = vae_cluster_center_list[n_state-2]
vae_num_sub_per_state, vae_ratio_sub_per_state = compute_sub_per_state(kmeans_label=vae_kmeans_label, n_pt=n_pt_train, n_window=n_window)
vae_num_fnc_per_state, vae_ratio_fnc_per_state = compute_fnc_per_state(kmeans_label=vae_kmeans_label, n_pt=n_pt_train, n_window=n_window)

vae_label_pt = vae_kmeans_label[:n_pt_train*n_window]
vae_label_hc = vae_kmeans_label[n_pt_train*n_window:]
vae_dfnc_data_train_pt = vae_dfnc_data_train[:n_pt_train*n_window,:]
vae_dfnc_data_train_hc = vae_dfnc_data_train[n_pt_train*n_window:,:]

vae_dfnc_state = np.zeros((3, n_state, 53, 53))
for i in range(n_state):
  cluster_median_1d_pt = np.median(vae_dfnc_data_train_pt[np.where(vae_label_pt == i)[0], :], axis=0)
  cluster_median_1d_hc = np.median(vae_dfnc_data_train_hc[np.where(vae_label_hc == i)[0], :], axis=0)
  vae_dfnc_state[0,i] = vector2matrix(cluster_median_1d_pt)
  vae_dfnc_state[1,i] = vector2matrix(cluster_median_1d_hc)
  vae_dfnc_state[2,i] = vae_dfnc_state[0,i] - vae_dfnc_state[1,i]

# sort dFNC states by similarity between original and generated dFNC states
corr = np.zeros((2, n_state, n_state))
for i in range(n_state):
  for j in range(n_state):
    for k in range(2):
      corr[k,i,j] = np.corrcoef((dfnc_state[k,i,:].flatten(), vae_dfnc_state[k,j,:].flatten()))[0,1]
vae_sorted_state_ind = np.argmax(corr[0], axis=1)

vae_unique_sorted_state_ind = find_unique_ind(vae_sorted_state_ind, corr, vae_ratio_fnc_per_state)
vae_ratio_fnc_per_state_sorted = vae_ratio_fnc_per_state[:, vae_unique_sorted_state_ind]

print(vae_num_sub_per_state)
print(vae_ratio_sub_per_state)
print(vae_num_fnc_per_state)
print(vae_ratio_fnc_per_state)
print(vae_unique_sorted_state_ind)

In [None]:
corr_state = np.zeros((2, n_state))
vae_dfnc_state_sorted = np.zeros((3, n_state, 53, 53))

fig, axes = plt.subplots(nrows=3, ncols=n_state, figsize=(5*n_state, 5*3))

for i, j in enumerate(np.arange(n_state)[vae_unique_sorted_state_ind]):
  state_pt = vae_dfnc_data_train_pt[np.where(vae_label_pt == j)[0], :]
  state_hc = vae_dfnc_data_train_hc[np.where(vae_label_hc == j)[0], :]
  stat, pvalue = stats.ttest_ind(a=state_pt, b=state_hc, equal_var=True)
  pvalue_map = vector2matrix(pvalue)
  pvalue_mask = pvalue_map <= (0.05/len(pvalue))
  upper_triangle_mask = np.triu(np.ones_like(pvalue_mask)).astype(bool)
  pvalue_mask[upper_triangle_mask] = 1
  cluster_median_1d_pt = np.median(state_pt, axis=0)
  cluster_median_1d_hc = np.median(state_hc, axis=0)
  vae_dfnc_state_sorted[0,i] = vector2matrix(cluster_median_1d_pt)
  vae_dfnc_state_sorted[1,i] = vector2matrix(cluster_median_1d_hc)
  vae_dfnc_state_sorted[2,i] = vae_dfnc_state_sorted[0,i] - vae_dfnc_state_sorted[1,i]

  if i == 0:
    show_xticks = True
  else:
    show_xticks = False

  plot_fnc(vae_dfnc_state_sorted[0,i], axes[0,i], f"{round(vae_ratio_fnc_per_state_sorted[0,i]*100,1)}% ASD\n{round(vae_ratio_fnc_per_state_sorted[1,i]*100,1)}% CTR", show_xticks=show_xticks)
  plot_fnc(vae_dfnc_state_sorted[1,i], axes[1,i], show_xticks=show_xticks)
  plot_fnc(vae_dfnc_state_sorted[2,i] * pvalue_mask * 2, axes[2,i], show_xticks=show_xticks)

  corr_state[0,i] = np.corrcoef((dfnc_state[0,i,:].flatten(), vae_dfnc_state_sorted[0,i,:].flatten()))[0,1]
  corr_state[1,i] = np.corrcoef((dfnc_state[1,i,:].flatten(), vae_dfnc_state_sorted[1,i,:].flatten()))[0,1]
  print(corr_state[0,i], corr_state[1,i])

plt.tight_layout()
plt.savefig(os.path.join(res_path, f'vae_dfnc_kmeans_{n_state}states_pt_hc_pt-hc_sorted.png'), dpi=500)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=n_state, figsize=(5*n_state, 5))

for i, j in enumerate(np.arange(n_state)[vae_unique_sorted_state_ind]):
  state = vae_dfnc_data_train[np.where(vae_kmeans_label==j)[0],:]
  state_median_1d = np.median(state, axis=0)
  state_median_2d = vector2matrix(state_median_1d)
  plot_fnc(state_median_2d, axes[i], f"State {i+1}")

plt.tight_layout()
plt.savefig(os.path.join(res_path, f'vae_dfnc_kmeans{n_state}states_pt_sorted.png'), bbox_inches='tight', dpi=500)

In [None]:
vae_dwell_time = vae_num_fnc_per_state / vae_num_sub_per_state
vae_dwell_time_sorted = vae_dwell_time[:, vae_unique_sorted_state_ind]

vae_kmeans_label_2d = vae_kmeans_label.reshape((n_train, n_window))
vae_dwell_state_mean_pt, vae_dwell_state_ste_pt, vae_dwell_state_mean_hc, vae_dwell_state_ste_hc, vae_dwell_state_pvalue = compute_dwell_state(vae_kmeans_label_2d, vae_unique_sorted_state_ind, n_pt_train, n_hc_train)
vae_transition_matrix_pt, vae_transition_matrix_hc, vae_transition_matrix, vae_transition_matrix_pvalue = compute_transition_matrix(vae_kmeans_label_2d, vae_unique_sorted_state_ind, n_pt_train)

np.save(os.path.join(res_path, 'vae_dwell_time.npy'), vae_dwell_time_sorted)
np.save(os.path.join(res_path, 'vae_dwell_state_pvalue.npy'), vae_dwell_state_pvalue)
np.save(os.path.join(res_path, 'vae_dwell_state_mean_pt.npy'), vae_dwell_state_mean_pt)
np.save(os.path.join(res_path, 'vae_dwell_state_ste_pt.npy'), vae_dwell_state_ste_pt)
np.save(os.path.join(res_path, 'vae_dwell_state_mean_hc.npy'), vae_dwell_state_mean_hc)
np.save(os.path.join(res_path, 'vae_dwell_state_ste_hc.npy'), vae_dwell_state_ste_hc)
np.save(os.path.join(res_path, 'vae_transition_matrix_pt.npy'), vae_transition_matrix_pt)
np.save(os.path.join(res_path, 'vae_transition_matrix_hc.npy'), vae_transition_matrix_hc)
np.save(os.path.join(res_path, 'vae_sorted_state_ind.npy'), vae_unique_sorted_state_ind)
np.save(os.path.join(res_path, 'vae_kmeans_cluster_center.npy'), vae_kmeans_cluster_center)
np.save(os.path.join(res_path, 'vae_kmeans_label.npy'), vae_kmeans_label)
np.save(os.path.join(res_path, 'vae_transition_matrix.npy'), vae_transition_matrix)
np.save(os.path.join(res_path, 'vae_transition_matrix_pvalue.npy'), vae_transition_matrix_pvalue)

In [None]:
vae_dfnc_data_test = np.zeros(dfnc_data_test_2d.shape)
for i in range(dfnc_data_test_2d.shape[0]):
  vae_z = torch.Tensor([[z_test[i,0], z_test[i,1]]])
  vae_dfnc_data_test[i,:] = np.squeeze(model.decode(vae_z).detach().numpy())
np.save(os.path.join(res_path, 'vae_dfnc_data_test.npy'), vae_dfnc_data_test)

# vae_dfnc_data_test = np.load(os.path.join(res_path, 'vae_dfnc_data_test.npy'))
# vae_kmeans_label_test = np.load(os.path.join(res_path, 'vae_kmeans_label_test.npy'))

In [None]:
kmeans = vae_kmeans_list[n_state-2]
vae_kmeans_label_test = kmeans.predict(z_test)

vae_num_sub_per_state_test, vae_ratio_sub_per_state_test = compute_sub_per_state(kmeans_label=vae_kmeans_label_test, n_pt=n_pt_test, n_window=n_window)
vae_num_fnc_per_state_test, vae_ratio_fnc_per_state_test = compute_fnc_per_state(kmeans_label=vae_kmeans_label_test, n_pt=n_pt_test, n_window=n_window)

vae_label_pt_test = vae_kmeans_label_test[:n_pt_test*n_window]
vae_label_hc_test = vae_kmeans_label_test[n_pt_test*n_window:]
vae_dfnc_data_test_pt = vae_dfnc_data_test[:n_pt_test*n_window,:]
vae_dfnc_data_test_hc = vae_dfnc_data_test[n_pt_test*n_window:,:]

vae_dfnc_state_test = np.zeros((3, n_state, 53, 53))
for i in range(n_state):
  cluster_median_1d_pt = np.median(vae_dfnc_data_test_pt[np.where(vae_label_pt_test == i)[0], :], axis=0)
  cluster_median_1d_hc = np.median(vae_dfnc_data_test_hc[np.where(vae_label_hc_test == i)[0], :], axis=0)
  vae_dfnc_state_test[0,i] = vector2matrix(cluster_median_1d_pt)
  vae_dfnc_state_test[1,i] = vector2matrix(cluster_median_1d_hc)
  vae_dfnc_state_test[2,i] = vae_dfnc_state_test[0,i] - vae_dfnc_state_test[1,i]

corr_test = np.zeros((2, n_state, n_state))
for i in range(n_state):
  for j in range(n_state):
    for k in range(2):
      corr_test[k,i,j] = np.corrcoef((dfnc_state_test[k,i,:].flatten(), vae_dfnc_state_test[k,j,:].flatten()))[0,1]

vae_sorted_state_ind_test = np.argmax(corr_test[0], axis=1)
vae_unique_sorted_state_ind_test = find_unique_ind(vae_sorted_state_ind_test, corr_test, vae_ratio_fnc_per_state_test)
vae_ratio_fnc_per_state_sorted_test = vae_ratio_fnc_per_state_test[:, vae_unique_sorted_state_ind_test]
print(vae_unique_sorted_state_ind_test)

In [None]:
vae_corr_state_test = np.zeros((2, n_state))
vae_dfnc_state_test = np.zeros((3, n_state, 53, 53))

fig, axes = plt.subplots(nrows=3, ncols=n_state, figsize=(5*n_state, 5*3))

for i, j in enumerate(np.arange(n_state)[vae_unique_sorted_state_ind_test]):
  state_pt = vae_dfnc_data_test_pt[np.where(vae_label_pt_test == j)[0], :]
  state_hc = vae_dfnc_data_test_hc[np.where(vae_label_hc_test == j)[0], :]

  stat, pvalue = stats.ttest_ind(a=state_pt, b=state_hc, equal_var=True)
  pvalue_map = vector2matrix(pvalue)
  pvalue_mask = pvalue_map <= (0.05/len(pvalue))
  upper_triangle_mask = np.triu(np.ones_like(pvalue_mask)).astype(bool)
  pvalue_mask[upper_triangle_mask] = 1

  cluster_median_1d_pt = np.median(state_pt, axis=0)
  cluster_median_1d_hc = np.median(state_hc, axis=0)

  vae_dfnc_state_test[0,i] = vector2matrix(cluster_median_1d_pt)
  vae_dfnc_state_test[1,i] = vector2matrix(cluster_median_1d_hc)
  vae_dfnc_state_test[2,i] = vae_dfnc_state_test[0,i] - vae_dfnc_state_test[1,i]

  if i == 0:
    show_xticks = True
  else:
    show_xticks = False

  plot_fnc(vae_dfnc_state_test[0,i], axes[0,i], f"{round(vae_ratio_fnc_per_state_sorted_test[0,i]*100,1)}% ASD\n{round(vae_ratio_fnc_per_state_sorted_test[1,i]*100,1)}% CTR", show_xticks=show_xticks)
  plot_fnc(vae_dfnc_state_test[1,i], axes[1,i], show_xticks=show_xticks)
  plot_fnc(vae_dfnc_state_test[2,i] * pvalue_mask * 2, axes[2,i], show_xticks=show_xticks)

  vae_corr_state_test[0,i] = np.corrcoef((dfnc_state_test[0,i,:].flatten(), vae_dfnc_state_test[0,i,:].flatten()))[0,1]
  vae_corr_state_test[1,i] = np.corrcoef((dfnc_state_test[1,i,:].flatten(), vae_dfnc_state_test[1,i,:].flatten()))[0,1]
  print(vae_corr_state_test[0,i], vae_corr_state_test[1,i])

plt.tight_layout()
plt.savefig(os.path.join(res_path, f'vae_dfnc_kmeans_{n_state}states_pt_hc_pt-hc_sorted_test.png'), dpi=500)

In [None]:
vae_dwell_time_test = vae_num_fnc_per_state_test / vae_num_sub_per_state_test
vae_dwell_time_sorted_test = vae_dwell_time_test[:, vae_unique_sorted_state_ind_test]

vae_kmeans_label_2d_test = vae_kmeans_label_test.reshape((n_test, n_window))
vae_dwell_state_mean_pt_test, vae_dwell_state_ste_pt_test, vae_dwell_state_mean_hc_test, vae_dwell_state_ste_hc_test, vae_dwell_state_pvalue_test = compute_dwell_state(vae_kmeans_label_2d_test, vae_unique_sorted_state_ind_test, n_pt_test, n_hc_test)
vae_transition_matrix_pt_test, vae_transition_matrix_hc_test, vae_transition_matrix_test, vae_transition_matrix_pvalue_test = compute_transition_matrix(vae_kmeans_label_2d_test, vae_unique_sorted_state_ind_test, n_pt_test)

np.save(os.path.join(res_path, 'vae_dwell_time_test.npy'), vae_dwell_time_sorted_test)
np.save(os.path.join(res_path, 'vae_dwell_state_pvalue_test.npy'), vae_dwell_state_pvalue_test)
np.save(os.path.join(res_path, 'vae_dwell_state_mean_pt_test.npy'), vae_dwell_state_mean_pt_test)
np.save(os.path.join(res_path, 'vae_dwell_state_ste_pt_test.npy'), vae_dwell_state_ste_pt_test)
np.save(os.path.join(res_path, 'vae_dwell_state_mean_hc_test.npy'), vae_dwell_state_mean_hc_test)
np.save(os.path.join(res_path, 'vae_dwell_state_ste_hc_test.npy'), vae_dwell_state_ste_hc_test)
np.save(os.path.join(res_path, 'vae_transition_matrix_pt_test.npy'), vae_transition_matrix_pt_test)
np.save(os.path.join(res_path, 'vae_transition_matrix_hc_test.npy'), vae_transition_matrix_hc_test)
np.save(os.path.join(res_path, 'vae_sorted_state_ind_test.npy'), vae_unique_sorted_state_ind_test)
np.save(os.path.join(res_path, 'vae_kmeans_label_test.npy'), vae_kmeans_label_test)
np.save(os.path.join(res_path, 'vae_transition_matrix_test.npy'), vae_transition_matrix_test)
np.save(os.path.join(res_path, 'vae_transition_matrix_pvalue_test.npy'), vae_transition_matrix_pvalue_test)

In [None]:
dfnc_corr_matrix_train = np.corrcoef(vae_dfnc_data_train, dfnc_data_train_2d)
dfnc_corr_vector_train = np.diag(dfnc_corr_matrix_train[n_train*n_window:,:n_train*n_window])
dfnc_corr_matrix_test = np.corrcoef(vae_dfnc_data_test, dfnc_data_test_2d)
dfnc_corr_vector_test = np.diag(dfnc_corr_matrix_test[n_test*n_window:,:n_test*n_window])

print(np.mean(dfnc_corr_vector_train))
print(np.mean(dfnc_corr_vector_test))

np.save(os.path.join(res_path, 'dfnc_corr_train.npy'), dfnc_corr_vector_train)
np.save(os.path.join(res_path, 'dfnc_corr_test.npy'), dfnc_corr_vector_test)