In [None]:
import os
import nibabel as nib
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
%matplotlib inline

# Use a white background for matplotlib figures
matplotlib.rcParams['figure.facecolor'] = '#ffffff'

torch.set_default_tensor_type('torch.cuda.FloatTensor')
random_seed = 2022
torch.manual_seed(random_seed)

# Load data

In [None]:
def remove_norm(lst):
    '''Remove names starting with "norm_fmri_img"'''
    lst.sort()
    l = len(lst)
    if l%2:
        print('The list has an odd number of elements!')
        return None
    else:
        take = l // 2
        print(f'The length of the original list is {l}, and the lenth of the new list is {take}')
        return lst[:take]
    
### CN ###
dir_CN = '/blue/li.duan/zheng.yu/Research1/CN'
CN_img_names = os.listdir(dir_CN)

CN_img_names = remove_norm(CN_img_names) # Only keep main_images

### AD ###
dir_AD = '/blue/li.duan/zheng.yu/Research1/AD'
AD_img_names = os.listdir(dir_AD)

AD_img_names = remove_norm(AD_img_names) # Only keep main_images

### LMCI ###
# dir_LMCI = '/blue/li.duan/zheng.yu/Research1/LMCI'
# LMCI_img_names = os.listdir(dir_LMCI)

# LMCI_img_names = remove_norm(LMCI_img_names) # Only keep main_images

### CN

In [None]:
# Load data in CN group
imgs_data = torch.tensor(nib.load(os.path.join(dir_CN, CN_img_names[0])).get_fdata()).unsqueeze(0).unsqueeze(0)
for i in range(2,38,2):#range(3,39,2): (change to 1 above)
    name = CN_img_names[i]
    to_be_stacked = torch.tensor(nib.load(os.path.join(dir_CN, name)).get_fdata()).unsqueeze(0).unsqueeze(0)
    to_be_stacked = to_be_stacked[...,:120] # All images must have the same times
    imgs_data = torch.cat((imgs_data, to_be_stacked), dim=0)
    print(imgs_data.shape)
# imgs_data = []
# for name in CN_img_names[:32]:
#     imgs_data.append(nib.load(os.path.join(dir_CN, name)).get_fdata())
# imgs_data = [nib.load(os.path.join(dir_CN, name)).get_fdata() for name in CN_img_names]

### AD

In [None]:
# Load data in AD group
imgs_data = torch.tensor(nib.load(os.path.join(dir_AD, AD_img_names[0])).get_fdata()).unsqueeze(0).unsqueeze(0)
for i in range(2,26,2):
    name = AD_img_names[i]
    to_be_stacked = torch.tensor(nib.load(os.path.join(dir_AD, name)).get_fdata()).unsqueeze(0).unsqueeze(0)
    to_be_stacked = to_be_stacked[...,:120] # All images must have the same times
    imgs_data = torch.cat((imgs_data, to_be_stacked), dim=0)
    print(imgs_data.shape)

### LMCI

In [None]:
# Load data in LMCI group
imgs_data = torch.tensor(nib.load(os.path.join(dir_LMCI, LMCI_img_names[0])).get_fdata()).unsqueeze(0).unsqueeze(0)
for name in LMCI_img_names[1:]:
    to_be_stacked = torch.tensor(nib.load(os.path.join(dir_LMCI, name)).get_fdata()).unsqueeze(0).unsqueeze(0)
    to_be_stacked = to_be_stacked[...,:120] # All images must have the same times
    imgs_data = torch.cat((imgs_data, to_be_stacked), dim=0)
    print(imgs_data.shape)

### CN + AD

In [None]:
# Load data in CN group
names = []
for i in range(len(CN_img_names)//2):
    if i==0:
        names.append(CN_img_names[0])
        imgs_data = torch.tensor(nib.load(os.path.join(dir_CN, CN_img_names[0])).get_fdata()).unsqueeze(0).unsqueeze(0)
    else:
        names.append(CN_img_names[2*i])
        to_be_stacked = torch.tensor(nib.load(os.path.join(dir_CN, CN_img_names[2*i])).get_fdata()).unsqueeze(0).unsqueeze(0)
        to_be_stacked = to_be_stacked[...,:120] # All images must have the same times
        imgs_data = torch.cat((imgs_data, to_be_stacked), dim=0)
        print(imgs_data.shape)
print('CN finished!')
for i in range(len(AD_img_names)//2):
    names.append(AD_img_names[2*i])
    to_be_stacked = torch.tensor(nib.load(os.path.join(dir_AD, AD_img_names[2*i])).get_fdata()).unsqueeze(0).unsqueeze(0)
    to_be_stacked = to_be_stacked[...,:120] # All images must have the same times
    imgs_data = torch.cat((imgs_data, to_be_stacked), dim=0)
    print(imgs_data.shape)
print('AD finished!')
print(names)

# Some helper functions

In [None]:
# Data normalization
def Data_Normalization(imgs_data):
    '''imgs_data: train_size x 1 x . x . x . x time'''
    '''mapping to [-1,1]'''
    max_values = torch.amax(imgs_data, dim=(1,2,3,4))
    min_values = torch.amin(imgs_data, dim=(1,2,3,4))
    for i in range(imgs_data.shape[0]):
        imgs_data[i] = 2 * ((imgs_data[i]-min_values[i])/(max_values[i]-min_values[i]) - 0.5)
    return imgs_data, max_values, min_values


# Data masking
def mask_img2(img, mask):
    '''masking an image; work even if there is nan'''
    for i in range(img.size(-1)):
        img[...,i][mask==0] = 0
    return img

def mask_img_list(lst, mask):
    '''masking a tensor consisting of a bunch of images'''
    for i in range(lst.shape[0]):
        lst[i] = mask_img2(lst[i], mask)
    return lst


# # Replacing nan with 0
# def replace_nan_with_0(img):
#     '''replace nan entries of an image with 0'''
#     img[torch.isnan(img)] = 0
#     return img

# def replace_nan_with_0_forlist(lst):
#     '''apply replace_nan_with_0 to a list of images'''
#     for i in range(len(lst)):
#         lst[i] = replace_nan_with_0(lst[i])
#     return lst


# Truncate time
def truncate_time(imgs_data, max_time=120):
    for i in range(imgs_data.shape[0]):
        imgs_data[i] = imgs_data[...,:max_time]
    return imgs_data



# Moving a list of tensors from GPU to CPU
def to_cpu(lst):
    return lst.cpu()


# Helper functions for using GPU
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

device = get_default_device()

def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True, dtype=torch.float)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

# Some hyperparameters

In [None]:
train_size = imgs_data.shape[0]
batch_size = 4
latent_dim = 10
z_dim = latent_dim
print(f'train size:{train_size}, batch size:{batch_size}, latent dimension:{latent_dim}, z dimension:{z_dim}')

In [None]:
sig_z = 1.
sig_h = 1.
sig_x = 1.
rho = 0.1
lambda_z = 10.#e-1

In [None]:
tol_time = 120

# IDs & labels for images

In [None]:
IDs = torch.tensor(np.arange(train_size), dtype=torch.long)
IDs

In [None]:
Labels = torch.tensor(np.append(np.zeros(19), np.ones(13)), dtype=torch.long) # Only need this when using group-specific F but sharing other paramters
Labels

# Training pipeline

In [None]:
def training_pipeline(imgs_data,
                      data_normalization=True, 
                      include_IDs=True, 
                      masking=False,
                      log_transform=False, epsilon=1e-6,
                      replace_nan_with_0=False,
                      move_to_cpu=False,
                      trunc_time=False):
    '''imgs_data: a tensor'''
#     if trunc_time:
#         imgs_data = truncate_time(imgs_data, 120)
    if replace_nan_with_0:#******************
        imgs_data = replace_nan_with_0_forlist(imgs_data)
    if masking:
        mask = torch.tensor(nib.load('mask.nii.gz').get_fdata()).unsqueeze(0)
        imgs_data = mask_img_list(imgs_data, mask)
    if log_transform:#******************
        imgs_data = [torch.log(item+epsilon) for item in imgs_data]
    if data_normalization:
        imgs_data, max_values, min_values = Data_Normalization(imgs_data)
    if move_to_cpu:
        imgs_data = to_cpu(imgs_data)
    if include_IDs:
        imgs_data = [[imgs_data[i], IDs[i]] for i in range(imgs_data.shape[0])]
    return imgs_data, max_values, min_values

In [None]:
train_ds, max_value, min_value = training_pipeline(imgs_data)
del imgs_data
torch.cuda.empty_cache()

# Data loader

In [None]:
print(device) # Make sure we are using GPU

train_loader = DataLoader(train_ds, batch_size, shuffle=True, generator=torch.Generator(device='cuda'))
train_loader = DeviceDataLoader(train_loader, device)

test_loader = DataLoader(train_ds, 1, shuffle=False)
test_loader = DeviceDataLoader(test_loader, device)

# Initializing $F$, $z^s$ and $h_0$

In [None]:
# F^c (c=0: CN, c=1: AD)
# num_group = 2
# F_mat = torch.rand(num_group, latent_dim, latent_dim)
F_mat = torch.rand(latent_dim, latent_dim)
F_mat_original = F_mat.clone().detach()
print(f'Shape of F:{F_mat.shape}, device:{F_mat.device}')

# z^s
z_vectors = torch.normal(mean=torch.zeros((train_size, z_dim)), std=sig_z).clone().detach().requires_grad_(True)
z_vectors_original = z_vectors.clone().detach()
#z_vectors = sig_z * torch.randn(train_size, z_dim, requires_grad=True)
print(f'Shape of z:{z_vectors.shape}, device:{z_vectors.device}')

# h_0
h0 = torch.rand(1, latent_dim)
#h0 = sig_h * torch.randn(1, latent_dim)
print(f'Shape of h_0:{h0.shape}, device:{h0.device}')
print(f'h0:{h0}')

In [None]:
# Load in F from a saved file
F_mat = torch.load(os.path.join('Recorded', 'F_mat_AD.pt'))

# z^s
z_vectors = torch.normal(mean=torch.zeros((train_size, z_dim)), std=sig_z).clone().detach().requires_grad_(True)
z_vectors_original = z_vectors.clone().detach()
#z_vectors = sig_z * torch.randn(train_size, z_dim, requires_grad=True)
print(f'Shape of z:{z_vectors.shape}, device:{z_vectors.device}')

# h_0
h0 = torch.rand(1, latent_dim)
#h0 = sig_h * torch.randn(1, latent_dim)
print(f'Shape of h_0:{h0.shape}, device:{h0.device}')
print(f'h0:{h0}')

In [None]:
# For reproducibility
h0 = torch.tensor([[0.1719, 0.1631, 0.8127, 0.3485, 0.5798, 0.9957, 0.4949, 0.0442, 0.1098, 0.7062]])
h0

# Create the model

In [None]:
# class RecVAEModel(nn.Module):
#     def __init__(self, enc_out_dim=100, latent_dim=latent_dim, sqr_sig_x=sig_x, sqr_sig_h=sig_h, 
#                  lambda_z=lambda_z, F_mat=F_mat):
#         super().__init__()
        
#         self.sqr_sig_x = sqr_sig_x
#         self.sqr_sig_h = sqr_sig_h
#         self.latent_dim = latent_dim
#         self.lambda_z = lambda_z
#         self.F_mat = F_mat
        
#         # Encoder: from input(x) to one of the inputs of the hidden layer (enc_x)
#         # input: 1 x 91 x 109 x 91
#         self.encoder1 = nn.Sequential(
#             nn.Conv3d(1, 4, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm3d(4),
#             nn.LeakyReLU(0.2, inplace = True)) # output: 4 x 45 x 54 x 45
        
#         self.encoder2 = nn.Sequential(
#             nn.Conv3d(4, 8, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm3d(8),
#             nn.LeakyReLU(0.2, inplace = True)) # output: 8 x 22 x 27 x 22
        
#         self.encoder3 = nn.Sequential(
#             nn.Conv3d(8, 16, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm3d(16),
#             nn.LeakyReLU(0.2, inplace = True)) # output: 16 x 11 x 13 x 11
        
#         self.encoder4 = nn.Sequential(
#             nn.Conv3d(16, 32, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm3d(32),
#             nn.LeakyReLU(0.2, inplace = True)) # output: 32 x 5 x 6 x 5
        
#         self.encoder5 = nn.Sequential(
#             nn.Flatten(), 
#             nn.Linear(32*5*6*5, enc_out_dim),
#             nn.Tanh()) # output: enc_out_dim
        
        
#         # Hidden: from (enc_x,h_{t-1}) to h_t
#         self.hidden2mu = nn.Linear(enc_out_dim+latent_dim, latent_dim)
#         self.hidden2log_var = nn.Linear(enc_out_dim+latent_dim, latent_dim)
        
#         # Decoder: from h_t to mu_t
#         self.decoder1 = nn.Sequential(
#             nn.Linear(latent_dim, 32*5*6*5),
#             nn.Unflatten(1, (32, 5, 6, 5)),
#             nn.BatchNorm3d(32),
#             nn.LeakyReLU(0.2, inplace = True))
          
#         self.decoder2 = nn.Sequential(
#             nn.ConvTranspose3d(32, 16, kernel_size=4, stride=2, padding=1, output_padding=1, bias=False), 
#             nn.BatchNorm3d(16),
#             nn.LeakyReLU(0.2, inplace = True))
        
#         self.decoder3 = nn.Sequential(
#             nn.ConvTranspose3d(16, 8, kernel_size=4, stride=2, padding=1, output_padding=(0,1,0), bias=False), 
#             nn.BatchNorm3d(8),
#             nn.LeakyReLU(0.2, inplace = True))
        
#         self.decoder4 = nn.Sequential(
#             nn.ConvTranspose3d(8, 4, kernel_size=4, stride=2, padding=1, output_padding=(1,0,1), bias=False), 
#             nn.BatchNorm3d(4),
#             nn.LeakyReLU(0.2, inplace = True))
        
#         self.decoder5 = nn.Sequential(
#             nn.ConvTranspose3d(4, 1, kernel_size=4, stride=2, padding=1, output_padding=1, bias=False), 
#             nn.Tanh())
        
#         # F
#         # self.F = nn.Linear(latent_dim, latent_dim, bias=False)
#         # self.F_AD = nn.Linear(latent_dim, latent_dim, bias=False)
#         # self.F_CN = nn.Linear(latent_dim, latent_dim, bias=False)
                   
        
       
#     def g_transform(self, h_old, batch_labels):
#         # h_old shape: batch_size(cunrrently) x latent_dim
#         # batch_labels: 0 -> CN, 1 -> AD
#         # h_new = torch.matmul(h_old, F_mat)
#         h_old = h_old.unsqueeze(1)
#         h_new = torch.zeros(len(batch_labels), latent_dim)
#         for i in range(len(batch_labels)):
#             h_new[i] = torch.matmul(self.F_mat[batch_labels[i]], h_old[i].transpose(0,1)).transpose(0,1) # 1 x latent_dim
#         return h_new
    
    
#     def encode(self, x):
#         enc_x = self.encoder1(x)
#         enc_x = self.encoder2(enc_x)
#         enc_x = self.encoder3(enc_x)
#         enc_x = self.encoder4(enc_x)
#         enc_x = self.encoder5(enc_x)
#         return enc_x
    
    
#     def decode(self, h):
#         dec_h = self.decoder1(h)
#         dec_h = self.decoder2(dec_h)
#         dec_h = self.decoder3(dec_h)
#         dec_h = self.decoder4(dec_h)
#         dec_h = self.decoder5(dec_h)
#         return dec_h
        
    
#     def reparametrize(self, mu_h,log_var_h):
#         # Reparametrization Trick to allow gradients to backpropagate from the stochastic part of the model
#         sigma_h = torch.exp(log_var_h / 2)
#         z = torch.randn(size = (mu_h.size(0),mu_h.size(1)))
#         # z = z.type_as(mu_h) # Setting z to be .cuda when using GPU training
#         return mu_h + sigma_h*z
   
  
#     def updating_F(self, h_history_history, h0, rho):
#         # h_history_history: train_size x tol_time x latent_dim
#         rhoI = 2*sig_h**2*rho * torch.eye(latent_dim)
        
#         for c in range(num_group):
#             h_his_his_for_single_group = h_history_history[Labels==c]
#             Y_tilde = h_his_his_for_single_group.view(-1, latent_dim) # output: train_size*tol_time x latent_dim
#             for i in range(torch.sum(Labels==c)):
#                 new_h_history = torch.cat((h0, h_his_his_for_single_group[i][0:(tol_time-1)]), 0)
#                 h_his_his_for_single_group[i] = new_h_history
#             X_tilde = h_his_his_for_single_group.view(-1, latent_dim) # output: train_size*tol_time x latent_dim
            
#             # updating F
#             XX_tilde = torch.matmul(X_tilde.transpose(0,1), X_tilde)
#             XY_tilde = torch.matmul(X_tilde.transpose(0,1), Y_tilde)
#             self.F_mat[c] = torch.linalg.solve(XX_tilde+rhoI, XY_tilde).transpose(0,1)
        
    
#     def VAE(self, x, h, which_ones):
#         # encode x and h to get the mu and variance parameters for the latent space
#         enc_x = self.encode(x)
#         combined = torch.cat((enc_x, h), 1)
#         mu_h, log_var_h = self.hidden2mu(combined), self.hidden2log_var(combined)
        
#         # sample h
#         h = self.reparametrize(mu_h, log_var_h)
        
#         # add z^s
#         h_tilde = h + z_vectors[which_ones,:]
        
#         # decode
#         mu = self.decode(h_tilde)
#         return mu, h
    
    
#     def training_step(self, batch, h_0, which_ones, batch_labels):
#         '''h_0: batch_size x latent_dim'''
#         x_list, mu_history, h_history, gh_history = self(batch, h_0, which_ones, batch_labels)
        
#         temp = 2 * batch.size(0) * len(h_history)
#         # calculate loss
#         loss1 = sum([torch.sum(torch.pow(x-mu, 2)) for x, mu in zip(x_list, mu_history)])
#         loss1 = loss1 / self.sqr_sig_x / temp
        
#         loss2 = sum([torch.sum(torch.pow(h-gh, 2)) for h, gh in zip(h_history, gh_history)])
#         loss2 = loss2 / self.sqr_sig_h / temp
        
#         loss_F = rho * torch.sum(torch.pow(self.F_mat, 2))
#         #loss_F = rho * loss_F / temp
        
#         #loss_z = torch.sum(torch.pow(z_vectors, 2)) ### L2->L1
#         loss_z = torch.sum(torch.abs(z_vectors))
#         loss_z = self.lambda_z * loss_z
#         #loss3 = F.l1_loss(z_vectors, torch.zeros_like(z_vectors), reduction='sum')
        
#         loss = loss1 + loss2 + loss_z
        
#         return loss, {'loss1':loss1, 'loss2':loss2, 'loss_F':loss_F, 'loss_z':loss_z}, h_history
    
    
#     def forward(self, x, h_0, which_ones, batch_labels):
#         # which_ones is a list containing the IDs of all subjects in the current batch
#         #tol_time = x.size(-1) # x is of size batch_size*channel*x1*x2*x3*tol_time
#         x_list = [x[:,:,:,:,:,t] for t in range(tol_time)]
#         del x
#         h = h_0
#         h_history = []
#         gh_history = []
#         mu_history = []
#         for t in range(tol_time):
#             gh_history.append(self.g_transform(h, batch_labels))
#             #print(f't={t}:{torch.max(self.g_transform(h, batch_labels))}')
#             mu, h = self.VAE(x_list[t], h, which_ones)
#             #print(f't={t}:{torch.max(h)}')
#             h_history.append(h)
#             mu_history.append(mu)
#         return x_list, mu_history, h_history, gh_history
    

# model = RecVAEModel()

In [None]:
class RecVAEModel(nn.Module):
    def __init__(self, enc_out_dim=100, latent_dim=latent_dim, sqr_sig_x=sig_x, sqr_sig_h=sig_h, 
                 lambda_z=lambda_z, F_mat=F_mat):
        super().__init__()
        
        self.sqr_sig_x = sqr_sig_x
        self.sqr_sig_h = sqr_sig_h
        self.latent_dim = latent_dim
        self.lambda_z = lambda_z
        self.F_mat = F_mat
        
        # Encoder: from input(x) to one of the inputs of the hidden layer (enc_x)
        # input: 1 x 91 x 109 x 91
        self.encoder1 = nn.Sequential(
            nn.Conv3d(1, 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(4),
            nn.LeakyReLU(0.2, inplace = True)) # output: 4 x 45 x 54 x 45
        
        self.encoder2 = nn.Sequential(
            nn.Conv3d(4, 8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(8),
            nn.LeakyReLU(0.2, inplace = True)) # output: 8 x 22 x 27 x 22
        
        self.encoder3 = nn.Sequential(
            nn.Conv3d(8, 16, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(16),
            nn.LeakyReLU(0.2, inplace = True)) # output: 16 x 11 x 13 x 11
        
        self.encoder4 = nn.Sequential(
            nn.Conv3d(16, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(32),
            nn.LeakyReLU(0.2, inplace = True)) # output: 32 x 5 x 6 x 5
        
        self.encoder5 = nn.Sequential(
            nn.Flatten(), 
            nn.Linear(32*5*6*5, enc_out_dim),
            nn.Tanh()) # output: enc_out_dim
        
        
        # Hidden: from (enc_x,h_{t-1}) to h_t
        self.hidden2mu = nn.Linear(enc_out_dim+latent_dim, latent_dim)
        self.hidden2log_var = nn.Linear(enc_out_dim+latent_dim, latent_dim)
        
        # Decoder: from h_t to mu_t
        self.decoder1 = nn.Sequential(
            nn.Linear(latent_dim, 32*5*6*5),
            nn.Unflatten(1, (32, 5, 6, 5)),
            nn.BatchNorm3d(32),
            nn.LeakyReLU(0.2, inplace = True))
          
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose3d(32, 16, kernel_size=4, stride=2, padding=1, output_padding=1, bias=False), 
            nn.BatchNorm3d(16),
            nn.LeakyReLU(0.2, inplace = True))
        
        self.decoder3 = nn.Sequential(
            nn.ConvTranspose3d(16, 8, kernel_size=4, stride=2, padding=1, output_padding=(0,1,0), bias=False), 
            nn.BatchNorm3d(8),
            nn.LeakyReLU(0.2, inplace = True))
        
        self.decoder4 = nn.Sequential(
            nn.ConvTranspose3d(8, 4, kernel_size=4, stride=2, padding=1, output_padding=(1,0,1), bias=False), 
            nn.BatchNorm3d(4),
            nn.LeakyReLU(0.2, inplace = True))
        
        self.decoder5 = nn.Sequential(
            nn.ConvTranspose3d(4, 1, kernel_size=4, stride=2, padding=1, output_padding=1, bias=False), 
            nn.Tanh())
        
        # F
        # self.F = nn.Linear(latent_dim, latent_dim, bias=False)
        # self.F_AD = nn.Linear(latent_dim, latent_dim, bias=False)
        # self.F_CN = nn.Linear(latent_dim, latent_dim, bias=False)
                   
        
       
    def g_transform(self, h_old):
        # h_old shape: batch_size(cunrrently) x latent_dim
        return torch.matmul(h_old, self.F_mat.transpose(0,1))
    
    
    def encode(self, x):
        enc_x = self.encoder1(x)
        enc_x = self.encoder2(enc_x)
        enc_x = self.encoder3(enc_x)
        enc_x = self.encoder4(enc_x)
        enc_x = self.encoder5(enc_x)
        return enc_x
    
    
    def decode(self, h):
        dec_h = self.decoder1(h)
        dec_h = self.decoder2(dec_h)
        dec_h = self.decoder3(dec_h)
        dec_h = self.decoder4(dec_h)
        dec_h = self.decoder5(dec_h)
        return dec_h
        
    
    def reparametrize(self, mu_h,log_var_h):
        # Reparametrization Trick to allow gradients to backpropagate from the stochastic part of the model
        sigma_h = torch.exp(log_var_h / 2)
        z = torch.randn(size = (mu_h.size(0),mu_h.size(1)))
        # z = z.type_as(mu_h) # Setting z to be .cuda when using GPU training
        return mu_h + sigma_h*z
   
  
    def updating_F(self, h_history_history, h0, rho):
        # h_history_history: train_size x tol_time x latent_dim
        rhoI = 2*sig_h**2*rho * torch.eye(latent_dim)
        
        Y_tilde = h_history_history.view(-1, latent_dim) # output: train_size*tol_time x latent_dim
        for i in range(train_size):
            new_h_history = torch.cat((h0, h_history_history[i][0:(tol_time-1)]), 0)
            h_history_history[i] = new_h_history
        X_tilde = h_history_history.view(-1, latent_dim) # output: train_size*tol_time x latent_dim
            
        # updating F
        XX_tilde = torch.matmul(X_tilde.transpose(0,1), X_tilde)
        XY_tilde = torch.matmul(X_tilde.transpose(0,1), Y_tilde)
        self.F_mat = torch.linalg.solve(XX_tilde+rhoI, XY_tilde).transpose(0,1)
        
    
    def VAE(self, x, h, which_ones):
        # encode x and h to get the mu and variance parameters for the latent space
        enc_x = self.encode(x)
        combined = torch.cat((enc_x, h), 1)
        mu_h, log_var_h = self.hidden2mu(combined), self.hidden2log_var(combined)
        
        # sample h
        h = self.reparametrize(mu_h, log_var_h)
        
        # add z^s
        h_tilde = h + z_vectors[which_ones,:]
        
        # decode
        mu = self.decode(h_tilde)
        return mu, h
    
    
    def training_step(self, batch, h_0, which_ones):
        '''h_0: batch_size x latent_dim'''
        x_list, mu_history, h_history, gh_history = self(batch, h_0, which_ones)
        
        temp = 2 * batch.size(0) * len(h_history)
        # calculate loss
        loss1 = sum([torch.sum(torch.pow(x-mu, 2)) for x, mu in zip(x_list, mu_history)])
        loss1 = loss1 / self.sqr_sig_x / temp
        
        loss2 = sum([torch.sum(torch.pow(h-gh, 2)) for h, gh in zip(h_history, gh_history)])
        loss2 = loss2 / self.sqr_sig_h / temp
        
        loss_F = rho * torch.sum(torch.pow(self.F_mat, 2))
        #loss_F = rho * loss_F / temp
        
        #loss_z = torch.sum(torch.pow(z_vectors, 2)) ### L2->L1
        loss_z = torch.sum(torch.abs(z_vectors))
        loss_z = self.lambda_z * loss_z
        #loss3 = F.l1_loss(z_vectors, torch.zeros_like(z_vectors), reduction='sum')
        
        loss = loss1 + loss2 + loss_z
        
        return loss, {'loss1':loss1, 'loss2':loss2, 'loss_F':loss_F, 'loss_z':loss_z}, h_history
    
    
    def forward(self, x, h_0, which_ones):
        # which_ones is a list containing the IDs of all subjects in the current batch
        #tol_time = x.size(-1) # x is of size batch_size*channel*x1*x2*x3*tol_time
        x_list = [x[:,:,:,:,:,t] for t in range(tol_time)]
        del x
        h = h_0
        h_history = []
        gh_history = []
        mu_history = []
        for t in range(tol_time):
            gh_history.append(self.g_transform(h))
            #print(f't={t}:{torch.max(self.g_transform(h, batch_labels))}')
            mu, h = self.VAE(x_list[t], h, which_ones)
            #print(f't={t}:{torch.max(h)}')
            h_history.append(h)
            mu_history.append(mu)
        return x_list, mu_history, h_history, gh_history

In [None]:
model_CN = RecVAEModel(F_mat=F_mat)
#model_AD = RecVAEModel(F_mat=F_mat)

In [None]:
# def fit(epochs, lr, h0, model, rho=rho, train_loader=train_loader, opt_func=torch.optim.SGD):
#     train_loss_history = []
#     #optimizer = opt_func(model.parameters(), lr)
#     optimizer = opt_func([{'params': model.parameters()}, {'params': z_vectors}], lr)
#     h_history_history = torch.zeros(train_size, tol_time, latent_dim)
#     for epoch in range(epochs):
#         # Training Phase 
#         model.train()
#         for batch, batch_index in train_loader:
#             # print(f'z: {z_vectors}')
#             h_0 = h0.expand(batch.size(0), -1) # h0: 1 x latent_dim, h_0: batch_size x latent_dim
#             which_ones = IDs[batch_index.to(int)]
#             batch_labels = Labels[batch_index.to(int)]
        
#             loss, loss_dic, h_history = model.training_step(batch, h_0, which_ones, batch_labels)
        
#             # Updating h_history_history
#             h_history_history[which_ones] = torch.stack(h_history).transpose(0,1) # output: batch_size x tol_time x latent_dim
            
#             loss.backward()
#             optimizer.step()
#             optimizer.zero_grad()
            
#         if True:#epoch%10==0:
#             train_loss_history.append(loss)
#             print("Epoch [{}]: train loss: {:.2f} with loss1: {:.2f}, loss2: {:.2f}, loss_z: {:.4f}, loss_F: {:.4f}"
#                   .format(epoch, loss, loss_dic['loss1'], loss_dic['loss2'], loss_dic['loss_z'], loss_dic['loss_F']))
        
#         # Updating F
#         with torch.no_grad():
#             model.updating_F(h_history_history, h0, rho)
#     return {'train_loss_history': train_loss_history, 'all_h_history': h_history_history, 'last_index': which_ones}

In [None]:
def fit(epochs, lr, h0, model, rho=rho, train_loader=train_loader, opt_func=torch.optim.SGD):
    train_loss_history = []
    #optimizer = opt_func(model.parameters(), lr)
    optimizer = opt_func([{'params': model.parameters()}, {'params': z_vectors}], lr)
    h_history_history = torch.zeros(train_size, tol_time, latent_dim)
    for epoch in range(epochs):
        # Training Phase 
        model.train()
        for batch, batch_index in train_loader:
            # print(f'z: {z_vectors}')
            h_0 = h0.expand(batch.size(0), -1) # h0: 1 x latent_dim, h_0: batch_size x latent_dim
            which_ones = IDs[batch_index.to(int)]
        
            loss, loss_dic, h_history = model.training_step(batch, h_0, which_ones)
        
            # Updating h_history_history
            h_history_history[which_ones] = torch.stack(h_history).transpose(0,1) # output: batch_size x tol_time x latent_dim
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
        if True:#epoch%10==0:
            train_loss_history.append(loss)
            print("Epoch [{}]: train loss: {:.2f} with loss1: {:.2f}, loss2: {:.2f}, loss_z: {:.4f}, loss_F: {:.4f}"
                  .format(epoch, loss, loss_dic['loss1'], loss_dic['loss2'], loss_dic['loss_z'], loss_dic['loss_F']))
        
        # Updating F
        with torch.no_grad():
            model.updating_F(h_history_history, h0, rho)
    return {'train_loss_history': train_loss_history, 'all_h_history': h_history_history, 'last_index': which_ones}

In [None]:
history = fit(epochs=500, lr=1e-6, h0=h0, model=model_CN)

In [None]:
history = fit(epochs=500, lr=1e-6, h0=h0, model=model_AD)

# Save and load model

In [None]:
torch.save(model_CN, os.path.join('Recorded', 'model_CN_half.pt'))
torch.save(model_CN.F_mat, os.path.join('Recorded', 'F_mat_CN.pt'))

In [None]:
model_AD = torch.load(os.path.join('Recorded', 'model_AD_half.pt'))
# model_CN.eval() 
'''Remember that you must call model.eval() to set dropout and batch normalization layers 
to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.
** We will call .eval() in the function named evaluate.'''

# Save images

In [None]:
t_tocompare = 60 # the timestamp on which we are to do comparison

def save_image(mu_history, name, which_one, t=t_tocompare):
    maxim = torch.Tensor.cpu(max_value[which_one,t]).detach().numpy()
    minim = torch.Tensor.cpu(min_value[which_one,t]).detach().numpy()
    one_mu = torch.Tensor.cpu(mu_history[t]).detach().numpy()
    one_mu = one_mu[0,:,:,:,:]
    one_mu = (one_mu * .5 + .5) * (maxim - minim) + minim
    img_new = nib.Nifti1Image(one_mu, np.eye(4))
    nib.save(img_new, os.path.join('Generated', name))

In [None]:
def get_name(file_position, data_dir = './Generated'):
    return os.listdir(data_dir)[file_position]


def show_slices(slices):
   """ Function to display row of image slices """
   fig, axes = plt.subplots(1, len(slices))
   for i, slice in enumerate(slices):
       axes[i].imshow(slice.T, cmap="gray", origin="lower")


def get_plot(file_name, data_dir = './Generated', plot_name = 'Slices', x1=40, x2=40, x3=40, t=t_tocompare):
    img = nib.load(os.path.join(data_dir, file_name))
    img_data = img.get_fdata()
    if data_dir == './Dataset':
        show_slices([img_data[x1,:,:,t], img_data[:,x2,:,t], img_data[:,:,x3,t]])
        plt.suptitle(plot_name)
    else:
        show_slices([img_data[0,x1,:,:], img_data[0,:,x2,:], img_data[0,:,:,x3]])
        plt.suptitle(plot_name) 
        
def get_plot_from_blue(img_name, file_dir = 'CN', plot_name = 'Slices', 
             x1=40, x2=40, x3=40, t=60):
    '''Plot images from blue'''
    data_dir = os.path.join('/blue/li.duan/zheng.yu/Research1/', file_dir, img_name)
    img = nib.load(data_dir)
    img_data = img.get_fdata()
    print(f'Shape:{img_data.shape}')
    show_slices([img_data[x1,:,:,t], img_data[:,x2,:,t], img_data[:,:,x3,t]])
    plt.suptitle(plot_name)

In [None]:
@torch.no_grad()
def evaluate(model, x, h0, which_ones):
    model.eval()
    # x = x.unsqueeze(0)
    # x = to_device(x, device)
    x_list, mu_history, h_history, gh_history = model(x, h0, which_ones)
    return x_list, mu_history, h_history, gh_history

In [None]:
import json
def save_dict_to_json(dict, name):
    # create json object from dictionary
    j_dict = json.dumps(dict)

    # open file for writing, "w" 
    f = open(os.path.join('Recorded', name+".json"),"w")

    # write json object to file
    f.write(j_dict)

    # close file
    f.close()

In [None]:
get_plot('Reconstructedmainimage_002_S_4225_2012-10-11.nii.gz.nii.gz', plot_name = 'Reconstructed')
get_plot_from_blue('mainimage_002_S_4225_2012-10-11.nii.gz', file_dir = 'CN')

# Evaluation

In [None]:
# Setting z to be 0
z_vectors = torch.zeros(38, z_vectors.size(1)) # CN
#z_vectors = torch.zeros(26, z_vectors.size(1)) # AD

In [None]:
losses_F_CN_for_CN = {'loss1':[],'loss2':[], 'tol_loss':[]}
losses_F_CN_for_CN['loss_F'] = rho * torch.sum(torch.pow(model_CN.F_mat, 2)).cpu().detach().numpy()
#losses_F_CN_for_CN['loss_z'] = lambda_z * torch.sum(torch.abs(z_vectors)).cpu().detach().numpy()
for test_batch, test_batch_index in test_loader:
    test_batch = test_batch.to(torch.float32)
    h_0 = h0.expand(test_batch.size(0), -1) # h0: 1 x latent_dim, h_0: batch_size x latent_dim
    which_ones = IDs[test_batch_index.to(int)]
    test_x_list, test_mu_history, test_h_history, test_gh_history = evaluate(model_CN, test_batch, h_0, which_ones)
    
    # calculate loss
    temp = 2 * test_batch.size(0) * tol_time
    loss1 = sum([torch.sum(torch.pow(x-mu, 2)) for x, mu in zip(test_x_list, test_mu_history)])
    loss1 = loss1 / (sig_x**2) / temp
        
    loss2 = sum([torch.sum(torch.pow(h-gh, 2)) for h, gh in zip(test_h_history, test_gh_history)])
    loss2 = loss2 / (sig_h**2) / temp
    
    losses_F_CN_for_CN['loss1'].append(loss1.item())
    losses_F_CN_for_CN['loss2'].append(loss2.item())
    losses_F_CN_for_CN['tol_loss'].append((loss1+loss2+losses_F_CN_for_CN['loss_F']).item())
    
    # Saving images
    save_image(mu_history=test_mu_history, 
               name='CN_test_Reconstructed_by_model_CN_half' + CN_img_names[which_ones], #Pay attention to the name when un-commenting this line
               which_one=which_ones)
    
save_dict_to_json(losses_F_CN_for_CN, 'losses_F_CN_for_CN_test')
    

In [None]:
losses_F_CN_for_AD = {'loss1':[],'loss2':[], 'tol_loss':[]}
losses_F_CN_for_AD['loss_F'] = rho * torch.sum(torch.pow(model_CN.F_mat, 2)).cpu().detach().numpy()
for test_batch, test_batch_index in test_loader:
    test_batch = test_batch.to(torch.float32)
    h_0 = h0.expand(test_batch.size(0), -1) # h0: 1 x latent_dim, h_0: batch_size x latent_dim
    which_ones = IDs[test_batch_index.to(int)]
    test_x_list, test_mu_history, test_h_history, test_gh_history = evaluate(model_CN, test_batch, h_0, which_ones)
    
    # calculate loss
    temp = 2 * test_batch.size(0) * tol_time
    loss1 = sum([torch.sum(torch.pow(x-mu, 2)) for x, mu in zip(test_x_list, test_mu_history)])
    loss1 = loss1 / (sig_x**2) / temp
        
    loss2 = sum([torch.sum(torch.pow(h-gh, 2)) for h, gh in zip(test_h_history, test_gh_history)])
    loss2 = loss2 / (sig_h**2) / temp
    
    losses_F_CN_for_AD['loss1'].append(loss1.item())
    losses_F_CN_for_AD['loss2'].append(loss2.item())
    losses_F_CN_for_AD['tol_loss'].append((loss1+loss2+losses_F_CN_for_AD['loss_F']).item())
    
    # Saving images
    save_image(mu_history=test_mu_history, 
               name='AD_test_Reconstructed_by_model_CN_half' + AD_img_names[which_ones], #Pay attention to the name when un-commenting this line
               which_one=which_ones)
    
save_dict_to_json(losses_F_CN_for_AD, 'losses_F_CN_for_AD_test')

In [None]:
losses_F_AD_for_AD = {'loss1':[],'loss2':[], 'tol_loss':[]}
losses_F_AD_for_AD['loss_F'] = rho * torch.sum(torch.pow(model_AD.F_mat, 2)).cpu().detach().numpy()
for test_batch, test_batch_index in test_loader:
    test_batch = test_batch.to(torch.float32)
    h_0 = h0.expand(test_batch.size(0), -1) # h0: 1 x latent_dim, h_0: batch_size x latent_dim
    which_ones = IDs[test_batch_index.to(int)]
    test_x_list, test_mu_history, test_h_history, test_gh_history = evaluate(model_AD, test_batch, h_0, which_ones)
    
    # calculate loss
    temp = 2 * test_batch.size(0) * tol_time
    loss1 = sum([torch.sum(torch.pow(x-mu, 2)) for x, mu in zip(test_x_list, test_mu_history)])
    loss1 = loss1 / (sig_x**2) / temp
        
    loss2 = sum([torch.sum(torch.pow(h-gh, 2)) for h, gh in zip(test_h_history, test_gh_history)])
    loss2 = loss2 / (sig_h**2) / temp
    
    losses_F_AD_for_AD['loss1'].append(loss1.item())
    losses_F_AD_for_AD['loss2'].append(loss2.item())
    losses_F_AD_for_AD['tol_loss'].append((loss1+loss2+losses_F_AD_for_AD['loss_F']).item())
    
    # Saving images
    save_image(mu_history=test_mu_history, 
               name='AD_test_Reconstructed_by_model_AD_half' + AD_img_names[which_ones], #Pay attention to the name when un-commenting this line
               which_one=which_ones)
    
save_dict_to_json(losses_F_AD_for_AD, 'losses_F_AD_for_AD_test')

In [None]:
losses_F_AD_for_CN = {'loss1':[],'loss2':[], 'tol_loss':[]}
losses_F_AD_for_CN['loss_F'] = rho * torch.sum(torch.pow(model_AD.F_mat, 2)).cpu().detach().numpy()
for test_batch, test_batch_index in test_loader:
    test_batch = test_batch.to(torch.float32)
    h_0 = h0.expand(test_batch.size(0), -1) # h0: 1 x latent_dim, h_0: batch_size x latent_dim
    which_ones = IDs[test_batch_index.to(int)]
    test_x_list, test_mu_history, test_h_history, test_gh_history = evaluate(model_AD, test_batch, h_0, which_ones)
    
    # calculate loss
    temp = 2 * test_batch.size(0) * tol_time
    loss1 = sum([torch.sum(torch.pow(x-mu, 2)) for x, mu in zip(test_x_list, test_mu_history)])
    loss1 = loss1 / (sig_x**2) / temp
        
    loss2 = sum([torch.sum(torch.pow(h-gh, 2)) for h, gh in zip(test_h_history, test_gh_history)])
    loss2 = loss2 / (sig_h**2) / temp
    
    losses_F_AD_for_CN['loss1'].append(loss1.item())
    losses_F_AD_for_CN['loss2'].append(loss2.item())
    losses_F_AD_for_CN['tol_loss'].append((loss1+loss2+losses_F_AD_for_CN['loss_F']).item())
    
    # Saving images
    save_image(mu_history=test_mu_history, 
               name='CN_test_Reconstructed_by_model_AD_half' + CN_img_names[which_ones], #Pay attention to the name when un-commenting this line
               which_one=which_ones)
    
save_dict_to_json(losses_F_AD_for_CN, 'losses_F_AD_for_CN_test')

# Extra

In [None]:
d_z = F.l1_loss(z_vectors, z_vectors_original, reduction='sum')
z_norm = F.l1_loss(z_vectors, torch.zeros_like(z_vectors), reduction='sum')
z_0_norm = F.l1_loss(z_vectors_original, torch.zeros_like(z_vectors_original), reduction='sum')
d_F = F.l1_loss(F_mat, F_mat_original, reduction='sum')
F_norm = F.l1_loss(F_mat, torch.zeros_like(F_mat), reduction='sum')
F_0_norm = F.l1_loss(F_mat_original, torch.zeros_like(F_mat_original), reduction='sum')
print(f'The l1 difference between z and z_0 is: {d_z}')
print(f'The l1 norm of z is: {z_norm}')
print(f'The l1 norm of z_0 is: {z_0_norm}')
print(f'The l1 difference between F and F_0 is: {d_F}')
print(f'The l1 norm of F is: {F_norm}')
print(f'The l1 norm of F_0 is: {F_0_norm}')

In [None]:
F_mat = model.F.weight
num_ = 10
diff_F = torch.zeros(num_, num_)
for i in range(diff_F.shape[0]):
    for j in range(diff_F.shape[1]):
        diff_F[i,j]=F.l1_loss(F_mat[i], F_mat[j], reduction='sum')
print(diff_F)
diff_F_same_subject = []
for i in range(train_size//2):
    diff_F_same_subject.append(F.l1_loss(F_mat[2*i], F_mat[2*i+1], reduction='sum'))
print(diff_F_same_subject)
diff_F_diff_subject = []
for i in range(train_size//2-1):
    diff_F_diff_subject.append(F.l1_loss(F_mat[2*i+1], F_mat[2*i+2], reduction='sum'))
print(diff_F_diff_subject)

In [None]:
index_for_plot = 0
h_history = history['h_history']
print(h_history[0].shape)
print(h_history[1].shape)
h1_all = torch.stack(h_history).transpose(0,1) # output: batch_size x tol_time x latent_dim
h1 = h1_all[index_for_plot]
h_history.pop(-1)
h_history.insert(0, h0.expand(2, -1)) # The size of expanded h0 needs to be adjusted (num_train mod batch_size)
#h_history.insert(0, h0.expand(batch_size, -1))
print(h_history[0].shape)
h2_all = torch.stack(h_history).transpose(0,1) # output: batch_size x tol_time x latent_dim
h2 = h2_all[index_for_plot]

# index_for_plot = 1
# h1 = h1_all[index_for_plot]
# h2 = h2_all[index_for_plot]

h2 = model.F(h2)
#h2 = torch.mm(h2, F_mat)

print(torch.dist(h1, h2))
for i in range(120):
    plt.plot([-5,5],[-5,5])
    plt.scatter(h1[i].cpu().detach().numpy(),h2[i].cpu().detach().numpy())
    plt.show()

In [None]:
for batch, batch_index in train_loader:
    h_0 = h0.expand(batch.size(0), -1) # h0: 1 x latent_dim, h_0: batch_size x latent_dim
    which_ones = IDs[batch_index.to(int)]
    x, mu_his, h_his, gh_his = model(batch, h0.expand(batch.size(0), -1), which_ones)

    

index_for_plot = 0
h_all = torch.stack(h_his).transpose(0,1) # output: batch_size x tol_time x latent_dim
h = h_all[index_for_plot]

gh_all = torch.stack(gh_his).transpose(0,1) # output: batch_size x tol_time x latent_dim
gh = gh_all[index_for_plot]


print(torch.dist(h, gh))
for i in range(120):
    plt.plot([-5,5],[-5,5])
    plt.scatter(h[i].cpu().detach().numpy(),gh[i].cpu().detach().numpy())
    plt.show()

In [None]:
index_for_plot = 0
h_history = history['h_history']
print(h_history[0].shape)
print(h_history[1].shape)
h1_all = torch.stack(h_history).transpose(0,1) # output: batch_size x tol_time x latent_dim
h1 = h1_all[index_for_plot]
h_history.pop(-1)
h_history.insert(0, h0.expand(2, -1)) # The size of expanded h0 needs to be adjusted (num_train mod batch_size)
#h_history.insert(0, h0.expand(batch_size, -1))
print(h_history[0].shape)
h2_all = torch.stack(h_history).transpose(0,1) # output: batch_size x tol_time x latent_dim
h2 = h2_all[index_for_plot]

# index_for_plot = 1
# h1 = h1_all[index_for_plot]
# h2 = h2_all[index_for_plot]

h2 = model.F(h2)
#h2 = torch.mm(h2, F_mat)

print(torch.dist(h1, h2))
for i in range(120):
    plt.plot([-5,5],[-5,5])
    plt.scatter(h1[i].cpu().detach().numpy(),h2[i].cpu().detach().numpy())
    plt.show()