In [2]:
import os
import torch
from src.utils import *
from src.model_all import *
from scipy.ndimage import gaussian_filter
from nilearn.image import resample_img

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Resample data
input_file = '/nfs/masi/kanakap/projects/DeepN4/data/IXI015-HH-1258-T1.nii.gz'
output_file = '/nfs/masi/kanakap/projects/DeepN4/data/resampled_IXI015-HH-1258-T1.nii.gz'
x_res, y_res, z_res = 2, 2, 2
os.system('mri_convert \"{}\" \"{}\" -vs {} {} {} -rt cubic'.format(input_file, output_file, x_res, y_res, z_res))

In [4]:
def load( subj ):

    input_data = nib.load(subj).get_fdata()
    input_data, [lx,lX,ly,lY,lz,lZ,rx,rX,ry,rY,rz,rZ] = pad(input_data, 128)
    in_max = np.percentile(input_data[np.nonzero(input_data)], 99.99)
    input_data = normalize_img(input_data, in_max, 0, 1, 0)
    input_data = np.squeeze(input_data)
    input_vols = np.zeros((1,1, 128, 128, 128))
    input_vols[0,0,:,:,:] = input_data

    return torch.from_numpy(input_vols).float(), lx,lX,ly,lY,lz,lZ,rx,rX,ry,rY,rz,rZ, in_max

In [5]:
# Inference function
def pred_model( input_path, checkpoint_file ):
    use_cuda = torch.cuda.is_available()
    torch.manual_seed(1)
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 10, 'pin_memory': True} if use_cuda else {}

    # Get model checkpoint
    model = Synbo_UNet3D(1, 1).to(device)
    model = load_model(model, checkpoint_file)
    model.eval()
    in_features, lx,lX,ly,lY,lz,lZ,rx,rX,ry,rY,rz,rZ, in_max = load(input_path)

    in_features = in_features.to(device)
    logfield = model(in_features, device)
    field = torch.exp(logfield)

    field = field.cpu()
    input_data = in_features.cpu()  

    # Reshape
    field = field.squeeze()
    field_ny = field.detach().numpy()
    input_data = input_data.squeeze()

    org_data = nib.load(input_path).get_fdata()
    final_field = np.zeros([org_data.shape[0], org_data.shape[1], org_data.shape[2]])
    final_field[rx:rX,ry:rY,rz:rZ] = field_ny[lx:lX,ly:lY,lz:lZ]

    final_input = np.zeros([org_data.shape[0], org_data.shape[1], org_data.shape[2]])
    final_input[rx:rX,ry:rY,rz:rZ] = input_data[lx:lX,ly:lY,lz:lZ]

    # Compute corrected image
    final_field = gaussian_filter(final_field, sigma=3)
    final_corrected = final_input / final_field
    final_corrected = unnormalize_img(final_corrected, in_max, 0, 1, 0)

    return final_corrected, final_field

In [6]:
# Run inference function # Download the checkpointfile from https://drive.google.com/drive/folders/1mdBsV0kHRRV_Alu1QJrTT7N0GGNJDuiu?usp=sharing 
# pred_model(gpu_no, test_file, checkpoint_file, pred_dir, filter_type) ADD FILTER TYPE option
final_corrected, final_field = pred_model(output_file, checkpoint_file='/nfs/masi/kanakap/projects/DeepN4/src/trained_model_Synbo_UNet3D/checkpoint_epoch_264')

In [None]:
# Save
ref = nib.load(output_file)
nii = nib.Nifti1Image(final_corrected, affine=ref.affine, header=ref.header)
nib.save(nii, '/nfs/masi/kanakap/projects/DeepN4/data/corrected_IXI015-HH-1258-T1.nii.gz')

nii = nib.Nifti1Image(final_field, affine=ref.affine, header=ref.header)
nib.save(nii, '/nfs/masi/kanakap/projects/DeepN4/data/predicted_field_IXI015-HH-1258-T1.nii.gz')

In [None]:
# Resample back to orginal resolution 
ref = nib.load(input_file)
output_img = resample_img(nib.Nifti1Image(final_corrected, nib.load(output_file).affine), target_affine=ref.affine, target_shape=ref.shape)
nib.save(output_img, '/nfs/masi/kanakap/projects/DeepN4/data/corrected_upsampled_IXI015-HH-1258-T1.nii.gz')