In [1]:
import torch
import torch.nn as nn
import pickle5 as pickle
import numpy as np
import copy
import math
import torch.nn.functional as F

from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from utils import metrics

In [2]:
target_dataset = "TJ"
device = torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu')

In [3]:
class SingleAttention(nn.Module):
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', demographic_dim=12, time_aware=False, use_demographic=False):
        super(SingleAttention, self).__init__()
        
        self.attention_type = attention_type
        self.attention_hidden_dim = attention_hidden_dim
        self.attention_input_dim = attention_input_dim
        self.use_demographic = use_demographic
        self.demographic_dim = demographic_dim
        self.time_aware = time_aware

        # batch_time = torch.arange(0, batch_mask.size()[1], dtype=torch.float32).reshape(1, batch_mask.size()[1], 1)
        # batch_time = batch_time.repeat(batch_mask.size()[0], 1, 1)
        
        if attention_type == 'add':
            if self.time_aware == True:
                # self.Wx = nn.Parameter(torch.randn(attention_input_dim+1, attention_hidden_dim))
                self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
                self.Wtime_aware = nn.Parameter(torch.randn(1, attention_hidden_dim))
                nn.init.kaiming_uniform_(self.Wtime_aware, a=math.sqrt(5))
            else:
                self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wt = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wd = nn.Parameter(torch.randn(demographic_dim, attention_hidden_dim))
            self.bh = nn.Parameter(torch.zeros(attention_hidden_dim,))
            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
            self.ba = nn.Parameter(torch.zeros(1,))
            
            nn.init.kaiming_uniform_(self.Wd, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        elif attention_type == 'mul':
            self.Wa = nn.Parameter(torch.randn(attention_input_dim, attention_input_dim))
            self.ba = nn.Parameter(torch.zeros(1,))
            
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        elif attention_type == 'concat':
            if self.time_aware == True:
                self.Wh = nn.Parameter(torch.randn(2*attention_input_dim+1, attention_hidden_dim))
            else:
                self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))

            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
            self.ba = nn.Parameter(torch.zeros(1,))
            
            nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        else:
            raise RuntimeError('Wrong attention type.')
        
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()
    
    def forward(self, input, demo=None):
 
        batch_size, time_step, input_dim = input.size() # batch_size * time_step * hidden_dim(i)
        time_decays = torch.tensor(range(time_step-1,-1,-1), dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(device)# 1*t*1
        b_time_decays = time_decays.repeat(batch_size,1,1)# b t 1
        
        if self.attention_type == 'add': #B*T*I  @ H*I
            q = torch.matmul(input[:,-1,:], self.Wt)# b h
            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            if self.time_aware == True:
                # k_input = torch.cat((input, time), dim=-1)
                k = torch.matmul(input, self.Wx)#b t h
                # k = torch.reshape(k, (batch_size, 1, time_step, self.attention_hidden_dim)) #B*1*T*H
                time_hidden = torch.matmul(b_time_decays, self.Wtime_aware)#  b t h
            else:
                k = torch.matmul(input, self.Wx)# b t h
                # k = torch.reshape(k, (batch_size, 1, time_step, self.attention_hidden_dim)) #B*1*T*H
            if self.use_demographic == True:
                d = torch.matmul(demo, self.Wd) #B*H
                d = torch.reshape(d, (batch_size, 1, self.attention_hidden_dim)) # b 1 h
            h = q + k + self.bh # b t h
            if self.time_aware == True:
                h += time_hidden
            h = self.tanh(h) #B*T*H
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step))# b t
        elif self.attention_type == 'mul':
            e = torch.matmul(input[:,-1,:], self.Wa)#b i
            e = torch.matmul(e.unsqueeze(1), input.permute(0,2,1)).squeeze() + self.ba #b t
        elif self.attention_type == 'concat':
            q = input[:,-1,:].unsqueeze(1).repeat(1,time_step,1)# b t i
            k = input
            c = torch.cat((q, k), dim=-1) #B*T*2I
            if self.time_aware == True:
                c = torch.cat((c, b_time_decays), dim=-1) #B*T*2I+1
            h = torch.matmul(c, self.Wh)
            h = self.tanh(h)
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step)) # b t 
        
        a = self.softmax(e) #B*T
        v = torch.matmul(a.unsqueeze(1), input).squeeze() #B*I

        return v, a

class FinalAttentionQKV(nn.Module):
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', dropout=None):
        super(FinalAttentionQKV, self).__init__()
        
        self.attention_type = attention_type
        self.attention_hidden_dim = attention_hidden_dim
        self.attention_input_dim = attention_input_dim


        self.W_q = nn.Linear(attention_input_dim, attention_hidden_dim)
        self.W_k = nn.Linear(attention_input_dim, attention_hidden_dim)
        self.W_v = nn.Linear(attention_input_dim, attention_hidden_dim)

        self.W_out = nn.Linear(attention_hidden_dim, 1)

        self.b_in = nn.Parameter(torch.zeros(1,))
        self.b_out = nn.Parameter(torch.zeros(1,))

        nn.init.kaiming_uniform_(self.W_q.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_k.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_v.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_out.weight, a=math.sqrt(5))

        self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))
        self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
        self.ba = nn.Parameter(torch.zeros(1,))
        
        nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        
        self.dropout = nn.Dropout(p=dropout)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, input):
 
        batch_size, time_step, input_dim = input.size() # batch_size * input_dim + 1 * hidden_dim(i)
        input_q = self.W_q(torch.mean(input, dim=1)) # b h
        input_k = self.W_k(input)# b t h
        input_v = self.W_v(input)# b t h

        if self.attention_type == 'add': #B*T*I  @ H*I

            q = torch.reshape(input_q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            h = q + input_k + self.b_in # b t h
            h = self.tanh(h) #B*T*H
            e = self.W_out(h) # b t 1
            e = torch.reshape(e, (batch_size, time_step))# b t

        elif self.attention_type == 'mul':
            q = torch.reshape(input_q, (batch_size, self.attention_hidden_dim, 1)) #B*h 1
            e = torch.matmul(input_k, q).squeeze()#b t
            
        elif self.attention_type == 'concat':
            q = input_q.unsqueeze(1).repeat(1,time_step,1)# b t h
            k = input_k
            c = torch.cat((q, k), dim=-1) #B*T*2I
            h = torch.matmul(c, self.Wh)
            h = self.tanh(h)
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step)) # b t 
        
        a = self.softmax(e) #B*T
        if self.dropout is not None:
            a = self.dropout(a)
        v = torch.matmul(a.unsqueeze(1), input_v).squeeze() #B*I

        return v, a

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
    return torch.index_select(a, dim, order_index).to(device)

class PositionwiseFeedForward(nn.Module): # new added
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x)))), None

    
class PositionalEncoding(nn.Module): # new added / not use anymore
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=400):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0., max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0 

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)# b h t d_k
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k) # b h t t
    if mask is not None:# 1 1 t t
        scores = scores.masked_fill(mask == 0, -1e9)# b h t t 
    p_attn = F.softmax(scores, dim = -1)# b h t t
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn # b h t v (d_k) 
    
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, self.d_k * self.h), 3)
        self.final_linear = nn.Linear(d_model, d_model)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1) # 1 1 t t

        nbatches = query.size(0)# b
        input_dim = query.size(1)# i+1
        feature_dim = query.size(-1)# i+1

        #input size -> # batch_size * d_input * hidden_dim
        
        # d_model => h * d_k 
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))] # b num_head d_input d_k
        
       
        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)# b num_head d_input d_v (d_k) 
        
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)# batch_size * d_input * hidden_dim

        #DeCov 
        DeCov_contexts = x.transpose(0, 1).transpose(1, 2) # I+1 H B
#         print(DeCov_contexts.shape)
        Covs = cov(DeCov_contexts[0,:,:])
        DeCov_loss = 0.5 * (torch.norm(Covs, p = 'fro')**2 - torch.norm(torch.diag(Covs))**2 ) 
        for i in range(11 -1):
            Covs = cov(DeCov_contexts[i+1,:,:])
            DeCov_loss += 0.5 * (torch.norm(Covs, p = 'fro')**2 - torch.norm(torch.diag(Covs))**2 ) 


        return self.final_linear(x), DeCov_loss

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-7):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

def cov(m, y=None):
    if y is not None:
        m = torch.cat((m, y), dim=0)
    m_exp = torch.mean(m, dim=1)
    x = m - m_exp[:, None]
    cov = 1 / (x.size(1) - 1) * x.mm(x.t())
    return cov

class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        returned_value = sublayer(self.norm(x))
        return x + self.dropout(returned_value[0]) , returned_value[1]

In [4]:
class FeatureAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(FeatureAttention, self).__init__()
        # 注意力网络
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),  # 降维
            nn.Tanh(),
            nn.Linear(hidden_dim // 2, 1)  # 输出单个注意力分数
        )

    def forward(self, x):
        """
        x: (batch_size, num_features, hidden_dim)
        返回值：
        - weighted_output: (batch_size, num_features, hidden_dim)
        - attention_weights: (batch_size, num_features)
        """
        # 计算注意力分数
        attention_scores = self.attention(x)  # (batch_size, num_features, 1)
        attention_weights = torch.softmax(attention_scores.squeeze(-1), dim=1)  # (batch_size, num_features)
        
        attention_weights = (attention_weights - attention_weights.mean(dim=1).unsqueeze(-1)) / attention_weights.std(dim=1).unsqueeze(-1)
#         attention_weights = (attention_weights - attention_weights.min(dim=1).values.unsqueeze(-1)) / (attention_weights.max(dim=1).values.unsqueeze(-1) - attention_weights.min(dim=1).values.unsqueeze(-1))

        # 加权求和
        weighted_output = attention_weights.unsqueeze(-1) * x  # (batch_size, num_features, hidden_dim)
        return weighted_output, attention_weights

class distcare_target(nn.Module):
    def __init__(self, input_dim, hidden_dim, d_model,  MHD_num_head, d_ff, output_dim, keep_prob=0.7):
        super(distcare_target, self).__init__()

        # hyperparameters
        self.input_dim = input_dim  
        self.hidden_dim = hidden_dim  # d_model
        self.d_model = d_model
        self.MHD_num_head = MHD_num_head
        self.d_ff = d_ff
        self.output_dim = output_dim
        self.keep_prob = keep_prob

        # layers
        self.PositionalEncoding = PositionalEncoding(self.d_model, dropout = 0, max_len = 400)

        self.GRUs = clones(nn.GRU(1, self.hidden_dim, batch_first = True), self.input_dim)
        self.FeatureAttention = FeatureAttention(self.hidden_dim)
        self.LastStepAttentions = clones(SingleAttention(self.hidden_dim, 16, attention_type='concat', demographic_dim=12, time_aware=True, use_demographic=False),self.input_dim)
        
        self.FinalAttentionQKV = FinalAttentionQKV(self.hidden_dim, self.hidden_dim, attention_type='mul',dropout = 1 - self.keep_prob)

        self.MultiHeadedAttention = MultiHeadedAttention(self.MHD_num_head, self.d_model,dropout = 1 - self.keep_prob)
        self.SublayerConnection = SublayerConnection(self.d_model, dropout = 1 - self.keep_prob)

        self.PositionwiseFeedForward = PositionwiseFeedForward(self.d_model, self.d_ff, dropout=0.1)

        self.demo_proj_main = nn.Linear(12, self.hidden_dim)
        self.demo_proj = nn.Linear(12, self.hidden_dim)
        self.output = nn.Linear(self.hidden_dim, self.output_dim)

        self.dropout = nn.Dropout(p = 1 - self.keep_prob)
        self.FC_embed = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.tanh=nn.Tanh()
        self.Linear = nn.Linear(self.hidden_dim, 1)
        self.Linear_los = nn.Linear(self.input_dim, self.output_dim)
        self.Linear_outcome = nn.Linear(self.input_dim, self.output_dim)
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.relu=nn.ReLU()

    def forward(self, input, lens):
        lens = lens.to('cpu')
        # input shape [batch_size, timestep, feature_dim]
#         demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze(1)# b hidden_dim
        
        batch_size = input.size(0)
        time_step = input.size(1)
        feature_dim = input.size(2)
        assert(feature_dim == self.input_dim)# input Tensor : 256 * 48 * 76
        assert(self.d_model % self.MHD_num_head == 0)

        # Initialization
        #cur_hs = Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0))

        # forward
        # GRU_embeded_input = self.GRUs[0](input[:,:,0].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0] # b t h
        # Attention_embeded_input = self.LastStepAttentions[0](GRU_embeded_input)[0].unsqueeze(1)# b 1 h
        # for i in range(feature_dim-1):
        #     embeded_input = self.GRUs[i+1](input[:,:,i+1].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0] # b 1 h
        #     embeded_input = self.LastStepAttentions[i+1](embeded_input)[0].unsqueeze(1)# b 1 h
        #     Attention_embeded_input = torch.cat((Attention_embeded_input, embeded_input), 1)# b i h

        # Attention_embeded_input = torch.cat((Attention_embeded_input, demo_main), 1)# b i+1 h
        # posi_input = self.dropout(Attention_embeded_input) # batch_size * d_input+1 * hidden_dim

#         input = pack_padded_sequence(input, lens, batch_first=True)
        
        GRU_embeded_input = []
        for i in range(feature_dim):
            embeded_input = self.GRUs[i](pack_padded_sequence(input[:,:,i].unsqueeze(-1), lens, batch_first=True))[1].squeeze().unsqueeze(1) # b 1 h
            GRU_embeded_input.append(embeded_input)
        

        GRU_embeded_input = torch.cat(GRU_embeded_input, 1)
        weighted_input, feature_importance = self.FeatureAttention(GRU_embeded_input)
        posi_input = self.dropout(weighted_input) # batch_size * d_input * hidden_dim


#         #mask = subsequent_mask(time_step).to(device) # 1 t t 下三角 N to 1任务不用mask
#         contexts = self.SublayerConnection(posi_input, lambda x: self.MultiHeadedAttention(posi_input, posi_input, posi_input, None))# # batch_size * d_input * hidden_dim
    
#         DeCov_loss = contexts[1]
#         contexts = contexts[0]

#         contexts = self.SublayerConnection(contexts, lambda x: self.PositionwiseFeedForward(contexts))[0]# # batch_size * d_input * hidden_dim
#         #contexts = contexts.view(batch_size, feature_dim * self.hidden_dim)#
#         # contexts = torch.matmul(self.Wproj, contexts) + self.bproj
#         # contexts = contexts.squeeze()
#         # demo_key = self.demo_proj(demo_input)# b hidden_dim
#         # demo_key = self.relu(demo_key)
#         # input_dim_scores = torch.matmul(contexts, demo_key.unsqueeze(-1)).squeeze() # b i
#         # input_dim_scores = self.dropout(self.sigmoid(input_dim_scores)).unsqueeze(1)# b i
        
#         # weighted_contexts = torch.matmul(input_dim_scores, contexts).squeeze()
# #         print(contexts.shape)

#         weighted_contexts = self.FinalAttentionQKV(contexts)[0]
#         #output_embed = self.FC_embed(weighted_contexts)
        contexts = self.Linear(posi_input).squeeze()# b i
        output = self.Linear_los(self.dropout(contexts))# b 1
        outcome = self.Linear_outcome(self.dropout(contexts))# b 1
        outcome = F.sigmoid(outcome)
#         if self.output_dim != 1:
#             output = F.softmax(output, dim=1)
#         print(weighted_contexts.shape)
          
        return output, None, None, outcome, feature_importance
    #, self.MultiHeadedAttention.attn

In [5]:
if target_dataset == 'TJ':
    data_path = './data/Tongji/'
    all_x = pickle.load(open(data_path + 'x.pkl', 'rb'))
    all_y = pickle.load(open(data_path + 'y.pkl', 'rb'))
    all_time = pickle.load(open(data_path + 'y.pkl', 'rb'))
    all_x_len = [len(i) for i in all_x]

    for i in range(len(all_time)):
        for j in range(len(all_time[i])):
            all_time[i][j] = all_time[i][j][-1]
            all_y[i][j] = all_y[i][j][0]

    tar_subset_idx = [2, 3, 4, 9, 13, 14, 26, 27, 30, 32, 34, 38, 39, 41, 52, 53, 66, 74]
    tar_other_idx = list(range(75))
    for i in tar_subset_idx:
        tar_other_idx.remove(i)
    for i in range(len(all_x)):
        cur = np.array(all_x[i], dtype=float)
        cur_subset = cur[:, tar_subset_idx]
        cur_other = cur[:, tar_other_idx]
        all_x[i] = np.concatenate((cur_subset, cur_other), axis=1).tolist()
elif target_dataset == 'HM':
    data_path = './data/CDSL/'
    all_x = pickle.load(open(data_path + 'x.pkl', 'rb'))
    all_y = pickle.load(open(data_path + 'y.pkl', 'rb'))
    all_time = pickle.load(open(data_path + 'y.pkl', 'rb'))
    all_x_len = [len(i) for i in all_x]

    for i in range(len(all_time)):
        for j in range(len(all_time[i])):
            all_time[i][j] = all_time[i][j][-1]
            all_y[i][j] = all_y[i][j][0]

    tar_subset_idx = [5, 6, 4, 2, 3, 48, 79, 76, 87, 25, 30, 31, 18, 43, 58, 66, 40, 57, 23, 92, 50, 54, 91, 60, 39, 81]
    tar_other_idx = list(range(99))
    for i in tar_subset_idx:
        tar_other_idx.remove(i)
    for i in range(len(all_x)):
        cur = np.array(all_x[i], dtype=float)
        cur_subset = cur[:, tar_subset_idx]
        cur_other = cur[:, tar_other_idx]
        all_x[i] = np.concatenate((cur_subset, cur_other), axis=1).tolist()

In [6]:
long_x = all_x
long_y = all_y
long_y_kfold = [each[-1] for each in all_y]
long_time = all_time

In [7]:
def get_n2n_data(x, y, x_len, outcome=None):
    length = len(x)
    assert length == len(y)
    assert length == len(outcome)
    assert length == len(x_len)
    new_x = []
    new_y = []
    new_outcome = []
    new_x_len = []
    for i in range(length):
        for j in range(len(x[i])):
            new_x.append(x[i][:j+1])
            new_y.append(y[i][j])
            new_outcome.append(outcome[i][j])
            new_x_len.append(j+1)
    return new_x, new_y, new_x_len, new_outcome

def ckd_batch_iter(x, y, lens, batch_size, shuffle=False, outcome=None):
    """ Yield batches of source and target sentences reverse sorted by length (largest to smallest).
    @param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
    @param batch_size (int): batch size
    @param shuffle (boolean): whether to randomly shuffle the dataset
    """
    batch_num = math.ceil(len(x) / batch_size) # 向下取整
    index_array = list(range(len(x)))

    if shuffle:
        np.random.shuffle(index_array)

    for i in range(batch_num):
        indices = index_array[i * batch_size: (i + 1) * batch_size] #  fetch out all the induces
        
        examples = []
        for idx in indices:
            examples.append((x[idx], y[idx],  lens[idx], outcome[idx]))
       
        examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
    
        batch_x = [e[0] for e in examples]
        batch_y = [e[1] for e in examples]
#         batch_name = [e[2] for e in examples]
        batch_lens = [e[2] for e in examples]
        batch_outcome = [e[3] for e in examples]
       

        yield batch_x, batch_y, batch_lens, batch_outcome
        
def pad_sents(sents, pad_token):

    sents_padded = []

    max_length = max([len(_) for _ in sents])
    for i in sents:
        padded = list(i) + [pad_token]*(max_length-len(i))
        sents_padded.append(np.array(padded))


    return np.array(sents_padded)

def length_to_mask(length, max_len=None, dtype=None):
    """length: B.
    return B x max_len.
    If max_len is None, then max of length will be used.
    """
    assert len(length.shape) == 1, 'Length shape should be 1 dimensional.'
    max_len = max_len or length.max().item()
    mask = torch.arange(max_len, device=length.device,
                        dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1)
    if dtype is not None:
        mask = torch.as_tensor(mask, dtype=dtype, device=length.device)
    return mask

def reverse_los(y, los_info):
    return y * los_info["los_std"] + los_info["los_mean"]

class TargetMultitaskLoss(nn.Module):
    def __init__(self, task_num=2):
        super(TargetMultitaskLoss, self).__init__()
        self.task_num = task_num
        self.alpha = nn.Parameter(torch.ones((task_num)), requires_grad=True)
        self.mse = nn.MSELoss()
        self.bce = nn.BCELoss()

    def forward(self, opt_student, los, outcome, outcome_y):
        MSE_Loss = self.mse(opt_student, los)
        BCE_Loss = self.bce(outcome, outcome_y)
        return MSE_Loss * self.alpha[0] + BCE_Loss * self.alpha[1]

def get_target_multitask_loss(opt_student, los, outcome, outcome_y):
    mtl = TargetMultitaskLoss(task_num=2)
    return mtl(opt_student, los, outcome, outcome_y)

In [8]:
if target_dataset == 'TJ':
    input_dim = 75
    data_path = './data/Tongji/'
elif target_dataset == 'HM':
    input_dim = 99
    data_path = './data/CDSL/'
    
cell = 'GRU'
hidden_dim = 32
d_model = 32
MHD_num_head = 4
d_ff = 64
output_dim = 1
batch_size = 256

In [9]:
model = distcare_target(input_dim = input_dim,output_dim=output_dim, d_model=d_model, MHD_num_head=MHD_num_head, d_ff=d_ff, hidden_dim=hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [10]:
if target_dataset == 'TJ':    
    file_name = './model/covid/distcare-trans-5-fold-LOS-regression-Attention1'
    
checkpoint = torch.load(file_name, map_location=torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu') )
save_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()

distcare_target(
  (PositionalEncoding): PositionalEncoding(
    (dropout): Dropout(p=0, inplace=False)
  )
  (GRUs): ModuleList(
    (0): GRU(1, 32, batch_first=True)
    (1): GRU(1, 32, batch_first=True)
    (2): GRU(1, 32, batch_first=True)
    (3): GRU(1, 32, batch_first=True)
    (4): GRU(1, 32, batch_first=True)
    (5): GRU(1, 32, batch_first=True)
    (6): GRU(1, 32, batch_first=True)
    (7): GRU(1, 32, batch_first=True)
    (8): GRU(1, 32, batch_first=True)
    (9): GRU(1, 32, batch_first=True)
    (10): GRU(1, 32, batch_first=True)
    (11): GRU(1, 32, batch_first=True)
    (12): GRU(1, 32, batch_first=True)
    (13): GRU(1, 32, batch_first=True)
    (14): GRU(1, 32, batch_first=True)
    (15): GRU(1, 32, batch_first=True)
    (16): GRU(1, 32, batch_first=True)
    (17): GRU(1, 32, batch_first=True)
    (18): GRU(1, 32, batch_first=True)
    (19): GRU(1, 32, batch_first=True)
    (20): GRU(1, 32, batch_first=True)
    (21): GRU(1, 32, batch_first=True)
    (22): GRU(1, 32, b

In [11]:
x, y, length, outcome = get_n2n_data(long_x, long_time, all_x_len, outcome=long_y)

valid_loss = []
importance = []
y_pred_flatten = []
y_true_flatten = []
outcome_pred_flatten = []
outcome_true_flatten = []

los_info = pickle.load(open(data_path + 'los_info.pkl', 'rb'))

pad_token = np.zeros(input_dim)

with torch.no_grad():
    model.eval()
    for batch_x, batch_y, batch_lens, batch_outcome in ckd_batch_iter(x, y, length, batch_size, outcome=outcome):
        batch_x = torch.tensor(pad_sents(batch_x, pad_token), dtype=torch.float32).to(device)
        batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device)
        batch_lens = torch.tensor(batch_lens, dtype=torch.float32).to(device).int()
        batch_outcome = torch.tensor(batch_outcome, dtype=torch.float32).to(device)
        masks = length_to_mask(batch_lens).unsqueeze(-1).float()

        opt, decov_loss, emb, outcome, attention = model(batch_x, batch_lens)

        pred_loss = get_target_multitask_loss(opt, batch_y.unsqueeze(-1), outcome, batch_outcome.unsqueeze(-1))

        valid_loss.append(pred_loss.cpu().detach().numpy())
        importance.append(attention.cpu().detach().numpy())
        
        y_pred_flatten += [reverse_los(x, los_info) for x in list(opt.cpu().detach().numpy().flatten())]
        y_true_flatten += [reverse_los(x, los_info) for x in list(batch_y.cpu().numpy().flatten())]
        outcome_pred_flatten += list(outcome.cpu().detach().numpy().flatten())
        outcome_true_flatten += list(batch_outcome.cpu().numpy().flatten())


    valid_loss = np.mean(valid_loss)
    ret = metrics.print_metrics_regression(y_true_flatten, y_pred_flatten, verbose=0)
    ret_outcome = metrics.print_metrics_binary(outcome_true_flatten, outcome_pred_flatten, verbose=0)
    importance = np.concatenate(importance, axis=0)

ret, ret_outcome



({'mad': 3.032338399797889,
  'mse': 19.29585890315027,
  'mape': 268.5830375564429,
  'kappa': 0.6376578028471662},
 {'acc': 0.94953054,
  'prec0': 0.9581749,
  'prec1': 0.9355828,
  'rec0': 0.96,
  'rec1': 0.93272173,
  'auroc': 0.98885248288918,
  'auprc': 0.983666733753241,
  'minpse': 0.9327217125382263,
  'f1_score': 0.9341501144250984})

In [14]:
len(importance)

1704