In [2]:
import torch
import numpy as np
import torch.nn as nn
from transformer.Modules import ScaledDotProductAttention
from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward


In [None]:
class AttnBlock(nn.Module):

    def __init__(self):
        super(AttnBlock, self).__init__()
        self.attn = MultiHeadAttention()
        self.pff = PositionwiseFeedForward()

    def forward(self, X):
        S, attn = self.attn(X)
        H = self.pff(S)

        return H, attn

In [None]:
class parameters(nn.Module):
    def __init__(self):
        self.alpha = nn.parameter(torch.tensor(0.5))
        self.beta = nn.parameter(torch.tensor(1.0))


In [None]:
class Model(nn.Module):
    def __init__(self, B, T, D, embed_dim, num_attention):
        super(Model, self).__init__()

        self.B = B
        self.T = T
        self.D = D
        self.embed_dim = embed_dim
        self.num_attention = num_attention

        self.time_embedding = None
        self.event_embedding = nn.Embedding(self.D, self.embed_dim)

        AttnBlocks = []
        for _ in range(num_attention):
            AttnBlocks.append(AttnBlock())
        
        self.AttnBlocks = nn.ModuleList(AttnBlocks)

        self.time_pred = nn.Linear(self.embed_dim, 1)
        self.event_pred = nn.Linear(self.embed_dim, self.D)

        ### Params
        self.alpha = nn.parameter(torch.tensor(0.1))
        self.beta = nn.parameter(torch.tensor(1.0))


    def forward(self, time, event):
        time_mask = self.create_time_mask()
        Z = self.time_embedding(time)
        UY = self.event_embedding(event)
        UY = UY * time_mask
        X = Z + UY
        H = X
        
        for attn in self.AttnBlocks:
            H = attn(H)
            H += Z
        
        H -= Z 

        pred_t = self.time_pred(H)
        pred_e = self.event_pred(H)

        return H, pred_t, pred_e

    def create_time_mask(self):
        temp = torch.ones((self.B, self.T))
        triu = torch.triu(temp, diagonal=1)
        mask = triu.unsqueeze(2).expand(-1, -1, self.D)
        return mask

        




