In [47]:
import torch.nn as nn
import torch
import numpy as np
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc

In [48]:
def initialise_linear_layer(layer):
    if isinstance(layer, nn.Linear):
        torch.nn.init.xavier_uniform_(layer.weight)
        if layer.bias is not None:
            torch.nn.init.zeros_(layer.bias)
            
def initialise_parameters(layer, method):
    if isinstance(layer, nn.Parameter):
        if method=='glorot_uniform':
            torch.nn.init.xavier_uniform_(layer)
        elif method=='zeros':
            torch.nn.init.zeros_(layer)
        elif method=='ones':
            torch.nn.init.ones_(layer)

In [49]:
class CVE(nn.Module):
    def __init__(self, hid_dim, output_dim):
        super(CVE, self).__init__()
        self.stack = nn.Sequential(
            nn.Linear(in_features=1, out_features=hid_dim, bias=True),
            nn.Tanh(),
            nn.Linear(in_features=hid_dim, out_features=output_dim, bias=False)
        )
        self.stack.apply(initialise_linear_layer)
        
    def forward(self, X):
        X = X.unsqueeze(dim=-1)
        return self.stack(X)

In [50]:
class Attention(nn.Module):
    def __init__(self, d, hid_dim):
        super(Attention, self).__init__()
        self.stack = nn.Sequential(
            nn.Linear(in_features=d, out_features=hid_dim, bias=True),
            nn.Tanh(),
            nn.Linear(in_features=hid_dim, out_features=1, bias=False)
        )
        self.softmax = nn.Softmax(dim=-2)
        self.stack.apply(initialise_linear_layer)
    
    def forward(self, X, mask, mask_value=-1e30):
        attn_weights = self.stack(X)
        mask = torch.unsqueeze(mask, dim=-1)
        attn_weights = mask*attn_weights + (1-mask)*mask_value
        attn_weights = self.softmax(attn_weights)
        
        return attn_weights

In [51]:
class Transformer(nn.Module):
    def __init__(self, d, N=2, h=8, dk=None, dv=None, dff=None, dropout=0, epsilon=1e-07):
        super(Transformer, self).__init__()
        
        self.N, self.h, self.dk, self.dv, self.dff, self.dropout = N, h, dk, dv, dff, dropout
        self.epsilon = epsilon * epsilon
        if self.dk==None:
            self.dk = d // self.h
        if self.dv==None:
            self.dv = d//self.h
        if self.dff==None:
            self.dff = 2*d
        
        self.Wq = nn.Parameter(torch.empty(self.N, self.h, d, self.dk))
        initialise_parameters(self.Wq, 'glorot_uniform')
        
        self.Wk = nn.Parameter(torch.empty(self.N, self.h, d, self.dk))
        initialise_parameters(self.Wk, 'glorot_uniform')
        
        self.Wv = nn.Parameter(torch.empty(self.N, self.h, d, self.dk))
        initialise_parameters(self.Wv, 'glorot_uniform')
        
        self.Wo = nn.Parameter(torch.empty(self.N, self.dv*self.h, d))
        initialise_parameters(self.Wo, 'glorot_uniform')
        
        
        self.W1 = nn.Parameter(torch.empty(self.N, d, self.dff))
        initialise_parameters(self.W1, 'glorot_uniform')
        
        self.b1 = nn.Parameter(torch.empty(self.N, self.dff))
        initialise_parameters(self.b1, 'zeros')
        
        self.W2 = nn.Parameter(torch.empty(self.N, self.dff, d))
        initialise_parameters(self.W2, 'glorot_uniform')
        
        self.b2 = nn.Parameter(torch.empty(self.N, d))
        initialise_parameters(self.b2, 'zeros')
        
        
        self.gamma = nn.Parameter(torch.empty(2*self.N,))
        initialise_parameters(self.gamma, 'ones')
        
        self.beta = nn.Parameter(torch.empty(2*self.N,))
        initialise_parameters(self.beta, 'zeros')
        
        
        self.dropout_layer = nn.Dropout(p=self.dropout)
        self.identity = nn.Identity()
        
    def forward(self, X, mask, mask_value=-1e-30):
        mask = torch.unsqueeze(mask, dim=-2)
        
        for i in range(self.N):
            mha_ops = []
            for j in range(self.h):
                q = torch.matmul(X, self.Wq[i,j,:,:])
                k = torch.matmul(X, self.Wk[i,j,:,:]).permute(0,2,1)
                v = torch.matmul(X, self.Wv[i,j,:,:])
                A = torch.bmm(q, k)
                A = mask * A + (1-mask) * mask_value
                
                def dropped_A():
                    dp_mask = (torch.rand_like(A)>=self.dropout).type(dtype=torch.float32)
                    return A*dp_mask + (1-dp_mask)*mask_value
                # Dropout
                if self.training:
                    A = dropped_A()
                else:
                    A = self.identity(A)
                    
                A = nn.functional.softmax(A, dim=-1)
                
                mha_ops.append(torch.bmm(A, v))
            
            conc = torch.cat(mha_ops, dim=-1)
            proj = torch.matmul(conc, self.Wo[i,:,:])
            # Dropout
            if self.training:
                proj = self.identity(self.dropout_layer(proj))
            else:
                proj = self.identity(proj)
            
            # Add
            X = X + proj
            # Layer Normalisation
            mean = torch.mean(X, dim=-1, keepdim=True)
            variance = torch.mean(torch.square(X - mean), axis=-1 ,keepdims=True)
            std = torch.sqrt(variance + self.epsilon)
            X  = (X-mean)/std
            X = X * self.gamma[2*i] + self.beta[2*i]
            
            # FFN
            ffn_op = torch.add(torch.matmul(nn.functional.relu(torch.add(torch.matmul(X, self.W1[i,:,:]), self.b1[i,:])), self.W2[i,:,:]),self.b2[i,:])
            # FFN Dropout
            if self.training:
                ffn_op = self.dropout_layer(ffn_op)
            else:
                ffn_op = self.identity(ffn_op)
            
            # Add
            X = X + ffn_op
            # Layer Normalisation
            mean = torch.mean(X, dim=-1, keepdim=True)
            variance = torch.mean(torch.square(X - mean), axis=-1 ,keepdims=True)
            std = torch.sqrt(variance + self.epsilon)
            X = (X-mean)/std
            X = X*self.gamma[2*i+1] + self.beta[2*i+1]
        return X            

In [52]:
class STraTS(nn.Module):
    def __init__(self, D, V, d, N, he, dropout, forecast=False):
        super(STraTS, self).__init__()
        total_parameters = 0
        cve_units = int(np.sqrt(d))
        # Inputs: max_len * batch_size
        self.varis_stack = nn.Embedding(V+1, d)
        num_params = sum(p.numel() for p in self.varis_stack.parameters())
        print(f'varis_stack: {num_params}')
        total_parameters += num_params
        
        self.values_stack = CVE(
            hid_dim=cve_units, 
            output_dim=d
        )        
        num_params = sum(p.numel() for p in self.values_stack.parameters())
        print(f'values_stack: {num_params}')
        total_parameters += num_params
        
        self.times_stack = CVE(
            hid_dim=cve_units, 
            output_dim=d
        )        
        num_params = sum(p.numel() for p in self.times_stack.parameters())
        print(f'times_stack: {num_params}')
        total_parameters += num_params
        
        
        # Transformer Output = batch_size * max_len * d
        self.cont_stack = Transformer(
            d=d, 
            N=N, 
            h=he, 
            dk=None, 
            dv=None, 
            dff=None, 
            dropout=dropout, 
            epsilon=1e-07
        )
        num_params = sum(p.numel() for p in self.cont_stack.parameters())
        print(f'cont_stack: {num_params}')
        total_parameters += num_params
        
        # Attention Output = batch_size * max_len * 1 
        self.attn_stack = Attention(
            d=d,
            hid_dim=2*d
        )
        num_params = sum(p.numel() for p in self.attn_stack.parameters())
        print(f'attn_stack: {num_params}')
        total_parameters += num_params
        
        # Demographics Input : batch_size * D
        # Demographics Output: batch_size * d
        self.demo_stack = nn.Sequential(
            nn.Linear(in_features=D, out_features=2*d),
            nn.Tanh(),
            nn.Linear(in_features=2*d, out_features=d),
            nn.Tanh()
        )
        num_params = sum(p.numel() for p in self.demo_stack.parameters())
        print(f'demo_stack: {num_params}')
        total_parameters += num_params
        
        # Output Layer Inputs: Attention Weight * Time Series Embedding + Demographic Encoding = batch_size * (+d)
        if forecast:
            self.output_stack = nn.Linear(in_features=d+d, out_features=V)
        else:
            self.output_stack = nn.Sequential(
                nn.Linear(in_features=d+d, out_features=1),
                nn.Sigmoid()
            )
        num_params = sum(p.numel() for p in self.output_stack.parameters())
        print(f'output_stack: {num_params}')
        total_parameters += num_params
        
        print(f'Total Parameters: {total_parameters}')
    
    def forward(self, demo, times, values, varis):
        
        demo_enc = self.demo_stack(demo)
        varis_emb = self.varis_stack(varis)
        values_emb = self.values_stack(values)
        times_emb = self.times_stack(times)
        print(f'varis_emb: {varis_emb.shape}')
        print(f'values_emb: {values_emb.shape}')
        print(f'times_emb: {times_emb.shape}')
        
        comb_emb = varis_emb + values_emb + times_emb
        print(f'comb_emb: {comb_emb.shape}')
        
        mask = torch.clamp(varis, 0,1)
        print(f'Mask: {mask.shape}')
        
        cont_emb = self.cont_stack(comb_emb, mask)
        print(f'cont_emb: {cont_emb.shape}')
        
        # Calculating the weights for cont_emb
        attn_weights = self.attn_stack(cont_emb, mask)
        print(f'attn_weights: {attn_weights.shape}')
        
        # Getting the weighted avg from the embeddings
        fused_emb = torch.sum(cont_emb * attn_weights, dim=-2)
        print(f'fused_emb: {fused_emb.shape}')
        
        # Combining Time Series Embedding with Demographic Embeddings
        conc = torch.cat([fused_emb, demo_enc], dim=-1)
        print(f'conc: {conc.shape}')
        
        # Generating Output
        output = self.output_stack(conc)
        print(f'output: {output.shape}')
        
        return output
        

In [53]:
model = STraTS(D=2, V=129, d=50, N=2, he=4, dropout=0, forecast=True)

varis_stack: 6500
values_stack: 364
times_stack: 364
cont_stack: 39508
attn_stack: 5200
demo_stack: 5350
output_stack: 13029
Total Parameters: 70315


In [54]:
model

STraTS(
  (varis_stack): Embedding(130, 50)
  (values_stack): CVE(
    (stack): Sequential(
      (0): Linear(in_features=1, out_features=7, bias=True)
      (1): Tanh()
      (2): Linear(in_features=7, out_features=50, bias=False)
    )
  )
  (times_stack): CVE(
    (stack): Sequential(
      (0): Linear(in_features=1, out_features=7, bias=True)
      (1): Tanh()
      (2): Linear(in_features=7, out_features=50, bias=False)
    )
  )
  (cont_stack): Transformer(
    (dropout_layer): Dropout(p=0, inplace=False)
    (identity): Identity()
  )
  (attn_stack): Attention(
    (stack): Sequential(
      (0): Linear(in_features=50, out_features=100, bias=True)
      (1): Tanh()
      (2): Linear(in_features=100, out_features=1, bias=False)
    )
    (softmax): Softmax(dim=-2)
  )
  (demo_stack): Sequential(
    (0): Linear(in_features=2, out_features=100, bias=True)
    (1): Tanh()
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): Tanh()
  )
  (output_stack): Linear(in_fe

In [69]:
varis = torch.randint(high=130, size=(2,880))
values = torch.rand((2,880))
times = torch.rand((2,880))
demo = torch.rand((2,2))

In [70]:
y_pred = model(demo, times,values, varis)

varis_emb: torch.Size([2, 880, 50])
values_emb: torch.Size([2, 880, 50])
times_emb: torch.Size([2, 880, 50])
comb_emb: torch.Size([2, 880, 50])
Mask: torch.Size([2, 880])
cont_emb: torch.Size([2, 880, 50])
attn_weights: torch.Size([2, 880, 1])
fused_emb: torch.Size([2, 50])
conc: torch.Size([2, 100])
output: torch.Size([2, 129])


In [60]:
def forecast_loss(y_true, y_pred):
    return torch.sum(y_true[:,V:]*(y_true[:,:V]- y_pred)**2, dim=-1)


def get_results(y_true, y_pred):
    precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
    pr_auc = auc(recall, precision)
    minrp = np.minimum(precision, recall).max()
    roc_auc = roc_auc_score(y_true, y_pred)
    return [roc_auc, pr_auc, minrp]


def mortality_loss(y_true, y_pred):
    return nn.BCELoss(y_true, y_pred, reduction='mean')


class Evaluation_Callback():
    def __init__(self, val_dataloader):
        self.val_dataloader = val_dataloader
        self.logs = {'epoch': [], 'pr_auc': [], 'roc_auc': [], 'min_rp': [], 'loss': []}

    def on_epoch_end(self, model, epoch, loss_fn):
        model.eval()
        Y = []
        Y_pred = []
        loss
        with torch.no_grad():
            for X, y in self.val_dataloader:
                y_pred = model.predict(X)
                loss += loss_fn(y_pred, y).detach().cpu().item() * len(y)
                Y.append(y); Y_pred.append(y_pred)
                
                
        loss /= len(self.val_dataloader)
        precision, recall, thresholds = precision_recall_curve(Y, Y_pred)
        pr_auc = auc(recall, precision)
        roc_auc = roc_auc_score(Y, Y_pred)
        min_rp = np.minimum(precision, recall).max()
        self.logs['epoch'].append(epoch) 
        self.logs['pr_auc'].append(pr_auc); self.logs['roc_auc'].append(roc_auc); self['min_rp'].append(min_rp)
        self.logs['loss'].append(loss);
        print(f'Val Metrics: PR_AUC: {pr_auc:.6f} ROC_AUC: {roc_auc:.6f} MIN_RP: {min_rp:.6f} BCE_LOSS: {loss:.6f}')
    
        return pr_auc, roc_auc, min_rp, loss
                   
    def get_logs():
        return pd.DataFrame(self.logs)
              

class Early_Stopper():
    def __init__(self, patience=5, min_delta=0, mode='max', restore_best_weights=False):
        self.patience = patience
        self.wait = 0
        self.min_delta = min_delta
        self.mode = mode
        self.stopped_epoch = 0
        self.restore_best_weights = restore_best_weights
        self.best_weights = None
        self.best = np.inf if mode=='min' else -np.inf
        
        if mode=='min':
            self.monitor_op = np.less
        elif mode=='max':
            self.monitor_op = np.greater
              
    def on_epoch_end(self, model, loss, epoch):
        if self.monitor_op(loss - self.min_delta, self.best):
            self.best = loss
            self.wait = 0
            if self.restore_best_weights:
                self.best_weights = model.state_dict()
        else:
            self.wait += 1
            
            if self.wait>=self.patience:
                self.stopped_epoch = epoch
                print(f'Early Stopping at Epoch {epoch} with best loss of {self.best:.6f}')
                if self.restore_best_weights:
                      print(f'Restoring best weights at Epoch {self.stopped_epoch-self.wait}')
                      model.load_state_dict(self.best_weights)
        return

In [61]:
from torch.utils.data import Dataset, DataLoader

class MultipleInputsDataset(Dataset):
    def __init__(self, X, Y):
        super(MultipleInputsDataset, self).__init__()
        
        self.X = X
        self.Y = Y
    
    def __getitem__(self, idx):
        
        return *[x[idx] for x in self.X], self.Y[idx]
        
    def __len__(self):
        return len(self.Y)

In [62]:
import torch
import numpy as np
import gzip

In [77]:
V = 129
def forecast_loss(y_true, y_pred, V):
    return torch.sum(y_true[:,V:]*(y_true[:,:V]- y_pred)**2, dim=-1).mean()

In [78]:
y_pred.shape

torch.Size([2, 129])

In [79]:
f = gzip.GzipFile('/home/FYP/szhong005/fyp/multi_modal/STraTS_torch/forecast_datasets/val_y.npy.gz', "r")
val_y = np.load(f)

In [80]:
y_true = torch.tensor(val_y[0:2],dtype=torch.float32, requires_grad=True)

In [82]:
forecast_loss(y_true, y_pred, V).backward()

In [68]:
y_pred.grad

  y_pred.grad
