In [None]:
#IMPORTS

from functions import *
import nibabel as nib
from monai.networks.nets import UNet
import torch

In [None]:
# SPECIFY PARAMETERS

# data
patch_size = (32, 32, 32)
stride = (16, 16, 16)
target_shape = (192, 224, 192)

# network
spatial_dims=3
in_channels=2
out_channels=1
channels=(32, 64, 128, 256, 512, 1024)
net_strides=(2, 2, 2, 2, 2)
res_units=10
norm=None

# path to this repo
DATA_DIR = "/Users/al1612le/mri-sr-bob"

In [None]:
#LOAD DATA

# Path to input images
t1 = "your_file_path/T1w.nii.gz" # isotropic t1w image
t2_lr = "your_file_path/T2w_LR.nii.gz" # anisotropic t2w image

# reassure correct shape and voxel size
assert nib.load(t1).shape == nib.load(t2_lr).shape == (182,218,182)
assert nib.load(t1).header.get_zooms() == nib.load(t2_lr).header.get_zooms() == (1.0,1.0,1.0)

In [None]:
# EXTRACT PATCHES

t1_patches = get_patches_single_img(t1, patch_size, stride, target_shape)
t2_lr_patches, affine = get_patches_single_img(t2_lr, patch_size, stride, target_shape)

In [None]:
# LOAD NETWORK

net = UNet(
    spatial_dims,
    in_channels,
    out_channels,
    channels,
    net_strides,
    res_units, 
    norm,
).load_state_dict(torch.load(DATA_DIR/"weights"/"2025-12-10T15:25:46.850860_model_weights.pth", map_location="cpu"))

In [None]:
# TEST NETWORK

all_outputs = []
net.eval()
with torch.no_grad():
    for i in range(len(t1_patches)-1):
        input1 = torch.tensor(t1_patches[i]).float()
        input2 = torch.tensor(t2_lr_patches[i]).float()
        inputs = torch.stack([input1, input2], dim=0).unsqueeze(0) 
        output = net(inputs)
        all_outputs.append(output.squeeze(0).squeeze(0).cpu().numpy()) 
    t2_reconstructed = reconstruct_from_patches(all_outputs, target_shape, stride)

In [None]:
# SAVE RECONSTRUCTED IMAGE

nib.save(nib.Nifti1Image(t2_reconstructed, affine), DATA_DIR/"demo_images"/"demo_reconstructed_t2.nii.gz")