In [1]:
import sys
import os

# Set the main path in the root folder of the project.
sys.path.append(os.path.join('..'))

In [2]:
# Settings for autoreloading.
%load_ext autoreload
%autoreload 2

In [3]:
from src.utils.seed import set_random_seed

# Set the random seed for deterministic operations.
SEED = 42
set_random_seed(SEED)

In [4]:
import torch

# Set the device for training and querying the model.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'The selected device is: "{DEVICE}"')

The selected device is: "cuda"


# Building the event set

In [5]:
import os

BASE_DATA_DIR = os.path.join('..', 'data', 'metr-la')

In [6]:
import pickle
with open(os.path.join(BASE_DATA_DIR, 'processed', 'scaler.pkl'), 'rb') as f:
    scaler = pickle.load(f)

In [7]:
from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.data.data_extraction import get_adjacency_matrix

# Get the adjacency matrix
adj_matrix_structure = get_adjacency_matrix(
    os.path.join(BASE_DATA_DIR, 'adj_mx_metr_la.pkl'))

# Get the header of the adjacency matrix and the matrix itself.
header, _, adj_matrix = adj_matrix_structure

spatial_temporal_gnn = SpatialTemporalGNN(9, 1, 12, 12, adj_matrix, DEVICE, 64)

stgnn_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                      'st_gnn_metr_la.pth')

stgnn_checkpoints = torch.load(stgnn_checkpoints_path)
spatial_temporal_gnn.load_state_dict(stgnn_checkpoints['model_state_dict'])

spatial_temporal_gnn.eval()

SpatialTemporalGNN(
  (encoder): Linear(in_features=9, out_features=64, bias=False)
  (s_gnns): ModuleList(
    (0-11): 12 x S_GNN(
      (latent_encoder): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=False)
        (1): Linear(in_features=64, out_features=32, bias=False)
      )
      (linear): Linear(in_features=64, out_features=64, bias=False)
    )
  )
  (hidden_s_gnns): ModuleList(
    (0-10): 11 x S_GNN(
      (latent_encoder): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=False)
        (1): Linear(in_features=64, out_features=32, bias=False)
      )
      (linear): Linear(in_features=64, out_features=64, bias=False)
    )
  )
  (grus): ModuleList(
    (0-11): 12 x GRU(
      (z_x_linear): Linear(in_features=64, out_features=64, bias=False)
      (z_h_linear): Linear(in_features=64, out_features=64, bias=False)
      (r_x_linear): Linear(in_features=64, out_features=64, bias=False)
      (r_h_linear): Linear(in_features=64, out_fe

In [8]:
import os
import numpy as np
from src.spatial_temporal_gnn.prediction import predict

x_train = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_train.npy'))
y_train = predict(spatial_temporal_gnn, x_train, scaler, DEVICE)
x_val = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_val.npy'))
y_val = predict(spatial_temporal_gnn, x_val, scaler, DEVICE)
x_test = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_test.npy'))
y_test = predict(spatial_temporal_gnn, x_test, scaler, DEVICE)

In [9]:
from typing import List, Optional, Tuple

from src.data.data_analysis import days_encoder

def get_largest_event_set(data: np.ndarray
                          ) -> List[Tuple[str, Optional[int], Optional[int]]]:
    n_time_steps, n_nodes, _ = data.shape[-3:]

    # Get the largest event set related to the speed.
    speed_events = [
        (0, time_step, node) 
        for time_step in range(n_time_steps) 
        for node in range(n_nodes) 
        if data[..., time_step, node, 0] > 0]
    # Get the largest event set related to the time of day.
    time_of_day_events = [
        (1, time_step, None) for time_step in range(n_time_steps)]
    # Get the largest event set related to the day of week.
    day_of_week_events = [(2, None, None)]
    # Get the largest event set related to the kind of day.
    #kind_of_day_events = [(3, None, None)]

    return speed_events + time_of_day_events + day_of_week_events #+\
    #    kind_of_day_events

In [10]:
s = get_largest_event_set(x_test[11])

# Map the event set to the graph

In [11]:
from typing import List, Optional, Tuple, Union
import torch

def remove_features_by_events(
    data: Union[np.ndarray, torch.FloatTensor],
    events: List[Tuple[int, Optional[int], Optional[int]]]
    ) -> Union[np.ndarray, torch.FloatTensor]:
    if isinstance(data, torch.FloatTensor):
        filtered_data = data.clone()
    else:
        filtered_data = data.copy()
    n_time_steps, n_nodes, n_features = filtered_data.shape[-3:]

    speed_events = [tuple(event) for event in events if event[0] == 0]
    time_of_day_events = [tuple(event) for event in events if event[0] == 1]
    day_of_week_events = [tuple(event) for event in events if event[0] == 2]
    #kind_of_day_events = [event for event in events if event[0] == 3]

    # TODO: readd filtering day of week events.
    '''if n_features > 1 and not len(day_of_week_events):
        filtered_data[..., -7:] = 0'''

    for time_step in range(n_time_steps):
        for node in range(n_nodes):
            if (0, time_step, node) not in speed_events:
                filtered_data[..., time_step, node, 0] = 0

            '''if n_features > 1 and len(kind_of_day_events):
                if 1 in data[..., -7:-2]:
                    filtered_data[..., time_step, node, -7:-2] = 1
                else:
                    filtered_data[..., time_step, node, -2:] = 1''';
        # TODO: readd filtering day of week events.
        '''
        if n_features > 1 and (1, time_step, None) not in time_of_day_events:
            filtered_data[..., time_step, :, 1] = -1
        '''

    return filtered_data

In [12]:
def remove_single_event_from_data(
    data: Union[np.ndarray, torch.FloatTensor],
    event: Union[np.ndarray, torch.FloatTensor]
    ) -> Union[np.ndarray, torch.FloatTensor]:
    # n_time_steps, n_nodes, _ = data.shape[-3:]
    event_kind = event[0]
    time_step = int(event[1].item())
    node = int(event[2].item())
    if event_kind == 0:
        data[..., time_step, node, 0] = 0
    elif event_kind == 1:
        data[..., time_step, :, 1] = -1
    elif event_kind == 2:
        data[..., -7:] = 0
    #elif event[0] == 3:
    #    pass

    '''if event_kind in [0, 1]: 
        for time_step in range(n_time_steps):
            for node in range(n_nodes):
                if 1 in data[..., -7:-2]:
                    data[..., time_step, node, -7:-2] = 1
                else:
                    data[..., time_step, node, -2:] = 1'''

    return data

In [13]:
s = remove_features_by_events(x_test[0], [(1, 0, None)])

In [14]:
# (event_type, speed, time_of_day, ...day, speed_target, time_of_day_target, ...day_target)

# ->

# -> score

In [15]:
y_test[0].shape

(12, 207, 1)

In [44]:
import torch
from torch import nn
import numpy as np

def simulate_model(
    instance: torch.FloatTensor, events_scores: torch.FloatTensor) -> torch.FloatTensor:
    #eps = torch.rand_like(events_scores).float().to(instance.device)
    #eps_hat = (eps.log() - (1 - eps).log() + instance) / 1.0
    events_scores = events_scores.sigmoid()
    # TODO: Simulate all events, not just the speed events.
    relaxed_bernoulli = torch.distributions.RelaxedBernoulli(2.0, events_scores)
    #e = torch.sigmoid(events_scores)
    e = relaxed_bernoulli.sample()
    #e = torch.rand(1).float().to(instance.device)
    #e_hat = torch.sigmoid((torch.log(e) - torch.log(1 - e) + events_scores) / .05)
    result = e >= .5
    #print(result)
    #result = torch.bernoulli(events_scores)
    instance = result * instance
    return instance

In [45]:
import torch
from torch import nn

class Navigator(nn.Module):
    def __init__(self, device: str, hidden_features: int = 64) -> None:
        super().__init__()
        # Set the linear encoder.
        self.linear_encoder = nn.LazyLinear(hidden_features)
        # Set the linear decoder.
        self.linear_decoder = nn.Linear(hidden_features, 1)
        # Set the device that is used for training and querying the model.
        self.device = device
        self.to(device)
        
    def forward(self, x: torch.FloatTensor, y: torch.FloatTensor) -> torch.FloatTensor:
        # Concatenate the candidate event and the target events.
        #y = torch.flatten(y, start_dim=1)
        #print(x.shape, y.shape)
        input = torch.cat([x, y], dim=-1)
        # Encode the input.
        out = self.linear_encoder(input)
        out = torch.relu(out)
        # Decode the output to get the logits prediction.
        out = self.linear_decoder(out)
        #out = out.sigmoid()
        # Mask the output to only keep the non-zero speed events.
        #mask = x[..., 0:1] > 0
        #out = out * mask

        return out
        

In [46]:
model = Navigator(DEVICE)

In [47]:
from typing import Tuple
from torch.utils.data.dataloader import DataLoader, Dataset
import numpy as np

from src.spatial_temporal_gnn.prediction import predict


class EventsDataset(Dataset):
    def __init__(self, x: np.ndarray, y: np.ndarray) -> None:
        """Initialize the dataset

        Parameters
        ----------
        x : ndarray
            The input values of the dataset.
        y : ndarray
            The ground truth of the dataset's input data.
        """
        # Filter out the instances that do not have speed values at all.
        self.x = x[np.any(x[..., 0] != 0, axis=(1, 2))]
        self.y = y[np.any(x[..., 0] != 0, axis=(1, 2))]
        self.len = self.x.shape[0]

    def _mask_random_instance_features(self, instance: np.ndarray) -> np.ndarray:
        """Mask random features of a Spatial-Temporal Graph instance.

        Parameters
        ----------
        instance : ndarray
            The Spatial-Temporal Graph instance to mask random features
            from.

        Returns
        -------
        ndarray
            The Spatial-Temporal Graph instance with masked random
            features.
        """
        # Get the largest event set of the input instance.
        events = np.array(get_largest_event_set(instance))

        # TODO: consider all events, not solely speed events.
        events = events[events[..., 0] == 0]

        # Get the possible subset sizes of the input and target events.
        events_subset_sizes = np.arange(1, len(events) + 1)
        
        # Compute the probabilities of selecting certain subset sizes.
        # Favor larger subset sizes.
        sizes_probs = np.array(
            events_subset_sizes / np.sum(events_subset_sizes))

        # Select the size of the input events subset.
        [size] = np.random.choice(range(1, len(events) + 1), 1, p=sizes_probs)

        # Randomly select the events subset based on the selected size.
        selected_events_idx = np.random.choice(
            np.arange(len(events)), size=size, replace=False)
        selected_events = events[selected_events_idx]
        # Remove the features corresponding to the selected events.
        instance = remove_features_by_events(instance, selected_events.tolist())
        return instance

    def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
        """Get a dataset 

        Parameters
        ----------
        index : int
            The index from where to extract the dataset instance.

        Returns
        -------
        ndarray
            The input data at the given index with masked random
            features.
        ndarray
            The ground truth with respect to the input data at the
            given index with masked random features.
        """
        x = self.x[index].copy()
        y = self.y[index].copy()
        
        events = np.array(get_largest_event_set(x))
        events = events[events[..., 0] == 0]
        events_t = np.array(get_largest_event_set(y))
        events_t = events_t[events_t[..., 0] == 0]
        
        [selected_event_t_idx] = np.random.choice(
            np.arange(len(events_t)), size=1, replace=False)
        selected_event_t = events_t[selected_event_t_idx]
        instances = []

        for event in events:
            instance = x[event[1], event[2], :]
            instance = [event[1], event[2], *instance]
            instances.append(instance)

        #x = self._mask_random_instance_features(x)
        
        for t in range(y.shape[0]):
            for n in range(y.shape[1]):
                if t != selected_event_t[1] and n != selected_event_t[2]:
                    y[t, n] = 0.
        # y = self._mask_random_instance_features(y)
        
        # y_pred = predict(spatial_temporal_gnn, x, scaler, DEVICE)
        # y_pred[..., 0] = y_pred[..., 0] * y[..., 0] != 0
        instances = np.array(instances)
        t = np.array([selected_event_t[1], selected_event_t[2], x[selected_event_t[1], selected_event_t[2], 0]])
        return x, instances, t, y #y_pred

    def __len__(self) -> int:
        """Get the length of the dataset.

        Returns
        -------
        int
            The length of the dataset.
        """
        return self.len

def get_dataloader(x: np.ndarray, y: np.ndarray, 
                   batch_size: int, shuffle: bool) -> DataLoader:
    dataset = EventsDataset(x, y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)


In [48]:
train_loader = get_dataloader(x_train, y_train, batch_size=None, shuffle=True)
val_loader = get_dataloader(x_val, y_val, batch_size=None, shuffle=False)

In [49]:
instances = next(iter(train_loader))

In [50]:
print(instances[1].shape)

torch.Size([2277, 11])


In [51]:
from src.spatial_temporal_gnn.training import Checkpoint

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=2e-6)

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=.1, patience=10, verbose=False,
    threshold=.001, threshold_mode='rel', cooldown=0, min_lr=1e-6, eps=1e-08)

checkpoint_file_path = os.path.join('..', 'models', 'checkpoints',
                                    'navigator_metr_la.pth')
checkpoint = Checkpoint(checkpoint_file_path)

EPOCHS = 1_000

In [52]:
from time import time
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data import DataLoader

from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.spatial_temporal_gnn.metrics import MAE, RMSE, MAPE
from src.spatial_temporal_gnn.training import Checkpoint
from src.data.data_processing import Scaler

def train(
    model: Navigator, optimizer: torch.optim.Optimizer,
    train_dataloader: DataLoader, val_dataloader: DataLoader,
    spatial_temporal_gnn: SpatialTemporalGNN, scaler: Scaler,
    epochs: int, checkpoint: Optional[Checkpoint] = None,
    lr_scheduler: Optional[object] = None,
    reload_best_weights: bool = True) -> Dict[str, np.ndarray]:
    # Get the device that is used for training and querying the model.
    device = model.device

    # Initialize the training criterions.
    mae_criterion = MAE()
    rmse_criterion = RMSE()
    mape_criterion = MAPE()

    # Initialize the histories.
    metrics = ['train_mae', 'train_rmse', 'train_mape', 'val_mae', 'val_rmse',
               'val_mape']
    history = { m: [] for m in metrics }

    # Set model in training mode.
    model.train()

    # Iterate across the epochs.
    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')

        # Remove unused tensors from gpu memory.
        torch.cuda.empty_cache()

        # Initialize the running errors.
        running_train_mae = 0.
        running_train_rmse = 0.
        running_train_mape = 0.

        start_time = time()

        for batch_idx, (x, instances, t, y) in enumerate(train_dataloader):
            # Increment the number of batch steps.
            batch_steps = batch_idx + 1

            # Get the data.
            x = x.type(torch.float32).to(device=device)
            instances = instances.type(torch.float32).to(device=device)
            t = t.float().to(device=device)
            y = y.type(torch.float32).to(device=device)
            
            # Repeat y for each instance that has shape (Batch, Features)
            t_repeated = t.unsqueeze(0).repeat(instances.shape[0], 3)

            event_scores = model(instances, t_repeated)

            ev_scores = torch.zeros((x.shape[0], x.shape[1], 1))
            
            for i, x_ in enumerate(instances):
                timestep = int(x_[0].item()); node = int(x_[1].item())
                ev_scores[timestep, node, 0] = event_scores[i]
                
            ev_scores = ev_scores.to(device=device)
            # print(ev_scores.shape, x.shape)
            x_sim = simulate_model(x, ev_scores)
            #print(x_sim.shape)
            x_sim = scaler.scale(x_sim)

            # Compute the Spatial-Temporal GNN model predictions.
            y_pred = spatial_temporal_gnn(x_sim.unsqueeze(0))

            # Un-scale the predictions.
            y_pred = scaler.un_scale(y_pred)

            loss = mae_criterion(y_pred, y.unsqueeze(0))
            
            #print(y)

            # Compute errors and update running errors.
            with torch.no_grad():
                rmse = rmse_criterion(y_pred, y.unsqueeze(0))
                mape = mape_criterion(y_pred, y.unsqueeze(0))

            running_train_mae += loss.item()
            running_train_rmse += rmse.item()
            running_train_mape += mape.item()

            # Zero the gradients.
            optimizer.zero_grad()

            # Use MAE as the loss function for backpropagation.
            loss.backward()

            # Update the weights.
            optimizer.step()

            # Get the batch time.
            epoch_time = time() - start_time
            batch_time = epoch_time / batch_steps

            # Print the batch results.
            print(
                f'[{batch_steps}/{len(train_dataloader)}] -',
                f'{epoch_time:.0f}s {batch_time * 1e3:.0f}ms/step -',

                f'train {{ MAE (loss): {running_train_mae / batch_steps:.3g} -',
                f'RMSE: {running_train_rmse / batch_steps:.3g} -',
                f'MAPE: {running_train_mape * 100. / batch_steps:.3g}% }} -',

                f'lr: {optimizer.param_groups[0]["lr"]:.3g} -',
                f'weight decay: {optimizer.param_groups[0]["weight_decay"]}',
                '             ' if batch_steps < len(train_dataloader) else '',
                end='\r')

        '''# Set the model in evaluation mode.
        model.eval()

        # Get the average training errors and update the history.
        train_mae = running_train_mae / len(train_dataloader)
        train_rmse = running_train_rmse / len(train_dataloader)
        train_mape = running_train_mape / len(train_dataloader)

        history['train_mae'].append(train_mae)
        history['train_rmse'].append(train_rmse)
        history['train_mape'].append(train_mape)

        # Get the validation results and update the history.
        val_results = validate(model, val_dataloader, spatial_temporal_gnn,
                               scaler)
        val_mae, val_rmse, val_mape = val_results

        history['val_mae'].append(val_mae)
        history['val_rmse'].append(val_rmse)
        history['val_mape'].append(val_mape)

        # Save the checkpoints if demanded.
        if checkpoint is not None:
            err_sum = val_mae + val_rmse + val_mape
            checkpoint.save_best(model, optimizer, err_sum)

        # Print the epoch results.
        print(
            f'[{len(train_dataloader)}/{len(train_dataloader)}] -',
            f'{epoch_time:.0f}s -',

            f'train: {{ MAE (loss): {train_mae:.3g} -',
            f'RMSE: {train_rmse:.3g} -',
            f'MAPE: {train_mape * 100.:.3g}% }} -',

            f'val: {{ MAE: {val_mae:.3g} -',
            f'RMSE: {val_rmse:.3g} -',
            f'MAPE: {val_mape * 100.:.3g}% }} -',

            f'lr: {optimizer.param_groups[0]["lr"]:.3g} -',
            f'weight decay: {optimizer.param_groups[0]["weight_decay"]}')''';

        # Update the learning rate scheduler.
        #lr_scheduler.step(train_mae)

        # Set model in training mode.
        model.train()

    # Load the best weights of the model if demanded.
    if checkpoint is not None and reload_best_weights:
        checkpoint.load_best_weights(model)

    # Set the model in evaluation mode.
    model.eval()

    # Remove unused tensors from gpu memory.
    torch.cuda.empty_cache()

    # Turn the history to numpy arrays.
    for k, v in history.items():
        history[k] = np.array(v)

    return history

def validate(
    model: Navigator, val_dataloader: DataLoader,
    spatial_temporal_gnn: SpatialTemporalGNN, scaler: Scaler
    ) -> Tuple[float, float, float]:
    device = model.device
    torch.cuda.empty_cache()

    # Initialize the validation criterions.
    mae_criterion = MAE()
    rmse_criterion = RMSE()
    mape_criterion = MAPE()

    # Inizialize running errors.
    running_val_mae = 0.
    running_val_rmse = 0.
    running_val_mape = 0.

    with torch.no_grad():
        for _, (x) in enumerate(val_dataloader):
            # Get the data.
            x = x.type(torch.float32).to(device=device)
            y = y.type(torch.float32).to(device=device)

            event_scores = model(x, y.repeat(events.shape[0]))
            
            x_sim = simulate_model(x, event_scores)

            x_sim = scaler.scale(x_sim)

            # Compute the Spatial-Temporal GNN model predictions.
            y_pred = spatial_temporal_gnn(x_sim.unsqueeze(0))

            # Un-scale the predictions.
            y_pred = scaler.un_scale(y_pred)

            # Get the prediction errors and update the running errors.
            mae = mae_criterion(y_pred, y.unsqueeze(0))
            rmse = rmse_criterion(y_pred, y.unsqueeze(0))
            mape = mape_criterion(y_pred, y.unsqueeze(0))

            running_val_mae += mae.item()
            running_val_rmse += rmse.item()
            running_val_mape += mape.item()

    # Remove unused tensors from gpu memory.
    torch.cuda.empty_cache()

    # Get the average MAE, RMSE and MAPE scores.
    val_mae = running_val_mae / len(val_dataloader)
    val_rmse = running_val_rmse / len(val_dataloader)
    val_mape = running_val_mape / len(val_dataloader)

    return val_mae, val_rmse, val_mape

In [53]:
history = train(
    model, optimizer, train_loader, val_loader, spatial_temporal_gnn, scaler,
    EPOCHS, checkpoint, lr_scheduler, reload_best_weights=True)

Epoch 1/1000
[706/988] - 293s 415ms/step - train { MAE (loss): 6.76 - RMSE: 10.9 - MAPE: 20.7% } - lr: 0.0001 - weight decay: 2e-06              

KeyboardInterrupt: 

In [54]:
x, ev, t, y = next(iter(train_loader))

In [55]:
 #print(ev)

In [63]:
res = model(ev.to(DEVICE).float(), t.unsqueeze(0).repeat(ev.shape[0], 3).to(DEVICE).float())
ev_scores = torch.zeros((x.shape[0], x.shape[1], 1))
            
for i, x_ in enumerate(ev):
    timestep = int(x_[0].item()); node = int(x_[1].item())
    ev_scores[timestep, node, 0] = res[i]

In [64]:
ev_scores.shape

torch.Size([12, 207, 1])

In [65]:
torch.unique(ev_scores)

tensor([-10.5852, -10.4704, -10.3164,  ...,  -1.4401,  -1.3886,   0.0000],
       grad_fn=<Unique2Backward0>)

In [66]:
ev_scores

tensor([[[ -1.3886],
         [ -1.7143],
         [ -1.6388],
         ...,
         [ -7.6415],
         [ -7.3941],
         [ -7.5763]],

        [[ -1.4813],
         [ -1.7702],
         [ -1.6947],
         ...,
         [ -7.6227],
         [ -8.8487],
         [ -7.5781]],

        [[ -1.5261],
         [ -1.8230],
         [ -1.7042],
         ...,
         [ -7.6943],
         [ -8.9877],
         [ -7.7397]],

        ...,

        [[ -2.0669],
         [ -2.0433],
         [ -2.0033],
         ...,
         [ -8.6425],
         [-10.5852],
         [ -9.7031]],

        [[ -2.1371],
         [ -2.1208],
         [ -2.0322],
         ...,
         [ -8.4683],
         [-10.1568],
         [ -9.3274]],

        [[ -2.1560],
         [ -2.1612],
         [ -2.0715],
         ...,
         [ -8.6440],
         [-10.3164],
         [ -9.6983]]], grad_fn=<CopySlices>)

In [None]:
def _mask_random_instance_features(instance: np.ndarray) -> np.ndarray:
    """Mask random features of a Spatial-Temporal Graph instance.

    Parameters
    ----------
    instance : ndarray
        The Spatial-Temporal Graph instance to mask random features
        from.

    Returns
    -------
    ndarray
        The Spatial-Temporal Graph instance with masked random
        features.
    """
    with torch.no_grad:
        # Get the largest event set of the input instance.
        instance = instance.cpu().numpy().copy()
        events = np.array(get_largest_event_set(instance))

        # TODO: consider all events, not solely speed events.
        events = events[events[..., 0] == 0]
        # Get randomly either 1 or 0
        integ = np.random.randint(2)
        
        if integ == 1:
            selected_events = events[events[..., 1] > 8]
        else:
            # Get the possible subset sizes of the input and target events.
            events_subset_sizes = np.arange(1, 10)
            
            # Compute the probabilities of selecting certain subset sizes.
            # Favor larger subset sizes.
            sizes_probs = np.array(
                events_subset_sizes / np.sum(events_subset_sizes))

            # Select the size of the input events subset.
            [size] = np.random.choice(range(1, 500), 1)

            # Randomly select the events subset based on the selected size.
            selected_events_idx = np.random.choice(
                np.arange(len(events)), size=size, replace=False)
            print(selected_events_idx)
            selected_events = events[selected_events_idx]
            # Remove the features corresponding to the selected events.
        instance = remove_features_by_events(instance, selected_events.tolist())
        return instance


NameError: name 'np' is not defined

In [None]:
x, ev, y = next(iter(train_loader))

In [None]:
x_ = torch.Tensor(_mask_random_instance_features(x)).float().to(DEVICE)

[1998 1863 1208 2366 2384  684  993  980 2369 1734  413 1335  750 2410
  286  946 1039 1646  131  934 1904 2335  203 2081 2152  504  891 2328
  161  600 2159 1324  189  524  456 1057 1845 1514  828 2471 1020 2256
  140 1821 2295 1086  512 1650  789  571 2368 1043 1544  372 1445  827
  156 1468 1932  412  329  511 2310 1436 1525 2262  176  660 1105  219
  695 2221  822 1440  762 2185  798 1081 1909 2172  229  184 2216 1402
  972 1521  705 1328  330 2004 2291 1228 1702  880  779 1960 2337  278
  830  114  536  507 1710  394  365  808 1307 1891 2027  726 2222 1910
 1286 1953  242 1716 2467 1037   59 1111  693 1908  367  260 1052 2462
 1857  815 2096 1993 1753  508 1087 1343 1378 1216 2245  685 1553  638
  111 1720 1728 1142 1741 1943 1368 1592  999 1465 1594  533  561  630
    2 1235   65 1956 1131 1053  382  225 1657  746  870 1570 2080 2342
 1666 1965  497 1874 1784  941 1170 1409  104 2138 1838 2461  856  947
  419  209 1022 1948 1192 2465 1472  631 2294 1358 1679 2383   25  349
 1498 

In [None]:
from src.spatial_temporal_gnn.metrics import MAE
l = MAE()
l(spatial_temporal_gnn(x_.unsqueeze(0)), y.unsqueeze(0).float().to(DEVICE))

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 6.00 GiB total capacity; 5.21 GiB already allocated; 0 bytes free; 5.28 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
spatial_temporal_gnn.s_gnns[0].linear.weight.shape

torch.Size([64, 64])

In [None]:
spatial_temporal_gnn.s_gnns[0].linear.weight.shape

torch.Size([64, 64])

In [None]:
#stgnn_checkpoints['model_state_dict']

OrderedDict([('encoder.weight',
              tensor([[ 2.5134e-01,  3.0450e-01,  7.5520e-02,  2.0865e-01,  2.8809e-01,
                        9.7634e-02, -2.9664e-01, -2.0466e-01, -3.4900e-01],
                      [-1.8688e-01, -2.3607e-01,  2.9868e-01,  1.5404e-01, -2.5910e-01,
                        2.4225e-01,  2.0629e-01, -3.3710e-01,  1.8940e-01],
                      [-1.9582e-01, -3.2470e-01, -2.3799e-01, -2.7147e-01, -2.7561e-01,
                        2.1004e-01,  1.2448e-01, -5.7050e-02,  2.0347e-01],
                      [-3.3320e-02, -3.4339e-01,  1.4654e-01, -2.7905e-01, -1.5994e-01,
                       -3.7389e-01, -3.1333e-01, -2.5504e-01, -3.4831e-01],
                      [-2.1683e-01, -3.1467e-01, -7.4819e-03,  2.4080e-01, -6.4404e-02,
                       -5.7576e-02,  1.4887e-01,  3.4442e-01, -1.9733e-01],
                      [ 3.1908e-01, -1.3602e-01,  2.5490e-01, -1.2949e-01, -2.0595e-01,
                        5.2199e-02, -3.0044e-01,  1.6180e-01

In [None]:
#spatial_temporal_gnn.state_dict()

OrderedDict([('encoder.weight',
              tensor([[ 2.5134e-01,  3.0450e-01,  7.5520e-02,  2.0865e-01,  2.8809e-01,
                        9.7634e-02, -2.9664e-01, -2.0466e-01, -3.4900e-01],
                      [-1.8688e-01, -2.3607e-01,  2.9868e-01,  1.5404e-01, -2.5910e-01,
                        2.4225e-01,  2.0629e-01, -3.3710e-01,  1.8940e-01],
                      [-1.9582e-01, -3.2470e-01, -2.3799e-01, -2.7147e-01, -2.7561e-01,
                        2.1004e-01,  1.2448e-01, -5.7050e-02,  2.0347e-01],
                      [-3.3320e-02, -3.4339e-01,  1.4654e-01, -2.7905e-01, -1.5994e-01,
                       -3.7389e-01, -3.1333e-01, -2.5504e-01, -3.4831e-01],
                      [-2.1683e-01, -3.1467e-01, -7.4819e-03,  2.4080e-01, -6.4404e-02,
                       -5.7576e-02,  1.4887e-01,  3.4442e-01, -1.9733e-01],
                      [ 3.1908e-01, -1.3602e-01,  2.5490e-01, -1.2949e-01, -2.0595e-01,
                        5.2199e-02, -3.0044e-01,  1.6180e-01

In [None]:
def validate_state_dicts(model_state_dict_1, model_state_dict_2):
    if len(model_state_dict_1) != len(model_state_dict_2):
        print(
            f"Length mismatch: {len(model_state_dict_1)}, {len(model_state_dict_2)}"
        )
        return False

    # Replicate modules have "module" attached to their keys, so strip these off when comparing to local model.
    if next(iter(model_state_dict_1.keys())).startswith("module"):
        model_state_dict_1 = {
            k[len("module") + 1 :]: v for k, v in model_state_dict_1.items()
        }

    if next(iter(model_state_dict_2.keys())).startswith("module"):
        model_state_dict_2 = {
            k[len("module") + 1 :]: v for k, v in model_state_dict_2.items()
        }

    for ((k_1, v_1), (k_2, v_2)) in zip(
        model_state_dict_1.items(), model_state_dict_2.items()
    ):
        if k_1 != k_2:
            print(f"Key mismatch: {k_1} vs {k_2}")
            return False
        # convert both to the same CUDA device
        if str(v_1.device) != "cuda:0":
            v_1 = v_1.to("cuda:0" if torch.cuda.is_available() else "cpu")
        if str(v_2.device) != "cuda:0":
            v_2 = v_2.to("cuda:0" if torch.cuda.is_available() else "cpu")

        if not torch.allclose(v_1, v_2):
            print(f"Tensor mismatch: {v_1} vs {v_2}")
            return False

validate_state_dicts(spatial_temporal_gnn.state_dict(), stgnn_checkpoints['model_state_dict'])