### Load the model

In [1]:
import torch
from Custom_CNN import Simple3DRegressionCNN
from collections import OrderedDict

checkpoint = torch.load('models/best_model.pt')
model = Simple3DRegressionCNN()
key_mapping = {
    "conv1.weight": "group1.0.weight",
    "conv1.bias": "group1.0.bias",
    "conv2.weight": "group2.0.weight",
    "conv2.bias": "group2.0.bias",
    "fc1.weight": "fc1.0.weight",
    "fc1.bias": "fc1.0.bias",
    "fc2.weight": "fc2.0.weight",
    "fc2.bias": "fc2.0.bias",
    "fc3.weight": "fc.0.weight",
    "fc3.bias": "fc.0.bias"
}
# Create a new state dictionary with the updated keys
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    if k in key_mapping:
        new_state_dict[key_mapping[k]] = v
    else:
        new_state_dict[k] = v

# Load the new state dictionary into the model
model.load_state_dict(new_state_dict, strict=False)

model.eval() # Set the model to evaluation mode

device = torch.device('cpu')
model = model.to(device)

In [2]:
import nibabel as nib
import numpy as np
from skimage.measure import block_reduce

def load_nifti(img_path):
    img_nii = nib.load(img_path)
    img_np = img_nii.get_fdata()
    img_np = img_np.astype(np.float32)

    img_np = block_reduce(img_np, (2, 2, 2), np.mean)

    img_np = np.expand_dims(img_np, axis=0)
    img_np = np.expand_dims(img_np, axis=0)

    # Convert to tensor
    img_tensor = torch.from_numpy(img_np)

    return img_tensor

In [4]:
def predict(model, img_tensor):
    with torch.no_grad():
        outputs = model(img_tensor)
        label = outputs.item()
    return label

sub = 'sub-CC00113XX06'
ses = 'ses-37200'
img_path = f'../tor/rel3_dhcp_anat_pipeline/{sub}/{ses}/anat/{sub}_{ses}_T2w.nii.gz'
img_tensor = load_nifti(img_path)
label = predict(model, img_tensor)
print(f'Predicted label: {label}')

Predicted label: 69.39547729492188
