In [None]:
import torch, os, glob, gc
os.environ['VXM_BACKEND'] = 'pytorch'
os.environ['NEURITE_BACKEND'] = 'pytorch'
from torch.autograd import Variable
from scipy import ndimage
import enum
import torchvision
import math, random
import nibabel as nib
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from voxelmorph import voxelmorph
import numpy as np
import time
import lagomorph as lm
from lagomorph import adjrep 
from lagomorph import deform 
import SimpleITK as sitk
# from atlas_builder.atlas_trainer import main_atlas_train
from SADIR_forward import get_diffused_image, get_deformed_image

IMAGE_SIZE=64
device = torch.device('cuda')

### Data Loader

In [None]:
# Load Data
class SADIRData(Dataset):
    def __init__(self, path, test_flag=False):
        self.test_flag = test_flag
        if self.test_flag==True:
            self.path = path + "/test_64/"
            self.filenames = os.listdir(self.path)
        else:
            self.path = path + "/train_64/"
            self.filenames = os.listdir(self.path)
        self.filenames = [i for i in self.filenames if i.startswith('.')==False]
        print("number of files in directory:", len(os.listdir(self.path)))
        
    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        atlas_path = './final_atlas.nii'
        if self.test_flag:
            obj= np.array(nib.load(os.path.join(self.path, self.filenames[idx])).get_fdata())
        else:
            obj= np.array(nib.load(os.path.join(self.path, self.filenames[idx])).get_fdata())
        atlas= np.array(nib.load(atlas_path).get_fdata())
        atlas[atlas<0.5]=0
        atlas[atlas>=0.5]=1
        img= np.zeros((IMAGE_SIZE,IMAGE_SIZE,IMAGE_SIZE))
        for idi in list(np.linspace(0,62,14).astype(np.int16))[3:-1]:
        
            img[:,idi,:]=obj[:,idi,:]
        
        mask_bf = np.zeros((2, int(np.shape(img)[0]), int(np.shape(img)[1]), int(np.shape(img)[2])))
        mask_bf[0, :, :, :] = img
        mask_bf[1, :, :, :] = atlas

        if self.test_flag:
            return torch.from_numpy(mask_bf).float(), self.filenames[idx]
        else:
            obj = obj[np.newaxis, :, :, :]
            return torch.from_numpy(mask_bf).float(), torch.from_numpy(obj).float()
        
args={'data_dir' : './',
      'batch_size' : 2}

ds = SADIRData(args['data_dir'], test_flag=True)
datal= torch.utils.data.DataLoader(
    ds,
    batch_size=args['batch_size'],
    shuffle=True)
data = iter(datal)
print("number of files: ", len(list(datal)))
temp = torch.cuda.FloatTensor(args['batch_size'], 3, IMAGE_SIZE,IMAGE_SIZE,IMAGE_SIZE).fill_(0).contiguous()


### Testing

In [None]:
device = torch.device("cuda:0")
model = voxelmorph.torch.networks.VxmDense.load('./trained_models/0510.pt', device)
model.to(device)

In [None]:
gt_path = './test_64/'
data = iter(datal)
for _ in range(len(list(datal))):
    prior, fname = next(data)
    time_ = random.randint(10,999)
    x_t = torch.randn_like(prior[0][0].unsqueeze(0).unsqueeze(0))
    for t_ in range(time_):
        inputs = torch.cat([prior, x_t], dim=1)
        inputs = [d.cuda().permute(0, 1, 2, 3, 4) for d in inputs.unsqueeze(0)]
        # run inputs through the model to produce a warped image and flow field
        m0_pred = model(*inputs, torch.tensor([t_]).cuda())        
        x0_pred= [get_deformed_image(m0_pred, prior[0][1].unsqueeze(0).unsqueeze(0).cuda()).squeeze()]
        x0_pred_prc= x0_pred[0].unsqueeze(0).unsqueeze(0)
        x_t = x0_pred_prc
    yim= nib.load(gt_path+fname[0])    
    x0_pred_prc=x0_pred_prc.detach().cpu().numpy().squeeze()
    k = (np.amax(x0_pred_prc) + np.amin(x0_pred_prc))/2
    x0_pred_prc[x0_pred_prc>=k]=1
    x0_pred_prc[x0_pred_prc<k]=0
    nib.save(nib.Nifti1Image(x0_pred_prc, yim.affine,yim.header), './predictions_/'+str(fname[0]))
