In [1]:
# inference.py

import torch
import torch.nn as nn
import numpy as np
import nibabel as nib
import os

from models.model import UNet2D  # Adjust if your model file/path is different

def load_nii_as_numpy(filepath):
    """
    Load a NIfTI file and return the NumPy array and affine.
    """
    nii_img = nib.load(filepath)
    data = nii_img.get_fdata()  # shape [D, H, W] (or [H, W, D], depending on orientation)
    affine = nii_img.affine     # to preserve the spatial orientation
    return data, affine

def predict_volume_2d(model, volume_3d):
    """
    Given a 3D volume (NumPy array) and a 2D model:
    1. Slice the volume along the first dimension [D, H, W].
    2. Pass each slice through the model.
    3. Return the predicted 3D volume.
    """

    model.eval()  # set model to eval mode

    depth = volume_3d.shape[0]
    preds_3d = []

    with torch.no_grad():
        for i in range(depth):
            # Extract a single slice: shape [H, W]
            slice_2d = volume_3d[i, :, :].astype(np.float32)

            # (Optional) Normalize or apply the same preprocessing as in training
            # e.g.: slice_2d = (slice_2d - slice_2d.mean()) / slice_2d.std()

            # Add batch dimension and channel dimension => shape [1, 1, H, W]
            input_tensor = torch.from_numpy(slice_2d).unsqueeze(0).unsqueeze(0)

            # Move to GPU if available
            device = next(model.parameters()).device
            input_tensor = input_tensor.to(device)

            # Forward pass
            output = model(input_tensor)  # shape [1, 1, H, W]

            # Convert back to NumPy
            output_slice = output.squeeze().cpu().numpy()

            preds_3d.append(output_slice)

    # Stack along the depth dimension => shape [D, H, W]
    preds_3d = np.stack(preds_3d, axis=0)
    return preds_3d

def main():
    # 1) Load your trained model
    model = UNet2D(in_channels=1, out_channels=1)
    # Adjust path to your saved weights
    model.load_state_dict(torch.load("superres_unet_v4.pth", map_location="cpu"))
    model.cuda() if torch.cuda.is_available() else None

    # 2) Load the LR volume you want to upsample
    #    Make sure this LR volume is pre-resampled to match the desired shape in 3D (if needed).
    lr_path = "data/sub-OAS30001_axial_upsampled.nii"
    lr_volume, lr_affine = load_nii_as_numpy(lr_path)  # shape [D, H, W]

    # 3) Predict the super-resolved volume (using 2D slices)
    sr_preds = predict_volume_2d(model, lr_volume)  # shape [D, H, W]

    # 4) Save the predicted volume as a NIfTI
    sr_nifti = nib.Nifti1Image(sr_preds, lr_affine)
    out_path = "data/sub-OAS30001_sr_prediction.nii"
    nib.save(sr_nifti, out_path)
    print(f"Saved super-resolved volume to: {out_path}")

if __name__ == "__main__":
    main()


  model.load_state_dict(torch.load("superres_unet_v4.pth", map_location="cpu"))


RuntimeError: Error(s) in loading state_dict for UNet2D:
	Missing key(s) in state_dict: "inc.double_conv.0.weight", "inc.double_conv.0.bias", "inc.double_conv.1.weight", "inc.double_conv.1.bias", "inc.double_conv.1.running_mean", "inc.double_conv.1.running_var", "inc.double_conv.3.weight", "inc.double_conv.3.bias", "inc.double_conv.4.weight", "inc.double_conv.4.bias", "inc.double_conv.4.running_mean", "inc.double_conv.4.running_var", "down1.1.double_conv.0.weight", "down1.1.double_conv.0.bias", "down1.1.double_conv.1.weight", "down1.1.double_conv.1.bias", "down1.1.double_conv.1.running_mean", "down1.1.double_conv.1.running_var", "down1.1.double_conv.3.weight", "down1.1.double_conv.3.bias", "down1.1.double_conv.4.weight", "down1.1.double_conv.4.bias", "down1.1.double_conv.4.running_mean", "down1.1.double_conv.4.running_var", "down2.1.double_conv.0.weight", "down2.1.double_conv.0.bias", "down2.1.double_conv.1.weight", "down2.1.double_conv.1.bias", "down2.1.double_conv.1.running_mean", "down2.1.double_conv.1.running_var", "down2.1.double_conv.3.weight", "down2.1.double_conv.3.bias", "down2.1.double_conv.4.weight", "down2.1.double_conv.4.bias", "down2.1.double_conv.4.running_mean", "down2.1.double_conv.4.running_var", "down3.1.double_conv.0.weight", "down3.1.double_conv.0.bias", "down3.1.double_conv.1.weight", "down3.1.double_conv.1.bias", "down3.1.double_conv.1.running_mean", "down3.1.double_conv.1.running_var", "down3.1.double_conv.3.weight", "down3.1.double_conv.3.bias", "down3.1.double_conv.4.weight", "down3.1.double_conv.4.bias", "down3.1.double_conv.4.running_mean", "down3.1.double_conv.4.running_var", "down4.1.double_conv.0.weight", "down4.1.double_conv.0.bias", "down4.1.double_conv.1.weight", "down4.1.double_conv.1.bias", "down4.1.double_conv.1.running_mean", "down4.1.double_conv.1.running_var", "down4.1.double_conv.3.weight", "down4.1.double_conv.3.bias", "down4.1.double_conv.5.weight", "down4.1.double_conv.5.bias", "down4.1.double_conv.5.running_mean", "down4.1.double_conv.5.running_var", "conv_up1.double_conv.1.weight", "conv_up1.double_conv.1.bias", "conv_up1.double_conv.1.running_mean", "conv_up1.double_conv.1.running_var", "conv_up1.double_conv.3.weight", "conv_up1.double_conv.3.bias", "conv_up1.double_conv.4.weight", "conv_up1.double_conv.4.bias", "conv_up1.double_conv.4.running_mean", "conv_up1.double_conv.4.running_var", "conv_up2.double_conv.1.weight", "conv_up2.double_conv.1.bias", "conv_up2.double_conv.1.running_mean", "conv_up2.double_conv.1.running_var", "conv_up2.double_conv.3.weight", "conv_up2.double_conv.3.bias", "conv_up2.double_conv.4.weight", "conv_up2.double_conv.4.bias", "conv_up2.double_conv.4.running_mean", "conv_up2.double_conv.4.running_var", "up3.weight", "up3.bias", "conv_up3.double_conv.0.weight", "conv_up3.double_conv.0.bias", "conv_up3.double_conv.1.weight", "conv_up3.double_conv.1.bias", "conv_up3.double_conv.1.running_mean", "conv_up3.double_conv.1.running_var", "conv_up3.double_conv.3.weight", "conv_up3.double_conv.3.bias", "conv_up3.double_conv.4.weight", "conv_up3.double_conv.4.bias", "conv_up3.double_conv.4.running_mean", "conv_up3.double_conv.4.running_var", "up4.weight", "up4.bias", "conv_up4.double_conv.0.weight", "conv_up4.double_conv.0.bias", "conv_up4.double_conv.1.weight", "conv_up4.double_conv.1.bias", "conv_up4.double_conv.1.running_mean", "conv_up4.double_conv.1.running_var", "conv_up4.double_conv.3.weight", "conv_up4.double_conv.3.bias", "conv_up4.double_conv.4.weight", "conv_up4.double_conv.4.bias", "conv_up4.double_conv.4.running_mean", "conv_up4.double_conv.4.running_var", "outc.weight", "outc.bias". 
	Unexpected key(s) in state_dict: "conv_down1.double_conv.0.weight", "conv_down1.double_conv.0.bias", "conv_down1.double_conv.2.weight", "conv_down1.double_conv.2.bias", "conv_down2.double_conv.0.weight", "conv_down2.double_conv.0.bias", "conv_down2.double_conv.2.weight", "conv_down2.double_conv.2.bias", "conv_bottom.double_conv.0.weight", "conv_bottom.double_conv.0.bias", "conv_bottom.double_conv.2.weight", "conv_bottom.double_conv.2.bias", "conv_out.weight", "conv_out.bias", "conv_up1.double_conv.2.weight", "conv_up1.double_conv.2.bias", "conv_up2.double_conv.2.weight", "conv_up2.double_conv.2.bias". 
	size mismatch for up1.weight: copying a param with shape torch.Size([128, 64, 2, 2]) from checkpoint, the shape in current model is torch.Size([1024, 512, 2, 2]).
	size mismatch for up1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for conv_up1.double_conv.0.weight: copying a param with shape torch.Size([64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 1024, 3, 3]).
	size mismatch for conv_up1.double_conv.0.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for up2.weight: copying a param with shape torch.Size([256, 128, 2, 2]) from checkpoint, the shape in current model is torch.Size([512, 256, 2, 2]).
	size mismatch for up2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for conv_up2.double_conv.0.weight: copying a param with shape torch.Size([128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).
	size mismatch for conv_up2.double_conv.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).

In [None]:
# inference.py

import torch
import torch.nn as nn
import numpy as np
import nibabel as nib
import os
import torchio as tio  # for resampling

from models.model import UNet2D  # Adjust if your model file/path is different

def load_and_resample_nii(filepath, target_spacing):
    """
    Load a NIfTI file using TorchIO and resample it to the target spacing.
    
    Parameters:
        filepath (str): Path to the original LR NIfTI file.
        target_spacing (tuple): Desired voxel spacing (e.g., (1.0, 1.0, 1.0)).
        
    Returns:
        data (numpy.ndarray): The resampled volume with shape [D, H, W].
        affine (numpy.ndarray): The affine matrix of the resampled image.
    """
    image = tio.ScalarImage(filepath)
    resample_transform = tio.Resample(target_spacing)
    resampled_image = resample_transform(image)
    # Convert the image to a NumPy array.
    # TorchIO returns a 4D array with shape [C, D, H, W]; for a single channel, take the first index.
    data = resampled_image.numpy()[0]  # shape: [D, H, W]
    affine = resampled_image.affine
    return data, affine

def predict_volume_2d(model, volume_3d):
    """
    Given a 3D volume (NumPy array) and a 2D model:
      1. Slice the volume along the first dimension [D, H, W].
      2. Pass each slice through the model.
      3. Return the predicted 3D volume.
    """
    model.eval()  # set model to eval mode

    depth = volume_3d.shape[0]
    preds_3d = []

    with torch.no_grad():
        for i in range(depth):
            # Extract a single slice: shape [H, W]
            slice_2d = volume_3d[i, :, :].astype(np.float32)

            # (Optional) Normalize or apply the same preprocessing as in training
            # e.g.: slice_2d = (slice_2d - slice_2d.mean()) / slice_2d.std()

            # Add batch dimension and channel dimension => shape [1, 1, H, W]
            input_tensor = torch.from_numpy(slice_2d).unsqueeze(0).unsqueeze(0)

            # Move to GPU if available
            device = next(model.parameters()).device
            input_tensor = input_tensor.to(device)

            # Forward pass
            output = model(input_tensor)  # shape [1, 1, H, W]

            # Convert back to NumPy
            output_slice = output.squeeze().cpu().numpy()

            preds_3d.append(output_slice)

    # Stack along the depth dimension => shape [D, H, W]
    preds_3d = np.stack(preds_3d, axis=0)
    return preds_3d

def main():
    # 1) Load your trained model
    model = UNet2D(in_channels=1, out_channels=1)
    # Adjust path to your saved weights
    model.load_state_dict(torch.load("superres_unet_v2.pth", map_location="cpu"))
    if torch.cuda.is_available():
        model.cuda()

    # 2) Load the original LR volume and resample on the fly.
    # Since you didn't pre-resample, use the original LR file.
    lr_path = "data/sub-OAS30001_ses-d0129_run-01_T1w_axial_LR.nii.gz"
    # Define the target spacing (adjust these values to match your training resolution)
    target_spacing = (1.0, 1.0, 1.0)
    lr_volume, lr_affine = load_and_resample_nii(lr_path, target_spacing)  # shape [D, H, W]

    # 3) Predict the super-resolved volume (using 2D slices)
    sr_preds = predict_volume_2d(model, lr_volume)  # shape [D, H, W]

    # 4) Save the predicted volume as a NIfTI file
    sr_nifti = nib.Nifti1Image(sr_preds, lr_affine)
    out_path = "data/sub-OAS30001_sr_prediction.nii.gz"
    nib.save(sr_nifti, out_path)
    print(f"Saved super-resolved volume to: {out_path}")

if __name__ == "__main__":
    main()
