### Load the model

In [8]:
import torch
import torch.nn as nn
import torchvision.models as models

checkpoint = torch.load('models/best_model.pt')
model = models.resnet18()
model.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(7, 7))
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=1)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval() # Set the model to evaluation mode

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [12]:
import nibabel as nib
import numpy as np
from skimage.measure import block_reduce

def load_nifti(img_path):
    img_nii = nib.load(img_path)
    img_np = img_nii.get_fdata()
    img_np = img_np.astype(np.float32)

    # Slice the 3D image along the depth dimension
    num_slices = img_np.shape[0]
    slices = []
    pool_size = 10 # For avg/max pooling to reduce the number of slices (GPU memory issue)
    for i in range(0, num_slices, pool_size):
        slice_3d = img_np[i:i+pool_size, :, :] # Get a 3D block
        slice_2d = block_reduce(slice_3d, (pool_size, 1, 1), np.mean) # Apply avg pooling
        slices.append(slice_2d)
        
    # Stack the slices along the batch dimension
    img_np_stacked = np.stack(slices, axis=0)

    # Convert to tensor
    img_tensor = torch.from_numpy(img_np_stacked)

    return img_tensor

In [13]:
def predict(model, img_tensor):
    with torch.no_grad():
        outputs = model(img_tensor)
        label = outputs.argmax().item()
    return label

img_path = '../tor/rel3_dhcp_anat_pipeline/sub-CC00251XX04/ses-83800/anat/sub-CC00251XX04_ses-83800_T2w.nii.gz' # Should be 19
img_tensor = load_nifti(img_path)
img_tensor = img_tensor.to(device)
label = predict(model, img_tensor)
print(f'Predicted label: {label}')

Predicted label: 32.233299255371094
