# Imports

### Make sure this notebook is in the "leam_us/active" directory.

In [1]:
import torch
from model.pytorch.dcrnn_supervisor import DCRNNSupervisor
from lib.utils import load_graph_data
import yaml
import numpy as np

### Model's weights path. Ensure all files I sent are in a path as listed below. There are other functions that have dependencies on this type of folder structure.  

In [2]:
model_folder_path = "seed1/itr11"

### Configuration pathway

In [3]:
config_path = "data/model/dcrnn_cov.yaml"

### Model tar file pathway

In [4]:
model_tar = torch.load(f"{model_folder_path}/model_epo101.tar")

### Loading the configuration file and the adjacency matrix

In [5]:
with open(config_path) as f:
        supervisor_config = yaml.safe_load(f)

        graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename')
        sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename)

### Necessary Arguments for creating a DCRNNSupervisor model

In [6]:
random_seed = 1
iteration = 11
max_iter = 12

### Model Creatiion

In [7]:
model = DCRNNSupervisor(random_seed, iteration, max_iter, adj_mx, **model_tar)

2025-02-07 18:28:45,497 - INFO - Log directory: data/model\dcrnn_DR_1_h_28_16_lr_0.0008_bs_128_0207182845_1_11/
2025-02-07 18:29:00,496 - INFO - Model created


### This is necessary for the load_model() function to know the model weight locations

In [8]:
model._epoch_num = 101

### Load the model

In [9]:
model.load_model()

2025-02-07 18:29:01,685 - INFO - Loaded model at 101


### There can be a little experimentation here of the 'dataset' argument

In [10]:
model.evaluate()

<generator object DataLoader.get_iterator.<locals>._wrapper at 0x000001AE70139AC0>
Iteration: 0
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 1
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 2
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 3
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 4
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 5
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 6
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 7
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 8
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 9
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 10
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 11
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 12
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 13
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 14
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 15
(128, 28, 58, 10)
(128, 28, 24)
(128, 24)
Iteration: 16
(128, 28, 58, 10)

(8227.506,
 53073.246,
 {'prediction': array([[[5.4823364e+01, 2.6583040e+01, 2.2328602e+01, ...,
           6.1966486e+00, 4.6376762e+00, 3.5907903e+00],
          [4.7840855e+01, 2.3697824e+01, 1.9783297e+01, ...,
           5.5793242e+00, 4.1731849e+00, 3.2527604e+00],
          [4.1564312e+01, 2.0729504e+01, 1.7351725e+01, ...,
           5.1131215e+00, 3.8518243e+00, 3.0763073e+00],
          ...,
          [1.1341774e+04, 4.6476455e+03, 3.2132300e+03, ...,
           2.4439966e+02, 4.6684369e+02, 1.7910605e+02],
          [1.1376192e+04, 4.6645449e+03, 3.2261362e+03, ...,
           2.4478119e+02, 4.6762067e+02, 1.7905856e+02],
          [1.1435600e+04, 4.6859688e+03, 3.2357930e+03, ...,
           2.4566048e+02, 4.6847903e+02, 1.7984157e+02]],
  
         [[1.6278187e+03, 5.5069238e+02, 4.2179947e+02, ...,
           3.6870502e+01, 2.2732197e+01, 1.2844892e+01],
          [1.6678132e+03, 5.6344183e+02, 4.3047198e+02, ...,
           3.7706295e+01, 2.2595886e+01, 1.2368067e+01],


In [11]:
inputer = 256

In [12]:
x = np.random.randn(inputer, 28, 58, 10)
y = np.random.randn(inputer, 28, 24)
x0 = np.random.randn(inputer, 24)
x, y, x0 = model._prepare_data(x, y, x0)

In [13]:
with torch.no_grad():
    result = model.dcrnn_model(x, y, x0, 0, True, model.z_mean_all, model.z_var_temp_all)[0].cpu()

result

tensor([[[ 5.4928,  4.8512,  4.6644,  ...,  3.1788,  3.1788,  2.5787],
         [ 0.8080,  0.5475,  0.5180,  ...,  0.3936,  0.2512,  0.2314],
         [ 0.7778,  0.4773,  0.4322,  ...,  0.3701,  0.2523,  0.3000],
         ...,
         [ 7.6129,  6.8809,  6.6911,  ...,  4.7545,  5.2214,  4.3421],
         [ 1.1565,  0.9009,  0.8988,  ...,  0.6683,  0.6104,  0.5220],
         [ 1.8824,  1.4922,  1.3958,  ...,  0.9083,  0.7647,  0.6270]],

        [[12.7179, 11.7588, 11.4001,  ...,  7.7819,  7.5892,  5.7611],
         [ 1.8296,  1.2124,  1.1228,  ...,  0.7693,  0.3165,  0.2801],
         [ 2.0524,  1.4918,  1.3115,  ...,  0.8339,  0.4482,  0.4670],
         ...,
         [ 9.7846,  8.9621,  8.6387,  ...,  5.8891,  6.5497,  5.2739],
         [ 3.0380,  2.4118,  2.2666,  ...,  1.3164,  1.0174,  0.8601],
         [ 2.0563,  1.4891,  1.3094,  ...,  0.7331,  0.3768,  0.4618]],

        [[13.2435, 12.3419, 11.9894,  ...,  8.1667,  7.9192,  5.7653],
         [ 2.5993,  1.7489,  1.6055,  ...,  1

In [14]:
z_var_all = 0.1 + 0.9 * torch.sigmoid(model.z_var_temp_all)
zs = model.dcrnn_model.sample_z(model.z_mean_all, z_var_all, inputer)
outputs_hidden = model.dcrnn_model.dcrnn_to_hidden(x)
# xz = torch.cat([h_outputs, zs], dim=-1)
output = model.dcrnn_model.decoder(x0, outputs_hidden, zs)
output

tensor([[[ 5.6938,  5.0375,  4.8464,  ...,  3.2998,  3.3240,  2.7109],
         [ 0.7466,  0.4870,  0.4623,  ...,  0.3540,  0.2142,  0.2286],
         [ 1.4672,  1.1062,  1.0242,  ...,  0.7036,  0.5476,  0.4714],
         ...,
         [ 6.7464,  6.0449,  5.9145,  ...,  4.3286,  4.8096,  4.0535],
         [ 0.9556,  0.7154,  0.7085,  ...,  0.5491,  0.4899,  0.4118],
         [ 1.5160,  1.1437,  1.0614,  ...,  0.6953,  0.5923,  0.5277]],

        [[12.7341, 11.8003, 11.4416,  ...,  7.7701,  7.6330,  5.7578],
         [ 1.6875,  1.0713,  0.9920,  ...,  0.6796,  0.2449,  0.2719],
         [ 2.4945,  1.9114,  1.6899,  ...,  0.9763,  0.5667,  0.5562],
         ...,
         [ 9.6717,  8.8435,  8.5238,  ...,  5.8222,  6.4910,  5.2422],
         [ 1.8774,  1.3694,  1.2432,  ...,  0.7071,  0.4249,  0.4519],
         [ 2.0040,  1.4389,  1.2535,  ...,  0.7056,  0.3671,  0.4684]],

        [[12.9775, 12.0838, 11.7292,  ...,  7.9726,  7.8438,  5.7815],
         [ 2.4178,  1.5635,  1.4260,  ...,  0