In [None]:
"""
@author: borum

TCN code from: https://github.com/locuslab/TCN
Energy OOD code from: https://github.com/wetliu/energy_ood
ResNet code from: https://github.com/hsd1503/resnet1d
InceptionTime code from: https://github.com/TheMrGhostman/InceptionTime-Pytorch

"""

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.nn.utils import weight_norm
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torch.autograd import Variable
import time
import gc
import matplotlib.pyplot as plt
from local_attention import LocalAttention
from net1d import Net1D
from inception import Inception, InceptionBlock
from torchsummary import summary

"""
@author: borum

TCN code from: https://github.com/locuslab/TCN
Energy OOD code from: https://github.com/wetliu/energy_ood
ResNet code from: https://github.com/hsd1503/resnet1d
InceptionTime code from: https://github.com/TheMrGhostman/InceptionTime-Pytorch

"""

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.nn.utils import weight_norm
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torch.autograd import Variable
import time
import gc
import matplotlib.pyplot as plt
from local_attention import LocalAttention
from net1d import Net1D
from inception import Inception, InceptionBlock
from torchsummary import summary


GPU_NUM = 0
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(device) # change allocation of current GPU
print ('Current cuda device ', torch.cuda.current_device()) # check

MESA_PPG_PATH = '/PPG_SleepStaging/MESA_PPG_preprocessed'
MESA_annot_PATH = '/PPG_SleepStaging/MESA_annot_preprocessed'


to_np = lambda x: x.data.cpu().numpy()

#Training parameters
BATCH_SIZE = 2
LR = 1e-3
max_epoch = 100
k_fold=3
fold_n = 1 ## hold-out --> fold_n = 1 

# Labeling (str to int)
def sleep_label(annot_array):
    total_len = np.shape(annot_array)[0]
    for index in range(total_len):
        if annot_array[index] == 0:
            annot_array[index] = 0
        elif annot_array[index] == 1:
            annot_array[index] = 1
        elif annot_array[index] == 2:
            annot_array[index] = 1
        elif annot_array[index] == 3:
            annot_array[index] = 2
        elif annot_array[index] == 4:
            annot_array[index] = 2
        elif annot_array[index] == 5:
            annot_array[index] = 3
        else:
            pass
    return annot_array

#k-fold
def k_fold(k,fold_num,file_path):
    subject_list = []
    for subject in sorted(os.listdir(file_path)):
        subject_list.append(subject)
    train_subject_set = []
    test_subject_set = []
    samples_in_fold = int(round(len(subject_list)/k))
    if not len(subject_list)%samples_in_fold == 0:
        supplement_num = samples_in_fold - (len(subject_list)%samples_in_fold)
        for sup in list(range(supplement_num)):
            subject_list.append(subject_list[sup])
        test_subject_set = subject_list[samples_in_fold*(fold_num-1):samples_in_fold*fold_num]
        train_subject_set = [x for x in subject_list if x not in test_subject_set]
    else:
        test_subject_set = subject_list[samples_in_fold*(fold_num-1):samples_in_fold*fold_num]
        train_subject_set = [x for x in subject_list if x not in test_subject_set]
    return train_subject_set, test_subject_set

def truncateORpad(arr, annot=False):
    if annot==False:
        if np.shape(arr)[0] < 1228800:
            arr = np.pad(arr, ((0, 1228800-np.shape(arr)[0]),(0,0)), 'constant', constant_values=0)
        elif np.shape(arr)[0] > 1228800:
            arr = arr[:1228800]
            
        else:
            pass       
    elif annot is True:
        if np.shape(arr)[0] < 1200:
            arr = np.pad(arr, ((0,1200-np.shape(arr)[0]),(0,0)), 'constant', constant_values=4)
        elif np.shape(arr)[0] >1200:
            arr = arr[:1200]
        else:
            pass
    else:
        pass
    return arr

    
# DB load function
def data_to_dict(ppg_data_path,annot_path, fold_num):
    ppg_train_sample_list = []
    ppg_test_sample_list = []
    train_subject_list = []
    test_subject_list = []
    train_class_dict = {}
    test_class_dict = {}
    
    train_subject_list, test_subject_list = k_fold(k=3,fold_num=fold_num,file_path=ppg_data_path)
    
    # train set
    for train_subject in train_subject_list:
        train_name = train_subject.split('.')[0]
        train_annot_df = pd.read_csv(annot_path+'/'+train_name+'-profusion.csv',header=None)
        train_annot_array = np.array(train_annot_df)
        train_label = sleep_label(train_annot_array)
        ppg_train_sample_path = ppg_data_path+'/'+train_subject
        ppg_train_sample_list.append(ppg_train_sample_path)
        train_class_dict[ppg_train_sample_path] = train_label
    # test set
    for test_subject in test_subject_list:
        test_name = test_subject.split('.')[0]
        test_annot_df = pd.read_csv(annot_path+'/'+test_name+'-profusion.csv',header=None)
        test_annot_array = np.array(test_annot_df)
        test_label = sleep_label(test_annot_array)
        ppg_test_sample_path = ppg_data_path+'/'+test_subject
        ppg_test_sample_list.append(ppg_test_sample_path)
        test_class_dict[ppg_test_sample_path] = test_label
    return train_class_dict, ppg_train_sample_list, test_class_dict, ppg_test_sample_list


class PSG_DB():
    def __init__(self, category_dict, ppg_sample_dirs):
        self.category_dict = category_dict
        self.ppg_sample_dirs = ppg_sample_dirs
        
    def __len__(self):
        return len(self.ppg_sample_dirs)
    
    def __getitem__(self, idx):
        ppg_sample = self.ppg_sample_dirs[idx]
        ppg_sample_df = pd.read_csv(ppg_sample,header=None)
        ppg_sample_input = np.array(ppg_sample_df)
        ppg_sample_input = truncateORpad(ppg_sample_input)
        label_per_input = self.category_dict[ppg_sample]
        label_per_input = truncateORpad(label_per_input, annot=True)
        return ppg_sample_input, label_per_input
    
# Load Data
train_class, ppg_train_list, test_class, ppg_test_list = data_to_dict(
    MESA_PPG_PATH, MESA_annot_PATH, fold_num=fold_n)
train_dataset = PSG_DB(train_class, ppg_train_list)
test_dataset = PSG_DB(test_class, ppg_test_list)

train_total_data_num = len(ppg_train_list)
test_total_data_num = len(ppg_test_list)           

train_data_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE)
test_data_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE)

class TimeDistributed(nn.Module):
    def __init__(self, module, batch_first=False):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.batch_first = batch_first
    def forward(self, x):
        if len(x.size()) <= 2:
            return self.module(x)
        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().view(-1, x.size(-1))  # (samples * timesteps, input_size)
        y = self.module(x_reshape)
        # We have to reshape Y
        if self.batch_first:
            y = y.contiguous().view(x.size(0), -1, y.size(-1))  # (samples, timesteps, output_size)
        else:
            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)
        return y
    
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size
    def forward(self,x):
        return x[:, :, :-self.chomp_size].contiguous()
    
class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.1):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                          stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        
        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                          stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        
        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()
        
    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)
            
    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)
    
class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=20, dropout=0.1):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(n_inputs=in_channels, n_outputs=out_channels, 
                                     kernel_size=kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]
        self.network = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.network(x) 
    
class SleepStaging_NET(nn.Module):
    def __init__(self,input_size,output_size):
        super(SleepStaging_NET, self).__init__()
        # Local attention mask
        self.causal_conv =  weight_norm(nn.Conv1d(1, 8, kernel_size=7168,
                                                  stride=1, padding=7167, dilation=1))
        self.chomp_att = Chomp1d(7167)
        self.downsample_att = nn.Conv1d(8, 1, 1)
        self.local_att = nn.Sequential(self.causal_conv, self.chomp_att, self.downsample_att)
        
        # InceptionTime
        self.conv1d = nn.Conv1d(1, 32, kernel_size=40, stride=20)
        self.relu = nn.ReLU()
        self.InceptionTime = nn.Sequential(
            InceptionBlock(
                in_channels=32, 
                n_filters=8, 
                kernel_sizes=[5,11,23],
                bottleneck_channels=8,
                use_residual=True,
                activation=nn.ReLU()
            ),
            InceptionBlock(
                in_channels=32, 
                n_filters=16, 
                kernel_sizes=[5,11,23],
                bottleneck_channels=16,
                use_residual=True,
                activation=nn.ReLU()
            ),
            InceptionBlock(
                in_channels=64, 
                n_filters=16, 
                kernel_sizes=[5,11,23],
                bottleneck_channels=16,
                use_residual=True,
                activation=nn.ReLU()
            ),
            InceptionBlock(
                in_channels=64, 
                n_filters=32, 
                kernel_sizes=[5,11,23],
                bottleneck_channels=16,
                use_residual=True,
                activation=nn.ReLU()
            ),
            InceptionBlock(
                in_channels=128, 
                n_filters=64, 
                kernel_sizes=[5,11,23],
                bottleneck_channels=32,
                use_residual=True,
                activation=nn.ReLU()
            ),
            InceptionBlock(
                in_channels=256, 
                n_filters=128, 
                kernel_sizes=[5,11,23],
                bottleneck_channels=32,
                use_residual=True,
                activation=nn.ReLU()
            ),
            nn.AdaptiveAvgPool1d(output_size=1200)
        )
        self.conv1 = nn.Conv1d(512, 256, kernel_size=1)

        # Time-distributed dense layer
        self.tdd =TimeDistributed(nn.Linear(256,128),batch_first=True)
        # TCN
        self.tcn = TemporalConvNet(128, [64,64,64,64,64], kernel_size=8, dropout=0.2)
        self.downsample = nn.Conv1d(64, 4, 1)
        
    def forward(self, inputs):
        # input shape: (batch_size, time_step, feature)
        inputs = inputs.transpose(2, 1) # now input shape : (batch_size, feature, time_step)
        att = self.local_att(inputs) # att shape : (batch_size, feature, time_step)
        for i in range(1200):
            att[:,:,i*1024:(i+1)*1024] = F.sigmoid(att[:,:,i*1024:(i+1)*1024].clone())
        inputs = inputs*att
        inputs = self.conv1d(inputs)
        inputs = self.relu(inputs)
        inception_output = self.InceptionTime(inputs)
        inception_output = self.conv1(inception_output)
        inception_output = inception_output.transpose(2, 1) # now inception_output shape : (batch_size, time_step, feature)
        tdd_output = self.tdd(inception_output) # tdd_output shape : (batch_size, time_step, feature)
        tdd_output = tdd_output.transpose(2, 1) # now tdd_output shape : (batch_size, feature, time_step)
        tcn_output = self.tcn(tdd_output) # tcn_output shape : (batch_size, feature, time_step)
        output = self.downsample(tcn_output) # output shape : (batch_size, feature, time_step)
        
        return F.log_softmax(output, dim=1), att, output

def get_energy(output_array):
    T = 1
    energy = -to_np((T*torch.logsumexp(output_array / T, dim=1)))
            
    return energy
                                                                                                                     
def fit(batch_size,
        data,
        label,
        model,
        model_optimizer,
       inference=False):
    loss = 0
    model_optimizer.zero_grad()

    prediction_results, attention, final_feature = model.forward(data)
    
    gc.collect()
    torch.cuda.empty_cache()
    loss_function = nn.NLLLoss()
    
    label_len = []
    mini_batch_size = np.shape(label)[0]
    for batch in range(mini_batch_size):
        for i in range(1200):
            if label[batch,i] == 4:
                label_len.append(i)
                break
            elif i == 1199:
                label_len.append(1200)
                break
    
    corrects = 0

    for batch in range(mini_batch_size):
        batch_prediction_results = prediction_results[batch,:,:label_len[batch]].clone()
        batch_prediction_results = batch_prediction_results.transpose(1, 0) # shape: (timesteps, 4), 4 is class
        batch_label = label[batch,:label_len[batch]].view([-1]).clone() # shape: (timesteps)
        current_loss = loss_function(batch_prediction_results, batch_label)
        loss += current_loss
        corrects_sum = ((torch.argmax(batch_prediction_results, 1)) == torch.squeeze(batch_label)).sum()
        corrects += corrects_sum
    L = sum(label_len)
        
    attention = torch.squeeze(attention)
    #loss.requires_grad_(True)
    if inference is False:
        loss.backward()
        model_optimizer.step()

    final_loss = loss.item()
    pred_label = torch.argmax(prediction_results, dim=1)
    pred_label = torch.squeeze(pred_label)
    energy_results = get_energy(final_feature)
    
    gc.collect()
    torch.cuda.empty_cache()

    return label, pred_label, prediction_results, energy_results, attention, \
            final_loss, corrects, L, final_feature


log_name = 'InsightSleepNet(InceptionTime)MESA_batchsize_{0}_initlr_{1}_fold{2}'.format(str(BATCH_SIZE), str(LR), str(fold_n))+time.strftime("_%b_%d_%H_%M", 
                                                                                          time.localtime())

saved_weights_folder = os.path.join('/PPG_SleepStaging/results/MESA',log_name)
if not os.path.exists(saved_weights_folder):
    os.makedirs(saved_weights_folder)
    
net = SleepStaging_NET(input_size=1, output_size=4).to(device)

optimizer = torch.optim.RMSprop(net.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, 
                                                  lr_lambda= lambda epoch: 1.0 ** epoch)

train_loss_list = []
train_acc_list = []
test_loss_list = []
test_acc_list = []

num_step_per_epoch_train = train_total_data_num/BATCH_SIZE
num_step_per_epoch_test = test_total_data_num/BATCH_SIZE

for epoch_num in range(max_epoch):
    
    epoch_train_loss = 0
    train_corrects = 0
    train_L = 0
    train_label_pred_energy = np.zeros((1,1200,3))
    train_att = np.zeros((1,1228800))
    train_softmax = np.zeros((1,4,1200))
    train_feature = np.zeros((1,4,1200))
    for i, (ppg_input, label) in enumerate(train_data_loader):
        ppg_input = Variable(ppg_input).cuda().float()
        label = Variable(label).cuda().long()
        
        label_result, pred_label, pred, energy, att, loss, batch_corrects, batch_L, feature = fit(
            BATCH_SIZE, ppg_input, label,net,optimizer)   
        # attention stack by batch(subject)
        train_att_tmp = att.cpu().data.numpy()
        train_att = np.concatenate((train_att, train_att_tmp), axis=0)
        # softmax stack by batch(subject)
        train_softmax_tmp = pred.cpu().data.numpy()
        train_softmax = np.concatenate((train_softmax, train_softmax_tmp), axis=0)
        # final feature stack by batch(subject)
        train_feature_tmp = feature.cpu().data.numpy()
        train_feature = np.concatenate((train_feature,train_feature_tmp), axis=0)

        # label, pred label, energy stack by batch(subject)
        train_label_tmp = label_result.cpu().data.numpy()
        train_pred_label_tmp = pred_label.cpu().data.numpy()
        train_energy_tmp = energy
        label_pred_energy_tmp = np.concatenate((train_label_tmp,
                                               np.expand_dims(train_pred_label_tmp,2),
                                               np.expand_dims(train_energy_tmp, 2)), axis=2)
        train_label_pred_energy = np.concatenate((train_label_pred_energy,label_pred_energy_tmp), axis=0)

        # loss, acc
        epoch_train_loss += loss
        train_corrects += batch_corrects
        train_L += batch_L
        
        gc.collect()
        torch.cuda.empty_cache()
    
    # Train results save
    saved_by_epoch = os.path.join(saved_weights_folder,'epoch_{}'.format(epoch_num, '04'))
    if not os.path.exists(saved_by_epoch):
        os.makedirs(saved_by_epoch)
        
    # attention, softmax, final output feature save
    train_att_np = train_att[1:,:]
    train_softmax_np = train_softmax[1:,:,:]
    train_feature_np = train_feature[1:,:,:]
    np.save(saved_by_epoch+"/train_att.npy", train_att_np)
    np.save(saved_by_epoch+"/train_softmax.npy", train_softmax_np)
    np.save(saved_by_epoch+"/train_feature.npy", train_feature_np)
    
    # label, pred label, energy
    train_label_pred_energy = train_label_pred_energy[1:,:,:]
    np.save(saved_by_epoch+"/train_pred_results.npy", train_label_pred_energy)
    
    # loss, acc
    epoch_train_loss = epoch_train_loss/num_step_per_epoch_train
    avg_train_acc = train_corrects.float()/float(train_L)
    print("epoch:{} ".format(epoch_num)+ " train loss: {}".format(epoch_train_loss))
    print("epoch:{} ".format(epoch_num)+ " train accuracy: {}".format(avg_train_acc))
    train_loss_list.append(epoch_train_loss)
    train_acc_list.append(avg_train_acc)
    
    epoch_test_loss = 0
    test_corrects = 0
    test_L = 0
    test_label_pred_energy = np.zeros((1,1200,3))
    test_att = np.zeros((1,1228800))
    test_softmax = np.zeros((1,4,1200))
    test_feature = np.zeros((1,4,1200))
    for i, (ppg_input, label) in enumerate(test_data_loader):
        with torch.no_grad():
            ppg_input = Variable(ppg_input).cuda().float()
            label = Variable(label).cuda().long()

            label_result, pred_label, pred, energy, att, loss, batch_corrects, batch_L, feature = fit(
                BATCH_SIZE, ppg_input, label,net,optimizer,inference=True)

        # attention stack by batch(subject)
        test_att_tmp = att.cpu().data.numpy()
        test_att = np.concatenate((test_att, test_att_tmp), axis=0)
        # softmax stack by batch(subject)
        test_softmax_tmp = pred.cpu().data.numpy()
        test_softmax = np.concatenate((test_softmax, test_softmax_tmp), axis=0)
        # final feature stack by batch(subject)
        test_feature_tmp = feature.cpu().data.numpy()
        test_feature = np.concatenate((test_feature,test_feature_tmp), axis=0)
           
        # label, pred label, energy stack by batch(subject)
        test_label_tmp = label_result.cpu().data.numpy()
        test_pred_label_tmp = pred_label.cpu().data.numpy()
        test_energy_tmp = energy
        label_pred_energy_tmp = np.concatenate((test_label_tmp,
                                               np.expand_dims(test_pred_label_tmp, 2),
                                               np.expand_dims(test_energy_tmp, 2)), axis=2)
        test_label_pred_energy = np.concatenate((test_label_pred_energy,label_pred_energy_tmp), axis=0)

        # loss, acc
        epoch_test_loss += loss
        test_corrects += batch_corrects
        test_L += batch_L
        
        gc.collect()
        torch.cuda.empty_cache()
    
    # Train results save
    saved_by_epoch = os.path.join(saved_weights_folder,'epoch_{}'.format(epoch_num, '04'))
    if not os.path.exists(saved_by_epoch):
        os.makedirs(saved_by_epoch)
        
    # attention, softmax, final output feature save
    test_att_np = test_att[1:,:]
    test_softmax_np = test_softmax[1:,:,:]
    test_feature_np = test_feature[1:,:,:]
    np.save(saved_by_epoch+"/test_att.npy", test_att_np)
    np.save(saved_by_epoch+"/test_softmax.npy", test_softmax_np)
    np.save(saved_by_epoch+"/test_feature.npy", test_feature_np)
    
    # label, pred label, energy
    test_label_pred_energy = test_label_pred_energy[1:,:,:]
    np.save(saved_by_epoch+"/test_pred_results.npy", test_label_pred_energy)
    
    # loss, acc
    epoch_test_loss = epoch_test_loss/num_step_per_epoch_test
    avg_test_acc = test_corrects.float()/float(test_L)
    print("epoch:{} ".format(epoch_num)+ " test loss: {}".format(epoch_test_loss))
    print("epoch:{} ".format(epoch_num)+ " test accuracy: {}".format(avg_test_acc))
    test_loss_list.append(epoch_test_loss)
    test_acc_list.append(avg_test_acc)
    
    optimizer.step()
    scheduler.step()
    
    print("learning rate is : {}".format(optimizer.param_groups[0]['lr']))
    
    torch.save(net.state_dict(), saved_by_epoch+'/model.pt')
    torch.save(net.state_dict(), saved_by_epoch+'/model.pth')


# train, test set subject name save
train_subject_list, test_subject_list = k_fold(k=10,fold_num=fold_n,file_path=MESA_PPG_PATH)
train_subjects = np.array(train_subject_list)
test_subjects = np.array(test_subject_list)
np.save(saved_weights_folder+"/train_subjects.npy", train_subjects)
np.save(saved_weights_folder+"/test_subjects.npy", test_subjects)

# final loss, acc save
train_loss_np = np.array(train_loss_list)
train_acc_np = np.array(train_acc_list)
train_loss_np = np.expand_dims(train_loss_np, axis=1)
train_acc_np = np.expand_dims(train_acc_np, axis=1)
train_metric_results = np.concatenate((train_loss_np,train_acc_np), axis=1)
train_metric_results_df = pd.DataFrame(train_metric_results)
metric_header = ['loss', 'acc']
train_metric_results_df.to_csv(saved_weights_folder+'/'+'train_metric_results.csv',header=metric_header, 
                               index=False)

test_loss_np = np.array(test_loss_list)
test_acc_np = np.array(test_acc_list)
test_loss_np = np.expand_dims(test_loss_np, axis=1)
test_acc_np = np.expand_dims(test_acc_np, axis=1)
test_metric_results = np.concatenate((test_loss_np,test_acc_np), axis=1)
test_metric_results_df = pd.DataFrame(test_metric_results)
test_metric_results_df.to_csv(saved_weights_folder+'/'+'test_metric_results.csv',header=metric_header, 
                               index=False)

