In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from pathlib import Path
from sklearn.model_selection import KFold

import multiprocessing
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
import warnings

from monai.networks import nets

import os
from glob import glob
from tqdm.notebook import tqdm
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import enum
from IPython import display
from tqdm.notebook import tqdm

from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.patches as mpatches

import natsort

from scipy import io
import glob

import nibabel as nib

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
def make_colorbar_with_padding(ax):
    """
    Create colorbar axis that fits the size of a plot - detailed here: http://chris35wills.github.io/matplotlib_axis/
    """
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)
    return(cax) 

In [None]:
seed = 42
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
# #data load
# data_root = '...' 
# subjects_list=[]
# subjects_list = os.listdir(data_root)
# subjects_list = natsort.natsorted(subjects_list)

# state_dir = '.../CheckPoint/Quantitative_Maps_Generator'
# save_inference_dir = '.../inference'

#data load
data_root = '/hdd/share/test_github_upload/data'
subjects_list=[]
subjects_list = os.listdir(data_root)
subjects_list = natsort.natsorted(subjects_list)

state_dir = '/hdd/share/2023_New_Harmonization/src/PhyCHarm_share/CheckPoint/Quantitative_Maps_Generator'
save_inference_dir = '/hdd/share/test_github_upload/inference_QM'

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data_root, sub_num, num_z_slice, train):
        self.data_root = data_root
        self.sub_num = sub_num

        self.dataset = []
        for z in range(num_z_slice):
            self.dataset.append(z)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        z = self.dataset[idx]
        
        T1w_path = glob.glob(f'{self.data_root}/{self.sub_num}/T1w/*.nii.gz')[0]
        mask_path =glob.glob(f'{self.data_root}/{self.sub_num}/mask/*.nii.gz')[0]
        
        T1w = nib.load(T1w_path).get_fdata()[None, ..., z].astype(np.float32)
        mask = nib.load(mask_path).get_fdata()[None, ..., z].astype(np.uint8)

        T1w = torch.from_numpy(T1w)
        mask = torch.from_numpy(mask)
        
        return T1w, mask

In [None]:
class Action(enum.Enum):
    TRAIN = True
    VALIDATE = False

def run_epoch(action, loader, generator_T1map, generator_M0map, xs, ys, zs):

    if action:
        generator_T1map.train()
    else:
        generator_T1map.eval()
        generator_M0map.eval()

    with torch.no_grad():
        volume_T1 = torch.zeros((xs,ys,zs)).to(device) #trained shape: xs=240, ys=240, zs= 224
        volume_M0 = torch.zeros((xs,ys,zs)).to(device)
        
        for batch_idx, batch in enumerate(tqdm(loader)):
            
            """generative network (mapping)"""
            # T1w to T1map
            print(batch_idx)

            input_T1w = batch[0].to(device)
            mask = batch[1].to(device)
            input_T1w = input_T1w/4095

            print(input_T1w.shape)

            """generative network (mapping)"""

            # T1w to T1map

            gen_t1map=generator_T1map(input_T1w) # T1w to T1map
            gen_m0map=generator_M0map(input_T1w) # T1w to M0map

            gen_t1map = mask*(gen_t1map-gen_t1map.min())*4095
            gen_m0map = mask*(gen_m0map-gen_m0map.min())*500

            volume_T1[:,:,batch_idx] = gen_t1map[0,0,:,:]          
            volume_M0[:,:,batch_idx] = gen_m0map[0,0,:,:]

        volume_T1=volume_T1.detach().cpu().numpy()
        volume_M0=volume_M0.detach().cpu().numpy()
        
        return volume_T1,volume_M0

In [None]:
for sub in subjects_list: 

    header_data = nib.load(glob.glob(f'{data_root}/{sub}/T1w/*.nii.gz')[0])
    sample_data = nib.load(glob.glob(f'{data_root}/{sub}/T1w/*.nii.gz')[0]).get_fdata()
    [xs,ys,zs] = sample_data.shape

    inference_dataset = CustomDataset(
        data_root=data_root, sub_num = sub,  num_z_slice=zs, train=False,
    )
    
    batch_size=1
    inference_dl = DataLoader(inference_dataset, batch_size=batch_size, num_workers=0, shuffle=False)

    generator_T1map=nets.BasicUnet(
        in_channels=1,
        out_channels=1,
        spatial_dims=2,
        features=(16,32,64,128,256,32),
    ).to(device)

    generator_M0map=nets.BasicUnet(
        in_channels=1,
        out_channels=1,
        spatial_dims=2,
        features=(16,32,64,128,256,32),
    ).to(device)

    checkpoint = torch.load(f'{state_dir}/MapG_T1.pth')
    generator_T1map.load_state_dict(checkpoint['T1map'])

    checkpoint = torch.load(f'{state_dir}/MapG_M0.pth')
    generator_M0map.load_state_dict(checkpoint['M0map'])


    volume_T1map,volume_M0map = run_epoch(0, inference_dl, generator_T1map, generator_M0map, xs, ys, zs)

    #save results
    
    os.makedirs(f'{save_inference_dir}/{sub}', exist_ok=True)

    save_T1_path = f'{save_inference_dir}/{sub}/pred_T1.nii.gz'
    save_M0_path = f'{save_inference_dir}/{sub}/pred_M0.nii.gz'
    
    save_output = nib.Nifti1Image(volume_T1map, header_data.affine, header_data.header)
    nib.save(save_output, f'{save_T1_path}')

    save_output = nib.Nifti1Image(volume_M0map, header_data.affine, header_data.header)
    nib.save(save_output, f'{save_M0_path}')