In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from pathlib import Path

from wikipedia_cleanup.data_filter import KeepAttributesDataFilter, generate_default_filters
from wikipedia_cleanup.predict import TrainAndPredictFramework
from wikipedia_cleanup.predictor import ZeroPredictor, OnePredictor, MeanPredictor, RandomPredictor
from wikipedia_cleanup.property_correlation import PropertyCorrelationPredictor
from wikipedia_cleanup.random_forest import RandomForestPredictor
from datetime import datetime

import pandas as pd
from tqdm import tqdm

In [None]:
model = RandomForestPredictor(use_cache=False)
framework = TrainAndPredictFramework(model, group_key=['infobox_key', 'property_name'],test_start_date=datetime(2017, 9, 1))

In [None]:
import os
import glob
csv_files = glob.glob(os.path.join("../../custom-format-default-filtered-features/", "*.pickle"))

In [None]:
lst=[]
for f in tqdm(csv_files):
    df_tmp=pd.read_pickle(f)
    lst.append(df_tmp)
data_df=pd.concat(lst)

In [None]:
from datetime import datetime
df_counts = data_df.groupby(['infobox_key', 'property_name'],sort=False).count()["timestamp"]

In [None]:
lst=[]
start=0
for key,count in tqdm(df_counts.iteritems()):
    if count>=200:
        lst.append(data_df[start:start+count])
    start+=count
data_df_small=pd.concat(lst)
data_df_small.shape

In [None]:
#number of keys
data_df_small.groupby(['infobox_key', 'property_name'],sort=False).count()["timestamp"].shape

In [None]:
framework.data = data_df_small

In [None]:
framework.data["value_valid_from"] = pd.to_datetime(framework.data["timestamp"]).dt.tz_localize(None)

In [None]:
group_key=['infobox_key', 'property_name']
framework.data["key"] = list(
            zip(*[framework.data[group_key] for group_key in framework.group_key])
        )

In [None]:
framework.fit_model()

In [None]:
# print(framework.test_model(randomize=False, predict_subset=1))

In [None]:
# framework.generate_plots()

In [None]:
import numpy as np
timeframes = ["daily","weekly","monthly","yearly"]
thresholds = np.linspace(0,1,11)
min_changes = np.linspace(200,300,11)
print(thresholds)
print(min_changes)
stats=[]
for min_number_changes in min_changes:
    framework.predictor.min_number_changes=min_number_changes
    for threshold in thresholds:
        framework.predictor.threshold=threshold
        framework.test_model(randomize=False, predict_subset=1,save_results=False)
        for timeframe,dic in zip(timeframes,framework.pred_stats):
            prec=dic["prec_recall"][0][1]
            rec=dic["prec_recall"][1][1]
            number_pred_changes=np.array(dic["y_hat"]).sum()
            stats.append([timeframe,min_number_changes,threshold,prec,rec,number_pred_changes])

In [None]:
columns=["timeframe","min_changes","threshold","precision","recall","number_pred_changes"]
stats_df=pd.DataFrame(stats,columns=columns)
stats_df.to_csv("gridsearch.csv")
stats_df

In [None]:
data

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
timeframe="daily"
data=stats_df[stats_df["timeframe"]==timeframe]
title = timeframe+" precision and recall"
plt.plot(data["threshold"],data["precision"],label="precision")
plt.plot(data["threshold"],data["recall"],label="recall")
plt.xlabel("threshold")
plt.legend()
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
fig=sns.barplot(x=np.around(thresholds,2),y=data["number_pred_changes"])
title = timeframe+" absolute predictions"
fig.set_title(title)
fig.set_xlabel("threshold")
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
import seaborn as sns
timeframe="weekly"
data=stats_df[stats_df["timeframe"]==timeframe]
title = timeframe+" precision and recall"
plt.plot(data["threshold"],data["precision"],label="precision")
plt.plot(data["threshold"],data["recall"],label="recall")
plt.xlabel("threshold")
plt.legend()
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
fig=sns.barplot(x=np.around(thresholds,2),y=data["number_pred_changes"])
title = timeframe+" absolute predictions"
fig.set_title(title)
fig.set_xlabel("threshold")
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
import seaborn as sns
timeframe="monthly"

data=stats_df[stats_df["timeframe"]==timeframe]
title = timeframe+" precision and recall"
plt.plot(data["threshold"],data["precision"],label="precision")
plt.plot(data["threshold"],data["recall"],label="recall")
plt.xlabel("threshold")
plt.legend()
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
fig=sns.barplot(x=np.around(thresholds,2),y=data["number_pred_changes"])
title = timeframe+" absolute predictions"
fig.set_title(title)
fig.set_xlabel("threshold")
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
import seaborn as sns
timeframe="yearly"

data=stats_df[stats_df["timeframe"]==timeframe]
title = timeframe+" precision and recall"
plt.plot(data["threshold"],data["precision"],label="precision")
plt.plot(data["threshold"],data["recall"],label="recall")
plt.xlabel("threshold")
plt.legend()
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
fig=sns.barplot(x=np.around(thresholds,2),y=data["number_pred_changes"])
title = timeframe+" absolute predictions"
fig.set_title(title)
fig.set_xlabel("threshold")
plt.savefig(title.replace(" ","_")+'.png')

In [None]:
framework.predictor.threshold=1.0
framework.predictor.min_number_changes=300
print(framework.test_model(randomize=False, predict_subset=1,save_results=False))

In [None]:
stats_df.query('timeframe=="weekly" and precision>=0.85').sort_values("number_pred_changes",ascending=False)