In [1]:
import sys
import os

# Set the main path in the root folder of the project.
sys.path.append(os.path.join('..'))

In [2]:
# Settings for autoreloading.
%load_ext autoreload
%autoreload 2

In [3]:
from src.utils.seed import set_random_seed

# Set the random seed for deterministic operations.
SEED = 42
set_random_seed(SEED)

In [4]:
import torch

# Set the device for training and querying the model.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'The selected device is: "{DEVICE}"')

The selected device is: "cuda"


# Building the event set

In [5]:
import os

BASE_DATA_DIR = os.path.join('..', 'data', 'metr-la')

In [6]:
import pickle
with open(os.path.join(BASE_DATA_DIR, 'processed', 'scaler.pkl'), 'rb') as f:
    scaler = pickle.load(f)

In [7]:
from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.data.data_extraction import get_adjacency_matrix

# Get the adjacency matrix
adj_matrix_structure = get_adjacency_matrix(
    os.path.join(BASE_DATA_DIR, 'adj_mx_metr_la.pkl'))

# Get the header of the adjacency matrix and the matrix itself.
header, _, adj_matrix = adj_matrix_structure

spatial_temporal_gnn = SpatialTemporalGNN(9, 1, 12, 12, adj_matrix, DEVICE, 64)

stgnn_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                      'st_gnn_metr_la.pth')

stgnn_checkpoints = torch.load(stgnn_checkpoints_path)
spatial_temporal_gnn.load_state_dict(stgnn_checkpoints['model_state_dict'])

spatial_temporal_gnn.eval()

SpatialTemporalGNN(
  (encoder): Linear(in_features=9, out_features=64, bias=False)
  (s_gnns): ModuleList(
    (0): S_GNN(
      (latent_encoder): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=False)
        (1): Linear(in_features=64, out_features=32, bias=False)
      )
      (linear): Linear(in_features=64, out_features=64, bias=False)
    )
    (1): S_GNN(
      (latent_encoder): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=False)
        (1): Linear(in_features=64, out_features=32, bias=False)
      )
      (linear): Linear(in_features=64, out_features=64, bias=False)
    )
    (2): S_GNN(
      (latent_encoder): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=False)
        (1): Linear(in_features=64, out_features=32, bias=False)
      )
      (linear): Linear(in_features=64, out_features=64, bias=False)
    )
    (3): S_GNN(
      (latent_encoder): Sequential(
        (0): Linear(in_features=64, out_features

In [8]:
import os
import numpy as np
from src.spatial_temporal_gnn.prediction import predict

x_train = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_train.npy'))
y_train = predict(spatial_temporal_gnn, x_train, scaler, DEVICE)
x_val = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_val.npy'))
y_val = predict(spatial_temporal_gnn, x_val, scaler, DEVICE)
x_test = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_test.npy'))
y_test = predict(spatial_temporal_gnn, x_test, scaler, DEVICE)

In [9]:
from typing import List, Optional, Tuple

from src.data.data_analysis import days_encoder

def get_largest_event_set(data: np.ndarray
                          ) -> List[Tuple[str, Optional[int], Optional[int]]]:
    n_time_steps, n_nodes, _ = data.shape[-3:]
    
    # Get the largest event set related to the speed.
    speed_events = [
        (0, time_step, node) 
        for time_step in range(n_time_steps) 
        for node in range(n_nodes) 
        if data[..., time_step, node, 0] > 0]
    # Get the largest event set related to the time of day.
    time_of_day_events = [
        (1, time_step, None) for time_step in range(n_time_steps)]
    # Get the largest event set related to the day of week.
    day_of_week_events = [(2, None, None)]
    # Get the largest event set related to the kind of day.
    #kind_of_day_events = [(3, None, None)]
    
    return speed_events + time_of_day_events + day_of_week_events #+\
    #    kind_of_day_events

In [10]:
s = get_largest_event_set(x_test[11])

# Map the event set to the graph

In [11]:
from typing import List, Optional, Tuple, Union
import torch

def remove_features_by_events(
    data: Union[np.ndarray, torch.FloatTensor],
    events: List[Tuple[int, Optional[int], Optional[int]]]
    ) -> Union[np.ndarray, torch.FloatTensor]:
    if isinstance(data, torch.FloatTensor):
        filtered_data = data.clone()
    else:
        filtered_data = data.copy()
    n_time_steps, n_nodes, n_features = filtered_data.shape[-3:]
    
    speed_events = [tuple(event) for event in events if event[0] == 0]
    time_of_day_events = [tuple(event) for event in events if event[0] == 1]
    day_of_week_events = [tuple(event) for event in events if event[0] == 2]
    #kind_of_day_events = [event for event in events if event[0] == 3]
    
    if n_features > 1 and not len(day_of_week_events):
        filtered_data[..., -7:] = 0
    
    for time_step in range(n_time_steps):
        for node in range(n_nodes):
            if (0, time_step, node) not in speed_events:
                filtered_data[..., time_step, node, 0] = 0

            '''if n_features > 1 and len(kind_of_day_events):
                if 1 in data[..., -7:-2]:
                    filtered_data[..., time_step, node, -7:-2] = 1
                else:
                    filtered_data[..., time_step, node, -2:] = 1'''

        if n_features > 1 and (1, time_step, None) not in time_of_day_events:
            filtered_data[..., time_step, :, 1] = -1
    
    return filtered_data

In [12]:
def remove_single_event_from_data(
    data: Union[np.ndarray, torch.FloatTensor],
    event: Union[np.ndarray, torch.FloatTensor]
    ) -> Union[np.ndarray, torch.FloatTensor]:
    # n_time_steps, n_nodes, _ = data.shape[-3:]
    event_kind = event[0]
    time_step = int(event[1].item())
    node = int(event[2].item())
    if event_kind == 0:
        data[..., time_step, node, 0] = 0
    elif event_kind == 1:
        data[..., time_step, :, 1] = -1
    elif event_kind == 2:
        data[..., -7:] = 0
    #elif event[0] == 3:
    #    pass

    '''if event_kind in [0, 1]: 
        for time_step in range(n_time_steps):
            for node in range(n_nodes):
                if 1 in data[..., -7:-2]:
                    data[..., time_step, node, -7:-2] = 1
                else:
                    data[..., time_step, node, -2:] = 1'''

    return data

In [13]:
s = remove_features_by_events(x_test[0], [(1, 0, None)])

In [14]:
# (event_type, speed, time_of_day, ...day, speed_target, time_of_day_target, ...day_target)

# ->

# -> score

In [15]:
y_test[0].shape

(12, 207, 1)

In [16]:
import torch
from torch import nn
import numpy as np

def simulate_model(
    instance: torch.FloatTensor, events_scores: torch.FloatTensor) -> torch.FloatTensor:
    #eps = torch.rand_like(events_scores).float().to(instance.device)
    #eps_hat = (eps.log() - (1 - eps).log() + instance) / 1.0
    events_scores = events_scores.sigmoid()
    #eps = torch.rand_like(events_scores)
    eps = torch.rand(1).to(instance.device)
    events_scores = torch.sigmoid((eps.log() - (1 - eps).log() + events_scores) / 2.0)
    # TODO: Simulate all events, not just the speed events.
    #relaxed_bernoulli = torch.distributions.RelaxedBernoulli(2.0, events_scores)
    #e = torch.sigmoid(events_scores)
    #e = relaxed_bernoulli.sample()
    #e = torch.rand(1).float().to(instance.device)
    #e_hat = torch.sigmoid((torch.log(e) - torch.log(1 - e) + events_scores) / .05)
    result = events_scores >= .5
    #print(result)
    #result = torch.bernoulli(events_scores)
    instance = result * instance
    return instance

In [17]:
import torch
from torch import nn

class Navigator(nn.Module):
    def __init__(self, device: str, hidden_features: int = 64) -> None:
        super().__init__()
        self.graph_encoder = nn.Conv2d(12, 6, kernel_size=(207, 1))


        # Set the linear encoder.
        self.linear_encoder = nn.LazyLinear(hidden_features)
        # Set the linear decoder.
        self.linear_decoder = nn.Linear(hidden_features, 1)
        # Set the device that is used for training and querying the model.
        self.device = device
        self.to(device)

    def forward(self, candidate_event: torch.FloatTensor, target_events: torch.FloatTensor) -> torch.FloatTensor:
        # Flatten the target events.
        y = self.graph_encoder(target_events)
        y = y.flatten(start_dim=1)
        # print(y.shape, y)
        #target_events = target_events.flatten(start_dim=1)
        # Concatenate the candidate event and the target events.
        x = torch.cat((candidate_event, y), dim=1)
        # Encode the input.
        out = self.linear_encoder(x)
        # Decode the output to get the logits prediction.
        out = self.linear_decoder(out)
        return out
        

In [18]:
model = Navigator(DEVICE)



In [19]:
from typing import Tuple
from torch.utils.data.dataloader import DataLoader, Dataset
import numpy as np

from src.spatial_temporal_gnn.prediction import predict


class EventsDataset(Dataset):
    def __init__(self, x: np.ndarray, y: np.ndarray) -> None:
        """Initialize the dataset

        Parameters
        ----------
        x : ndarray
            The input values of the dataset.
        y : ndarray
            The ground truth of the dataset's input data.
        """
        # Filter out the instances that do not have speed values at all.
        self.x = x[np.any(x[..., 0] != 0, axis=(1, 2))]
        self.y = y[np.any(x[..., 0] != 0, axis=(1, 2))]
        self.len = self.x.shape[0]

    def _mask_random_instance_features(self, instance: np.ndarray) -> np.ndarray:
        """Mask random features of a Spatial-Temporal Graph instance.

        Parameters
        ----------
        instance : ndarray
            The Spatial-Temporal Graph instance to mask random features
            from.

        Returns
        -------
        ndarray
            The Spatial-Temporal Graph instance with masked random
            features.
        """
        # Get the largest event set of the input instance.
        events = np.array(get_largest_event_set(instance))

        # TODO: consider all events, not solely speed events.
        events = events[events[..., 0] == 0]

        # Get the possible subset sizes of the input and target events.
        events_subset_sizes = np.arange(1, len(events) + 1)
        
        # Compute the probabilities of selecting certain subset sizes.
        # Favor larger subset sizes.
        sizes_probs = np.array(
            events_subset_sizes / np.sum(events_subset_sizes))

        # Select the size of the input events subset.
        [size] = np.random.choice(range(1, len(events) + 1), 1, p=sizes_probs)

        # Randomly select the events subset based on the selected size.
        selected_events_idx = np.random.choice(
            np.arange(len(events)), size=size, replace=False)
        selected_events = events[selected_events_idx]
        # Remove the features corresponding to the selected events.
        instance = remove_features_by_events(instance, selected_events.tolist())
        return instance

    def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
        """Get a dataset 

        Parameters
        ----------
        index : int
            The index from where to extract the dataset instance.

        Returns
        -------
        ndarray
            The input data at the given index with masked random
            features.
        ndarray
            The ground truth with respect to the input data at the
            given index with masked random features.
        """
        x = self.x[index].copy()
        y = self.y[index].copy()
        
        events = np.array(get_largest_event_set(x))
        events = events[events[..., 0] == 0]
        #print(events)

        instances = []

        for event in events:
            instance = x[event[1], event[2], :]
            instance = [event[1], event[2], *instance]
            instances.append(instance)
        y = self._mask_random_instance_features(y)

        instances = np.array(instances)

        return x, instances, y #y_pred

    def __len__(self) -> int:
        """Get the length of the dataset.

        Returns
        -------
        int
            The length of the dataset.
        """
        return self.len

def get_dataloader(x: np.ndarray, y: np.ndarray, 
                   batch_size: int, shuffle: bool) -> DataLoader:
    dataset = EventsDataset(x, y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)


In [20]:
train_loader = get_dataloader(x_train, y_train, batch_size=None, shuffle=True)
val_loader = get_dataloader(x_val, y_val, batch_size=None, shuffle=False)

In [21]:
x  = next(iter(train_loader))


In [22]:
from src.spatial_temporal_gnn.training import Checkpoint

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0)

#lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#    optimizer, mode='min', factor=.1, patience=2, verbose=False,
#    threshold=.001, threshold_mode='rel', cooldown=0, min_lr=1e-5, eps=1e-08)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.94, verbose=False)

checkpoint_file_path = os.path.join('..', 'models', 'checkpoints',
                                    'navigator_metr_la.pth')
checkpoint = Checkpoint(checkpoint_file_path)

EPOCHS = 100

In [23]:
from time import time
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data import DataLoader

from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.spatial_temporal_gnn.metrics import MAE, RMSE, MAPE
from src.spatial_temporal_gnn.training import Checkpoint
from src.data.data_processing import Scaler

def train(
    model: Navigator, optimizer: torch.optim.Optimizer,
    train_dataloader: DataLoader, val_dataloader: DataLoader,
    spatial_temporal_gnn: SpatialTemporalGNN, scaler: Scaler,
    epochs: int, checkpoint: Optional[Checkpoint] = None,
    lr_scheduler: Optional[object] = None,
    reload_best_weights: bool = True) -> Dict[str, np.ndarray]:
    # Get the device that is used for training and querying the model.
    device = model.device

    # Initialize the training criterions.
    mae_criterion = MAE()
    rmse_criterion = RMSE()
    mape_criterion = MAPE()

    # Initialize the histories.
    metrics = ['train_mae', 'train_rmse', 'train_mape', 'val_mae', 'val_rmse',
               'val_mape']
    history = { m: [] for m in metrics }

    # Set model in training mode.
    model.train()

    # Iterate across the epochs.
    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')

        # Remove unused tensors from gpu memory.
        torch.cuda.empty_cache()

        # Initialize the running errors.
        running_train_mae = 0.
        running_train_rmse = 0.
        running_train_mape = 0.

        start_time = time()

        for batch_idx, (x, instances, y) in enumerate(train_dataloader):
            # Increment the number of batch steps.
            batch_steps = batch_idx + 1

            # Get the data.
            x = x.type(torch.float32).to(device=device)
            instances = instances.type(torch.float32).to(device=device)
            y = y.type(torch.float32).to(device=device)
            
            y_repeated = y.unsqueeze(0).repeat(instances.shape[0], 1, 1, 1)
            #print(y_repeated[0] == y)
            # Repeat y for each instance that has shape (Batch, Features)
            #t_repeated = t.unsqueeze(0).repeat(instances.shape[0], 3)

            event_scores = model(instances, y_repeated)

            ev_scores = torch.zeros((x.shape[0], x.shape[1], 1))
            
            for i, x_ in enumerate(instances):
                timestep = int(x_[0].item()); node = int(x_[1].item())
                ev_scores[timestep, node, 0] = event_scores[i]
                
            ev_scores = ev_scores.to(device=device)
            # print(ev_scores.shape, x.shape)
            x_sim = simulate_model(x, ev_scores)
            #print(x_sim.shape)
            x_sim = scaler.scale(x_sim)

            # Compute the Spatial-Temporal GNN model predictions.
            y_pred = spatial_temporal_gnn(x_sim.unsqueeze(0))

            # Un-scale the predictions.
            y_pred = scaler.un_scale(y_pred)

            loss = mae_criterion(y_pred, y.unsqueeze(0))
            
            #print(y)

            # Compute errors and update running errors.
            with torch.no_grad():
                rmse = rmse_criterion(y_pred, y.unsqueeze(0))
                mape = mape_criterion(y_pred, y.unsqueeze(0))

            running_train_mae += loss.item()
            running_train_rmse += rmse.item()
            running_train_mape += mape.item()

            # Zero the gradients.
            optimizer.zero_grad()

            # Use MAE as the loss function for backpropagation.
            loss.backward()
            
            for param in spatial_temporal_gnn.parameters():
                param.grad[:] = 0

            # Update the weights.
            optimizer.step()

            # Get the batch time.
            epoch_time = time() - start_time
            batch_time = epoch_time / batch_steps

            # Print the batch results.
            print(
                f'[{batch_steps}/{len(train_dataloader)}] -',
                f'{epoch_time:.0f}s {batch_time * 1e3:.0f}ms/step -',

                f'train {{ MAE (loss): {running_train_mae / batch_steps:.3g} -',
                f'RMSE: {running_train_rmse / batch_steps:.3g} -',
                f'MAPE: {running_train_mape * 100. / batch_steps:.3g}% }} -',

                f'lr: {optimizer.param_groups[0]["lr"]:.3g} -',
                f'weight decay: {optimizer.param_groups[0]["weight_decay"]}',
                '             ' if batch_steps < len(train_dataloader) else '',
                end='\r')

        # Set the model in evaluation mode.
        model.eval()

        # Get the average training errors and update the history.
        train_mae = running_train_mae / len(train_dataloader)
        train_rmse = running_train_rmse / len(train_dataloader)
        train_mape = running_train_mape / len(train_dataloader)

        history['train_mae'].append(train_mae)
        history['train_rmse'].append(train_rmse)
        history['train_mape'].append(train_mape)

        # Get the validation results and update the history.
        val_results = validate(model, val_dataloader, spatial_temporal_gnn,
                               scaler)
        val_mae, val_rmse, val_mape = val_results

        history['val_mae'].append(val_mae)
        history['val_rmse'].append(val_rmse)
        history['val_mape'].append(val_mape)

        # Save the checkpoints if demanded.
        if checkpoint is not None:
            err_sum = val_mae + val_rmse + val_mape
            checkpoint.save_best(model, optimizer, err_sum)

        # Print the epoch results.
        print(
            f'[{len(train_dataloader)}/{len(train_dataloader)}] -',
            f'{epoch_time:.0f}s -',

            f'train: {{ MAE (loss): {train_mae:.3g} -',
            f'RMSE: {train_rmse:.3g} -',
            f'MAPE: {train_mape * 100.:.3g}% }} -',

            f'val: {{ MAE: {val_mae:.3g} -',
            f'RMSE: {val_rmse:.3g} -',
            f'MAPE: {val_mape * 100.:.3g}% }} -',

            f'lr: {optimizer.param_groups[0]["lr"]:.3g} -',
            f'weight decay: {optimizer.param_groups[0]["weight_decay"]}')

        # Update the learning rate scheduler.
        #lr_scheduler.step(train_mae)

        # Set model in training mode.
        lr_scheduler.step()
        model.train()

    # Load the best weights of the model if demanded.
    if checkpoint is not None and reload_best_weights:
        checkpoint.load_best_weights(model)

    # Set the model in evaluation mode.
    model.eval()

    # Remove unused tensors from gpu memory.
    torch.cuda.empty_cache()

    # Turn the history to numpy arrays.
    for k, v in history.items():
        history[k] = np.array(v)

    return history

def validate(
    model: Navigator, val_dataloader: DataLoader,
    spatial_temporal_gnn: SpatialTemporalGNN, scaler: Scaler
    ) -> Tuple[float, float, float]:
    device = model.device
    torch.cuda.empty_cache()

    # Initialize the validation criterions.
    mae_criterion = MAE()
    rmse_criterion = RMSE()
    mape_criterion = MAPE()

    # Inizialize running errors.
    running_val_mae = 0.
    running_val_rmse = 0.
    running_val_mape = 0.

    with torch.no_grad():
        for x, instances, y in val_dataloader:
            # Get the data.
            x = x.type(torch.float32).to(device=device)
            instances = instances.type(torch.float32).to(device=device)
            y = y.type(torch.float32).to(device=device)
            
            y_repeated = y.unsqueeze(0).repeat(instances.shape[0], 1, 1, 1)

            event_scores = model(instances, y_repeated)

            ev_scores = torch.zeros((x.shape[0], x.shape[1], 1))
            
            for i, x_ in enumerate(instances):
                timestep = int(x_[0].item()); node = int(x_[1].item())
                ev_scores[timestep, node, 0] = event_scores[i]
                
            ev_scores = ev_scores.to(device=device)
            # print(ev_scores.shape, x.shape)
            x_sim = simulate_model(x, ev_scores)
            #print(x_sim.shape)
            x_sim = scaler.scale(x_sim)

            # Compute the Spatial-Temporal GNN model predictions.
            y_pred = spatial_temporal_gnn(x_sim.unsqueeze(0))

            # Un-scale the predictions.
            y_pred = scaler.un_scale(y_pred)

            # Get the prediction errors and update the running errors.
            mae = mae_criterion(y_pred, y.unsqueeze(0))
            rmse = rmse_criterion(y_pred, y.unsqueeze(0))
            mape = mape_criterion(y_pred, y.unsqueeze(0))

            running_val_mae += mae.item()
            running_val_rmse += rmse.item()
            running_val_mape += mape.item()

    # Remove unused tensors from gpu memory.
    torch.cuda.empty_cache()

    # Get the average MAE, RMSE and MAPE scores.
    val_mae = running_val_mae / len(val_dataloader)
    val_rmse = running_val_rmse / len(val_dataloader)
    val_mape = running_val_mape / len(val_dataloader)

    return val_mae, val_rmse, val_mape

In [24]:
history = train(
    model, optimizer, train_loader, val_loader, spatial_temporal_gnn, scaler,
    EPOCHS, checkpoint, lr_scheduler, reload_best_weights=True)

Epoch 1/100
[988/988] - 724s - train: { MAE (loss): 3.06 - RMSE: 5 - MAPE: 9.13% } - val: { MAE: 2.82 - RMSE: 4.72 - MAPE: 7.84% } - lr: 0.01 - weight decay: 0
Epoch 2/100
[988/988] - 722s - train: { MAE (loss): 3.47 - RMSE: 5.66 - MAPE: 10.6% } - val: { MAE: 3.48 - RMSE: 5.57 - MAPE: 10.2% } - lr: 0.0094 - weight decay: 0
Epoch 3/100
[988/988] - 726s - train: { MAE (loss): 3.47 - RMSE: 5.62 - MAPE: 10.6% } - val: { MAE: 3.19 - RMSE: 5.23 - MAPE: 8.95% } - lr: 0.00884 - weight decay: 0
Epoch 4/100
[276/988] - 203s 734ms/step - train { MAE (loss): 3.47 - RMSE: 5.56 - MAPE: 10.5% } - lr: 0.00831 - weight decay: 0              

KeyboardInterrupt: 

In [None]:
#for(i, j) in zip(spatial_temporal_gnn.state_dict(), stgnn_checkpoints['model_state_dict']):
#    if i != j: print(True)

In [None]:
x, ev, y = next(iter(train_loader))

In [None]:
res = model(ev.to(DEVICE).float(), y.unsqueeze(0).repeat(ev.shape[0], 1, 1, 1).to(DEVICE).float())
ev_scores = torch.zeros((x.shape[0], x.shape[1], 1))
            
for i, x_ in enumerate(ev):
    timestep = int(x_[0].item()); node = int(x_[1].item())
    ev_scores[timestep, node, 0] = res[i]

In [None]:
for ev in torch.unique(ev_scores.sigmoid()):
    print(ev)

tensor(2.4888e-05, grad_fn=<UnbindBackward0>)
tensor(2.5055e-05, grad_fn=<UnbindBackward0>)
tensor(2.6995e-05, grad_fn=<UnbindBackward0>)
tensor(2.8503e-05, grad_fn=<UnbindBackward0>)
tensor(2.8654e-05, grad_fn=<UnbindBackward0>)
tensor(2.9176e-05, grad_fn=<UnbindBackward0>)
tensor(3.2264e-05, grad_fn=<UnbindBackward0>)
tensor(3.2876e-05, grad_fn=<UnbindBackward0>)
tensor(3.3349e-05, grad_fn=<UnbindBackward0>)
tensor(3.3994e-05, grad_fn=<UnbindBackward0>)
tensor(3.4439e-05, grad_fn=<UnbindBackward0>)
tensor(3.5410e-05, grad_fn=<UnbindBackward0>)
tensor(3.5991e-05, grad_fn=<UnbindBackward0>)
tensor(3.6055e-05, grad_fn=<UnbindBackward0>)
tensor(3.7073e-05, grad_fn=<UnbindBackward0>)
tensor(3.8583e-05, grad_fn=<UnbindBackward0>)
tensor(4.0125e-05, grad_fn=<UnbindBackward0>)
tensor(4.0885e-05, grad_fn=<UnbindBackward0>)
tensor(4.1972e-05, grad_fn=<UnbindBackward0>)
tensor(4.1987e-05, grad_fn=<UnbindBackward0>)
tensor(4.2031e-05, grad_fn=<UnbindBackward0>)
tensor(4.2752e-05, grad_fn=<Unbind