Skip to content

Commit

Permalink
Added median and percentile loss metrics for each time step
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Edidin committed Apr 17, 2024
1 parent 62bd8b2 commit 837cd15
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,19 @@ def eval_metrics(
/ torch.sum(movement_mask)
).item()

l1_error = torch.abs(target[mask] - output[mask])[:, -1]
log_dict["avg_l1"] = torch.sum(l1_error) / torch.sum(mask)
log_dict["med_l1"] = torch.median(l1_error)
log_dict["ninetieth_l1"] = torch.quantile(l1_error, 0.9)
log_dict |= {"std_mean": std.mean(), "std_std": std.std(), "corr_mean": corr.mean(), "corr_std": corr.std()}
log_dict |= {"max_loss": max_loss, "average_movement": average_movement}

for t in range(cfg.time_steps):
if t != 0:
movement_mask = target[:, t, :, :] - target[:, 0, :, :] != 0
total_movement = torch.sum(torch.abs((target[:, t, :, :] - target[:, 0, :, :])))
log_dict[f"average_movement/time_{t}"] = total_movement / torch.sum(movement_mask)
log_dict[f"l1_movement_average/time_{t}"] = torch.sum(
torch.abs(torch.squeeze(target[:, t, :, :], dim=1)[movement_mask] - torch.squeeze(output[:, t, :, :], dim=1)[movement_mask])
) / torch.sum(movement_mask)
l1 = torch.abs(target[:, t, :, :][movement_mask] - output[:, t, :, :][movement_mask])
log_dict[f"l1_movement_average/time_{t}"] = torch.mean(l1)
log_dict[f"l1_movement_median/time_{t}"] = torch.median(l1)
log_dict[f"l1_movement_90th_percentile/time_{t}"] = torch.quantiel(l1, 0.9)

log_dict[f"error/time_{t}"] = diff[:, t, :, :].mean()

log_dict = {f"{prefix}/{key}": value for key, value in log_dict.items()}
Expand Down

0 comments on commit 837cd15

Please sign in to comment.