In [92]:
import sys
sys.path.append('../')

import glob

from typing import Tuple, Union, List
import pandas as pd
import numpy as np
import networkx as nx
from tqdm.notebook import tqdm as tqdm

import torch
import torch.nn as nn

import torch_geometric as pyg
from torch_geometric.utils.convert import from_networkx

from src.utils import *
from dataset import *
from src.train import train, test
from src.dataloaders import make_data_loaders_from_dataset
from src.model import KnnEstimator

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')

from IPython.display import clear_output

%matplotlib inline
%load_ext autoreload
%autoreload 2

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [93]:
G = nx.readwrite.read_gpickle('data/network.gpickle')

In [94]:
# Pytorch geometric Data object. For now used only for storing node embedding. 
# Supposed to be used in the future for obtaining node embeddings.
pyg_graph = from_networkx(G)
pyg_graph.to(device)

Data(edge_index=[2, 990], id=[374], lat=[374], lon=[374], dist=[990], num_nodes=374)

$$T(a) = \sum_{b \neq a,\space b\subset Train } T(b)w(a, b), \textrm{where summation is calculated for the {\bf k} nearest neighbors.}$$

$$ w(a,b)= \frac{u(a,b)}{\sum_{b \neq a} u(a, b)};$$

$u(a, b) = exp(-\lambda_1 d(a, b));$

So for this model $\lambda_1$ is optimized.

In [144]:
def weight_fn(dists, lamb):
    return torch.exp(-lamb * dists)


class Estimator(KnnEstimator):
    def __init__(self, pyg_graph: pyg.data.Data, obs_nodes, obs_targets) -> None:
        super().__init__(pyg_graph, obs_nodes, obs_targets)

        # self.k = torch.tensor([1.0]).to(device)
        self.lambda_1 = nn.Parameter(torch.rand(1))
        self.lambda_2 = nn.Parameter(torch.rand(1))

    def forward(self, X):
        # getting nearest observed nodes
        X_indices = torch.as_tensor(self.node_to_idx(X))
        dists, indices = self.get_kneighbors(X_indices)
        
        dists = dists.to(device)
        indices = indices.to(device)


        dist_weights = weight_fn(dists, self.lambda_1)

        # sum normalizization
        dist_weights = nn.functional.normalize(dist_weights, p=1)

        att_weights = dist_weights
        targets = self.obs_targets[indices].to(device)

        # interpolation 
        result = torch.sum(att_weights.mul(targets), dim=-1)

        return result

In [145]:
results = dict()

In [153]:
loss_fn = nn.HuberLoss(delta=20).to(device)
model = None

for path in tqdm(glob.glob('datasets/*')[:20]):
    day = path.split('/')[1].split('.')[0]

    try:
        ds = torch.load(path)
    except:
        continue

    train_loader, val_loader, test_loader = make_data_loaders_from_dataset(ds, batch_size=16)
    train_batches = [batch for batch in train_loader]
    train_nodes = [n for batch in train_batches for n in batch[0]]
    train_targets = torch.cat([batch[1] for batch in train_batches])


    if model is None:
        model = Estimator(pyg_graph, train_nodes, train_targets).to(device)
    else:
        model.set_observations(train_nodes, train_targets)
        
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20, gamma=0.9)

    best_model = train(model, train_loader, val_loader, loss_fn, optimizer, device, num_epochs=10, plotting=False)
    test_loss, test_score = test(best_model, test_loader, loss_fn, device)

    # for name, param in best_model.named_parameters():
        # print(name, param)

    results[f'{day}'] = test_score 
    print(f'{day}, Test loss: {test_loss:.4f}, test score: {test_score:.4f}')

  0%|          | 0/20 [00:00<?, ?it/s]

2021-01-15, Test loss: 8825.0638, test score: -2.1456
2021-01-08, Test loss: 8637.1915, test score: -1.8071
2021-01-17, Test loss: 4927.3298, test score: -1.4336
2020-12-28, Test loss: 9726.3830, test score: -1.7430
2020-12-30, Test loss: 9664.2553, test score: -2.0520
2021-01-23, Test loss: 5524.1809, test score: -1.6672
2020-12-27, Test loss: 3448.2553, test score: -2.0251
2021-01-05, Test loss: 10563.8298, test score: -2.6386
2021-01-21, Test loss: 10371.4894, test score: -2.1178
2021-01-22, Test loss: 9481.7021, test score: -2.3296
2021-01-18, Test loss: 7512.7660, test score: -2.1822
2020-12-31, Test loss: 6922.1277, test score: -1.6886
2021-01-06, Test loss: 8902.5532, test score: -2.3122
2021-01-11, Test loss: 8913.6170, test score: -1.5141
2021-01-01, Test loss: 2700.8617, test score: -1.3528
2020-12-26, Test loss: 7697.7553, test score: -0.1433
2021-01-20, Test loss: 8659.5745, test score: -2.1994
2021-01-04, Test loss: 10564.5213, test score: -1.7385
2020-12-29, Test loss: 83

In [154]:
import json

with open('results/baseline.json', mode='w') as f:
    json.dump(results, f, indent=4)