# Diversity metric #
To compare the diversity within one solution or between two solutions

In [19]:
import pandas as pd
baseline = pd.read_csv("predictions_baseline")
R1 = pd.read_csv("../Data/sub1.csv.gz")

def overlap(pred1, pred2):
    # the overlap between two predictions (aka jaccard)
    intersection = len(pred1.intersection(pred2))
    union = len(pred1.union(pred2))
    return intersection / union

In [20]:
sample_base = baseline.sample(n=1000).reset_index()
sample_R1 = R1.sample(n=1000).reset_index()

In [21]:
def diversity_cross(predictions1, predictions2):
    # The average of 1-overlap ratio between all rows in the sample
    diversity = 0
    comparisons = 0
    for i, row1 in predictions1.iterrows():
        pred1 = set([int(article) for article in row1["prediction"].split()])
        for j, row2 in predictions2.iterrows():
            if j!=i:  # Avoid comparing to itself
                pred2 = set([int(article) for article in row2["prediction"].split()])
                diversity += 1 - overlap(pred1, pred2) # overlap indicates similarity, so complement it
                comparisons += 1
    assert comparisons > 0
    return diversity / comparisons


print(f"Within base {diversity_cross(sample_base,sample_base)}")
print(f"Within R3 {diversity_cross(sample_R1,sample_R1)}")
# print(f"Across {diversity_cross(sample_base,sample_R1)}")
# print(f"Across {diversity_cross(sample_R1, sample_base)}")

Within base 0.20636648257926843
Within R3 0.3890300100815666


In [31]:
sample_base_match = baseline.sample(n=300000).reset_index() # no cross so number can be larger
combined = sample_base_match.merge(R1, on="customer_id", suffixes=("_base", "_R1"))
print(len(combined))
assert len(combined) == 300000

300000


In [32]:
combined.head()

Unnamed: 0,index,customer_id,prediction_base,prediction_R1
0,1276835,ee4b1d5147c5aa33157c80207cde43875c2cd07386ae28...,0738567003 0798773001 0924243001 0924243002 09...,0738567003 0798773001 0918522001 0915529003 09...
1,471097,57fe4fe2181fc1ea2506595fc4ead6a025cdd898f51887...,0842004003 0924243001 0924243002 0918522001 09...,0842004003 0918522001 0915529003 0915529005 09...
2,1075347,c8ad587fe30a1cca6f6ea07dee541c2e5526172157ff25...,0924243001 0924243002 0918522001 0923758001 08...,0924243001 0924243002 0918522001 0923758001 08...
3,181544,21cf0d41af3a46280c7b1e42fbb52490758cd9b9e2e3d7...,0924243001 0924243002 0918522001 0923758001 08...,0924243001 0924243002 0918522001 0923758001 08...
4,312421,3a5fd26e6063efaf3725b9824bf97c4abb94171048e327...,0893691001 0887714003 0896409001 0859101009 09...,0893691001 0887714003 0896409001 0859101009 09...


In [33]:
def diversity_match(combined):
    # The average of 1-overlap ratio between all rows in the sample
    diversity = 0
    comparisons = 0
    
    for i, row in combined.iterrows():
        pred1 = set([int(article) for article in row["prediction_base"].split()])
        pred2 = set([int(article) for article in row["prediction_R1"].split()])
        diversity += 1 - overlap(pred1, pred2) # overlap indicates similarity, so complement it
        comparisons += 1
    
    assert comparisons > 0
    return diversity / comparisons

print(f"Across {diversity_match(combined)}")

Across 0.1616508812605254
