Skip to content

Commit

Permalink
debug usgan
Browse files Browse the repository at this point in the history
  • Loading branch information
AugustJW committed Apr 3, 2024
1 parent 590d7be commit 1079577
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 20 deletions.
9 changes: 7 additions & 2 deletions pypots/imputation/brits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def impute(

estimations = torch.cat(estimations, dim=1)
imputed_data = masks * values + (1 - masks) * estimations
return imputed_data, hidden_states, reconstruction_loss
return imputed_data, estimations, hidden_states, reconstruction_loss

def forward(self, inputs: dict, direction: str = "forward") -> dict:
"""Forward processing of the NN module.
Expand All @@ -190,7 +190,7 @@ def forward(self, inputs: dict, direction: str = "forward") -> dict:
A dictionary includes all results.
"""
imputed_data, hidden_state, reconstruction_loss = self.impute(inputs, direction)
imputed_data, estimations, hidden_state, reconstruction_loss = self.impute(inputs, direction)
# for each iteration, reconstruction_loss increases its value for 3 times
reconstruction_loss /= self.n_steps * 3

Expand All @@ -200,6 +200,7 @@ def forward(self, inputs: dict, direction: str = "forward") -> dict:
), # single direction, has no consistency loss
"reconstruction_loss": reconstruction_loss,
"imputed_data": imputed_data,
"reconstructed_data": estimations,
"final_hidden_state": hidden_state,
}
return ret_dict
Expand Down Expand Up @@ -304,6 +305,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
ret_b = self._reverse(self.rits_b(inputs, "backward"))

imputed_data = (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2
reconstructed_data = (ret_f["reconstructed_data"] + ret_b["reconstructed_data"]) / 2

results = {
"imputed_data": imputed_data,
Expand All @@ -323,5 +325,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict:

# `loss` is always the item for backward propagating to update the model
results["loss"] = loss
results['reconstructed_data'] = reconstructed_data
results['f_reconstructed_data'] = ret_f['reconstructed_data']
results['b_reconstructed_data'] = ret_b['reconstructed_data']

return results
14 changes: 4 additions & 10 deletions pypots/imputation/usgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,14 @@ def _train_model(

try:
training_step = 0
epoch_train_loss_G_collector = []
epoch_train_loss_D_collector = []
for epoch in range(1, self.epochs + 1):
self.model.train()
step_train_loss_G_collector = []
step_train_loss_D_collector = []
for idx, data in enumerate(training_loader):
training_step += 1
inputs = self._assemble_input_for_training(data)

step_train_loss_G_collector = []
step_train_loss_D_collector = []

if idx % self.G_steps == 0:
self.G_optimizer.zero_grad()
results = self.model.forward(
Expand All @@ -278,9 +275,6 @@ def _train_model(
mean_step_train_D_loss = np.mean(step_train_loss_D_collector)
mean_step_train_G_loss = np.mean(step_train_loss_G_collector)

epoch_train_loss_D_collector.append(mean_step_train_D_loss)
epoch_train_loss_G_collector.append(mean_step_train_G_loss)

# save training loss logs into the tensorboard file for every step if in need
# Note: the `training_step` is not the actual number of steps that Discriminator and Generator get
# trained, the actual number should be D_steps*training_step and G_steps*training_step accordingly
Expand All @@ -292,8 +286,8 @@ def _train_model(
self._save_log_into_tb_file(
training_step, "training", loss_results
)
mean_epoch_train_D_loss = np.mean(epoch_train_loss_D_collector)
mean_epoch_train_G_loss = np.mean(epoch_train_loss_G_collector)
mean_epoch_train_D_loss = np.mean(step_train_loss_D_collector)
mean_epoch_train_G_loss = np.mean(step_train_loss_G_collector)

if val_loader is not None:
self.model.eval()
Expand Down
16 changes: 8 additions & 8 deletions pypots/imputation/usgan/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ....utils.metrics import calc_mse

from .submodules import Discriminator
from ...brits.modules import _BRITS
Expand Down Expand Up @@ -62,24 +63,23 @@ def forward(
if training:
forward_X = inputs["forward"]["X"]
forward_missing_mask = inputs["forward"]["missing_mask"]

inputs["discrimination"] = self.discriminator(
forward_X, forward_missing_mask
)
imputed_data = results['imputed_data']

if training_object == "discriminator":
inputs["discrimination"] = self.discriminator(imputed_data.detach(), forward_missing_mask)
l_D = F.binary_cross_entropy_with_logits(
inputs["discrimination"], forward_missing_mask
)
results["discrimination_loss"] = l_D
else:
inputs["discrimination"] = inputs["discrimination"].detach()
l_G = F.binary_cross_entropy_with_logits(
inputs["discrimination"] = self.discriminator(imputed_data, forward_missing_mask)
l_G = -F.binary_cross_entropy_with_logits(
inputs["discrimination"],
1 - forward_missing_mask,
forward_missing_mask,
weight=1 - forward_missing_mask,
)
loss_gene = l_G + self.lambda_mse * results["loss"]
reconstruction_loss = calc_mse(forward_X, results['reconstructed_data'], forward_missing_mask) + 0.1 * calc_mse(results['f_reconstructed_data'], results['b_reconstructed_data'])
loss_gene = l_G + self.lambda_mse * reconstruction_loss
results["generation_loss"] = loss_gene

return results

0 comments on commit 1079577

Please sign in to comment.