In [1]:
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torchvision import transforms, models
from tqdm.auto import tqdm
import random
import torch
import torch.nn.functional as F
import torch.distributions as dist
from custom_data_loading import *
from density_estimation import *
from nonlinear_vae import *
from distance_correlation import *
from spearman_correlation import *
import os
import re


def stable_cor(z,eps=1e-6):

    cov=z.T.cov()
    normalized_factor=(cov.diag().reshape(-1,1)*cov.diag().reshape(1,-1)).sqrt()+eps
    #normalized_factor=torch.clamp(normalized_factor, min=eps)
    R=cov/normalized_factor
    
    return R


def elbo_loss_SC(X,X_hat,mean,std,z,beta=0,l1=0,l2=0,l3=50):
    #I. the MSE reconstructed loss
    recon_loss=(nn.MSELoss(reduction='sum')(X,X_hat))/len(X)
    #II. the KLD for posterior
    KLD = (0.5 * torch.sum(-1 - (std**2).log() + mean**2 + std**2))/len(X)
    #III. Covariance regularization
    #Covariance matrix and diag-decomposition
    cor_z=z.T.cov()
    #cor_z=stable_cor(z)
    _,off_diagz=diagonal_decomposition(cor_z)
    #Regularization term 1 off_diag
    reg=off_diagz.abs().sum()

    #Reg 2: spearman_correlation
    SC=spearman_correlation_vectorized(z)
    _,off_diagSC=diagonal_decomposition(SC)
    reg2=off_diagSC.abs().sum()
    #reg2=reg2/(len(X)**2-len(X))
    total_loss=recon_loss+beta*KLD+l2*reg+l3*reg2
    return total_loss,recon_loss,KLD,reg,reg2

def elbo_loss_DC(X,X_hat,mean,std,z,beta=0,l1=0,l2=0,l3=50):
    #I. the MSE reconstructed loss
    recon_loss=(nn.MSELoss(reduction='sum')(X,X_hat))/len(X)
    #II. the KLD for posterior
    KLD = (0.5 * torch.sum(-1 - (std**2).log() + mean**2 + std**2))/len(X)
    #III. Covariance regularization
    #Covariance matrix and diag-decomposition
    cor_z=z.T.cov()
    #cor_z=stable_cor(z)
    _,off_diagz=diagonal_decomposition(cor_z)
    #Regularization term 1 off_diag
    reg=off_diagz.abs().sum()

    #Reg 2: Dsitance correlation
    DC,_=Distance_CorrCorv_vectorized(z,correlation=True)
    _,off_diagDC=diagonal_decomposition(DC)
    reg2=off_diagDC.abs().sum()
    #reg2=reg2/(len(X)**2-len(X))
    total_loss=recon_loss+beta*KLD+l2*reg+l3*reg2
    return total_loss,recon_loss,KLD,reg,reg2

def elbo_loss_TC(X,X_hat,mean,std,z,beta=0,l1=0,l2=5,l3=50):
    #I. the MSE reconstructed loss
    recon_loss=(nn.MSELoss(reduction='sum')(X,X_hat))/len(X)
    #II. the KLD for posterior
    KLD = (0.5 * torch.sum(-1 - (std**2).log() + mean**2 + std**2))/len(X)
    #III. Covariance regularization
    #Covariance matrix and diag-decomposition
    cov_z=z.T.cov()
    diagz,off_diagz=diagonal_decomposition(cov_z)
    #Regularization term 1 off_diag
    #reg=(l1*((diagz-1)**2).sum()+0.5*l2*(off_diagz**2).sum())/len(X)
    reg=(off_diagz**2).sum()#/(len(X)**2-len(X))
    #Reg 2: sum of mutual info
    #1 Calculate the marginal densities and kernel values from skew-normal
    _,marginal_k=multi_vars_SkewNormal_density_estimate(z,check=False)
    #2 Estimate the total joint kernel values for each sample
    #joint,_=estimate_total_joint_density(z)
    _,joint_k=estimate_total_joint_density_skew_normal_training(z)
    #3 Estimate the total correlation
    tc_array=total_correlation_estimation(marginal_k,joint_k)
    reg2=tc_array.abs().sum()#/len(X)
    total_loss=recon_loss+beta*KLD+l2*reg+l3*reg2
    return total_loss,recon_loss,KLD,reg,reg2

def elbo_loss_MI(X,X_hat,mean,std,z,beta=0,l1=0,l2=5,l3=50):
    #I. the MSE reconstructed loss
    recon_loss=(nn.MSELoss(reduction='sum')(X,X_hat))/len(X)
    #II. the KLD for posterior
    KLD = (0.5 * torch.sum(-1 - (std**2).log() + mean**2 + std**2))/len(X)
    #III. Covariance regularization
    #Covariance matrix and diag-decomposition
    cov_z=z.T.cov()
    diagz,off_diagz=diagonal_decomposition(cov_z)
    #Regularization term 1 off_diag
    #reg=(l1*((diagz-1)**2).sum()+0.5*l2*(off_diagz**2).sum())/len(X)
    reg=(off_diagz**2).sum()/(len(X)**2-len(X))
    #Reg 2: sum of mutual info
    _,mi_sum=pair_wise_MI_upper(z)
    reg2=mi_sum/(len(X)**2-len(X))
    total_loss=recon_loss+beta*KLD+l2*reg+l3*reg2
    return total_loss,recon_loss,KLD,reg,reg2

def VAE_list_trainer(model,X_tr_list,X_val_list,param):
    lr=param['lr']
    max_epoch=param['max_epoch']
    batch_sz=param['batch_sz']
    Type=param['Type']
    beta=param['beta']
    l1=param['l1']
    l2=param['l2']
    l3=param['l3']
    by_channel=param['by_channel']
    channel_idx=param['channel_idx']
    normalize=param['normalize']
    loss=param['loss_fn']
    optimizor=param['optimizor']
    print('The adopted loss:',loss)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    try:
        assert (len(X_tr_list)>batch_sz) & (len(X_val_list)>batch_sz)
    except:
        batch_sz=len(X_val_list)
    costs,BCE,KLD,Reg,Reg2=[[],[],[],[],[]]
    costs_v,BCE_v,KLD_v,Reg_v,Reg2_v=[[],[],[],[],[]]
    
    #optimizor=optim.RMSprop(self.parameters(),lr=lr,alpha=0.9)
    batch_num_tr=len(X_tr_list)//batch_sz
    batch_num_val=len(X_val_list)//batch_sz
    for epoch in tqdm(range(max_epoch),desc='Epoch',position=0,leave=True):
        #Shffule the lists
        random.shuffle(X_tr_list)
        random.shuffle(X_val_list)
        model.train()
        for b in tqdm(range(batch_num_tr),desc="Iteration in Epoch:{0}".format(epoch+1),position=0,leave=False):
            #Load batch tensors at once when training!
            X_b_list=X_tr_list[b*batch_sz:(b+1)*batch_sz]
            X_b=load_a_batch_tensor(X_b_list,by_channel,channel_idx).to(device)
            #Z-score standardization
            if normalize:
                X_b,_,_=batch_tensor_standardization(X_b)
                #print('Train normalization is processed!')
            ###
            mean,std,_,z=model.forward_encoder(X_b)
            #try:
            #    assert not z.any().isnan()
            #except:
            #    print("Nan value happen in the Encoder")
            #    break
            X_hat=model.forward_decoder(z)
            #try:
            #    assert not X_hat.any().isnan()
            #except:
            #    print("Nan value happen in the Decoder")
            #    break
            #mean,log_var,X_hat=model.forward(X_b)
            cost,bce,kld,reg,reg2=loss(X_b,X_hat,mean,std,z,beta,l1,l2,l3)
            ###
            optimizor.zero_grad()
            cost.backward()
            optimizor.step()
        #Validation
        model.eval()
        
        with torch.no_grad():
            for v in range(batch_num_val):
                X_bv_list=X_val_list[v*batch_sz:(v+1)*batch_sz]
                X_bv=load_a_batch_tensor(X_bv_list,by_channel,channel_idx).to(device)
                if normalize:
                    #Z-score standardization
                    X_bv,_,_=batch_tensor_standardization(X_bv)
                    #print('valid normalization is processed!')
                #mean,log_var,X_hat_val=model.forward(X_bv)
                #cost_v,bce_v,kld_v=elbo_loss(X_hat_val,X_bv,mean,log_var,beta)
                ###
                meanv,stdv,_,zv=model.forward_encoder(X_bv)
                X_hat_v=model.forward_decoder(zv)
                #mean,log_var,X_hat=model.forward(X_b)
                cost_v,bce_v,kld_v,reg_v,reg2_v=loss(X_bv,X_hat_v,meanv,stdv,zv,beta,l1,l2,l3)
                ###
        
        costs.append((cost.detach()).to('cpu'))
        BCE.append((bce.detach()).to('cpu'))
        KLD.append((kld.detach()).to('cpu'))
        Reg.append((reg.detach()).to('cpu'))
        Reg2.append((reg2.detach()).to('cpu'))
        print('Epoch loss:{0},MSE:{1},Reg1:{2},Reg2:{3}'.format(cost.detach().item(),bce.detach().item(),reg.detach().item(),reg2.detach().item()))
        costs_v.append((cost_v.detach()).to('cpu'))
        BCE_v.append((bce_v.detach()).to('cpu'))
        KLD_v.append((kld_v.detach()).to('cpu'))
        Reg_v.append((reg_v.detach()).to('cpu'))
        Reg2_v.append((reg2_v.detach()).to('cpu'))

    model.cost_dict={'ELBO_tr':costs,'BCE_tr':BCE,'KLD_tr':KLD,'Reg_tr':Reg,'Reg2_tr':Reg2,\
                'ELBO_v':costs_v,'BCE_v':BCE_v,'KLD_v':KLD_v,'Reg_v':Reg_v,'Reg2_v':Reg2_v,}      
    return model

def plot_loss_reg_VAE(model,mode='Training_only'):
    if mode=='Training_only':
        plt.figure(figsize=(20,5))
        plt.subplot(1,5,1)
        plt.plot([c for c in model.cost_dict['ELBO_tr']],label='Train')
        plt.title('ELBO')
        plt.subplot(1,5,2)
        plt.plot([b for b in model.cost_dict['BCE_tr']],label='Train')
        plt.title('Reconstruction loss')
        plt.subplot(1,5,3)
        plt.plot([k for k in model.cost_dict['KLD_tr']],label='Train')
        plt.title('KL divergence')
        plt.subplot(1,5,4)
        plt.plot([k for k in model.cost_dict['Reg_tr']],label='Train')
        plt.title('Regularization 1')
        plt.subplot(1,5,5)
        plt.plot([k for k in model.cost_dict['Reg2_tr']],label='Train')
        plt.title('Regularization 2')
        plt.tight_layout()
        plt.show()
    else:
        plt.figure(figsize=(20,5))
        plt.subplot(1,5,1)
        plt.plot([c for c in model.cost_dict['ELBO_tr']],label='Train')
        plt.plot([c for c in model.cost_dict['ELBO_v']],label='Validation')
        plt.legend()
        plt.title('ELBO')
        plt.subplot(1,5,2)
        plt.plot([b for b in model.cost_dict['BCE_tr']],label='Train')
        plt.plot([b for b in model.cost_dict['BCE_v']],label='Validation')
        plt.legend()
        plt.title('Reconstruction loss')
        plt.subplot(1,5,3)
        plt.plot([k for k in model.cost_dict['KLD_tr']],label='Train')
        plt.plot([k for k in model.cost_dict['KLD_v']],label='Validation')
        plt.legend()
        plt.title('KL divergence')
        plt.subplot(1,5,4)
        plt.plot([k for k in model.cost_dict['Reg_tr']],label='Train')
        plt.plot([k for k in model.cost_dict['Reg_v']],label='Validation')
        plt.legend()
        plt.title('Regularization 1')
        
        plt.subplot(1,5,5)
        plt.plot([k for k in model.cost_dict['Reg2_tr']],label='Train')
        plt.plot([k for k in model.cost_dict['Reg2_v']],label='Validation')
        plt.legend()
        plt.title('Regularization 2')
        plt.tight_layout()
        plt.show()