In [2]:
import sys
import os

# Set the main path in the root folder of the project.
sys.path.append(os.path.join('..'))

In [3]:
# Settings for autoreloading.
%load_ext autoreload
%autoreload 2

In [4]:
from src.utils.seed import set_random_seed

# Set the random seed for deterministic operations.
SEED = 42
set_random_seed(SEED)

In [5]:
import torch

# Set the device for training and querying the model.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'The selected device is: "{DEVICE}"')

The selected device is: "cuda"


# Building the event set

In [6]:
import os

BASE_DATA_DIR = os.path.join('..', 'data', 'metr-la')

In [7]:
import pickle

# Get the data scaler.
with open(os.path.join(BASE_DATA_DIR, 'processed', 'scaler.pkl'), 'rb') as f:
    scaler = pickle.load(f)

In [8]:
from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.data.data_extraction import get_adjacency_matrix

# Get the adjacency matrix
adj_matrix_structure = get_adjacency_matrix(
    os.path.join(BASE_DATA_DIR, 'adj_mx_metr_la.pkl'))

# Get the header of the adjacency matrix and the matrix itself.
header, _, adj_matrix = adj_matrix_structure

# Get the STGNN and load the checkpoints.
spatial_temporal_gnn = SpatialTemporalGNN(9, 1, 12, 12, adj_matrix, DEVICE, 64)

stgnn_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                      'st_gnn_metr_la.pth')

stgnn_checkpoints = torch.load(stgnn_checkpoints_path)
spatial_temporal_gnn.load_state_dict(stgnn_checkpoints['model_state_dict'])

# Set the model in evaluation mode.
spatial_temporal_gnn.eval();

In [9]:
import os
import numpy as np
from src.spatial_temporal_gnn.prediction import predict

# Get the data and the values predicted by the STGNN.
x_train = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_train.npy'))
y_train = predict(spatial_temporal_gnn, x_train, scaler, DEVICE)
x_val = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_val.npy'))
y_val = predict(spatial_temporal_gnn, x_val, scaler, DEVICE)
x_test = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_test.npy'))
y_test = predict(spatial_temporal_gnn, x_test, scaler, DEVICE)

# Map the event set to the graph

In [137]:
import torch
from torch import nn
import numpy as np

def simulate_model(
    instance: torch.FloatTensor, events_scores: torch.FloatTensor) -> torch.FloatTensor:
    #eps = torch.rand_like(events_scores).float().to(instance.device)
    #eps_hat = (eps.log() - (1 - eps).log() + instance) / 1.0
    
    
    events_scores = events_scores.sigmoid()
    
    
    #eps = torch.rand_like(events_scores)
    eps = torch.rand(1).to(instance.device)
    events_scores = torch.sigmoid((eps.log() - (1 - eps).log() + events_scores) / 2.0)
    # TODO: Simulate all events, not just the speed events.
    #relaxed_bernoulli = torch.distributions.RelaxedBernoulli(2.0, events_scores)
    #e = torch.sigmoid(events_scores)
    #e = relaxed_bernoulli.sample()
    #e = torch.rand(1).float().to(instance.device)
    #e_hat = torch.sigmoid((torch.log(e) - torch.log(1 - e) + events_scores) / .05)
    result = events_scores >= .5
    #print(result)
    #result = torch.bernoulli(events_scores)
    instance = result * instance
    return instance

In [138]:
import torch
from torch import nn

class Navigator(nn.Module):
    def __init__(self, device: str, hidden_features: int = 64) -> None:
        super().__init__()
        # Set the linear encoder.
        self.linear_encoder = nn.LazyLinear(hidden_features)
        # Set the linear decoder.
        self.linear_decoder = nn.Linear(hidden_features, 1)
        # Set the device that is used for training and querying the model.
        self.device = device
        self.to(device)

    def forward(self, candidate_event: torch.FloatTensor, target_events: torch.FloatTensor) -> torch.FloatTensor:
        # Concatenate the candidate event and the target events.
        x = torch.cat((candidate_event, target_events), dim=1)
        # Encode the input.
        out = self.linear_encoder(x)
        # Decode the output to get the logits prediction.
        out = self.linear_decoder(out)
        return out
        

In [167]:
model = Navigator(DEVICE)



In [168]:
from src.explanation.navigator.dataloaders import get_dataloader

In [169]:
train_loader = get_dataloader(x_train, y_train, batch_size=None, shuffle=True)
val_loader = get_dataloader(x_val, y_val, batch_size=None, shuffle=False)

In [170]:
x  = next(iter(train_loader))


In [171]:
#torch.unique(x[3])

In [172]:
from src.spatial_temporal_gnn.training import Checkpoint

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=0)

#lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#    optimizer, mode='min', factor=.1, patience=2, verbose=False,
#    threshold=.001, threshold_mode='rel', cooldown=0, min_lr=1e-5, eps=1e-08)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.94, verbose=False)

checkpoint_file_path = os.path.join('..', 'models', 'checkpoints',
                                    'navigator_metr_la.pth')
checkpoint = Checkpoint(checkpoint_file_path)

EPOCHS = 5

In [173]:
from math import ceil
from time import time
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data import DataLoader

from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.spatial_temporal_gnn.metrics import MAE, RMSE, MAPE
from src.spatial_temporal_gnn.training import Checkpoint
from src.data.data_processing import Scaler

def train(
    model: Navigator, optimizer: torch.optim.Optimizer,
    train_dataloader: DataLoader, val_dataloader: DataLoader,
    spatial_temporal_gnn: SpatialTemporalGNN, scaler: Scaler,
    epochs: int, validations_per_batch: int = 1,
    checkpoint: Optional[Checkpoint] = None,
    lr_scheduler: Optional[object] = None,
    reload_best_weights: bool = True) -> Dict[str, np.ndarray]:
    # Get the device that is used for training and querying the model.
    device = model.device
    
    # Set the valdation step inside the batch
    assert validations_per_batch > 0, \
        'The number of validations per batch must be greater than zero.'
    val_step = ceil(len(train_dataloader) / validations_per_batch)

    # Initialize the training criterions.
    mae_criterion = MAE()
    rmse_criterion = RMSE()
    mape_criterion = MAPE()

    # Initialize the histories.
    metrics = ['train_mae', 'train_rmse', 'train_mape', 'val_mae', 'val_rmse',
               'val_mape']
    history = { m: [] for m in metrics }

    # Set model in training mode.
    model.train()

    # Iterate across the epochs.
    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')

        # Remove unused tensors from gpu memory.
        torch.cuda.empty_cache()

        # Initialize the running errors.
        running_train_mae = 0.
        running_train_rmse = 0.
        running_train_mape = 0.

        start_time = time()

        for batch_idx, (x, instances, t, y) in enumerate(train_dataloader):
            # Increment the number of batch steps.
            batch_steps = batch_idx + 1

            # Get the data.
            x = x.type(torch.float32).to(device=device)
            instances = instances.type(torch.float32).to(device=device)
            y = y.type(torch.float32).to(device=device)
            
            t_repeated = t.unsqueeze(0).repeat(instances.shape[0], 1).to(device=device).float()
            # Repeat y for each instance that has shape (Batch, Features)
            #t_repeated = t.unsqueeze(0).repeat(instances.shape[0], 3)

            event_scores = model(instances, t_repeated)

            ev_scores = torch.zeros((x.shape[0], x.shape[1], 1))
            
            for i, x_ in enumerate(instances):
                timestep = int(x_[0].item()); node = int(x_[1].item())
                ev_scores[timestep, node, 0] = event_scores[i]
                
            ev_scores = ev_scores.to(device=device)
            # print(ev_scores.shape, x.shape)
            x_sim = simulate_model(x, ev_scores)
            #print(x_sim.shape)
            x_sim = scaler.scale(x_sim)

            # Compute the Spatial-Temporal GNN model predictions.
            y_pred = spatial_temporal_gnn(x_sim.unsqueeze(0))

            # Un-scale the predictions.
            y_pred = scaler.un_scale(y_pred)

            loss = mae_criterion(y_pred, y.unsqueeze(0))
            
            #print(y)

            # Compute errors and update running errors.
            with torch.no_grad():
                rmse = rmse_criterion(y_pred, y.unsqueeze(0))
                mape = mape_criterion(y_pred, y.unsqueeze(0))

            running_train_mae += loss.item()
            running_train_rmse += rmse.item()
            running_train_mape += mape.item()

            # Zero the gradients.
            optimizer.zero_grad()

            # Use MAE as the loss function for backpropagation.
            loss.backward()

            for param in spatial_temporal_gnn.parameters():
                param.grad[:] = 0

            # Update the weights.
            optimizer.step()

            # Get the batch time.
            epoch_time = time() - start_time
            batch_time = epoch_time / batch_steps

            apply_validation = (batch_idx + 1) % val_step == 0
            
            # Print the batch results.
            print(
                f'[{batch_steps}/{len(train_dataloader)}] -',
                f'{epoch_time:.0f}s {batch_time * 1e3:.0f}ms/step -',

                f'train {{ MAE (loss): {running_train_mae / batch_steps:.3g} -',
                f'RMSE: {running_train_rmse / batch_steps:.3g} -',
                f'MAPE: {running_train_mape * 100. / batch_steps:.3g}% }} -',

                f'lr: {optimizer.param_groups[0]["lr"]:.3g} -',
                f'weight decay: {optimizer.param_groups[0]["weight_decay"]}',
                '             ' if batch_steps < len(train_dataloader) else '',
                end='\r' if not apply_validation else '\n')

            # Evaluate on validation set.
            if apply_validation:
                # Set the model in eval mode.
                model.eval()

                # Remove unused tensors from gpu memory.
                torch.cuda.empty_cache()

                # Compute the validation scores.
                val_results = validate(model, val_dataloader, spatial_temporal_gnn,
                                    scaler)
                val_mae, val_rmse, val_mape = val_results

                # Remove unused tensors from gpu memory.
                torch.cuda.empty_cache()
                
                # Print the validation step results.
                print(
                    '\t'
                    f'val step -',

                    f'val: {{ MAE: {val_mae:.3g} -',
                    f'RMSE: {val_rmse:.3g} -',
                    f'MAPE: {val_mape * 100.:.3g}% }} -',

                    f'lr: {optimizer.param_groups[0]["lr"]:.3g} -',
                    f'weight decay: {optimizer.param_groups[0]["weight_decay"]}'
                    )

                # Save the checpoints.
                if checkpoint is not None:
                    err_sum = val_mae + val_rmse + val_mape
                    checkpoint.save_best(model, optimizer, err_sum)

                # Set the model in train mode.
                model.train()

        # Set the model in evaluation mode.
        model.eval()

        # Get the average training errors and update the history.
        train_mae = running_train_mae / len(train_dataloader)
        train_rmse = running_train_rmse / len(train_dataloader)
        train_mape = running_train_mape / len(train_dataloader)

        history['train_mae'].append(train_mae)
        history['train_rmse'].append(train_rmse)
        history['train_mape'].append(train_mape)

        # Get the validation results and update the history.
        val_results = validate(model, val_dataloader, spatial_temporal_gnn,
                               scaler)
        val_mae, val_rmse, val_mape = val_results

        history['val_mae'].append(val_mae)
        history['val_rmse'].append(val_rmse)
        history['val_mape'].append(val_mape)

        # Save the checkpoints if demanded.
        if checkpoint is not None:
            err_sum = val_mae + val_rmse + val_mape
            checkpoint.save_best(model, optimizer, err_sum)

        # Print the epoch results.
        print(
            f'[{len(train_dataloader)}/{len(train_dataloader)}] -',
            f'{epoch_time:.0f}s -',

            f'train: {{ MAE (loss): {train_mae:.3g} -',
            f'RMSE: {train_rmse:.3g} -',
            f'MAPE: {train_mape * 100.:.3g}% }} -',

            f'val: {{ MAE: {val_mae:.3g} -',
            f'RMSE: {val_rmse:.3g} -',
            f'MAPE: {val_mape * 100.:.3g}% }} -',

            f'lr: {optimizer.param_groups[0]["lr"]:.3g} -',
            f'weight decay: {optimizer.param_groups[0]["weight_decay"]}')

        # Update the learning rate scheduler.
        #lr_scheduler.step(train_mae)

        # Set model in training mode.
        lr_scheduler.step()
        model.train()

    # Load the best weights of the model if demanded.
    if checkpoint is not None and reload_best_weights:
        checkpoint.load_best_weights(model)

    # Set the model in evaluation mode.
    model.eval()

    # Remove unused tensors from gpu memory.
    torch.cuda.empty_cache()

    # Turn the history to numpy arrays.
    for k, v in history.items():
        history[k] = np.array(v)

    return history

def validate(
    model: Navigator, val_dataloader: DataLoader,
    spatial_temporal_gnn: SpatialTemporalGNN, scaler: Scaler
    ) -> Tuple[float, float, float]:
    device = model.device
    torch.cuda.empty_cache()

    # Initialize the validation criterions.
    mae_criterion = MAE()
    rmse_criterion = RMSE()
    mape_criterion = MAPE()

    # Inizialize running errors.
    running_val_mae = 0.
    running_val_rmse = 0.
    running_val_mape = 0.

    with torch.no_grad():
        for x, instances, t, y in val_dataloader:
            # Get the data.
            x = x.type(torch.float32).to(device=device)
            instances = instances.type(torch.float32).to(device=device)
            y = y.type(torch.float32).to(device=device)
            
            t_repeated = t.unsqueeze(0).repeat(instances.shape[0], 1).to(device=device).float()
            # Repeat y for each instance that has shape (Batch, Features)
            #t_repeated = t.unsqueeze(0).repeat(instances.shape[0], 3)

            event_scores = model(instances, t_repeated)

            ev_scores = torch.zeros((x.shape[0], x.shape[1], 1))
            
            for i, x_ in enumerate(instances):
                timestep = int(x_[0].item()); node = int(x_[1].item())
                ev_scores[timestep, node, 0] = event_scores[i]
                
            ev_scores = ev_scores.to(device=device)
            # print(ev_scores.shape, x.shape)
            x_sim = simulate_model(x, ev_scores)
            #print(x_sim.shape)
            x_sim = scaler.scale(x_sim)

            # Compute the Spatial-Temporal GNN model predictions.
            y_pred = spatial_temporal_gnn(x_sim.unsqueeze(0))

            # Un-scale the predictions.
            y_pred = scaler.un_scale(y_pred)

            mae = mae_criterion(y_pred, y.unsqueeze(0))
            rmse = rmse_criterion(y_pred, y.unsqueeze(0))
            mape = mape_criterion(y_pred, y.unsqueeze(0))

            running_val_mae += mae.item()
            running_val_rmse += rmse.item()
            running_val_mape += mape.item()

    # Remove unused tensors from gpu memory.
    torch.cuda.empty_cache()

    # Get the average MAE, RMSE and MAPE scores.
    val_mae = running_val_mae / len(val_dataloader)
    val_rmse = running_val_rmse / len(val_dataloader)
    val_mape = running_val_mape / len(val_dataloader)

    return val_mae, val_rmse, val_mape

In [174]:
history = train(
    model, optimizer, train_loader, val_loader, spatial_temporal_gnn, scaler,
    EPOCHS, 4, checkpoint, lr_scheduler, reload_best_weights=True)

Epoch 1/5
[247/988] - 181s 734ms/step - train { MAE (loss): 2.18 - RMSE: 2.18 - MAPE: 6.91% } - lr: 1e-05 - weight decay: 0              
	val step - val: { MAE: 1.49 - RMSE: 1.49 - MAPE: 4.94% } - lr: 1e-05 - weight decay: 0
[404/988] - 346s 856ms/step - train { MAE (loss): 2.27 - RMSE: 2.27 - MAPE: 7.49% } - lr: 1e-05 - weight decay: 0              

KeyboardInterrupt: 

In [None]:
#next(iter(train_loader))[3].shape

In [175]:
x, ev, t, y = next(iter(train_loader))

In [176]:
t_repeated = t.unsqueeze(0).repeat(ev.shape[0], 1).to(device=DEVICE).float()
res = model(ev.to(DEVICE).float(), t_repeated)


ev_scores = torch.zeros((x.shape[0], x.shape[1], 1))
            
for i, x_ in enumerate(ev):
    timestep = int(x_[0].item()); node = int(x_[1].item())
    ev_scores[timestep, node, 0] = res[i]

In [179]:
len(torch.unique(ev_scores.sigmoid()))

1635

In [178]:
for ev in torch.unique(ev_scores.sigmoid()):
    print(ev)

tensor(0.1896, grad_fn=<UnbindBackward0>)
tensor(0.1937, grad_fn=<UnbindBackward0>)
tensor(0.1969, grad_fn=<UnbindBackward0>)
tensor(0.1996, grad_fn=<UnbindBackward0>)
tensor(0.2210, grad_fn=<UnbindBackward0>)
tensor(0.2268, grad_fn=<UnbindBackward0>)
tensor(0.2314, grad_fn=<UnbindBackward0>)
tensor(0.2339, grad_fn=<UnbindBackward0>)
tensor(0.2343, grad_fn=<UnbindBackward0>)
tensor(0.2447, grad_fn=<UnbindBackward0>)
tensor(0.2837, grad_fn=<UnbindBackward0>)
tensor(0.3119, grad_fn=<UnbindBackward0>)
tensor(0.4485, grad_fn=<UnbindBackward0>)
tensor(0.4574, grad_fn=<UnbindBackward0>)
tensor(0.4831, grad_fn=<UnbindBackward0>)
tensor(0.5000, grad_fn=<UnbindBackward0>)
tensor(0.5013, grad_fn=<UnbindBackward0>)
tensor(0.5120, grad_fn=<UnbindBackward0>)
tensor(0.5475, grad_fn=<UnbindBackward0>)
tensor(0.5703, grad_fn=<UnbindBackward0>)
tensor(0.5798, grad_fn=<UnbindBackward0>)
tensor(0.5834, grad_fn=<UnbindBackward0>)
tensor(0.5839, grad_fn=<UnbindBackward0>)
tensor(0.6013, grad_fn=<UnbindBack