# This is a demo for xQSM;

# import necessary packages

In [9]:
import torch 
import torch.nn as nn
import sys
sys.path.append('..')
import numpy as np
import nibabel as nib
import scipy.io as scio
from collections import OrderedDict
import time
from xQSM import *
from Unet import *
from utils import ssim, psnr

# define utility functions

In [10]:
def ZeroPadding(Field, factor = 8):
    ImSize = np.shape(Field)
    UpSize = np.ceil(ImSize / factor) * factor  # calculate the padding size; 
    pos_init = np.ceil((UpSize - ImSize) / 2) 
    pos_end = pos_init + ImSize - 1
    tmp_Field = np.zeros(UpSize)
    tmp_Field[pos_init[1]:pos_end[1], pos_init[2]:pos_end[2], pos_init[3]:pos_end[3]] = Field
    Field = tmp_Field
    pos = np.zeros([3, 2])
    pos[:,0] = pos_init
    pos[:,1] = pos_end
    return Field, pos

markdown: ZeroPadding to make the size of the field divisible by the designated factor; 
          Field: local field map; 
          pos: the position information of padding; 

In [11]:
def ZeroRemoving(Field, pos):
    Field = Field[pos_init[1]:pos_end[1], pos_init[2]:pos_end[2], pos_init[3]:pos_end[3]]
    return Field

markdown: ZeroRemoving: inverse function of ZeroPadding; 

In [12]:
def Read_nii(path):
    nibField = nib.load(path)
    Field = nibField.get_fdata() 
    aff = nibField.affine
    Field = np.array(Field)
    return Field, aff

markdown: read local field map from nifti fils; 

In [13]:
def Save_nii(Recon, aff, path):
    nibRecon = nib.Nifti1Image(Recon,aff)
    nib.save(nibRecon, path) 

markdown: save the results in nii format; 

In [14]:
def Save_mat(Recon, path):
    """
    save the results in mat format; 
    """
    scio.savemat(path, {'Recon':Recon})   

markdown: save results in .mat format;

# define evaluation function for the networks

In [17]:
def Eval(Field, NetName):
    with torch.no_grad(): 
        ## Network Load; 
        print('Load Pretrained Network')
        model_weights_path = NetName + '.pth'
        if 'xQSM' in NetName:
            Net = xQSM(2)
        elif 'Unet' in NetName:
            Net = Unet(2)
        else:
            sys.stderr.write('Network Type Invalid!\n')
        if torch.cuda.is_available():  ## if GPU is available; 
            Net = nn.DataParallel(Net) ## our network is trained with dataparallel wrapper;
            Net.load_state_dict(torch.load(model_weights_path))
            Net = Net.module
            device = torch.device("cuda:0")
            Net.to(device)
            Net.eval()  ## set the model to evaluation mode
            Field = Field.to(device)
        else:
            weights = torch.load(model_weights_path, map_location='cpu')
            new_state_dict = OrderedDict()
            print(new_state_dict)
            for k, v in weights.items():
                ## remove the first 7 charecters  "module." of the network weights 
                ## files to load the net into cpu, because our network is saved 
                ## as with dataparallel wrapper. 
                name = k[7:]  
                new_state_dict[name] = v
            Net.load_state_dict(new_state_dict)
            Net.eval()  ## set the model to evaluation mode
        ################ Evaluation ##################
        time_start = time.time()
        Recon = Net(Field)
        time_end = time.time()
        print('%f seconds elapsed!' % (time_end - time_start))
        Recon = torch.squeeze(Recon, 0)
        Recon = torch.squeeze(Recon, 0)
        Recon = Recon.to('cpu')  ## transfer to cpu for saving. 
        Recon = Recon.numpy()
    return Recon


markdown: Eval(Field, Netype, Env) retunrs the QSM reconstruction of the local field map (Field)
          using a designated Network (NetName); 

# Demonstration on a simulated COSMOS data;

In [19]:
with torch.no_grad(): 
    ## Data Load;        
    print('Data Loading')   
    Field, aff = Read_nii('field_input.nii')
    print('Loading Completed')
    mask = Field != 0 
    ## note the size of the field map input needs to be divisibel by the factor
    ## otherwise 0 padding should be done first
    print('ZeroPadding')
    imSize = np.shape(Field)
    if np.mod(imSize,  8).any():
        Field, pos = ZeroPadding(Field, 8)  # ZeroPadding
    Field = torch.from_numpy(Field) 
    ## The networks in pytorch only supports inputs that are a mini-batch of samples,
    ## and not a single sample. Therefore we need  to squeeze the 3D tensor to be 
    ## a 5D tesor for model evaluation.  
    Field = torch.unsqueeze(Field, 0)
    Field = torch.unsqueeze(Field, 0)
    Field = Field.float()
    ## QSM Reconstruction 
    print('Reconstruction')
    # Recon_xQSM_invivo = Eval(Field, 'xQSM_invivo')
    Recon_Unet_invivo = Eval(Field, 'Unet_invivo')
    #Recon_xQSM_syn = Eval(Field, 'xQSM_syn')
    #Recon_Unet_syn = Eval(Field, 'Unet_syn')
    if np.mod(imSize,  8).any():
        Recon_xQSM_invivo  = ZeroRemoving(Recon_xQSM_invivo , pos) # ZeroRemoving if zeropadding were performed; 
        Recon_Unet_invivo  = ZeroRemoving(Recon_Unet_invivo , pos) 
    Recon_xQSM_invivo = Recon_xQSM_invivo * mask
    Recon_Unet_invivo = Recon_Unet_invivo * mask
    ## calculate PSNR and SSIM
    label, aff = Read_nii('cosmos_label.nii')  # read label; 
    print('PSNR of xQSM_invivo is %f'% (psnr(Recon_xQSM_invivo, label)))
    print('PSNR of Unet_invivo is %f'% (psnr(Recon_Unet_invivo, label)))
    ## Saving Results (in .mat)
    print('saving reconstructions')
    path = './Chi_xQSM_invivo.mat' 
    Save_mat(Recon_xQSM_invivo, path)
    path = './Chi_Unet_invivo.mat' 
    Save_mat(Recon_Unet_invivo, path)
    #path = './Chi_xQSM_syn.mat' 
    #Save_mat(Recon_xQSM_syn, path)
    #path = './Chi_Unet_syn.mat' 
    #Save_mat(Recon_Unet_syn, path)
    ## or can be stored in .nii format; 
    path = 'Chi_xQSM_invivo.nii'
    Save_nii(Recon_xQSM_invivo, aff, path)
    path = 'Chi_Unet_invivo.nii'
    Save_nii(Recon_Unet_invivo, aff, path)
    #path = 'Chi_xQSM_syn.nii'
    #Save_nii(Recon_xQSM_syn, aff, path)
    #path = 'Chi_Unet_syn.nii'
    #Save_nii(Recon_xQSM_invivo, aff, path)

Data Loading
Loading Completed
ZeroPadding
Reconstruction
Load Pretrained Network
torch.Size([1, 1, 192, 256, 176])


RuntimeError: CUDA out of memory. Tried to allocate 2.06 GiB (GPU 0; 6.00 GiB total capacity; 3.25 GiB already allocated; 26.66 MiB free; 4.33 GiB reserved in total by PyTorch)