In [None]:
import sys
import copy
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from einops import rearrange
from Utilities import *

In [None]:
File_C, Target_C, FID_C  = 'Data/carevue.csv',    0, [211, 646, 618, 51, 8368]                  # File Path, Prediction ID, Feature IDs [for carevue]
File_M, Target_M, FID_M  = 'Data/metavision.csv', 0, [220045, 220210, 220277, 220181, 220179]   # File Path, Prediction ID, Feature IDs [for metavision]

W     = 3             # Window Size
UP    = 50            # Update Period
EPOC  = 50            # Number of epoch
alpha = 0.2           # Weight of heterogeneous transfer

NF_C  = len(FID_C)-1  # Number of features [for carevue]
NF_M  = len(FID_M)-1  # Number of features [for metavision]
loss_mae = torch.nn.L1Loss()
loss_mse = torch.nn.MSELoss()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# XSC_Tra,XDC_Tra,YC_Tra,BC_Tra, XSC_Val,XDC_Val,YC_Val,BC_Val, XSC_Tes,XDC_Tes,YC_Tes,BC_Tes = load_data_SD(File_C, [0.6, 0.2, 0.2], FID_C, Target_C, W, device)
# XSM_Tra,XDM_Tra,YM_Tra,BM_Tra, XSM_Val,XDM_Val,YM_Val,BM_Val, XSM_Tes,XDM_Tes,YM_Tes,BM_Tes = load_data_SD(File_M, [0.6, 0.2, 0.2], FID_M, Target_M, W, device)

In [None]:
class AE(torch.nn.Module):
    def __init__(self, n_time, n_embv):
        super(AE, self).__init__()
        ed0, ed1, ed2 = 32, 256, 24
        self.encoder  = torch.nn.Sequential(torch.nn.Linear(n_time, ed0), torch.nn.LeakyReLU(), torch.nn.Linear(ed0, ed1), torch.nn.LeakyReLU(),torch.nn.Linear(ed1, ed2), torch.nn.LeakyReLU(), torch.nn.Linear(ed2, n_embv))
        self.decoder  = torch.nn.Sequential(torch.nn.Linear(n_embv, ed2), torch.nn.LeakyReLU(), torch.nn.Linear(ed2, ed1), torch.nn.LeakyReLU(),torch.nn.Linear(ed1, ed0), torch.nn.LeakyReLU(), torch.nn.Linear(ed0, n_time))
    def forward(self, x):
        embv = self.encoder(x)
        recv = self.decoder(embv)
        return [embv, recv]

class PL(torch.nn.Module):
    def __init__(self, n_fea, n_embv):
        super(PL, self).__init__()
        ed0, ed1, ed2 = 32, 256, 16
        self.Linear    = torch.nn.Sequential(torch.nn.Linear(n_fea*n_embv*2, ed0), torch.nn.LeakyReLU(), torch.nn.Linear(ed0, ed1), torch.nn.LeakyReLU(),torch.nn.Linear(ed1, ed2), torch.nn.LeakyReLU(), torch.nn.Linear(ed2, 1))
    def forward(self, x):
        return self.Linear(x) 

In [None]:
NPara  = 0
Net_AE, Net_PL = AE(n_time=W, n_embv=1), PL(n_fea =NF_C, n_embv=1)
paras  = filter(lambda p: p.requires_grad, Net_AE.parameters())
NPara += sum([np.prod(p.size()) for p in paras])*NF_C
paras  = filter(lambda p: p.requires_grad, Net_PL.parameters())
NPara += sum([np.prod(p.size()) for p in paras])
print('Parameters:', NPara)

### Training

In [None]:
# Initialization
BL_C, PATH_C = np.inf, 'Model/3-TEDS/Model-care-'+str(Target_C)+'-'+str(Target_M)+'_'+str(IterID)  # Best Loss, Model Saved Path
BL_M, PATH_M = np.inf, 'Model/3-TEDS/Model-meta-'+str(Target_C)+'-'+str(Target_M)+'_'+str(IterID)
Tnet_AEC, Tnet_PLC = [AE(n_time=W,  n_embv=1) for F in range(NF_C)], PL(n_fea =NF_C, n_embv=1)     # Temporary Models
Tnet_AEM, Tnet_PLM = [AE(n_time=W,  n_embv=1) for F in range(NF_M)], PL(n_fea =NF_M, n_embv=1)
net_AEC,   net_PLC = [AE(n_time=W,  n_embv=1) for F in range(NF_C)], PL(n_fea =NF_C, n_embv=1)     # Training Model
net_AEM,   net_PLM = [AE(n_time=W,  n_embv=1) for F in range(NF_M)], PL(n_fea =NF_M, n_embv=1)
for F in range(NF_C): net_AEC[F].to(device)
for F in range(NF_M): net_AEM[F].to(device)
net_PLC.to(device)
net_PLM.to(device)

for epoc in range(EPOC):
    for d in range(2):     # 0=Train, 1=Valid
        if d==0:     # Train
            XDC, XSC, YC, BC = XDC_Tra, XSC_Tra, YC_Tra, BC_Tra
            XDM, XSM, YM, BM = XDM_Tra, XSM_Tra, YM_Tra, BM_Tra
        if d==1:     # Valid
            XDC, XSC, YC, BC = XDC_Val, XSC_Val, YC_Val, BC_Val
            XDM, XSM, YM, BM = XDM_Val, XSM_Val, YM_Val, BM_Val
            Tnet_AEC, Tnet_PLC = copy.deepcopy(net_AEC), copy.deepcopy(net_PLC)     # Save the origin model
            Tnet_AEM, Tnet_PLM = copy.deepcopy(net_AEM), copy.deepcopy(net_PLM)
        optimizer_HC, optimizer_RC = [torch.optim.Adam(net_AEC[F].parameters(), lr=0.0025) for F in range(NF_C)], torch.optim.Adam(net_PLC.parameters(), lr=0.01)
        optimizer_HM, optimizer_RM = [torch.optim.Adam(net_AEM[F].parameters(), lr=0.0025) for F in range(NF_M)], torch.optim.Adam(net_PLM.parameters(), lr=0.01)
        Pred_C, SL_C = torch.FloatTensor([]).to(device), []
        Pred_M, SL_M = torch.FloatTensor([]).to(device), []
        NRC   = [BC[i+1]-BC[i] for i in range(len(BC)-1)]                           # Number of records of each patient
        NRM   = [BM[i+1]-BM[i] for i in range(len(BM)-1)]

        for b in range( max([max(NRC), max(NRM)])//UP ):
            if d==0 and b%5==4: print(epoc, 'Tra', b+1, '/', max([max(NRC), max(NRM)])//UP, end='     \r')
            if d==1 and b%5==4: print(epoc, 'Val', b+1, '/', max([max(NRC), max(NRM)])//UP, end='     \r')
            sl_C, sl_M = [], []
            for p in range(len(BC)-1):
                if NRC[p]>=UP*(b+1): sl_C += [BC[p]+UP*b+i for i in range(UP)]
            for p in range(len(BM)-1):
                if NRM[p]>=UP*(b+1): sl_M += [BM[p]+UP*b+i for i in range(UP)]        

            # Heterogeneous Transfer (Training Phase Only)
            if d==0:
                # Find the similar AE for Care from Meta
                Tnet_AEC, Tnet_AEM, SAE_C = copy.deepcopy(net_AEC), copy.deepcopy(net_AEM), []
                for FC in range(NF_C):
                    LL = []   # Loss list
                    for FM in range(NF_M):
                        ev, rv = net_AEM[FM](XDC[sl_C,FC,:])
                        LL.append(float(loss_mae(rv.detach(), XDC[sl_C,FC,:])))
                    SAE_C.append(np.argmin(LL))
                # Weighted
                for F in range(NF_C): 
                    sdC = Tnet_AEC[F].state_dict()
                    sdM = Tnet_AEM[SAE_C[F]].state_dict()
                    for key in sdC: sdC[key] = (1-alpha)*sdC[key] + alpha*sdM[key]
                    net_AEC[F].load_state_dict(sdC)
            if d==0:   
                Tnet_AEC, Tnet_AEM, SAE_M = copy.deepcopy(net_AEC), copy.deepcopy(net_AEM), []
                for FM in range(NF_M):
                    LL = []   # Loss list
                    for FC in range(NF_C):
                        ev, rv = net_AEC[FC](XDM[sl_M,FM,:])
                        LL.append(float(loss_mae(rv.detach(), XDM[sl_M,FM,:])))
                    SAE_M.append(np.argmin(LL))
                for F in range(NF_M):
                    sdC = Tnet_AEC[SAE_M[F]].state_dict()
                    sdM = Tnet_AEM[F].state_dict()
                    for key in sdM: sdM[key] = (1-alpha)*sdM[key] + alpha*sdC[key]
                    net_AEM[F].load_state_dict(sdM)
                    
            # Care            
            EV, RV = net_AEC[0]( XDC[sl_C,0,:] )
            RV     = RV[:,None,:]
            for FC in range(1,NF_C):     # Embed Dense Feature
                ev, rv = net_AEC[FC]( XDC[sl_C,FC,:] )
                EV = torch.cat((EV,ev), dim=1)
                RV = torch.cat((RV,rv[:,None,:]), dim=1)
            for FC in range(NF_C):       # Embed Spare Feature
                ev, rv = net_AEC[F]( XSC[sl_C,FC,:] )
                EV = torch.cat((EV,ev), dim=1)
            pred = net_PLC(EV)
            for FC in range(NF_C):       # Update AEs
                loss = loss_mae(RV[:,FC], XDC[sl_C,FC,:])
                loss.backward(retain_graph=True)
            loss = loss_mse(torch.reshape(pred,(-1,)), YC[sl_C])
            loss.backward()
            for FC in range(NF_C): optimizer_HC[FC].step()
            for FC in range(NF_C): optimizer_HC[FC].zero_grad()
            optimizer_RC.step()
            optimizer_RC.zero_grad()
            Pred_C = torch.cat((Pred_C,torch.reshape(pred,(-1,)).detach()))
            SL_C  += sl_C
            
            # Meta
            EV, RV = net_AEM[0]( XDM[sl_M,0,:] )
            RV     = RV[:,None,:]
            for FM in range(1,NF_M):     # Embed Dense Feature
                ev, rv = net_AEM[FM]( XDM[sl_M,FM,:] )
                EV = torch.cat((EV,ev), dim=1)
                RV = torch.cat((RV,rv[:,None,:]), dim=1)
            for FM in range(NF_M):       # Embed Spare Feature
                ev, rv = net_AEM[FM]( XSM[sl_M,FM,:] )
                EV = torch.cat((EV,ev), dim=1)
            pred = net_PLM(EV)
            for FM in range(NF_M):       # Update AEs
                loss = loss_mae(RV[:,FM], XDM[sl_M,FM,:])
                loss.backward(retain_graph=True)
            loss = loss_mse(torch.reshape(pred,(-1,)), YM[sl_M])
            loss.backward()
            for FM in range(NF_M): optimizer_HM[FM].step()
            for FM in range(NF_M): optimizer_HM[FM].zero_grad()
            optimizer_RM.step()
            optimizer_RM.zero_grad()
            Pred_M = torch.cat((Pred_M,torch.reshape(pred,(-1,)).detach()))
            SL_M  += sl_M

        if d==1:     # Save Best (Validation Phase Only)
            loss_C, loss_M = loss_mse(Pred_C, YC[SL_C]), loss_mse(Pred_M, YM[SL_M])
            if loss_C<BL_C:
                BL_C = loss_C
                torch.save([Tnet_PLC.state_dict()]+[Tnet_AEC[F].state_dict() for F in range(NF_C)], PATH_C+'.pt') 
                print('Save Best Care at', epoc, 'with loss of',  round(float(loss_C),3))
            if loss_M<BL_M:
                BL_M = loss_M
                torch.save([Tnet_PLM.state_dict()]+[Tnet_AEM[F].state_dict() for F in range(NF_M)], PATH_M+'.pt') 
                print('Save Best Meta at', epoc, 'with loss of',  round(float(loss_M),3))
            if epoc%10==9: print(epoc+1, '/', EPOC, ' ', round(float(loss_C),3), round(float(loss_M),3), end='\r')

### Losses

In [None]:
# Initialization
TVTLoss_C, TVTLoss_M = np.zeros(3), np.zeros(3)     # Train/Valid/Test Loss
XSC_Tra,XDC_Tra,YC_Tra,BC_Tra, XSC_Val,XDC_Val,YC_Val,BC_Val, XSC_Tes,XDC_Tes,YC_Tes,BC_Tes = load_data_SD(File_C, [0.6, 0.2, 0.2], FID_C, Target_C, W, device)
XSM_Tra,XDM_Tra,YM_Tra,BM_Tra, XSM_Val,XDM_Val,YM_Val,BM_Val, XSM_Tes,XDM_Tes,YM_Tes,BM_Tes = load_data_SD(File_M, [0.6, 0.2, 0.2], FID_M, Target_M, W, device)
Tnet_AEC, Tnet_PLC = [AE(n_time=W,  n_embv=1) for F in range(NF_C)], PL(n_fea =NF_C, n_embv=1)     # Temporary Models (for Care)
Tnet_AEM, Tnet_PLM = [AE(n_time=W,  n_embv=1) for F in range(NF_M)], PL(n_fea =NF_M, n_embv=1)     # Temporary Models (for Meta)
net_AEM,   net_PLM = [AE(n_time=W,  n_embv=1) for F in range(NF_M)], PL(n_fea =NF_M, n_embv=1)     # Saved Model
net_AEC,   net_PLC = [AE(n_time=W,  n_embv=1) for F in range(NF_C)], PL(n_fea =NF_C, n_embv=1)     # Training Model
for F in range(NF_C): net_AEC[F].to(device)
for F in range(NF_M): net_AEM[F].to(device)
net_PLC.to(device)
net_PLM.to(device)

PATH_C = 'Model/3-TEDS/Model-care-'+str(Target_C)+'-'+str(Target_M)+'_'+str(IterID)     # Best Loss, Model Saved Path
PATH_M = 'Model/3-TEDS/Model-meta-'+str(Target_C)+'-'+str(Target_M)+'_'+str(IterID)
# Load Saved Weightes
Weights_C = torch.load(PATH_C+'.pt')
Weights_M = torch.load(PATH_M+'.pt')  
Tnet_PLC.load_state_dict(Weights_C[0])
Tnet_PLM.load_state_dict(Weights_M[0])
for F in range(NF_C): Tnet_AEC[F].load_state_dict(Weights_C[1+F])
for F in range(NF_M): Tnet_AEM[F].load_state_dict(Weights_M[1+F])

for d in range(3):     # 0=Train, 1=Valid, 2=Test
    # Load Saved Weights
    net_PLC, net_PLM = copy.deepcopy(Tnet_PLC), copy.deepcopy(Tnet_PLM)
    for F in range(NF_C): net_AEC[F] = copy.deepcopy(Tnet_AEC[F])
    for F in range(NF_M): net_AEM[F] = copy.deepcopy(Tnet_AEM[F])
    if d==0: XDC,XSC,YC,BC, XDM,XSM,YM,BM = XDC_Tra,XSC_Tra,YC_Tra,BC_Tra, XDM_Tra,XSM_Tra,YM_Tra,BM_Tra     # Training
    if d==1: XDC,XSC,YC,BC, XDM,XSM,YM,BM = XDC_Val,XSC_Val,YC_Val,BC_Val, XDM_Val,XSM_Val,YM_Val,BM_Val     # Validation
    if d==2: XDC,XSC,YC,BC, XDM,XSM,YM,BM = XDC_Tes,XSC_Tes,YC_Tes,BC_Tes, XDM_Tes,XSM_Tes,YM_Tes,BM_Tes     # Testing
    optimizer_HC, optimizer_RC = [torch.optim.Adam(net_AEC[F].parameters(), lr=0.0025) for F in range(NF_C)], torch.optim.Adam(net_PLC.parameters(), lr=0.01)
    optimizer_HM, optimizer_RM = [torch.optim.Adam(net_AEM[F].parameters(), lr=0.0025) for F in range(NF_M)], torch.optim.Adam(net_PLM.parameters(), lr=0.01)
    Pred_C, SL_C = torch.FloatTensor([]).to(device), []
    Pred_M, SL_M = torch.FloatTensor([]).to(device), []
    NRC, NRM     = [BC[i+1]-BC[i] for i in range(len(BC)-1)], [BM[i+1]-BM[i] for i in range(len(BM)-1)]      # Number of records of each patient, Flag for update

    for b in range( max([max(NRC), max(NRM)])//UP ):
        if d==0 and b%5==4: print('Tra', b+1, '/', max([max(NRC), max(NRM)])//UP, end='     \r')
        if d==1 and b%5==4: print('Val', b+1, '/', max([max(NRC), max(NRM)])//UP, end='     \r')
        if d==2 and b%5==4: print('Tes', b+1, '/', max([max(NRC), max(NRM)])//UP, end='     \r')
        sl_C, sl_M = [], []
        for p in range(len(BC)-1):
            if NRC[p]>=UP*(b+1): sl_C += [BC[p]+UP*b+i for i in range(UP)]
        for p in range(len(BM)-1):
            if NRM[p]>=UP*(b+1): sl_M += [BM[p]+UP*b+i for i in range(UP)]         
        # Care            
        EV, RV = net_AEC[0]( XDC[sl_C,0,:] )
        RV     = RV[:,None,:]
        for FC in range(1,NF_C):     # Embed Dense Feature
            ev, rv = net_AEC[FC]( XDC[sl_C,FC,:] )
            EV = torch.cat((EV,ev), dim=1)
            RV = torch.cat((RV,rv[:,None,:]), dim=1)
        for FC in range(NF_C):       # Embed Spare Feature
            ev, rv = net_AEC[F]( XSC[sl_C,FC,:] )
            EV = torch.cat((EV,ev), dim=1)
        pred = net_PLC(EV)
        for FC in range(NF_C):       # Update AEs
            loss = loss_mae(RV[:,FC], XDC[sl_C,FC,:])
            loss.backward(retain_graph=True)
        loss = loss_mse(torch.reshape(pred,(-1,)), YC[sl_C])
        loss.backward()
        for FC in range(NF_C): optimizer_HC[FC].step()
        for FC in range(NF_C): optimizer_HC[FC].zero_grad()
        optimizer_RC.step()
        optimizer_RC.zero_grad()
        Pred_C = torch.cat((Pred_C,torch.reshape(pred,(-1,)).detach()))
        SL_C  += sl_C
        
        # Meta
        EV, RV = net_AEM[0]( XDM[sl_M,0,:] )
        RV     = RV[:,None,:]
        for FM in range(1,NF_M):     # Embed Dense Feature
            ev, rv = net_AEM[FM]( XDM[sl_M,FM,:] )
            EV = torch.cat((EV,ev), dim=1)
            RV = torch.cat((RV,rv[:,None,:]), dim=1)
        for FM in range(NF_M):       # Embed Spare Feature
            ev, rv = net_AEM[FM]( XSM[sl_M,FM,:] )
            EV = torch.cat((EV,ev), dim=1)
        pred = net_PLM(EV)
        for FM in range(NF_M):       # Update AEs
            loss = loss_mae(RV[:,FM], XDM[sl_M,FM,:])
            loss.backward(retain_graph=True)
        loss = loss_mse(torch.reshape(pred,(-1,)), YM[sl_M])
        loss.backward()
        for FM in range(NF_M): optimizer_HM[FM].step()
        for FM in range(NF_M): optimizer_HM[FM].zero_grad()
        optimizer_RM.step()
        optimizer_RM.zero_grad()
        Pred_M = torch.cat((Pred_M,torch.reshape(pred,(-1,)).detach()))
        SL_M  += sl_M

    loss_C, loss_M = loss_mse(Pred_C, YC[SL_C]), loss_mse(Pred_M, YM[SL_M])
    TVTLoss_C[d] = float(loss_C)
    TVTLoss_M[d] = float(loss_M)
print('Care')
print('Training   Loss', round(float(TVTLoss_C[0]),3))
print('Validation Loss', round(float(TVTLoss_C[1]),3))
print('Testing    Loss', round(float(TVTLoss_C[2]),3))
print('Meta')
print('Training   Loss', round(float(TVTLoss_M[0]),3))
print('Validation Loss', round(float(TVTLoss_M[1]),3))
print('Testing    Loss', round(float(TVTLoss_M[2]),3))