#### Import Libraries

In [None]:
import torch, os, glob, gc
os.environ['VXM_BACKEND'] = 'pytorch'
os.environ['NEURITE_BACKEND'] = 'pytorch'
from torch.autograd import Variable
from scipy import ndimage
import enum
import torchvision
import math, random
import nibabel as nib
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from voxelmorph import voxelmorph
import numpy as np
import time
import lagomorph as lm
from lagomorph import adjrep 
from lagomorph import deform 
import SimpleITK as sitk
from atlas_builder.atlas_trainer import main_atlas_train
from SADIR_forward import get_diffused_image, get_deformed_image

IMAGE_SIZE=64
device = torch.device('cuda')


#### Data Loader

In [None]:
# Train atlas builder to obtain initial atlas
main_atlas_train(False)

# Load Data
class SADIRData(Dataset):
    def __init__(self, path, test_flag=False):
        self.test_flag = test_flag
        if self.test_flag==True:
            self.path = path + "/test_64/"
            self.filenames = os.listdir(self.path)
        else:
            self.path = path + "/train_64/"
            self.filenames = os.listdir(self.path)
        self.filenames = [i for i in self.filenames if i.startswith('.')==False]
        print("number of files in directory:", len(os.listdir(self.path)))
        
    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        atlas_path = './final_atlas.nii'
        if self.test_flag:
            obj= np.array(nib.load(os.path.join(self.path, self.filenames[idx])).get_fdata())
        else:
            obj= np.array(nib.load(os.path.join(self.path, self.filenames[idx])).get_fdata())
        atlas= np.array(nib.load(atlas_path).get_fdata())
        atlas[atlas<0.5]=0
        atlas[atlas>=0.5]=1
        img= np.zeros((IMAGE_SIZE,IMAGE_SIZE,IMAGE_SIZE))
        for idi in list(np.linspace(0,62,14).astype(np.int16))[3:-1]:
        
            img[:,idi,:]=obj[:,idi,:]
        
        mask_bf = np.zeros((2, int(np.shape(img)[0]), int(np.shape(img)[1]), int(np.shape(img)[2])))
        mask_bf[0, :, :, :] = img
        mask_bf[1, :, :, :] = atlas

        if self.test_flag:
            return torch.from_numpy(mask_bf).float(), self.filenames[idx]
        else:
            obj = obj[np.newaxis, :, :, :]
            return torch.from_numpy(mask_bf).float(), torch.from_numpy(obj).float()
        
args={'data_dir' : './',
      'batch_size' : 2}

ds = SADIRData(args['data_dir'], test_flag=False)
datal= torch.utils.data.DataLoader(
    ds,
    batch_size=args['batch_size'],
    shuffle=True)
data = iter(datal)
print("number of files: ", len(list(datal)))
temp = torch.cuda.FloatTensor(args['batch_size'], 3, IMAGE_SIZE,IMAGE_SIZE,IMAGE_SIZE).fill_(0).contiguous()


#### Model Training

In [None]:
# Set training parameters
# enc_nf = [16, 32, 32, 32]
# dec_nf = [32, 32, 32, 32, 32, 16, 16]
epochs=2000
model_dir = './trained_models'
model = voxelmorph.torch.networks.VxmDense(
        inshape=(IMAGE_SIZE,IMAGE_SIZE,IMAGE_SIZE),
        nb_unet_features=32,
        nb_unet_levels=2,
        use_attention_unet = True)

# prepare the model for training and send to device
model.cuda()
model.train()

# set optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
grad_loss_func = voxelmorph.torch.losses.Grad('l2', loss_mult=1)
image_loss_func = voxelmorph.torch.losses.Dice()
v0_loss_func = voxelmorph.torch.losses.MSE()

atlas_images = []
deepatlas_losses = []

losses = [v0_loss_func, image_loss_func, grad_loss_func]
weights = [1, 0.03, 0.01]

# training loops
for epoch in range(1,epochs):
    if epochs%200==0:
        main_atlas_train(True)
    data = iter(datal)
    epoch_loss = []
    epoch_total_loss = []
    epoch_step_time = []
    data = iter(datal)
    for step in range(len(list(datal))):
        prior, x0_true = next(data)
        time_ = random.randint(10,999)
        # generate image after forward diffusion
        x_t = get_diffused_image(x0_true, torch.from_numpy(np.array([time_])).cuda())
        for t_ in range(time_):
            inputs = torch.cat([prior, x_t], dim=1)
            inputs = [d.cuda().permute(0, 1, 2, 3, 4) for d in inputs.unsqueeze(0)]
            x_0 = [d.cuda().permute(0, 1, 2, 3, 4) for d in x0_true.unsqueeze(0)]
            # run inputs through the model to produce a momentum field
            m0_pred = model(*inputs, torch.tensor([t_]).cuda())        
            x0_pred_prc= [get_deformed_image(m0_pred, prior[0][1].unsqueeze(0).unsqueeze(0).cuda()).squeeze()]
            x0_pred_prc= x0_pred_prc[0].unsqueeze(0).unsqueeze(0)
            x_t = x0_pred_prc
        m0_pred.requires_grad_(True)
        # calculate total loss
        loss = 0
        loss_list = []
        for n, loss_function in enumerate(losses):
            curr_loss = loss_function.loss(x0_pred_prc, x_0[0]) * weights[n]        
            loss_list.append(curr_loss.item())
            loss += curr_loss

        epoch_loss.append(loss_list)
        epoch_total_loss.append(loss.item())

        # backpropagate and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # get compute time
        epoch_step_time.append(time.time() - step_start_time)
        # if step%100==0:
        #     print(step, 'step')

        del inputs
        del x_0
        del m0_pred
        del x0_pred_prc
        torch.cuda.empty_cache()

    # print epoch info
    epoch_info = 'Epoch %d/%d' % (epoch + 1, epochs)
    time_info = '%.4f sec/step' % np.mean(epoch_step_time)
    losses_info = ', '.join(['%.4e' % f for f in np.mean(epoch_loss, axis=0)])
    loss_info = 'loss: %.4e  (%s)' % (np.mean(epoch_total_loss), losses_info)
    print(' - '.join((epoch_info, time_info, loss_info)), flush=True)

# final model save
model.save(os.path.join(model_dir, '%04d.pt' % epochs))