In [None]:
import numpy as np
import argparse
import os
import imp
import re
import pickle
import datetime
import random
import math
import logging
import copy
import matplotlib.pyplot as plt
import sklearn
from sklearn.cluster import KMeans
from sklearn.model_selection import StratifiedKFold
from sklearn.neighbors import kneighbors_graph

from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils import data
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn import Parameter

from utils import utils
from utils.readers import InHospitalMortalityReader
from utils.preprocessing import Discretizer, Normalizer
from utils import metrics
from utils import common_utils

In [None]:
# Select the target dataset: COVID-19 Dataset from TJ Hospital or HM Hospital
target_dataset = 'TJ' # 'TJ' or 'HM'

# Use CUDA if available
device = torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu')
print("available device: {}".format(device))

### Transfer Source Data & Model: 
#### Teacher Model


In [None]:
data_path = './data/Challenge/'
small_part = False
arg_timestep = 1.0
batch_size = 256
epochs = 100

In [None]:
all_x = pickle.load(open(data_path + 'new_x_front_fill.dat', 'rb'))
all_y = pickle.load(open(data_path + 'new_y_front_fill.dat', 'rb'))
all_names = pickle.load(open(data_path + 'new_name.dat', 'rb'))
static = pickle.load(open(data_path + 'new_demo_front_fill.dat', 'rb'))
mask_x = pickle.load(open(data_path + 'new_mask_x.dat', 'rb'))
mask_demo = pickle.load(open(data_path + 'new_mask_demo.dat', 'rb'))
all_x_len = [len(i) for i in all_x]

print(all_x[0])
print(mask_x[0])
print(all_names[0])
print(static[0])

In [None]:
if target_dataset == 'TJ':
    subset_idx = [27, 29, 18, 16, 26, 33, 28, 31, 32, 15, 11, 25, 21, 20, 9, 17, 30, 19]
elif target_dataset == 'HM':
    subset_idx = [9, 11, 15, 16, 17, 18, 19, 20, 21, 25, 26, 27, 28, 29, 32, 33, 14, 22, 23]

subset_cnt = len(subset_idx)
other_idx = []
for i in range(len(all_x[0][0])):
    if i not in subset_idx:
        other_idx.append(i)

for i in range(len(all_x)):
    cur = np.array(all_x[i], dtype=float)
    cur_mask = np.array(mask_x[i])
    cur_subset = cur[:, subset_idx]
    cur_other = cur[:, other_idx]
    cur_mask_subset = cur_mask[:, subset_idx]
    cur_mask_other = cur_mask[:, other_idx]
    all_x[i] = np.concatenate((cur_subset, cur_other), axis=1).tolist()
    mask_x[i] = np.concatenate((cur_mask_subset, cur_mask_other), axis=1).tolist()
print(all_x[0])
print(mask_x[0])
print(len(all_x[0][0]))

In [None]:
train_num =int( len(all_x) * 0.8) + 1
print(train_num)
dev_num =int( len(all_x) * 0.1) + 1
print(dev_num)
test_num =int( len(all_x) * 0.1)
print(test_num)
assert(train_num+dev_num+test_num == len(all_x))

train_x = []
train_y = []
train_names = []
train_static = []
train_x_len = []
train_mask_x = []
for idx in range(train_num):
    train_x.append(all_x[idx])
    train_y.append(int(all_y[idx][-1]))
    train_names.append(all_names[idx])
    train_static.append(static[idx])
    train_x_len.append(all_x_len[idx])
    train_mask_x.append(mask_x[idx])

dev_x = []
dev_y = []
dev_names = []
dev_static = []
dev_x_len = []
dev_mask_x = []
for idx in range(train_num, train_num + dev_num):
    dev_x.append(all_x[idx])
    dev_y.append(int(all_y[idx][-1]))
    dev_names.append(all_names[idx])
    dev_static.append(static[idx])
    dev_x_len.append(all_x_len[idx])
    dev_mask_x.append(mask_x[idx])


test_x = []
test_y = []
test_names = []
test_static = []
test_x_len = []
test_mask_x = []
for idx in range(train_num + dev_num, train_num + dev_num + test_num):
    test_x.append(all_x[idx])
    test_y.append(int(all_y[idx][-1]))
    test_names.append(all_names[idx])
    test_static.append(static[idx])
    test_x_len.append(all_x_len[idx])
    test_mask_x.append(mask_x[idx])


assert(len(train_x) == train_num)
assert(len(dev_x) == dev_num)
assert(len(test_x) == test_num)

In [None]:
long_x = all_x
long_y = [y[-1] for y in all_y]


In [None]:
def get_loss(y_pred, y_true):
    loss = torch.nn.BCELoss()
    return loss(y_pred, y_true)

In [None]:
def get_re_loss(y_pred, y_true):
    loss = torch.nn.MSELoss()
    return loss(y_pred, y_true)

In [None]:
def get_kl_loss(x_pred, x_target):
    loss = torch.nn.KLDivLoss(reduce=True, size_average=True)
    return loss(x_pred, x_target)

In [None]:
def get_wass_dist(x_pred, x_target):
    m1 = torch.mean(x_pred, dim=0)
    m2 = torch.mean(x_target, dim=0)
    v1 = torch.var(x_pred, dim=0)
    v2 = torch.var(x_target, dim=0)
    p1 = torch.sum(torch.pow((m1 - m2), 2))
    p2 = torch.sum(torch.pow(torch.pow(v1, 1/2) - torch.pow(v2, 1/2), 2))
    return torch.pow(p1+p2, 1/2)

In [None]:
def pad_sents(sents, pad_token):

    sents_padded = []

    max_length = max([len(_) for _ in sents])
    for i in sents:
        padded = list(i) + [pad_token]*(max_length-len(i))
        sents_padded.append(np.array(padded))


    return np.array(sents_padded)

In [None]:
def batch_iter(x, y, mask, lens, batch_size, shuffle=False):
    """ Yield batches of source and target sentences reverse sorted by length (largest to smallest).
    @param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
    @param batch_size (int): batch size
    @param shuffle (boolean): whether to randomly shuffle the dataset
    """
    batch_num = math.ceil(len(x) / batch_size) # 向下取整
    index_array = list(range(len(x)))

    if shuffle:
        np.random.shuffle(index_array)

    for i in range(batch_num):
        indices = index_array[i * batch_size: (i + 1) * batch_size] #  fetch out all the induces
        
        examples = []
        for idx in indices:
            examples.append((x[idx], y[idx], mask[idx], lens[idx]))
       
        examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
    
        batch_x = [e[0] for e in examples]
        batch_y = [e[1] for e in examples]
        batch_mask_x = [e[2] for e in examples]
#         batch_name = [e[2] for e in examples]
        batch_lens = [e[3] for e in examples]

        yield batch_x, batch_y, batch_mask_x, batch_lens

In [None]:
def length_to_mask(length, max_len=None, dtype=None):
    """length: B.
    return B x max_len.
    If max_len is None, then max of length will be used.
    """
    assert len(length.shape) == 1, 'Length shape should be 1 dimensional.'
    max_len = max_len or length.max().item()
    mask = torch.arange(max_len, device=length.device,
                        dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1)
    if dtype is not None:
        mask = torch.as_tensor(mask, dtype=dtype, device=length.device)
    return mask

In [None]:
class SingleAttention(nn.Module):
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', demographic_dim=12, time_aware=False, use_demographic=False):
        super(SingleAttention, self).__init__()
        
        self.attention_type = attention_type
        self.attention_hidden_dim = attention_hidden_dim
        self.attention_input_dim = attention_input_dim
        self.use_demographic = use_demographic
        self.demographic_dim = demographic_dim
        self.time_aware = time_aware

        # batch_time = torch.arange(0, batch_mask.size()[1], dtype=torch.float32).reshape(1, batch_mask.size()[1], 1)
        # batch_time = batch_time.repeat(batch_mask.size()[0], 1, 1)
        
        if attention_type == 'add':
            if self.time_aware == True:
                # self.Wx = nn.Parameter(torch.randn(attention_input_dim+1, attention_hidden_dim))
                self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
                self.Wtime_aware = nn.Parameter(torch.randn(1, attention_hidden_dim))
                nn.init.kaiming_uniform_(self.Wtime_aware, a=math.sqrt(5))
            else:
                self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wt = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wd = nn.Parameter(torch.randn(demographic_dim, attention_hidden_dim))
            self.bh = nn.Parameter(torch.zeros(attention_hidden_dim,))
            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
            self.ba = nn.Parameter(torch.zeros(1,))
            
            nn.init.kaiming_uniform_(self.Wd, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        elif attention_type == 'mul':
            self.Wa = nn.Parameter(torch.randn(attention_input_dim, attention_input_dim))
            self.ba = nn.Parameter(torch.zeros(1,))
            
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        elif attention_type == 'concat':
            if self.time_aware == True:
                self.Wh = nn.Parameter(torch.randn(2*attention_input_dim+1, attention_hidden_dim))
            else:
                self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))

            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
            self.ba = nn.Parameter(torch.zeros(1,))
            
            nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        else:
            raise RuntimeError('Wrong attention type.')
        
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()
    
    def forward(self, input, demo=None):
 
        batch_size, time_step, input_dim = input.size() # batch_size * time_step * hidden_dim(i)
        time_decays = torch.tensor(range(time_step-1,-1,-1), dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(device)# 1*t*1
        b_time_decays = time_decays.repeat(batch_size,1,1)# b t 1
        
        if self.attention_type == 'add': #B*T*I  @ H*I
            q = torch.matmul(input[:,-1,:], self.Wt)# b h
            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            if self.time_aware == True:
                # k_input = torch.cat((input, time), dim=-1)
                k = torch.matmul(input, self.Wx)#b t h
                # k = torch.reshape(k, (batch_size, 1, time_step, self.attention_hidden_dim)) #B*1*T*H
                time_hidden = torch.matmul(b_time_decays, self.Wtime_aware)#  b t h
            else:
                k = torch.matmul(input, self.Wx)# b t h
                # k = torch.reshape(k, (batch_size, 1, time_step, self.attention_hidden_dim)) #B*1*T*H
            if self.use_demographic == True:
                d = torch.matmul(demo, self.Wd) #B*H
                d = torch.reshape(d, (batch_size, 1, self.attention_hidden_dim)) # b 1 h
            h = q + k + self.bh # b t h
            if self.time_aware == True:
                h += time_hidden
            h = self.tanh(h) #B*T*H
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step))# b t
        elif self.attention_type == 'mul':
            e = torch.matmul(input[:,-1,:], self.Wa)#b i
            e = torch.matmul(e.unsqueeze(1), input.permute(0,2,1)).squeeze() + self.ba #b t
        elif self.attention_type == 'concat':
            q = input[:,-1,:].unsqueeze(1).repeat(1,time_step,1)# b t i
            k = input
            c = torch.cat((q, k), dim=-1) #B*T*2I
            if self.time_aware == True:
                c = torch.cat((c, b_time_decays), dim=-1) #B*T*2I+1
            h = torch.matmul(c, self.Wh)
            h = self.tanh(h)
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step)) # b t 
        
        a = self.softmax(e) #B*T
        v = torch.matmul(a.unsqueeze(1), input).squeeze() #B*I

        return v, a

class FinalAttentionQKV(nn.Module):
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', dropout=None):
        super(FinalAttentionQKV, self).__init__()
        
        self.attention_type = attention_type
        self.attention_hidden_dim = attention_hidden_dim
        self.attention_input_dim = attention_input_dim


        self.W_q = nn.Linear(attention_input_dim, attention_hidden_dim)
        self.W_k = nn.Linear(attention_input_dim, attention_hidden_dim)
        self.W_v = nn.Linear(attention_input_dim, attention_hidden_dim)

        self.W_out = nn.Linear(attention_hidden_dim, 1)

        self.b_in = nn.Parameter(torch.zeros(1,))
        self.b_out = nn.Parameter(torch.zeros(1,))

        nn.init.kaiming_uniform_(self.W_q.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_k.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_v.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_out.weight, a=math.sqrt(5))

        self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))
        self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
        self.ba = nn.Parameter(torch.zeros(1,))
        
        nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        
        self.dropout = nn.Dropout(p=dropout)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, input):
 
        batch_size, time_step, input_dim = input.size() # batch_size * input_dim + 1 * hidden_dim(i)
        input_q = self.W_q(torch.mean(input, dim=1)) # b h
        input_k = self.W_k(input)# b t h
        input_v = self.W_v(input)# b t h

        if self.attention_type == 'add': #B*T*I  @ H*I

            q = torch.reshape(input_q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            h = q + input_k + self.b_in # b t h
            h = self.tanh(h) #B*T*H
            e = self.W_out(h) # b t 1
            e = torch.reshape(e, (batch_size, time_step))# b t

        elif self.attention_type == 'mul':
            q = torch.reshape(input_q, (batch_size, self.attention_hidden_dim, 1)) #B*h 1
            e = torch.matmul(input_k, q).squeeze()#b t
            
        elif self.attention_type == 'concat':
            q = input_q.unsqueeze(1).repeat(1,time_step,1)# b t h
            k = input_k
            c = torch.cat((q, k), dim=-1) #B*T*2I
            h = torch.matmul(c, self.Wh)
            h = self.tanh(h)
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step)) # b t 
        
        a = self.softmax(e) #B*T
        if self.dropout is not None:
            a = self.dropout(a)
        v = torch.matmul(a.unsqueeze(1), input_v).squeeze() #B*I

        return v, a

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
    return torch.index_select(a, dim, order_index).to(device)

class PositionwiseFeedForward(nn.Module): # new added
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x)))), None

    
class PositionalEncoding(nn.Module): # new added / not use anymore
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=400):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0., max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0 

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)# b h t d_k
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k) # b h t t
    if mask is not None:# 1 1 t t
        scores = scores.masked_fill(mask == 0, -1e9)# b h t t 
    p_attn = F.softmax(scores, dim = -1)# b h t t
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn # b h t v (d_k) 
    
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, self.d_k * self.h), 3)
        self.final_linear = nn.Linear(d_model, d_model)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1) # 1 1 t t

        nbatches = query.size(0)# b
        input_dim = query.size(1)# i+1
        feature_dim = query.size(-1)# i+1

        #input size -> # batch_size * d_input * hidden_dim
        
        # d_model => h * d_k 
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))] # b num_head d_input d_k
        
       
        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)# b num_head d_input d_v (d_k) 
        
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)# batch_size * d_input * hidden_dim

        #DeCov 
        DeCov_contexts = x.transpose(0, 1).transpose(1, 2) # I+1 H B
#         print(DeCov_contexts.shape)
        Covs = cov(DeCov_contexts[0,:,:])
        DeCov_loss = 0.5 * (torch.norm(Covs, p = 'fro')**2 - torch.norm(torch.diag(Covs))**2 ) 
        for i in range(11 -1):
            Covs = cov(DeCov_contexts[i+1,:,:])
            DeCov_loss += 0.5 * (torch.norm(Covs, p = 'fro')**2 - torch.norm(torch.diag(Covs))**2 ) 


        return self.final_linear(x), DeCov_loss

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-7):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

def cov(m, y=None):
    if y is not None:
        m = torch.cat((m, y), dim=0)
    m_exp = torch.mean(m, dim=1)
    x = m - m_exp[:, None]
    cov = 1 / (x.size(1) - 1) * x.mm(x.t())
    return cov

class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        returned_value = sublayer(self.norm(x))
        return x + self.dropout(returned_value[0]) , returned_value[1]

class distcare_teacher(nn.Module):
    def __init__(self, input_dim, hidden_dim, d_model,  MHD_num_head, d_ff, output_dim, keep_prob=0.5):
        super(distcare_teacher, self).__init__()

        # hyperparameters
        self.input_dim = input_dim  
        self.hidden_dim = hidden_dim  # d_model
        self.d_model = d_model
        self.MHD_num_head = MHD_num_head
        self.d_ff = d_ff
        self.output_dim = output_dim
        self.keep_prob = keep_prob

        # layers
        self.PositionalEncoding = PositionalEncoding(self.d_model, dropout = 0, max_len = 400)

        self.GRUs = clones(nn.GRU(1, self.hidden_dim, batch_first = True), self.input_dim)
        self.LastStepAttentions = clones(SingleAttention(self.hidden_dim, 8, attention_type='concat', demographic_dim=12, time_aware=True, use_demographic=False),self.input_dim)
        
        self.FinalAttentionQKV = FinalAttentionQKV(self.hidden_dim, self.hidden_dim, attention_type='mul',dropout = 1 - self.keep_prob)

        self.MultiHeadedAttention = MultiHeadedAttention(self.MHD_num_head, self.d_model,dropout = 1 - self.keep_prob)
        self.SublayerConnection = SublayerConnection(self.d_model, dropout = 1 - self.keep_prob)

        self.PositionwiseFeedForward = PositionwiseFeedForward(self.d_model, self.d_ff, dropout=0.1)

        self.demo_proj_main = nn.Linear(12, self.hidden_dim)
        self.demo_proj = nn.Linear(12, self.hidden_dim)
        self.output = nn.Linear(self.hidden_dim, self.output_dim)

        self.dropout = nn.Dropout(p = 1 - self.keep_prob)
        self.tanh=nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.relu=nn.ReLU()

    def forward(self, input, lens):
        # input shape [batch_size, timestep, feature_dim]
#         demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze(1)# b hidden_dim
        
        batch_size = input.size(0)
        time_step = input.size(1)
        feature_dim = input.size(2)
        assert(feature_dim == self.input_dim)# input Tensor : 256 * 48 * 76
        assert(self.d_model % self.MHD_num_head == 0)

        
        GRU_embeded_input = self.GRUs[0](pack_padded_sequence(input[:,:,0].unsqueeze(-1), lens, batch_first=True))[1].squeeze().unsqueeze(1) # b 1 h
#         print(GRU_embeded_input.shape)
        for i in range(feature_dim-1):
            embeded_input = self.GRUs[i+1](pack_padded_sequence(input[:,:,i+1].unsqueeze(-1), lens, batch_first=True))[1].squeeze().unsqueeze(1) # b 1 h
            GRU_embeded_input = torch.cat((GRU_embeded_input, embeded_input), 1)

#         GRU_embeded_input = torch.cat((GRU_embeded_input, demo_main), 1)# b i+1 h
        posi_input = self.dropout(GRU_embeded_input) # batch_size * d_input * hidden_dim


        #mask = subsequent_mask(time_step).to(device) # 1 t t 
        contexts = self.SublayerConnection(posi_input, lambda x: self.MultiHeadedAttention(posi_input, posi_input, posi_input, None))# # batch_size * d_input * hidden_dim
    
        DeCov_loss = contexts[1]
        contexts = contexts[0]

        contexts = self.SublayerConnection(contexts, lambda x: self.PositionwiseFeedForward(contexts))[0]# # batch_size * d_input * hidden_dim
        #contexts = contexts.view(batch_size, feature_dim * self.hidden_dim)#
        # contexts = torch.matmul(self.Wproj, contexts) + self.bproj
        # contexts = contexts.squeeze()
        # demo_key = self.demo_proj(demo_input)# b hidden_dim
        # demo_key = self.relu(demo_key)
        # input_dim_scores = torch.matmul(contexts, demo_key.unsqueeze(-1)).squeeze() # b i
        # input_dim_scores = self.dropout(self.sigmoid(input_dim_scores)).unsqueeze(1)# b i
        
        # weighted_contexts = torch.matmul(input_dim_scores, contexts).squeeze()
#         print(contexts.shape)

        weighted_contexts = self.FinalAttentionQKV(contexts)[0]
        output = self.output(self.dropout(weighted_contexts))# b 1
        output = self.sigmoid(output)
#         print(weighted_contexts.shape)
          
        return output, DeCov_loss  ,weighted_contexts
    #, self.MultiHeadedAttention.attn




In [None]:
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED) #numpy
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED) # cpu
torch.cuda.manual_seed(RANDOM_SEED) #gpu
torch.backends.cudnn.deterministic=True # cudnn
    
epochs = 150
batch_size = 256
input_dim = 34
hidden_dim = 32
d_model = 32
MHD_num_head = 4
d_ff = 64
output_dim = 1

model = distcare_teacher(input_dim = input_dim, hidden_dim = hidden_dim, d_model=d_model, MHD_num_head=MHD_num_head, d_ff=d_ff, output_dim = output_dim).to(device)
# input_dim, d_model, d_k, d_v, MHD_num_head, d_ff, output_dim
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Training Teacher
# If you don't want to train Teacher Model:
# - The pretrained taecher model is in direcrtory './model/', and can be directly loaded, 
# - Simply skip this cell and load the model to validate on Dev Dataset.


total_train_loss = []
total_valid_loss = []
global_best = 0
auroc = []
auprc = []
minpse = []
history = []

pad_token = np.zeros(input_dim)
# begin_time = time.time()
best_auroc = 0
best_auprc = 0
best_minpse = 0
    
if target_dataset == 'TJ':    
    file_name = './model/pretrained-challenge-front-fill-teacher-2covid'
elif target_datset == 'HM':
    file_name = './model/pretrained-challenge-front-fill-teacher-2spain'

for each_epoch in range(epochs):

    epoch_loss = []
    counter_batch = 0
    model.train()  

    for step, (batch_x, batch_y, batch_mask_x, batch_lens) in enumerate(batch_iter(train_x, train_y, train_mask_x, train_x_len, batch_size, shuffle=True)):  
        optimizer.zero_grad()
        batch_x = torch.tensor(pad_sents(batch_x, pad_token), dtype=torch.float32).to(device)
        batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device)
        batch_lens = torch.tensor(batch_lens, dtype=torch.float32).to(device).int()
        batch_mask_x = torch.tensor(pad_sents(batch_mask_x, pad_token), dtype=torch.float32).to(device)

#        masks = length_to_mask(batch_lens).unsqueeze(-1).float()

        opt, decov_loss, emb = model(batch_x, batch_lens)

        BCE_Loss = get_loss(opt, batch_y.unsqueeze(-1)) # b t 1
#             REC_Loss = F.mse_loss(masks * recon, masks * batch_x, reduction='mean').to(device)

        loss = BCE_Loss #+ 1000 * decov_loss

        epoch_loss.append(BCE_Loss.cpu().detach().numpy())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 20)
        optimizer.step()

        if step % 20 == 0:
            print('Epoch %d Batch %d: Train Loss = %.4f'%(each_epoch, step, loss.cpu().detach().numpy()))

    epoch_loss = np.mean(epoch_loss)
    total_train_loss.append(epoch_loss)

    #Validation
    y_true = []
    y_pred = []
    with torch.no_grad():
        model.eval()
        valid_loss = []
        valid_true = []
        valid_pred = []
        for batch_x, batch_y, batch_mask_x, batch_lens in batch_iter(dev_x, dev_y, dev_mask_x, dev_x_len, batch_size):
            batch_x = torch.tensor(pad_sents(batch_x, pad_token), dtype=torch.float32).to(device)
            batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device)
            batch_lens = torch.tensor(batch_lens, dtype=torch.float32).to(device).int()
            batch_mask_x = torch.tensor(pad_sents(batch_mask_x, pad_token), dtype=torch.float32).to(device)
#            masks = length_to_mask(batch_lens).unsqueeze(-1).float()


            opt, decov_loss, emb = model(batch_x, batch_lens)

            BCE_Loss = get_loss(opt, batch_y.unsqueeze(-1))
#                 REC_Loss = F.mse_loss(recon, batch_x, reduction='mean').to(device)

            valid_loss.append(BCE_Loss.cpu().detach().numpy())

            y_pred += list(opt.cpu().detach().numpy().flatten())
            y_true += list(batch_y.cpu().numpy().flatten())

        valid_loss = np.mean(valid_loss)
        total_valid_loss.append(valid_loss)
        ret = metrics.print_metrics_binary(y_true, y_pred,verbose = 0)
        history.append(ret)
        #print()

        print('Epoch %d: Loss = %.4f Valid loss = %.4f roc = %.4f'%(each_epoch, total_train_loss[-1], total_valid_loss[-1], ret['auroc']))
        metrics.print_metrics_binary(y_true, y_pred)

        cur_auroc = ret['auroc']
        if cur_auroc > best_auroc:
            best_auroc = cur_auroc
            best_auprc = ret['auprc']
            best_minpse = ret['minpse']
            state = {
                'net': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': each_epoch
            }
            torch.save(state, file_name)
            print('------------ Save best model - AUROC: %.4f ------------'%cur_auroc)       

print('auroc %.4f'%(best_auroc))
print('auprc %.4f'%(best_auprc))
print('minpse %.4f'%(best_minpse))  

In [None]:
if target_dataset == 'TJ':    
    file_name = './model/pretrained-challenge-front-fill-teacher-2covid'
elif target_dataset == 'HM':
    file_name = './model/pretrained-challenge-front-fill-teacher-2spain'
    
checkpoint = torch.load(file_name, \
                        map_location=torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu') )
save_epoch = checkpoint['epoch']
print("last saved model is in epoch {}".format(save_epoch))
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()


In [None]:
batch_loss = []
y_true = []
y_pred = []
pad_token = np.zeros(34)
with torch.no_grad():
    model.eval()
    for step, (batch_x, batch_y, batch_mask_x, batch_lens) in enumerate(batch_iter(test_x, test_y, test_mask_x, test_x_len, batch_size, shuffle=True)):  
        optimizer.zero_grad()
        batch_x = torch.tensor(pad_sents(batch_x, pad_token), dtype=torch.float32).to(device)
        batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device)
        batch_lens = torch.tensor(batch_lens, dtype=torch.float32).to(device).int()
        batch_mask_x = torch.tensor(pad_sents(batch_mask_x, pad_token), dtype=torch.float32).to(device)

        masks = length_to_mask(batch_lens).unsqueeze(-1).float()

        opt, decov_loss, emb = model(batch_x, batch_lens)

        BCE_Loss = get_loss(opt, batch_y.unsqueeze(-1))
#             REC_Loss = F.mse_loss(masks * recon, masks * batch_x, reduction='mean').to(device)

        model_loss =  BCE_Loss 

        loss = model_loss
        batch_loss.append(loss.cpu().detach().numpy())
        if step % 20 == 0:
            print('Batch %d: Test Loss = %.4f'%(step, loss.cpu().detach().numpy()))
        y_pred += list(opt.cpu().detach().numpy().flatten())
        y_true += list(batch_y.cpu().numpy().flatten())

print("\n==>Predicting on test")
print('Test Loss = %.4f'%(np.mean(np.array(batch_loss))))
y_pred = np.array(y_pred)
y_pred = np.stack([1 - y_pred, y_pred], axis=1)
test_res = metrics.print_metrics_binary(y_true, y_pred)

#### Student Model

In [None]:
class distcare_student(nn.Module):
    def __init__(self, input_dim, hidden_dim, d_model,  MHD_num_head, d_ff, output_dim, keep_prob=0.5):
        super(distcare_student, self).__init__()

        # hyperparameters
        self.input_dim = input_dim  
        self.hidden_dim = hidden_dim  # d_model
        self.d_model = d_model
        self.MHD_num_head = MHD_num_head
        self.d_ff = d_ff
        self.output_dim = output_dim
        self.keep_prob = keep_prob

        # layers
        self.PositionalEncoding = PositionalEncoding(self.d_model, dropout = 0, max_len = 400)

        self.GRUs = clones(nn.GRU(1, self.hidden_dim, batch_first = True), self.input_dim)
        self.LastStepAttentions = clones(SingleAttention(self.hidden_dim, 8, attention_type='concat', demographic_dim=12, time_aware=True, use_demographic=False),self.input_dim)
        
        self.FinalAttentionQKV = FinalAttentionQKV(self.hidden_dim, self.hidden_dim, attention_type='mul',dropout = 1 - self.keep_prob)

        self.MultiHeadedAttention = MultiHeadedAttention(self.MHD_num_head, self.d_model,dropout = 1 - self.keep_prob)
        self.SublayerConnection = SublayerConnection(self.d_model, dropout = 1 - self.keep_prob)

        self.PositionwiseFeedForward = PositionwiseFeedForward(self.d_model, self.d_ff, dropout=0.1)

        self.demo_proj_main = nn.Linear(12, self.hidden_dim)
        self.demo_proj = nn.Linear(12, self.hidden_dim)
        self.output = nn.Linear(self.hidden_dim, self.output_dim)

        self.dropout = nn.Dropout(p = 1 - self.keep_prob)
        self.FC_embed = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.tanh=nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.relu=nn.ReLU()

    def forward(self, input, lens):
        # input shape [batch_size, timestep, feature_dim]
#         demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze(1)# b hidden_dim
        
        batch_size = input.size(0)
        time_step = input.size(1)
        feature_dim = input.size(2)
        assert(feature_dim == self.input_dim)# input Tensor : 256 * 48 * 76
        assert(self.d_model % self.MHD_num_head == 0)

        # Initialization
        #cur_hs = Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0))

        # forward
        # GRU_embeded_input = self.GRUs[0](input[:,:,0].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0] # b t h
        # Attention_embeded_input = self.LastStepAttentions[0](GRU_embeded_input)[0].unsqueeze(1)# b 1 h
        # for i in range(feature_dim-1):
        #     embeded_input = self.GRUs[i+1](input[:,:,i+1].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0] # b 1 h
        #     embeded_input = self.LastStepAttentions[i+1](embeded_input)[0].unsqueeze(1)# b 1 h
        #     Attention_embeded_input = torch.cat((Attention_embeded_input, embeded_input), 1)# b i h

        # Attention_embeded_input = torch.cat((Attention_embeded_input, demo_main), 1)# b i+1 h
        # posi_input = self.dropout(Attention_embeded_input) # batch_size * d_input+1 * hidden_dim

#         input = pack_padded_sequence(input, lens, batch_first=True)
        
        GRU_embeded_input = self.GRUs[0](pack_padded_sequence(input[:,:,0].unsqueeze(-1), lens, batch_first=True))[1].squeeze().unsqueeze(1) # b 1 h
#         print(GRU_embeded_input.shape)
        for i in range(feature_dim-1):
            embeded_input = self.GRUs[i+1](pack_padded_sequence(input[:,:,i+1].unsqueeze(-1), lens, batch_first=True))[1].squeeze().unsqueeze(1) # b 1 h
            GRU_embeded_input = torch.cat((GRU_embeded_input, embeded_input), 1)

#         GRU_embeded_input = torch.cat((GRU_embeded_input, demo_main), 1)# b i+1 h
        posi_input = self.dropout(GRU_embeded_input) # batch_size * d_input * hidden_dim


        #mask = subsequent_mask(time_step).to(device) # 1 t t 下三角 N to 1任务不用mask
        contexts = self.SublayerConnection(posi_input, lambda x: self.MultiHeadedAttention(posi_input, posi_input, posi_input, None))# # batch_size * d_input * hidden_dim
    
        DeCov_loss = contexts[1]
        contexts = contexts[0]

        contexts = self.SublayerConnection(contexts, lambda x: self.PositionwiseFeedForward(contexts))[0]# # batch_size * d_input * hidden_dim
        #contexts = contexts.view(batch_size, feature_dim * self.hidden_dim)#
        # contexts = torch.matmul(self.Wproj, contexts) + self.bproj
        # contexts = contexts.squeeze()
        # demo_key = self.demo_proj(demo_input)# b hidden_dim
        # demo_key = self.relu(demo_key)
        # input_dim_scores = torch.matmul(contexts, demo_key.unsqueeze(-1)).squeeze() # b i
        # input_dim_scores = self.dropout(self.sigmoid(input_dim_scores)).unsqueeze(1)# b i
        
        # weighted_contexts = torch.matmul(input_dim_scores, contexts).squeeze()
#         print(contexts.shape)

        weighted_contexts = self.FinalAttentionQKV(contexts)[0]
        output_embed = self.FC_embed(weighted_contexts)
        output = self.output(self.dropout(weighted_contexts))# b 1
        output = self.sigmoid(output)
#         print(weighted_contexts.shape)
          
        return output, DeCov_loss  ,output_embed
    #, self.MultiHeadedAttention.attn




In [None]:
epochs = 150
batch_size = 256
input_dim = subset_cnt
hidden_dim = 32
d_model = 32
MHD_num_head = 4
d_ff = 64
output_dim = 1

model_student = distcare_student(input_dim = input_dim, hidden_dim = hidden_dim, d_model=d_model, MHD_num_head=MHD_num_head, d_ff=d_ff, output_dim = output_dim).to(device)
# input_dim, d_model, d_k, d_v, MHD_num_head, d_ff, output_dim
optimizer_student = torch.optim.Adam(model_student.parameters(), lr=1e-3)



In [None]:
#Generate Teacher model embedding
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()

train_teacher_emb = []
batch_loss = []
y_true = []
y_pred = []
pad_token = np.zeros(34)
with torch.no_grad():
    model.eval()
    for step, (batch_x, batch_y, batch_mask_x, batch_lens) in enumerate(batch_iter(train_x, train_y, train_mask_x, train_x_len, batch_size, shuffle=False)):  
        optimizer.zero_grad()
        batch_x = torch.tensor(pad_sents(batch_x, pad_token), dtype=torch.float32).to(device)
        batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device)
        batch_lens = torch.tensor(batch_lens, dtype=torch.float32).to(device).int()
        batch_mask_x = torch.tensor(pad_sents(batch_mask_x, pad_token), dtype=torch.float32).to(device)

        masks = length_to_mask(batch_lens).unsqueeze(-1).float()

        opt, decov_loss, emb = model(batch_x, batch_lens)
        train_teacher_emb.append(emb.cpu().detach().numpy())

        BCE_Loss = get_loss(opt, batch_y.unsqueeze(-1))
#             REC_Loss = F.mse_loss(masks * recon, masks * batch_x, reduction='mean').to(device)

        model_loss =  BCE_Loss 
        if step % 20 == 0:
            print('Batch %d: Test Loss = %.4f'%(step, model_loss.cpu().detach().numpy()))

In [None]:
# Training Student
# If you don't want to train Student Model:
# - The pretrained student model is in direcrtory './model/', and can be directly loaded, 
# - Simply skip this cell and load the model to validate on Dev Dataset.

teacher_flag = True
epochs = 200
total_train_loss = []
total_valid_loss = []
global_best = 0
auroc = []
auprc = []
minpse = []
history = []

pad_token = np.zeros(34)
# begin_time = time.time()
best_auroc = 0
best_auprc = 0
best_minpse = 0

if target_dataset == 'TJ':    
    data_str = 'covid'
elif target_datset == 'HM':
    data_str = 'spain'

if teacher_flag:
    file_name = './model/pretrained-challenge-front-fill-2'+ data_str
else: 
    file_name = './model/pretrained-challenge-front-fill-2'+ data_str + '-noteacher'

for each_epoch in range(epochs):

    epoch_loss = []
    counter_batch = 0
    model_student.train()  
    model.eval()
    for step, (batch_x, batch_y, batch_mask_x, batch_lens) in enumerate(batch_iter(train_x, train_y, train_mask_x, train_x_len, batch_size, shuffle=False)):  
        optimizer_student.zero_grad()
        batch_x = torch.tensor(pad_sents(batch_x, pad_token), dtype=torch.float32).to(device)
        batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device)
        batch_lens = torch.tensor(batch_lens, dtype=torch.float32).to(device).int()
        batch_mask_x = torch.tensor(pad_sents(batch_mask_x, pad_token), dtype=torch.float32).to(device)

#        masks = length_to_mask(batch_lens).unsqueeze(-1).float()

        opt_student, decov_loss_student, emb_student = model_student(batch_x[:,:,:subset_cnt], batch_lens)
        emb_teacher = torch.tensor(train_teacher_emb[step], dtype=torch.float32).to(device)
        # opt_teacher, decov_loss_teacher, emb_teacher = model(batch_x, batch_lens)
        BCE_Loss = get_loss(opt_student, batch_y.unsqueeze(-1)) # b t 1
        #emb_Loss = 0.1 * get_re_loss(emb_student, emb_teacher.detach())
        if teacher_flag:
            emb_student = F.log_softmax(emb_student, dim=1)
            emb_teacher = F.softmax(emb_teacher.detach(), dim=1)
            emb_Loss = get_kl_loss(emb_student, emb_teacher)
            loss = BCE_Loss + emb_Loss#+ decov_loss_student
        # emb_Loss = 0.1 * get_wass_dist(emb_student, emb_teacher.detach())
#             REC_Loss = F.mse_loss(masks * recon, masks * batch_x, reduction='mean').to(device)
        else:
            loss = BCE_Loss #+ decov_loss_student

        epoch_loss.append(BCE_Loss.cpu().detach().numpy())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model_student.parameters(), 20)
        optimizer_student.step()

        if step % 20 == 0:
            print('Epoch %d Batch %d: Train Loss = %.4f'%(each_epoch, step, loss.cpu().detach().numpy()))

    epoch_loss = np.mean(epoch_loss)
    total_train_loss.append(epoch_loss)

    #Validation
    y_true = []
    y_pred = []
    with torch.no_grad():
        model_student.eval()
        valid_loss = []
        valid_true = []
        valid_pred = []
        for batch_x, batch_y, batch_mask_x, batch_lens in batch_iter(dev_x, dev_y, dev_mask_x, dev_x_len, batch_size):
            batch_x = torch.tensor(pad_sents(batch_x, pad_token), dtype=torch.float32).to(device)
            batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device)
            batch_lens = torch.tensor(batch_lens, dtype=torch.float32).to(device).int()
            batch_mask_x = torch.tensor(pad_sents(batch_mask_x, pad_token), dtype=torch.float32).to(device)
#            masks = length_to_mask(batch_lens).unsqueeze(-1).float()


            opt_student, decov_loss_student, emb = model_student(batch_x[:,:,:subset_cnt], batch_lens)

            BCE_Loss = get_loss(opt_student, batch_y.unsqueeze(-1))
#                 REC_Loss = F.mse_loss(recon, batch_x, reduction='mean').to(device)

            valid_loss.append(BCE_Loss.cpu().detach().numpy())

            y_pred += list(opt_student.cpu().detach().numpy().flatten())
            y_true += list(batch_y.cpu().numpy().flatten())

        valid_loss = np.mean(valid_loss)
        total_valid_loss.append(valid_loss)
        ret = metrics.print_metrics_binary(y_true, y_pred,verbose = 0)
        history.append(ret)
        #print()

        print('Epoch %d: Loss = %.4f Valid loss = %.4f roc = %.4f'%(each_epoch, total_train_loss[-1], total_valid_loss[-1], ret['auroc']))
        metrics.print_metrics_binary(y_true, y_pred)

        cur_auroc = ret['auroc']
        if cur_auroc > best_auroc:
            best_auroc = cur_auroc
            best_auprc = ret['auprc']
            best_minpse = ret['minpse']
            state = {
                'net': model_student.state_dict(),
                'optimizer': optimizer_student.state_dict(),
                'epoch': each_epoch
            }
            torch.save(state, file_name)
            print('------------ Save best model - AUROC: %.4f ------------'%cur_auroc)       

print('auroc %.4f'%(best_auroc))
print('auprc %.4f'%(best_auprc))
print('minpse %.4f'%(best_minpse)) 

In [None]:
teacher_flag = True

if target_dataset == 'TJ':    
    data_str = 'covid'
elif target_dataset == 'HM':
    data_str = 'spain'

if teacher_flag:
    file_name = './model/pretrained-challenge-front-fill-2'+ data_str
else: 
    file_name = './model/pretrained-challenge-front-fill-2'+ data_str + '-noteacher'

checkpoint = torch.load(file_name, \
                        map_location=torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu') )
save_epoch = checkpoint['epoch']
print("last saved model is in epoch {}".format(save_epoch))
model_student.load_state_dict(checkpoint['net'])
optimizer_student.load_state_dict(checkpoint['optimizer'])
model_student.eval()

In [None]:
batch_loss = []
y_true = []
y_pred = []
with torch.no_grad():
    model_student.eval()
    for step, (batch_x, batch_y, batch_mask_x, batch_lens) in enumerate(batch_iter(test_x, test_y, test_mask_x, test_x_len, batch_size, shuffle=True)):  
        optimizer_student.zero_grad()
        batch_x = torch.tensor(pad_sents(batch_x, pad_token), dtype=torch.float32).to(device)
        batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device)
        batch_lens = torch.tensor(batch_lens, dtype=torch.float32).to(device).int()
        batch_mask_x = torch.tensor(pad_sents(batch_mask_x, pad_token), dtype=torch.float32).to(device)

        masks = length_to_mask(batch_lens).unsqueeze(-1).float()

        opt, decov_loss, emb = model_student(batch_x[:,:,:subset_cnt], batch_lens)

        BCE_Loss = get_loss(opt, batch_y.unsqueeze(-1))
#             REC_Loss = F.mse_loss(masks * recon, masks * batch_x, reduction='mean').to(device)

        model_loss =  BCE_Loss 

        loss = model_loss
        batch_loss.append(loss.cpu().detach().numpy())
        if step % 20 == 0:
            print('Batch %d: Test Loss = %.4f'%(step, loss.cpu().detach().numpy()))
        y_pred += list(opt.cpu().detach().numpy().flatten())
        y_true += list(batch_y.cpu().numpy().flatten())

print("\n==>Predicting on test")
print('Test Loss = %.4f'%(np.mean(np.array(batch_loss))))
y_pred = np.array(y_pred)
y_pred = np.stack([1 - y_pred, y_pred], axis=1)
test_res = metrics.print_metrics_binary(y_true, y_pred)

### Transfer Target Dataset & Model

In [None]:
if target_dataset == 'TJ':
    data_path = './data/Tongji/'
    all_x = pickle.load(open(data_path + 'x.dat', 'rb'))
    all_y = pickle.load(open(data_path + 'y.dat', 'rb'))
    all_time = pickle.load(open(data_path + 'time_all.dat', 'rb'))
    all_x_len = [len(i) for i in all_x]

    tar_subset_idx = [0,1,2,7,11,12,24,25,28,30,32,36,37,39,50,51,65,73]
    tar_other_idx = list(range(74))
    for i in tar_subset_idx:
        tar_other_idx.remove(i)
    for i in range(len(all_x)):
        cur = np.array(all_x[i], dtype=float)
        cur_subset = cur[:, tar_subset_idx]
        cur_other = cur[:, tar_other_idx]
        all_x[i] = np.concatenate((cur_subset, cur_other), axis=1).tolist()
    print(all_x[0])
    print(len(all_x[0][0]))
    
elif target_dataset == 'HM':
    data_path = './data/Spain/'
    all_x = pickle.load(open(data_path + 'x.dat', 'rb'))
    all_y = pickle.load(open(data_path + 'y.dat', 'rb'))
    all_time = pickle.load(open(data_path + 'time_all.dat', 'rb'))
    all_x_len = [len(i) for i in all_x]
    
    tar_subset_idx = [39, 35, 23, 47, 55, 51, 22, 53, 25, 15, 43, 65, 1, 2, 48, 12, 26, 44, 49]
    tar_other_idx = list(range(66))
    for i in tar_subset_idx:
        tar_other_idx.remove(i)
    for i in range(len(all_x)):
        cur = np.array(all_x[i], dtype=float)
        cur_subset = cur[:, tar_subset_idx]
        cur_other = cur[:, tar_other_idx]
        all_x[i] = np.concatenate((cur_subset, cur_other), axis=1).tolist()
    print(all_x[0])
    print(len(all_x[0][0]))
    print(all_y[0])

In [None]:
long_x = all_x
long_y = [s[-1] for s in all_y] if target_dataset == 'TJ' else all_y
long_time = all_time

In [None]:
def get_n2n_data(x, y, x_len):
    length = len(x)
    assert length == len(y)
    assert length == len(x_len)
    new_x = []
    new_y = []
    new_x_len = []
    for i in range(length):
        for j in range(len(x[i])):
            new_x.append(x[i][:j+1])
            new_y.append(y[i][j])
            new_x_len.append(j+1)
    return new_x, new_y, new_x_len

In [None]:
class distcare_target(nn.Module):
    def __init__(self, input_dim, hidden_dim, d_model,  MHD_num_head, d_ff, output_dim, keep_prob=0.5):
        super(distcare_target, self).__init__()

        # hyperparameters
        self.input_dim = input_dim  
        self.hidden_dim = hidden_dim  # d_model
        self.d_model = d_model
        self.MHD_num_head = MHD_num_head
        self.d_ff = d_ff
        self.output_dim = output_dim
        self.keep_prob = keep_prob

        # layers
        self.PositionalEncoding = PositionalEncoding(self.d_model, dropout = 0, max_len = 400)

        self.GRUs = clones(nn.GRU(1, self.hidden_dim, batch_first = True), self.input_dim)
        self.LastStepAttentions = clones(SingleAttention(self.hidden_dim, 16, attention_type='concat', demographic_dim=12, time_aware=True, use_demographic=False),self.input_dim)
        
        self.FinalAttentionQKV = FinalAttentionQKV(self.hidden_dim, self.hidden_dim, attention_type='mul',dropout = 1 - self.keep_prob)

        self.MultiHeadedAttention = MultiHeadedAttention(self.MHD_num_head, self.d_model,dropout = 1 - self.keep_prob)
        self.SublayerConnection = SublayerConnection(self.d_model, dropout = 1 - self.keep_prob)

        self.PositionwiseFeedForward = PositionwiseFeedForward(self.d_model, self.d_ff, dropout=0.1)

        self.demo_proj_main = nn.Linear(12, self.hidden_dim)
        self.demo_proj = nn.Linear(12, self.hidden_dim)
        self.output = nn.Linear(self.hidden_dim, self.output_dim)

        self.dropout = nn.Dropout(p = 1 - self.keep_prob)
        self.FC_embed = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.tanh=nn.Tanh()
        self.MLP = nn.Sequential(
            nn.Linear(self.hidden_dim, 8),
            nn.ReLU(),
            nn.Linear(8, self.output_dim)
        )
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.relu=nn.ReLU()

    def forward(self, input, lens):
        # input shape [batch_size, timestep, feature_dim]
#         demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze(1)# b hidden_dim
        
        batch_size = input.size(0)
        time_step = input.size(1)
        feature_dim = input.size(2)
        assert(feature_dim == self.input_dim)# input Tensor : 256 * 48 * 76
        assert(self.d_model % self.MHD_num_head == 0)

        # Initialization
        #cur_hs = Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0))

        # forward
        # GRU_embeded_input = self.GRUs[0](input[:,:,0].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0] # b t h
        # Attention_embeded_input = self.LastStepAttentions[0](GRU_embeded_input)[0].unsqueeze(1)# b 1 h
        # for i in range(feature_dim-1):
        #     embeded_input = self.GRUs[i+1](input[:,:,i+1].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0] # b 1 h
        #     embeded_input = self.LastStepAttentions[i+1](embeded_input)[0].unsqueeze(1)# b 1 h
        #     Attention_embeded_input = torch.cat((Attention_embeded_input, embeded_input), 1)# b i h

        # Attention_embeded_input = torch.cat((Attention_embeded_input, demo_main), 1)# b i+1 h
        # posi_input = self.dropout(Attention_embeded_input) # batch_size * d_input+1 * hidden_dim

#         input = pack_padded_sequence(input, lens, batch_first=True)
        
        GRU_embeded_input = self.GRUs[0](pack_padded_sequence(input[:,:,0].unsqueeze(-1), lens, batch_first=True))[1].squeeze().unsqueeze(1) # b 1 h
#         print(GRU_embeded_input.shape)
        for i in range(feature_dim-1):
            embeded_input = self.GRUs[i+1](pack_padded_sequence(input[:,:,i+1].unsqueeze(-1), lens, batch_first=True))[1].squeeze().unsqueeze(1) # b 1 h
            GRU_embeded_input = torch.cat((GRU_embeded_input, embeded_input), 1)
        

#         GRU_embeded_input = torch.cat((GRU_embeded_input, demo_main), 1)# b i+1 h
        posi_input = self.dropout(GRU_embeded_input) # batch_size * d_input * hidden_dim


        #mask = subsequent_mask(time_step).to(device) # 1 t t 下三角 N to 1任务不用mask
        contexts = self.SublayerConnection(posi_input, lambda x: self.MultiHeadedAttention(posi_input, posi_input, posi_input, None))# # batch_size * d_input * hidden_dim
    
        DeCov_loss = contexts[1]
        contexts = contexts[0]

        contexts = self.SublayerConnection(contexts, lambda x: self.PositionwiseFeedForward(contexts))[0]# # batch_size * d_input * hidden_dim
        #contexts = contexts.view(batch_size, feature_dim * self.hidden_dim)#
        # contexts = torch.matmul(self.Wproj, contexts) + self.bproj
        # contexts = contexts.squeeze()
        # demo_key = self.demo_proj(demo_input)# b hidden_dim
        # demo_key = self.relu(demo_key)
        # input_dim_scores = torch.matmul(contexts, demo_key.unsqueeze(-1)).squeeze() # b i
        # input_dim_scores = self.dropout(self.sigmoid(input_dim_scores)).unsqueeze(1)# b i
        
        # weighted_contexts = torch.matmul(input_dim_scores, contexts).squeeze()
#         print(contexts.shape)

        weighted_contexts = self.FinalAttentionQKV(contexts)[0]
        #output_embed = self.FC_embed(weighted_contexts)
        output = self.MLP(self.dropout(weighted_contexts))# b 1
        if self.output_dim != 1:
            output = F.softmax(output, dim=1)
#         print(weighted_contexts.shape)
          
        return output, DeCov_loss  ,weighted_contexts
    #, self.MultiHeadedAttention.attn




In [None]:
def transfer_gru_dict(pretrain_dict, model_dict):
    state_dict = {}
    for k, v in pretrain_dict.items():
        if 'GRUs' in k:
            state_dict[k] = v
            print("transfered weight: {}".format(k))
        else:
            print("Other weight in model_dict: {}".format(k))
    return state_dict

In [None]:
if target_dataset == 'TJ':
    input_dim = 74
elif target_dataset == 'HM':
    input_dim = 66
    
cell = 'GRU'
hidden_dim = 32
d_model = 32
MHD_num_head = 4
d_ff = 64
output_dim = 1
RANDOM_SEED = 42

kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=RANDOM_SEED)

In [None]:
def ckd_batch_iter(x, y, lens, batch_size, shuffle=False):
    """ Yield batches of source and target sentences reverse sorted by length (largest to smallest).
    @param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
    @param batch_size (int): batch size
    @param shuffle (boolean): whether to randomly shuffle the dataset
    """
    batch_num = math.ceil(len(x) / batch_size) # 向下取整
    index_array = list(range(len(x)))

    if shuffle:
        np.random.shuffle(index_array)

    for i in range(batch_num):
        indices = index_array[i * batch_size: (i + 1) * batch_size] #  fetch out all the induces
        
        examples = []
        for idx in indices:
            examples.append((x[idx], y[idx],  lens[idx]))
       
        examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
    
        batch_x = [e[0] for e in examples]
        batch_y = [e[1] for e in examples]
#         batch_name = [e[2] for e in examples]
        batch_lens = [e[2] for e in examples]
       

        yield batch_x, batch_y, batch_lens

In [None]:
teacher_flag = True
transfer_flag = True
n_splits = 10
kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_SEED)

if target_dataset == 'TJ':    
    data_str = 'covid'
elif target_dataset == 'HM':
    data_str = 'spain'

if teacher_flag:
    file_name = './model/pretrained-challenge-front-fill-2'+ data_str
else: 
    file_name = './model/pretrained-challenge-front-fill-2'+ data_str + '-noteacher'


epochs = 200
batch_size = 256

fold_count = 0
total_train_loss = []
total_valid_loss = []

global_best = 10000
mse = []
mad = []
mape = []
kappa = []
history = []

pad_token = np.zeros(input_dim)
# begin_time = time.time()

for train, test in kfold.split(long_x, long_y):
    
    model = distcare_target(input_dim = input_dim,output_dim=output_dim, d_model=d_model, MHD_num_head=MHD_num_head, d_ff=d_ff, hidden_dim=hidden_dim).to(device)
    
    if transfer_flag:
        checkpoint = torch.load(file_name, \
                        map_location=torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu'))
        pretrain_dict = checkpoint['net']
        model_dict = model.state_dict()
        pretrain_dict = transfer_gru_dict(pretrain_dict, model_dict)
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)
        
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    fold_count += 1
#     print(train)

    
    train_x = [long_x[i] for i in train]
    train_y = [long_time[i] for i in train]
    train_x_len = [all_x_len[i] for i in train]
    #train_static = [long_static[i] for i in train]
    
    train_x, train_y, train_x_len = get_n2n_data(train_x, train_y, train_x_len)
    
    test_x = [long_x[i] for i in test]
    test_y = [long_time[i] for i in test]
    test_x_len = [all_x_len[i] for i in test]
    #test_static = [long_static[i] for i in test]
    
    test_x, test_y, test_x_len = get_n2n_data(test_x, test_y, test_x_len)
    
    if not os.path.exists('./model/'+data_str):
        os.mkdir('./model/'+data_str)
        
    if transfer_flag:
        target_file_name = './model/'+data_str+'/distcare-trans-'+str(n_splits)+'-fold-LOS-regression' + str(fold_count)#4114
    else:
        target_file_name = './model/'+data_str+'/distcare-no-trans-'+str(n_splits)+'-fold-LOS-regression' + str(fold_count)#4114
    
    fold_train_loss = []
    fold_valid_loss = []
    best_mse = 10000
    best_mad = 0
    best_mape = 0
    best_kappa = 0
    
    for each_epoch in range(epochs):
       
        
        epoch_loss = []
        counter_batch = 0
        model.train()  
        
        for step, (batch_x, batch_y, batch_lens) in enumerate(ckd_batch_iter(train_x, train_y, train_x_len, batch_size, shuffle=True)):  
            optimizer.zero_grad()
            batch_x = torch.tensor(pad_sents(batch_x, pad_token), dtype=torch.float32).to(device)
            batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device)
            batch_lens = torch.tensor(batch_lens, dtype=torch.float32).to(device).int()

            masks = length_to_mask(batch_lens).unsqueeze(-1).float()

            opt, decov_loss, emb = model(batch_x, batch_lens)

            MSE_Loss = get_re_loss(opt, batch_y.unsqueeze(-1))

            model_loss =  MSE_Loss + 1e7*decov_loss

            loss = model_loss

            epoch_loss.append(MSE_Loss.cpu().detach().numpy())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 20)
            optimizer.step()
            
            if step % 50 == 0:
                print('Fold %d Epoch %d Batch %d: Train Loss = %.4f'%(fold_count,each_epoch, step, loss.cpu().detach().numpy()))
            
        epoch_loss = np.mean(epoch_loss)
        fold_train_loss.append(epoch_loss)

        #Validation
        y_true = []
        y_pred = []
        y_pred_flatten = []
        y_true_flatten = []
        with torch.no_grad():
            model.eval()
            valid_loss = []
            valid_true = []
            valid_pred = []
            for batch_x, batch_y, batch_lens in ckd_batch_iter(test_x, test_y, test_x_len, batch_size):
                batch_x = torch.tensor(pad_sents(batch_x, pad_token), dtype=torch.float32).to(device)
                batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device)
                batch_lens = torch.tensor(batch_lens, dtype=torch.float32).to(device).int()
                masks = length_to_mask(batch_lens).unsqueeze(-1).float()
               
                opt, decov_loss, emb = model(batch_x, batch_lens)
                
                MSE_Loss = get_re_loss(opt, batch_y.unsqueeze(-1))
                
                valid_loss.append(MSE_Loss.cpu().detach().numpy())

                y_pred_flatten += list(opt.cpu().detach().numpy().flatten())
                y_true_flatten += list(batch_y.cpu().numpy().flatten())
            

            valid_loss = np.mean(valid_loss)
            fold_valid_loss.append(valid_loss)
            ret = metrics.print_metrics_regression(y_true_flatten, y_pred_flatten, verbose=0)
            history.append(ret)
            #print()

            if each_epoch % 10 == 0:
                print('Fold %d, epoch %d: Loss = %.4f Valid loss = %.4f MSE = %.4f' % (
                    fold_count, each_epoch, fold_train_loss[-1], fold_valid_loss[-1], ret['mse']), flush=True)
                # metrics.print_metrics_regression(y_true_flatten, y_pred_flatten)
                
            cur_mse = ret['mse']
            if cur_mse < best_mse:
                print('------------ Save FOLD-BEST model - MSE: %.4f ------------' % cur_mse, flush=True)
                metrics.print_metrics_regression(y_true_flatten, y_pred_flatten)
                best_mse = cur_mse
                best_mad = ret['mad']
                state = {
                    'net': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': each_epoch
                }
                torch.save(state, target_file_name + '_' + str(fold_count))

                if cur_mse < global_best:
                    global_best = cur_mse
                    state = {
                        'net': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'epoch': each_epoch
                    }
                    torch.save(state, target_file_name)
                    print('------------ Save best model - MSE: %.4f ------------' % cur_mse, flush=True)

        print('Fold %d, mse = %.4f, mad = %.4f' % (fold_count, ret['mse'], ret['mad']), flush=True)

    mse.append(best_mse)
    mad.append(best_mad)
    total_train_loss.append(fold_train_loss)
    total_valid_loss.append(fold_valid_loss)

print('mse %.4f(%.4f)' % (np.mean(mse), np.std(mse)))
print('mad %.4f(%.4f)' % (np.mean(mad), np.std(mad)))

In [None]:
print('mse %.4f(%.4f)' % (np.mean(mse), np.std(mse)))
print('mad %.4f(%.4f)' % (np.mean(mad), np.std(mad)))