In [None]:
#IMPORTS

from scripts.func.functions import *
import nibabel as nib
from monai.networks.nets import UNet
import torch
from huggingface_hub import hf_hub_download



In [None]:
# ---EDIT THIS PART---

# Path to input images - adjust this if you want to use other images than the demo images
t1 = hf_hub_download(repo_id="almalennartsson/baby-ai", filename="demo_t1w_input.nii.gz") # isotropic t1w image
t2_lr = hf_hub_download(repo_id="almalennartsson/baby-ai", filename="demo_t2w_input.nii.gz") # anisotropic t2w image

# Path to repository root
REPO_ROOT = ...

In [None]:
# SPECIFY PARAMETERS

# data
patch_size = (32, 32, 32)
stride = (16, 16, 16)
target_shape = (192, 224, 192)

# network
spatial_dims=3
in_channels=2
out_channels=1
channels=(32, 64, 128, 256, 512, 1024)
strides=(2, 2, 2, 2, 2)
num_res_units=10
norm=None

# model weights
model_weights = hf_hub_download(repo_id="almalennartsson/baby-ai", filename="pretrained_weights.pth") #load model weights from huggingface

In [None]:
# reassure correct shape and voxel size on input data
assert nib.load(t1).shape == nib.load(t2_lr).shape == (182,218,182)
assert nib.load(t1).header.get_zooms() == nib.load(t2_lr).header.get_zooms() == (1.0,1.0,1.0)

In [29]:
# EXTRACT PATCHES

t1_patches, affine = get_patches_single_img(t1, patch_size, stride, target_shape)
t2_lr_patches, affine = get_patches_single_img(t2_lr, patch_size, stride, target_shape)

In [None]:
# LOAD NETWORK

net = UNet(
    spatial_dims=spatial_dims,
    in_channels=in_channels,
    out_channels=out_channels,
    channels=channels,
    strides=strides,
    num_res_units=num_res_units,
    norm=norm
)
net.load_state_dict(torch.load(model_weights, map_location="cpu"))

In [None]:
# TEST NETWORK

all_outputs = []
net.eval()
with torch.no_grad():
    for i in range(len(t1_patches)):
        input1 = torch.tensor(t1_patches[i]).float()
        input2 = torch.tensor(t2_lr_patches[i]).float()
        inputs = torch.stack([input1, input2], dim=0) .unsqueeze(0)
        output = net(inputs)
        all_outputs.append(output.squeeze(0).squeeze(0).cpu().numpy()) 
    t2_reconstructed = reconstruct_from_patches(all_outputs, target_shape, stride)

In [None]:
# SAVE RECONSTRUCTED IMAGE

nib.save(nib.Nifti1Image(t2_reconstructed, affine), REPO_ROOT/"results"/"demo.nii.gz")