# Analyze Model Weights

In [1]:
import pandas as pd
import numpy as np

%load_ext autoreload
%autoreload 2

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
sns.set_theme(
    context="paper", 
    style="whitegrid", 
    font_scale=1.2,
    rc={'figure.figsize': (20, 20), 'figure.dpi': 300}
)

## Load Weights

In [18]:
from common.data import get_model_weight

population = 'adhd'
measure = 'WISC_PSI'
age_group = 'all'

pls_weights = get_model_weight('pls', population, measure, age_group)
ridge_weights = get_model_weight('ridge', population, measure, age_group)

print(pls_weights.shape, ridge_weights.shape)

(34716,) (34716,)


In [9]:
from os.path import join
from common.paths import RIDGE_WEIGHTS

weight_pop = 'adhd'
weight_tar = 'WISC_FSIQ'
weight_age = 'all'
s1_f = f'ridge_{weight_pop}_{weight_tar}_{weight_age}_set_2.npy'
s2_f = f'ridge_healthy_{weight_tar}_{weight_age}_set_1.npy'

s1 = np.load(join(RIDGE_WEIGHTS, s1_f))
s2 = np.load(join(RIDGE_WEIGHTS, s2_f))

## Intraclass Correlation (ICC)

In [10]:
%%time
import pingouin as pg

# coefs = np.array([pls_weights, ridge_weights])
coefs = np.array([s1, s2])
# print(f'{population}_{measure}_{age_group}', coefs.shape)

icc_data = pd.DataFrame(coefs).melt(var_name='connection', value_name='weight', ignore_index=False)
icc_data['cv_run_num'] = icc_data.index
icc = pg.intraclass_corr(data=icc_data, targets='connection', raters='cv_run_num', ratings='weight').round(3)
icc.set_index("Type")

CPU times: user 29.7 s, sys: 196 ms, total: 29.9 s
Wall time: 29.9 s


Unnamed: 0_level_0,Description,ICC,F,df1,df2,pval,CI95%
Type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
ICC1,Single raters absolute,0.121,1.276,34715,34716,0.0,"[0.11, 0.13]"
ICC2,Single random raters,0.121,1.277,34715,34715,0.0,"[0.11, 0.13]"
ICC3,Single fixed raters,0.122,1.277,34715,34715,0.0,"[0.11, 0.13]"
ICC1k,Average raters absolute,0.216,1.276,34715,34716,0.0,"[0.2, 0.23]"
ICC2k,Average random raters,0.217,1.277,34715,34715,0.0,"[0.2, 0.23]"
ICC3k,Average fixed raters,0.217,1.277,34715,34715,0.0,"[0.2, 0.23]"
