# Latent Representation Learning

This notebook implements the latent learning fusion method introduced in https://pubmed.ncbi.nlm.nih.gov/31021792/

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter, defaultdict
from importlib import reload
import seaborn as sns
import os, sys
import torch
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler
from sklearn.model_selection import train_test_split

In [2]:
target_encoder = LabelEncoder()
onehot_encoder = OneHotEncoder(sparse=False)

def load_data(impute_method = 'Mean-Mode', target = 'NACCAD3'):
    uds = pd.read_csv("../data/data_imputed/{}/uds.csv".format(impute_method))
    uds['datetime'] = pd.to_datetime(uds['datetime'])
    uds = uds.dropna(subset=[target, 'EDUC'])
    print("UDS Shape:  ", uds.shape)
    print("Target Distribution: {}\n".format(Counter(uds[target])))
    uds[target] = target_encoder.fit_transform(uds[target])
    onehot_encoder.fit(uds[target].values.reshape(-1, 1))
    mri = pd.read_csv("../data/data_imputed/{}/mri.csv".format(impute_method))
    mri['datetime'] = pd.to_datetime(mri['datetime'])
    
    csf = pd.read_csv("../data/data_imputed/{}/csf.csv".format(impute_method))
    return uds, mri, csf

uds_dict = pd.read_csv("../data/data_dictionary/uds_feature_dictionary_cleaned.csv")
mri_dict = pd.read_csv("../data/data_dictionary/mri_feature_dictionary_cleaned.csv") 

uds_drop_columns = ['NACCID', 'NACCADC', 'NACCVNUM', 'datetime', 'SEX', 'NACCAGE','EDUC', 
                    'NACCUDSD', 'NACCALZP', 'NACCAD3', 'NACCAD5']
mri_drop_columns = ['NACCID', 'NACCVNUM', 'datetime', 'datetime_UDS', 'timediff', 'within-a-year']
csf_drop_columns = ['NACCID', 'CSFABMD', 'CSFTTMD', 'CSFPTMD']

In [3]:
uds, mri, csf = load_data(impute_method = 'Mean-Mode', target = 'NACCAD3')
print(uds.shape, mri.shape, csf.shape)

UDS Shape:   (34025, 89)
Target Distribution: Counter({'Healthy': 17673, 'Dementia-AD': 11882, 'MCI-AD': 4470})

(34025, 89) (2873, 161) (2180, 7)


# Latent Representation Learning 

In [4]:
common_pid = set(uds['NACCID']).intersection(mri['NACCID']).intersection(csf['NACCID'])
cpidTr, cpidTe = train_test_split(np.array(list(common_pid)), test_size = 0.3)

def train_test_split_modality(df, cpidTr, cpidTe):
    dfV = df[~df['NACCID'].isin(common_pid)]
    idx = dfV.index.to_frame()
    iTr, iTe = train_test_split(idx[0], test_size = 0.3)
    dfVTr, dfVTe = dfV.loc[iTr], dfV.loc[iTe]
    dfTr = pd.concat([df[df['NACCID'].isin(cpidTr)], dfVTr], axis=0).reset_index(drop=True)
    dfTe = pd.concat([df[df['NACCID'].isin(cpidTe)], dfVTe], axis=0).reset_index(drop=True)
    return dfTr, dfTe

mri, csf = mri[mri['NACCID'].isin(uds['NACCID'])], csf[csf['NACCID'].isin(uds['NACCID'])]
uds, mri, csf = uds.sort_values('NACCID'), mri.sort_values('NACCID'), csf.sort_values('NACCID')

udsTr, udsTe = train_test_split_modality(uds, cpidTr, cpidTe)
mriTr, mriTe = train_test_split_modality(mri, cpidTr, cpidTe)
csfTr, csfTe = train_test_split_modality(csf, cpidTr, cpidTe) 

In [5]:
def prepare_X_Y(udsT, mriT, csfT, scalerU=None, scalerM=None, scalerC=None):
    XU, XM, XC = udsT.drop(uds_drop_columns, axis=1), mriT.drop(mri_drop_columns, axis=1), csfT.drop(csf_drop_columns, axis=1)
    if scalerU is None:
        scalerU, scalerM, scalerC = StandardScaler(), StandardScaler(), StandardScaler()
        XU, XM, XC = scalerU.fit_transform(XU).T, scalerM.fit_transform(XM).T, scalerC.fit_transform(XC).T
    else:
        XU, XM, XC = scalerU.transform(XU).T, scalerM.transform(XM).T, scalerC.transform(XC).T
    XU, XM, XC = torch.tensor(XU).float(), torch.tensor(XM).float(), torch.tensor(XC).float()

    YU, YM, YC = udsT['NACCAD3'], uds.merge(mriT, on='NACCID')['NACCAD3'], uds.merge(csfT, on='NACCID')['NACCAD3']
    YU, YM, YC = onehot_encoder.transform(YU.values.reshape(-1, 1)),onehot_encoder.transform(YM.values.reshape(-1, 1)),onehot_encoder.transform(YC.values.reshape(-1, 1))
    YU, YM, YC = torch.tensor(YU.T).float(), torch.tensor(YM.T).float(), torch.tensor(YC.T).float()
    return XU, XM, XC, YU, YM, YC, scalerU, scalerM, scalerC

In [6]:
def train_weights(W, X, H, E, Q, mu, verbose = False):
    for i in range(100):
        G = torch.diag((W**2).sum(1))
        W_new = torch.inverse(X.matmul(X.T) + (2*beta) / mu * G).matmul(X).matmul((H + E - Q / mu).T)
        if ((W_new - W)**2).sum() < 1e-6:
            break
        W = W_new
    if verbose and i == 100:
        print("Not converged")
    return W

def update_E(E, W, X, H, Q, mu):
    E_new = torch.zeros_like(E)
    temp = W.T.matmul(X) - H + Q / mu
    E_new = E_new + (temp > gamma / mu).int() * temp - gamma/mu
    E_new = E_new + (temp < -gamma / mu).int() * temp + gamma/mu
    return E_new

def calculate_accuracy(XU, XM, XC, YU, YM, YC, WU, WM, WC, P, nH):
    HU, HM, HC = WU.T.matmul(XU), WM.T.matmul(XM), WC.T.matmul(XC)
    H = (HU[:,:nH] + HM[:,:nH] + HC[:,:nH]) / 3
    YU_hat = P.matmul(torch.concat([H, HU[:,nH:]], axis=1)).argmax(axis=0)
    YM_hat = P.matmul(torch.concat([H, HM[:,nH:]], axis=1)).argmax(axis=0)
    YC_hat = P.matmul(torch.concat([H, HC[:,nH:]], axis=1)).argmax(axis=0)
    return "U-acc: {:.2f}%\tM-acc: {:.2f}%\tC-acc: {:.2f}%".format(
        (YU_hat == YU.argmax(0)).float().mean() * 100,
        (YM_hat == YM.argmax(0)).float().mean() * 100,
        (YC_hat == YC.argmax(0)).float().mean() * 100)
    
# maxiter, stop, rho, max_mu =1000, 1, 1.5, 1e6
def train_latent_features(maxiter=1000, stop = 1, rho = 1.5, mu=1e-2, max_mu = 1e6):
#     Train individual weights
    for i in range(maxiter):
        H = torch.randn(kH, nH)
        P = torch.randn(3, kH)
        WU, WM, WC = torch.randn(kU, kH), torch.randn(kM, kH), torch.randn(kC, kH)
        HU, EU, QU = torch.randn(kH, XU.shape[1]-nH), torch.randn(kH, XU.shape[1]), torch.ones(kH, XU.shape[1])
        HM, EM, QM = torch.randn(kH, XM.shape[1]-nH), torch.randn(kH, XM.shape[1]), torch.ones(kH, XM.shape[1])
        HC, EC, QC = torch.randn(kH, XC.shape[1]-nH), torch.randn(kH, XC.shape[1]), torch.ones(kH, XC.shape[1])

        WU = train_weights(WU, XU, torch.concat([H, HU], axis=1), EU, QU, mu)
        WM = train_weights(WM, XM, torch.concat([H, HM], axis=1), EM, QM, mu)
        WC = train_weights(WC, XC, torch.concat([H, HC], axis=1), EC, QC, mu)

        EU_new = update_E(EU, WU, XU, torch.concat([H, HU], axis=1), QU, mu)
        EM_new = update_E(EM, WM, XM, torch.concat([H, HM], axis=1), QM, mu)
        EC_new = update_E(EC, WC, XC, torch.concat([H, HC], axis=1), QC, mu)

        HU = torch.inverse((P.T.matmul(P)) + mu * torch.eye(kH)).matmul(
            P.T.matmul(YU[:,nH:]) + mu * (WU.T.matmul(XU[:,nH:]) - EU[:,nH:] + QU[:,nH:] / mu))
        HM = torch.inverse((P.T.matmul(P)) + mu * torch.eye(kH)).matmul(
            P.T.matmul(YM[:,nH:]) + mu * (WM.T.matmul(XM[:,nH:]) - EM[:,nH:] + QM[:,nH:] / mu))
        HC = torch.inverse((P.T.matmul(P)) + mu * torch.eye(kH)).matmul(
            P.T.matmul(YC[:,nH:]) + mu * (WC.T.matmul(XC[:,nH:]) - EC[:,nH:] + QC[:,nH:] / mu))

        # train global parameters
        H = torch.inverse(P.T.matmul(P) + mu * torch.eye(kH) * 3).matmul(P.T.matmul(YU[:,:nH]) + mu * (
        (WU.T.matmul(XU[:,:nH]) - EU[:,:nH] + QU[:,:nH] / mu) + 
        (WM.T.matmul(XM[:,:nH]) - EM[:,:nH] + QM[:,:nH] / mu) + 
        (WC.T.matmul(XC[:,:nH]) - EC[:,:nH] + QC[:,:nH] / mu)
        ))

        H_all = torch.concat([H, HU, HM, HC], axis=1)
        Y_all = torch.concat([YU[:,:nH], YU[:,nH:], YM[:,nH:], YC[:,nH:]],axis=1)
        P = Y_all.matmul(H_all.T).matmul(torch.inverse(H_all.matmul(H_all.T) + eta * torch.eye(kH)))

        QU = QU + mu * (WU.T.matmul(XU) - torch.concat([H, HU], axis=1) - EU)
        QM = QM + mu * (WM.T.matmul(XM) - torch.concat([H, HM], axis=1) - EM)
        QC = QC + mu * (WC.T.matmul(XC) - torch.concat([H, HC], axis=1) - EC)

        mu = min(rho * mu, max_mu)

        errU = torch.abs(WU.T.matmul(XU) - torch.concat([H, HU], axis=1) - EU).max()
        errM = torch.abs(WM.T.matmul(XM) - torch.concat([H, HM], axis=1) - EM).max()
        errC = torch.abs(WC.T.matmul(XC) - torch.concat([H, HC], axis=1) - EC).max()
        error = max(errU, errM, errC)
        if error < stop:
            break
        if i % 20 == 0:
            print("Iter {}:  error-{:.3f}\t{}".format(i, error, 
                                                    calculate_accuracy(XU, XM, XC, YU, YM, YC, WU, WM, WC, P, len(cpidTr))))
    if i == maxiter:
        print("Not Converged")
    return WU, WM, WC, P

In [7]:
XU, XM, XC, YU, YM, YC, scalerU, scalerM, scalerC = prepare_X_Y(udsTr, mriTr, csfTr)
kU, kM, kC = XU.shape[0], XM.shape[0], XC.shape[0]

nH, kH = len(cpidTr), 20
lamb, beta, gamma, eta = 0.1, 0.1, 0.1, 0.1

WU, WM, WC, P = train_latent_features(maxiter=5000, stop = 1, rho = 1.5, mu=1e-2, max_mu = 1e6)

Iter 0:  error-117.747	U-acc: 23.23%	M-acc: 19.69%	C-acc: 27.54%
Iter 20:  error-3.490	U-acc: 57.17%	M-acc: 37.61%	C-acc: 39.89%
Iter 40:  error-3.638	U-acc: 75.57%	M-acc: 40.80%	C-acc: 37.84%
Iter 60:  error-2.931	U-acc: 71.25%	M-acc: 42.26%	C-acc: 42.25%
Iter 80:  error-3.047	U-acc: 73.98%	M-acc: 38.65%	C-acc: 46.18%
Iter 100:  error-3.479	U-acc: 70.40%	M-acc: 40.49%	C-acc: 29.66%
Iter 120:  error-3.489	U-acc: 71.62%	M-acc: 39.27%	C-acc: 41.54%
Iter 140:  error-3.264	U-acc: 72.99%	M-acc: 42.94%	C-acc: 43.67%
Iter 160:  error-3.265	U-acc: 70.50%	M-acc: 42.26%	C-acc: 39.58%
Iter 180:  error-3.247	U-acc: 68.08%	M-acc: 38.17%	C-acc: 40.52%
Iter 200:  error-3.028	U-acc: 73.96%	M-acc: 39.33%	C-acc: 42.25%
Iter 220:  error-3.467	U-acc: 64.07%	M-acc: 38.47%	C-acc: 40.91%
Iter 240:  error-3.036	U-acc: 73.54%	M-acc: 41.04%	C-acc: 41.62%
Iter 260:  error-3.049	U-acc: 74.76%	M-acc: 39.14%	C-acc: 42.41%
Iter 280:  error-3.064	U-acc: 67.06%	M-acc: 38.59%	C-acc: 39.89%
Iter 300:  error-3.145	U-acc:

Iter 2500:  error-3.149	U-acc: 72.30%	M-acc: 40.00%	C-acc: 38.87%
Iter 2520:  error-3.485	U-acc: 71.42%	M-acc: 40.43%	C-acc: 48.78%
Iter 2540:  error-3.072	U-acc: 73.08%	M-acc: 39.88%	C-acc: 38.79%
Iter 2560:  error-3.233	U-acc: 66.48%	M-acc: 40.92%	C-acc: 39.65%
Iter 2580:  error-3.249	U-acc: 75.02%	M-acc: 41.47%	C-acc: 42.41%
Iter 2600:  error-3.293	U-acc: 69.10%	M-acc: 41.96%	C-acc: 43.90%
Iter 2620:  error-3.259	U-acc: 74.96%	M-acc: 41.65%	C-acc: 47.21%
Iter 2640:  error-3.111	U-acc: 73.22%	M-acc: 41.77%	C-acc: 40.52%
Iter 2660:  error-3.034	U-acc: 74.26%	M-acc: 39.02%	C-acc: 42.80%
Iter 2680:  error-3.015	U-acc: 74.58%	M-acc: 39.82%	C-acc: 44.93%
Iter 2700:  error-3.314	U-acc: 73.34%	M-acc: 38.41%	C-acc: 45.16%
Iter 2720:  error-3.252	U-acc: 70.62%	M-acc: 41.28%	C-acc: 48.70%
Iter 2740:  error-3.073	U-acc: 70.93%	M-acc: 38.65%	C-acc: 44.93%
Iter 2760:  error-2.910	U-acc: 69.58%	M-acc: 40.98%	C-acc: 45.55%
Iter 2780:  error-3.249	U-acc: 57.57%	M-acc: 34.50%	C-acc: 36.90%
Iter 2800:

In [8]:
XUTe, XMTe, XCTe, YUTe, YMTe, YCTe, scalerU, scalerM, scalerC = prepare_X_Y(udsTe, mriTe, csfTe, scalerU, scalerM, scalerC)

In [10]:
print(calculate_accuracy(XUTe, XMTe, XCTe, YUTe, YMTe, YCTe, WU, WM, WC, P, len(cpidTr)))

U-acc: 73.23%	M-acc: 36.84%	C-acc: 39.31%
