In [None]:
from google.colab import drive
from google.colab import runtime
import pickle
import matplotlib.pyplot as plt
import numpy as np
drive.mount('/content/drive')

In [None]:
SIZE_TITLE = 24
SIZE_LABELS = 24
SIZE_TICKS = 18
SIZE_LEGEND = 18

In [None]:
!pip install gdown
!pip install -U pysal
!pip install geopandas
!pip install torch_geometric pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cpu.html
!pip install torch-geometric-temporal
!pip install --upgrade --force-reinstall git+https://github.com/FelipeSchreiber/COE770_Machine_Learning_on_Graphs.git

In [None]:
from Tests.benchmark import *
gammas=np.logspace(1,6,num=3).astype(int)
filters = [2,4,8]

In [None]:
benchmark = CovidBenchmark()
stats = benchmark.run_test(lags=4,filter_sizes=filters,\
                           train_model=True,\
                           gammas=gammas,\
                           num_epochs=100,\
                           warm_start=False)

In [None]:
stats

In [None]:
with open('covid_test_0.pickle', 'wb') as handle:
    pickle.dump(stats, handle, protocol=pickle.HIGHEST_PROTOCOL)
!cp "/content/covid_test_0.pickle" "/content/drive/MyDrive"

In [None]:
test_0=None
with open("/content/drive/MyDrive/covid_test_0.pickle", 'rb') as handle:
    test_0 = pickle.load(handle)
test_0.keys()

In [None]:
def scatter_(dict_,x_name,y_name,z_name):
  fig, ax = plt.subplots()
  x,y = dict_[x_name],np.log10(dict_[y_name])
  ax.scatter(x,y)
  for i, txt in enumerate(dict_[z_name]):
      ax.annotate("{:.2f}".format(txt), (x[i], y[i]))
  ax.set_xlabel(x_name)
  ax.set_ylabel(y_name)
scatter_(test_0,'filter_size','gamma', 'MSE')

In [None]:
fig = plt.figure(figsize=(5, 4))
ax = fig.add_subplot(111)
x,y,z = test_0['filter_size'] ,np.log10(test_0['gamma']),test_0["MSE"]
xlabel="filter_size"
ylabel='gamma'
ticks = np.linspace(np.min(z), np.max(z), 5, endpoint=True)
C = ax.scatter(x=x,y=y,c=z,cmap="coolwarm")
cb = fig.colorbar(C, ax=ax, fraction=0.02, pad=0.1, label='MSE',ticks=ticks)
cb.set_label(label='MSE', size=SIZE_LEGEND)
cb.ax.tick_params(labelsize=SIZE_TICKS)
plt.xlabel( xlabel, fontsize = SIZE_LABELS )
plt.ylabel( ylabel, fontsize = SIZE_LABELS )
plt.xticks( fontsize = SIZE_TICKS )
plt.yticks( fontsize = SIZE_TICKS )
plt.show()

In [None]:
loader = CovidDatasetLoader(method="other")
dataset = loader.get_dataset(lags=4)
# train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)
i = 0
fig = plt.figure(1, figsize=(8, 14), frameon=False, dpi=100)
for filter_size, gamma in tqdm(product(filters,gammas)):
  model = get_model(False,num_features=35,num_filters=filter_size,gamma=gamma)
  model.to(device)
  model.eval()
  cost = 0
  preds = []
  y = []
  for time, snapshot in enumerate(dataset):
      snapshot.to(device)
      y_hat,_ = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
      preds.append(y_hat.sum().cpu().detach().numpy())
      y.append(snapshot.y.sum().cpu().detach().numpy())
      del snapshot
  if i ==0:
    plt.plot(y,label=f"Y")
  plt.plot(preds,label=f"ADCRNN_{int(filter_size)}_{np.log10(gamma):.2f}")
  i+=1
plt.legend()
plt.ylabel("Total de casos agregados")
plt.xticks(rotation=45)
plt.savefig("agg.jpg")
plt.show()
!cp "/content/agg.jpg" "/content/drive/MyDrive"

In [None]:
runtime.unassign()