Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/ground_truth_testing' into groun…
Browse files Browse the repository at this point in the history
…d_truth_testing
  • Loading branch information
Ben Edidin committed Mar 30, 2024
2 parents 2a7adf7 + f29e9f9 commit 7646170
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/scripts/train_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def train(cfg: DictConfig) -> None:
wandb.init(project="oc-data-training", entity="atari-obj-pred", name=cfg.name + cfg.game,
config=typing.cast(Dict[Any, Any], OmegaConf.to_container(cfg)))
wandb.log({"batch_size": cfg.batch_size})
wandb.watch(feature_extract, log="gradients", log_freq=100, idx=1)
wandb.watch(predictor, log="gradients", log_freq=100, idx=2)
wandb.watch(feature_extract, log=None, log_freq=100, idx=1)
wandb.watch(predictor, log=None, log_freq=100, idx=2)

criterion = nn.MSELoss().to(device)
optimizer = torch.optim.Adam(list(feature_extract.parameters()) + list(predictor.parameters()), lr=1e-3)
Expand All @@ -46,7 +46,7 @@ def train(cfg: DictConfig) -> None:
if cfg.ground_truth_masks:
bbox_ints = bboxes * 128
bbox_ints = bbox_ints.int()
masks = torch.zeros(masks.size())
masks = torch.zeros(masks.size(), device=device)
for j in range(len(masks)):
for k in range(len(masks[j])):
masks[j, k, int(bbox_ints[j][0][k][0]): int(bbox_ints[j][0][k][0] + bbox_ints[j][0][k][2]), int(bbox_ints[j][0][k][1]): int(bbox_ints[j][0][k][1] + bbox_ints[j][0][k][3])] = 1
Expand Down

0 comments on commit 7646170

Please sign in to comment.