<a href="https://colab.research.google.com/github/Sergikavtaradze/2025-Summer-Research/blob/main/xQSM/python/eval/run_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# This is a demo for xQSM;

# import necessary packages

In [40]:
import os
import nibabel as nib
import numpy as np

# Define paths
data_path = "/Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/"
brain_masks_dir = os.path.join(data_path, "Brain_Masks")

# Collect all brain mask files
brain_masks = []
for mask_file in os.listdir(brain_masks_dir):
    if '.DS_Store' in mask_file:
        continue
    brain_masks.append(os.path.join(brain_masks_dir, mask_file))

# For each brain mask, determine the subject/session and apply mask to the two files
for mask_path in brain_masks:
    mask_filename = os.path.basename(mask_path)
    # Example mask_filename: sub-01_ses-01_Brain_Mask.nii.gz
    # Extract subject and session from filename
    parts = mask_filename.split('_')
    sub = None
    ses = None
    for part in parts:
        if part.startswith('sub-'):
            sub = part
        if part.startswith('ses-'):
            ses = part
    if sub and ses:
        qsm_dir = os.path.join(data_path, sub, ses, 'qsm')
        if not os.path.exists(qsm_dir):
            print(f"QSM directory does not exist for {sub} {ses}, skipping.")
            continue

        # Load brain mask
        mask_nii = nib.load(mask_path)
        mask_data = mask_nii.get_fdata()
        mask_bool = mask_data > 0

        # Define the two files to mask
        localfield_filename = f"{sub}_{ses}_unwrapped-SEGUE_mask-nfe_bfr-PDF_localfield.nii.gz"
        chimap_filename = f"{sub}_{ses}_unwrapped-SEGUE_mask-nfe_bfr-PDF_susc-autoNDI_Chimap.nii.gz"

        for target_filename in [localfield_filename, chimap_filename]:
            target_path = os.path.join(qsm_dir, target_filename)
            if not os.path.exists(target_path):
                print(f"File {target_path} not found, skipping.")
                continue

            # Load target image
            img_nii = nib.load(target_path)
            img_data = img_nii.get_fdata()

            # Apply mask
            masked_data = np.where(mask_bool, img_data, 0)

            # Save masked image to new file
            masked_filename = target_filename.replace('.nii.gz', '_masked.nii.gz').replace('.nii.gz', '_masked.nii.gz')
            masked_path = os.path.join(qsm_dir, masked_filename)
            masked_nii = nib.Nifti1Image(masked_data, img_nii.affine, img_nii.header)
            nib.save(masked_nii, masked_path)
            print(f"Masked file saved: {masked_path}")
    else:
        print(f"Could not determine subject/session for {mask_filename}")


Masked file saved: /Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-07/ses-03/qsm/sub-07_ses-03_unwrapped-SEGUE_mask-nfe_bfr-PDF_localfield_masked_masked.nii.gz
Masked file saved: /Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-07/ses-03/qsm/sub-07_ses-03_unwrapped-SEGUE_mask-nfe_bfr-PDF_susc-autoNDI_Chimap_masked_masked.nii.gz
Masked file saved: /Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-07/ses-02/qsm/sub-07_ses-02_unwrapped-SEGUE_mask-nfe_bfr-PDF_localfield_masked_masked.nii.gz
Masked file saved: /Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-07/ses-02/qsm/sub-07_ses-02_unwrapped-SEGUE_mask-nfe_bfr-PDF_susc-autoNDI_Chimap_masked_masked.nii.gz
Masked file saved: /Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-07/ses-01/qsm/sub-07_ses-01_unwrapped-SEGUE_mask-nfe_bfr-PDF_localfield_masked_masked.nii.gz
Masked file saved: /Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-07/ses-01/qsm

The masks where well. Now I need to crop the masked images, because due to masking a lot of the slices are now empty. To do this I need to find the image dimensions between which there are non zero values and then I will decide on the size of the cropping that needs to be done.

In [43]:
import os
import nibabel as nib
import numpy as np

# Define paths
data_path = "/Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/"
brain_masks_dir = os.path.join(data_path, "Brain_Masks")

# Collect all brain mask files
brain_masks = []
for mask_file in os.listdir(brain_masks_dir):
    if '.DS_Store' in mask_file:
        continue
    brain_masks.append(os.path.join(brain_masks_dir, mask_file))

def find_nonzero_crop_bounds(data):
    """Finds the bounding box of nonzero values in a 3D numpy array."""
    nonzero = np.nonzero(data)
    if len(nonzero[0]) == 0:
        # All zeros, return full slice (or could return None)
        return (0, data.shape[0]), (0, data.shape[1]), (0, data.shape[2])
    x_min, x_max = np.min(nonzero[0]), np.max(nonzero[0])
    y_min, y_max = np.min(nonzero[1]), np.max(nonzero[1])
    z_min, z_max = np.min(nonzero[2]), np.max(nonzero[2])
    return (x_min, x_max), (y_min, y_max), (z_min, z_max)

# Store all crop box edges and dimensions for averaging and output
crop_boxes = []
crop_dims = []
crop_info_lines = []

# For each brain mask, determine the subject/session and crop the masked images
for mask_path in brain_masks:
    mask_filename = os.path.basename(mask_path)
    # Example mask_filename: sub-01_ses-01_Brain_Mask.nii.gz
    # Extract subject and session from filename
    parts = mask_filename.split('_')
    sub = None
    ses = None
    for part in parts:
        if part.startswith('sub-'):
            sub = part
        if part.startswith('ses-'):
            ses = part
    if sub and ses:
        qsm_dir = os.path.join(data_path, sub, ses, 'qsm')
        if not os.path.exists(qsm_dir):
            print(f"QSM directory does not exist for {sub} {ses}, skipping.")
            continue

        # Define the two masked files to crop
        masked_localfield_filename = f"{sub}_{ses}_unwrapped-SEGUE_mask-nfe_bfr-PDF_localfield_masked_masked.nii.gz"
        masked_chimap_filename = f"{sub}_{ses}_unwrapped-SEGUE_mask-nfe_bfr-PDF_susc-autoNDI_Chimap_masked_masked.nii.gz"

        for masked_filename in [masked_localfield_filename, masked_chimap_filename]:
            masked_path = os.path.join(qsm_dir, masked_filename)
            if not os.path.exists(masked_path):
                print(f"Masked file {masked_path} not found, skipping.")
                continue

            # Load masked image
            img_nii = nib.load(masked_path)
            img_data = img_nii.get_fdata()

            # Find crop bounds
            (x_min, x_max), (y_min, y_max), (z_min, z_max) = find_nonzero_crop_bounds(img_data)

            # Optionally, you can add a margin to the crop (e.g., 2 voxels)
            margin = 0
            x_min_c = max(x_min - margin, 0)
            x_max_c = min(x_max + margin, img_data.shape[0] - 1)
            y_min_c = max(y_min - margin, 0)
            y_max_c = min(y_max + margin, img_data.shape[1] - 1)
            z_min_c = max(z_min - margin, 0)
            z_max_c = min(z_max + margin, img_data.shape[2] - 1)

            # Crop the image
            cropped_data = img_data[x_min_c:x_max_c+1, y_min_c:y_max_c+1, z_min_c:z_max_c+1]

            # Adjust the affine to account for cropping
            affine = img_nii.affine.copy()
            # The translation part (last column) needs to be updated
            orig_offset = np.array([x_min_c, y_min_c, z_min_c])
            affine[:3, 3] = affine[:3, 3] + affine[:3, :3].dot(orig_offset)

            # Save cropped image
            cropped_filename = masked_filename.replace('_masked.nii.gz', '_cropped.nii.gz')
            cropped_path = os.path.join(qsm_dir, cropped_filename)
            cropped_nii = nib.Nifti1Image(cropped_data, affine, img_nii.header)
            nib.save(cropped_nii, cropped_path)
            print(f"Cropped file saved: {cropped_path}")

            # Record crop box and dimensions for output
            crop_boxes.append([x_min_c, x_max_c, y_min_c, y_max_c, z_min_c, z_max_c])
            crop_dims.append([x_max_c - x_min_c + 1, y_max_c - y_min_c + 1, z_max_c - z_min_c + 1])
            crop_info_lines.append(
                f"{cropped_filename}\tX:({x_min_c},{x_max_c}) Y:({y_min_c},{y_max_c}) Z:({z_min_c},{z_max_c})\t"
                f"Shape: {cropped_data.shape}"
            )
    else:
        print(f"Could not determine subject/session for {mask_filename}")

# Write crop info to txt file
output_txt_path = os.path.join(data_path, "cropped_image_dimensions.txt")
with open(output_txt_path, "w") as f:
    f.write("Filename\tCrop Box (x_min,x_max y_min,y_max z_min,z_max)\tShape\n")
    for line in crop_info_lines:
        f.write(line + "\n")

# Compute and print average box edges and average dimensions
if crop_boxes:
    crop_boxes_np = np.array(crop_boxes)
    crop_dims_np = np.array(crop_dims)
    avg_edges = np.mean(crop_boxes_np, axis=0)
    avg_dims = np.mean(crop_dims_np, axis=0)
    print("\nAverage crop box edges (x_min, x_max, y_min, y_max, z_min, z_max):")
    print(avg_edges)
    print("Average cropped image dimensions (x, y, z):")
    print(avg_dims)
    # Also write to txt file
    with open(output_txt_path, "a") as f:
        f.write("\nAverage crop box edges (x_min, x_max, y_min, y_max, z_min, z_max):\n")
        f.write(str(avg_edges) + "\n")
        f.write("Average cropped image dimensions (x, y, z):\n")
        f.write(str(avg_dims) + "\n")
else:
    print("No crop boxes found; nothing to average.")


Cropped file saved: /Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-07/ses-03/qsm/sub-07_ses-03_unwrapped-SEGUE_mask-nfe_bfr-PDF_localfield_masked_cropped.nii.gz
Cropped file saved: /Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-07/ses-03/qsm/sub-07_ses-03_unwrapped-SEGUE_mask-nfe_bfr-PDF_susc-autoNDI_Chimap_masked_cropped.nii.gz
Cropped file saved: /Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-07/ses-02/qsm/sub-07_ses-02_unwrapped-SEGUE_mask-nfe_bfr-PDF_localfield_masked_cropped.nii.gz
Cropped file saved: /Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-07/ses-02/qsm/sub-07_ses-02_unwrapped-SEGUE_mask-nfe_bfr-PDF_susc-autoNDI_Chimap_masked_cropped.nii.gz
Cropped file saved: /Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-07/ses-01/qsm/sub-07_ses-01_unwrapped-SEGUE_mask-nfe_bfr-PDF_localfield_masked_cropped.nii.gz
Cropped file saved: /Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-07

In [16]:
import torch
import torch.nn as nn
import sys
sys.path.append('..')
import numpy as np
import nibabel as nib
import scipy.io as scio
from collections import OrderedDict
import time

import sys
sys.path.append('/Users/sirbucks/Documents/xQSM/2025-Summer-Research/xQSM/python')

from Unet_file import Unet
from xQSM_file import xQSM

sys.path.append('/content/2025-Summer-Research/xQSM/python/eval')
from utils import ssim, psnr


### PSNR = Peak Signal-to-Noise Ratio

A numerical measure of how similar two images are.
Typically used to measure how well a compressed or reconstructed image matches the original.
Higher PSNR → images are more similar.

### SSIM = Structural Similarity Index

Unlike PSNR, SSIM aims to model perceived image quality by the human visual system.

How SSIM works

SSIM compares images based on:

1. Luminance similarity
2. Contrast similarity
3. Structure similarity

# define utility functions

Read local field map from nifti files

In [17]:
def Read_nii(path):
    nibField = nib.load(path)
    Field = nibField.get_fdata()
    aff = nibField.affine
    Field = np.array(Field)
    return Field, aff

In [19]:
def ZeroPadding(Field, factor = 8):
    ImSize = np.shape(Field)
    UpSize = np.ceil(ImSize / factor) * factor  # calculate the padding size;
    pos_init = np.ceil((UpSize - ImSize) / 2)
    pos_end = pos_init + ImSize - 1
    tmp_Field = np.zeros(UpSize)
    tmp_Field[pos_init[1]:pos_end[1], pos_init[2]:pos_end[2], pos_init[3]:pos_end[3]] = Field
    Field = tmp_Field
    pos = np.zeros([3, 2])
    pos[:,0] = pos_init
    pos[:,1] = pos_end
    return Field, pos

# Field, aff = Read_nii('/Users/sirbucks/Documents/xQSM/2025-Summer-Research/xQSM/xQSM_Checkpoints/field_input.nii')
# imSize = np.shape(Field)
# print(imSize)
# print(np.mod(imSize,  8).any())
# print(np.mod())
# if np.mod(imSize,  8).any():
#     Field, pos = ZeroPadding(Field, 8)  # ZeroPadding
# Field = torch.from_numpy(Field)


In [13]:
192*256*176

8650752

In [14]:
8650752/8

1081344.0

#### ZeroPadding 
to make the size of the field divisible by the designated factor;
          Field: local field map;
          pos: the position information of padding;

In [20]:
def ZeroRemoving(Field, pos):
    Field = Field[pos_init[1]:pos_end[1], pos_init[2]:pos_end[2], pos_init[3]:pos_end[3]]
    return Field

markdown: ZeroRemoving: inverse function of ZeroPadding;

In [21]:
def Save_nii(Recon, aff, path):
    nibRecon = nib.Nifti1Image(Recon,aff)
    nib.save(nibRecon, path)

markdown: save the results in nii format;

In [22]:
def Save_mat(Recon, path):
    """
    save the results in mat format;
    """
    scio.savemat(path, {'Recon':Recon})

markdown: save results in .mat format;

# define evaluation function for the networks

In [25]:
def Eval(Field, NetName):
    with torch.no_grad():
        ## Network Load;
        print('Load Pretrained Network')
        model_weights_path = NetName + '.pth'
        if 'Unet' in NetName:
            Net = Unet(2)
        elif 'xQSM' in NetName:
            Net = xQSM(2, initial_num_layers=64)
        else:
            sys.stderr.write('Network Type Invalid!\n')
        if torch.cuda.is_available():  ## if GPU is available;
            Net = nn.DataParallel(Net) ## our network is trained with dataparallel wrapper;
            Net.load_state_dict(torch.load(model_weights_path))
            Net = Net.module
            device = torch.device("cuda:0")
            Net.to(device)
            Net.eval()  ## set the model to evaluation mode
            Field = Field.to(device)
        else:
            weights = torch.load(model_weights_path, map_location='cpu', weights_only=True)
            new_state_dict = OrderedDict()
            print(new_state_dict)
            for k, v in weights.items():
                ## remove the first 7 charecters  "module." of the network weights
                ## files to load the net into cpu, because our network is saved
                ## as with dataparallel wrapper.
                name = k[7:]
                new_state_dict[name] = v
            Net.load_state_dict(new_state_dict)
            Net.eval()  ## set the model to evaluation mode
        ################ Evaluation ##################
        time_start = time.time()
        Recon = Net(Field)
        time_end = time.time()
        print('%f seconds elapsed!' % (time_end - time_start))
        Recon = torch.squeeze(Recon, 0)
        Recon = torch.squeeze(Recon, 0)
        Recon = Recon.to('cpu')  ## transfer to cpu for saving.
        Recon = Recon.numpy()
    return Recon


markdown: Eval(Field, Netype, Env) retunrs the QSM reconstruction of the local field map (Field)
          using a designated Network (NetName);

# Demonstration on a simulated COSMOS data;

In [None]:
with torch.no_grad():
    ## Data Load;
    print('Data Loading')
    Field, aff = Read_nii('/Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-09/ses-01/qsm/sub-09_ses-01_unwrapped-SEGUE_mask-nfe_bfr-PDF_localfield.nii.gz')
    print('Loading Completed')
    mask = Field != 0
    ## note the size of the field map input needs to be divisibel by the factor
    ## otherwise 0 padding should be done first
    print('ZeroPadding')
    imSize = np.shape(Field)
    if np.mod(imSize,  8).any(): # The demo data field.nii is divisible by 8
        Field, pos = ZeroPadding(Field, 8)  # ZeroPadding
    Field = torch.from_numpy(Field)
    ## The networks in pytorch only supports inputs that are a mini-batch of samples,
    ## and not a single sample. Therefore we need  to squeeze the 3D tensor to be
    ## a 5D tesor for model evaluation.
    Field = torch.unsqueeze(Field, 0)
    Field = torch.unsqueeze(Field, 0)
    Field = Field.float()
    ## QSM Reconstruction
    print('Reconstruction')
    #Recon_xQSM_invivo_TL = Eval(Field, '/Users/sirbucks/Documents/xQSM/2025-Summer-Research/xQSM/ckpt/Trial_1_Jun17/ckpt/xQSM_TransferLearning_Best')
    Recon_xQSM_invivo = Eval(Field, '/Users/sirbucks/Documents/xQSM/2025-Summer-Research/xQSM/Pretrained_Checkpoints/xQSM_invivo')
    Recon_Unet_invivo = Eval(Field, '/Users/sirbucks/Documents/xQSM/2025-Summer-Research/xQSM/Pretrained_Checkpoints/Unet_invivo')
    #Recon_xQSM_syn = Eval(Field, 'xQSM_syn')
    #Recon_Unet_syn = Eval(Field, 'Unet_syn')
    if np.mod(imSize,  8).any():
        Recon_xQSM_invivo  = ZeroRemoving(Recon_xQSM_invivo , pos) # ZeroRemoving if zeropadding were performed;
        Recon_Unet_invivo  = ZeroRemoving(Recon_Unet_invivo , pos)
    Recon_xQSM_invivo = Recon_xQSM_invivo * mask
    Recon_Unet_invivo = Recon_Unet_invivo * mask
    ## calculate PSNR and SSIM
    label, aff = Read_nii('/Users/sirbucks/Documents/xQSM/2025-Summer-Research/QSM_data/sub-09/ses-01/qsm/sub-09_ses-01_unwrapped-SEGUE_mask-nfe_bfr-PDF_susc-autoNDI_Chimap.nii.gz')  # read label;
    print('PSNR of xQSM_invivo is %f'% (psnr(Recon_xQSM_invivo, label)))
    print('PSNR of Unet_invivo is %f'% (psnr(Recon_Unet_invivo, label)))
    ## Saving Results (in .mat)
    #print('saving reconstructions')
    # path = './Chi_xQSM_invivo.mat'
    # Save_mat(Recon_xQSM_invivo, path)
    # path = './Chi_Unet_invivo.mat'
    # Save_mat(Recon_Unet_invivo, path)
    #path = './Chi_xQSM_syn.mat'
    #Save_mat(Recon_xQSM_syn, path)
    #path = './Chi_Unet_syn.mat'
    #Save_mat(Recon_Unet_syn, path)
    ## or can be stored in .nii format;
    # path = 'Chi_xQSM_invivo.nii'
    # Save_nii(Recon_xQSM_invivo, aff, path)
    #path = 'Chi_Unet_invivo.nii'
    #Save_nii(Recon_Unet_invivo, aff, path)
    #path = 'Chi_xQSM_syn.nii'
    #Save_nii(Recon_xQSM_syn, aff, path)
    #path = 'Chi_Unet_syn.nii'
    #Save_nii(Recon_xQSM_invivo, aff, path)

Data Loading
Loading Completed
ZeroPadding
Reconstruction
Load Pretrained Network
OrderedDict()
79.389482 seconds elapsed!
Load Pretrained Network
OrderedDict()
219.273347 seconds elapsed!
PSNR of xQSM_invivo is 4.248406


In [27]:
print('PSNR of Unet_invivo is %f'% (psnr(Recon_Unet_invivo, label)))

PSNR of Unet_invivo is 5.524919
