In [1]:
from model.utils.etc import WeightRepository

In [2]:
from model.sfcn_reg import RegressionSFCN

MIN_AGE = 3
MAX_AGE = 95
model = RegressionSFCN(prediction_range=(MIN_AGE, MAX_AGE))

In [34]:
# Suppose you already have the weight path
WEIGHT_PATH = '../.pyment/pyment/data/sfcn-regbrain-age-2022.h5'
torch_weights = WeightRepository.convert_to_torch(WEIGHT_PATH, model.state_dict())

In [4]:
# Check translation
for key in torch_weights:
    print(key)

fn1.sfcn-reg_block1_conv.weight
fn1.sfcn-reg_block1_conv.bias
fn1.sfcn-reg_block1_norm.weight
fn1.sfcn-reg_block1_norm.bias
fn1.sfcn-reg_block1_norm.running_mean
fn1.sfcn-reg_block1_norm.running_var
fn1.sfcn-reg_block2_conv.weight
fn1.sfcn-reg_block2_conv.bias
fn1.sfcn-reg_block2_norm.weight
fn1.sfcn-reg_block2_norm.bias
fn1.sfcn-reg_block2_norm.running_mean
fn1.sfcn-reg_block2_norm.running_var
fn1.sfcn-reg_block3_conv.weight
fn1.sfcn-reg_block3_conv.bias
fn1.sfcn-reg_block3_norm.weight
fn1.sfcn-reg_block3_norm.bias
fn1.sfcn-reg_block3_norm.running_mean
fn1.sfcn-reg_block3_norm.running_var
fn1.sfcn-reg_block4_conv.weight
fn1.sfcn-reg_block4_conv.bias
fn1.sfcn-reg_block4_norm.weight
fn1.sfcn-reg_block4_norm.bias
fn1.sfcn-reg_block4_norm.running_mean
fn1.sfcn-reg_block4_norm.running_var
fn1.sfcn-reg_block5_conv.weight
fn1.sfcn-reg_block5_conv.bias
fn1.sfcn-reg_block5_norm.weight
fn1.sfcn-reg_block5_norm.bias
fn1.sfcn-reg_block5_norm.running_mean
fn1.sfcn-reg_block5_norm.running_var
fn1.s

In [5]:
# Load into the model
model.load_state_dict(torch_weights)

<All keys matched successfully>

In [15]:
# Prepare leonardsen model
from pyment.models.sfcn_reg import RegressionSFCN

original_model = RegressionSFCN(prediction_range=(MAX_AGE, MIN_AGE), name="sfcn_reg")
original_model.weight_name = "brain-age-2022"

In [7]:
tf_weights = WeightRepository.load_tf_weights(weights_path=WEIGHT_PATH)

In [26]:
original_model.set_weights(list(tf_weights.values()))

In [28]:
# Prepare data as original
import os
import numpy as np
from tqdm import tqdm
import nibabel as nib
import pandas as pd
import torch

LABELS_FILE = "../test_data/IXI/IXI.xls"
labels = pd.read_excel(LABELS_FILE)

IMAGE_FOLDER = "../test_data/IXI"
predictions_leo = []
predictions_torch = []

# Set torch model to eval mode
model.eval()

for subject in tqdm(os.listdir(IMAGE_FOLDER)):#
    if "IXI" in subject and os.path.isdir(os.path.join(IMAGE_FOLDER, subject)):
        path = os.path.join(IMAGE_FOLDER, subject, 'cropped.nii.gz')
        subjectid = int(subject[3:])
        if not os.path.isfile(path):
            print(f'Skipping {subject}: Missing cropped.nii.gz')
            continue
        elif subjectid not in labels['IXI_ID'].values:
            print(f'Skipping {subject}: Missing labels')
            continue

        age = labels.loc[labels['IXI_ID'] == subjectid, 'AGE'].values[0]
        
        img = nib.load(path)
        img = img.get_fdata()

        img = np.expand_dims(img, 0)

        with torch.no_grad():
            prediction: torch.Tensor = model(torch.from_numpy(img).unsqueeze(0).float())[0]
            prediction = prediction.squeeze(0).item()
            predictions_torch.append({
                'subject': subject,
                'age': age,
                'prediction': prediction
            })

        prediction: np.ndarray = original_model.predict(np.expand_dims(img, -1), verbose=0)[0]
        prediction = original_model.postprocess(prediction)
        predictions_leo.append({
                'subject': subject,
                'age': age,
                'prediction': prediction
            })


 14%|█▍        | 12/83 [00:18<01:46,  1.49s/it]

Skipping IXI081: Missing labels


 20%|██        | 17/83 [00:24<01:29,  1.36s/it]

Skipping IXI088: Missing labels


100%|██████████| 83/83 [02:05<00:00,  1.51s/it]


In [30]:
# Taking a random prediction...

predictions_leo[20]['prediction'], predictions_torch[20]['prediction']

(array(27.239164, dtype=float32), 27.23916244506836)

In [33]:
# Using original pyment comparision
import plotly.graph_objs as go
from plotly.subplots import make_subplots

df_pred_leo = pd.DataFrame(predictions_leo)
df_pred_torch = pd.DataFrame(predictions_torch)

fig = make_subplots(rows=1, cols=2)

fig.add_trace(
    go.Scatter(
        x=[MIN_AGE, MAX_AGE],
        y=[MIN_AGE, MAX_AGE],
        showlegend=False
    ),
    row = 1, col = 1
)

fig.add_trace(
    go.Scatter(
        x=df_pred_leo['prediction'],
        y=df_pred_leo['age'],
        mode='markers',
        showlegend=False
    ),
    row = 1, col = 1
)

fig.add_trace(
    go.Scatter(
        x=[MIN_AGE, MAX_AGE],
        y=[MIN_AGE, MAX_AGE],
        showlegend=False
    ),
    row=1, col=2
)

fig.add_trace(
    go.Scatter(
        x=df_pred_torch['prediction'],
        y=df_pred_torch['age'],
        mode='markers',
        showlegend=False
    ),
    row=1, col=2
)


mae_leo = np.mean(np.abs(df_pred_leo['prediction'] - df_pred_leo['age']))
mae_torch = np.mean(np.abs(df_pred_torch['prediction'] - df_pred_torch['age']))

fig.update_layout(
    title={'x': 0.5, 'text': f'{model.__class__.__name__} with weights={original_model.weight_name}'},
    width=1024,
    height=512,
    annotations=[
        {'x': 70, 'y': 15, 'text': f'MAE (Leo)={mae_leo:.2f}', 'showarrow': False, 'xref': 'x1', 'yref': 'y1'},
        {'x': 70, 'y': 15, 'text': f'MAE (Torch)={mae_torch:.2f}', 'showarrow': False, 'xref': 'x2', 'yref': 'y2'}
    ],
    xaxis1={'range': [MIN_AGE, MAX_AGE], 'title': 'Prediction'},
    yaxis1={'range': [MIN_AGE, MAX_AGE], 'title': 'Age'},
    xaxis2={'range': [MIN_AGE, MAX_AGE], 'title': 'Prediction'},
    yaxis2={'range': [MIN_AGE, MAX_AGE], 'title': 'Age'}
)

fig.show()