In [4]:
import numpy as np
import torch
from datetime import datetime, timedelta
from torch.nn import Linear, ModuleList, Parameter

In [127]:
class DyRepNode(torch.nn.Module):
    def __init__(self, num_nodes, hidden_dim, random_state, first_date, end_datetime, num_neg_samples= 5, num_time_samples = 10,
                 device='cpu', all_comms=False, train_td_max=None):
        super(DyRepNode, self).__init__()
        self.w_t = Parameter(0.5*torch.ones(1))
        self.alpha = Parameter(0.5*torch.ones(1))
        self.psi = Parameter(0.5*torch.ones(1))
        self.omega = Linear(in_features=hidden_dim, out_features=1)
        

        self.W_h = Linear(in_features=hidden_dim, out_features=hidden_dim)
        self.W_event_to_neigh = Linear(in_features=hidden_dim, out_features=hidden_dim)
        self.W_rec_event = Linear(in_features=hidden_dim, out_features=hidden_dim)
        self.W_rec_neigh = Linear(in_features=hidden_dim, out_features=hidden_dim)
        self.W_t = Linear(1,hidden_dim)

    def forward(self, data):
        u, time_delta, time_bar, time_cur,significance,magnitudo = data[:6]

        batch_size = len(u)
        u_all = u.data.cpu().numpy()
        
        # testing일때, A_pred, surv를 초기화 세팅
        if not self.training:
            A_pred, surv, lambda_pred = None, None, None
            A_pred = self.A.new_zeros((batch_size, self.num_nodes, self.num_nodes))
            surv = self.A.new_zeros((batch_size, self.num_nodes, self.num_nodes))

        # *** time의 shape 알고 수정 필요
        # *** time을 normalize할 바에, fixed encoding을 하는 방식을 어떨까?
        time_mean = torch.from_numpy(np.array([0, 0, 0, 0])).float().to(self.device).view(1, 1, 4)
        time_sd = torch.from_numpy(np.array([50, 7, 15, 15])).float().to(self.device).view(1, 1, 4)
        time_diff = (time_diff - time_mean) / time_sd
        return time_diff
    def update_node_embedding_without_attention(self, prev_embedding, u_event,u_neighborhood, time_delta_it):
        """
        주어진 node embedding과 시간 차이를 사용하여 node embedding을 업데이트합니다.

        Args:
        prev_embedding (torch.Tensor): 이전 embedding (shape: [num_nodes, hidden_dim])
        u_event: event가 발생한 노드
        u_neighborhood: event가 발생 노드의 이웃
        time_delta_it (torch.Tensor): 시간 차이 (shape: [batch_size, 1])

        Returns:
        torch.Tensor: 업데이트된 node embedding
        """
        # *** attention 사용하지 않은 버전

        # 이전 embedding을 복제하여 새로운 embedding 생성
        z_new = prev_embedding.clone()

        #neighborhood node에 대한 업데이트
        z_new[u_neighborhood] = torch.sigmoid(self.W_event_to_neigh(prev_embedding[u_event]) + \
                                  self.W_rec_neigh(prev_embedding[u_neighborhood]) + \
                                  self.W_t(time_delta_it[u_neighborhood].view(len(u_neighborhood),3)))
        
        #event node에 대한 update 
        z_new[u_event] = torch.sigmoid(self.W_rec_event(prev_embedding[u_event]) + \
                                  self.W_t(time_delta_it[u_event]))
        return z_new

In [128]:
rnd = np.random.RandomState(42)
model = DyRepNode(num_nodes=4,
                  hidden_dim=2,
                  random_state= rnd,
                  first_date=1,
                  end_datetime=10,
                  num_neg_samples=10,
                  num_time_samples=5,
                  device='cpu',
                  train_td_max=5)


In [129]:
for name, param in model.named_parameters():
    print(name,param.shape,param.data)

w_t torch.Size([1]) tensor([0.5000])
alpha torch.Size([1]) tensor([0.5000])
psi torch.Size([1]) tensor([0.5000])
omega.weight torch.Size([1, 2]) tensor([[-0.5500, -0.6468]])
omega.bias torch.Size([1]) tensor([-0.4260])
W_h.weight torch.Size([2, 2]) tensor([[-0.0780,  0.5686],
        [-0.6080,  0.2216]])
W_h.bias torch.Size([2]) tensor([0.0022, 0.4817])
W_event_to_neigh.weight torch.Size([2, 2]) tensor([[-0.1437,  0.2169],
        [ 0.1113,  0.0093]])
W_event_to_neigh.bias torch.Size([2]) tensor([0.1255, 0.5880])
W_rec_event.weight torch.Size([2, 2]) tensor([[-0.2672, -0.6409],
        [ 0.4996, -0.2756]])
W_rec_event.bias torch.Size([2]) tensor([0.6108, 0.2024])
W_rec_neigh.weight torch.Size([2, 2]) tensor([[ 0.0626, -0.2621],
        [-0.4605, -0.6585]])
W_rec_neigh.bias torch.Size([2]) tensor([0.6138, 0.6745])
W_t.weight torch.Size([2, 1]) tensor([[ 0.0634],
        [-0.6298]])
W_t.bias torch.Size([2]) tensor([-0.7038, -0.1371])


In [131]:
prev_embedding = [[1,1],[2,2],[3,3],[4,4]]
prev_embedding=torch.Tensor(prev_embedding)
time_delta_it =torch.Tensor([[2,3,1],[2,3,1],[2,3,1],[2,3,1]])
u_neighborhood = [1,2]
u_event = 0

In [132]:
model.update_node_embedding_without_attention(prev_embedding, u_event=u_event, u_neighborhood=u_neighborhood, time_delta_it=time_delta_it)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 1x2)