In [None]:
import pandas as pd
import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt

import sys
sys.path.append('../')

from src.d04_modeling.knapsack_classifier import KnapsackClassifier
from src.d04_modeling.naive_classifier import NaiveClassifier
from src.d04_modeling.ctip_classifier import CtipClassifier

In [None]:
# you might have to update the user
frl_key = 'tk12'
model1 = NaiveClassifier(positive_group='nFocal', rate=False, key=frl_key)
model2 = NaiveClassifier(positive_group='nFocal', rate=True, key=frl_key)
model3 = CtipClassifier(positive_group='nFocal', key=frl_key)
# once you have run the model you can just use load=True to avoid having to do the optimization all over
model4 = KnapsackClassifier(positive_group='nFocal', load=True, key=frl_key)

In [None]:
results1 = model1.get_roc()
results2 = model2.get_roc()
results3 = model3.get_roc()
results4 = model4.get_roc()

In [None]:
import matplotlib.pyplot as plt

plt.rcParams['font.size'] = '16'

fig, ax = plt.subplots(figsize=(10,10))

lw = 4
ax.plot(results1['fpr'], results1['tpr'], label='Naive', linewidth=lw)
ax.plot(results2['fpr'], results2['tpr'], label='Naive(rate)', linewidth=lw)
ax.plot(results3['fpr'], results3['tpr'], marker='.', label='CTIP1', markersize=12, linewidth=lw)
ax.plot(results4['fpr'], results4['tpr'], label='Knapsack', linewidth=lw)
ax.set_xlabel('FPR')
ax.set_ylabel('TPR')
ax.legend()
plt.tight_layout()
plt.savefig('results.png')
plt.show()

In [None]:
fpr = 0.08

print(model1.get_confusion_matrix(params=fpr))
print(model2.get_confusion_matrix(params=fpr))
print(model3.get_confusion_matrix(params=fpr))
print(model4.get_confusion_matrix(params=fpr))

In [None]:
knapsack_solution = model4.get_solution_set(fpr=0.08)
naive_solution = model2.get_solution_set(fpr=0.08)

print("Groups in the Knapsack solution and not in the Naive solution:")
knapsack_solution.difference(naive_solution)
print("Groups in the Naive solution and not in the Knapsack solution:")
naive_solution.difference(knapsack_solution)

In [None]:
model3.plot_map(params=fpr)