Skip to content

Commit

Permalink
skip incomplete last batch to fix issue
Browse files Browse the repository at this point in the history
  • Loading branch information
quajak committed Apr 19, 2024
1 parent a61cd1d commit e40a4f1
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def eval(cfg: DictConfig) -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"

data_loader = instantiate(cfg.data_loader, model=cfg.model, game=cfg.game, num_obj=cfg.num_objects, val_pct=0, test_pct=0.3, max_data=10000)
data_loader = instantiate(cfg.data_loader, model=cfg.model, game=cfg.game, num_obj=cfg.num_objects, val_pct=0, test_pct=0.3)

t = cfg.t
feature_extractor_state = torch.load(f"models/trained/{cfg.game}/{t}_feat_extract.pth", map_location=device)
Expand All @@ -36,7 +36,9 @@ def eval(cfg: DictConfig) -> None:
ninetieth = []

for i in tqdm(range(data_loader.num_train + data_loader.num_val, len(data_loader.frames), cfg.batch_size)):
images, bboxes, masks, actions = data_loader.sample_idxes(cfg.time_steps, device, range(i, min(i + cfg.batch_size, len(data_loader.frames)-cfg.time_steps)))
if i + cfg.batch_size > len(data_loader.frames) - cfg.time_steps:
break
images, bboxes, masks, actions = data_loader.sample_idxes(cfg.time_steps, device, range(i, i + cfg.batch_size))
if cfg.ground_truth_masks:
masks = get_ground_truth_masks(bboxes, masks.shape, device=device)

Expand Down

0 comments on commit e40a4f1

Please sign in to comment.