# Transfer
This notebook will transfer a pre-fitted HBR model to a small dataset. 

In [2]:
!pip install git+https://github.com/amarquand/PCNtoolkit.git@v1.alpha.6
!pip install numpy==1.26.4

Collecting git+https://github.com/amarquand/PCNtoolkit.git@v1.alpha.6
  Cloning https://github.com/amarquand/PCNtoolkit.git (to revision v1.alpha.6) to /private/var/folders/m8/vtbcb7c96ms3mbjny3b70h3w0000gp/T/pip-req-build-nyuq0v6n
  Running command git clone --filter=blob:none --quiet https://github.com/amarquand/PCNtoolkit.git /private/var/folders/m8/vtbcb7c96ms3mbjny3b70h3w0000gp/T/pip-req-build-nyuq0v6n
  Running command git checkout -q 1a61f549a4f7b26e4eab03ccc14030c497a24421
  Resolved https://github.com/amarquand/PCNtoolkit.git to commit 1a61f549a4f7b26e4eab03ccc14030c497a24421
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: pcntoolkit
  Building wheel for pcntoolkit (pyproject.toml) ... [?25ldone
[?25h  Created wheel for pcntoolkit: filename=pcntoolkit-1.0.0a6-py3-none-any.whl size=104261 sha256=f34d038a0128500b5c9dc82c09

In [3]:
import os
from pcntoolkit import (
    load_fcon1000,
    NormativeModel,
)
import numpy as np
import seaborn as sns 
import matplotlib.pyplot as plt
from PIL import Image
from pcntoolkit.util.output import Output
from modelspec import shashb1
Output.set_show_messages(False)

## Why do we transfer?
First we will demonstrate why we need transfer in the first place. 

Let's say we want to make a normative model, and we only have access to a small dataset:

In [None]:
# Download the dataset and select these two sites
transfer_sites = ["Milwaukee_b", "Oulu"]
transfer_data, _ = load_fcon1000(save_path="data").split_batch_effects(
    {"site": transfer_sites}, names=("transfer", "fit")
)
# Select only a few features
features_to_model = [
    "WM-hypointensities",
    "Right-Lateral-Ventricle",
    "Right-Amygdala",
    "CortexVol",
]
transfer_data = transfer_data.sel({"response_vars": features_to_model})
print(
    f"This dataset contains {len(transfer_data.observations)} samples of {len(transfer_data.response_vars)} response variables"
)

transfer_train, transfer_test = transfer_data.train_test_split()


### Inspecting the data
As we see in the plot below, our available data does cover the entire age range, but rather is concentrated in the ages between 20 and 22, and between ages 44 and 67. If we fit a model on this data, the predictions between 22 and 44 that it makes will be a 'best guess' between those clusters.

In [None]:
# Inspect the data
feature_to_plot = features_to_model[0]
df = transfer_data.to_dataframe()
fig, ax = plt.subplots(1, 2, figsize=(15, 5))

sns.countplot(
    data=df,
    y=("batch_effects", "site"),
    hue=("batch_effects", "sex"),
    ax=ax[0],
    orient="h",
    palette="Set1",
)
ax[0].legend(title="Sex")
ax[0].set_title("Count of sites")
ax[0].set_xlabel("Site")
ax[0].set_ylabel("Count")


sns.scatterplot(
    data=df,
    x=("X", "age"),
    y=("Y", feature_to_plot),
    hue=("batch_effects", "site"),
    style=("batch_effects", "sex"),
    ax=ax[1],
    palette="Set2",
)
# ax[1].legend([], [])
ax[1].set_title(f"Scatter plot of age vs {feature_to_plot}")
ax[1].set_xlabel("Age")
ax[1].set_ylabel(feature_to_plot)

plt.show()

### Fitting a model to the small dataset
Let's fit a model to this dataset.

In [None]:
new_model = shashb1
new_model.set_save_dir("models/new_model")
test = new_model.fit_predict(transfer_train, transfer_test)

### Inspecting the model output
Now go to the `models/new_model/plots` directory, and open some centiles plots. They do not follow a natural curve, especially in the range between 22 and 44. We already expected this, so it's nice to see it confirmed.


## Transfering a pre-fitted model to our small dataset

Now we will do it the right way. 

We will take a model that was previously fitted on a larger dataset - N.B.: on different sites - and transfer it to our small dataset. By transfering, we take the original model as a starting point, and adapt it slightly to our new dataset. This will ensure that a lot of the model characteristics are retained, so also the predictions in the age range 22-44.

In [1]:
# Download model files
!wget https://raw.githubusercontent.com/AuguB/federated_learning_workshop/main/models/model_to_transfer.zip -O model_to_transfer.zip
!unzip -q model_to_transfer.zip -d models
!rm -rf models/__MACOSX
!rm model_to_transfer.zip
# Load the normative model
model_to_transfer = NormativeModel.load("models/model_to_transfer")
# Show the batch effects that this model was fitted on:
print("This model was trained on these batches:")
model_to_transfer.unique_batch_effects

zsh:1: command not found: wget
unzip:  cannot find or open model_to_transfer.zip, model_to_transfer.zip.zip or model_to_transfer.zip.ZIP.
rm: model_to_transfer.zip: No such file or directory


NameError: name 'NormativeModel' is not defined

In [None]:
# Sanity check: try to use the pre-fitted model to predict on the transfer data.
# This throws an error, because our transfer data only contains data from "Milwaukee_b" and "Oulu", which were not in the original training set (see list directly above this cell)
try:
    model_to_transfer.predict(transfer_test)
except Exception as e:
    print(f"This should throw an error!: {e}")

In [None]:
# Now we transfer the model:
transfered_model = model_to_transfer.transfer_predict(
    transfer_train, transfer_test, save_dir="models/transferred_model"
)

## Inspect the model output

Let's compare the centiles from the new model and the centiles from the transfered model side by side

In [None]:
fig, ax = plt.subplots(len(features_to_model), 2, figsize=(13, 20))
for i, f in enumerate(features_to_model):
    centile_plot_transferred_model = np.asarray(
        Image.open(
            f"models/transferred_model/plots/centiles_{f}_transfer_test_harmonized.png"
        )
    )
    centile_plot_new_model = np.asarray(
        Image.open(
            f"models/new_model/plots/centiles_{f}_transfer_test_harmonized.png"
        )
    )

    plt.axis("off")
    plt.rcParams["axes.titlesize"] = 20

    ax[i, 0].imshow(centile_plot_new_model)
    ax[i, 0].axis("off")
    ax[i, 0].set_title("New model")

    ax[i, 1].imshow(centile_plot_transferred_model)
    ax[i, 1].axis("off")
    ax[i, 1].set_title("Transferred model")
    plt.tight_layout()
