# -*- coding: utf-8 -*-
"""
Created on Mon Jul  6 16:00:55 2020

@author: haolinl
"""

In [None]:
import os
import copy

import numpy as np
import torch
import scipy.io

from sklearn import linear_model
from torchvision.transforms import Compose, ToTensor

In [None]:
model_name = "head_and_neck" # One of "head_and_neck", "kidney", and "aorta".
alpha = 0.001 # Weight applied to the regularizaion term. 
deformation_scalar = 100 # Scale back to real dimensions by this number. Inherited from the main code. Default: 100. Do not need to change. 
gt_mode = "restored" # Either "reconstructed" or "restored". Do not need to change. 

In [None]:

# HELP FUNCTIONS. 

def get_non_zero_indices_list(fix_indices_list, nDOF):
    """_summary_

    Args:
        fix_indices_list (_type_): Indexing from 1. 
        nDOF (_type_): _description_
    """
    
    fix_indices_list = [item-1 for item in fix_indices_list] # Return to zero-indexing. 
    non_zero_indices_list = []
    
    for i in range(int(nDOF/3)): # Iterate within the range of node_num. 
        if i not in fix_indices_list: 
            non_zero_indices_list.append(i*3)
            non_zero_indices_list.append(i*3+1)
            non_zero_indices_list.append(i*3+2)
            
    return non_zero_indices_list # Index from 0. Dim: (nDOF-nFix*3) * 1. 


def matrixExpand(data_matrix, nDOF, non_zero_indices_list):
    """_summary_

    Args:
        data_matrix (_type_): Shape: nFeature * nSample. 
        nDOF (_type_): _description_
        non_zero_indices_list (_type_): _description_
    """
    
    data_expanded = np.zeros(shape=(nDOF, data_matrix.shape[1]), dtype=complex)

    for i, index in enumerate(non_zero_indices_list):
        data_expanded[index,:] = data_matrix[i,:]
    
    return np.real(data_expanded)


def dataReconstruction_Autoencoder(encoder_model, latent_mat, nDOF, 
                                   non_zero_indices_list, device, 
                                   dtype=torch.float, transform=Compose([ToTensor()])):
    """
    latent_mat: nSample * nFeature. 
    """
    
    # Transform weights back to original vector space (decoding). 
    encoder_model.eval()
    with torch.no_grad():
        for i in range(latent_mat.shape[0]):
            latent = latent_mat[i,:]
            data_temp = encoder_model.decoder(transform(latent.reshape(1,-1)).to(dtype).to(device))
            data_temp = data_temp.cpu().data.numpy().astype(float).reshape(-1,1)
            
            if i == 0: data_reconstruct_mat = copy.deepcopy(data_temp)
            else: data_reconstruct_mat = np.hstack((data_reconstruct_mat, 
                                                    copy.deepcopy(data_temp))) # Shape: (nDOF-nFix*3) * nSample. 
    
    return matrixExpand(data_reconstruct_mat, nDOF, non_zero_indices_list)

In [None]:
if model_name == "head_and_neck":
    working_directory = "head_and_neck"
    FM_indices_array = np.array([4,96,431,752,1144]).reshape(-1) # Index from 0. Do not need to change. 
    nDOF, fix_indices_list = 3474, [761, 1000, 1158] # Index from 1. Do not need to change. 
    non_zero_indices_list = get_non_zero_indices_list(fix_indices_list, nDOF) # Index from 0. Dim: (nDOF-nFix*3) * 1.

elif model_name == "kidney":
    working_directory = "kidney"
    FM_indices_array = np.array([9,235,327,350,475]).reshape(-1) # Index from 0. Do not need to change. 
    nDOF, fix_indices_list = 3372, [2, 453, 745] # Index from 1.Do not need to change. 
    non_zero_indices_list = get_non_zero_indices_list(fix_indices_list, nDOF) # Index from 0. Dim: (nDOF-nFix*3) * 1.

elif model_name == "aorta":
    working_directory = "aorta"
    FM_indices_array = np.array([129,381,429,467,475,484,662,798,1123,1151]).reshape(-1) # Index from 0. Do not need to change. 
    nDOF, fix_indices_list = 3654, [1148, 1156, 1169] # Index from 1. Do not need to change. 
    non_zero_indices_list = get_non_zero_indices_list(fix_indices_list, nDOF) # Index from 0. Dim: (nDOF-nFix*3) * 1.
    
else: raise ValueError("Illegal 'model_name' input. ")

deform_dataset_filepath = os.path.join(working_directory, "groundtruths_list.npy")
latent_encoding_filepath = os.path.join(working_directory, "latent_list.npy")
train_set_indices_filepath = os.path.join(working_directory, "train_set_ind_array.npy")
test_set_indices_filepath = os.path.join(working_directory, "test_set_ind_array.npy")
AE_model_filepath = os.path.join(working_directory, "model_final.pth")

In [None]:
# Data processing. 

deform_matrix = np.load(deform_dataset_filepath) # Dim: nSample * (nDOF-nFix*3).

FM_disp_matrix = np.zeros(shape=(deform_matrix.shape[0], int(len(FM_indices_array)*3))) # Size: nSample * nFM_disp.
for i, index in enumerate(FM_indices_array):
    FM_disp_matrix[:,i*3:(i+1)*3] = deform_matrix[:,int(index*3):int((index+1)*3)]

latent_matrix = np.load(latent_encoding_filepath) # Dim: nSample * nFeature. 
train_set_indices_arr = np.load(train_set_indices_filepath).astype(int).reshape(-1)
test_set_indices_arr = np.load(test_set_indices_filepath).astype(int).reshape(-1)

train_x_mat = FM_disp_matrix[train_set_indices_arr]
train_y_mat = latent_matrix[train_set_indices_arr]
test_x_mat = FM_disp_matrix[test_set_indices_arr]
test_y_mat = latent_matrix[test_set_indices_arr]

In [None]:
# Ridge regression model optimizing. 

rr_model = linear_model.Ridge(alpha=alpha)
rr_model.fit(train_x_mat, train_y_mat) # Ridge regression model training. 

test_pred_mat = rr_model.predict(test_x_mat) # RR prediction. 
train_pred_mat = rr_model.predict(train_x_mat) # RR prediction. 

In [None]:
# AE reconstruction. 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ae_model = torch.load(AE_model_filepath).to(device)

if gt_mode == "reconstructed":
    data_reconstruct_test_y = dataReconstruction_Autoencoder(ae_model, test_y_mat, 
                                                             nDOF, non_zero_indices_list, 
                                                             device) # Shape: nDOF * nSample. Reconstruct via grountruth latents. 
    data_reconstruct_train_y = dataReconstruction_Autoencoder(ae_model, train_y_mat, 
                                                              nDOF, non_zero_indices_list, 
                                                              device) # Shape: nDOF * nSample. Reconstruct via grountruth latents.
elif gt_mode == "restored": 
    data_reconstruct_test_y = matrixExpand(deform_matrix[test_set_indices_arr].T, 
                                           nDOF, non_zero_indices_list) # Shape: nDOF * nSample. Restore from grountruth deformation.
    data_reconstruct_train_y = matrixExpand(deform_matrix[train_set_indices_arr].T, 
                                            nDOF, non_zero_indices_list) # Shape: nDOF * nSample. Restore from grountruth deformation.
else: raise ValueError("Illegal 'gt_mode' input. ")

data_reconstruct_test_predict = dataReconstruction_Autoencoder(ae_model, test_pred_mat, 
                                                               nDOF, non_zero_indices_list, 
                                                               device) # Shape: nDOF * nSample. 

data_reconstruct_train_predict = dataReconstruction_Autoencoder(ae_model, train_pred_mat, 
                                                                nDOF, non_zero_indices_list, 
                                                                device) # Shape: nDOF * nSample. 

In [None]:
# Save results as .mat files. 

mdict = {"model_name": model_name, "gt_mode": gt_mode,
         "FM_indices": FM_indices_array.astype(int).reshape(-1,1)+1,
         "fix_node_list": fix_indices_list, "nDOF": nDOF, 
         "RR_alpha": alpha, "deformation_scalar": deformation_scalar,
         "test_deformation_label": data_reconstruct_test_y/deformation_scalar, # Groundtruth on test dataset. 
         "test_deformation_reconstruct": data_reconstruct_test_predict/deformation_scalar, # RR reconstruction on test dataset. 
         "train_deformation_label": data_reconstruct_train_y/deformation_scalar, # Groundtruth on training dataset. 
         "train_deformation_reconstruct": data_reconstruct_train_predict/deformation_scalar} # RR reconstruction on training dataset.

mat_save_path = os.path.join(working_directory, "RR_benchmark_result_alpha_{}.mat".format(alpha))
scipy.io.savemat(mat_save_path, mdict)