Skip to content

Commit

Permalink
fix for metric mcov in trainig
Browse files Browse the repository at this point in the history
  • Loading branch information
RRobert92 committed Apr 28, 2023
1 parent 568c783 commit e0e7e6f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
9 changes: 6 additions & 3 deletions tardis_pytorch/dist_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ def _validate(self):
input0_1 = self.Graph0_25.patch_to_segment(
graph=edge_cpu, coord=coord, idx=out_cpu, prune=5, sort=False
)
mcov0_25.append(mcov(input0_1, target))
mcov_m, _ = mcov(input0_1, target)
mcov0_25.append(mcov_m)
except:
mcov0_25.append(0.0)

Expand All @@ -335,7 +336,8 @@ def _validate(self):
input0_5 = self.Graph0_5.patch_to_segment(
graph=edge_cpu, coord=coord, idx=out_cpu, prune=5, sort=False
)
mcov0_5.append(mcov(input0_5, target))
mcov_m, _ = mcov(input0_5, target)
mcov0_5.append(mcov_m)
except:
mcov0_5.append(0.0)

Expand All @@ -344,7 +346,8 @@ def _validate(self):
input0_9 = self.Graph0_9.patch_to_segment(
graph=edge_cpu, coord=coord, idx=out_cpu, prune=5, sort=False
)
mcov0_9.append(mcov(input0_9, target))
mcov_m = mcov(input0_9, target)
mcov0_9.append(mcov_m)
except:
mcov0_9.append(0.0)

Expand Down
2 changes: 0 additions & 2 deletions tardis_pytorch/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,12 @@ def mcov(

# Get GT instances, compute IoU for best mach between GT and input
for j in unique_target:
print(j)
g = targets[targets[:, 0] == j, 1:] # Pick GT instance
w_g = g.shape[0] / targets.shape[0] # ratio of instance to whole PC
iou = []

# Select max IoU (the best mach)
for i in unique_input:
print(i)
p = input[input[:, 0] == i, 1:] # Pick input instance

# Intersection of coordinates between GT and input instances
Expand Down

0 comments on commit e0e7e6f

Please sign in to comment.