In [1]:

import arviz as az
from scipy.optimize import curve_fit
from bokeh.plotting import figure, show,output_file, save
from bokeh.transform import factor_cmap, factor_mark
from bokeh.palettes import Spectral
from bokeh.models import Slope, Div
from bokeh.io import curdoc,output_notebook,export_png
from bokeh.layouts import column,gridplot
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
from bokeh.models import Band, ColumnDataSource

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from Plotting import *

In [2]:
train=pd.read_excel(r"../data/chem_train_pm.xlsx")
test=pd.read_excel(r"../data/chem_test_pm.xlsx")

In [3]:
gpr=torch.load("../models/pm25_seasonality",weights_only=False,map_location='cuda')

In [4]:
gpr_bc=torch.load("../models/seasonality_BC",weights_only=False,map_location='cuda')

In [5]:
train_predict=train.copy()
train=train.loc[(train.corrected_week>0)&(train.corrected_week<53),:]
test_predict=test.copy()
test=test.loc[(test.corrected_week>0)&(test.corrected_week<53),:]

In [6]:
meanbc_train,data_bc_train=roll_week(train_predict,1000,2,particle=["BC_Gaussion"])
meanpm_train,data_pm_train=roll_week(train_predict,1000,2,particle=["pm25_Gaussion"])
meanbc_test,data_bc_test=roll_week(test_predict,1000,2,particle=["BC_Gaussion"])
meanpm_test,data_pm_test=roll_week(test_predict,1000,2,particle=["pm25_Gaussion"])

In [7]:
plot_pm,model_pm=gaussian_plot(np.linspace(0,52.5,1000),gpr) 
plot_bc,model_bc=gaussian_plot(np.linspace(0,52.5,1000),gpr_bc) 

In [8]:
r2_bc,r2_pm=r2_score(data_bc_train.y,model_bc.y),r2_score(data_pm_train.y,model_pm.y)
r2_bc_test,r2_pm_test=r2_score(data_bc_test.y,model_bc.y),r2_score(data_pm_test.y,model_pm.y)

In [9]:
output_notebook()
TOOLS="hover,crosshair,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,undo,redo,reset,tap,save,box_select,poly_select,lasso_select,examine,help"

p = figure(x_range=(0, 53),y_range=(0,60));
p.title.text = r"$$PM_{2.5}$$ seasonality";
p.xgrid.grid_line_color=None;
p.ygrid.grid_line_alpha=0.5;
p.line(data_pm_train.index, data_pm_train.y, line_width=3,color="green",legend_label="Train roll mean R2 "+str(round(r2_pm,2)));
p.line(data_pm_test.index, data_pm_test.y, line_width=3,color="orange",legend_label="Test roll mean R2 "+str(round(r2_pm_test,2)));
p.line(model_pm.index, model_pm.y, line_width=3,color="red",legend_label="Gaussion model ");
p.scatter(train.corrected_week, y=train.pm25_Gaussion, color="blue", marker="dot", size=20, alpha=0.4,legend_label="raw points");
band = Band(base="index", lower="lower", upper="upper",source=meanpm_train, fill_color="red", line_color="black",fill_alpha=0.2);
band1 = Band(base="index", lower="lower", upper="upper",source=plot_pm,fill_alpha=0.5, fill_color="blue", line_color="black");
p.yaxis.axis_label_orientation  = 0
p.add_layout(band);
p.add_layout(band1);
p.xaxis.axis_label = r'$$Week \ of \ the \ year$$';
p.yaxis.axis_label = r'$$\frac{\mu g}{m^3} $$';


p1 = figure(x_range=(0, 53),y_range=(0,2));
p1.title.text = r"BC seasonality";
p1.xgrid.grid_line_color=None;
p1.ygrid.grid_line_alpha=0.5;
p1.line(data_bc_train.index, data_bc_train.y, line_width=3,color="green",legend_label="Train roll mean R2 "+str(round(r2_bc,2)));
p1.line(data_bc_test.index, data_bc_test.y, line_width=3,color="orange",legend_label="Test roll mean R2 "+str(round(r2_bc_test,2)));
p1.line(model_bc.index, model_bc.y, line_width=3,color="red",legend_label="Gaussion model ");
p1.scatter(train.corrected_week, y=train.BC_Gaussion, color="blue", marker="dot", size=20, alpha=0.4,legend_label="raw points");
band2 = Band(base="index", lower="lower", upper="upper",source=meanbc_test, fill_color="red", line_color="black",fill_alpha=0.2);
band3 = Band(base="index", lower="lower", upper="upper",source=plot_bc,fill_alpha=0.5, fill_color="blue", line_color="black");
p1.yaxis.axis_label_orientation  = 0
p1.add_layout(band2);
p1.add_layout(band3);
p1.xaxis.axis_label = r'$$Week \ of \ the \ year$$';
p1.yaxis.axis_label = r'$$\frac{\mu g}{m^3} $$';
grid =  gridplot([[p, p1]], width=500, height=500);

show(column(grid));