### 1. Install Dependencies

In [1]:
! pip install configargparse

Collecting configargparse
  Downloading ConfigArgParse-1.7-py3-none-any.whl.metadata (23 kB)
Downloading ConfigArgParse-1.7-py3-none-any.whl (25 kB)
Installing collected packages: configargparse
Successfully installed configargparse-1.7


In [2]:
! pip install torch torchvision torchaudio

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

### 2. Imports

In [3]:
import os
import shutil
import time
from datetime import datetime
import pickle
from google.colab import drive

from abc import ABC, abstractmethod
from collections import OrderedDict

import math
import random
import numpy as np

import torch
from torch.utils.data import Dataset
from torch.autograd import grad
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch import nn

import wandb

import matplotlib
import matplotlib.pyplot as plt
import plotly.express as px
import scipy.io as spio

from tqdm.autonotebook import tqdm
from sklearn import svm

import configargparse

  from tqdm.autonotebook import tqdm


In [4]:
torch.cuda.is_available()

True

### 3. Connect to Google Drive

In [5]:
drive.mount("/content/drive",force_remount=True)
os.chdir("/content/drive/My Drive")

Mounted at /content/drive


### 4. Utils Code

In [6]:
# uses model input and real boundary fn
class ReachabilityDataset(Dataset):
    def __init__(self, dynamics, numpoints, pretrain, pretrain_iters, tMin, tMax, counter_start, counter_end, num_src_samples, num_target_samples):
        self.dynamics = dynamics
        self.numpoints = numpoints
        self.pretrain = pretrain
        self.pretrain_counter = 0
        self.pretrain_iters = pretrain_iters
        self.tMin = tMin
        self.tMax = tMax
        self.counter = counter_start
        self.counter_end = counter_end
        self.num_src_samples = num_src_samples
        self.num_target_samples = num_target_samples

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        # uniformly sample domain and include coordinates where source is non-zero
        model_states = torch.zeros(self.numpoints, self.dynamics.state_dim).uniform_(-1, 1)
        if self.num_target_samples > 0:
            target_state_samples = self.dynamics.sample_target_state(self.num_target_samples)
            model_states[-self.num_target_samples:] = self.dynamics.coord_to_input(torch.cat((torch.zeros(self.num_target_samples, 1), target_state_samples), dim=-1))[:, 1:self.dynamics.state_dim+1]

        if self.pretrain:
            # only sample in time around the initial condition
            times = torch.full((self.numpoints, 1), self.tMin)
        else:
            # slowly grow time values from start time
            times = self.tMin + torch.zeros(self.numpoints, 1).uniform_(0, (self.tMax-self.tMin) * (self.counter / self.counter_end))
            # make sure we always have training samples at the initial time
            times[-self.num_src_samples:, 0] = self.tMin
        model_coords = torch.cat((times, model_states), dim=1)
        if self.dynamics.input_dim > self.dynamics.state_dim + 1: # temporary workaround for having to deal with dynamics classes for parametrized models with extra inputs
            model_coords = torch.cat((model_coords, torch.zeros(self.numpoints, self.dynamics.input_dim - self.dynamics.state_dim - 1)), dim=1)

        boundary_values = self.dynamics.boundary_fn(self.dynamics.input_to_coord(model_coords)[..., 1:])
        if self.dynamics.loss_type == 'brat_hjivi':
            reach_values = self.dynamics.reach_fn(self.dynamics.input_to_coord(model_coords)[..., 1:])
            avoid_values = self.dynamics.avoid_fn(self.dynamics.input_to_coord(model_coords)[..., 1:])

        if self.pretrain:
            dirichlet_masks = torch.ones(model_coords.shape[0]) > 0
        else:
            # only enforce initial conditions around self.tMin
            dirichlet_masks = (model_coords[:, 0] == self.tMin)

        if self.pretrain:
            self.pretrain_counter += 1
        elif self.counter < self.counter_end:
            self.counter += 1

        if self.pretrain and self.pretrain_counter == self.pretrain_iters:
            self.pretrain = False

        if self.dynamics.loss_type == 'brt_hjivi':
            return {'model_coords': model_coords}, {'boundary_values': boundary_values, 'dirichlet_masks': dirichlet_masks}
        elif self.dynamics.loss_type == 'brat_hjivi':
            return {'model_coords': model_coords}, {'boundary_values': boundary_values, 'reach_values': reach_values, 'avoid_values': avoid_values, 'dirichlet_masks': dirichlet_masks}
        else:
            raise NotImplementedError

In [7]:
# TODO: I don't think jacobian is needed here; torch.autograd.grad should be enough, to compute gradients of a scalar value function w.r.t. inputs

# batched jacobian
# y: [..., N], x: [..., M] -> [..., N, M]
def jacobian(y, x):
    ''' jacobian of y wrt x '''
    jac = torch.zeros(*y.shape, x.shape[-1]).to(y.device)
    for i in range(y.shape[-1]):
        # calculate dydx over batches for each feature value of y
        y_flat = y[...,i].view(-1, 1)
        jac[..., i, :] = grad(y_flat, x, torch.ones_like(y_flat), create_graph=True)[0]

    status = 0
    if torch.any(torch.isnan(jac)):
        status = -1

    return jac, status

In [8]:
class Validator(ABC):
    @abstractmethod
    def validate(self, coords, values):
        raise NotImplementedError

class ValueThresholdValidator(Validator):
    def __init__(self, v_min, v_max):
        self.v_min = v_min
        self.v_max = v_max

    def validate(self, coords, values):
        return (values >= self.v_min)*(values <= self.v_max)

class MLPValidator(Validator):
    def __init__(self, device, mlp, o_min, o_max, model, dynamics):
        self.device = device
        self.mlp = mlp
        self.o_min = o_min
        self.o_max = o_max
        self.model = model
        self.dynamics = dynamics

    def validate(self, coords, values):
        model_results = self.model({'coords': self.dynamics.coord_to_input(coords.to(self.device))})
        inputs = torch.cat((coords[..., 1:].to(self.device), values[:, None].to(self.device)), dim=-1)
        outputs = torch.sigmoid(self.mlp(inputs).squeeze())
        return ((outputs >= self.o_min)*(outputs <=self.o_max)).to(device=values.device)

class MLPConditionedValidator(Validator):
    def __init__(self, device, mlp, o_levels, v_levels, model, dynamics):
        self.device = device
        self.mlp = mlp
        self.o_levels = o_levels
        self.v_levels = v_levels
        self.model = model
        self.dynamics = dynamics
        assert len(self.o_levels) == len(self.v_levels) + 1

    def validate(self, coords, values):
        model_results = self.model({'coords': self.dynamics.coord_to_input(coords.to(self.device))})
        inputs = torch.cat((coords[..., 1:].to(self.device), values[:, None].to(self.device)), dim=-1)
        outputs = torch.sigmoid(self.mlp(inputs).squeeze(dim=-1)).to(device=values.device)
        valids = torch.zeros_like(outputs)
        for i in range(len(self.o_levels) - 1):
            valids = torch.logical_or(
                valids,
                (outputs > self.o_levels[i])*(outputs <= self.o_levels[i+1])*(values >= self.v_levels[i][0])*(values <= self.v_levels[i][1])
            )
        return valids

class MultiValidator(Validator):
    def __init__(self, validators):
        self.validators = validators

    def validate(self, coords, values):
        result = self.validators[0].validate(coords, values)
        for i in range(len(self.validators)-1):
            result = result * self.validators[i+1].validate(coords, values)
        return result

class SampleGenerator(ABC):
    @abstractmethod
    def sample(self, num_samples):
        raise NotImplementedError

class SliceSampleGenerator(SampleGenerator):
    def __init__(self, dynamics, slices):
        self.dynamics = dynamics
        self.slices = slices
        assert self.dynamics.state_dim == len(slices)

    def sample(self, num_samples):
        samples = torch.zeros(num_samples, self.dynamics.state_dim)
        for dim in range(self.dynamics.state_dim):
            if self.slices[dim] is None:
                samples[:, dim].uniform_(*self.dynamics.state_test_range()[dim])
            else:
                samples[:, dim] = self.slices[dim]
        return samples

import torch
from tqdm import tqdm

# # get the tEarliest in [tMin:tMax:dt] at which the state is still valid
# def get_tEarliest(device, model, dynamics, state, tMin, tMax, dt, validator):
#     with torch.no_grad():
#         tEarliest = torch.full(state.shape[:-1], tMin - 1)
#         model_state = dynamics.normalize_state(state)

#         times_to_try = torch.arange(tMin, tMax + dt, dt)
#         for time_to_try in times_to_try:
#             blank_idx = (tEarliest < tMin)
#             time = torch.full((*state.shape[:-1], 1), time_to_try)
#             model_time = dynamics.normalize_time(time)
#             model_coord = torch.cat((model_time, model_state), dim=-1)[blank_idx]
#             model_result = model({'coords': model_coord.to(device)})
#             value = dynamics.output_to_value(output=model_result['model_out'][..., 0], state=state.to(device)).cpu()
#             valid_idx = validator.validate(torch.cat((time, state), dim=-1), value)
#             tMasked = tEarliest[blank_idx]
#             tMasked[valid_idx] = time_to_try
#             tEarliest[blank_idx] = tMasked
#             if torch.all(tEarliest >= tMin):
#                 break
#         blank_idx = (tEarliest < tMin)
#         if torch.any(blank_idx):
#             print(str(torch.sum(blank_idx)), 'invalid states')
#             tEarliest[blank_idx] = tMax
#         return tEarliest

def scenario_optimization(device, model, policy, dynamics, tMin, tMax, dt, set_type, control_type, scenario_batch_size, sample_batch_size, sample_generator, sample_validator, violation_validator, max_scenarios=None, max_samples=None, max_violations=None, tStart_generator=None):
    rem = ((tMax-tMin) / dt)%1
    e_tol = 1e-12
    assert rem < e_tol or 1 - rem < e_tol, f'{tMax-tMin} is not divisible by {dt}'
    assert tMax > tMin
    assert set_type in ['BRS', 'BRT']
    if set_type == 'BRS':
        print('confirm correct calculation of true values of trajectories (batch_scenario_costs)')
        raise NotImplementedError
    assert control_type in ['value', 'ttr', 'init_ttr']
    assert max_scenarios or max_samples or max_violations, 'one of the termination conditions must be used'
    if max_scenarios:
        assert (max_scenarios / scenario_batch_size)%1 == 0, 'max_scenarios is not divisible by scenario_batch_size'
    if max_samples:
        assert (max_samples / sample_batch_size)%1 == 0, 'max_samples is not divisible by sample_batch_size'

    # accumulate scenarios
    times = torch.zeros(0, )
    states = torch.zeros(0, dynamics.state_dim)
    values = torch.zeros(0, )
    costs = torch.zeros(0, )
    init_hams = torch.zeros(0, )
    mean_hams = torch.zeros(0, )
    mean_abs_hams = torch.zeros(0, )
    max_abs_hams = torch.zeros(0, )
    min_abs_hams = torch.zeros(0, )

    num_scenarios = 0
    num_samples = 0
    num_violations = 0

    pbar_pos = 0
    if max_scenarios:
        scenarios_pbar = tqdm(total=max_scenarios, desc='Scenarios', position=pbar_pos)
        pbar_pos += 1
    if max_samples:
        samples_pbar = tqdm(total=max_samples, desc='Samples', position=pbar_pos)
        pbar_pos += 1
    if max_violations:
        violations_pbar = tqdm(total=max_violations, desc='Violations', position=pbar_pos)
        pbar_pos += 1

    nums_valid_samples = []
    while True:
        if (max_scenarios and (num_scenarios >= max_scenarios)) or (max_violations and (num_violations >= max_violations)):
            break

        batch_scenario_times = torch.zeros(scenario_batch_size, )
        batch_scenario_states = torch.zeros(scenario_batch_size, dynamics.state_dim)
        batch_scenario_values = torch.zeros(scenario_batch_size, )

        num_collected_scenarios = 0
        while num_collected_scenarios < scenario_batch_size:
            if max_samples and (num_samples >= max_samples):
                break
            # sample batch
            if tStart_generator is not None:
                batch_sample_times = tStart_generator(sample_batch_size)
                # need to round to nearest dt
                batch_sample_times = torch.round(batch_sample_times/dt)*dt
            else:
                batch_sample_times = torch.full((sample_batch_size, ), tMax)
            batch_sample_states = dynamics.equivalent_wrapped_state(sample_generator.sample(sample_batch_size))
            batch_sample_coords = torch.cat((batch_sample_times.unsqueeze(-1), batch_sample_states), dim=-1)

            # validate batch
            with torch.no_grad():
                batch_sample_model_results = model({'coords': dynamics.coord_to_input(batch_sample_coords.to(device))})
                batch_sample_values = dynamics.io_to_value(batch_sample_model_results['model_in'].detach(), batch_sample_model_results['model_out'].squeeze(dim=-1).detach())
            batch_valid_sample_idxs = torch.where(sample_validator.validate(batch_sample_coords, batch_sample_values))[0].detach().cpu()

            # store valid samples
            num_valid_samples = len(batch_valid_sample_idxs)
            start_idx = num_collected_scenarios
            end_idx = min(start_idx + num_valid_samples, scenario_batch_size)
            batch_scenario_times[start_idx:end_idx] = batch_sample_times[batch_valid_sample_idxs][:end_idx-start_idx]
            batch_scenario_states[start_idx:end_idx] = batch_sample_states[batch_valid_sample_idxs][:end_idx-start_idx]
            batch_scenario_values[start_idx:end_idx] = batch_sample_values[batch_valid_sample_idxs][:end_idx-start_idx]

            # update counters
            num_samples += sample_batch_size
            if max_samples:
                samples_pbar.update(sample_batch_size)
            num_collected_scenarios += end_idx - start_idx
            nums_valid_samples.append(num_valid_samples)
        if max_samples and (num_samples >= max_samples):
            break

        # propagate scenarios
        state_trajs = torch.zeros(scenario_batch_size, int((tMax-tMin)/dt + 1), dynamics.state_dim)
        ctrl_trajs = torch.zeros(scenario_batch_size, int((tMax-tMin)/dt), dynamics.control_dim)
        dstb_trajs = torch.zeros(scenario_batch_size, int((tMax-tMin)/dt), dynamics.disturbance_dim)
        ham_trajs = torch.zeros(scenario_batch_size, int((tMax-tMin)/dt))

        state_trajs[:, 0, :] = batch_scenario_states
        for k in tqdm(range(int((tMax-tMin)/dt)), desc='Trajectory Propagation', position=pbar_pos, leave=False):
            if control_type == 'value':
                traj_time = tMax - k*dt
                traj_times = torch.full((scenario_batch_size, ), traj_time)
            # elif control_type == 'ttr':
            #     traj_times = get_tEarliest(model=model, dynamics=dynamics, state=state_trajs[:, k], tMin=tMin, tMax=traj_time, dt=dt, validator=sample_validator)
            # elif control_type == 'init_ttr':
            #     if k == 0:
            #         init_traj_times = get_tEarliest(model=model, dynamics=dynamics, state=state_trajs[:, k], tMin=tMin, tMax=traj_time, dt=dt, validator=sample_validator)
            #     traj_times = torch.maximum(init_traj_times - k*dt, torch.tensor(tMin)) # check whether this is the best thing to do for init_ttr
            traj_coords = torch.cat((traj_times.unsqueeze(-1), state_trajs[:, k]), dim=-1)
            traj_policy_results = policy({'coords': dynamics.coord_to_input(traj_coords.to(device))})
            traj_dvs = dynamics.io_to_dv(traj_policy_results['model_in'], traj_policy_results['model_out'].squeeze(dim=-1)).detach()

            # TODO: I do not think there is actually any reason to store these trajs? Could save space by removing these.
            ctrl_trajs[:, k] = dynamics.optimal_control(traj_coords[:, 1:].to(device), traj_dvs[..., 1:].to(device))
            dstb_trajs[:, k] = dynamics.optimal_disturbance(traj_coords[:, 1:].to(device), traj_dvs[..., 1:].to(device))
            ham_trajs[:, k] = dynamics.hamiltonian(traj_coords[:, 1:].to(device), traj_dvs[..., 1:].to(device))

            if tStart_generator is not None: # freeze states whose start time has not been reached yet
                is_frozen = batch_scenario_times < traj_times
                is_unfrozen = torch.logical_not(is_frozen)
                state_trajs[is_frozen, k+1] = state_trajs[is_frozen, k]
                state_trajs[is_unfrozen, k+1] = dynamics.equivalent_wrapped_state(state_trajs[is_unfrozen, k].to(device) + dt*dynamics.dsdt(state_trajs[is_unfrozen, k].to(device), ctrl_trajs[is_unfrozen, k].to(device), dstb_trajs[is_unfrozen, k].to(device))).cpu()
            else:
                state_trajs[:, k+1] = dynamics.equivalent_wrapped_state(state_trajs[:, k].to(device) + dt*dynamics.dsdt(state_trajs[:, k].to(device), ctrl_trajs[:, k].to(device), dstb_trajs[:, k].to(device)))

        # compute batch_scenario_costs
        # TODO: need to handle the case of using tStart_generator when extending a trajectory by a frozen initial state will inadvertently affect cost computation (the min lx cost formulation is unaffected, but other cost formulations might care)
        if set_type == 'BRT':
            batch_scenario_costs = dynamics.cost_fn(state_trajs.to(device))
        elif set_type == 'BRS':
            if control_type == 'init_ttr': # is this correct for init_ttr?
                batch_scenario_costs =  dynamics.boundary_fn(state_trajs.to(device))[:, (init_traj_times - tMin) / dt]
            elif control_type == 'value':
                batch_scenario_costs =  dynamics.boundary_fn(state_trajs.to(device))[:, -1]
            else:
                raise NotImplementedError # what is the correct thing to do for ttr?

        # compute batch_scenario_init_hams, batch_scenario_mean_hams, batch_scenario_mean_abs_hams, batch_scenario_max_abs_hams, batch_scenario_min_abs_hams
        batch_scenario_init_hams = ham_trajs[:, 0]
        batch_scenario_mean_hams = torch.mean(ham_trajs, dim=-1)
        batch_scenario_mean_abs_hams = torch.mean(torch.abs(ham_trajs), dim=-1)
        batch_scenario_max_abs_hams = torch.max(torch.abs(ham_trajs), dim=-1).values
        batch_scenario_min_abs_hams = torch.min(torch.abs(ham_trajs), dim=-1).values

        # store scenarios
        times = torch.cat((times, batch_scenario_times.cpu()), dim=0)
        states = torch.cat((states, batch_scenario_states.cpu()), dim=0)
        values = torch.cat((values, batch_scenario_values.cpu()), dim=0)
        costs = torch.cat((costs, batch_scenario_costs.cpu()), dim=0)
        init_hams = torch.cat((init_hams, batch_scenario_init_hams.cpu()), dim=0)
        mean_hams = torch.cat((mean_hams, batch_scenario_mean_hams.cpu()), dim=0)
        mean_abs_hams = torch.cat((mean_abs_hams, batch_scenario_mean_abs_hams.cpu()), dim=0)
        max_abs_hams = torch.cat((max_abs_hams, batch_scenario_max_abs_hams.cpu()), dim=0)
        min_abs_hams = torch.cat((min_abs_hams, batch_scenario_min_abs_hams.cpu()), dim=0)

        # update counters
        num_scenarios += scenario_batch_size
        if max_scenarios:
            scenarios_pbar.update(scenario_batch_size)
        num_new_violations = int(torch.sum(violation_validator.validate(batch_scenario_states, batch_scenario_costs)))
        num_violations += num_new_violations
        if max_violations:
            violations_pbar.update(num_new_violations)

    if max_scenarios:
        scenarios_pbar.close()
    if max_samples:
        samples_pbar.close()
    if max_violations:
        violations_pbar.close()

    violations = violation_validator.validate(states, costs)

    return {
        'times': times,
        'states': states,
        'values': values,
        'costs': costs,
        'init_hams': init_hams,
        'init_abs_hams': torch.abs(init_hams),
        'mean_hams': mean_hams,
        'mean_abs_hams': mean_abs_hams,
        'max_abs_hams': max_abs_hams,
        'min_abs_hams': min_abs_hams,
        'violations': violations,
        'valid_sample_fraction': torch.mean(torch.tensor(nums_valid_samples, dtype=float))/sample_batch_size,
        'violation_rate': 0 if not num_scenarios else num_violations / num_scenarios,
        'maxed_scenarios': (max_scenarios is not None) and num_scenarios >= max_scenarios,
        'maxed_samples': (max_samples is not None) and num_samples >= max_samples,
        'maxed_violations': (max_violations is not None) and num_violations >= max_violations,
        'batch_state_trajs': None if (max_samples and (num_samples >= max_samples)) else state_trajs,
    }

def target_fraction(device, model, dynamics, t, sample_validator, target_validator, num_samples, batch_size):
    with torch.no_grad():
        states = torch.zeros(0, dynamics.state_dim)
        values = torch.zeros(0, )

        while len(states) < num_samples:
            # sample batch
            batch_times = torch.full((batch_size, 1), t)
            batch_states = torch.zeros(batch_size, dynamics.state_dim)
            for dim in range(dynamics.state_dim):
                batch_states[:, dim].uniform_(*dynamics.state_test_range()[dim])
            batch_states = dynamics.equivalent_wrapped_state(batch_states)
            batch_coords = torch.cat((batch_times, batch_states), dim=-1)

            # validate batch
            batch_model_results = model({'coords': dynamics.coord_to_input(batch_coords.to(device))})
            batch_values = dynamics.io_to_value(batch_model_results['model_in'], batch_model_results['model_out'].squeeze(dim=-1)).detach()
            batch_valids = sample_validator.validate(batch_coords, batch_values).detach().cpu()

            # store valid portion of batch
            states = torch.cat((states, batch_states[batch_valids].cpu()), dim=0)
            values = torch.cat((values, batch_values[batch_valids].cpu()), dim=0)

        states = states[:num_samples]
        values = values[:num_samples]
        coords = torch.cat((torch.full((num_samples, 1), t), states), dim=-1)
        valids = target_validator.validate(coords.to(device), values.to(device))
    return torch.sum(valids) / num_samples

class MLP(torch.nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()

        s1 = int(2*input_size)
        s2 = int(input_size)
        s3 = int(input_size)
        self.l1 = torch.nn.Linear(input_size, s1)
        self.a1 = torch.nn.ReLU()
        self.l2 = torch.nn.Linear(s1, s2)
        self.a2 = torch.nn.ReLU()
        self.l3 = torch.nn.Linear(s2, s3)
        self.a3 = torch.nn.ReLU()
        self.l4 = torch.nn.Linear(s3, 1)

    def forward(self, x):
        x = self.l1(x)
        x = self.a1(x)
        x = self.l2(x)
        x = self.a2(x)
        x = self.l3(x)
        x = self.a3(x)
        x = self.l4(x)
        return x



def sample_values(device, model, dynamics, t, num_samples, batch_size):
    with torch.no_grad():
        states = torch.zeros(0, dynamics.state_dim)
        values = torch.zeros(0, )

        while len(states) < num_samples:
            # sample batch
            batch_times = torch.full((batch_size, 1), t)
            batch_states = torch.zeros(batch_size, dynamics.state_dim)
            for dim in range(dynamics.state_dim):
                batch_states[:, dim].uniform_(*dynamics.state_test_range()[dim])
            batch_states = dynamics.equivalent_wrapped_state(batch_states)
            batch_coords = torch.cat((batch_times, batch_states), dim=-1)

            batch_model_results = model({'coords': dynamics.coord_to_input(batch_coords.to(device))})
            batch_values = dynamics.io_to_value(batch_model_results['model_in'], batch_model_results['model_out'].squeeze(dim=-1)).detach()

            # store batch
            states = torch.cat((states, batch_states.cpu()), dim=0)
            values = torch.cat((values, batch_values.cpu()), dim=0)

        states = states[:num_samples]
        values = values[:num_samples]
        coords = torch.cat((torch.full((num_samples, 1), t), states), dim=-1)
    return values


In [9]:
# uses real units
def init_brt_hjivi_loss(dynamics, minWith, dirichlet_loss_divisor):
    def brt_hjivi_loss(state, value, dvdt, dvds, boundary_value, dirichlet_mask, output):
        if torch.all(dirichlet_mask):
            # pretraining loss
            diff_constraint_hom = torch.Tensor([0])
        else:
            ham = dynamics.hamiltonian(state, dvds)
            if minWith == 'zero':
                ham = torch.clamp(ham, max=0.0)

            diff_constraint_hom = dvdt - ham
            if minWith == 'target':
                diff_constraint_hom = torch.max(
                    diff_constraint_hom, value - boundary_value)
        dirichlet = value[dirichlet_mask] - boundary_value[dirichlet_mask]
        if dynamics.deepreach_model == 'exact':
            if torch.all(dirichlet_mask):
                # pretraining
                dirichlet = output.squeeze(dim=-1)[dirichlet_mask]-0.0
            else:
                return {'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()}

        return {'dirichlet': torch.abs(dirichlet).sum() / dirichlet_loss_divisor,
                'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()}

    return brt_hjivi_loss
def init_brat_hjivi_loss(dynamics, minWith, dirichlet_loss_divisor):
    def brat_hjivi_loss(state, value, dvdt, dvds, boundary_value, reach_value, avoid_value, dirichlet_mask, output):
        if torch.all(dirichlet_mask):
            # pretraining loss
            diff_constraint_hom = torch.Tensor([0])
        else:
            ham = dynamics.hamiltonian(state, dvds)
            if minWith == 'zero':
                ham = torch.clamp(ham, max=0.0)

            diff_constraint_hom = dvdt - ham
            if minWith == 'target':
                diff_constraint_hom = torch.min(
                    torch.max(diff_constraint_hom, value - reach_value), value + avoid_value)

        dirichlet = value[dirichlet_mask] - boundary_value[dirichlet_mask]
        if dynamics.deepreach_model == 'exact':
            if torch.all(dirichlet_mask):
                dirichlet = output.squeeze(dim=-1)[dirichlet_mask]-0.0
            else:
                return {'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()}
        return {'dirichlet': torch.abs(dirichlet).sum() / dirichlet_loss_divisor,
                'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()}
    return brat_hjivi_loss

In [10]:
"""
MIT License

Copyright (c) 2020 Vincent Sitzmann

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

class BatchLinear(nn.Linear):
    '''A linear layer'''
    __doc__ = nn.Linear.__doc__

    def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        bias = params.get('bias', None)
        weight = params['weight']

        output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
        output += bias.unsqueeze(-2)
        return output


class Sine(nn.Module):
    def __init(self):
        super().__init__()

    def forward(self, input):
        # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
        return torch.sin(30 * input)


class FCBlock(nn.Module):
    '''A fully connected neural network.
    '''

    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=False, nonlinearity='relu', weight_init=None):
        super().__init__()

        self.first_layer_init = None

        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init),
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None)}

        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init

        self.net = []
        self.net.append(nn.Sequential(
            BatchLinear(in_features, hidden_features), nl
        ))

        for i in range(num_hidden_layers):
            self.net.append(nn.Sequential(
                BatchLinear(hidden_features, hidden_features), nl
            ))

        if outermost_linear:
            self.net.append(nn.Sequential(BatchLinear(hidden_features, out_features)))
        else:
            self.net.append(nn.Sequential(
                BatchLinear(hidden_features, out_features), nl
            ))

        self.net = nn.Sequential(*self.net)
        if self.weight_init is not None:
            self.net.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.net[0].apply(first_layer_init)

    def forward(self, coords, params=None, **kwargs):
        if params is None:
            params = OrderedDict(self.named_parameters())

        output = self.net(coords)
        return output


class SingleBVPNet(nn.Module):
    '''A canonical representation network for a BVP.'''

    def __init__(self, out_features=1, type='sine', in_features=2,
                 mode='mlp', hidden_features=256, num_hidden_layers=3, **kwargs):
        super().__init__()
        self.mode = mode
        self.net = FCBlock(in_features=in_features, out_features=out_features, num_hidden_layers=num_hidden_layers,
                           hidden_features=hidden_features, outermost_linear=True, nonlinearity=type)
        print(self)

    def forward(self, model_input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        # Enables us to compute gradients w.r.t. coordinates
        # TODO: should not need to .clone().detach().requires_grad_(True); instead, use .retain_grad() on input in calling script
        # otherwise, .detach() removes input from the graph so grad cannot propagate back end-to-end, e.g., percept -> NN -> state estimation (input)
        coords_org = model_input['coords'].clone().detach().requires_grad_(True)
        coords = coords_org

        output = self.net(coords)
        return {'model_in': coords_org, 'model_out': output}


########################
# Initialization methods
def init_weights_normal(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')


def init_weights_selu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))


def init_weights_elu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))


def init_weights_xavier(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.xavier_normal_(m.weight)


def sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)


def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-1 / num_input, 1 / num_input)


### 5. Dynamics Code

In [11]:
# during training, states will be sampled uniformly by each state dimension from the model-unit -1 to 1 range (for training stability),
# which may or may not correspond to proper test ranges
# note that coord refers to [time, *state], and input refers to whatever is fed directly to the model (often [time, *state, params])
# in the future, code will need to be fixed to correctly handle parameterized models
class Dynamics(ABC):
    def __init__(self,
    loss_type:str, set_mode:str,
    state_dim:int, input_dim:int,
    control_dim:int, disturbance_dim:int,
    state_mean:list, state_var:list,
    value_mean:float, value_var:float, value_normto:float,
    deepreach_model:str):
        self.loss_type = loss_type
        self.set_mode = set_mode
        self.state_dim = state_dim
        self.input_dim = input_dim
        self.control_dim = control_dim
        self.disturbance_dim = disturbance_dim
        self.state_mean = torch.tensor(state_mean)
        self.state_var = torch.tensor(state_var)
        self.value_mean = value_mean
        self.value_var = value_var
        self.value_normto = value_normto
        self.deepreach_model = deepreach_model
        assert self.loss_type in ['brt_hjivi', 'brat_hjivi'], f'loss type {self.loss_type} not recognized'
        if self.loss_type == 'brat_hjivi':
            assert callable(self.reach_fn) and callable(self.avoid_fn)
        assert self.set_mode in ['reach', 'avoid'], f'set mode {self.set_mode} not recognized'
        for state_descriptor in [self.state_mean, self.state_var]:
            assert len(state_descriptor) == self.state_dim, 'state descriptor dimension does not equal state dimension, ' + str(len(state_descriptor)) + ' != ' + str(self.state_dim)

    # ALL METHODS ARE BATCH COMPATIBLE

    # MODEL-UNIT CONVERSIONS (TODO: refactor into separate model-unit conversion class?)

    # convert model input to real coord
    def input_to_coord(self, input):
        coord = input.clone()
        coord[..., 1:] = (input[..., 1:] * self.state_var.to(device=input.device)) + self.state_mean.to(device=input.device)
        return coord

    # convert real coord to model input
    def coord_to_input(self, coord):
        input = coord.clone()
        input[..., 1:] = (coord[..., 1:] - self.state_mean.to(device=coord.device)) / self.state_var.to(device=coord.device)
        return input

    # convert model io to real value
    def io_to_value(self, input, output):
        if self.deepreach_model=="diff":
            return (output * self.value_var / self.value_normto) + self.boundary_fn(self.input_to_coord(input)[..., 1:])
        elif self.deepreach_model=="exact":
            return (output * input[..., 0] * self.value_var / self.value_normto) + self.boundary_fn(self.input_to_coord(input)[..., 1:])
        else:
            return (output * self.value_var / self.value_normto) + self.value_mean

    # convert model io to real dv
    def io_to_dv(self, input, output):
        dodi = jacobian(output.unsqueeze(dim=-1), input)[0].squeeze(dim=-2)

        if self.deepreach_model=="diff":
            dvdt = (self.value_var / self.value_normto) * dodi[..., 0]

            dvds_term1 = (self.value_var / self.value_normto / self.state_var.to(device=dodi.device)) * dodi[..., 1:]
            state = self.input_to_coord(input)[..., 1:]
            dvds_term2 = jacobian(self.boundary_fn(state).unsqueeze(dim=-1), state)[0].squeeze(dim=-2)
            dvds = dvds_term1 + dvds_term2
        elif self.deepreach_model=="exact":
            dvdt = (self.value_var / self.value_normto) * \
                (input[..., 0]*dodi[..., 0] + output)

            dvds_term1 = (self.value_var / self.value_normto /
                          self.state_var.to(device=dodi.device)) * dodi[..., 1:] * input[..., 0].unsqueeze(-1)
            state = self.input_to_coord(input)[..., 1:]
            dvds_term2 = jacobian(self.boundary_fn(
                state).unsqueeze(dim=-1), state)[0].squeeze(dim=-2)
            dvds = dvds_term1 + dvds_term2
        else:
            dvdt = (self.value_var / self.value_normto) * dodi[..., 0]
            dvds = (self.value_var / self.value_normto / self.state_var.to(device=dodi.device)) * dodi[..., 1:]

        return torch.cat((dvdt.unsqueeze(dim=-1), dvds), dim=-1)

    # ALL FOLLOWING METHODS USE REAL UNITS

    @abstractmethod
    def state_test_range(self):
        raise NotImplementedError

    @abstractmethod
    def equivalent_wrapped_state(self, state):
        raise NotImplementedError

    @abstractmethod
    def dsdt(self, state, control, disturbance):
        raise NotImplementedError

    @abstractmethod
    def boundary_fn(self, state):
        raise NotImplementedError

    @abstractmethod
    def sample_target_state(self, num_samples):
        raise NotImplementedError

    @abstractmethod
    def cost_fn(self, state_traj):
        raise NotImplementedError

    @abstractmethod
    def hamiltonian(self, state, dvds):
        raise NotImplementedError

    @abstractmethod
    def optimal_control(self, state, dvds):
        raise NotImplementedError

    @abstractmethod
    def optimal_disturbance(self, state, dvds):
        raise NotImplementedError

    @abstractmethod
    def plot_config(self):
        raise NotImplementedError

class ParameterizedVertDrone2D(Dynamics):
    def __init__(self, gravity:float, input_multiplier_max:float, input_magnitude_max:float):
        self.gravity = gravity                             # g
        self.input_multiplier_max = input_multiplier_max   # k_max
        self.input_magnitude_max = input_magnitude_max     # u_max
        super().__init__(
            loss_type='brt_hjivi', set_mode='avoid',
            state_dim=3, input_dim=4, control_dim=1, disturbance_dim=0,
            state_mean=[0, 1.5, self.input_multiplier_max/2], # v, z, k
            state_var=[4, 2, self.input_multiplier_max/2],    # v, z, k
            value_mean=0.25,
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-4, 4],                        # v
            [-0.5, 3.5],                    # z
            [0, self.input_multiplier_max], # k
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        return wrapped_state

    # ParameterizedVertDrone2D dynamics
    # \dot v = k*u - g
    # \dot z = v
    # \dot k = 0
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = state[..., 2]*control[..., 0] - self.gravity
        dsdt[..., 1] = state[..., 0]
        dsdt[..., 2] = 0
        return dsdt

    def boundary_fn(self, state):
        return -torch.abs(state[..., 1] - 1.5) + 1.5

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        raise NotImplementedError

    def hamiltonian(self, state, dvds):
        return state[..., 2]*torch.abs(dvds[..., 0]*self.input_magnitude_max) \
                - dvds[..., 0]*self.gravity \
                + dvds[..., 1]*state[..., 0]

    def optimal_control(self, state, dvds):
        raise NotImplementedError

    def optimal_disturbance(self, state, dvds):
        raise NotImplementedError

    def plot_config(self):
        return {
            'state_slices': [0, 1.5, self.input_multiplier_max/2],
            'state_labels': ['v', 'z', 'k'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 2,
        }

class Air3D(Dynamics):
    def __init__(self, collisionR:float, velocity:float, omega_max:float, angle_alpha_factor:float):
        self.collisionR = collisionR
        self.velocity = velocity
        self.omega_max = omega_max
        self.angle_alpha_factor = angle_alpha_factor
        super().__init__(
            loss_type='brt_hjivi', set_mode='avoid',
            state_dim=3, input_dim=4, control_dim=1, disturbance_dim=1,
            state_mean=[0, 0, 0],
            state_var=[1, 1, self.angle_alpha_factor*math.pi],
            value_mean=0.25,
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-1, 1],
            [-1, 1],
            [-math.pi, math.pi],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # Air3D dynamics
    # \dot x    = -v + v \cos \psi + u y
    # \dot y    = v \sin \psi - u x
    # \dot \psi = d - u
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = -self.velocity + self.velocity*torch.cos(state[..., 2]) + control[..., 0]*state[..., 1]
        dsdt[..., 1] = self.velocity*torch.sin(state[..., 2]) - control[..., 0]*state[..., 0]
        dsdt[..., 2] = disturbance[..., 0] - control[..., 0]
        return dsdt

    def boundary_fn(self, state):
        return torch.norm(state[..., :2], dim=-1) - self.collisionR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        ham = self.omega_max * torch.abs(dvds[..., 0] * state[..., 1] - dvds[..., 1] * state[..., 0] - dvds[..., 2])  # Control component
        ham = ham - self.omega_max * torch.abs(dvds[..., 2])  # Disturbance component
        ham = ham + (self.velocity * (torch.cos(state[..., 2]) - 1.0) * dvds[..., 0]) + (self.velocity * torch.sin(state[..., 2]) * dvds[..., 1])  # Constant component
        return ham

    def optimal_control(self, state, dvds):
        det = dvds[..., 0]*state[..., 1] - dvds[..., 1]*state[..., 0]-dvds[..., 2]
        return (self.omega_max * torch.sign(det))[..., None]

    def optimal_disturbance(self, state, dvds):
        return (-self.omega_max * torch.sign(dvds[..., 2]))[..., None]

    def plot_config(self):
        return {
            'state_slices': [0, 0, 0],
            'state_labels': ['x', 'y', 'theta'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 2,
        }

class Dubins3D(Dynamics):
    def __init__(self, goalR:float, velocity:float, omega_max:float, angle_alpha_factor:float, set_mode:str, freeze_model: bool):
        self.goalR = goalR
        self.velocity = velocity
        self.omega_max = omega_max
        self.angle_alpha_factor = angle_alpha_factor
        self.freeze_model = freeze_model
        super().__init__(
            loss_type='brt_hjivi', set_mode=set_mode,
            state_dim=3, input_dim=4, control_dim=1, disturbance_dim=0,
            state_mean=[0, 0, 0],
            state_var=[1, 1, self.angle_alpha_factor*math.pi],
            value_mean=0.25,
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact"
        )

    def state_test_range(self):
        return [
            [-1, 1],
            [-1, 1],
            [-math.pi, math.pi],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # Dubins3D dynamics
    # \dot x    = v \cos \theta
    # \dot y    = v \sin \theta
    # \dot \theta = u
    def dsdt(self, state, control, disturbance):
        if self.freeze_model:
            raise NotImplementedError
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = self.velocity*torch.cos(state[..., 2])
        dsdt[..., 1] = self.velocity*torch.sin(state[..., 2])
        dsdt[..., 2] = control[..., 0]
        return dsdt

    def boundary_fn(self, state):
        return torch.norm(state[..., :2], dim=-1) - self.goalR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        if self.freeze_model:
            raise NotImplementedError
        if self.set_mode == 'reach':
            return self.velocity*(torch.cos(state[..., 2]) * dvds[..., 0] + torch.sin(state[..., 2]) * dvds[..., 1]) - self.omega_max * torch.abs(dvds[..., 2])
        elif self.set_mode == 'avoid':
            return self.velocity*(torch.cos(state[..., 2]) * dvds[..., 0] + torch.sin(state[..., 2]) * dvds[..., 1]) + self.omega_max * torch.abs(dvds[..., 2])

    def optimal_control(self, state, dvds):
        if self.set_mode == 'reach':
            return (-self.omega_max*torch.sign(dvds[..., 2]))[..., None]
        elif self.set_mode == 'avoid':
            return (self.omega_max*torch.sign(dvds[..., 2]))[..., None]

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [0, 0, 0],
            'state_labels': ['x', 'y', r'$\theta$'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 2,
        }

class Dubins4D(Dynamics):
    def __init__(self, bound_mode:str):
        self.vMin = 0.2
        self.vMax = 14.8
        self.collisionR = 1.5
        self.bound_mode = bound_mode
        assert self.bound_mode in ['v1', 'v2']

        xMean = 0
        yMean = 0
        thetaMean = 0
        vMean = 7.5
        aMean = 0
        oMean = 0

        xVar = 10
        yVar = 10
        thetaVar = 1.2*math.pi
        vVar = 7.5
        aVar = 10
        oVar = 3*math.pi if self.bound_mode == 'v1' else 2.0

        super().__init__(
            loss_type='brt_hjivi',
            state_dim=14, input_dim=15,  control_dim=2, disturbance_dim=0,
            state_mean=[xMean, yMean, thetaMean, vMean, xMean, yMean, aMean, aMean, oMean, oMean, aMean, aMean, oMean, oMean],
            state_var=[xVar, yVar, thetaVar, vVar, xVar, yVar, aVar, aVar, oVar, oVar, aVar, aVar, oVar, oVar],
            value_mean=13,
            value_var=14,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-1, 1],
            [-1, 1],
            [-math.pi, math.pi],
            [self.vMin, self.vMax],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    def boundary_fn(self, state):
        return torch.norm(state[..., 0:2] - state[..., 4:6], dim=-1) - self.collisionR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        raise NotImplementedError

    def dsdt(self, state, control, disturbance):
        raise NotImplementedError

    def hamiltonian(self, state, dvds):
        raise NotImplementedError

    def optimal_control(self, state, dvds):
        raise NotImplementedError

    def optimal_disturbance(self, state, dvds):
        raise NotImplementedError

    def plot_config(self):
        raise NotImplementedError

class NarrowPassage(Dynamics):
    def __init__(self, avoid_fn_weight:float, avoid_only:bool):
        self.L = 2.0

        # # Target positions
        self.goalX = [6.0, -6.0]
        self.goalY = [-1.4, 1.4]

        # State bounds
        self.vMin = 0.001
        self.vMax = 6.50
        self.phiMin = -0.3*math.pi + 0.001
        self.phiMax = 0.3*math.pi - 0.001

        # Control bounds
        self.aMin = -4.0
        self.aMax = 2.0
        self.psiMin = -3.0*math.pi
        self.psiMax = 3.0*math.pi

        # Lower and upper curb positions (in the y direction)
        self.curb_positions = [-2.8, 2.8]

        # Stranded car position
        self.stranded_car_pos = [0.0, -1.8]

        self.avoid_fn_weight = avoid_fn_weight

        self.avoid_only = avoid_only

        super().__init__(
            loss_type='brt_hjivi' if self.avoid_only else 'brat_hjivi', set_mode='avoid' if self.avoid_only else 'reach',
            state_dim=10, input_dim=11, control_dim=4, disturbance_dim=0,
            # state = [x1, y1, th1, v1, phi1, x2, y2, th2, v2, phi2]
            state_mean=[
                0, 0, 0, 3, 0,
                0, 0, 0, 3, 0
            ],
            state_var=[
                8.0, 3.8, 1.2*math.pi, 4.0, 1.2*0.3*math.pi,
                8.0, 3.8, 1.2*math.pi, 4.0, 1.2*0.3*math.pi,
            ],
            value_mean=0.25*8.0,
            value_var=0.5*8.0,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-8, 8],
            [-3.8, 3.8],
            [-math.pi, math.pi],
            [-1, 7],
            [-0.3*math.pi, 0.3*math.pi],
            [-8, 8],
            [-3.8, 3.8],
            [-math.pi, math.pi],
            [-1, 7],
            [-0.3*math.pi, 0.3*math.pi],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 4] = (wrapped_state[..., 4] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 7] = (wrapped_state[..., 7] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 9] = (wrapped_state[..., 9] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # NarrowPassage dynamics
    # \dot x   = v * cos(th)
    # \dot y   = v * sin(th)
    # \dot th  = v * tan(phi) / L
    # \dot v   = u1
    # \dot phi = u2
    # \dot x   = ...
    # \dot y   = ...
    # \dot th  = ...
    # \dot v   = ...
    # \dot phi = ...
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = state[..., 3]*torch.cos(state[..., 2])
        dsdt[..., 1] = state[..., 3]*torch.sin(state[..., 2])
        dsdt[..., 2] = state[..., 3]*torch.tan(state[..., 4]) / self.L
        dsdt[..., 3] = control[..., 0]
        dsdt[..., 4] = control[..., 1]
        dsdt[..., 5] = state[..., 8]*torch.cos(state[..., 7])
        dsdt[..., 6] = state[..., 8]*torch.sin(state[..., 7])
        dsdt[..., 7] = state[..., 8]*torch.tan(state[..., 9]) / self.L
        dsdt[..., 8] = control[..., 2]
        dsdt[..., 9] = control[..., 3]
        return dsdt

    def reach_fn(self, state):
        if self.avoid_only:
            raise RuntimeError
        # vehicle 1
        goal_tensor_R1 = torch.tensor([self.goalX[0], self.goalY[0]], device=state.device)
        dist_R1 = torch.norm(state[..., 0:2] - goal_tensor_R1, dim=-1) - self.L
        # vehicle 2
        goal_tensor_R2 = torch.tensor([self.goalX[1], self.goalY[1]], device=state.device)
        dist_R2 = torch.norm(state[..., 5:7] - goal_tensor_R2, dim=-1) - self.L
        return torch.maximum(dist_R1, dist_R2)

    def avoid_fn(self, state):
        # distance from lower curb
        dist_lc_R1 = state[..., 1] - self.curb_positions[0] - 0.5*self.L
        dist_lc_R2 = state[..., 6] - self.curb_positions[0] - 0.5*self.L
        dist_lc = torch.minimum(dist_lc_R1, dist_lc_R2)

        # distance from upper curb
        dist_uc_R1 = self.curb_positions[1] - state[..., 1] - 0.5*self.L
        dist_uc_R2 = self.curb_positions[1] - state[..., 6] - 0.5*self.L
        dist_uc = torch.minimum(dist_uc_R1, dist_uc_R2)

        # distance from the stranded car
        stranded_car_pos = torch.tensor(self.stranded_car_pos, device=state.device)
        dist_stranded_R1 = torch.norm(state[..., 0:2] - stranded_car_pos, dim=-1) - self.L
        dist_stranded_R2 = torch.norm(state[..., 5:7] - stranded_car_pos, dim=-1) - self.L
        dist_stranded = torch.minimum(dist_stranded_R1, dist_stranded_R2)

        # distance between the vehicles themselves
        dist_R1R2 = torch.norm(state[..., 0:2] - state[..., 5:7], dim=-1) - self.L

        return self.avoid_fn_weight * torch.min(torch.min(torch.min(dist_lc, dist_uc), dist_stranded), dist_R1R2)

    def boundary_fn(self, state):
        if self.avoid_only:
            return self.avoid_fn(state)
        else:
            return torch.maximum(self.reach_fn(state), -self.avoid_fn(state))

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        if self.avoid_only:
            return torch.min(self.avoid_fn(state_traj), dim=-1).values
        else:
            # return min_t max{l(x(t)), max_k_up_to_t{-g(x(k))}}, where l(x) is reach_fn, g(x) is avoid_fn
            reach_values = self.reach_fn(state_traj)
            avoid_values = self.avoid_fn(state_traj)
            return torch.min(torch.maximum(reach_values, torch.cummax(-avoid_values, dim=-1).values), dim=-1).values

    def hamiltonian(self, state, dvds):
        optimal_control = self.optimal_control(state, dvds)
        return state[..., 3] * torch.cos(state[..., 2]) * dvds[..., 0] + \
               state[..., 3] * torch.sin(state[..., 2]) * dvds[..., 1] + \
               state[..., 3] * torch.tan(state[..., 4]) * dvds[..., 2] / self.L + \
               optimal_control[..., 0] * dvds[..., 3] + \
               optimal_control[..., 1] * dvds[..., 4] + \
               state[..., 8] * torch.cos(state[..., 7]) * dvds[..., 5] + \
               state[..., 8] * torch.sin(state[..., 7]) * dvds[..., 6] + \
               state[..., 8] * torch.tan(state[..., 9]) * dvds[..., 7] / self.L + \
               optimal_control[..., 2] * dvds[..., 8] + \
               optimal_control[..., 3] * dvds[..., 9]

    def optimal_control(self, state, dvds):
        a1_min = self.aMin * (state[..., 3] > self.vMin)
        a1_max = self.aMax * (state[..., 3] < self.vMax)
        psi1_min = self.psiMin * (state[..., 4] > self.phiMin)
        psi1_max = self.psiMax * (state[..., 4] < self.phiMax)
        a2_min = self.aMin * (state[..., 8] > self.vMin)
        a2_max = self.aMax * (state[..., 8] < self.vMax)
        psi2_min = self.psiMin * (state[..., 9] > self.phiMin)
        psi2_max = self.psiMax * (state[..., 9] < self.phiMax)

        if self.avoid_only:
            a1 = torch.where(dvds[..., 3] < 0, a1_min, a1_max)
            psi1 = torch.where(dvds[..., 4] < 0, psi1_min, psi1_max)
            a2 = torch.where(dvds[..., 8] < 0, a2_min, a2_max)
            psi2 = torch.where(dvds[..., 9] < 0, psi2_min, psi2_max)

        else:
            a1 = torch.where(dvds[..., 3] > 0, a1_min, a1_max)
            psi1 = torch.where(dvds[..., 4] > 0, psi1_min, psi1_max)
            a2 = torch.where(dvds[..., 8] > 0, a2_min, a2_max)
            psi2 = torch.where(dvds[..., 9] > 0, psi2_min, psi2_max)

        return torch.cat((a1[..., None], psi1[..., None], a2[..., None], psi2[..., None]), dim=-1)

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [
                -6.0, -1.4, 0.0, 6.5, 0.0,
                -6.0, 1.4, -math.pi, 0.0, 0.0
            ],
            'state_labels': [
                r'$x_1$', r'$y_1$', r'$\theta_1$', r'$v_1$', r'$\phi_1$',
                r'$x_2$', r'$y_2$', r'$\theta_2$', r'$v_2$', r'$\phi_2$',
            ],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 2,
        }

class ReachAvoidRocketLanding(Dynamics):
    def __init__(self):
        super().__init__(
            loss_type='brat_hjivi', set_mode='reach',
            state_dim=6, input_dim=7, control_dim=2, disturbance_dim=0,
            state_mean=[0.0, 80.0, 0.0, 0.0, 0.0, 0.0],
            state_var=[150.0, 70.0, 1.2*math.pi, 200.0, 200.0, 10.0],
            value_mean=0.0,
            value_var=1.0,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-150, 150],
            [10, 150],
            [-math.pi, math.pi],
            [-200, 200],
            [-200, 200],
            [-10, 10],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # \dot x = v_x
    # \dot y = v_y
    # \dot th = w
    # \dot v_x = u1 * cos(th) - u2 sin(th)
    # \dot v_y = u1 * sin(th) + u2 cos(th) - 9.81
    # \dot w = 0.3 * u1
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = state[..., 3]
        dsdt[..., 1] = state[..., 4]
        dsdt[..., 2] = state[..., 5]
        dsdt[..., 3] = control[..., 0]*torch.cos(state[..., 2]) - control[..., 1]*torch.sin(state[..., 2])
        dsdt[..., 4] = control[..., 0]*torch.sin(state[..., 2]) + control[..., 1]*torch.cos(state[..., 2]) - 9.81
        dsdt[..., 5] = 0.3*control[..., 0]
        return dsdt

    def reach_fn(self, state):
        # Only target set in the xy direction
        # Target set position in x direction
        dist_x = torch.abs(state[..., 0]) - 20.0 #[-20, 150] boundary_fn range

        # Target set position in y direction
        dist_y = state[..., 1] - 20.0  #[-10, 130] boundary_fn range

        # First compute the target function as you normally would but then normalize it later.
        max_dist = torch.max(dist_x, dist_y)
        return torch.where((max_dist >= 0), max_dist/150.0, max_dist/10.0)

    def avoid_fn(self, state):
        # distance to floor
        dist_y = state[..., 1]

        # distance to wall
        wall_left = -30
        wall_right = -20
        wall_bottom = 0
        wall_top = 100
        dist_left = wall_left - state[..., 0]
        dist_right = state[..., 0] - wall_right
        dist_bottom = wall_bottom - state[..., 1]
        dist_top = state[..., 1] - wall_top
        dist_wall_x = torch.max(dist_left, dist_right)
        dist_wall_y = torch.max(dist_bottom, dist_top)
        dist_wall = torch.norm(torch.cat((torch.max(torch.tensor(0), dist_wall_x).unsqueeze(-1), torch.max(torch.tensor(0), dist_wall_y).unsqueeze(-1)), dim=-1), dim=-1) + torch.min(torch.tensor(0), torch.max(dist_wall_x, dist_wall_y))

        return torch.min(dist_y, dist_wall)

    def boundary_fn(self, state):
        return torch.maximum(self.reach_fn(state), -self.avoid_fn(state))

    def sample_target_state(self, num_samples):
        target_state_range = self.state_test_range()
        target_state_range[0] = [-20, 20] # y in [-20, 20]
        target_state_range[1] = [10, 20]  # z in [10, 20]
        target_state_range = torch.tensor(target_state_range)
        return target_state_range[:, 0] + torch.rand(num_samples, self.state_dim)*(target_state_range[:, 1] - target_state_range[:, 0])

    def cost_fn(self, state_traj):
        # return min_t max{l(x(t)), max_k_up_to_t{-g(x(k))}}, where l(x) is reach_fn, g(x) is avoid_fn
        reach_values = self.reach_fn(state_traj)
        avoid_values = self.avoid_fn(state_traj)
        return torch.min(torch.maximum(reach_values, torch.cummax(-avoid_values, dim=-1).values), dim=-1).values

    def hamiltonian(self, state, dvds):
        # Control Hamiltonian
        u1_coeff = dvds[..., 3] * torch.cos(state[..., 2]) + dvds[..., 4] * torch.sin(state[..., 2]) + 0.3 * dvds[..., 5]
        u2_coeff = -dvds[..., 3] * torch.sin(state[..., 2]) + dvds[..., 4] * torch.cos(state[..., 2])
        ham_ctrl = -250.0 * torch.sqrt(u1_coeff * u1_coeff + u2_coeff * u2_coeff)
        # Constant Hamiltonian
        ham_constant = dvds[..., 0] * state[..., 3] + dvds[..., 1] * state[..., 4] + \
                      dvds[..., 2] * state[..., 5]  - dvds[..., 4] * 9.81
        # Compute the Hamiltonian
        ham_vehicle = ham_ctrl + ham_constant
        return ham_vehicle

    def optimal_control(self, state, dvds):
        u1_coeff = dvds[..., 3] * torch.cos(state[..., 2]) + dvds[..., 4] * torch.sin(state[..., 2]) + 0.3 * dvds[..., 5]
        u2_coeff = -dvds[..., 3] * torch.sin(state[..., 2]) + dvds[..., 4] * torch.cos(state[..., 2])
        opt_angle = torch.atan2(u2_coeff, u1_coeff) + math.pi
        return torch.cat((250.0 * torch.cos(opt_angle)[..., None], 250.0 * torch.sin(opt_angle)[..., None]), dim=-1)

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [-100, 120, 0, 150, -5, 0.0],
            'state_labels': ['x', 'y', r'$\theta$', r'$v_x$', r'$v_y$', r'$\omega'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 4,
        }

class RocketLanding(Dynamics):
    def __init__(self):
        super().__init__(
            loss_type='brt_hjivi', set_mode='reach',
            state_dim=6, input_dim=8, control_dim=2, disturbance_dim=0,
            state_mean=[0.0, 80.0, 0.0, 0.0, 0.0, 0.0],
            state_var=[150.0, 70.0, 1.2*math.pi, 200.0, 200.0, 10.0],
            value_mean=0.0,
            value_var=1.0,
            value_normto=0.02,
            deepreach_model="exact",
        )

    # convert model input to real coord
    def input_to_coord(self, input):
        input = input[..., :-1]
        coord = input.clone()
        coord[..., 1:] = (input[..., 1:] * self.state_var.to(device=input.device)) + self.state_mean.to(device=input.device)
        return coord

    # convert real coord to model input
    def coord_to_input(self, coord):
        input = coord.clone()
        input[..., 1:] = (coord[..., 1:] - self.state_mean.to(device=coord.device)) / self.state_var.to(device=coord.device)
        input = torch.cat((input, torch.zeros((*input.shape[:-1], 1), device=input.device)), dim=-1)
        return input

    # convert model io to real value
    def io_to_value(self, input, output):
        if self.deepreach_model=="diff":
            return (output * self.value_var / self.value_normto) + self.boundary_fn(self.input_to_coord(input)[..., 1:])
        else:
            return (output * self.value_var / self.value_normto) + self.value_mean

    # convert model io to real dv
    def io_to_dv(self, input, output):
        dodi = jacobian(output.unsqueeze(dim=-1), input)[0].squeeze(dim=-2)[..., :-1]

        if self.deepreach_model=="diff":
            dvdt = (self.value_var / self.value_normto) * dodi[..., 0]

            dvds_term1 = (self.value_var / self.value_normto / self.state_var.to(device=dodi.device)) * dodi[..., 1:]
            state = self.input_to_coord(input)[..., 1:]
            dvds_term2 = jacobian(self.boundary_fn(state).unsqueeze(dim=-1), state)[0].squeeze(dim=-2)
            dvds = dvds_term1 + dvds_term2

        else:
            dvdt = (self.value_var / self.value_normto) * dodi[..., 0]
            dvds = (self.value_var / self.value_normto / self.state_var.to(device=dodi.device)) * dodi[..., 1:]

        return torch.cat((dvdt.unsqueeze(dim=-1), dvds), dim=-1)


    def state_test_range(self):
        return [
            [-150, 150],
            [10, 150],
            [-math.pi, math.pi],
            [-200, 200],
            [-200, 200],
            [-10, 10],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # \dot x = v_x
    # \dot y = v_y
    # \dot th = w
    # \dot v_x = u1 * cos(th) - u2 sin(th)
    # \dot v_y = u1 * sin(th) + u2 cos(th) - 9.81
    # \dot w = 0.3 * u1
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = state[..., 3]
        dsdt[..., 1] = state[..., 4]
        dsdt[..., 2] = state[..., 5]
        dsdt[..., 3] = control[..., 0]*torch.cos(state[..., 2]) - control[..., 1]*torch.sin(state[..., 2])
        dsdt[..., 4] = control[..., 0]*torch.sin(state[..., 2]) + control[..., 1]*torch.cos(state[..., 2]) - 9.81
        dsdt[..., 5] = 0.3*control[..., 0]
        return dsdt

    def boundary_fn(self, state):
        # Only target set in the yz direction
        # Target set position in y direction
        dist_y = torch.abs(state[..., 0]) - 20.0 #[-20, 150] boundary_fn range

        # Target set position in z direction
        dist_z = state[..., 1] - 20.0  #[-10, 130] boundary_fn range

        # First compute the l(x) as you normally would but then normalize it later.
        lx = torch.max(dist_y, dist_z)
        return torch.where((lx >= 0), lx/150.0, lx/10.0)

    def sample_target_state(self, num_samples):
        target_state_range = self.state_test_range()
        target_state_range[0] = [-20, 20] # y in [-20, 20]
        target_state_range[1] = [10, 20]  # z in [10, 20]
        target_state_range = torch.tensor(target_state_range)
        return target_state_range[:, 0] + torch.rand(num_samples, self.state_dim)*(target_state_range[:, 1] - target_state_range[:, 0])

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        # Control Hamiltonian
        u1_coeff = dvds[..., 3] * torch.cos(state[..., 2]) + dvds[..., 4] * torch.sin(state[..., 2]) + 0.3 * dvds[..., 5]
        u2_coeff = -dvds[..., 3] * torch.sin(state[..., 2]) + dvds[..., 4] * torch.cos(state[..., 2])
        ham_ctrl = -250.0 * torch.sqrt(u1_coeff * u1_coeff + u2_coeff * u2_coeff)
        # Constant Hamiltonian
        ham_constant = dvds[..., 0] * state[..., 3] + dvds[..., 1] * state[..., 4] + \
                      dvds[..., 2] * state[..., 5]  - dvds[..., 4] * 9.81
        # Compute the Hamiltonian
        ham_vehicle = ham_ctrl + ham_constant
        return ham_vehicle

    def optimal_control(self, state, dvds):
        u1_coeff = dvds[..., 3] * torch.cos(state[..., 2]) + dvds[..., 4] * torch.sin(state[..., 2]) + 0.3 * dvds[..., 5]
        u2_coeff = -dvds[..., 3] * torch.sin(state[..., 2]) + dvds[..., 4] * torch.cos(state[..., 2])
        opt_angle = torch.atan2(u2_coeff, u1_coeff) + math.pi
        return torch.cat((250.0 * torch.cos(opt_angle)[..., None], 250.0 * torch.sin(opt_angle)[..., None]), dim=-1)

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [-100, 120, 0, 150, -5, 0.0],
            'state_labels': ['x', 'y', r'$\theta$', r'$v_x$', r'$v_y$', r'$\omega'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 4,
        }

class Quadrotor(Dynamics):
    def __init__(self, collisionR:float, thrust_max:float, set_mode:str):
        self.thrust_max = thrust_max
        self.m=1 #mass
        self.arm_l=0.17
        self.CT=1
        self.CM=0.016
        self.Gz=-9.8

        self.thrust_max = thrust_max
        self.collisionR = collisionR


        super().__init__(
            loss_type='brt_hjivi', set_mode=set_mode,
            state_dim=13, input_dim=14, control_dim=4, disturbance_dim=0,
            state_mean=[0 for i in range(13)],
            state_var=[1.5, 1.5, 1.5, 1, 1, 1, 1, 10, 10 ,10 ,10 ,10 ,10],
            value_mean=(math.sqrt(1.5**2+1.5**2+1.5**2)-2*self.collisionR)/2,
            value_var=math.sqrt(1.5**2+1.5**2+1.5**2),
            value_normto=0.02,
            deepreach_model="exact"
        )

    def state_test_range(self):
        return [
            [-1.5, 1.5],
            [-1.5, 1.5],
            [-1.5, 1.5],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-10, 10],
            [-10, 10],
            [-10, 10],
            [-10, 10],
            [-10, 10],
            [-10, 10],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        return wrapped_state

    # Dubins3D dynamics
    # \dot x    = v \cos \theta
    # \dot y    = v \sin \theta
    # \dot \theta = u
    def dsdt(self, state, control, disturbance):
        qw = state[..., 3] * 1.0
        qx = state[..., 4] * 1.0
        qy = state[..., 5] * 1.0
        qz = state[..., 6] * 1.0
        vx = state[..., 7] * 1.0
        vy = state[..., 8] * 1.0
        vz = state[..., 9] * 1.0
        wx = state[..., 10] * 1.0
        wy = state[..., 11] * 1.0
        wz = state[..., 12] * 1.0
        u1 = control[...,0] * 1.0
        u2 = control[...,1] * 1.0
        u3 = control[...,2] * 1.0
        u4 = control[...,3] * 1.0


        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = vx
        dsdt[..., 1] = vy
        dsdt[..., 2] = vz
        dsdt[..., 3] = -(wx*qx+wy*qy+wz*qz)/2.0
        dsdt[..., 4] =  (wx*qw+wz*qy-wy*qz)/2.0
        dsdt[..., 5] = (wy*qw-wz*qx+wx*qz)/2.0
        dsdt[..., 6] = (wz*qw+wy*qx-wx*qy)/2.0
        dsdt[..., 7] = 2*(qw*qy+qx*qz)*self.CT/self.m*(u1+u2+u3+u4)
        dsdt[..., 8] =2*(-qw*qx+qy*qz)*self.CT/self.m*(u1+u2+u3+u4)
        dsdt[..., 9] =self.Gz+(1-2*torch.pow(qx,2)-2*torch.pow(qy,2))*self.CT/self.m*(u1+u2+u3+u4)
        dsdt[..., 10] = 4*math.sqrt(2)*self.CT*(u1-u2-u3+u4)/(3*self.arm_l*self.m)-5*wy*wz/9.0
        dsdt[..., 11] = 4*math.sqrt(2)*self.CT*(-u1-u2+u3+u4)/(3*self.arm_l*self.m)+5*wx*wz/9.0
        dsdt[..., 12] =12*self.CT*self.CM/(7*self.arm_l**2*self.m)*(u1-u2+u3-u4)
        return dsdt

    def boundary_fn(self, state):
        return torch.norm(state[..., :3], dim=-1) - self.collisionR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        if self.set_mode == 'reach':
            raise NotImplementedError

        elif self.set_mode == 'avoid':
            qw = state[..., 3] * 1.0
            qx = state[..., 4] * 1.0
            qy = state[..., 5] * 1.0
            qz = state[..., 6] * 1.0
            vx = state[..., 7] * 1.0
            vy = state[..., 8] * 1.0
            vz = state[..., 9] * 1.0
            wx = state[..., 10] * 1.0
            wy = state[..., 11] * 1.0
            wz = state[..., 12] * 1.0


            C1=2*(qw*qy+qx*qz)*self.CT/self.m
            C2=2*(-qw*qx+qy*qz)*self.CT/self.m
            C3=(1-2*torch.pow(qx,2)-2*torch.pow(qy,2))*self.CT/self.m
            C4=4*math.sqrt(2)*self.CT/(3*self.arm_l*self.m)
            C5=4*math.sqrt(2)*self.CT/(3*self.arm_l*self.m)
            C6=12*self.CT*self.CM/(7*self.arm_l**2*self.m)

            # Compute the hamiltonian for the quadrotor
            ham= dvds[..., 0]*vx + dvds[..., 1]*vy+ dvds[..., 2]*vz
            ham+= -dvds[..., 3]* (wx*qx+wy*qy+wz*qz)/2.0
            ham+= dvds[..., 4]*(wx*qw+wz*qy-wy*qz)/2.0
            ham+= dvds[..., 5]*(wy*qw-wz*qx+wx*qz)/2.0
            ham+= dvds[..., 6]*(wz*qw+wy*qx-wx*qy)/2.0
            ham+= dvds[..., 9]*-9.8
            ham+= -dvds[..., 10]*5*wy*wz/9.0+ dvds[..., 11]*5*wx*wz/9.0

            ham+=torch.abs(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                +dvds[..., 10]*C4-dvds[..., 11]*C5+dvds[..., 12]*C6)*self.thrust_max

            ham+=torch.abs(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                -dvds[..., 10]*C4-dvds[..., 11]*C5-dvds[..., 12]*C6)*self.thrust_max

            ham+=torch.abs(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                -dvds[..., 10]*C4+dvds[..., 11]*C5+dvds[..., 12]*C6)*self.thrust_max

            ham+=torch.abs(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                +dvds[..., 10]*C4+dvds[..., 11]*C5-dvds[..., 12]*C6)*self.thrust_max

            return ham

    def optimal_control(self, state, dvds):
        if self.set_mode == 'reach':
            raise NotImplementedError
        elif self.set_mode == 'avoid':
            qw = state[..., 3] * 1.0
            qx = state[..., 4] * 1.0
            qy = state[..., 5] * 1.0
            qz = state[..., 6] * 1.0


            C1=2*(qw*qy+qx*qz)*self.CT/self.m
            C2=2*(-qw*qx+qy*qz)*self.CT/self.m
            C3=(1-2*torch.pow(qx,2)-2*torch.pow(qy,2))*self.CT/self.m
            C4=4*math.sqrt(2)*self.CT/(3*self.arm_l*self.m)
            C5=4*math.sqrt(2)*self.CT/(3*self.arm_l*self.m)
            C6=12*self.CT*self.CM/(7*self.arm_l**2*self.m)


            u1=self.thrust_max*torch.sign(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                +dvds[..., 10]*C4-dvds[..., 11]*C5+dvds[..., 12]*C6)
            u2=self.thrust_max*torch.sign(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                -dvds[..., 10]*C4-dvds[..., 11]*C5-dvds[..., 12]*C6)
            u3=self.thrust_max*torch.sign(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                -dvds[..., 10]*C4+dvds[..., 11]*C5+dvds[..., 12]*C6)
            u4=self.thrust_max*torch.sign(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                +dvds[..., 10]*C4+dvds[..., 11]*C5-dvds[..., 12]*C6)

        return torch.cat((u1[..., None], u2[..., None], u3[..., None], u4[..., None]), dim=-1)

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            'state_labels': ['x', 'y', 'z', 'qw', 'qx', 'qy', 'qz', 'vx', 'vy', 'vz', 'wx', 'wy', 'wz'],
            'x_axis_idx': 0,
            'y_axis_idx': 2,
            'z_axis_idx': 7,
        }

class MultiVehicleCollision(Dynamics):
    def __init__(self):
        self.angle_alpha_factor = 1.2
        self.velocity = 0.6
        self.omega_max = 1.1
        self.collisionR = 0.25
        super().__init__(
            loss_type='brt_hjivi', set_mode='avoid',
            state_dim=9, input_dim=10, control_dim=3, disturbance_dim=0,
            state_mean=[
                0, 0,
                0, 0,
                0, 0,
                0, 0, 0,
            ],
            state_var=[
                1, 1,
                1, 1,
                1, 1,
                self.angle_alpha_factor*math.pi, self.angle_alpha_factor*math.pi, self.angle_alpha_factor*math.pi,
            ],
            value_mean=0.25,
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact"
        )

    def state_test_range(self):
        return [
            [-1, 1], [-1, 1],
            [-1, 1], [-1, 1],
            [-1, 1], [-1, 1],
            [-math.pi, math.pi], [-math.pi, math.pi], [-math.pi, math.pi],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 6] = (wrapped_state[..., 6] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 7] = (wrapped_state[..., 7] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 8] = (wrapped_state[..., 8] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # dynamics (per car)
    # \dot x    = v \cos \theta
    # \dot y    = v \sin \theta
    # \dot \theta = u
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = self.velocity*torch.cos(state[..., 6])
        dsdt[..., 1] = self.velocity*torch.sin(state[..., 6])
        dsdt[..., 2] = self.velocity*torch.cos(state[..., 7])
        dsdt[..., 3] = self.velocity*torch.sin(state[..., 7])
        dsdt[..., 4] = self.velocity*torch.cos(state[..., 8])
        dsdt[..., 5] = self.velocity*torch.sin(state[..., 8])
        dsdt[..., 6] = control[..., 0]
        dsdt[..., 7] = control[..., 1]
        dsdt[..., 8] = control[..., 2]
        return dsdt

    def boundary_fn(self, state):
        boundary_values = torch.norm(state[..., 0:2] - state[..., 2:4], dim=-1) - self.collisionR
        for i in range(1, 2):
            boundary_values_current = torch.norm(state[..., 0:2] - state[..., 2*(i+1):2*(i+1)+2], dim=-1) - self.collisionR
            boundary_values = torch.min(boundary_values, boundary_values_current)
        # Collision cost between the evaders themselves
        for i in range(2):
            for j in range(i+1, 2):
                evader1_coords_index = (i+1)*2
                evader2_coords_index = (j+1)*2
                boundary_values_current = torch.norm(state[..., evader1_coords_index:evader1_coords_index+2] - state[..., evader2_coords_index:evader2_coords_index+2], dim=-1) - self.collisionR
                boundary_values = torch.min(boundary_values, boundary_values_current)
        return boundary_values

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        # Compute the hamiltonian for the ego vehicle
        ham = self.velocity*(torch.cos(state[..., 6]) * dvds[..., 0] + torch.sin(state[..., 6]) * dvds[..., 1]) + self.omega_max * torch.abs(dvds[..., 6])
        # Hamiltonian effect due to other vehicles
        ham += self.velocity*(torch.cos(state[..., 7]) * dvds[..., 2] + torch.sin(state[..., 7]) * dvds[..., 3]) + self.omega_max * torch.abs(dvds[..., 7])
        ham += self.velocity*(torch.cos(state[..., 8]) * dvds[..., 4] + torch.sin(state[..., 8]) * dvds[..., 5]) + self.omega_max * torch.abs(dvds[..., 8])
        return ham

    def optimal_control(self, state, dvds):
        return self.omega_max*torch.sign(dvds[..., [6, 7, 8]])

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [
                0, 0,
                -0.4, 0,
                0.4, 0,
                math.pi/2, math.pi/4, 3*math.pi/4,
            ],
            'state_labels': [
                r'$x_1$', r'$y_1$',
                r'$x_2$', r'$y_2$',
                r'$x_3$', r'$y_3$',
                r'$\theta_1$', r'$\theta_2$', r'$\theta_3$',
            ],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 6,
        }

###############################################################
# Custom Dynamics for 2d Planar Robot
###############################################################
class PlanarRobot2D(Dynamics):
    def __init__(
        self,
        goalR: float,
        velocity: float,
        set_mode: str = 'avoid',
        freeze_model: bool = False
    ):
        """
        2D constant-speed robot dynamics for reach/avoid problems.

        Args:
            goalR: Radius of the obstacle (avoid) or goal (reach) circle.
            velocity: Constant forward speed (m/s).
            set_mode: 'reach' or 'avoid'.
            freeze_model: If True, dsdt/hamiltonian will raise NotImplementedError.
        """
        self.goalR = goalR
        self.velocity = velocity
        self.freeze_model = freeze_model

        super().__init__(
            loss_type='brt_hjivi',
            set_mode=set_mode,
            state_dim=2, # [p_x, p_y]
            input_dim=3, # [p_x, p_y, t]
            control_dim=1, # [theta]
            disturbance_dim=0,

            # Normalize x,y from [-2,2] to [-1,1]
            state_mean=[0.0, 0.0],
            state_var=[2.0, 2.0],

            value_mean=0.0, # value_mean is not used in 'exact' mode
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-2.0, 2.0], # p_x
            [-2.0, 2.0], # p_y
        ]

    def equivalent_wrapped_state(self, state):
        return state

    def dsdt(self, state, control, disturbance):
        # \dot p_x = v \cos \theta
        # \dot p_y = v \sin \theta
        if self.freeze_model:
            raise NotImplementedError
        theta = control[..., 0]
        ds = torch.zeros_like(state)
        ds[..., 0] = self.velocity * torch.cos(theta)
        ds[..., 1] = self.velocity * torch.sin(theta)
        return ds

    def boundary_fn(self, state):
        return torch.norm(state, dim=-1) - self.goalR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        raise NotImplementedError
        # return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        if self.freeze_model:
            raise NotImplementedError
        norm = torch.sqrt(dvds[..., 0]**2 + dvds[..., 1]**2)
        if self.set_mode == 'reach':
            return  - self.velocity * norm
        elif self.set_mode == 'avoid':
            return  self.velocity * norm

    def optimal_control(self, state, dvds):
        raise NotImplementedError

    def optimal_disturbance(self, state, dvds):
        raise NotImplementedError

    def plot_config(self):
        return {
            "state_slices": [0, 0],
            "state_labels": ["p_x", "p_y"],
            "x_axis_idx": 0,
            "y_axis_idx": 1,
            "z_axis_idx": -1,
        }

###############################################################


### 6. Experiments Code

In [12]:
class Experiment(ABC):
    def __init__(self, model, dataset, experiment_dir, use_wandb):
        self.model = model
        self.dataset = dataset
        self.experiment_dir = experiment_dir
        self.use_wandb = use_wandb

    @abstractmethod
    def init_special(self):
        raise NotImplementedError

    def _load_checkpoint(self, epoch):
        if epoch == -1:
            model_path = os.path.join(self.experiment_dir, 'training', 'checkpoints', 'model_final.pth')
            self.model.load_state_dict(torch.load(model_path))
        else:
            model_path = os.path.join(self.experiment_dir, 'training', 'checkpoints', 'model_epoch_%04d.pth' % epoch)
            self.model.load_state_dict(torch.load(model_path)['model'])

    def validate(self, device, epoch, save_path, x_resolution, y_resolution, z_resolution, time_resolution):
        was_training = self.model.training
        self.model.eval()
        self.model.requires_grad_(False)

        plot_config = self.dataset.dynamics.plot_config()
        state_test_range = self.dataset.dynamics.state_test_range()

        x_idx = plot_config['x_axis_idx']
        y_idx = plot_config['y_axis_idx']
        z_idx = plot_config.get('z_axis_idx', -1)

        x_min, x_max = state_test_range[x_idx]
        y_min, y_max = state_test_range[y_idx]
        xs = torch.linspace(x_min, x_max, x_resolution)
        ys = torch.linspace(y_min, y_max, y_resolution)
        xys = torch.cartesian_prod(xs, ys)

        # Determine z slices: if z_idx == -1, use a single dummy slice
        if z_idx == -1:
            zs = torch.tensor([0.0], dtype=torch.float32)
        else:
            z_min, z_max = state_test_range[z_idx]
            zs = torch.linspace(z_min, z_max, z_resolution)

        times = torch.linspace(0, self.dataset.tMax, time_resolution)
        fig = plt.figure(figsize=(5*len(times), 5*len(zs)))

        # Loop over time and z-slices
        for i, t in enumerate(times):
            for j, z_val in enumerate(zs):
                coords = torch.zeros(x_resolution*y_resolution, self.dataset.dynamics.state_dim + 1)
                coords[:, 0] = t
                coords[:, 1:] = torch.tensor(plot_config['state_slices'])

                coords[:, 1 + x_idx] = xys[:, 0]
                coords[:, 1 + y_idx] = xys[:, 1]

                # Only fill z if it's a real 3D axis
                if z_idx != -1:
                    coords[:, 1 + z_idx] = z_val

                with torch.no_grad():
                    inp = self.dataset.dynamics.coord_to_input(coords.to(device))
                    results = self.model({'coords': inp})
                    values = self.dataset.dynamics.io_to_value(results['model_in'].detach(), results['model_out'].squeeze(dim=-1).detach())

                ax = fig.add_subplot(len(times), len(zs), (j+1) + i*len(zs))
                title_z = (f", {plot_config['state_labels'][z_idx]} = {z_val:.2f}"
                        if z_idx != -1 else "")
                ax.set_title(f"t = {t:.2f}{title_z}")

                img = values.cpu().numpy().reshape(x_resolution, y_resolution).T
                # binary decision boundary: <=0 in red/blue
                mask = (img <= 0).astype(float)
                s = ax.imshow(mask, cmap='bwr', origin='lower', extent=(x_min, x_max, y_min, y_max))
                fig.colorbar(s, ax=ax)

        fig.savefig(save_path)
        if self.use_wandb:
            wandb.log({
                'step': epoch,
                'val_plot': wandb.Image(fig),
            })
        plt.close()

        if was_training:
            self.model.train()
            self.model.requires_grad_(True)

    def train(
            self, device, batch_size, epochs, lr,
            steps_til_summary, epochs_til_checkpoint,
            loss_fn, clip_grad, use_lbfgs, adjust_relative_grads,
            val_x_resolution, val_y_resolution, val_z_resolution, val_time_resolution,
            use_CSL, CSL_lr, CSL_dt, epochs_til_CSL, num_CSL_samples, CSL_loss_frac_cutoff, max_CSL_epochs, CSL_loss_weight, CSL_batch_size,
        ):
        was_eval = not self.model.training
        self.model.train()
        self.model.requires_grad_(True)

        train_dataloader = DataLoader(self.dataset, shuffle=True, batch_size=batch_size, pin_memory=True, num_workers=0)

        optim = torch.optim.Adam(lr=lr, params=self.model.parameters())

        # copy settings from Raissi et al. (2019) and here
        # https://github.com/maziarraissi/PINNs
        if use_lbfgs:
            optim = torch.optim.LBFGS(lr=lr, params=self.model.parameters(), max_iter=50000, max_eval=50000,
                                    history_size=50, line_search_fn='strong_wolfe')

        training_dir = os.path.join(self.experiment_dir, 'training')

        summaries_dir = os.path.join(training_dir, 'summaries')
        if not os.path.exists(summaries_dir):
            os.makedirs(summaries_dir)

        checkpoints_dir = os.path.join(training_dir, 'checkpoints')
        if not os.path.exists(checkpoints_dir):
            os.makedirs(checkpoints_dir)

        writer = SummaryWriter(summaries_dir)

        total_steps = 0

        if adjust_relative_grads:
            new_weight = 1

        with tqdm(total=len(train_dataloader) * epochs) as pbar:
            train_losses = []
            last_CSL_epoch = -1
            for epoch in range(0, epochs):
                if self.dataset.pretrain: # skip CSL
                    last_CSL_epoch = epoch
                time_interval_length = (self.dataset.counter/self.dataset.counter_end)*(self.dataset.tMax-self.dataset.tMin)
                CSL_tMax = self.dataset.tMin + int(time_interval_length/CSL_dt)*CSL_dt

                # self-supervised learning
                for step, (model_input, gt) in enumerate(train_dataloader):
                    start_time = time.time()

                    model_input = {key: value.to(device) for key, value in model_input.items()}
                    gt = {key: value.to(device) for key, value in gt.items()}

                    model_results = self.model({'coords': model_input['model_coords']})

                    states = self.dataset.dynamics.input_to_coord(model_results['model_in'].detach())[..., 1:]
                    values = self.dataset.dynamics.io_to_value(model_results['model_in'].detach(), model_results['model_out'].squeeze(dim=-1))
                    dvs = self.dataset.dynamics.io_to_dv(model_results['model_in'], model_results['model_out'].squeeze(dim=-1))
                    boundary_values = gt['boundary_values']
                    if self.dataset.dynamics.loss_type == 'brat_hjivi':
                        reach_values = gt['reach_values']
                        avoid_values = gt['avoid_values']
                    dirichlet_masks = gt['dirichlet_masks']

                    if self.dataset.dynamics.loss_type == 'brt_hjivi':
                        losses = loss_fn(states, values, dvs[..., 0], dvs[..., 1:], boundary_values, dirichlet_masks, model_results['model_out'])
                    elif self.dataset.dynamics.loss_type == 'brat_hjivi':
                        losses = loss_fn(states, values, dvs[..., 0], dvs[..., 1:], boundary_values, reach_values, avoid_values, dirichlet_masks, model_results['model_out'])
                    else:
                        raise NotImplementedError

                    if use_lbfgs:
                        def closure():
                            optim.zero_grad()
                            train_loss = 0.
                            for loss_name, loss in losses.items():
                                train_loss += loss.mean()
                            train_loss.backward()
                            return train_loss
                        optim.step(closure)

                    # Adjust the relative magnitude of the losses if required
                    if self.dataset.dynamics.deepreach_model in ['vanilla', 'diff'] and adjust_relative_grads:
                        if losses['diff_constraint_hom'] > 0.01:
                            params = OrderedDict(self.model.named_parameters())
                            # Gradients with respect to the PDE loss
                            optim.zero_grad()
                            losses['diff_constraint_hom'].backward(retain_graph=True)
                            grads_PDE = []
                            for key, param in params.items():
                                grads_PDE.append(param.grad.view(-1))
                            grads_PDE = torch.cat(grads_PDE)

                            # Gradients with respect to the boundary loss
                            optim.zero_grad()
                            losses['dirichlet'].backward(retain_graph=True)
                            grads_dirichlet = []
                            for key, param in params.items():
                                grads_dirichlet.append(param.grad.view(-1))
                            grads_dirichlet = torch.cat(grads_dirichlet)

                            # # Plot the gradients
                            # import seaborn as sns
                            # import matplotlib.pyplot as plt
                            # fig = plt.figure(figsize=(5, 5))
                            # ax = fig.add_subplot(1, 1, 1)
                            # ax.set_yscale('symlog')
                            # sns.distplot(grads_PDE.cpu().numpy(), hist=False, kde_kws={"shade": False}, norm_hist=True)
                            # sns.distplot(grads_dirichlet.cpu().numpy(), hist=False, kde_kws={"shade": False}, norm_hist=True)
                            # fig.savefig('gradient_visualization.png')

                            # fig = plt.figure(figsize=(5, 5))
                            # ax = fig.add_subplot(1, 1, 1)
                            # ax.set_yscale('symlog')
                            # grads_dirichlet_normalized = grads_dirichlet * torch.mean(torch.abs(grads_PDE))/torch.mean(torch.abs(grads_dirichlet))
                            # sns.distplot(grads_PDE.cpu().numpy(), hist=False, kde_kws={"shade": False}, norm_hist=True)
                            # sns.distplot(grads_dirichlet_normalized.cpu().numpy(), hist=False, kde_kws={"shade": False}, norm_hist=True)
                            # ax.set_xlim([-1000.0, 1000.0])
                            # fig.savefig('gradient_visualization_normalized.png')

                            # Set the new weight according to the paper
                            # num = torch.max(torch.abs(grads_PDE))
                            num = torch.mean(torch.abs(grads_PDE))
                            den = torch.mean(torch.abs(grads_dirichlet))
                            new_weight = 0.9*new_weight + 0.1*num/den
                            losses['dirichlet'] = new_weight*losses['dirichlet']
                        writer.add_scalar('weight_scaling', new_weight, total_steps)

                    # import ipdb; ipdb.set_trace()

                    train_loss = 0.
                    for loss_name, loss in losses.items():
                        single_loss = loss.mean()

                        if loss_name == 'dirichlet':
                            writer.add_scalar(loss_name, single_loss/new_weight, total_steps)
                        else:
                            writer.add_scalar(loss_name, single_loss, total_steps)
                        train_loss += single_loss

                    train_losses.append(train_loss.item())
                    writer.add_scalar("total_train_loss", train_loss, total_steps)

                    if not total_steps % steps_til_summary:
                        torch.save(self.model.state_dict(),
                                os.path.join(checkpoints_dir, 'model_current.pth'))
                        # summary_fn(model, model_input, gt, model_output, writer, total_steps)

                    if not use_lbfgs:
                        optim.zero_grad()
                        train_loss.backward()

                        if clip_grad:
                            if isinstance(clip_grad, bool):
                                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.)
                            else:
                                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=clip_grad)

                        optim.step()

                    pbar.update(1)

                    if not total_steps % steps_til_summary:
                        tqdm.write("Epoch %d, Total loss %0.6f, iteration time %0.6f" % (epoch, train_loss, time.time() - start_time))
                        if self.use_wandb:
                            wandb.log({
                                'step': epoch,
                                'train_loss': train_loss,
                                'pde_loss': losses['diff_constraint_hom'],
                            })

                    total_steps += 1

                # cost-supervised learning (CSL) phase
                if use_CSL and not self.dataset.pretrain and (epoch-last_CSL_epoch) >= epochs_til_CSL:
                    last_CSL_epoch = epoch

                    # generate CSL datasets
                    self.model.eval()

                    CSL_dataset = scenario_optimization(
                        device=device, model=self.model, policy=self.model, dynamics=self.dataset.dynamics,
                        tMin=self.dataset.tMin, tMax=CSL_tMax, dt=CSL_dt,
                        set_type="BRT", control_type="value", # TODO: implement option for BRS too
                        scenario_batch_size=min(num_CSL_samples, 100000), sample_batch_size=min(10*num_CSL_samples, 1000000),
                        sample_generator=SliceSampleGenerator(dynamics=self.dataset.dynamics, slices=[None]*self.dataset.dynamics.state_dim),
                        sample_validator=ValueThresholdValidator(v_min=float('-inf'), v_max=float('inf')),
                        violation_validator=ValueThresholdValidator(v_min=0.0, v_max=0.0),
                        max_scenarios=num_CSL_samples, tStart_generator=lambda n : torch.zeros(n).uniform_(self.dataset.tMin, CSL_tMax)
                    )
                    CSL_coords = torch.cat((CSL_dataset['times'].unsqueeze(-1), CSL_dataset['states']), dim=-1)
                    CSL_costs = CSL_dataset['costs']

                    num_CSL_val_samples = int(0.1*num_CSL_samples)
                    CSL_val_dataset = scenario_optimization(
                        model=self.model, policy=self.model, dynamics=self.dataset.dynamics,
                        tMin=self.dataset.tMin, tMax=CSL_tMax, dt=CSL_dt,
                        set_type="BRT", control_type="value", # TODO: implement option for BRS too
                        scenario_batch_size=min(num_CSL_val_samples, 100000), sample_batch_size=min(10*num_CSL_val_samples, 1000000),
                        sample_generator=SliceSampleGenerator(dynamics=self.dataset.dynamics, slices=[None]*self.dataset.dynamics.state_dim),
                        sample_validator=ValueThresholdValidator(v_min=float('-inf'), v_max=float('inf')),
                        violation_validator=ValueThresholdValidator(v_min=0.0, v_max=0.0),
                        max_scenarios=num_CSL_val_samples, tStart_generator=lambda n : torch.zeros(n).uniform_(self.dataset.tMin, CSL_tMax)
                    )
                    CSL_val_coords = torch.cat((CSL_val_dataset['times'].unsqueeze(-1), CSL_val_dataset['states']), dim=-1)
                    CSL_val_costs = CSL_val_dataset['costs']

                    CSL_val_tMax_dataset = scenario_optimization(
                        model=self.model, policy=self.model, dynamics=self.dataset.dynamics,
                        tMin=self.dataset.tMin, tMax=self.dataset.tMax, dt=CSL_dt,
                        set_type="BRT", control_type="value", # TODO: implement option for BRS too
                        scenario_batch_size=min(num_CSL_val_samples, 100000), sample_batch_size=min(10*num_CSL_val_samples, 1000000),
                        sample_generator=SliceSampleGenerator(dynamics=self.dataset.dynamics, slices=[None]*self.dataset.dynamics.state_dim),
                        sample_validator=ValueThresholdValidator(v_min=float('-inf'), v_max=float('inf')),
                        violation_validator=ValueThresholdValidator(v_min=0.0, v_max=0.0),
                        max_scenarios=num_CSL_val_samples # no tStart_generator, since I want all tMax times
                    )
                    CSL_val_tMax_coords = torch.cat((CSL_val_tMax_dataset['times'].unsqueeze(-1), CSL_val_tMax_dataset['states']), dim=-1)
                    CSL_val_tMax_costs = CSL_val_tMax_dataset['costs']

                    self.model.train()

                    # CSL optimizer
                    CSL_optim = torch.optim.Adam(lr=CSL_lr, params=self.model.parameters())

                    # initial CSL val loss
                    CSL_val_results = self.model({'coords': self.dataset.dynamics.coord_to_input(CSL_val_coords.to(device))})
                    CSL_val_preds = self.dataset.dynamics.io_to_value(CSL_val_results['model_in'], CSL_val_results['model_out'].squeeze(dim=-1))
                    CSL_val_errors = CSL_val_preds - CSL_val_costs.to(device)
                    CSL_val_loss = torch.mean(torch.pow(CSL_val_errors, 2))
                    CSL_initial_val_loss = CSL_val_loss
                    if self.use_wandb:
                        wandb.log({
                            "step": epoch,
                            "CSL_val_loss": CSL_val_loss.item()
                        })

                    # initial self-supervised learning (SSL) val loss
                    # right now, just took code from dataio.py and the SSL training loop above; TODO: refactor all this for cleaner modular code
                    CSL_val_states = CSL_val_coords[..., 1:].to(device)
                    CSL_val_dvs = self.dataset.dynamics.io_to_dv(CSL_val_results['model_in'], CSL_val_results['model_out'].squeeze(dim=-1))
                    CSL_val_boundary_values = self.dataset.dynamics.boundary_fn(CSL_val_states)
                    if self.dataset.dynamics.loss_type == 'brat_hjivi':
                        CSL_val_reach_values = self.dataset.dynamics.reach_fn(CSL_val_states)
                        CSL_val_avoid_values = self.dataset.dynamics.avoid_fn(CSL_val_states)
                    CSL_val_dirichlet_masks = CSL_val_coords[:, 0].to(device) == self.dataset.tMin # assumes time unit in dataset (model) is same as real time units
                    if self.dataset.dynamics.loss_type == 'brt_hjivi':
                        SSL_val_losses = loss_fn(CSL_val_states, CSL_val_preds, CSL_val_dvs[..., 0], CSL_val_dvs[..., 1:], CSL_val_boundary_values, CSL_val_dirichlet_masks)
                    elif self.dataset.dynamics.loss_type == 'brat_hjivi':
                        SSL_val_losses = loss_fn(CSL_val_states, CSL_val_preds, CSL_val_dvs[..., 0], CSL_val_dvs[..., 1:], CSL_val_boundary_values, CSL_val_reach_values, CSL_val_avoid_values, CSL_val_dirichlet_masks)
                    else:
                        NotImplementedError
                    SSL_val_loss = SSL_val_losses['diff_constraint_hom'].mean() # I assume there is no dirichlet (boundary) loss here, because I do not ever explicitly generate source samples at tMin (i.e. torch.all(CSL_val_dirichlet_masks == False))
                    if self.use_wandb:
                        wandb.log({
                            "step": epoch,
                            "SSL_val_loss": SSL_val_loss.item()
                        })

                    # CSL training loop
                    for CSL_epoch in tqdm(range(max_CSL_epochs)):
                        CSL_idxs = torch.randperm(num_CSL_samples)
                        for CSL_batch in range(math.ceil(num_CSL_samples/CSL_batch_size)):
                            CSL_batch_idxs = CSL_idxs[CSL_batch*CSL_batch_size:(CSL_batch+1)*CSL_batch_size]
                            CSL_batch_coords = CSL_coords[CSL_batch_idxs]

                            CSL_batch_results = self.model({'coords': self.dataset.dynamics.coord_to_input(CSL_batch_coords.to(device))})
                            CSL_batch_preds = self.dataset.dynamics.io_to_value(CSL_batch_results['model_in'], CSL_batch_results['model_out'].squeeze(dim=-1))
                            CSL_batch_costs = CSL_costs[CSL_batch_idxs].to(device)
                            CSL_batch_errors = CSL_batch_preds - CSL_batch_costs
                            CSL_batch_loss = CSL_loss_weight*torch.mean(torch.pow(CSL_batch_errors, 2))

                            CSL_batch_states = CSL_batch_coords[..., 1:].to(device)
                            CSL_batch_dvs = self.dataset.dynamics.io_to_dv(CSL_batch_results['model_in'], CSL_batch_results['model_out'].squeeze(dim=-1))
                            CSL_batch_boundary_values = self.dataset.dynamics.boundary_fn(CSL_batch_states)
                            if self.dataset.dynamics.loss_type == 'brat_hjivi':
                                CSL_batch_reach_values = self.dataset.dynamics.reach_fn(CSL_batch_states)
                                CSL_batch_avoid_values = self.dataset.dynamics.avoid_fn(CSL_batch_states)
                            CSL_batch_dirichlet_masks = CSL_batch_coords[:, 0].to(device) == self.dataset.tMin # assumes time unit in dataset (model) is same as real time units
                            if self.dataset.dynamics.loss_type == 'brt_hjivi':
                                SSL_batch_losses = loss_fn(CSL_batch_states, CSL_batch_preds, CSL_batch_dvs[..., 0], CSL_batch_dvs[..., 1:], CSL_batch_boundary_values, CSL_batch_dirichlet_masks)
                            elif self.dataset.dynamics.loss_type == 'brat_hjivi':
                                SSL_batch_losses = loss_fn(CSL_batch_states, CSL_batch_preds, CSL_batch_dvs[..., 0], CSL_batch_dvs[..., 1:], CSL_batch_boundary_values, CSL_batch_reach_values, CSL_batch_avoid_values, CSL_batch_dirichlet_masks)
                            else:
                                NotImplementedError
                            SSL_batch_loss = SSL_batch_losses['diff_constraint_hom'].mean() # I assume there is no dirichlet (boundary) loss here, because I do not ever explicitly generate source samples at tMin (i.e. torch.all(CSL_batch_dirichlet_masks == False))

                            CSL_optim.zero_grad()
                            SSL_batch_loss.backward(retain_graph=True)
                            if (not use_lbfgs) and clip_grad: # no adjust_relative_grads, because I assume even with adjustment, the diff_constraint_hom remains unaffected and the only other loss (dirichlet) is zero
                                if isinstance(clip_grad, bool):
                                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.)
                                else:
                                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=clip_grad)
                            CSL_batch_loss.backward()
                            CSL_optim.step()

                        # evaluate on CSL_val_dataset
                        CSL_val_results = self.model({'coords': self.dataset.dynamics.coord_to_input(CSL_val_coords.to(device))})
                        CSL_val_preds = self.dataset.dynamics.io_to_value(CSL_val_results['model_in'], CSL_val_results['model_out'].squeeze(dim=-1))
                        CSL_val_errors = CSL_val_preds - CSL_val_costs.to(device)
                        CSL_val_loss = torch.mean(torch.pow(CSL_val_errors, 2))

                        CSL_val_states = CSL_val_coords[..., 1:].to(device)
                        CSL_val_dvs = self.dataset.dynamics.io_to_dv(CSL_val_results['model_in'], CSL_val_results['model_out'].squeeze(dim=-1))
                        CSL_val_boundary_values = self.dataset.dynamics.boundary_fn(CSL_val_states)
                        if self.dataset.dynamics.loss_type == 'brat_hjivi':
                            CSL_val_reach_values = self.dataset.dynamics.reach_fn(CSL_val_states)
                            CSL_val_avoid_values = self.dataset.dynamics.avoid_fn(CSL_val_states)
                        CSL_val_dirichlet_masks = CSL_val_coords[:, 0].to(device) == self.dataset.tMin # assumes time unit in dataset (model) is same as real time units
                        if self.dataset.dynamics.loss_type == 'brt_hjivi':
                            SSL_val_losses = loss_fn(CSL_val_states, CSL_val_preds, CSL_val_dvs[..., 0], CSL_val_dvs[..., 1:], CSL_val_boundary_values, CSL_val_dirichlet_masks)
                        elif self.dataset.dynamics.loss_type == 'brat_hjivi':
                            SSL_val_losses = loss_fn(CSL_val_states, CSL_val_preds, CSL_val_dvs[..., 0], CSL_val_dvs[..., 1:], CSL_val_boundary_values, CSL_val_reach_values, CSL_val_avoid_values, CSL_val_dirichlet_masks)
                        else:
                            raise NotImplementedError
                        SSL_val_loss = SSL_val_losses['diff_constraint_hom'].mean() # I assume there is no dirichlet (boundary) loss here, because I do not ever explicitly generate source samples at tMin (i.e. torch.all(CSL_val_dirichlet_masks == False))

                        CSL_val_tMax_results = self.model({'coords': self.dataset.dynamics.coord_to_input(CSL_val_tMax_coords.to(device))})
                        CSL_val_tMax_preds = self.dataset.dynamics.io_to_value(CSL_val_tMax_results['model_in'], CSL_val_tMax_results['model_out'].squeeze(dim=-1))
                        CSL_val_tMax_errors = CSL_val_tMax_preds - CSL_val_tMax_costs.to(device)
                        CSL_val_tMax_loss = torch.mean(torch.pow(CSL_val_tMax_errors, 2))

                        # log CSL losses, recovered_safe_set_fracs
                        if self.dataset.dynamics.set_mode == 'reach':
                            CSL_train_batch_theoretically_recoverable_safe_set_frac = torch.sum(CSL_batch_costs.to(device) < 0) / len(CSL_batch_preds)
                            CSL_train_batch_recovered_safe_set_frac = torch.sum(CSL_batch_preds < torch.min(CSL_batch_preds[CSL_batch_costs.to(device) > 0])) / len(CSL_batch_preds)
                            CSL_val_theoretically_recoverable_safe_set_frac = torch.sum(CSL_val_costs.to(device) < 0) / len(CSL_val_preds)
                            CSL_val_recovered_safe_set_frac = torch.sum(CSL_val_preds < torch.min(CSL_val_preds[CSL_val_costs.to(device) > 0])) / len(CSL_val_preds)
                            CSL_val_tMax_theoretically_recoverable_safe_set_frac = torch.sum(CSL_val_tMax_costs.to(device) < 0) / len(CSL_val_tMax_preds)
                            CSL_val_tMax_recovered_safe_set_frac = torch.sum(CSL_val_tMax_preds < torch.min(CSL_val_tMax_preds[CSL_val_tMax_costs.to(device) > 0])) / len(CSL_val_tMax_preds)
                        elif self.dataset.dynamics.set_mode == 'avoid':
                            CSL_train_batch_theoretically_recoverable_safe_set_frac = torch.sum(CSL_batch_costs.to(device) > 0) / len(CSL_batch_preds)
                            CSL_train_batch_recovered_safe_set_frac = torch.sum(CSL_batch_preds > torch.max(CSL_batch_preds[CSL_batch_costs.to(device) < 0])) / len(CSL_batch_preds)
                            CSL_val_theoretically_recoverable_safe_set_frac = torch.sum(CSL_val_costs.to(device) > 0) / len(CSL_val_preds)
                            CSL_val_recovered_safe_set_frac = torch.sum(CSL_val_preds > torch.max(CSL_val_preds[CSL_val_costs.to(device) < 0])) / len(CSL_val_preds)
                            CSL_val_tMax_theoretically_recoverable_safe_set_frac = torch.sum(CSL_val_tMax_costs.to(device) > 0) / len(CSL_val_tMax_preds)
                            CSL_val_tMax_recovered_safe_set_frac = torch.sum(CSL_val_tMax_preds > torch.max(CSL_val_tMax_preds[CSL_val_tMax_costs.to(device) < 0])) / len(CSL_val_tMax_preds)
                        else:
                            raise NotImplementedError
                        if self.use_wandb:
                            wandb.log({
                                "step": epoch+(CSL_epoch+1)*int(0.5*epochs_til_CSL/max_CSL_epochs),
                                "CSL_train_batch_loss": CSL_batch_loss.item(),
                                "SSL_train_batch_loss": SSL_batch_loss.item(),
                                "CSL_val_loss": CSL_val_loss.item(),
                                "SSL_val_loss": SSL_val_loss.item(),
                                "CSL_val_tMax_loss": CSL_val_tMax_loss.item(),
                                "CSL_train_batch_theoretically_recoverable_safe_set_frac": CSL_train_batch_theoretically_recoverable_safe_set_frac.item(),
                                "CSL_val_theoretically_recoverable_safe_set_frac": CSL_val_theoretically_recoverable_safe_set_frac.item(),
                                "CSL_val_tMax_theoretically_recoverable_safe_set_frac": CSL_val_tMax_theoretically_recoverable_safe_set_frac.item(),
                                "CSL_train_batch_recovered_safe_set_frac": CSL_train_batch_recovered_safe_set_frac.item(),
                                "CSL_val_recovered_safe_set_frac": CSL_val_recovered_safe_set_frac.item(),
                                "CSL_val_tMax_recovered_safe_set_frac": CSL_val_tMax_recovered_safe_set_frac.item(),
                            })

                        if CSL_val_loss < CSL_loss_frac_cutoff*CSL_initial_val_loss:
                            break

                if not (epoch+1) % epochs_til_checkpoint:
                    # Saving the optimizer state is important to produce consistent results
                    checkpoint = {
                        'epoch': epoch+1,
                        'model': self.model.state_dict(),
                        'optimizer': optim.state_dict()}
                    torch.save(checkpoint,
                        os.path.join(checkpoints_dir, 'model_epoch_%04d.pth' % (epoch+1)))
                    np.savetxt(os.path.join(checkpoints_dir, 'train_losses_epoch_%04d.txt' % (epoch+1)),
                        np.array(train_losses))
                    self.validate(
                        device=device, epoch=epoch+1, save_path=os.path.join(checkpoints_dir, 'BRS_validation_plot_epoch_%04d.png' % (epoch+1)),
                        x_resolution = val_x_resolution, y_resolution = val_y_resolution, z_resolution=val_z_resolution, time_resolution=val_time_resolution)

        if was_eval:
            self.model.eval()
            self.model.requires_grad_(False)

    def test(self, device, current_time, last_checkpoint, checkpoint_dt, dt, num_scenarios, num_violations, set_type, control_type, data_step, checkpoint_toload=None):
        was_training = self.model.training
        self.model.eval()
        self.model.requires_grad_(False)

        testing_dir = os.path.join(self.experiment_dir, 'testing_%s' % current_time.strftime('%m_%d_%Y_%H_%M'))
        if os.path.exists(testing_dir):
            overwrite = input("The testing directory %s already exists. Overwrite? (y/n)"%testing_dir)
            if not (overwrite == 'y'):
                print('Exiting.')
                quit()
            shutil.rmtree(testing_dir)
        os.makedirs(testing_dir)

        if checkpoint_toload is None:
            print('running cross-checkpoint testing')

            for i in tqdm(range(sidelen), desc='Checkpoint'):
                self._load_checkpoint(epoch=checkpoints[i])
                raise NotImplementedError

        else:
            print('running specific-checkpoint testing')
            self._load_checkpoint(checkpoint_toload)

            model = self.model
            dataset = self.dataset
            dynamics = dataset.dynamics
            raise NotImplementedError

        if was_training:
            self.model.train()
            self.model.requires_grad_(True)

class DeepReach(Experiment):
    def init_special(self):
        pass

### 7. Run Experiment

In [13]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33maanirudh[0m ([33maanirudh-n-a[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [16]:
import sys

sys.argv = [
    "script_name",  # Placeholder for script name (ignored by argparse)
    "--mode", "train",
    "--experiment_class", "DeepReach",
    "--dynamics_class", "PlanarRobot2D",
    "--experiment_name", "brt_obstacle_05m",
    "--minWith", "target",
    "--goalR", "0.5",
    "--velocity", "1.0",
    "--set_mode", "reach",
    "--num_epochs", "40000",

    "--wandb_project", "reachability-experiments",
    "--wandb_entity", "aanirudh-n-a",
    '--wandb_group', "Training",
    '--wandb_name', "brt_safeset_05m_run_02",
]

In [17]:
import inspect
from inspect import isclass
p = configargparse.ArgumentParser()

p.add_argument('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
p.add_argument('--mode', type=str, required=True, choices=['all', 'train', 'test'], help="Experiment mode to run (new experiments must choose 'all' or 'train').")

# save/load directory options
p.add_argument('--experiments_dir', type=str, default='./runs', help='Where to save the experiment subdirectory.')
p.add_argument('--experiment_name', type=str, required=True, help='Name of the experient subdirectory.')
p.add_argument('--use_wandb', default=True, action='store_false', help='use wandb for logging')

use_wandb = p.parse_known_args()[0].use_wandb
if use_wandb:
    p.add_argument('--wandb_project', type=str, required=True, help='wandb project')
    p.add_argument('--wandb_entity', type=str, required=True, help='wandb entity')
    p.add_argument('--wandb_group', type=str, required=True, help='wandb group')
    p.add_argument('--wandb_name', type=str, required=True, help='name of wandb run')

mode = p.parse_known_args()[0].mode

if (mode == 'all') or (mode == 'train'):
    p.add_argument('--seed', type=int, default=0, required=False, help='Seed for the experiment.')

    # load experiment_class choices dynamically from experiments module
    experiment_classes_dict = {
        name: clss for name, clss in globals().items()
        if isclass(clss) and issubclass(clss, Experiment) and clss is not Experiment
    }
    # experiment_classes_dict = {name: clss for name, clss in inspect.getmembers(experiments, inspect.isclass) if clss.__bases__[0] == experiments.Experiment}
    p.add_argument('--experiment_class', type=str, default='DeepReach', choices=experiment_classes_dict.keys(), help='Experiment class to use.')
    # load special experiment_class arguments dynamically from chosen experiment class
    experiment_class = DeepReach #experiment_classes_dict[p.parse_known_args()[0].experiment_class]
    experiment_params = {name: param for name, param in inspect.signature(experiment_class.init_special).parameters.items() if name != 'self'}
    for param in experiment_params.keys():
        p.add_argument('--' + param, type=experiment_params[param].annotation, required=True, help='special experiment_class argument')

    # simulation data source options
    p.add_argument('--device', type=str, default='cuda:0', required=False, help='CUDA Device to use.')
    p.add_argument('--numpoints', type=int, default=65000, help='Number of points in simulation data source __getitem__.')
    p.add_argument('--pretrain', action='store_true', default=False, required=False, help='Pretrain dirichlet conditions')
    p.add_argument('--pretrain_iters', type=int, default=2000, required=False, help='Number of pretrain iterations')
    p.add_argument('--tMin', type=float, default=0.0, required=False, help='Start time of the simulation')
    p.add_argument('--tMax', type=float, default=1.0, required=False, help='End time of the simulation')
    p.add_argument('--counter_start', type=int, default=0, required=False, help='Defines the initial time for the curriculum training')
    p.add_argument('--counter_end', type=int, default=-1, required=False, help='Defines the linear step for curriculum training starting from the initial time')
    p.add_argument('--num_src_samples', type=int, default=1000, required=False, help='Number of source samples (initial-time samples) at each time step')
    p.add_argument('--num_target_samples', type=int, default=0, required=False, help='Number of samples inside the target set')

    # model options
    p.add_argument('--model', type=str, default='sine', required=False, choices=['sine', 'tanh', 'sigmoid', 'relu'], help='Type of model to evaluate, default is sine.')
    p.add_argument('--model_mode', type=str, default='mlp', required=False, choices=['mlp', 'rbf', 'pinn'], help='Whether to use uniform velocity parameter')
    p.add_argument('--num_hl', type=int, default=3, required=False, help='The number of hidden layers')
    p.add_argument('--num_nl', type=int, default=512, required=False, help='Number of neurons per hidden layer.')
    p.add_argument('--deepreach_model', type=str, default='exact', required=False, choices=['exact', 'diff', 'vanilla'], help='deepreach model')

    # training options
    p.add_argument('--epochs_til_ckpt', type=int, default=1000, help='Time interval in seconds until checkpoint is saved.')
    p.add_argument('--steps_til_summary', type=int, default=100, help='Time interval in seconds until tensorboard summary is saved.')
    p.add_argument('--batch_size', type=int, default=1, help='Batch size used during training (irrelevant, since len(dataset) == 1).')
    p.add_argument('--lr', type=float, default=2e-5, help='learning rate. default=2e-5')
    p.add_argument('--num_epochs', type=int, default=100000, help='Number of epochs to train for.')
    p.add_argument('--clip_grad', default=0.0, type=float, help='Clip gradient.')
    p.add_argument('--use_lbfgs', default=False, type=bool, help='use L-BFGS.')
    p.add_argument('--adj_rel_grads', default=True, type=bool, help='adjust the relative magnitude of the losses')
    p.add_argument('--dirichlet_loss_divisor', default=1.0, required=False, type=float, help='What to divide the dirichlet loss by for loss reweighting')

    # cost-supervised learning (CSL) options
    p.add_argument('--use_CSL', default=False, action='store_true', help='use cost-supervised learning (CSL)')
    p.add_argument('--CSL_lr', type=float, default=2e-5, help='The learning rate used for CSL')
    p.add_argument('--CSL_dt', type=float, default=0.0025, help='The dt used in rolling out trajectories to get cost labels')
    p.add_argument('--epochs_til_CSL', type=int, default=10000, help='Number of epochs between CSL phases')
    p.add_argument('--num_CSL_samples', type=int, default=1000000, help='Number of cost samples in training dataset for CSL phases')
    p.add_argument('--CSL_loss_frac_cutoff', type=float, default=0.1, help='Fraction of initial cost loss on validation dataset to cutoff CSL phases')
    p.add_argument('--max_CSL_epochs', type=int, default=100, help='Max number of CSL epochs per phase')
    p.add_argument('--CSL_loss_weight', type=float, default=1.0, help='weight of cost loss (relative to PDE loss)')
    p.add_argument('--CSL_batch_size', type=int, default=1000, help='Batch size for training in CSL phases')

    # validation (during training) options
    p.add_argument('--val_x_resolution', type=int, default=200, help='x-axis resolution of validation plot during training')
    p.add_argument('--val_y_resolution', type=int, default=200, help='y-axis resolution of validation plot during training')
    p.add_argument('--val_z_resolution', type=int, default=5, help='z-axis resolution of validation plot during training')
    p.add_argument('--val_time_resolution', type=int, default=3, help='time-axis resolution of validation plot during training')

    # loss options
    p.add_argument('--minWith', type=str, required=True, choices=['none', 'zero', 'target'], help='BRS vs BRT computation (typically should be using target for BRT)')

    # # load dynamics_class choices dynamically from dynamics module
    # dynamics_classes_dict = {name: clss for name, clss in inspect.getmembers(dynamics, inspect.isclass) if clss.__bases__[0] == dynamics.Dynamics}
    p.add_argument('--dynamics_class', type=str, required=True, choices=["PlanarRobot2D"], help='Dynamics class to use.')
    # # load special dynamics_class arguments dynamically from chosen dynamics class
    # dynamics_class = dynamics_classes_dict[p.parse_known_args()[0].dynamics_class]
    # dynamics_params = {name: param for name, param in inspect.signature(dynamics_class).parameters.items() if name != 'self'}
    dynamics_class = PlanarRobot2D
    dynamics_params = {
        name: param for name, param in inspect.signature(dynamics_class).parameters.items() if name != 'self'
    }

    for param in dynamics_params.keys():
        if dynamics_params[param].annotation is bool:
            p.add_argument('--' + param, type=dynamics_params[param].annotation, default=False, help='special dynamics_class argument')
        else:
            p.add_argument('--' + param, type=dynamics_params[param].annotation, required=True, help='special dynamics_class argument')

if (mode == 'all') or (mode == 'test'):
    p.add_argument('--dt', type=float, default=0.0025, help='The dt used in testing simulations')
    p.add_argument('--checkpoint_toload', type=int, default=None, help="The checkpoint to load for testing (-1 for final training checkpoint, None for cross-checkpoint testing")
    p.add_argument('--num_scenarios', type=int, default=100000, help='The number of scenarios sampled in scenario optimization for testing')
    p.add_argument('--num_violations', type=int, default=1000, help='The number of violations to sample for in scenario optimization for testing')
    p.add_argument('--control_type', type=str, default='value', choices=['value', 'ttr', 'init_ttr'], help='The controller to use in scenario optimization for testing')
    p.add_argument('--data_step', type=str, default='run_basic_recovery', choices=['plot_violations', 'run_basic_recovery', 'plot_basic_recovery', 'collect_samples', 'train_binner', 'run_binned_recovery', 'plot_binned_recovery', 'plot_cost_function'], help='The data processing step to run')

opt = p.parse_args()

# start wandb
if use_wandb:
    wandb.init(
        project = opt.wandb_project,
        entity = opt.wandb_entity,
        group = opt.wandb_group,
        name = opt.wandb_name,
    )
    wandb.config.update(opt)

experiment_dir = os.path.join(opt.experiments_dir, opt.experiment_name)
if (mode == 'all') or (mode == 'train'):
    # create experiment dir
    if os.path.exists(experiment_dir):
        overwrite = input("The experiment directory %s already exists. Overwrite? (y/n)"%experiment_dir)
        if not (overwrite == 'y'):
            print('Exiting.')
            quit()
        shutil.rmtree(experiment_dir)
    os.makedirs(experiment_dir)
elif mode == 'test':
    # confirm that experiment dir already exists
    if not os.path.exists(experiment_dir):
        raise RuntimeError('Cannot run test mode: experiment directory not found!')

current_time = datetime.now()
# log current config
with open(os.path.join(experiment_dir, 'config_%s.txt' % current_time.strftime('%m_%d_%Y_%H_%M')), 'w') as f:
    for arg, val in vars(opt).items():
        f.write(arg + ' = ' + str(val) + '\n')

if (mode == 'all') or (mode == 'train'):
    # set counter_end appropriately if needed
    if opt.counter_end == -1:
        opt.counter_end = opt.num_epochs

    # log original options
    with open(os.path.join(experiment_dir, 'orig_opt.pickle'), 'wb') as opt_file:
        pickle.dump(opt, opt_file)

# load original experiment settings
with open(os.path.join(experiment_dir, 'orig_opt.pickle'), 'rb') as opt_file:
    orig_opt = pickle.load(opt_file)

# set the experiment seed
torch.manual_seed(orig_opt.seed)
random.seed(orig_opt.seed)
np.random.seed(orig_opt.seed)

dynamics_class = PlanarRobot2D #getattr(dynamics, orig_opt.dynamics_class)
dynamics = dynamics_class(**{argname: getattr(orig_opt, argname) for argname in inspect.signature(dynamics_class).parameters.keys() if argname != 'self'})
dynamics.deepreach_model=orig_opt.deepreach_model
dataset = ReachabilityDataset(
    dynamics=dynamics, numpoints=orig_opt.numpoints,
    pretrain=orig_opt.pretrain, pretrain_iters=orig_opt.pretrain_iters,
    tMin=orig_opt.tMin, tMax=orig_opt.tMax,
    counter_start=orig_opt.counter_start, counter_end=orig_opt.counter_end,
    num_src_samples=orig_opt.num_src_samples, num_target_samples=orig_opt.num_target_samples)

model = SingleBVPNet(in_features=dynamics.input_dim, out_features=1, type=orig_opt.model, mode=orig_opt.model_mode,
                             final_layer_factor=1., hidden_features=orig_opt.num_nl, num_hidden_layers=orig_opt.num_hl)
model.to(opt.device)

experiment_class = DeepReach #getattr(experiments, orig_opt.experiment_class)
experiment = experiment_class(model=model, dataset=dataset, experiment_dir=experiment_dir, use_wandb=use_wandb)
experiment.init_special(**{argname: getattr(orig_opt, argname) for argname in inspect.signature(experiment_class.init_special).parameters.keys() if argname != 'self'})

if (mode == 'all') or (mode == 'train'):
    if dynamics.loss_type == 'brt_hjivi':
        loss_fn = init_brt_hjivi_loss(dynamics, orig_opt.minWith, orig_opt.dirichlet_loss_divisor)
    elif dynamics.loss_type == 'brat_hjivi':
        loss_fn = init_brat_hjivi_loss(dynamics, orig_opt.minWith, orig_opt.dirichlet_loss_divisor)
    else:
        raise NotImplementedError
    experiment.train(
        device=opt.device, batch_size=orig_opt.batch_size, epochs=orig_opt.num_epochs, lr=orig_opt.lr,
        steps_til_summary=orig_opt.steps_til_summary, epochs_til_checkpoint=orig_opt.epochs_til_ckpt,
        loss_fn=loss_fn, clip_grad=orig_opt.clip_grad, use_lbfgs=orig_opt.use_lbfgs, adjust_relative_grads=orig_opt.adj_rel_grads,
        val_x_resolution=orig_opt.val_x_resolution, val_y_resolution=orig_opt.val_y_resolution, val_z_resolution=orig_opt.val_z_resolution, val_time_resolution=orig_opt.val_time_resolution,
        use_CSL=orig_opt.use_CSL, CSL_lr=orig_opt.CSL_lr, CSL_dt=orig_opt.CSL_dt, epochs_til_CSL=orig_opt.epochs_til_CSL, num_CSL_samples=orig_opt.num_CSL_samples, CSL_loss_frac_cutoff=orig_opt.CSL_loss_frac_cutoff, max_CSL_epochs=orig_opt.max_CSL_epochs, CSL_loss_weight=orig_opt.CSL_loss_weight, CSL_batch_size=orig_opt.CSL_batch_size)

if (mode == 'all') or (mode == 'test'):
    experiment.test(
        device=opt.device, current_time=current_time,
        last_checkpoint=orig_opt.num_epochs, checkpoint_dt=orig_opt.epochs_til_ckpt,
        checkpoint_toload=opt.checkpoint_toload, dt=opt.dt,
        num_scenarios=opt.num_scenarios, num_violations=opt.num_violations,
        set_type='BRT' if orig_opt.minWith in ['zero', 'target'] else 'BRS', control_type=opt.control_type, data_step=opt.data_step)

SingleBVPNet(
  (net): FCBlock(
    (net): Sequential(
      (0): Sequential(
        (0): BatchLinear(in_features=3, out_features=512, bias=True)
        (1): Sine()
      )
      (1): Sequential(
        (0): BatchLinear(in_features=512, out_features=512, bias=True)
        (1): Sine()
      )
      (2): Sequential(
        (0): BatchLinear(in_features=512, out_features=512, bias=True)
        (1): Sine()
      )
      (3): Sequential(
        (0): BatchLinear(in_features=512, out_features=512, bias=True)
        (1): Sine()
      )
      (4): Sequential(
        (0): BatchLinear(in_features=512, out_features=1, bias=True)
      )
    )
  )
)


  0%|          | 1/40000 [00:01<17:23:09,  1.56s/it]

Epoch 0, Total loss 2817.822021, iteration time 1.532481


  0%|          | 102/40000 [00:22<2:19:49,  4.76it/s]

Epoch 100, Total loss 543.033325, iteration time 0.211500


  1%|          | 202/40000 [00:43<2:32:16,  4.36it/s]

Epoch 200, Total loss 207.233109, iteration time 0.243977


  1%|          | 302/40000 [01:04<2:24:08,  4.59it/s]

Epoch 300, Total loss 349.681946, iteration time 0.223981


  1%|          | 402/40000 [01:26<2:23:52,  4.59it/s]

Epoch 400, Total loss 434.128448, iteration time 0.219529


  1%|▏         | 502/40000 [01:47<2:24:39,  4.55it/s]

Epoch 500, Total loss 755.522095, iteration time 0.221378


  2%|▏         | 602/40000 [02:09<2:27:15,  4.46it/s]

Epoch 600, Total loss 706.524048, iteration time 0.229115


  2%|▏         | 702/40000 [02:31<2:32:07,  4.31it/s]

Epoch 700, Total loss 737.521729, iteration time 0.234577


  2%|▏         | 802/40000 [02:53<2:27:48,  4.42it/s]

Epoch 800, Total loss 860.764526, iteration time 0.230392


  2%|▏         | 902/40000 [03:16<2:28:00,  4.40it/s]

Epoch 900, Total loss 1008.637634, iteration time 0.235151


  3%|▎         | 1002/40000 [03:39<4:32:33,  2.38it/s]

Epoch 1000, Total loss 1008.707703, iteration time 0.106822


  3%|▎         | 1102/40000 [04:01<2:27:08,  4.41it/s]

Epoch 1100, Total loss 1103.592773, iteration time 0.230654


  3%|▎         | 1202/40000 [04:23<2:27:55,  4.37it/s]

Epoch 1200, Total loss 1262.625488, iteration time 0.231828


  3%|▎         | 1302/40000 [04:46<2:27:42,  4.37it/s]

Epoch 1300, Total loss 1768.937012, iteration time 0.233292


  4%|▎         | 1402/40000 [05:08<2:28:12,  4.34it/s]

Epoch 1400, Total loss 1438.722656, iteration time 0.240109


  4%|▍         | 1502/40000 [05:31<2:31:15,  4.24it/s]

Epoch 1500, Total loss 1488.441406, iteration time 0.234791


  4%|▍         | 1602/40000 [05:53<2:26:56,  4.36it/s]

Epoch 1600, Total loss 1638.483765, iteration time 0.230489


  4%|▍         | 1702/40000 [06:16<2:27:06,  4.34it/s]

Epoch 1700, Total loss 1719.870605, iteration time 0.230929


  5%|▍         | 1802/40000 [06:38<2:27:27,  4.32it/s]

Epoch 1800, Total loss 1774.698730, iteration time 0.247271


  5%|▍         | 1902/40000 [07:01<2:27:51,  4.29it/s]

Epoch 1900, Total loss 1747.013184, iteration time 0.238963


  5%|▌         | 2002/40000 [07:24<3:37:25,  2.91it/s]

Epoch 2000, Total loss 2048.428711, iteration time 0.105961


  5%|▌         | 2102/40000 [07:47<2:29:23,  4.23it/s]

Epoch 2100, Total loss 2458.502930, iteration time 0.245700


  6%|▌         | 2202/40000 [08:09<2:27:13,  4.28it/s]

Epoch 2200, Total loss 2117.459229, iteration time 0.255143


  6%|▌         | 2302/40000 [08:32<2:27:12,  4.27it/s]

Epoch 2300, Total loss 2288.759033, iteration time 0.248941


  6%|▌         | 2402/40000 [08:55<2:24:17,  4.34it/s]

Epoch 2400, Total loss 2406.125000, iteration time 0.232164


  6%|▋         | 2502/40000 [09:18<2:25:37,  4.29it/s]

Epoch 2500, Total loss 2094.540283, iteration time 0.250179


  7%|▋         | 2602/40000 [09:40<2:24:48,  4.30it/s]

Epoch 2600, Total loss 2459.079590, iteration time 0.237214


  7%|▋         | 2702/40000 [10:03<2:24:24,  4.30it/s]

Epoch 2700, Total loss 2318.969482, iteration time 0.231741


  7%|▋         | 2802/40000 [10:26<2:24:50,  4.28it/s]

Epoch 2800, Total loss 2184.707275, iteration time 0.234876


  7%|▋         | 2902/40000 [10:49<2:25:24,  4.25it/s]

Epoch 2900, Total loss 2781.162598, iteration time 0.231577


  8%|▊         | 3002/40000 [11:12<3:31:53,  2.91it/s]

Epoch 3000, Total loss 2816.377441, iteration time 0.099703


  8%|▊         | 3102/40000 [11:35<2:23:16,  4.29it/s]

Epoch 3100, Total loss 2344.909424, iteration time 0.243244


  8%|▊         | 3202/40000 [11:58<2:25:23,  4.22it/s]

Epoch 3200, Total loss 1522.130615, iteration time 0.242423


  8%|▊         | 3302/40000 [12:21<2:22:49,  4.28it/s]

Epoch 3300, Total loss 2952.756348, iteration time 0.239432


  9%|▊         | 3402/40000 [12:44<2:24:21,  4.23it/s]

Epoch 3400, Total loss 1314.705444, iteration time 0.241125


  9%|▉         | 3502/40000 [13:08<2:23:52,  4.23it/s]

Epoch 3500, Total loss 982.790161, iteration time 0.239906


  9%|▉         | 3602/40000 [13:31<2:24:58,  4.18it/s]

Epoch 3600, Total loss 1119.353149, iteration time 0.243370


  9%|▉         | 3702/40000 [13:54<2:22:53,  4.23it/s]

Epoch 3700, Total loss 853.034180, iteration time 0.240279


 10%|▉         | 3802/40000 [14:18<2:23:53,  4.19it/s]

Epoch 3800, Total loss 1214.907104, iteration time 0.252574


 10%|▉         | 3902/40000 [14:41<2:24:03,  4.18it/s]

Epoch 3900, Total loss 830.736755, iteration time 0.244334


 10%|█         | 4002/40000 [15:05<3:28:06,  2.88it/s]

Epoch 4000, Total loss 736.334717, iteration time 0.092295


 10%|█         | 4102/40000 [15:28<2:23:04,  4.18it/s]

Epoch 4100, Total loss 764.887329, iteration time 0.246836


 11%|█         | 4202/40000 [15:52<2:23:22,  4.16it/s]

Epoch 4200, Total loss 707.307983, iteration time 0.251731


 11%|█         | 4302/40000 [16:15<2:23:59,  4.13it/s]

Epoch 4300, Total loss 536.188232, iteration time 0.245511


 11%|█         | 4402/40000 [16:38<2:16:27,  4.35it/s]

Epoch 4400, Total loss 7548.442383, iteration time 0.235291


 11%|█▏        | 4502/40000 [17:01<2:21:40,  4.18it/s]

Epoch 4500, Total loss 1025.094238, iteration time 0.246990


 12%|█▏        | 4602/40000 [17:25<2:21:39,  4.16it/s]

Epoch 4600, Total loss 726.453369, iteration time 0.247771


 12%|█▏        | 4702/40000 [17:48<2:20:13,  4.20it/s]

Epoch 4700, Total loss 863.457458, iteration time 0.243232


 12%|█▏        | 4802/40000 [18:11<2:18:44,  4.23it/s]

Epoch 4800, Total loss 617.711792, iteration time 0.240338


 12%|█▏        | 4902/40000 [18:35<2:20:20,  4.17it/s]

Epoch 4900, Total loss 699.970154, iteration time 0.248942


 13%|█▎        | 5002/40000 [18:59<3:21:00,  2.90it/s]

Epoch 5000, Total loss 596.586121, iteration time 0.095725


 13%|█▎        | 5102/40000 [19:22<2:18:38,  4.20it/s]

Epoch 5100, Total loss 411.359253, iteration time 0.247063


 13%|█▎        | 5202/40000 [19:45<2:17:07,  4.23it/s]

Epoch 5200, Total loss 1019.944397, iteration time 0.239650


 13%|█▎        | 5302/40000 [20:09<2:21:30,  4.09it/s]

Epoch 5300, Total loss 825.198914, iteration time 0.266782


 14%|█▎        | 5402/40000 [20:32<2:17:37,  4.19it/s]

Epoch 5400, Total loss 760.449402, iteration time 0.246534


 14%|█▍        | 5502/40000 [20:56<2:17:14,  4.19it/s]

Epoch 5500, Total loss 628.372681, iteration time 0.244378


 14%|█▍        | 5602/40000 [21:19<2:19:37,  4.11it/s]

Epoch 5600, Total loss 369.514496, iteration time 0.249926


 14%|█▍        | 5702/40000 [21:42<2:17:24,  4.16it/s]

Epoch 5700, Total loss 567.088623, iteration time 0.244318


 15%|█▍        | 5802/40000 [22:06<2:16:29,  4.18it/s]

Epoch 5800, Total loss 1089.916748, iteration time 0.248149


 15%|█▍        | 5902/40000 [22:29<2:17:25,  4.14it/s]

Epoch 5900, Total loss 780.965027, iteration time 0.258882


 15%|█▌        | 6002/40000 [22:53<3:23:48,  2.78it/s]

Epoch 6000, Total loss 1108.689453, iteration time 0.106810


 15%|█▌        | 6102/40000 [23:17<2:13:44,  4.22it/s]

Epoch 6100, Total loss 374.050659, iteration time 0.243344


 16%|█▌        | 6202/40000 [23:40<2:14:32,  4.19it/s]

Epoch 6200, Total loss 369.966187, iteration time 0.245828


 16%|█▌        | 6302/40000 [24:03<2:14:43,  4.17it/s]

Epoch 6300, Total loss 282.901855, iteration time 0.257543


 16%|█▌        | 6402/40000 [24:27<2:11:15,  4.27it/s]

Epoch 6400, Total loss 3660.099365, iteration time 0.242975


 16%|█▋        | 6502/40000 [24:50<2:13:12,  4.19it/s]

Epoch 6500, Total loss 613.898315, iteration time 0.249559


 17%|█▋        | 6602/40000 [25:13<2:13:13,  4.18it/s]

Epoch 6600, Total loss 464.426758, iteration time 0.246235


 17%|█▋        | 6702/40000 [25:37<2:12:22,  4.19it/s]

Epoch 6700, Total loss 431.247742, iteration time 0.241383


 17%|█▋        | 6802/40000 [26:00<2:12:13,  4.18it/s]

Epoch 6800, Total loss 292.845734, iteration time 0.246955


 17%|█▋        | 6902/40000 [26:24<2:12:01,  4.18it/s]

Epoch 6900, Total loss 1341.505615, iteration time 0.244329


 18%|█▊        | 7002/40000 [26:48<3:31:58,  2.59it/s]

Epoch 7000, Total loss 905.884155, iteration time 0.100921


 18%|█▊        | 7102/40000 [27:11<2:09:34,  4.23it/s]

Epoch 7100, Total loss 561.234619, iteration time 0.238796


 18%|█▊        | 7202/40000 [27:34<2:11:25,  4.16it/s]

Epoch 7200, Total loss 992.368713, iteration time 0.243438


 18%|█▊        | 7302/40000 [27:58<2:11:28,  4.15it/s]

Epoch 7300, Total loss 341.541626, iteration time 0.242078


 19%|█▊        | 7402/40000 [28:21<2:09:44,  4.19it/s]

Epoch 7400, Total loss 525.022461, iteration time 0.243796


 19%|█▉        | 7502/40000 [28:45<2:09:24,  4.19it/s]

Epoch 7500, Total loss 355.985168, iteration time 0.240278


 19%|█▉        | 7602/40000 [29:08<2:08:07,  4.21it/s]

Epoch 7600, Total loss 345.864075, iteration time 0.242305


 19%|█▉        | 7702/40000 [29:31<2:11:29,  4.09it/s]

Epoch 7700, Total loss 665.942505, iteration time 0.252546


 20%|█▉        | 7802/40000 [29:55<2:07:12,  4.22it/s]

Epoch 7800, Total loss 1063.120972, iteration time 0.240523


 20%|█▉        | 7902/40000 [30:18<2:07:29,  4.20it/s]

Epoch 7900, Total loss 400.617737, iteration time 0.245160


 20%|██        | 8002/40000 [30:42<3:03:50,  2.90it/s]

Epoch 8000, Total loss 694.469055, iteration time 0.096003


 20%|██        | 8102/40000 [31:05<2:07:41,  4.16it/s]

Epoch 8100, Total loss 788.925537, iteration time 0.245875


 21%|██        | 8202/40000 [31:29<2:06:27,  4.19it/s]

Epoch 8200, Total loss 848.739990, iteration time 0.249058


 21%|██        | 8302/40000 [31:52<2:04:51,  4.23it/s]

Epoch 8300, Total loss 721.820740, iteration time 0.238730


 21%|██        | 8402/40000 [32:15<2:05:22,  4.20it/s]

Epoch 8400, Total loss 820.735596, iteration time 0.251357


 21%|██▏       | 8502/40000 [32:39<2:05:55,  4.17it/s]

Epoch 8500, Total loss 530.088135, iteration time 0.244434


 22%|██▏       | 8602/40000 [33:02<2:03:49,  4.23it/s]

Epoch 8600, Total loss 823.406372, iteration time 0.242123


 22%|██▏       | 8702/40000 [33:26<2:04:21,  4.19it/s]

Epoch 8700, Total loss 703.479736, iteration time 0.247056


 22%|██▏       | 8802/40000 [33:49<2:05:42,  4.14it/s]

Epoch 8800, Total loss 612.226562, iteration time 0.240869


 22%|██▏       | 8902/40000 [34:12<2:02:41,  4.22it/s]

Epoch 8900, Total loss 714.815125, iteration time 0.250800


 23%|██▎       | 9002/40000 [34:36<2:58:56,  2.89it/s]

Epoch 9000, Total loss 757.507202, iteration time 0.093251


 23%|██▎       | 9102/40000 [35:00<2:02:53,  4.19it/s]

Epoch 9100, Total loss 693.531372, iteration time 0.242183


 23%|██▎       | 9202/40000 [35:23<2:02:51,  4.18it/s]

Epoch 9200, Total loss 582.004089, iteration time 0.243770


 23%|██▎       | 9302/40000 [35:47<2:01:35,  4.21it/s]

Epoch 9300, Total loss 511.233643, iteration time 0.243709


 24%|██▎       | 9402/40000 [36:10<2:02:40,  4.16it/s]

Epoch 9400, Total loss 612.090454, iteration time 0.246358


 24%|██▍       | 9502/40000 [36:33<2:03:06,  4.13it/s]

Epoch 9500, Total loss 596.433044, iteration time 0.249171


 24%|██▍       | 9602/40000 [36:57<2:01:55,  4.16it/s]

Epoch 9600, Total loss 615.186829, iteration time 0.244467


 24%|██▍       | 9702/40000 [37:20<2:00:48,  4.18it/s]

Epoch 9700, Total loss 626.716309, iteration time 0.241450


 25%|██▍       | 9802/40000 [37:44<1:59:49,  4.20it/s]

Epoch 9800, Total loss 518.137207, iteration time 0.243076


 25%|██▍       | 9902/40000 [38:07<2:00:14,  4.17it/s]

Epoch 9900, Total loss 558.393555, iteration time 0.249344


 25%|██▌       | 10002/40000 [38:31<2:51:41,  2.91it/s]

Epoch 10000, Total loss 547.971558, iteration time 0.092049


 25%|██▌       | 10102/40000 [38:54<1:57:44,  4.23it/s]

Epoch 10100, Total loss 693.929749, iteration time 0.241414


 26%|██▌       | 10202/40000 [39:18<1:59:35,  4.15it/s]

Epoch 10200, Total loss 575.881592, iteration time 0.243731


 26%|██▌       | 10302/40000 [39:41<1:58:28,  4.18it/s]

Epoch 10300, Total loss 565.439575, iteration time 0.245344


 26%|██▌       | 10402/40000 [40:04<1:56:42,  4.23it/s]

Epoch 10400, Total loss 554.311951, iteration time 0.241686


 26%|██▋       | 10502/40000 [40:28<1:56:26,  4.22it/s]

Epoch 10500, Total loss 558.100464, iteration time 0.239206


 27%|██▋       | 10602/40000 [40:51<1:57:29,  4.17it/s]

Epoch 10600, Total loss 544.312866, iteration time 0.241156


 27%|██▋       | 10702/40000 [41:14<1:56:10,  4.20it/s]

Epoch 10700, Total loss 523.103027, iteration time 0.244684


 27%|██▋       | 10802/40000 [41:38<1:55:26,  4.22it/s]

Epoch 10800, Total loss 598.562683, iteration time 0.239030


 27%|██▋       | 10902/40000 [42:01<1:55:50,  4.19it/s]

Epoch 10900, Total loss 542.726868, iteration time 0.244166


 28%|██▊       | 11002/40000 [42:26<4:03:33,  1.98it/s]

Epoch 11000, Total loss 564.881836, iteration time 0.104460


 28%|██▊       | 11102/40000 [42:49<1:53:59,  4.23it/s]

Epoch 11100, Total loss 498.631134, iteration time 0.239745


 28%|██▊       | 11202/40000 [43:13<1:54:55,  4.18it/s]

Epoch 11200, Total loss 592.672119, iteration time 0.248423


 28%|██▊       | 11302/40000 [43:36<1:54:28,  4.18it/s]

Epoch 11300, Total loss 654.904053, iteration time 0.235932


 29%|██▊       | 11402/40000 [44:00<1:54:53,  4.15it/s]

Epoch 11400, Total loss 550.705688, iteration time 0.245918


 29%|██▉       | 11502/40000 [44:23<1:52:27,  4.22it/s]

Epoch 11500, Total loss 626.817444, iteration time 0.240731


 29%|██▉       | 11602/40000 [44:47<1:53:55,  4.15it/s]

Epoch 11600, Total loss 528.440186, iteration time 0.247972


 29%|██▉       | 11702/40000 [45:10<1:52:50,  4.18it/s]

Epoch 11700, Total loss 654.898926, iteration time 0.246135


 30%|██▉       | 11802/40000 [45:33<1:51:00,  4.23it/s]

Epoch 11800, Total loss 575.802368, iteration time 0.243019


 30%|██▉       | 11902/40000 [45:57<1:50:34,  4.24it/s]

Epoch 11900, Total loss 582.583313, iteration time 0.238122


 30%|███       | 12002/40000 [46:21<2:45:37,  2.82it/s]

Epoch 12000, Total loss 485.281769, iteration time 0.093069


 30%|███       | 12102/40000 [46:44<1:52:17,  4.14it/s]

Epoch 12100, Total loss 468.600311, iteration time 0.247799


 31%|███       | 12202/40000 [47:07<1:49:38,  4.23it/s]

Epoch 12200, Total loss 597.945435, iteration time 0.239515


 31%|███       | 12302/40000 [47:31<1:49:39,  4.21it/s]

Epoch 12300, Total loss 217.601273, iteration time 0.240728


 31%|███       | 12402/40000 [47:54<1:49:47,  4.19it/s]

Epoch 12400, Total loss 208.685196, iteration time 0.246101


 31%|███▏      | 12502/40000 [48:18<1:49:30,  4.19it/s]

Epoch 12500, Total loss 246.673248, iteration time 0.246906


 32%|███▏      | 12602/40000 [48:41<1:50:17,  4.14it/s]

Epoch 12600, Total loss 392.238037, iteration time 0.248466


 32%|███▏      | 12702/40000 [49:04<1:48:51,  4.18it/s]

Epoch 12700, Total loss 251.863373, iteration time 0.242762


 32%|███▏      | 12802/40000 [49:28<1:47:55,  4.20it/s]

Epoch 12800, Total loss 2782.437988, iteration time 0.241900


 32%|███▏      | 12902/40000 [49:51<1:49:05,  4.14it/s]

Epoch 12900, Total loss 278.657349, iteration time 0.252502


 33%|███▎      | 13002/40000 [50:15<2:35:49,  2.89it/s]

Epoch 13000, Total loss 235.580383, iteration time 0.094024


 33%|███▎      | 13102/40000 [50:39<1:46:56,  4.19it/s]

Epoch 13100, Total loss 896.579468, iteration time 0.244854


 33%|███▎      | 13202/40000 [51:02<1:47:58,  4.14it/s]

Epoch 13200, Total loss 632.976318, iteration time 0.249270


 33%|███▎      | 13302/40000 [51:25<1:47:24,  4.14it/s]

Epoch 13300, Total loss 600.548462, iteration time 0.243074


 34%|███▎      | 13402/40000 [51:49<1:46:16,  4.17it/s]

Epoch 13400, Total loss 658.841370, iteration time 0.245179


 34%|███▍      | 13502/40000 [52:12<1:45:42,  4.18it/s]

Epoch 13500, Total loss 644.577393, iteration time 0.242806


 34%|███▍      | 13602/40000 [52:36<1:45:17,  4.18it/s]

Epoch 13600, Total loss 621.312256, iteration time 0.252243


 34%|███▍      | 13702/40000 [52:59<1:44:25,  4.20it/s]

Epoch 13700, Total loss 617.646362, iteration time 0.244962


 35%|███▍      | 13802/40000 [53:23<1:43:38,  4.21it/s]

Epoch 13800, Total loss 614.697144, iteration time 0.240588


 35%|███▍      | 13902/40000 [53:46<1:45:10,  4.14it/s]

Epoch 13900, Total loss 596.563599, iteration time 0.253218


 35%|███▌      | 14002/40000 [54:10<2:44:04,  2.64it/s]

Epoch 14000, Total loss 615.174683, iteration time 0.093306


 35%|███▌      | 14102/40000 [54:34<1:44:21,  4.14it/s]

Epoch 14100, Total loss 484.917267, iteration time 0.243459


 36%|███▌      | 14202/40000 [54:57<1:42:46,  4.18it/s]

Epoch 14200, Total loss 648.265625, iteration time 0.244787


 36%|███▌      | 14302/40000 [55:20<1:41:49,  4.21it/s]

Epoch 14300, Total loss 563.394165, iteration time 0.239789


 36%|███▌      | 14402/40000 [55:44<1:42:45,  4.15it/s]

Epoch 14400, Total loss 571.571228, iteration time 0.242887


 36%|███▋      | 14502/40000 [56:07<1:41:55,  4.17it/s]

Epoch 14500, Total loss 578.693848, iteration time 0.241483


 37%|███▋      | 14602/40000 [56:31<1:41:33,  4.17it/s]

Epoch 14600, Total loss 651.547974, iteration time 0.245599


 37%|███▋      | 14702/40000 [56:54<1:40:51,  4.18it/s]

Epoch 14700, Total loss 582.900879, iteration time 0.251040


 37%|███▋      | 14802/40000 [57:18<1:41:08,  4.15it/s]

Epoch 14800, Total loss 558.496948, iteration time 0.240441


 37%|███▋      | 14902/40000 [57:41<1:39:51,  4.19it/s]

Epoch 14900, Total loss 558.425415, iteration time 0.246728


 38%|███▊      | 15002/40000 [58:05<2:26:41,  2.84it/s]

Epoch 15000, Total loss 529.649170, iteration time 0.095997


 38%|███▊      | 15102/40000 [58:28<1:40:08,  4.14it/s]

Epoch 15100, Total loss 606.969849, iteration time 0.250762


 38%|███▊      | 15202/40000 [58:52<1:39:09,  4.17it/s]

Epoch 15200, Total loss 598.981995, iteration time 0.240863


 38%|███▊      | 15302/40000 [59:15<1:39:22,  4.14it/s]

Epoch 15300, Total loss 595.150146, iteration time 0.246732


 39%|███▊      | 15402/40000 [59:39<1:38:06,  4.18it/s]

Epoch 15400, Total loss 688.350525, iteration time 0.243413


 39%|███▉      | 15502/40000 [1:00:02<1:37:58,  4.17it/s]

Epoch 15500, Total loss 597.613159, iteration time 0.243446


 39%|███▉      | 15602/40000 [1:00:26<1:38:44,  4.12it/s]

Epoch 15600, Total loss 557.971008, iteration time 0.255164


 39%|███▉      | 15702/40000 [1:00:49<1:36:10,  4.21it/s]

Epoch 15700, Total loss 530.264282, iteration time 0.242475


 40%|███▉      | 15802/40000 [1:01:12<1:36:35,  4.18it/s]

Epoch 15800, Total loss 511.793457, iteration time 0.251156


 40%|███▉      | 15902/40000 [1:01:36<1:35:41,  4.20it/s]

Epoch 15900, Total loss 509.597595, iteration time 0.243642


 40%|████      | 16002/40000 [1:02:00<2:34:24,  2.59it/s]

Epoch 16000, Total loss 549.928101, iteration time 0.092887


 40%|████      | 16102/40000 [1:02:24<1:34:58,  4.19it/s]

Epoch 16100, Total loss 636.059631, iteration time 0.252554


 41%|████      | 16202/40000 [1:02:47<1:35:07,  4.17it/s]

Epoch 16200, Total loss 607.866211, iteration time 0.242193


 41%|████      | 16302/40000 [1:03:10<1:33:35,  4.22it/s]

Epoch 16300, Total loss 629.630371, iteration time 0.239242


 41%|████      | 16402/40000 [1:03:34<1:34:58,  4.14it/s]

Epoch 16400, Total loss 639.408203, iteration time 0.251069


 41%|████▏     | 16502/40000 [1:03:57<1:35:09,  4.12it/s]

Epoch 16500, Total loss 495.090515, iteration time 0.240506


 42%|████▏     | 16602/40000 [1:04:21<1:33:15,  4.18it/s]

Epoch 16600, Total loss 581.882935, iteration time 0.245788


 42%|████▏     | 16702/40000 [1:04:44<1:31:58,  4.22it/s]

Epoch 16700, Total loss 580.559143, iteration time 0.238554


 42%|████▏     | 16802/40000 [1:05:08<1:32:14,  4.19it/s]

Epoch 16800, Total loss 592.962708, iteration time 0.251080


 42%|████▏     | 16902/40000 [1:05:31<1:32:58,  4.14it/s]

Epoch 16900, Total loss 523.750916, iteration time 0.242577


 43%|████▎     | 17002/40000 [1:05:55<2:14:30,  2.85it/s]

Epoch 17000, Total loss 510.405884, iteration time 0.092773


 43%|████▎     | 17102/40000 [1:06:19<1:31:08,  4.19it/s]

Epoch 17100, Total loss 666.712036, iteration time 0.243933


 43%|████▎     | 17202/40000 [1:06:42<1:31:05,  4.17it/s]

Epoch 17200, Total loss 559.844604, iteration time 0.233110


 43%|████▎     | 17302/40000 [1:07:06<1:30:12,  4.19it/s]

Epoch 17300, Total loss 497.534607, iteration time 0.252139


 44%|████▎     | 17402/40000 [1:07:29<1:30:20,  4.17it/s]

Epoch 17400, Total loss 548.133057, iteration time 0.243289


 44%|████▍     | 17502/40000 [1:07:53<1:28:40,  4.23it/s]

Epoch 17500, Total loss 588.786255, iteration time 0.240018


 44%|████▍     | 17602/40000 [1:08:16<1:29:09,  4.19it/s]

Epoch 17600, Total loss 602.265808, iteration time 0.237186


 44%|████▍     | 17702/40000 [1:08:40<1:30:32,  4.10it/s]

Epoch 17700, Total loss 589.891602, iteration time 0.250956


 45%|████▍     | 17802/40000 [1:09:03<1:28:01,  4.20it/s]

Epoch 17800, Total loss 579.781128, iteration time 0.240164


 45%|████▍     | 17902/40000 [1:09:27<1:28:22,  4.17it/s]

Epoch 17900, Total loss 814.264282, iteration time 0.244857


 45%|████▌     | 18002/40000 [1:09:51<2:16:30,  2.69it/s]

Epoch 18000, Total loss 197.457199, iteration time 0.093436


 45%|████▌     | 18102/40000 [1:10:14<1:29:06,  4.10it/s]

Epoch 18100, Total loss 205.766571, iteration time 0.249758


 46%|████▌     | 18202/40000 [1:10:38<1:26:50,  4.18it/s]

Epoch 18200, Total loss 4290.125000, iteration time 0.242817


 46%|████▌     | 18302/40000 [1:11:01<1:26:36,  4.18it/s]

Epoch 18300, Total loss 316.878357, iteration time 0.243290


 46%|████▌     | 18402/40000 [1:11:25<1:26:07,  4.18it/s]

Epoch 18400, Total loss 214.454987, iteration time 0.243569


 46%|████▋     | 18502/40000 [1:11:48<1:25:52,  4.17it/s]

Epoch 18500, Total loss 682.499634, iteration time 0.245731


 47%|████▋     | 18602/40000 [1:12:12<1:26:03,  4.14it/s]

Epoch 18600, Total loss 518.121399, iteration time 0.244192


 47%|████▋     | 18702/40000 [1:12:35<1:24:55,  4.18it/s]

Epoch 18700, Total loss 529.423828, iteration time 0.244456


 47%|████▋     | 18802/40000 [1:12:59<1:25:16,  4.14it/s]

Epoch 18800, Total loss 524.152832, iteration time 0.246278


 47%|████▋     | 18902/40000 [1:13:22<1:23:49,  4.19it/s]

Epoch 18900, Total loss 503.304382, iteration time 0.242436


 48%|████▊     | 19002/40000 [1:13:46<2:02:39,  2.85it/s]

Epoch 19000, Total loss 582.576538, iteration time 0.096407


 48%|████▊     | 19102/40000 [1:14:10<1:23:28,  4.17it/s]

Epoch 19100, Total loss 538.259399, iteration time 0.244365


 48%|████▊     | 19202/40000 [1:14:33<1:22:58,  4.18it/s]

Epoch 19200, Total loss 570.691956, iteration time 0.247544


 48%|████▊     | 19302/40000 [1:14:56<1:22:16,  4.19it/s]

Epoch 19300, Total loss 493.375580, iteration time 0.249035


 49%|████▊     | 19402/40000 [1:15:20<1:22:17,  4.17it/s]

Epoch 19400, Total loss 556.032043, iteration time 0.243623


 49%|████▉     | 19502/40000 [1:15:43<1:21:38,  4.18it/s]

Epoch 19500, Total loss 528.700378, iteration time 0.243789


 49%|████▉     | 19602/40000 [1:16:07<1:21:28,  4.17it/s]

Epoch 19600, Total loss 517.594055, iteration time 0.238095


 49%|████▉     | 19702/40000 [1:16:30<1:21:02,  4.17it/s]

Epoch 19700, Total loss 671.692993, iteration time 0.250655


 50%|████▉     | 19802/40000 [1:16:54<1:21:03,  4.15it/s]

Epoch 19800, Total loss 177.630524, iteration time 0.252547


 50%|████▉     | 19902/40000 [1:17:17<1:20:02,  4.18it/s]

Epoch 19900, Total loss 198.791870, iteration time 0.244203


 50%|█████     | 20002/40000 [1:17:41<2:14:07,  2.48it/s]

Epoch 20000, Total loss 174.376282, iteration time 0.092865


 50%|█████     | 20102/40000 [1:18:05<1:19:53,  4.15it/s]

Epoch 20100, Total loss 400.465912, iteration time 0.257162


 51%|█████     | 20202/40000 [1:18:28<1:19:12,  4.17it/s]

Epoch 20200, Total loss 605.471985, iteration time 0.244887


 51%|█████     | 20302/40000 [1:18:52<1:18:17,  4.19it/s]

Epoch 20300, Total loss 551.810364, iteration time 0.243040


 51%|█████     | 20402/40000 [1:19:15<1:18:44,  4.15it/s]

Epoch 20400, Total loss 524.939575, iteration time 0.252326


 51%|█████▏    | 20502/40000 [1:19:39<1:16:49,  4.23it/s]

Epoch 20500, Total loss 544.621704, iteration time 0.237328


 52%|█████▏    | 20602/40000 [1:20:02<1:17:42,  4.16it/s]

Epoch 20600, Total loss 548.968628, iteration time 0.250146


 52%|█████▏    | 20702/40000 [1:20:26<1:17:19,  4.16it/s]

Epoch 20700, Total loss 544.256531, iteration time 0.244438


 52%|█████▏    | 20802/40000 [1:20:49<1:16:39,  4.17it/s]

Epoch 20800, Total loss 500.895996, iteration time 0.247863


 52%|█████▏    | 20902/40000 [1:21:13<1:16:36,  4.15it/s]

Epoch 20900, Total loss 561.456970, iteration time 0.248109


 53%|█████▎    | 21002/40000 [1:21:37<1:51:37,  2.84it/s]

Epoch 21000, Total loss 511.203003, iteration time 0.095021


 53%|█████▎    | 21102/40000 [1:22:00<1:15:03,  4.20it/s]

Epoch 21100, Total loss 518.528442, iteration time 0.248781


 53%|█████▎    | 21202/40000 [1:22:24<1:15:43,  4.14it/s]

Epoch 21200, Total loss 537.436890, iteration time 0.246232


 53%|█████▎    | 21302/40000 [1:22:47<1:14:03,  4.21it/s]

Epoch 21300, Total loss 610.180603, iteration time 0.236055


 54%|█████▎    | 21402/40000 [1:23:11<1:13:56,  4.19it/s]

Epoch 21400, Total loss 484.803558, iteration time 0.246293


 54%|█████▍    | 21502/40000 [1:23:34<1:14:35,  4.13it/s]

Epoch 21500, Total loss 577.321777, iteration time 0.259761


 54%|█████▍    | 21602/40000 [1:23:58<1:14:29,  4.12it/s]

Epoch 21600, Total loss 501.328918, iteration time 0.246308


 54%|█████▍    | 21702/40000 [1:24:21<1:13:06,  4.17it/s]

Epoch 21700, Total loss 492.526581, iteration time 0.243681


 55%|█████▍    | 21802/40000 [1:24:44<1:12:28,  4.18it/s]

Epoch 21800, Total loss 571.573669, iteration time 0.241535


 55%|█████▍    | 21902/40000 [1:25:08<1:12:42,  4.15it/s]

Epoch 21900, Total loss 1325.765991, iteration time 0.245919


 55%|█████▌    | 22002/40000 [1:25:32<1:47:48,  2.78it/s]

Epoch 22000, Total loss 178.725586, iteration time 0.101687


 55%|█████▌    | 22102/40000 [1:25:55<1:11:01,  4.20it/s]

Epoch 22100, Total loss 241.591187, iteration time 0.243425


 56%|█████▌    | 22202/40000 [1:26:19<1:11:27,  4.15it/s]

Epoch 22200, Total loss 438.474060, iteration time 0.247274


 56%|█████▌    | 22302/40000 [1:26:42<1:10:42,  4.17it/s]

Epoch 22300, Total loss 450.188171, iteration time 0.248680


 56%|█████▌    | 22402/40000 [1:27:06<1:11:03,  4.13it/s]

Epoch 22400, Total loss 305.404083, iteration time 0.249388


 56%|█████▋    | 22502/40000 [1:27:29<1:09:29,  4.20it/s]

Epoch 22500, Total loss 251.711121, iteration time 0.239964


 57%|█████▋    | 22602/40000 [1:27:53<1:08:46,  4.22it/s]

Epoch 22600, Total loss 507.790771, iteration time 0.237572


 57%|█████▋    | 22702/40000 [1:28:16<1:08:54,  4.18it/s]

Epoch 22700, Total loss 568.646606, iteration time 0.236361


 57%|█████▋    | 22802/40000 [1:28:39<1:09:04,  4.15it/s]

Epoch 22800, Total loss 561.156128, iteration time 0.252100


 57%|█████▋    | 22902/40000 [1:29:03<1:08:20,  4.17it/s]

Epoch 22900, Total loss 538.609375, iteration time 0.248028


 58%|█████▊    | 23002/40000 [1:29:27<1:39:25,  2.85it/s]

Epoch 23000, Total loss 555.405396, iteration time 0.093079


 58%|█████▊    | 23102/40000 [1:29:50<1:07:29,  4.17it/s]

Epoch 23100, Total loss 548.420288, iteration time 0.240757


 58%|█████▊    | 23202/40000 [1:30:14<1:07:41,  4.14it/s]

Epoch 23200, Total loss 544.267212, iteration time 0.249819


 58%|█████▊    | 23302/40000 [1:30:37<1:06:44,  4.17it/s]

Epoch 23300, Total loss 539.702332, iteration time 0.251601


 59%|█████▊    | 23402/40000 [1:31:01<1:06:01,  4.19it/s]

Epoch 23400, Total loss 544.488159, iteration time 0.242916


 59%|█████▉    | 23502/40000 [1:31:24<1:05:16,  4.21it/s]

Epoch 23500, Total loss 519.515625, iteration time 0.233167


 59%|█████▉    | 23602/40000 [1:31:48<1:06:07,  4.13it/s]

Epoch 23600, Total loss 518.490112, iteration time 0.250763


 59%|█████▉    | 23702/40000 [1:32:11<1:05:08,  4.17it/s]

Epoch 23700, Total loss 523.599731, iteration time 0.245186


 60%|█████▉    | 23802/40000 [1:32:34<1:05:01,  4.15it/s]

Epoch 23800, Total loss 466.543091, iteration time 0.246413


 60%|█████▉    | 23902/40000 [1:32:58<1:04:33,  4.16it/s]

Epoch 23900, Total loss 500.963562, iteration time 0.243737


 60%|██████    | 24002/40000 [1:33:22<1:49:09,  2.44it/s]

Epoch 24000, Total loss 672.492432, iteration time 0.101024


 60%|██████    | 24102/40000 [1:33:46<1:03:09,  4.20it/s]

Epoch 24100, Total loss 484.507904, iteration time 0.245170


 61%|██████    | 24202/40000 [1:34:09<1:03:03,  4.18it/s]

Epoch 24200, Total loss 461.438538, iteration time 0.237658


 61%|██████    | 24302/40000 [1:34:32<1:02:47,  4.17it/s]

Epoch 24300, Total loss 553.023315, iteration time 0.254263


 61%|██████    | 24402/40000 [1:34:56<1:02:48,  4.14it/s]

Epoch 24400, Total loss 560.203125, iteration time 0.248797


 61%|██████▏   | 24502/40000 [1:35:19<1:01:42,  4.19it/s]

Epoch 24500, Total loss 488.477966, iteration time 0.242945


 62%|██████▏   | 24602/40000 [1:35:43<1:00:46,  4.22it/s]

Epoch 24600, Total loss 821.957214, iteration time 0.239082


 62%|██████▏   | 24702/40000 [1:36:06<1:00:53,  4.19it/s]

Epoch 24700, Total loss 246.042389, iteration time 0.243340


 62%|██████▏   | 24802/40000 [1:36:30<1:00:56,  4.16it/s]

Epoch 24800, Total loss 332.321259, iteration time 0.248830


 62%|██████▏   | 24902/40000 [1:36:53<59:32,  4.23it/s]  

Epoch 24900, Total loss 1464.121216, iteration time 0.242175


 63%|██████▎   | 25002/40000 [1:37:17<1:27:23,  2.86it/s]

Epoch 25000, Total loss 408.221375, iteration time 0.091961


 63%|██████▎   | 25102/40000 [1:37:41<59:01,  4.21it/s]

Epoch 25100, Total loss 513.859985, iteration time 0.240092


 63%|██████▎   | 25202/40000 [1:38:04<59:50,  4.12it/s]

Epoch 25200, Total loss 484.859131, iteration time 0.245244


 63%|██████▎   | 25302/40000 [1:38:27<58:26,  4.19it/s]

Epoch 25300, Total loss 482.922668, iteration time 0.244726


 64%|██████▎   | 25402/40000 [1:38:51<58:43,  4.14it/s]

Epoch 25400, Total loss 509.076660, iteration time 0.243801


 64%|██████▍   | 25502/40000 [1:39:14<57:26,  4.21it/s]

Epoch 25500, Total loss 485.373352, iteration time 0.242460


 64%|██████▍   | 25602/40000 [1:39:38<58:01,  4.14it/s]

Epoch 25600, Total loss 534.294800, iteration time 0.245763


 64%|██████▍   | 25702/40000 [1:40:01<56:52,  4.19it/s]

Epoch 25700, Total loss 467.956299, iteration time 0.244859


 65%|██████▍   | 25802/40000 [1:40:25<56:37,  4.18it/s]

Epoch 25800, Total loss 483.986511, iteration time 0.247030


 65%|██████▍   | 25902/40000 [1:40:48<56:07,  4.19it/s]

Epoch 25900, Total loss 545.933228, iteration time 0.243875


 65%|██████▌   | 26002/40000 [1:41:12<1:35:42,  2.44it/s]

Epoch 26000, Total loss 474.673706, iteration time 0.102212


 65%|██████▌   | 26102/40000 [1:41:36<56:11,  4.12it/s]

Epoch 26100, Total loss 426.317749, iteration time 0.251546


 66%|██████▌   | 26202/40000 [1:41:59<55:03,  4.18it/s]

Epoch 26200, Total loss 568.379395, iteration time 0.242995


 66%|██████▌   | 26302/40000 [1:42:23<54:36,  4.18it/s]

Epoch 26300, Total loss 490.181213, iteration time 0.244258


 66%|██████▌   | 26402/40000 [1:42:46<54:32,  4.15it/s]

Epoch 26400, Total loss 515.134155, iteration time 0.247297


 66%|██████▋   | 26502/40000 [1:43:09<53:30,  4.20it/s]

Epoch 26500, Total loss 465.459320, iteration time 0.243319


 67%|██████▋   | 26602/40000 [1:43:33<52:47,  4.23it/s]

Epoch 26600, Total loss 519.951416, iteration time 0.240226


 67%|██████▋   | 26702/40000 [1:43:56<53:03,  4.18it/s]

Epoch 26700, Total loss 429.879822, iteration time 0.241453


 67%|██████▋   | 26802/40000 [1:44:20<53:10,  4.14it/s]

Epoch 26800, Total loss 562.453003, iteration time 0.245807


 67%|██████▋   | 26902/40000 [1:44:43<52:18,  4.17it/s]

Epoch 26900, Total loss 492.486145, iteration time 0.239920


 68%|██████▊   | 27002/40000 [1:45:07<1:16:08,  2.85it/s]

Epoch 27000, Total loss 481.634094, iteration time 0.092960


 68%|██████▊   | 27102/40000 [1:45:31<51:42,  4.16it/s]

Epoch 27100, Total loss 541.981506, iteration time 0.246311


 68%|██████▊   | 27202/40000 [1:45:54<51:53,  4.11it/s]

Epoch 27200, Total loss 187.656738, iteration time 0.244402


 68%|██████▊   | 27302/40000 [1:46:17<50:36,  4.18it/s]

Epoch 27300, Total loss 304.222839, iteration time 0.242687


 69%|██████▊   | 27402/40000 [1:46:41<49:57,  4.20it/s]

Epoch 27400, Total loss 658.416504, iteration time 0.240838


 69%|██████▉   | 27502/40000 [1:47:04<49:54,  4.17it/s]

Epoch 27500, Total loss 347.028198, iteration time 0.242465


 69%|██████▉   | 27602/40000 [1:47:28<49:45,  4.15it/s]

Epoch 27600, Total loss 160.058990, iteration time 0.237122


 69%|██████▉   | 27702/40000 [1:47:51<49:12,  4.16it/s]

Epoch 27700, Total loss 485.539917, iteration time 0.244833


 70%|██████▉   | 27802/40000 [1:48:14<48:33,  4.19it/s]

Epoch 27800, Total loss 527.325439, iteration time 0.240606


 70%|██████▉   | 27902/40000 [1:48:38<48:07,  4.19it/s]

Epoch 27900, Total loss 517.143433, iteration time 0.239691


 70%|███████   | 28002/40000 [1:49:02<1:21:57,  2.44it/s]

Epoch 28000, Total loss 507.579437, iteration time 0.103396


 70%|███████   | 28102/40000 [1:49:25<47:09,  4.20it/s]

Epoch 28100, Total loss 508.421814, iteration time 0.243866


 71%|███████   | 28202/40000 [1:49:49<46:53,  4.19it/s]

Epoch 28200, Total loss 495.727570, iteration time 0.236062


 71%|███████   | 28302/40000 [1:50:12<46:14,  4.22it/s]

Epoch 28300, Total loss 484.665924, iteration time 0.238193


 71%|███████   | 28402/40000 [1:50:36<46:56,  4.12it/s]

Epoch 28400, Total loss 526.116150, iteration time 0.247706


 71%|███████▏  | 28502/40000 [1:50:59<45:33,  4.21it/s]

Epoch 28500, Total loss 511.308655, iteration time 0.238250


 72%|███████▏  | 28602/40000 [1:51:23<45:41,  4.16it/s]

Epoch 28600, Total loss 573.583008, iteration time 0.243902


 72%|███████▏  | 28702/40000 [1:51:46<44:45,  4.21it/s]

Epoch 28700, Total loss 504.897461, iteration time 0.242583


 72%|███████▏  | 28802/40000 [1:52:10<45:24,  4.11it/s]

Epoch 28800, Total loss 464.985657, iteration time 0.242134


 72%|███████▏  | 28902/40000 [1:52:33<44:30,  4.16it/s]

Epoch 28900, Total loss 436.941223, iteration time 0.244227


 73%|███████▎  | 29002/40000 [1:52:58<1:23:31,  2.19it/s]

Epoch 29000, Total loss 500.136780, iteration time 0.092869


 73%|███████▎  | 29102/40000 [1:53:21<43:26,  4.18it/s]

Epoch 29100, Total loss 478.707001, iteration time 0.243261


 73%|███████▎  | 29202/40000 [1:53:44<43:02,  4.18it/s]

Epoch 29200, Total loss 477.501953, iteration time 0.256267


 73%|███████▎  | 29302/40000 [1:54:08<42:39,  4.18it/s]

Epoch 29300, Total loss 493.229248, iteration time 0.242642


 74%|███████▎  | 29402/40000 [1:54:31<42:09,  4.19it/s]

Epoch 29400, Total loss 496.976105, iteration time 0.249751


 74%|███████▍  | 29502/40000 [1:54:55<42:00,  4.16it/s]

Epoch 29500, Total loss 496.328217, iteration time 0.244827


 74%|███████▍  | 29602/40000 [1:55:18<42:01,  4.12it/s]

Epoch 29600, Total loss 474.879639, iteration time 0.249975


 74%|███████▍  | 29702/40000 [1:55:41<41:02,  4.18it/s]

Epoch 29700, Total loss 462.889709, iteration time 0.243989


 75%|███████▍  | 29802/40000 [1:56:05<40:35,  4.19it/s]

Epoch 29800, Total loss 438.168274, iteration time 0.244568


 75%|███████▍  | 29902/40000 [1:56:28<40:14,  4.18it/s]

Epoch 29900, Total loss 598.472412, iteration time 0.243218


 75%|███████▌  | 30002/40000 [1:56:53<1:09:01,  2.41it/s]

Epoch 30000, Total loss 451.791504, iteration time 0.099072


 75%|███████▌  | 30102/40000 [1:57:16<39:22,  4.19it/s]

Epoch 30100, Total loss 510.560181, iteration time 0.242938


 76%|███████▌  | 30202/40000 [1:57:39<39:07,  4.17it/s]

Epoch 30200, Total loss 475.483459, iteration time 0.243894


 76%|███████▌  | 30302/40000 [1:58:03<38:39,  4.18it/s]

Epoch 30300, Total loss 534.634338, iteration time 0.243479


 76%|███████▌  | 30402/40000 [1:58:26<38:31,  4.15it/s]

Epoch 30400, Total loss 458.962433, iteration time 0.246177


 76%|███████▋  | 30502/40000 [1:58:50<37:56,  4.17it/s]

Epoch 30500, Total loss 260.322174, iteration time 0.240134


 77%|███████▋  | 30602/40000 [1:59:13<37:25,  4.18it/s]

Epoch 30600, Total loss 492.898193, iteration time 0.245084


 77%|███████▋  | 30702/40000 [1:59:37<37:04,  4.18it/s]

Epoch 30700, Total loss 383.597351, iteration time 0.245837


 77%|███████▋  | 30802/40000 [2:00:00<36:25,  4.21it/s]

Epoch 30800, Total loss 454.555420, iteration time 0.239425


 77%|███████▋  | 30902/40000 [2:00:24<36:14,  4.18it/s]

Epoch 30900, Total loss 380.426636, iteration time 0.244053


 78%|███████▊  | 31002/40000 [2:00:48<53:23,  2.81it/s]  

Epoch 31000, Total loss 431.993927, iteration time 0.092505


 78%|███████▊  | 31102/40000 [2:01:11<35:49,  4.14it/s]

Epoch 31100, Total loss 454.136414, iteration time 0.255805


 78%|███████▊  | 31202/40000 [2:01:35<34:43,  4.22it/s]

Epoch 31200, Total loss 407.521057, iteration time 0.238463


 78%|███████▊  | 31302/40000 [2:01:58<34:31,  4.20it/s]

Epoch 31300, Total loss 411.213379, iteration time 0.244727


 79%|███████▊  | 31402/40000 [2:02:21<34:22,  4.17it/s]

Epoch 31400, Total loss 263.692627, iteration time 0.243173


 79%|███████▉  | 31502/40000 [2:02:45<34:18,  4.13it/s]

Epoch 31500, Total loss 382.521698, iteration time 0.245213


 79%|███████▉  | 31602/40000 [2:03:08<33:11,  4.22it/s]

Epoch 31600, Total loss 474.584900, iteration time 0.240556


 79%|███████▉  | 31702/40000 [2:03:32<33:12,  4.17it/s]

Epoch 31700, Total loss 380.307434, iteration time 0.243022


 80%|███████▉  | 31802/40000 [2:03:55<32:54,  4.15it/s]

Epoch 31800, Total loss 489.128540, iteration time 0.244985


 80%|███████▉  | 31902/40000 [2:04:19<32:48,  4.11it/s]

Epoch 31900, Total loss 477.199371, iteration time 0.245281


 80%|████████  | 32002/40000 [2:04:43<47:00,  2.84it/s]

Epoch 32000, Total loss 278.679657, iteration time 0.093136


 80%|████████  | 32102/40000 [2:05:06<31:40,  4.16it/s]

Epoch 32100, Total loss 226.108902, iteration time 0.247724


 81%|████████  | 32202/40000 [2:05:30<30:48,  4.22it/s]

Epoch 32200, Total loss 224.067017, iteration time 0.238999


 81%|████████  | 32302/40000 [2:05:53<31:20,  4.09it/s]

Epoch 32300, Total loss 308.012146, iteration time 0.253401


 81%|████████  | 32402/40000 [2:06:16<30:06,  4.21it/s]

Epoch 32400, Total loss 1063.679932, iteration time 0.240959


 81%|████████▏ | 32502/40000 [2:06:40<29:47,  4.19it/s]

Epoch 32500, Total loss 302.794312, iteration time 0.252759


 82%|████████▏ | 32602/40000 [2:07:03<29:30,  4.18it/s]

Epoch 32600, Total loss 347.695068, iteration time 0.245585


 82%|████████▏ | 32702/40000 [2:07:27<29:17,  4.15it/s]

Epoch 32700, Total loss 435.773529, iteration time 0.247917


 82%|████████▏ | 32802/40000 [2:07:50<28:48,  4.16it/s]

Epoch 32800, Total loss 437.850494, iteration time 0.245473


 82%|████████▏ | 32902/40000 [2:08:14<27:58,  4.23it/s]

Epoch 32900, Total loss 428.835144, iteration time 0.236330


 83%|████████▎ | 33002/40000 [2:08:38<41:29,  2.81it/s]

Epoch 33000, Total loss 435.668152, iteration time 0.092236


 83%|████████▎ | 33102/40000 [2:09:01<28:04,  4.10it/s]

Epoch 33100, Total loss 421.665527, iteration time 0.255794


 83%|████████▎ | 33202/40000 [2:09:24<27:01,  4.19it/s]

Epoch 33200, Total loss 439.984406, iteration time 0.243960


 83%|████████▎ | 33302/40000 [2:09:48<26:31,  4.21it/s]

Epoch 33300, Total loss 416.827393, iteration time 0.243904


 84%|████████▎ | 33402/40000 [2:10:11<26:18,  4.18it/s]

Epoch 33400, Total loss 442.000397, iteration time 0.243221


 84%|████████▍ | 33502/40000 [2:10:35<26:20,  4.11it/s]

Epoch 33500, Total loss 441.509644, iteration time 0.253180


 84%|████████▍ | 33602/40000 [2:10:58<25:24,  4.20it/s]

Epoch 33600, Total loss 477.274536, iteration time 0.241374


 84%|████████▍ | 33702/40000 [2:11:22<25:01,  4.20it/s]

Epoch 33700, Total loss 426.845337, iteration time 0.244392


 85%|████████▍ | 33802/40000 [2:11:45<24:39,  4.19it/s]

Epoch 33800, Total loss 386.580017, iteration time 0.241664


 85%|████████▍ | 33902/40000 [2:12:08<24:28,  4.15it/s]

Epoch 33900, Total loss 463.196838, iteration time 0.241163


 85%|████████▌ | 34002/40000 [2:12:32<36:00,  2.78it/s]

Epoch 34000, Total loss 425.439941, iteration time 0.093425


 85%|████████▌ | 34102/40000 [2:12:56<23:31,  4.18it/s]

Epoch 34100, Total loss 379.378601, iteration time 0.243897


 86%|████████▌ | 34202/40000 [2:13:19<22:55,  4.22it/s]

Epoch 34200, Total loss 389.064941, iteration time 0.240201


 86%|████████▌ | 34302/40000 [2:13:43<22:52,  4.15it/s]

Epoch 34300, Total loss 359.122589, iteration time 0.246621


 86%|████████▌ | 34402/40000 [2:14:06<22:07,  4.22it/s]

Epoch 34400, Total loss 437.067841, iteration time 0.241765


 86%|████████▋ | 34502/40000 [2:14:29<21:49,  4.20it/s]

Epoch 34500, Total loss 379.861816, iteration time 0.239036


 87%|████████▋ | 34602/40000 [2:14:53<21:33,  4.17it/s]

Epoch 34600, Total loss 365.931702, iteration time 0.243958


 87%|████████▋ | 34702/40000 [2:15:16<21:15,  4.15it/s]

Epoch 34700, Total loss 386.305908, iteration time 0.247406


 87%|████████▋ | 34802/40000 [2:15:40<20:43,  4.18it/s]

Epoch 34800, Total loss 361.080933, iteration time 0.245261


 87%|████████▋ | 34902/40000 [2:16:03<20:08,  4.22it/s]

Epoch 34900, Total loss 377.300293, iteration time 0.240273


 88%|████████▊ | 35002/40000 [2:16:27<30:26,  2.74it/s]

Epoch 35000, Total loss 360.674744, iteration time 0.094873


 88%|████████▊ | 35102/40000 [2:16:50<19:42,  4.14it/s]

Epoch 35100, Total loss 357.089111, iteration time 0.246594


 88%|████████▊ | 35202/40000 [2:17:14<19:02,  4.20it/s]

Epoch 35200, Total loss 341.824219, iteration time 0.239502


 88%|████████▊ | 35302/40000 [2:17:37<18:43,  4.18it/s]

Epoch 35300, Total loss 315.348816, iteration time 0.243734


 89%|████████▊ | 35402/40000 [2:18:01<18:34,  4.13it/s]

Epoch 35400, Total loss 432.351074, iteration time 0.250851


 89%|████████▉ | 35502/40000 [2:18:24<17:50,  4.20it/s]

Epoch 35500, Total loss 406.705322, iteration time 0.237564


 89%|████████▉ | 35602/40000 [2:18:48<17:33,  4.18it/s]

Epoch 35600, Total loss 388.358215, iteration time 0.244992


 89%|████████▉ | 35702/40000 [2:19:11<17:01,  4.21it/s]

Epoch 35700, Total loss 383.387939, iteration time 0.232706


 90%|████████▉ | 35802/40000 [2:19:34<16:49,  4.16it/s]

Epoch 35800, Total loss 449.870728, iteration time 0.244193


 90%|████████▉ | 35902/40000 [2:19:58<16:16,  4.19it/s]

Epoch 35900, Total loss 281.358978, iteration time 0.242185


 90%|█████████ | 36002/40000 [2:20:22<24:16,  2.74it/s]

Epoch 36000, Total loss 536.776001, iteration time 0.090310


 90%|█████████ | 36102/40000 [2:20:45<15:31,  4.19it/s]

Epoch 36100, Total loss 443.017548, iteration time 0.245613


 91%|█████████ | 36202/40000 [2:21:09<15:15,  4.15it/s]

Epoch 36200, Total loss 206.193420, iteration time 0.246871


 91%|█████████ | 36302/40000 [2:21:32<14:45,  4.18it/s]

Epoch 36300, Total loss 643.917480, iteration time 0.252588


 91%|█████████ | 36402/40000 [2:21:55<14:09,  4.23it/s]

Epoch 36400, Total loss 397.563904, iteration time 0.240332


 91%|█████████▏| 36502/40000 [2:22:19<13:55,  4.19it/s]

Epoch 36500, Total loss 411.329956, iteration time 0.254695


 92%|█████████▏| 36602/40000 [2:22:42<13:38,  4.15it/s]

Epoch 36600, Total loss 394.244141, iteration time 0.250904


 92%|█████████▏| 36702/40000 [2:23:06<13:06,  4.19it/s]

Epoch 36700, Total loss 415.426514, iteration time 0.233114


 92%|█████████▏| 36802/40000 [2:23:29<12:38,  4.22it/s]

Epoch 36800, Total loss 412.967743, iteration time 0.241508


 92%|█████████▏| 36902/40000 [2:23:53<12:22,  4.17it/s]

Epoch 36900, Total loss 430.351379, iteration time 0.244242


 93%|█████████▎| 37002/40000 [2:24:17<21:11,  2.36it/s]

Epoch 37000, Total loss 401.406311, iteration time 0.102705


 93%|█████████▎| 37102/40000 [2:24:40<11:32,  4.19it/s]

Epoch 37100, Total loss 383.896118, iteration time 0.243062


 93%|█████████▎| 37202/40000 [2:25:04<11:02,  4.22it/s]

Epoch 37200, Total loss 414.867432, iteration time 0.241920


 93%|█████████▎| 37302/40000 [2:25:27<10:54,  4.12it/s]

Epoch 37300, Total loss 405.148438, iteration time 0.255134


 94%|█████████▎| 37402/40000 [2:25:51<10:16,  4.21it/s]

Epoch 37400, Total loss 421.759186, iteration time 0.237938


 94%|█████████▍| 37502/40000 [2:26:14<10:01,  4.15it/s]

Epoch 37500, Total loss 464.946899, iteration time 0.252167


 94%|█████████▍| 37602/40000 [2:26:37<09:32,  4.19it/s]

Epoch 37600, Total loss 528.738220, iteration time 0.242125


 94%|█████████▍| 37702/40000 [2:27:01<09:13,  4.15it/s]

Epoch 37700, Total loss 201.979141, iteration time 0.244483


 95%|█████████▍| 37802/40000 [2:27:24<08:42,  4.21it/s]

Epoch 37800, Total loss 262.944519, iteration time 0.242427


 95%|█████████▍| 37902/40000 [2:27:48<08:21,  4.18it/s]

Epoch 37900, Total loss 395.559265, iteration time 0.243284


 95%|█████████▌| 38002/40000 [2:28:11<12:00,  2.77it/s]

Epoch 38000, Total loss 521.619873, iteration time 0.092991


 95%|█████████▌| 38102/40000 [2:28:35<07:33,  4.18it/s]

Epoch 38100, Total loss 476.417419, iteration time 0.241493


 96%|█████████▌| 38202/40000 [2:28:58<07:06,  4.22it/s]

Epoch 38200, Total loss 473.034851, iteration time 0.235818


 96%|█████████▌| 38302/40000 [2:29:21<06:43,  4.21it/s]

Epoch 38300, Total loss 456.007141, iteration time 0.238471


 96%|█████████▌| 38402/40000 [2:29:45<06:18,  4.22it/s]

Epoch 38400, Total loss 459.701172, iteration time 0.238943


 96%|█████████▋| 38502/40000 [2:30:08<06:03,  4.13it/s]

Epoch 38500, Total loss 501.497559, iteration time 0.257255


 97%|█████████▋| 38602/40000 [2:30:32<05:34,  4.18it/s]

Epoch 38600, Total loss 455.524628, iteration time 0.241751


 97%|█████████▋| 38702/40000 [2:30:55<05:07,  4.22it/s]

Epoch 38700, Total loss 446.019531, iteration time 0.240095


 97%|█████████▋| 38802/40000 [2:31:18<04:46,  4.18it/s]

Epoch 38800, Total loss 481.140259, iteration time 0.245196


 97%|█████████▋| 38902/40000 [2:31:42<04:24,  4.16it/s]

Epoch 38900, Total loss 416.637024, iteration time 0.241077


 98%|█████████▊| 39002/40000 [2:32:06<05:51,  2.84it/s]

Epoch 39000, Total loss 446.501923, iteration time 0.091053


 98%|█████████▊| 39102/40000 [2:32:29<03:34,  4.19it/s]

Epoch 39100, Total loss 512.541565, iteration time 0.248194


 98%|█████████▊| 39202/40000 [2:32:53<03:11,  4.18it/s]

Epoch 39200, Total loss 219.574158, iteration time 0.243312


 98%|█████████▊| 39302/40000 [2:33:16<02:50,  4.08it/s]

Epoch 39300, Total loss 179.639191, iteration time 0.249124


 99%|█████████▊| 39402/40000 [2:33:39<02:21,  4.22it/s]

Epoch 39400, Total loss 1134.073608, iteration time 0.239046


 99%|█████████▉| 39502/40000 [2:34:03<01:58,  4.19it/s]

Epoch 39500, Total loss 241.002975, iteration time 0.242445


 99%|█████████▉| 39602/40000 [2:34:26<01:35,  4.17it/s]

Epoch 39600, Total loss 303.573975, iteration time 0.244257


 99%|█████████▉| 39702/40000 [2:34:50<01:11,  4.19it/s]

Epoch 39700, Total loss 572.782227, iteration time 0.239148


100%|█████████▉| 39802/40000 [2:35:13<00:47,  4.20it/s]

Epoch 39800, Total loss 514.175415, iteration time 0.242476


100%|█████████▉| 39902/40000 [2:35:37<00:23,  4.19it/s]

Epoch 39900, Total loss 512.697815, iteration time 0.242875


100%|██████████| 40000/40000 [2:36:00<00:00,  4.27it/s]
