We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2f22c74 commit 0f7e9e4Copy full SHA for 0f7e9e4
scripts/train.py
@@ -96,9 +96,15 @@ def evaluate(onlyImproved=False):
96
97
dists[i] = min(dist, dists[i]) # track the best distance
98
# filter the results by the distance, to ignore the outliers
99
- if dists[i] < 0.1:
+ maxValue = 0.1
100
+ if dists[i] < maxValue:
101
totalLoss.append(loss)
102
totalDist.append(dist)
103
+ else:
104
+ # prevent the big "jumps" in the loss and distance when the model is becoming better
105
+ # assuming that maxValue is bigger than the corresponding loss
106
+ totalLoss.append(maxValue)
107
+ totalDist.append(maxValue)
108
continue
109
if not onlyImproved:
110
print('Mean loss: %.5f | Mean distance: %.5f' % (
0 commit comments