In [None]:
from pathlib import Path
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn import datasets
from sklearn.decomposition import TruncatedSVD
from sklearn.linear_model import LinearRegression
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer, SimpleImputer
import pandas as pd
import pickle
from scipy.spatial.distance import pdist, cdist
from scipy.spatial import procrustes
from scipy.linalg import orthogonal_procrustes
from scipy.linalg import svd
%matplotlib inline

In [None]:
def stable_component_marriage(Vm, Vw):
    n = len(Vm)

    Pm, Pw = _calc_prefs(Vm, Vw)
    # lists are acutally a little faster than arrays here
    w_partner = [-1 for ix in range(n)]
    m_free = [True for ix in range(n)]
    
    free_count = n
    
    while free_count > 0:
        m=0
        for m, mf in enumerate(m_free):
            if mf:
                break
    
        i = 0
        while i < n and m_free[m]:
            w = Pm[m, i]
            m1 = w_partner[w]
            if (m1 == -1):
                w_partner[w] = m
                m_free[m] = False
                free_count -= 1
            elif Pw[w, m] < Pw[w, m1]:
                w_partner[w] = m
                m_free[m] = False
                m_free[m1] = True
            i += 1
    return w_partner
    
def _calc_prefs(Vm, Vw):
    corrs = np.abs(1 - cdist(Vm, Vw, metric='correlation'))
    Pm = np.argsort(-corrs)
    Pw = np.argsort(-corrs.T)
    return Pm, Pw

def clean_phenotype_data(dat, cols, age_sex_regress=True, impute=True, censor=True):
    dat = dat.copy()
    # regress out age and sex
    # create age and sex regressors
    dat['age_z'] = (dat.age_at_scan - dat.age_at_scan.mean()) / dat.age_at_scan.std(ddof=1)
    dat['age_z_2'] = dat.age_z ** 2
    dat['age_z_by_sex'] = dat.age_z * dat.sex_at_birth
    dat['age_z_2_by_sex'] = dat.age_z_2 * dat.sex_at_birth
    for col in cols:
        dat[col] = dat[col].astype(float)
        if age_sex_regress:
            lreg = LinearRegression()
            X = dat.loc[dat[col].notnull(), ['age_z', 'sex_at_birth', 'age_z_2', 'age_z_by_sex', 'age_z_2_by_sex']]
            _ = lreg.fit(X, dat.loc[dat[col].notnull(), [col]])
            dat[col] = dat.loc[dat[col].notnull(), [col]] - lreg.predict(X)
    # z transform
    dat.loc[:, cols] = dat.loc[:, cols].apply(lambda x: (x - x.mean())/x.std(), axis=0)

    if censor:
        # censor
        dat[np.abs(dat.loc[:, cols]) > 3] = np.nan

    if impute:
        # impute
        imputer = IterativeImputer(random_state=seed, max_iter=1000, skip_complete=True)
        dat.loc[:, cols] = imputer.fit_transform(dat.loc[:, cols])
    else:
        clean_participants = dat.loc[dat.loc[:, cols].isnull().sum(1)== 0, 'participant'].values
        dat = dat.loc[dat.participant.isin(clean_participants)].copy()
    # re-z-score
    dat.loc[:, cols] = dat.loc[:, cols].apply(lambda x: (x - x.mean())/x.std(), axis=0)
    return dat

# def sm_corr(Vm, Vw):
#     w_partner = stable_component_marriage(Vm, Vw)
#     return (1-cdist(V[0], V[1], metric='correlation'))
    

In [None]:
seed = 13231123
rng = np.random.default_rng(seed)

In [None]:
data_dir = Path('../data')

In [None]:
pnc_pkl = data_dir / 'latent_reliability_PNC.pickle'
pnc_dat = pickle.loads(pnc_pkl.read_bytes())

In [None]:
for ix, vals in pnc_dat.items():
    break

# Create new phenotypic dataset

In [None]:
pnc_dir = data_dir / 'pnc'
test_dat = pd.read_csv(pnc_dir / 'phenotype_measures_testing.txt')
train_dat = pd.read_csv(pnc_dir / 'phenotype_measures_training.txt')
all_dat = pd.concat([test_dat, train_dat]).reset_index(drop=True)

In [None]:
phenotype_columns = pnc_dat[0]['before_svd_split1'].columns

In [None]:
# make sure there isn't any all nan participant
assert (all_dat.loc[:, phenotype_columns].isnull().mean()!=1).all()

In [None]:
all_pheno_dat = all_dat.loc[:, ['participant', 'family_ID', 'sex_at_birth', 'age_at_scan'] + list(phenotype_columns)]
all_pheno_dat['sex_at_birth'] = all_pheno_dat.sex_at_birth - 1

In [None]:
n_splits = 1000
split_participants = []
split_stats = []
for spix in range(n_splits):
    fids1 = all_pheno_dat.groupby('family_ID').participant.first().reset_index().family_ID.sample(frac=0.5, replace=False).values
    pids1 = all_pheno_dat.loc[all_pheno_dat.family_ID.isin(fids1), 'participant'].values
    pids2 = all_pheno_dat.loc[~all_pheno_dat.family_ID.isin(fids1), 'participant'].values
    split_participants.append((pids1, pids2))
    row = dict(
        split=spix,
        n_1=all_pheno_dat.family_ID.isin(fids1).sum(),
        n_fam_1=len(fids1),
        n_male_1=(all_pheno_dat.loc[all_pheno_dat.family_ID.isin(fids1), 'sex_at_birth']).sum(),
        frac_male_1=(all_pheno_dat.loc[all_pheno_dat.family_ID.isin(fids1), 'sex_at_birth']).mean(),
        mean_age_1=all_pheno_dat.loc[all_pheno_dat.family_ID.isin(fids1), 'age_at_scan'].mean(),
        std_age_1=all_pheno_dat.loc[all_pheno_dat.family_ID.isin(fids1), 'age_at_scan'].std(),

        n_2 = (~all_pheno_dat.family_ID.isin(fids1)).sum(),
        n_fam_2=all_pheno_dat.loc[~all_pheno_dat.family_ID.isin(fids1), 'family_ID'].nunique(),
        n_male_2=(all_pheno_dat.loc[~all_pheno_dat.family_ID.isin(fids1), 'sex_at_birth']).sum(),
        frac_male_2=(all_pheno_dat.loc[~all_pheno_dat.family_ID.isin(fids1), 'sex_at_birth']).mean(),
        mean_age_2=all_pheno_dat.loc[~all_pheno_dat.family_ID.isin(fids1), 'age_at_scan'].mean(),
        std_age_2=all_pheno_dat.loc[~all_pheno_dat.family_ID.isin(fids1), 'age_at_scan'].std()
    )
    split_stats.append(row)
split_stats=pd.DataFrame(split_stats)

In [None]:
split_stats.describe()

# See what we can do with split halves

In [None]:
phenotypes = clean_phenotype_data(all_dat, phenotype_columns).loc[:, phenotype_columns]


In [None]:
V1s = []
V2s = []
pct_var1s = []
pct_var2s = []
match_corrs = []
sp1dats = []
sp2dats = []
age_sex_regress = True

oU, os, oV = svd(phenotypes, full_matrices=False)

for iiii, (pids1, pids2) in enumerate(split_participants):
    print(iiii, end=',')
    sp1dat = clean_phenotype_data(all_dat.loc[all_dat.participant.isin(pids1)].copy(),
                                  phenotype_columns,
                                  age_sex_regress=age_sex_regress).loc[:, phenotype_columns]
    sp2dat = clean_phenotype_data(all_dat.loc[all_dat.participant.isin(pids2)].copy(),
                                  phenotype_columns,
                                  age_sex_regress=age_sex_regress).loc[:, phenotype_columns]
    sp1dats.append(sp1dat)
    sp2dats.append(sp2dat)
    U1, s1, V1 = svd(sp1dat, full_matrices=False)
    U2, s2, V2 = svd(sp2dat, full_matrices=False)

    component_mapping1 = stable_component_marriage(V1, oV)
    mapped_V1 = V1[component_mapping1]
    V1_signs = np.sign([np.corrcoef(oVc, mVc)[0,1] for oVc, mVc in zip(oV, mapped_V1)]).reshape(-1,1)
    mapped_V1 *= V1_signs
    component_mapping2 = stable_component_marriage(V2, oV)
    mapped_V2 = V2[component_mapping2]
    V2_signs = np.sign([np.corrcoef(oVc, mVc)[0,1] for oVc, mVc in zip(oV, mapped_V2)]).reshape(-1,1)
    mapped_V2 *= V2_signs
    
    mc = [np.corrcoef(oVc, mVc)[0,1] for oVc, mVc in zip(mapped_V1, mapped_V2)]
    match_corrs.append(mc)
    V1s.append(mapped_V1)
    V2s.append(mapped_V2)
    
    mapped_s1 = s1[component_mapping1]
    pct_var1 = s1 **2 / (s1 **2).sum()
    pct_var1s.append(pct_var1)
    mapped_s2 = s2[component_mapping2]
    pct_var2 = mapped_s2 **2 / (mapped_s2 **2).sum()
    pct_var2s.append(pct_var2)

In [None]:
alldats = zip(V1s, V2s, pct_var1s, pct_var2s, match_corrs, sp1dats, sp2dats, split_participants)

In [None]:
alldat_lod = []
for adix, (V1, V2, pct_var1, pct_var2, match_corr, sp1dat, sp2dat, (pids1, pids2)) in enumerate(alldats):
    alldat_lod.append(
        dict(
            V1=V1,
            V2=V2,
            pct_var1=pct_var1,
            pct_var2=pct_var2,
            match_corr=match_corr,
            sp1dat=sp1dat,
            sp2dat=sp2dat,
            pids1=pids1, 
            pids2=pids2
        )
    )

In [None]:
alldat_file = pnc_dir / 'alldat_arczi.pkl'
alldat_file.write_bytes(pickle.dumps(alldat_lod))

In [None]:
tad = pickle.loads(alldat_file.read_bytes())

In [None]:
np.all(tad[0]['V1'] == V1s[0])

In [None]:
spdats = list(zip(sp1dats, sp2dats))


In [None]:
spdat_file = pnc_dir / 'phenotype_split_arczi.pkl'
spdat_file.write_bytes(pickle.dumps(spdats))


In [None]:
sns.heatmap((1 - cdist(V1s[0], V2s[1], metric='correlation')), vmin=-1, vmax=1, cmap='RdBu')

In [None]:
mean_match_corr = np.tanh(np.mean(np.abs(np.arctanh(match_corrs)), 0))
mean_match_r2 = mean_match_corr ** 2
lci_match_corr = np.percentile(np.abs(match_corrs), 2.5, axis=0)
lci_match_r2 = lci_match_corr ** 2
uci_match_corr= np.percentile(np.abs(match_corrs), 97.5, axis=0)
uci_match_r2 = uci_match_corr ** 2

In [None]:
fig, ax = plt.subplots(1)
ax.fill_between(range(len(mean_match_corr)), lci_match_corr, uci_match_corr, alpha = 0.5)
ax.plot(mean_match_corr)
ax.set_ylim(0,1)
ax.set_ylabel('Inter-split correlation')
ax.set_xlabel('Latent number')

In [None]:
fig, ax = plt.subplots(1)
ax.fill_between(range(len(mean_match_r2)), lci_match_r2, uci_match_r2, alpha = 0.5)
ax.plot(mean_match_r2)
ax.set_ylim(0,1)
ax.set_ylabel('Inter-split $r^2$')
ax.set_xlabel('Latent number')

In [None]:
pct_vars = np.vstack([pct_var1s, pct_var2s])
cum_vars = np.cumsum(pct_vars, axis=1)

In [None]:
mean_pct_var = pct_vars.mean(0)
lci_pct_var = np.quantile(pct_vars, 0.025, axis=0)
uci_pct_var = np.quantile(pct_vars, 0.975, axis=0)

mean_cum_var = cum_vars.mean(0)
lci_cum_var = np.quantile(cum_vars, 0.025, axis=0)
uci_cum_var = np.quantile(cum_vars, 0.975, axis=0)

In [None]:
fig, ax = plt.subplots(1)
ax.fill_between(np.arange(len(lci_pct_var)), lci_pct_var, uci_pct_var, alpha = 0.5)
ax.plot(mean_pct_var)

ax.fill_between(np.arange(len(lci_cum_var)), lci_cum_var, uci_cum_var, alpha = 0.5)
ax.plot(mean_cum_var)

ymin, ymax = ax.get_ylim()
xmin, xmax = ax.get_xlim()
ax.hlines(0.95, xmin, (lci_cum_var > 0.95).nonzero()[0].min()-0.5, ls='--', color='black')
ax.vlines((lci_cum_var > 0.95).nonzero()[0].min()-0.5, ymin, 0.95, ls='--', color='black')
ax.text

In [None]:

age_sex_regress = True

oU, os, oV = svd(phenotypes, full_matrices=False)

for iiii, (pids1, pids2) in enumerate(split_participants):

In [None]:
prepro_options = [
    dict(age_sex_regress=False, impute=False, censor=False),
    dict(age_sex_regress=False, impute=True, censor=True),
    dict(age_sex_regress=True, impute=False, censor=False),
]

In [None]:


# if impute:
#     # impute
#     imputer = IterativeImputer(random_state=seed, max_iter=1000, skip_complete=True)
#     dat.loc[:, cols] = imputer.fit_transform(dat.loc[:, cols])
# else:
#     clean_participants = dat.loc[dat.loc[:, cols].isnull().sum(1)== 0, 'participant'].values
#     dat = dat.loc[dat.participant.isin(clean_participants)].copy()
# #    return dat

In [None]:
hcp_categories = [
    ('ASR_Aggr_T','asr', 'adj'),
    ('ASR_Anxd_Pct','asr', 'adj'),
    ('ASR_Attn_T','asr', 'adj'),
    ('ASR_Extn_T','asr', 'adj'),
    ('ASR_Intn_T','asr', 'adj'),
    ('ASR_Intr_T','asr', 'adj'),
    ('ASR_Rule_T','asr', 'adj'),
    ('ASR_Soma_T','asr', 'adj'),
    ('ASR_TAO_Sum','asr', 'unadj'),
    ('ASR_Thot_T','asr', 'adj'),
    ('ASR_Totp_T','asr', 'adj'),
    ('ASR_Witd_T','asr', 'adj'),
    ('AngAffect_Unadj','emo', 'unadj'),
    ('AngAggr_Unadj','emo', 'unadj'),
    ('AngHostil_Unadj','emo', 'unadj'),
    ('CardSort_AgeAdj','task', 'adj'),
    ('CardSort_Unadj','task', 'unadj'),
    ('CogCrystalComp_AgeAdj','int', 'adj'),
    ('CogEarlyComp_AgeAdj','int', 'adj'),
    ('CogFluidComp_AgeAdj','int', 'adj'),
    ('CogTotalComp_AgeAdj','int', 'adj'),
    ('DDisc_AUC_200','task', 'unadj'),
    ('DDisc_AUC_40K','task', 'unadj'),
    ('Dexterity_Unadj','motor', 'unadj'),
    ('ER40ANG','penn_ER', 'unadj'),
    ('ER40FEAR','penn_ER', 'unadj'),
    ('ER40NOE','penn_ER', 'unadj'),
    ('ER40SAD','penn_ER', 'unadj'),
    ('ER40_CR','penn_ER', 'unadj'),
    ('ER40_CRT','penn_ER', 'unadj'),
    ('EmotSupp_Unadj','emo', 'unadj'),
    ('Emotion_Task_Face_Acc','task', 'unadj'),
    ('Endurance_Unadj','motor', 'unadj'),
    ('FearAffect_Unadj','emo', 'unadj'),
    ('FearSomat_Unadj','emo', 'unadj'),
    ('Flanker_AgeAdj','task', 'adj'),
    ('Flanker_Unadj','task', 'unadj'),
    ('Friendship_Unadj','emo', 'unadj'),
    ('GaitSpeed_Comp','motor', 'unadj'),
    ('Handedness','emo', 'unadj'),
    ('IWRD_TOT','task', 'unadj'),
    ('InstruSupp_Unadj','emo', 'unadj'),
    ('Language_Task_Math_Avg_Difficulty_Level','language', 'unadj'),
    ('Language_Task_Story_Avg_Difficulty_Level','language', 'unadj'),
    ('LifeSatisf_Unadj','emo', 'unadj'),
    ('ListSort_AgeAdj','task', 'adj'),
    ('ListSort_Unadj','task', 'unadj'),
    ('Loneliness_Unadj','emo', 'unadj'),
    ('MMSE_Score','alertness', 'uadj'),
    ('Mars_Final','sensory', 'uadj'),
    ('MeanPurp_Unadj','emo', 'unadj'),
    ('NEOFAC_A','ffi', 'unadj'),
    ('NEOFAC_C','ffi', 'unadj'),
    ('NEOFAC_E','ffi', 'unadj'),
    ('NEOFAC_N','ffi', 'unadj'),
    ('NEOFAC_O','ffi', 'unadj'),
    ('Odor_Unadj','sensory', 'unadj'),
    ('PMAT24_A_CR','task', 'unadj'),
    ('PSQI_Score','alertness', 'unadj'),
    ('PainInterf_Tscore','emo', 'unadj'),
    ('PercHostil_Unadj','emo', 'unadj'),
    ('PercReject_Unadj','emo', 'unadj'),
    ('PercStress_Unadj','emo', 'unadj'),
    ('PicSeq_AgeAdj','task', 'adj'),
    ('PicSeq_Unadj','task', 'unadj'),
    ('PicVocab_AgeAdj','language', 'adj'),
    ('PicVocab_Unadj','language', 'unadj'),
    ('PosAffect_Unadj','emo', 'unadj'),
    ('ProcSpeed_AgeAdj','task', 'adj'),
    ('ProcSpeed_Unadj','task', 'unadj'),
    ('ReadEng_AgeAdj','language', 'adj'),
    ('ReadEng_Unadj','language', 'unadj'),
    ('Relational_Task_Acc','task', 'unadj'),
    ('SCPT_SEN','task', 'unadj'),
    ('SCPT_SPEC','task', 'unadj'),
    ('Sadness_Unadj','emo', 'unadj'),
    ('SelfEff_Unadj','emo', 'unadj'),
    ('Social_Task_Perc_Random','task', 'unadj'),
    ('Social_Task_Perc_TOM','task', 'unadj'),
    ('Strength_Unadj','motor', 'unadj'),
    ('Taste_Unadj','sensory', 'unadj'),
    ('VSPLOT_TC','task', 'unadj'),
    ('WM_Task_Acc','task', 'unadj'),
]
hcp_categories = pd.DataFrame(hcp_categories, columns=['name', 'category', 'adj'])

In [None]:
hcp_categories.groupby("category").count()

In [None]:
emo_cats = hcp_categories.query("category == 'emo'").name.values
task_cats = hcp_categories.query("category == 'task'").name.values
sorted_cats = hcp_categories.sort_values(['adj', 'category']).name.values
asr_cats = hcp_categories.query("category == 'asr'").name.values
ffi_cats = hcp_categories.query("category == 'ffi'").name.values
motor_cats = hcp_categories.query("category == 'motor'").name.values
lang_cats = hcp_categories.query("category == 'language'").name.values

In [None]:
ppos = prepro_options[0]
use_cols = lang_cats

ppo_phenotypes = clean_phenotype_data(all_dat,
                                      hcp_phenotype_columns,
                                      **ppos,
                                     ).loc[:, use_cols]
oU, os, oV = svd(ppo_phenotypes, full_matrices=False)

for iiii, (pids1, pids2) in enumerate(split_participants):
    print(iiii, end=',')
    sp1dat = clean_phenotype_data(all_dat.loc[all_dat.participant.isin(pids1)].copy(),
                                  use_cols,
                                  **ppos).loc[:, use_cols]
    sp2dat = clean_phenotype_data(all_dat.loc[all_dat.participant.isin(pids2)].copy(),
                                  use_cols,
                                  **ppos).loc[:, use_cols]

    U1, s1, V1 = svd(sp1dat, full_matrices=False)
    U2, s2, V2 = svd(sp2dat, full_matrices=False)

    component_mapping1 = stable_component_marriage(V1, oV)
    mapped_V1 = V1[component_mapping1]
    V1_signs = np.sign([np.corrcoef(oVc, mVc)[0,1] for oVc, mVc in zip(oV, mapped_V1)]).reshape(-1,1)
    mapped_V1 *= V1_signs
    component_mapping2 = stable_component_marriage(V2, oV)
    mapped_V2 = V2[component_mapping2]
    V2_signs = np.sign([np.corrcoef(oVc, mVc)[0,1] for oVc, mVc in zip(oV, mapped_V2)]).reshape(-1,1)
    mapped_V2 *= V2_signs
    
    mc = [np.corrcoef(oVc, mVc)[0,1] for oVc, mVc in zip(mapped_V1, mapped_V2)]

    
    mapped_s1 = s1[component_mapping1]
    pct_var1 = s1 **2 / (s1 **2).sum()
    mapped_s2 = s2[component_mapping2]
    pct_var2 = mapped_s2 **2 / (mapped_s2 **2).sum()

    break

In [None]:
sns.heatmap((1 - cdist(mapped_V1, mapped_V2, metric='correlation')), vmin=-1, vmax=1, cmap='RdBu')

In [None]:
#task_mc = mc.copy()
#sorted_mc = mc.copy()
#emo_mc = mc.copy()
lang_mc = mc.copy()
fig, ax = plt.subplots(1)
ax.plot(task_mc, label="task")
#ax.plot(sorted_mc)
ax.plot(emo_mc, label="emotion")
ax.plot(asr_mc, label="asr")
ax.plot(ffi_mc, label="ffi")
ax.plot(motor_mc, label="motor")
ax.plot(lang_mc, label="language")
ax.legend()
ax.set_ylabel("Inter-split loading correlation")
ax.set_xlabel("Component number")

In [None]:
# emo
plt.plot(mc)

In [None]:
mv1df = pd.DataFrame(data=mapped_V1, columns=hcp_phenotype_columns)
mv1df['mc'] = mc
mv1df['s'] = mapped_s1

mv2df = pd.DataFrame(data=mapped_V2, columns=hcp_phenotype_columns)
mv2df['mc'] = mc

In [None]:
mv1df.loc[mv1df.mc > 0.5].s ** 2 / (mv1df.loc[mv1df.mc > 0.5].s ** 2 ).sum() * 100

In [None]:
mapped_U1 = U1[:, component_mapping1]
mapped_s1 = s1[component_mapping1]
mapped_S1 = np.zeros((len(s1), len(s1)))
np.fill_diagonal(mapped_S1, mapped_s1)

In [None]:
sns.heatmap(recon6)

In [None]:
s1[65]

In [None]:
pct_var1[65] * 100

In [None]:
recon64 = mapped_U1[:, [64]] @ mapped_S1[[64], :][:, [64]] @ mapped_V1[[64], :]
sns.heatmap(recon64)

In [None]:
recona = mapped_U1[:, :64] @ mapped_S1[:64, :][:, :64] @ mapped_V1[:64, :]
resida = sp1dat - recona
reconb = mapped_U1[:, :65] @ mapped_S1[:65, :][:, :65] @ mapped_V1[:65, :]
residb = sp1dat - reconb
sns.heatmap(resida)

In [None]:
sns.heatmap(residb)

In [None]:
sns.heatmap(residb - resida)

In [None]:
mapped_U1[:, [0,1,2,3,6]].shape, mapped_S1[[0,1,2,3,6], :][:, [0,1,2,3,6]].shape, mapped_V1[[0,1,2,3,6], :].shape

In [None]:
mv1df.loc[mv1df.mc > 0.5]

In [None]:
plt.plot(mv1df.loc[67, hcp_phenotype_columns])
plt.plot(mv2df.loc[67, hcp_phenotype_columns])

In [None]:
sns.heatmap(mv1df.loc[mv1df.mc > 0.5], cmap='RdBu', vmin=-1, vmax=1)

In [None]:
sns.heatmap(mv2df.loc[mv1df.mc > 0.5], cmap='RdBu', vmin=-1, vmax=1)

In [None]:
plt.plot(mc)

In [None]:
plt.plot(mc)

In [None]:
plt.plot(mc)

In [None]:
(lci_cum_var > 0.95).nonzero()[0].min()

In [None]:
mean_cum_var[34]

In [None]:
lci_cum_var[34]

In [None]:
V1s = []
V2s = []
pct_var1s = []
pct_var2s = []
match_corrs = []

vals = hcp_dat[0]
phenotypes = pd.concat([vals['before_svd_split1'], vals['before_svd_split2']])
perm_phenotypes = phenotypes.apply(lambda x: x.sample(frac=1, replace=False).values)
oU, os, oV = svd(perm_phenotypes, full_matrices=False)

for ix, vals in hcp_dat.items():
    sp1dat = perm_phenotypes.loc[perm_phenotypes.index.isin(vals['split_1_subject_id'].values)]
    sp2dat = perm_phenotypes.loc[perm_phenotypes.index.isin(vals['split_2_subject_id'].values)]

    # jV1 = vals['Vt_split1']
    # jV2 = vals['Vt_split2']
    U1, s1, V1 = svd(sp1dat, full_matrices=False)
    U2, s2, V2 = svd(sp2dat, full_matrices=False)

    component_mapping1 = stable_component_marriage(V1, oV)
    mapped_V1 = V1[component_mapping1]
    V1_signs = np.sign([np.corrcoef(oVc, mVc)[0,1] for oVc, mVc in zip(oV, mapped_V1)]).reshape(-1,1)
    mapped_V1 *= V1_signs
    component_mapping2 = stable_component_marriage(V2, oV)
    mapped_V2 = V2[component_mapping2]
    V2_signs = np.sign([np.corrcoef(oVc, mVc)[0,1] for oVc, mVc in zip(oV, mapped_V2)]).reshape(-1,1)
    mapped_V2 *= V2_signs
    
    mc = [np.corrcoef(oVc, mVc)[0,1] for oVc, mVc in zip(mapped_V1, mapped_V2)]
    match_corrs.append(mc)
    V1s.append(mapped_V1)
    V2s.append(mapped_V2)
    
    mapped_s1 = s1[component_mapping1]
    pct_var1 = s1 **2 / (s1 **2).sum()
    pct_var1s.append(pct_var1)
    mapped_s2 = s2[component_mapping2]
    pct_var2 = mapped_s2 **2 / (mapped_s2 **2).sum()
    pct_var2s.append(pct_var2)

In [None]:
mean_match_corr = np.tanh(np.mean(np.abs(np.arctanh(match_corrs)), 0))
mean_match_r2 = mean_match_corr ** 2
lci_match_corr = np.percentile(np.abs(match_corrs), 2.5, axis=0)
lci_match_r2 = lci_match_corr ** 2
uci_match_corr= np.percentile(np.abs(match_corrs), 97.5, axis=0)
uci_match_r2 = uci_match_corr ** 2

In [None]:
for ix, vals in hcp_dat.items():
    assert (vals['split_1_subject_id'] == vals['split_2_subject_id']).all()

In [None]:
vals['split_1_subject_id']

In [None]:
vals['split_2_subject_id']

In [None]:
fig, ax = plt.subplots(1)
ax.fill_between(range(len(mean_match_corr)), lci_match_corr, uci_match_corr, alpha = 0.5)
ax.plot(mean_match_corr)
ax.set_ylim(0,1)
ax.set_ylabel('Inter-split correlation')
ax.set_xlabel('Latent number')

In [None]:
pct_vars = np.vstack([pct_var1s, pct_var2s])

In [None]:
pct_var_df = pd.DataFrame(pct_vars)

In [None]:
sns.barplot(data=pct_var_df, 

In [None]:
plt.plot(np.array(pct_vars).mean(0))


In [None]:
svder = TruncatedSVD(n_components=sp1dat.shape[-1])
_  = svder.fit(sp1dat)
svder.components_ == jV1

In [None]:
sns.heatmap(svder.components_)

In [None]:
sns.heatmap(jV1)

In [None]:
sns.heatmap(V1)

In [None]:
sns.heatmap((1 - cdist(svder.components_, jV1, metric='correlation')), vmin=-1, vmax=1, cmap='RdBu')

In [None]:
sns.heatmap((1 - cdist(V1, jV1, metric='correlation')), vmin=-1, vmax=1, cmap='RdBu')

In [None]:
sns.heatmap((1 - cdist(V1, mapped_V2 * V_signs, metric='correlation')), vmin=-1, vmax=1, cmap='RdBu')

In [None]:
U, s, V = svd(phenotypes, full_matrices=False)
S = np.zeros((s.shape[0], s.shape[0]))
np.fill_diagonal(S, s)

In [None]:
perm_phenotypes = phenotypes.apply(lambda x: x.sample(frac=1, replace=False).values)

In [None]:
perm_phenotypes = 

In [None]:
U, s, V = svd(phenotypes, full_matrices=False)
S = np.zeros((s.shape[0], s.shape[0]))
np.fill_diagonal(S, s)
pU, ps, pV = svd(perm_phenotypes, full_matrices=False)
pS = np.zeros((ps.shape[0], ps.shape[0]))
np.fill_diagonal(pS, ps)


# dropped_phenotypes = phenotypes.copy()
# for var in dropped_phenotypes:
#     varval = dropped_phenotypes[var]
#     lr = LinearRegression(fit_intercept=False)
#     _ = lr.fit(U[:, [cx]], varval)
#     dropped_varval = varval - lr.predict(U[:, [cx]])
#     dropped_phenotypes[var] = dropped_varval

In [None]:
ps **2 / (ps **2).sum()

In [None]:
_, _, dropped_V = svd(dropped_phenotypes, full_matrices=False)


In [None]:
only_S = np.zeros((s.shape[0], s.shape[0]))
only_S[cx, cx] = s[cx]

In [None]:
only_recon = U @ only_S @ V

In [None]:
plt.plot(np.abs(V[cx]) / np.abs(V[cx]).max())

In [None]:
sns.heatmap(only_recon, cmap='RdBu', vmin=-3, vmax=3)

In [None]:
sns.heatmap(phenotypes, cmap='RdBu', vmin=-3, vmax=3)

In [None]:
cx = 0 
# recon without component
dropped_V = V.copy()
dropped_V[cx] = 0
dropped_U = U.copy()
dropped_U[:,cx] = 0
dropped_S = S.copy()
dropped_S[cx, cx] = 0
dropped_recon = dropped_U @ dropped_S @ dropped_V
noise = rng.normal(loc=0, scale=0.1, size=dropped_recon.shape)
dropped_recon = dropped_recon + noise
dd_U, dd_s, dd_V  = svd(dropped_recon, full_matrices=False)


In [None]:
dropped_V = V.copy()
dropped_V = np.vstack([dropped_V[:cx, :], dropped_V[(cx + 1):, :]])
dropped_U = U.copy()
dropped_U = np.hstack([dropped_U[:, :cx], dropped_U[:, (cx + 1):]])
dropped_S = np.zeros([s.shape[0] - 1, s.shape[0] - 1])
dropped_s = np.hstack([s[:cx], s[(cx+1):]])
np.fill_diagonal(dropped_S, dropped_s)
dropped_recon = dropped_U @ dropped_S @ dropped_V
dd_U, dd_s, dd_V = svd(dropped_recon, full_matrices=False)

In [None]:
((phenotypes - dropped_recon)**2).sum().sum()

In [None]:
dd_s **2 / (dd_s **2).sum()

In [None]:
dropped_V.shape

In [None]:
null_corrs = 1 - cdist(V[[0]], dropped_V, metric='correlation')

In [None]:
null_corrs

In [None]:
1 - cdist(V[[0]], V, metric='correlation')

In [None]:
np.abs(null_corrs).max()

In [None]:
np.abs(null_corrs).max()

In [None]:
sns.heatmap(phenotypes - dropped_recon)

In [None]:
sns.heatmap(phenotypes)

In [None]:
V[0]

In [None]:
dropped_V[-1]

In [None]:
sns.heatmap((1 - cdist(vals['Vt_split1'], vals['Vt_split2'], metric='correlation')), vmin=-1, vmax=1, cmap='RdBu')



In [None]:
m1, m2, disparity = procrustes(vals['Vt_split1'], vals['Vt_split2'])

In [None]:
foo = vals['Vt_split1'].copy()
foo[0:40] = vals['Vt_split2'][0:40]

In [None]:
R, scale = orthogonal_procrustes(vals['Vt_split2'].T, vals['Vt_split1'].T)

In [None]:
m2 = vals['Vt_split2'] @ R

In [None]:
plt.plot(R[0])

In [None]:
sns.heatmap(R, vmin=-1, vmax=1, cmap='RdBu')

In [None]:
sns.heatmap((1 - cdist(m1, m2, metric='correlation')), vmin=-1, vmax=1, cmap='RdBu')

In [None]:
disparity

In [None]:
plt.scatter(vals['Vt_split1'][0], vals['Vt_split2'][0])

In [None]:
plt.scatter(vals['Vt_split1'][:, 0], vals['Vt_split2'][:, 0])