In [7]:
#TEST ONE IMAGE
import pathlib as p
from functions import *
import nibabel as nib
import numpy as np
import random
import torch
from monai.networks.nets import UNet
from monai.networks.layers.factories import Norm
import matplotlib.pyplot as plt

#Parameters
patch_size = (32,32,32)
stride = (16,16,16)
target_shape = (192,224,192)
SAVE_PATH = p.Path.home()/"save_path"

#Load files
DATA_DIR = p.Path.home()/"data"/"bobsrepository"

t1_files = sorted(DATA_DIR.rglob("*T1w.nii.gz"))
t2_files = sorted(DATA_DIR.rglob("*T2w.nii.gz"))
t2_LR_files = sorted(DATA_DIR.rglob("*T2w_LR.nii.gz"))

print(f"Using T1 file: {t1_files[0]} of size {nib.load(t1_files[0]).shape}")
print(f"Using T2 file: {t2_files[0]} of size {nib.load(t2_files[0]).shape}")
print(f"Using T2 LR file: {t2_LR_files[0]} of size {nib.load(t2_LR_files[0]).shape}")

#reassure correct shape and voxel size
assert nib.load(t1_files[0]).shape == nib.load(t2_files[0]).shape == nib.load(t2_LR_files[0]).shape == (182,218,182)
assert nib.load(t1_files[0]).header.get_zooms() == nib.load(t2_files[0]).header.get_zooms() == nib.load(t2_LR_files[0]).header.get_zooms() == (1.0,1.0,1.0)

t1_patches, affine = get_patches_single_img(t1_files[0], patch_size, stride, target_shape)
t2_lr_patches, affine = get_patches_single_img(t2_LR_files[0], patch_size, stride, target_shape)

print(f"Extracted {len(t1_patches)} patches of size {t1_patches[0].shape} from T1 image")
print(f"Extracted {len(t2_lr_patches)} patches of size {t2_lr_patches[0].shape} from T2 LR image")

#Load pretrained model weights
net = UNet(
    spatial_dims=3,
    in_channels=2,
    out_channels=1,
    channels=(32, 64, 128, 256, 512, 1024),
    strides=(2, 2, 2, 2, 2),
    num_res_units=10, 
    norm=None,
)
net.load_state_dict(torch.load(DATA_DIR/"outputs"/"2025-12-10T15:25:46.850860_model_weights.pth", map_location="cpu"))

#Generate reconstructed image from patches
all_outputs = []
net.eval()
with torch.no_grad():
    for i in range(len(t1_patches)):
        input1 = torch.tensor(t1_patches[i]).float()
        input2 = torch.tensor(t2_lr_patches[i]).float()
        inputs = torch.stack([input1, input2], dim=0).unsqueeze(0)  # (1, 2, 32, 32, 32)
        output = net(inputs)
        all_outputs.append(output.squeeze(0).squeeze(0).cpu().numpy())  # (32, 32, 32)
    reconstructed_t2 = reconstruct_from_patches(all_outputs, target_shape, stride)

#Save nfiti-file
nib.save(nib.Nifti1Image(reconstructed_t2, affine), SAVE_PATH/"reconstructed_t2.nii.gz")



Using T1 file: /Users/al1612le/data/bobsrepository/sub-116056/ses-3mo/anat/sub-116056_ses-3mo_space-INFANTMNIacpc_T1w.nii.gz of size (182, 218, 182)
Using T2 file: /Users/al1612le/data/bobsrepository/sub-116056/ses-3mo/anat/sub-116056_ses-3mo_space-INFANTMNIacpc_T2w.nii.gz of size (182, 218, 182)
Using T2 LR file: /Users/al1612le/data/bobsrepository/LR_data/axial/even/LR2/sub-116056_ses-3mo_space-INFANTMNIacpc_T2w_LR.nii.gz of size (182, 218, 182)
Extracted 1573 patches of size (32, 32, 32) from T1 image
Extracted 1573 patches of size (32, 32, 32) from T2 LR image


  net.load_state_dict(torch.load(DATA_DIR/"outputs"/"2025-12-10T15:25:46.850860_model_weights.pth", map_location="cpu"))
