Skip to content

Commit

Permalink
Fix MGM3
Browse files Browse the repository at this point in the history
  • Loading branch information
ziao-guo committed Oct 15, 2023
1 parent 673ff87 commit d6674d4
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/dataset/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,25 @@ def get_multi_cluster(self, idx):
if self.cls is None or self.cls == 'none':
cls_iterator = random.choice(self.classes)
else:
cls_iterator = self.cls
if (self.test and cfg.EVAL.RAND_CLASS) or (not self.test and cfg.TRAIN.RAND_CLASS):
cls_iterator = random.sample(self.bm.classes, cfg.PROBLEM.NUM_CLUSTERS)
else:
cls_iterator = self.cls
for cls in cls_iterator:
dicts.append(self.get_multi(idx, cls))
ret_dict = {}
for key in dicts[0]:
ret_dict[key] = []
for dic in dicts:
ret_dict[key] += dic[key]
if key != 'gt_perm_mat':
ret_dict[key] = []
for dic in dicts:
ret_dict[key] += dic[key]
else:
ret_dict[key] = {}
for i, dic in enumerate(dicts):
for (idx1, idx2) in dic[key].keys():
new_idx1 = idx1 + i * cfg.PROBLEM.NUM_GRAPHS
new_idx2 = idx2 + i * cfg.PROBLEM.NUM_GRAPHS
ret_dict[key][(new_idx1, new_idx2)] = dic[key][(idx1, idx2)]
return ret_dict


Expand Down

0 comments on commit d6674d4

Please sign in to comment.