In [1]:
import numpy as np
import pandas as pd
from datetime import datetime, date
from dateutil.relativedelta import relativedelta
from xgboost import XGBRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import mean_squared_error, accuracy_score

import sys
sys.path.append("..") # in order to be able to use modules in src package

from src.data.load import load_data
from src.data.format import undummify, train_test, output_format
from src.metrics.error import apply_metrics

import warnings
warnings.filterwarnings('ignore')

  from pandas import MultiIndex, Int64Index


## Data preprocessing

In [2]:
df = load_data(category='Literature')
df

Unnamed: 0,Year,Count,Age0,Age1,Age2,Age3,Age4,Sex0,Sex1,Sex2,...,Country4_Spain,Country4_Sweden,Country4_Switzerland,Country4_Trinidad,Country4_Turkey,Country4_Tuscany (Italy),Country4_Ukraine,Country4_Union of Soviet Socialist Republics (Russia),Country4_United Kingdom,Country4_United States of America
0,1901,1,62.0,62.0,62.0,62.0,62.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
1,1902,1,85.0,85.0,85.0,85.0,85.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
2,1903,1,71.0,71.0,71.0,71.0,71.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
3,1904,2,74.0,72.0,74.0,72.0,74.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
4,1905,1,59.0,59.0,59.0,59.0,59.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
104,2012,1,57.0,57.0,57.0,57.0,57.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
105,2013,1,82.0,82.0,82.0,82.0,82.0,1.0,1.0,1.0,...,0,0,0,0,0,0,0,0,0,0
106,2014,1,69.0,69.0,69.0,69.0,69.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
107,2015,1,67.0,67.0,67.0,67.0,67.0,1.0,1.0,1.0,...,0,0,0,0,0,0,1,0,0,0


## Model

In [3]:
train_x, train_y, test_x, test_y = train_test(df)

In [4]:
model = MultiOutputRegressor(XGBRegressor(objective='reg:squarederror', max_depth=5, learning_rate=0.05, n_estimators=500))
model.fit(train_x, train_y)

In [5]:
y_hat = model.predict(test_x)

## Errors

In [6]:
errors = apply_metrics(y_hat, test_y, df.columns)
errors

Unnamed: 0,Count,Age0,Age1,Age2,Age3,Age4,Sex0,Sex1,Sex2,Sex3,Sex4,Country0,Country1,Country2,Country3,Country4
0,0.055556,129.211111,135.8,129.211111,135.8,129.211111,0.855556,0.766667,0.855556,0.766667,0.855556,0.266667,0.266667,0.266667,0.266667,0.266667


## Predictions

In [7]:
df.iloc[-10:, :]

Unnamed: 0,Year,Count,Age0,Age1,Age2,Age3,Age4,Sex0,Sex1,Sex2,...,Country4_Spain,Country4_Sweden,Country4_Switzerland,Country4_Trinidad,Country4_Turkey,Country4_Tuscany (Italy),Country4_Ukraine,Country4_Union of Soviet Socialist Republics (Russia),Country4_United Kingdom,Country4_United States of America
99,2007,1,88.0,88.0,88.0,88.0,88.0,1.0,1.0,1.0,...,0,0,0,0,0,0,0,0,0,0
100,2008,1,68.0,68.0,68.0,68.0,68.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
101,2009,1,56.0,56.0,56.0,56.0,56.0,1.0,1.0,1.0,...,0,0,0,0,0,0,0,0,0,0
102,2010,1,74.0,74.0,74.0,74.0,74.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
103,2011,1,80.0,80.0,80.0,80.0,80.0,0.0,0.0,0.0,...,0,1,0,0,0,0,0,0,0,0
104,2012,1,57.0,57.0,57.0,57.0,57.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
105,2013,1,82.0,82.0,82.0,82.0,82.0,1.0,1.0,1.0,...,0,0,0,0,0,0,0,0,0,0
106,2014,1,69.0,69.0,69.0,69.0,69.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
107,2015,1,67.0,67.0,67.0,67.0,67.0,1.0,1.0,1.0,...,0,0,0,0,0,0,1,0,0,0
108,2016,1,75.0,75.0,75.0,75.0,75.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,1


In [8]:
hist = df.iloc[-10:, 1:].to_numpy()
hist

array([[ 1., 88., 88., ...,  0.,  0.,  0.],
       [ 1., 68., 68., ...,  0.,  0.,  0.],
       [ 1., 56., 56., ...,  0.,  0.,  0.],
       ...,
       [ 1., 69., 69., ...,  0.,  0.,  0.],
       [ 1., 67., 67., ...,  0.,  0.,  0.],
       [ 1., 75., 75., ...,  0.,  0.,  1.]])

In [9]:
preds = output_format(model.predict(hist), df.columns)
preds.insert(0, 'Year', list(range(2017,2027)))
preds

Unnamed: 0,Year,Count,Age0,Age1,Age2,Age3,Age4,Sex0,Sex1,Sex2,Sex3,Sex4,Country0,Country1,Country2,Country3,Country4
0,2017,1.0,48.0,49.0,48.0,49.0,48.0,Male,Male,Male,Male,Male,Prussia (Germany),Prussia (Germany),Prussia (Germany),Prussia (Germany),Prussia (Germany)
1,2018,1.0,48.0,48.0,48.0,48.0,48.0,Female,Female,Female,Female,Female,United States of America,United States of America,United States of America,United States of America,United States of America
2,2019,1.0,60.0,60.0,60.0,60.0,60.0,Male,Male,Male,Male,Male,France,France,France,France,France
3,2020,1.0,71.0,71.0,71.0,71.0,71.0,Male,Male,Male,Male,Male,United States of America,United States of America,United States of America,United States of America,United States of America
4,2021,1.0,65.0,66.0,65.0,66.0,65.0,Male,Male,Male,Male,Male,Germany,Germany,Germany,Germany,Germany
5,2022,1.0,63.0,63.0,63.0,63.0,63.0,Male,Male,Male,Male,Male,France,France,France,France,France
6,2023,1.0,47.0,47.0,47.0,47.0,47.0,Male,Male,Male,Male,Male,Nigeria,Nigeria,Nigeria,Nigeria,Nigeria
7,2024,1.0,49.0,49.0,49.0,49.0,49.0,Female,Female,Female,Female,Female,United States of America,United States of America,United States of America,United States of America,United States of America
8,2025,1.0,62.0,62.0,62.0,62.0,62.0,Male,Male,Male,Male,Male,United States of America,United States of America,United States of America,United States of America,United States of America
9,2026,1.0,69.0,69.0,69.0,69.0,69.0,Male,Male,Male,Male,Male,Germany,Germany,Germany,Germany,Germany
