diff --git a/deeprank/learn/NeuralNet.py b/deeprank/learn/NeuralNet.py
index 64440fca..6f080635 100644
--- a/deeprank/learn/NeuralNet.py
+++ b/deeprank/learn/NeuralNet.py
@@ -993,7 +993,7 @@ def _plot_boxplot_class(self,figname):
for pts,t in zip(out,tar):
r = F.softmax(torch.FloatTensor(pts), dim=0).data.numpy()
data[t].append(r[1])
- confusion[t][r[1]>0.5] += 1
+ confusion[t][bool(r[1]>0.5)] += 1
#print(" {:5s}: {:s}".format(l,str(confusion)))
diff --git a/deeprank/learn/rankingMetrics.py b/deeprank/learn/rankingMetrics.py
index a60bf891..d783a76a 100644
--- a/deeprank/learn/rankingMetrics.py
+++ b/deeprank/learn/rankingMetrics.py
@@ -16,8 +16,8 @@ def hitrate(rs):
Example:
- >>> r = [0,1,1]
- >>> hit_rate(r,nr)
+ >>> rs = [0,1,1]
+ >>> hitrate(r)
Attributes:
@@ -27,14 +27,34 @@ def hitrate(rs):
Returns:
hirate (array): [recall@1,recall@2,...]
"""
- nr = np.max((1,np.sum(rs)))
+ nr = np.max((1, np.sum(rs)))
return np.cumsum(rs) / nr
+def success(rs):
+ """Success for positions ≤ k.
+
+ Example:
+ >>> rs = [0, 0, 1, 0, 1, 0]
+ >>> success(rs)
+ [0, 0, 1, 1, 1, 1]
+
+ Args:
+ rs (array): binary relevance array
+
+ Returns:
+ success (array): [success@≤1, success@≤2,...]
+ """
+ success = np.cumsum(rs) > 0
+
+ return success.astype(np.int)
+
+
def avprec(rs):
- return [average_precision(rs[:i]) for i in range(1,len(rs))]
+ return [average_precision(rs[:i]) for i in range(1, len(rs))]
+
-def recall(rs,nr):
+def recall(rs, nr):
"""recall rate
First element is rank 1, Relevance is binray
@@ -56,6 +76,7 @@ def recall(rs,nr):
return np.sum(rs)/nr
+
def mean_reciprocal_rank(rs):
"""Score is reciprocal of the rank of the first relevant item
@@ -272,4 +293,4 @@ def ndcg_at_k(r, k, method=0):
dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
if not dcg_max:
return 0.
- return dcg_at_k(r, k, method) / dcg_max
\ No newline at end of file
+ return dcg_at_k(r, k, method) / dcg_max
diff --git a/deeprank/utils/cal_hitrate_successrate.py b/deeprank/utils/cal_hitrate_successrate.py
new file mode 100644
index 00000000..74bd9683
--- /dev/null
+++ b/deeprank/utils/cal_hitrate_successrate.py
@@ -0,0 +1,169 @@
+import numpy as np
+import pandas as pd
+from deeprank.learn import rankingMetrics
+
+
+def evaluate(data):
+ '''
+ Calculate success rate and hit rate.
+
+
+ data: a data frame.
+
+ label caseID modelID target DR HS
+ Test 1AVX 1AVX_ranair-it0_5286 0 0.503823 6.980802
+ Test 1AVX 1AVX_ti5-itw_354w 1 0.502845 -95.158100
+ Test 1AVX 1AVX_ranair-it0_6223 0 0.511688 -11.961460
+
+