diff --git a/easy_tpp/model/torch_model/torch_rmtpp.py b/easy_tpp/model/torch_model/torch_rmtpp.py index 0e9fc14..da327f5 100644 --- a/easy_tpp/model/torch_model/torch_rmtpp.py +++ b/easy_tpp/model/torch_model/torch_rmtpp.py @@ -24,8 +24,8 @@ def __init__(self, model_config): self.layer_hidden = nn.Linear(self.hidden_size, self.num_event_types) - self.factor_intensity_base = torch.empty([1, 1, self.num_event_types], device=self.device) - self.factor_intensity_current_influence = torch.empty([1, 1, self.num_event_types], device=self.device) + self.factor_intensity_base = torch.nn.Parameter(torch.empty([1, 1, self.num_event_types], device=self.device)) + self.factor_intensity_current_influence = torch.nn.Parameter(torch.empty([1, 1, self.num_event_types], device=self.device)) nn.init.xavier_normal_(self.factor_intensity_base) nn.init.xavier_normal_(self.factor_intensity_current_influence)