In [1]:
import copy
import logging
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns

import pcntoolkit.util.output
from pcntoolkit import (
    HBR,
    BsplineBasisFunction,
    NormalLikelihood,
    NormativeModel,
    NormData,
    load_fcon1000,
    make_prior,
)

sns.set_style("darkgrid")

# Suppress some annoying warnings and logs
pymc_logger = logging.getLogger("pymc")

pymc_logger.setLevel(logging.WARNING)
pymc_logger.propagate = False

warnings.simplefilter(action="ignore", category=FutureWarning)
pd.options.mode.chained_assignment = None  # default='warn'
pcntoolkit.util.output.Output.set_show_messages(False)

In [2]:
# Download an example dataset
norm_data: NormData = load_fcon1000()

# Select only a few features
features_to_model = [
    "WM-hypointensities",
    "Right-Lateral-Ventricle",
    # "Right-Amygdala",
    # "CortexVol",
]
norm_data = norm_data.sel({"response_vars": features_to_model})

all_sites = np.unique(norm_data.batch_effects.sel(batch_effect_dims="site").values)

# split all_sites into three random groups of 7 sites
np.random.shuffle(all_sites)
group1 = all_sites[:7]
group2 = all_sites[7:14]
group3 = all_sites[14:]
print(f"Group 1: {group1}")
print(f"Group 2: {group2}")
print(f"Group 3: {group3}")

data_group1, data_group23 = norm_data.batch_effects_split({"site": group1}, names=("group1", "group23"))
data_group2, data_group3 = data_group23.batch_effects_split({"site": group2}, names=("group2", "group3"))

Group 1: ['Bangor' 'Atlanta' 'Baltimore' 'ICBM' 'Pittsburgh' 'AnnArbor_b'
 'Berlin_Margulies']
Group 2: ['Cleveland' 'NewYork_a' 'Munchen' 'Newark' 'Leiden_2200' 'Queensland'
 'Beijing_Zang']
Group 3: ['Milwaukee_b' 'Oxford' 'Cambridge_Buckner' 'Leiden_2180' 'PaloAlto'
 'AnnArbor_a' 'Oulu' 'NewYork_a_ADHD' 'SaintLouis']


In [3]:
mu = make_prior(
    linear=True,
    slope=make_prior(dist_name="Normal", dist_params=(0.0, 10.0)),
    intercept=make_prior(
        random=True,
        mu=make_prior(dist_name="Normal", dist_params=(0.0, 1.0)),
        sigma=make_prior(dist_name="Normal", dist_params=(0.0, 1.0), mapping="softplus", mapping_params=(0.0, 3.0)),
    ),
    basis_function=BsplineBasisFunction(basis_column=0, nknots=5, degree=3),
)
sigma = make_prior(
    linear=True,
    slope=make_prior(dist_name="Normal", dist_params=(0.0, 2.0)),
    intercept=make_prior(dist_name="Normal", dist_params=(1.0, 1.0)),
    basis_function=BsplineBasisFunction(basis_column=0, nknots=5, degree=3),
    mapping="softplus",
    mapping_params=(0.0, 3.0),
)

likelihood = NormalLikelihood(mu, sigma)

template_hbr = HBR(
    name="template",
    cores=16,
    progressbar=True,
    draws=1500,
    tune=500,
    chains=4,
    nuts_sampler="nutpie",
    likelihood=likelihood,
)

model = NormativeModel(
    template_regression_model=template_hbr,
    savemodel=True,
    evaluate_model=True,
    saveresults=True,
    saveplots=True,
    save_dir="resources/hbr_normal/save_dir",
    inscaler="standardize",
    outscaler="standardize",
)

In [4]:
model1 = copy.deepcopy(model)
model1.save_dir = "resources/hbr_merge/model1"
model2 = copy.deepcopy(model)
model2.save_dir = "resources/hbr_merge/model2"
model3 = copy.deepcopy(model)
model3.save_dir = "resources/hbr_merge/model3"

model1.fit(data_group1)
model2.fit(data_group2)
model3.fit(data_group3)

Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,0.11,127
,2000,0,0.11,191
,2000,0,0.11,191
,2000,0,0.11,127


Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,0.18,63
,2000,0,0.16,95
,2000,0,0.18,63
,2000,0,0.16,127


Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,0.14,127
,2000,0,0.14,95
,2000,0,0.15,127
,2000,0,0.16,63


Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,0.17,63
,2000,0,0.16,95
,2000,0,0.17,63
,2000,0,0.16,31


Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,0.14,63
,2000,0,0.14,95
,2000,0,0.14,127
,2000,0,0.14,63


Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,0.16,63
,2000,0,0.16,31
,2000,0,0.15,63
,2000,0,0.16,127


In [5]:
# model1 = NormativeModel.load(path="resources/hbr_merge/model1")
# model2 = NormativeModel.load(path="resources/hbr_merge/model2")
# model3 = NormativeModel.load(path="resources/hbr_merge/model3")

In [6]:
# We can pass a list of models or paths to the merge function.
merged_model = NormativeModel.merge(
    save_dir="resources/hbr_merge/merged_model", models=["resources/hbr_merge/model1", model2, model3]
)
# merged_model = NormativeModel.load(path="resources/hbr_merge/merged_model")

Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,0.12,31
,2000,0,0.12,287
,2000,0,0.12,63
,2000,0,0.11,127


Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,0.08,511
,2000,0,0.09,127
,2000,0,0.09,63
,2000,0,0.09,191


In [7]:
merged_model.predict(data_group1)
merged_model.predict(data_group2)
merged_model.predict(data_group3)
merged_model.predict(norm_data)