In [1]:
# imports
import pandas as pd
from statsmodels.tsa.statespace.sarimax import SARIMAX

pd.options.display.max_columns = None

In [2]:
# read data
dtype = {'ID_LAT_LON_YEAR_WEEK':'string',
         'latitude': 'string',
         'longitude': 'string',
         'year': 'int',
         'week_no': 'int',
         'emission': 'float'}
df = pd.read_csv('files/train.csv', dtype=dtype)
df_test = pd.read_csv('files/test.csv', dtype=dtype)
# df

In [3]:
# prepare data
df['id'] = df['ID_LAT_LON_YEAR_WEEK'].str[:16]
day_of_week = {2019:2, 2020:3, 2021:5, 2022:6, 2023:0}
df.loc[:, 'day_of_week'] = df['year'].map(day_of_week)
df.loc[:, 'date'] = df['year'].astype('string') + '-' + df['week_no'].astype('string') + '-' + df['day_of_week'].astype('string')
df.loc[:, 'date'] = pd.to_datetime(df['date'], format='%Y-%W-%w')

df_test.insert(0, 'id', df_test['ID_LAT_LON_YEAR_WEEK'].str[:16])
# df

In [4]:
endog_groups = df[['id', 'date', 'emission']].groupby('id')

exog = df.drop(columns=['ID_LAT_LON_YEAR_WEEK', 'latitude', 'longitude', 'year', 'week_no', 'emission', 'day_of_week'])
exog.columns = list(range(70)) + ['id', 'date']
exog.insert(0, 'date', exog.pop('date'))
exog.insert(0, 'id', exog.pop('id'))
exog_groups = exog.groupby('id')

df_test = df_test.sort_values(['id', 'year', 'week_no']).drop(columns=['ID_LAT_LON_YEAR_WEEK', 'latitude', 'longitude', 'year', 'week_no'])
df_test.columns = ['id'] + list(range(70))
test_groups = df_test.groupby('id')

In [5]:
# function to create model and forecast
def sarimax(endog, exog, exog_test, steps=49):
    if len(exog.columns) < 1:
        exog = None
        exog_test = None
    else:
        exog = exog.values
        exog_test = exog_test.values
    arima = SARIMAX(endog = endog.loc[:, 'emission'].values,
                    exog = exog,
                    order = (1, 0, 0),
                    seasonal_order = (1, 0, 0, 12),
                    dates = endog.index.values,
                    freq = 'W')
    arima = arima.fit(full_output = False,
                      disp = False)
    return arima.forecast(steps = steps,
                          exog = exog_test)

In [6]:
%%time
# run sarimax function for every location
results = {}

for i, (name, endog_group) in enumerate(endog_groups):
    if i%10 == 0: print(f'{int(100*i/endog_groups.ngroups)} %')
    df_endog_id = endog_group.drop(columns='id').set_index('date', drop=True).resample('W').nearest()
    df_exog_id = exog_groups.get_group(name).drop(columns='id').set_index('date', drop=True).resample('W').nearest()
    df_exog_id = df_exog_id.loc[:, df_exog_id.isna().mean() < .1].interpolate().bfill().ffill()
    df_exog_id_test = test_groups.get_group(name).drop(columns='id').interpolate().bfill().ffill()#.set_index('date', drop=True).resample('W').nearest()
    df_exog_id_test = df_exog_id_test[df_exog_id.columns]
    test = sarimax(df_endog_id, df_exog_id, df_exog_id_test)
    results[name] = test
df_results = pd.DataFrame(results).T.reset_index()
df_results

0 %
2 %
4 %
6 %
8 %
10 %
12 %
14 %
16 %
18 %
20 %
22 %
24 %
26 %
28 %
30 %
32 %
34 %
36 %
38 %
40 %
42 %
44 %
46 %
48 %
50 %
52 %
54 %
56 %
58 %
60 %
62 %
64 %
66 %
68 %
70 %
72 %
74 %
76 %
78 %
80 %
82 %
84 %
86 %
88 %
90 %
92 %
94 %
96 %
98 %
CPU times: total: 42min 36s
Wall time: 42min 43s


Unnamed: 0,index,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48
0,ID_-0.510_29.290,3.557931,4.021571,4.228842,4.180763,4.396011,4.442379,4.373865,4.149057,4.008408,4.334159,4.264896,4.291760,4.108385,3.854099,3.799127,4.284089,5.139032,3.646203,3.967296,4.325194,4.114568,4.049726,3.938128,4.444204,4.513852,4.606270,4.962422,4.980846,4.635346,4.979567,5.167398,4.699570,5.293910,5.135652,5.128925,4.749040,5.112742,4.990169,4.787485,5.120473,4.774270,4.963088,5.188333,5.451555,5.757078,5.462543,-2.998600,4.611002,4.720108
1,ID_-0.528_29.472,3.943257,4.139136,3.711173,4.411129,4.966229,4.271215,4.792955,4.452671,5.146561,4.823963,4.064205,4.839897,3.647800,4.691599,3.975714,4.251140,3.992921,4.498103,4.193973,3.635008,4.102537,3.812916,3.847900,3.600298,4.578827,3.948900,4.272450,4.549233,4.040513,3.710136,3.077747,4.523749,4.490182,4.771691,4.435225,4.787127,4.782205,4.563561,4.563128,5.284360,4.304702,4.820553,5.287671,4.846936,5.019038,4.928578,-1.015940,5.041694,5.299898
2,ID_-0.547_29.653,0.509983,0.521157,0.596430,0.612491,0.671968,0.663895,0.653887,0.628079,0.669266,0.676091,0.581665,0.664970,0.578227,0.587866,0.586644,0.648528,0.574033,0.530023,0.618882,0.591404,0.628000,0.612338,0.646111,0.665387,0.690106,0.724266,0.695627,0.664231,0.708567,0.683962,0.700591,0.668328,0.650919,0.693783,0.660005,0.708893,0.696209,0.739766,0.768852,0.699776,0.717497,0.734329,0.708342,0.755312,0.685725,0.683316,1.172501,0.700853,0.644634
3,ID_-0.569_30.031,111.141777,112.756023,118.340069,111.077076,117.723483,112.769327,111.307350,122.147791,105.381222,116.995858,113.736140,118.941581,112.430697,114.788888,116.411413,119.254418,110.306833,100.089171,115.166198,115.037615,117.460161,109.744354,127.469962,119.045996,127.192129,120.304693,122.026241,131.050275,123.854125,118.916423,130.667349,121.659269,133.195793,130.322208,126.953080,138.604792,140.360741,140.071611,143.225390,154.730976,148.420270,148.638498,145.598676,154.137009,151.185656,147.364189,136.830393,141.576879,146.711366
4,ID_-0.598_29.102,0.085040,0.087638,0.092494,0.081443,0.085905,0.080510,0.086999,0.084304,0.081653,0.098606,0.089413,0.098903,0.091384,0.078396,0.087805,0.090388,0.087671,0.083034,0.091211,0.091098,0.095121,0.096273,0.102155,0.104953,0.104791,0.102473,0.099889,0.108461,0.098990,0.106861,0.100158,0.107782,0.106366,0.101592,0.103495,0.110388,0.102520,0.105045,0.110088,0.104919,0.115314,0.120378,0.103372,0.117488,0.114192,0.115054,0.475771,0.113678,0.102089
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
492,ID_-3.153_30.347,14.546875,16.282235,16.333778,14.458778,15.116262,13.725311,12.673054,15.167554,13.504206,14.788981,13.703463,14.480126,13.764526,13.886384,13.531336,12.127219,12.115571,12.409697,12.825178,13.465983,12.395833,13.169396,13.706166,13.125586,13.685224,12.902508,16.085602,14.958720,15.352548,13.783130,14.795489,16.019856,13.973604,16.932860,14.948413,15.147626,15.832077,16.363519,17.658993,17.033076,18.463921,19.180388,17.239568,18.676504,17.123289,17.701537,-1984.973769,16.976332,16.524654
493,ID_-3.161_28.839,0.121577,0.122845,0.133933,0.124609,0.121229,0.119897,0.119291,0.120269,0.112213,0.126106,0.112756,0.116253,0.114023,0.107206,0.118116,0.122280,0.116294,0.107306,0.116031,0.106808,0.115541,0.123468,0.128867,0.136007,0.130077,0.129544,0.138716,0.145040,0.136543,0.138160,0.149880,0.144066,0.156082,0.142703,0.144954,0.145089,0.152331,0.151148,0.152117,0.151720,0.152095,0.156444,0.150383,0.163314,0.149860,0.148855,0.108589,0.147971,0.140025
494,ID_-3.174_29.926,41.413277,42.008218,46.044248,40.522666,43.309379,41.892084,44.236698,40.474114,43.871784,46.178871,43.959333,43.702541,39.990440,37.864844,42.980311,41.100008,41.250393,42.950987,46.467020,43.839055,44.301219,42.213471,44.613508,46.866734,45.124255,44.623655,36.357544,48.773070,46.670939,45.577948,48.679149,47.792531,46.757526,50.043649,48.028568,46.470145,50.133849,50.773354,53.334911,53.043871,54.310181,53.150985,54.403314,52.980639,51.286253,52.304010,57.560355,48.901656,48.464904
495,ID_-3.287_29.713,42.114827,42.303540,44.566933,40.956632,43.365380,42.764105,41.748787,42.460561,43.384202,41.530855,39.564700,40.334901,37.675009,35.802387,38.896775,36.638168,37.271127,35.194471,35.841954,38.845323,39.593388,39.601737,37.576370,42.128304,39.064763,39.302907,40.814982,43.461954,44.378062,39.659762,43.936934,44.014116,39.886173,45.300916,43.755224,44.844127,44.248590,45.312266,48.234027,49.900327,49.935275,55.847416,51.905009,52.490302,52.802887,48.535434,44.607366,48.793599,49.255706


In [7]:
# format output
df_results = df_results.melt(id_vars=['index'], value_vars=range(49))
df_results['id'] = df_results['index'] + '_2022_' + df_results['variable'].astype('string').str.zfill(2)
df_results = df_results[['id', 'value']]
df_results = df_results.sort_values('id').reset_index(drop=True)
df_results.columns = ['ID_LAT_LON_YEAR_WEEK', 'emission']
df_results

Unnamed: 0,ID_LAT_LON_YEAR_WEEK,emission
0,ID_-0.510_29.290_2022_00,3.557931
1,ID_-0.510_29.290_2022_01,4.021571
2,ID_-0.510_29.290_2022_02,4.228842
3,ID_-0.510_29.290_2022_03,4.180763
4,ID_-0.510_29.290_2022_04,4.396011
...,...,...
24348,ID_-3.299_30.301_2022_44,33.842067
24349,ID_-3.299_30.301_2022_45,32.589068
24350,ID_-3.299_30.301_2022_46,-750.855654
24351,ID_-3.299_30.301_2022_47,32.777538


In [8]:
# save output to csv
df_results.to_csv('output/sarimax.csv', index=False)