In [1]:
import numpy as np
import pandas as pd
from datetime import datetime, date
from dateutil.relativedelta import relativedelta
from xgboost import XGBRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import mean_squared_error, accuracy_score

import sys
sys.path.append("..") # in order to be able to use modules in src package

from src.data.load import load_data
from src.data.format import undummify, train_test, output_format
from src.metrics.error import apply_metrics

import warnings
warnings.filterwarnings('ignore')

  from pandas import MultiIndex, Int64Index


## Data preprocessing

In [2]:
df = load_data()
df

Unnamed: 0,Year,Count,Age0,Age1,Age2,Age3,Age4,Sex0,Sex1,Sex2,...,Country4_Prussia (Poland),Country4_Russia,Country4_Russian Empire (Ukraine),Country4_South Africa,Country4_Spain,Country4_Sweden,Country4_Switzerland,Country4_United Kingdom,Country4_United States of America,Country4_W&uuml;rttemberg (Germany)
0,1901,1,47.0,47.0,47.0,47.0,47.0,0.0,0.0,0.0,...,1,0,0,0,0,0,0,0,0,0
1,1902,1,45.0,45.0,45.0,45.0,45.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
2,1903,1,42.0,42.0,42.0,42.0,42.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
3,1904,1,55.0,55.0,55.0,55.0,55.0,0.0,0.0,0.0,...,0,1,0,0,0,0,0,0,0,0
4,1905,1,61.0,61.0,61.0,61.0,61.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
102,2012,3,79.0,50.0,50.0,79.0,50.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
103,2013,5,63.0,64.0,64.0,57.0,57.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
104,2014,3,75.0,51.0,52.0,75.0,51.0,0.0,1.0,0.0,...,0,0,0,0,0,0,0,0,0,0
105,2015,3,85.0,80.0,84.0,85.0,80.0,0.0,0.0,1.0,...,0,0,0,0,0,0,0,0,0,0


## Model

In [3]:
train_x, train_y, test_x, test_y = train_test(df)

In [4]:
model = MultiOutputRegressor(XGBRegressor(objective='reg:squarederror', max_depth=5, learning_rate=0.05, n_estimators=500))
model.fit(train_x, train_y)

In [5]:
y_hat = model.predict(test_x)

## Errors

In [6]:
errors = apply_metrics(y_hat, test_y, df.columns)
errors

Unnamed: 0,Count,Age0,Age1,Age2,Age3,Age4,Sex0,Sex1,Sex2,Sex3,Sex4,Country0,Country1,Country2,Country3,Country4
0,0.960227,217.022727,210.909091,230.920455,249.170455,150.329545,0.954545,0.767045,0.988636,0.943182,0.931818,0.443182,0.369318,0.426136,0.306818,0.482955


## Predictions

In [7]:
df.iloc[-10:, :]

Unnamed: 0,Year,Count,Age0,Age1,Age2,Age3,Age4,Sex0,Sex1,Sex2,...,Country4_Prussia (Poland),Country4_Russia,Country4_Russian Empire (Ukraine),Country4_South Africa,Country4_Spain,Country4_Sweden,Country4_Switzerland,Country4_United Kingdom,Country4_United States of America,Country4_W&uuml;rttemberg (Germany)
97,2007,4,70.0,70.0,66.0,82.0,70.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
98,2008,3,72.0,61.0,76.0,72.0,61.0,0.0,1.0,0.0,...,0,0,0,0,0,0,0,0,0,0
99,2009,5,61.0,48.0,57.0,57.0,57.0,1.0,1.0,0.0,...,0,0,0,0,0,0,0,1,0,0
100,2010,1,85.0,85.0,85.0,85.0,85.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,1,0,0
101,2011,4,53.0,53.0,70.0,68.0,53.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,1,0
102,2012,3,79.0,50.0,50.0,79.0,50.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
103,2013,5,63.0,64.0,64.0,57.0,57.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
104,2014,3,75.0,51.0,52.0,75.0,51.0,0.0,1.0,0.0,...,0,0,0,0,0,0,0,0,0,0
105,2015,3,85.0,80.0,84.0,85.0,80.0,0.0,0.0,1.0,...,0,0,0,0,0,0,0,0,0,0
106,2016,1,71.0,71.0,71.0,71.0,71.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0


In [8]:
hist = df.iloc[-10:, 1:].to_numpy()
hist

array([[ 4., 70., 70., ...,  0.,  0.,  0.],
       [ 3., 72., 61., ...,  0.,  0.,  0.],
       [ 5., 61., 48., ...,  1.,  0.,  0.],
       ...,
       [ 3., 75., 51., ...,  0.,  0.,  0.],
       [ 3., 85., 80., ...,  0.,  0.,  0.],
       [ 1., 71., 71., ...,  0.,  0.,  0.]])

In [9]:
preds = output_format(model.predict(hist), df.columns)
preds.insert(0, 'Year', list(range(2017,2027)))
preds

Unnamed: 0,Year,Count,Age0,Age1,Age2,Age3,Age4,Sex0,Sex1,Sex2,Sex3,Sex4,Country0,Country1,Country2,Country3,Country4
0,2017,1.0,59.0,53.0,52.0,73.0,47.0,Male,Male,Male,Male,Male,Denmark,Denmark,Denmark,Denmark,Denmark
1,2018,2.0,53.0,64.0,47.0,60.0,57.0,Male,Male,Male,Male,Male,Netherlands,Italy,Netherlands,United States of America,Italy
2,2019,2.0,76.0,72.0,65.0,80.0,62.0,Male,Male,Male,Male,Male,United States of America,United States of America,United States of America,United States of America,United States of America
3,2020,2.0,77.0,63.0,74.0,71.0,64.0,Female,Female,Female,Female,Female,Switzerland,Portugal,Portugal,United States of America,United Kingdom
4,2021,3.0,57.0,57.0,56.0,61.0,56.0,Male,Male,Male,Male,Male,Austria-Hungary (Czech Republic),Austria-Hungary (Czech Republic),Argentina,Austria-Hungary (Czech Republic),Austria-Hungary (Czech Republic)
5,2022,2.0,56.0,73.0,54.0,65.0,56.0,Male,Male,Male,Male,Male,United States of America,United States of America,United States of America,United States of America,United States of America
6,2023,2.0,60.0,58.0,64.0,53.0,57.0,Male,Male,Male,Male,Male,Belgium,Germany,United States of America,Switzerland,Belgium
7,2024,2.0,56.0,66.0,53.0,61.0,57.0,Male,Male,Male,Male,Male,United States of America,Italy,United States of America,United States of America,Germany
8,2025,2.0,81.0,68.0,78.0,72.0,72.0,Female,Female,Female,Female,Female,United States of America,United States of America,United Kingdom,United States of America,United Kingdom
9,2026,0.0,43.0,43.0,45.0,49.0,48.0,Male,Male,Male,Male,Male,United States of America,United States of America,United States of America,United States of America,United States of America
