In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from pathlib import Path

from wikipedia_cleanup.data_filter import KeepAttributesDataFilter, generate_default_filters
from wikipedia_cleanup.predict import TrainAndPredictFramework
from wikipedia_cleanup.predictor import ZeroPredictor, OnePredictor, MeanPredictor, RandomPredictor
from wikipedia_cleanup.property_correlation import PropertyCorrelationPredictor
from wikipedia_cleanup.random_forest import RandomForestPredictor
from datetime import datetime

import pandas as pd
from tqdm import tqdm

In [None]:
model = RandomForestPredictor(use_cache=False)
framework = TrainAndPredictFramework(model, group_key=['infobox_key', 'property_name'],test_start_date=datetime(2017, 9, 1))

In [None]:
import os
import glob
csv_files = glob.glob(os.path.join("../../custom-format-default-filtered-features/", "*.pickle"))

In [None]:
lst=[]
for f in tqdm(csv_files):
    df_tmp=pd.read_pickle(f)
    lst.append(df_tmp)
data_df=pd.concat(lst)

In [None]:
from datetime import datetime
df_counts = data_df.groupby(['infobox_key', 'property_name'],sort=False).count()["timestamp"]

In [None]:
lst=[]
start=0
for key,count in tqdm(df_counts.iteritems()):
    if count>=250:
        lst.append(data_df[start:start+count])
    start+=count
data_df_small=pd.concat(lst)
data_df_small.shape

In [None]:
#number of keys
data_df_small.groupby(['infobox_key', 'property_name'],sort=False).count()["timestamp"].shape

In [None]:
framework.data = data_df_small

In [None]:
framework.data["value_valid_from"] = pd.to_datetime(framework.data["timestamp"]).dt.tz_localize(None)

In [None]:
group_key=['infobox_key', 'property_name']
framework.data["key"] = list(
            zip(*[framework.data[group_key] for group_key in framework.group_key])
        )

In [None]:
framework.fit_model()

In [None]:
# print(framework.test_model(randomize=False, predict_subset=1))

In [None]:
# framework.generate_plots()

In [None]:
import numpy as np
thresholds = np.linspace(0,1,11)
print(thresholds)
stats={"daily":{"prec":[],"rec":[],"y_hat_sum":[]},
      "weekly":{"prec":[],"rec":[],"y_hat_sum":[]},
      "monthly":{"prec":[],"rec":[],"y_hat_sum":[]},
      "yearly":{"prec":[],"rec":[],"y_hat_sum":[]}}
for threshold in thresholds:
    framework.predictor.threshold=threshold
    framework.test_model(randomize=False, predict_subset=1,save_results=False)
    for key,dic in zip(stats.keys(),framework.pred_stats):
        stats[key]["prec"].append(dic["prec_recall"][0][1])
        stats[key]["rec"].append(dic["prec_recall"][1][1])
        stats[key]["y_hat_sum"].append(np.array(dic["y_hat"]).sum())

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
key="daily"

title = key+" precision and recall"
fig=sns.lineplot(data=pd.DataFrame(stats[key],thresholds).drop(["y_hat_sum"],axis=1))
fig.set_title(title)
fig.set_xlabel("threshold")
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
fig=sns.barplot(x=np.around(thresholds,2),y=stats[key]["y_hat_sum"])
title = key+" absolute predictions"
fig.set_title(title)
fig.set_xlabel("threshold")
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
import seaborn as sns
key="weekly"

title = key+" precision and recall"
fig=sns.lineplot(data=pd.DataFrame(stats[key],thresholds).drop(["y_hat_sum"],axis=1))
fig.set_title(title)
fig.set_xlabel("threshold")
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
fig=sns.barplot(x=np.around(thresholds,2),y=stats[key]["y_hat_sum"])
title = key+" absolute predictions"
fig.set_title(title)
fig.set_xlabel("threshold")
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
import seaborn as sns
key="monthly"

title = key+" precision and recall"
fig=sns.lineplot(data=pd.DataFrame(stats[key],thresholds).drop(["y_hat_sum"],axis=1))
fig.set_title(title)
fig.set_xlabel("threshold")
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
fig=sns.barplot(x=np.around(thresholds,2),y=stats[key]["y_hat_sum"])
title = key+" absolute predictions"
fig.set_title(title)
fig.set_xlabel("threshold")
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
import seaborn as sns
key="yearly"

title = key+" precision and recall"
fig=sns.lineplot(data=pd.DataFrame(stats[key],thresholds).drop(["y_hat_sum"],axis=1))
fig.set_title(title)
fig.set_xlabel("threshold")
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
fig=sns.barplot(x=np.around(thresholds,2),y=stats[key]["y_hat_sum"])
title = key+" absolute predictions"
fig.set_title(title)
fig.set_xlabel("threshold")
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
framework.predictor.threshold=1.0
framework.predictor.min_number_changes=300
print(framework.test_model(randomize=False, predict_subset=1,save_results=False))