## Test and Explanation of datareader for Meteorological dataset including 3 differents locations (APEX, Paranal and La Silla)

In [1]:
# packages 
import torch
from pymongo import MongoClient
from AGG.extended_typing import ContinuousTimeGraphSample
import yaml

First we connect to mongodb database

In [2]:
db_name = "Meteo_test" 
client = MongoClient("mongodb://localhost:27017/")
db = client[db_name] 
main_collection = db[db_name]

and we see how it would be one sample in our main collection, which has all the samples from the three locations

In [3]:
main_collection.find_one({})

{'_id': ObjectId('65fb39460630f2ec7f2790cb'),
 'idx': 0,
 'time': '2023-01-01T00:00:58',
 'node_features': 11.03,
 'type_index': 0,
 'spatial_index': 0}

Let's create a sorted index list

In [4]:
with open('config.yaml','r') as f:
    config = yaml.safe_load(f)  

from create_train_test_collection import get_sorted_idx_list
ids = get_sorted_idx_list(config)

for this example we will use the first 100 samples

In [5]:
ids = ids[:100]
ids

[257895,
 257896,
 257897,
 257898,
 257899,
 0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 257900,
 257901,
 257902,
 257903,
 257904,
 218880,
 218881,
 218882,
 218883,
 218884,
 218885,
 218886,
 218887,
 218888,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 257905,
 257906,
 257907,
 257908,
 257909]

These will be the parameters I will use:
* remove: 0.3
* stride: 5
* context_len: 10

In [6]:
from create_train_test_collection import get_graph_sample_idx
samples:dict = get_graph_sample_idx(context_len = config['context_len'], 
                                stride = config['stride'], 
                                idx_list = ids, 
                                selection_size = config['remove']
                                )

Total targets:  30
Total train nodes:  70
Creating indexes of nodes...


100%|██████████| 14/14 [00:00<00:00, 57625.37it/s]


The next dictionary of samples will have the keys as an id of the graph sample,  and its value as a list of two elements, the first element is a list that contains the id's of the document in the mongodb main collection (nodes), and the second element is also an id in mongodb collecition of the target of that graph sample.

In [7]:
samples

{0: [[257895, 257896, 257898, 257899, 0, 2, 3, 5, 6, 8], 1],
 1: [[257895, 257896, 257898, 257899, 0, 2, 3, 5, 6, 8], 4],
 2: [[257895, 257896, 257898, 257899, 0, 2, 3, 5, 6, 8], 7],
 3: [[257895, 257896, 257898, 257899, 0, 2, 3, 5, 6, 8], 257897],
 4: [[2, 3, 5, 6, 8, 9, 10, 11, 12, 14], 4],
 5: [[2, 3, 5, 6, 8, 9, 10, 11, 12, 14], 7],
 6: [[2, 3, 5, 6, 8, 9, 10, 11, 12, 14], 13],
 7: [[9, 10, 11, 12, 14, 17, 18, 20, 21, 22], 13],
 8: [[9, 10, 11, 12, 14, 17, 18, 20, 21, 22], 15],
 9: [[9, 10, 11, 12, 14, 17, 18, 20, 21, 22], 16],
 10: [[9, 10, 11, 12, 14, 17, 18, 20, 21, 22], 19],
 11: [[17, 18, 20, 21, 22, 23, 24, 25, 26, 27], 19],
 12: [[23, 24, 25, 26, 27, 28, 29, 30, 31, 33], 32],
 13: [[28, 29, 30, 31, 33, 34, 35, 37, 257900, 257902], 32],
 14: [[28, 29, 30, 31, 33, 34, 35, 37, 257900, 257902], 36],
 15: [[28, 29, 30, 31, 33, 34, 35, 37, 257900, 257902], 257901],
 16: [[34, 35, 37, 257900, 257902, 218880, 218881, 218882, 218884, 218885],
  218883],
 17: [[34, 35, 37, 257900, 257

The targets are:

In [18]:
[samples[s][1] for s in samples]

[1,
 4,
 7,
 257897,
 4,
 7,
 13,
 13,
 15,
 16,
 19,
 19,
 32,
 32,
 36,
 257901,
 218883,
 36,
 257901,
 257903,
 257904,
 218883,
 218887,
 218888,
 39,
 218887,
 218888,
 39,
 43,
 50,
 50,
 60,
 63,
 66,
 63,
 66,
 70,
 71,
 72,
 73,
 74,
 257906]

Then we need to create a train-test split, for that we create a collection for train and train graph samples

In [8]:
from create_train_test_collection import create_train_test_db
with open('utils/yaml/int_name_normal_coef.yaml','r') as f:
    int_name_normal_coef = yaml.safe_load(f)
len_train, len_test = create_train_test_db(samples = samples, 
                                           config = config, 
                                           int_name_normal_coef = int_name_normal_coef,
                                           test_size = 0.3)

30 samples expeted to be inserted in train collection, and 12 in test collection.
Starting to build graph samples...(train)


100%|██████████| 30/30 [00:06<00:00,  4.76it/s]


Inserting samples in train collection...
30  Samples inserted in collection "train" of database "Meteo_test"
Starting to build graph samples...(test)


100%|██████████| 12/12 [00:02<00:00,  4.98it/s]


Inserting samples in test collection...
12  Samples inserted in collection "train" of database "Meteo_test"


Let's see how it would be a sample in train collection

In [9]:
train_collection = db['train']
sample = train_collection.find_one({})
sample

{'_id': ObjectId('66173ab48df25b180e166778'),
 'time': [-0.00022376543209876544,
  -0.00022376543209876544,
  -0.00022376543209876544,
  -0.00022376543209876544,
  -0.00022376543209876544,
  -0.00022376543209876544,
  0.0,
  0.0,
  0.0,
  0.0],
 'node_features': [-0.61361513304406,
  0.12777952322139885,
  -0.5516757888136073,
  -0.6367982181805324,
  -0.6615026139340858,
  -1.028877768986193,
  0.6086707889238913,
  0.37964792241395373,
  -0.27859289479827803,
  -0.41379477941577364],
 'type_index': [0, 2, 3, 5, 6, 8, 8, 14, 26, 34],
 'spatial_index': [0, 0, 0, 0, 0, 0, 2, 2, 2, 2],
 'key_padding_mask': [False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False],
 'kaboom': 'kaboom',
 'target': {'time': [-0.00022376543209876544],
  'features': [-0.5368557226697135],
  'type_index': [4],
  'spatial_index': [0],
  'dummy': None},
 'id': 0}

The last document id is the output of the function create_train_test_db() 

In [10]:
last_document = train_collection.find_one({}, sort=[('_id', -1)])
last_document

{'_id': ObjectId('66173ab48df25b180e166795'),
 'time': [7.71604938271605e-06,
  7.71604938271605e-06,
  7.71604938271605e-06,
  7.71604938271605e-06,
  7.71604938271605e-06,
  7.71604938271605e-06,
  7.71604938271605e-06,
  7.71604938271605e-06,
  0.0,
  0.0],
 'node_features': [-0.19145693532453092,
  1.2882563653766084,
  1.2486282703581366,
  1.224780851849927,
  -0.09010486727815423,
  -0.22469878366739834,
  -0.19967327609728663,
  -0.8137351769795824,
  0.6186699380881273,
  0.12479250750206229],
 'type_index': [23, 24, 26, 27, 29, 30, 31, 37, 8, 0],
 'spatial_index': [0, 0, 0, 0, 0, 0, 0, 0, 2, 2],
 'key_padding_mask': [False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False,
  False],
 'kaboom': 'kaboom',
 'target': {'time': [7.71604938271605e-06],
  'features': [-0.8592934922919462],
  'type_index': [36],
  'spatial_index': [0],
  'dummy': None},
 'id': 29}

Then we can create the attention_mask for the sample in order to give it as input to the class ContinuousTimeGraphSample

In [11]:
if "attention_mask" not in sample or len(sample["attention_mask"]) == 0:
    sample["time"] = torch.tensor(sample["time"], dtype=torch.float)
    sample["attention_mask"] = sample["time"].unsqueeze(-1).T < sample[
        "time"
    ].unsqueeze(-1)
ContinuousTimeGraphSample(**sample) 

ContinuousTimeGraphSample(node_features=tensor([-0.6136,  0.1278, -0.5517, -0.6368, -0.6615, -1.0289,  0.6087,  0.3796,
        -0.2786, -0.4138]), key_padding_mask=tensor([False, False, False, False, False, False, False, False, False, False]), edge_index=None, time=tensor([-0.0002, -0.0002, -0.0002, -0.0002, -0.0002, -0.0002,  0.0000,  0.0000,
         0.0000,  0.0000]), attention_mask=tensor([[False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,