In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import datetime as dt
import matplotlib.style
import os
from dateutil.relativedelta import relativedelta
import pandas as pd

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Read the true and predicted values

In [None]:
path = '/content/drive/MyDrive/news-based-forecasting/price_data/'
data_length = 2915
num_yrs_to_plot = 8
kgs_in_quintal = 1
start = dt.datetime(2012,12,31)
crops = ['Onion', 'Potato', 'Rice', 'Wheat']
arima_preds = dict()
earima_preds = dict()
ren_preds = dict()
lstm_preds = dict()
true_prices = dict()
plot_diff = 0
plot_errors = True
plot_returns = False
raw_returns = False

for crop in crops:
  if plot_returns:
    true_prices[crop] = np.diff(np.ravel(np.load(f'{path}true_{crop}_returns.npy')), plot_diff)[-data_length:]/kgs_in_quintal
  else:
    true_prices[crop] = np.diff(np.ravel(pd.read_csv(f'{path}{crop}.csv', parse_dates = ['date'], index_col = ['date'])), plot_diff)[-data_length:]/kgs_in_quintal
  arima_preds[crop] = np.diff(np.ravel(np.load(f'{path}arima_{crop}.npy')), plot_diff)[-data_length:]/kgs_in_quintal
  earima_preds[crop] = np.diff(np.ravel(np.load(f'{path}earima_{crop}.npy')), plot_diff)[-data_length:]/kgs_in_quintal
  file_suffix = '_returns' if plot_returns else ''
  ren_preds[crop] = np.diff(np.ravel(np.load(f'{path}ren_{crop}{file_suffix}.npy')), plot_diff)[-data_length:]/kgs_in_quintal
  lstm_preds[crop] = np.diff(np.ravel(np.load(f'{path}lstm_{crop}{file_suffix}.npy')), plot_diff)[-data_length:]/kgs_in_quintal
  if raw_returns:
    ren_preds[crop] = np.exp(ren_preds[crop]) - 1
    lstm_preds[crop] = np.exp(lstm_preds[crop]) - 1
    true_prices[crop] = np.exp(true_prices[crop]) - 1

  if plot_errors:
    arima_preds[crop] -= true_prices[crop]
    earima_preds[crop] -= true_prices[crop]
    ren_preds[crop] -= true_prices[crop]
    lstm_preds[crop] -= true_prices[crop]



In [None]:
path_2 = '/content/drive/MyDrive/MSOM_price_data/'
df_o = pd.read_csv(f'{path_2}Onion.csv', parse_dates = ['date'], index_col = ['date'])

In [None]:
np.ravel(df_o.iloc[2557:].values)

array([1322.89976134, 1344.678487  , 1325.25295508, ..., 2226.94117647,
       2247.84313725, 2380.49019608])

### Plot daily forecast errors

In [None]:
#### Selective dates for x-ticks
dateindexes = [0,181,365,365+181,365+365,365+365+181,365+365+365]
dateindexes2 = [0]
leap_years = [7,15]
lpyr = 0
#for i in range(1,2*8+1): ### 8= total years
#    if i in leap_years:
#      lpyr+=1

#    if i%2==0:
#        dateindexes2.append(dateindexes2[i-1]+(365+lpyr-181))
#    else:
#       dateindexes2.append(dateindexes2[i-1]+181+lpyr)

#xdates  = []


In [None]:
xdates = [str(start.date())]
tmp = start.date()
#for k in dateindexes2:
#    xdates.append(str(start.date()+dt.timedelta(k)))
for i in range(num_yrs_to_plot):
    tmp += relativedelta(years=1)
    xdates.append(str(tmp))
#### Converting x-tick labels into yy-mmm format
dates = [dt.datetime.strptime(k, '%Y-%m-%d').date().strftime('%d-%b-%y') for k in xdates]
#### Generating all dates
xxdates = []
for i in range(data_length):
    xxdates.append(str(start.date()+dt.timedelta(i)))

In [None]:
rnn_color = 'aquamarine'
rnn_alpha = 1.0

ren_color = 'red'
ren_alpha = 0.6

earima_color = 'red'
earima_alpha = 1.0
earima_name = 'AR+events'

arima_color = 'aquamarine'
arima_alpha = 1.0
arima_name = 'AR'

true_color = 'royalblue'
true_alpha = 0.7

rmse_fontsize = 35
ytick_fontsize = 40

ren_linewidth = 1.0
rnn_linewidth = 3.0
arima_linewidth = 3.0
earima_linewidth = 1.0
true_linewidth = 3.0

rnn_order = 1
ren_order = 2
true_order = 1
arima_order = 1
earima_order= 2

In [None]:
#fig, axs = plt.subplots(2, 2,figsize=(30, 25))
# ax1 = axs[0,0]
# ax2 = axs[0,1]
# ax3 = axs[1,0]
# ax4 = axs[1,1]

fig = plt.figure(figsize=(40,25))
gs1 = gridspec.GridSpec(4, 1)
gs1.update(wspace=0.04, hspace=0.2) # set the spacing between axes.

ax1 = fig.add_subplot(gs1[0])
ax2 = fig.add_subplot(gs1[1])
ax3 = fig.add_subplot(gs1[2])
ax4 = fig.add_subplot(gs1[3])

crop_name_xpos = 100
crop = 'Onion'
if not plot_errors:
  ax1.plot(xxdates,true_prices[crop],linewidth=true_linewidth,alpha=true_alpha,zorder = true_order,label='True', color=true_color)
#ax1.plot(xxdates,ren_preds[crop],linewidth=ren_linewidth,color=ren_color,alpha=ren_alpha,zorder = ren_order, label='REN')
#ax1.plot(xxdates,lstm_preds[crop],linewidth=rnn_linewidth,zorder=rnn_order,label='RNN',alpha=rnn_alpha, color=rnn_color)
ax1.plot(xxdates,earima_preds[crop],linewidth=earima_linewidth,zorder=earima_order,label=earima_name,alpha=earima_alpha, color=earima_color)
ax1.plot(xxdates,arima_preds[crop],linewidth=arima_linewidth,zorder=arima_order,label=arima_name,alpha=arima_alpha, color=arima_color)
ax1.text(crop_name_xpos, 1500,'Onion',fontsize = rmse_fontsize)
crop = 'Potato'
if not plot_errors:
  ax2.plot(xxdates,true_prices[crop],linewidth=true_linewidth,alpha=true_alpha,zorder = true_order,label='True', color=true_color)
#ax2.plot(xxdates,ren_preds[crop],linewidth=ren_linewidth,color=ren_color,alpha=ren_alpha,zorder = ren_order, label='REN')
#ax2.plot(xxdates,lstm_preds[crop],linewidth=rnn_linewidth,zorder=rnn_order,label='RNN',alpha=rnn_alpha, color=rnn_color)
ax2.plot(xxdates,earima_preds[crop],linewidth=earima_linewidth,zorder=earima_order,label=earima_name,alpha=earima_alpha, color=earima_color)
ax2.plot(xxdates,arima_preds[crop],linewidth=arima_linewidth,zorder=arima_order,label=arima_name,alpha=arima_alpha, color=arima_color)
ax2.text(crop_name_xpos, 350,'Potato', fontsize = rmse_fontsize)


#ax1.set(xlabel='', xticks=xdates, ylim=[0,85])
ax1.set(xlabel='', xticks=xdates)
#ax1.set(xlabel='', xticks=xdates, ylim=[0, true_prices['Onion'].max()+10])
#ax2.set(xlabel='', ylabel='',xticks=xdates,ylim=[0,40])
ax2.set(xlabel='', ylabel='',xticks=xdates)
#ax2.set(xlabel='', ylabel='',xticks=xdates,ylim=[0, true_prices['Potato'].max()+10])
#ax1.set_ylabel('Price (\u20B9/kg)',fontsize=40)
ax1.set_xticklabels([])
ax2.set_xticklabels([])
ax3.set_xticklabels([])
#ax2.set_yticklabels([])
#ax2.yaxis.tick_right()
crop = 'Rice'
if not plot_errors:
  ax3.plot(xxdates,true_prices[crop],linewidth=true_linewidth,alpha=true_alpha,zorder = true_order,label='True', color=true_color)
#ax3.plot(xxdates,ren_preds[crop],linewidth=ren_linewidth,color=ren_color,alpha=ren_alpha,zorder = ren_order, label='REN')
#ax3.plot(xxdates,lstm_preds[crop],linewidth=rnn_linewidth,zorder=rnn_order,label='RNN',alpha=rnn_alpha, color=rnn_color)
ax3.plot(xxdates,earima_preds[crop],linewidth=earima_linewidth,zorder=earima_order,label=earima_name,alpha=earima_alpha, color=earima_color)
ax3.plot(xxdates,arima_preds[crop],linewidth=arima_linewidth,zorder=arima_order,label=arima_name,alpha=arima_alpha, color=arima_color)
ax3.text(crop_name_xpos,510,'Rice',fontsize = rmse_fontsize)

crop = 'Wheat'
if not plot_errors:
  ax4.plot(xxdates,true_prices[crop],linewidth=true_linewidth, alpha=true_alpha,zorder = true_order,label='Actual',color=true_color)
#ax4.plot(xxdates,ren_preds[crop],linewidth=ren_linewidth,color=ren_color,alpha=ren_alpha,zorder = ren_order, label='Our method (REN)')
#ax4.plot(xxdates,lstm_preds[crop],linewidth=rnn_linewidth,zorder=rnn_order,label='LSTM', alpha=rnn_alpha, color=rnn_color)
ax4.plot(xxdates,earima_preds[crop],linewidth=earima_linewidth,zorder=earima_order,label=earima_name,alpha=earima_alpha, color=earima_color)
ax4.plot(xxdates,arima_preds[crop],linewidth=arima_linewidth,zorder=arima_order,label=arima_name,alpha=arima_alpha, color=arima_color)
ax4.text(crop_name_xpos,210,'Wheat',fontsize = rmse_fontsize)

#ax3.set(xlabel='', xticks=xdates, ylim=[0,50])
ax3.set(xlabel='', xticks=xdates)
#ax4.set(xlabel='', ylabel='',xticks=xdates, ylim=[0,30])
ax4.set(xlabel='', ylabel='',xticks=xdates)


#ax3.set(xlabel='', xticks=xdates,ylim=[0,5000],yticks = np.arange(0, 5000, step=500))

#ax3.set(xlabel='', xticks=xdates, ylim=[0, true_prices['Rice'].max()+10])
#ax3.set_title('Rice', fontsize=rmse_fontsize)
#ax4.set(xlabel='', ylabel='',xticks=xdates,ylim=[0,5000],yticks = np.arange(0, 5000, step=500))

#ax4.set(xlabel='', ylabel='', xticks=xdates, ylim=[0, true_prices['Wheat'].max()+10])


#ax3.set_ylabel('Price (\u20B9/kg)',fontsize=40)
#ax4.set_xticks(xlabels_positions)
#ax4.yaxis.tick_right()

#ax1.legend(loc='best', shadow=True, fontsize=25)
#ax2.legend(loc='best', shadow=True, fontsize=25)
#ax3.legend(loc='lower right', shadow=True, fontsize=25)
#ax4.legend(loc='lower right', shadow=True, fontsize=25)

#ax1.set_yticklabels(np.arange(0,9000,1000),fontsize=25)
#ax3.set_yticklabels(np.arange(0, 50, 10), fontsize=25)
#ax2.set_yticklabels(np.arange(0, 40, 10), fontsize=25)
#ax4.set_yticklabels(np.arange(0, 50, 10), fontsize=25)
#ax1.set_yticklabels(np.arange(0, 80, 20), fontsize=25)


#ax3.set_xticklabels(dates, fontsize=25, rotation=0)
ax4.set_xticklabels(dates, fontsize=ytick_fontsize, rotation=0)
ax4.tick_params(axis='y', labelsize=ytick_fontsize)
ax3.tick_params(axis='y', labelsize=ytick_fontsize)
ax2.tick_params(axis='y', labelsize=ytick_fontsize)
ax1.tick_params(axis='y', labelsize=ytick_fontsize)
handles, labels = ax4.get_legend_handles_labels()
leg = fig.legend(handles, labels, loc='upper center', fancybox=True, fontsize=rmse_fontsize, ncol=5)
#fig.text(0.01,0.5,'Price (\u20B9/kg)',va='center',fontsize=40,rotation='vertical')
if plot_returns:
  fig.text(-0.02,0.5,'returns' if raw_returns else 'log(1 + returns)',va='center',fontsize=50,rotation='vertical')
else:
  fig.text(-0.02,0.5,'Forecast error (\u20B9/quintal)',va='center',fontsize=50,rotation='vertical')
#ax4.set_yticklabels([])
plt.setp(leg.get_lines(), linewidth=8)

ax1.grid(color='k', linewidth=.5, linestyle=':', axis='both')
ax2.grid(color='k', linewidth=.5, linestyle=':', axis='both')
ax3.grid(color='k', linewidth=.5, linestyle=':', axis='both')
ax4.grid(color='k', linewidth=.5, linestyle=':', axis='both')
#plt.close('all')
gs1.tight_layout(fig)


Output hidden; open in https://colab.research.google.com to view.