In [15]:
import torch
import arviz as az
from scipy.optimize import curve_fit
from bokeh.plotting import figure, show,output_file, save
from bokeh.transform import factor_cmap, factor_mark
from bokeh.palettes import Spectral
from bokeh.models import Slope, Div
from bokeh.io import curdoc,output_notebook,export_png
from bokeh.layouts import column,gridplot
#from print_versions import print_versions
from sklearn.metrics import r2_score
from seaborn import clustermap

from bokeh.models import Band, ColumnDataSource
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


In [16]:
train=pd.read_excel(r"../data/chem_train.xlsx")
test=pd.read_excel(r"../data/chem_test.xlsx")

In [17]:
train_predict=train.copy()
train=train.loc[(train.corrected_week>0)&(train.corrected_week<53),:]
test_predict=test.copy()
test=test.loc[(test.corrected_week>0)&(test.corrected_week<53),:]

In [18]:
gpr=torch.load("../models/no2_seasonality_synthertic",weights_only=False,map_location='cuda')

In [19]:
gpr_bc=torch.load("../models/nox_seasonality_synthertic",weights_only=False,map_location='cuda')

In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [21]:
linmod=torch.linspace(0,52,800).to(device)
pm25_model,pm25_model_std=gpr(linmod,full_cov=True)

linmod_np=linmod.cpu().detach().numpy()
pm25_model_np,pm25_model_std_np=pm25_model.cpu().detach().numpy().copy(),pm25_model_std.diag().sqrt().cpu().detach().numpy().copy()

lower1 = pm25_model_np - pm25_model_std_np
upper1 = pm25_model_np + pm25_model_std_np
data1=pd.DataFrame([linmod_np,lower1,upper1],index=["corrected_week","lower","upper"]).T
data1 = ColumnDataSource(data1.reset_index())

In [22]:

bc_model,bc_model_std=gpr_bc(linmod, full_cov=True)
bc_model_np,bc_model_std_np=bc_model.cpu().detach().numpy().copy(),bc_model_std.diag().sqrt().cpu().detach().numpy().copy()

lower1_bc = bc_model_np - bc_model_std_np
upper1_bc = bc_model_np + bc_model_std_np
data1_bc=pd.DataFrame([linmod_np,lower1_bc,upper1_bc],index=["corrected_week","lower_bc","upper_bc"]).T
data1_bc = ColumnDataSource(data1_bc.reset_index())

In [23]:
def roll_week(data,resolution,week,particle):
    mean=[]
    std=[]
    x=np.linspace(0,53,resolution,endpoint=True)
    sort=data.sort_values("corrected_week")
    list_std=[particle+"_std" for particle in particle]
    for i in x:
        mean.append(sort.loc[(i-week<sort.corrected_week)&(i+week>sort.corrected_week),particle].mean())
        std.append(sort.loc[(i-week<sort.corrected_week)&(i+week>sort.corrected_week),particle].std())
    mean_1,std_1=pd.DataFrame(mean,index=x),pd.DataFrame(std,index=x)
    std_1.columns=std_1.columns+"_std"
    data=pd.concat([mean_1,std_1],axis=1)
    return data.set_index(x)

In [24]:
resolution=500
week_averange=2
mean=roll_week(train_predict,resolution,week_averange,["no2","nox"])
mean_test=roll_week(test_predict,resolution,week_averange,["no2","nox"])
x=np.linspace(0,53,resolution,endpoint=True)
nox=gpr_bc(torch.tensor(x).float().to(device))[0].cpu().detach().numpy()
no2=gpr(torch.tensor(x).float().to(device))[0].cpu().detach().numpy()

In [25]:
lower1_std = mean.no2 - mean.no2_std
upper1_std = mean.no2 + mean.no2_std
lower1_bc_std  = mean.nox - mean.nox_std
upper1_bc_std  = mean.nox + mean.nox_std
mean["lower1_std_pm"]=lower1_std
mean["upper1_std_pm"]=upper1_std
mean["lower1_std_bc"]=lower1_bc_std
mean["upper1_std_bc"]=upper1_bc_std

mean=mean.dropna()
mean1 = ColumnDataSource(mean.reset_index())

In [26]:
r2_no2_test,r2_nox_test=r2_score(mean_test.no2,no2),r2_score(mean_test.nox,nox)
r2_no2,r2_nox=r2_score(mean.no2,no2),r2_score(mean.nox,nox)

In [27]:
output_notebook()
TOOLS="hover,crosshair,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,undo,redo,reset,tap,save,box_select,poly_select,lasso_select,examine,help"

In [28]:
output_file(filename="custom_filename.html", title="Static HTML file")
p = figure(x_range=(0, 53),y_range=(5, 18));
p.title.text = r" $$NO_2$$ seasonality";
p.xgrid.grid_line_color=None;
p.ygrid.grid_line_alpha=0.5;
p.line(mean_test.index, mean_test.no2, line_width=3,color="orange",legend_label="test roll mean R2 "+str(round(r2_no2_test,2)));
p.line(mean.index, mean.no2, line_width=3,color="green",legend_label="Train roll mean R2 "+str(round(r2_no2,2)));
p.line(linmod_np, pm25_model_np, line_width=3,color="red",legend_label="Gaussion model");
p.scatter(train.corrected_week, y=train.no2, color="blue", marker="dot", size=20, alpha=0.4,legend_label="raw points");
band = Band(base="index", lower="lower1_std_pm", upper="upper1_std_pm",source=mean1, fill_color="red", line_color="black",fill_alpha=0.2);
band1 = Band(base="corrected_week", lower="lower", upper="upper",source=data1,fill_alpha=0.5, fill_color="blue", line_color="black");
p.yaxis.axis_label_orientation  = 0
p.add_layout(band);
p.add_layout(band1);
p.xaxis.axis_label = r'$$Week \ of \ the \ year$$';
p.yaxis.axis_label = r'$$\frac{\mu g}{m^3} $$';
p1 = figure(x_range=(0, 53),y_range=(0, 40));
p1.title.text = r" $$NO_x$$ seasonality";
p1.xgrid.grid_line_color=None;
p1.ygrid.grid_line_alpha=0.5;
p1.line(mean.index, mean.nox, line_width=3,color="green",legend_label="train roll mean R2 "+str(round(r2_nox,2)));
p1.line(mean_test.index, mean_test.nox, line_width=3,color="orange",legend_label="test roll mean R2 "+str(round(r2_nox_test,2)));

p1.line(linmod_np, bc_model_np, line_width=3,color="red",legend_label="Gaussion model");
p1.scatter(train.corrected_week, y=train.nox, color="blue", marker="dot", size=20, alpha=0.8,legend_label="raw points");
band2 = Band(base="corrected_week", lower="lower_bc", upper="upper_bc",source=data1_bc, fill_color="blue", line_color="black",fill_alpha=0.5);
band3 = Band(base="index", lower="lower1_std_bc", upper="upper1_std_bc",source=mean1, fill_color="red", line_color="black",fill_alpha=0.2);
p1.add_layout(band3);
p1.add_layout(band2);
p1.yaxis.axis_label_orientation  = 0
p1.xaxis.axis_label = r'$$Week \ of \ the \ year$$';
p1.yaxis.axis_label = r'$$ \frac{\mu g}{m^3} $$';
p1.legend.title_text_font_size = "16px";
p.legend.title_text_font_size = "16px";
p.xaxis.axis_label_text_font_size = "14px";
p.yaxis.axis_label_text_font_size = "14px";
p1.xaxis.axis_label_text_font_size = "14px";
p1.yaxis.axis_label_text_font_size = "14px";
p1.yaxis.axis_label_text_font_style = "bold";
p1.xaxis.axis_label_text_font_style = "bold";
p.yaxis.axis_label_text_font_style = "bold";
p.xaxis.axis_label_text_font_style = "bold";
grid =  gridplot([[p, p1]], width=500, height=500);

show(column(grid));