### 1. Install Dependencies

In [1]:
! pip install configargparse

Collecting configargparse
  Downloading ConfigArgParse-1.7-py3-none-any.whl.metadata (23 kB)
Downloading ConfigArgParse-1.7-py3-none-any.whl (25 kB)
Installing collected packages: configargparse
Successfully installed configargparse-1.7


In [2]:
! pip install torch torchvision torchaudio

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

### 2. Imports

In [3]:
import os
import shutil
import time
from datetime import datetime
import pickle
from google.colab import drive

from abc import ABC, abstractmethod
from collections import OrderedDict

import math
import random
import numpy as np

import torch
from torch.utils.data import Dataset
from torch.autograd import grad
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch import nn

import wandb

import matplotlib
import matplotlib.pyplot as plt
import plotly.express as px
import scipy.io as spio

from tqdm.autonotebook import tqdm
from sklearn import svm

import configargparse

  from tqdm.autonotebook import tqdm


In [4]:
torch.cuda.is_available()

True

### 3. Connect to Google Drive

In [5]:
drive.mount("/content/drive",force_remount=True)
os.chdir("/content/drive/My Drive")

Mounted at /content/drive


### 4. Utils Code

In [6]:
# uses model input and real boundary fn
class ReachabilityDataset(Dataset):
    def __init__(self, dynamics, numpoints, pretrain, pretrain_iters, tMin, tMax, counter_start, counter_end, num_src_samples, num_target_samples):
        self.dynamics = dynamics
        self.numpoints = numpoints
        self.pretrain = pretrain
        self.pretrain_counter = 0
        self.pretrain_iters = pretrain_iters
        self.tMin = tMin
        self.tMax = tMax
        self.counter = counter_start
        self.counter_end = counter_end
        self.num_src_samples = num_src_samples
        self.num_target_samples = num_target_samples

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        # uniformly sample domain and include coordinates where source is non-zero
        model_states = torch.zeros(self.numpoints, self.dynamics.state_dim).uniform_(-1, 1)
        if self.num_target_samples > 0:
            target_state_samples = self.dynamics.sample_target_state(self.num_target_samples)
            model_states[-self.num_target_samples:] = self.dynamics.coord_to_input(torch.cat((torch.zeros(self.num_target_samples, 1), target_state_samples), dim=-1))[:, 1:self.dynamics.state_dim+1]

        if self.pretrain:
            # only sample in time around the initial condition
            times = torch.full((self.numpoints, 1), self.tMin)
        else:
            # slowly grow time values from start time
            times = self.tMin + torch.zeros(self.numpoints, 1).uniform_(0, (self.tMax-self.tMin) * (self.counter / self.counter_end))
            # make sure we always have training samples at the initial time
            times[-self.num_src_samples:, 0] = self.tMin
        model_coords = torch.cat((times, model_states), dim=1)
        if self.dynamics.input_dim > self.dynamics.state_dim + 1: # temporary workaround for having to deal with dynamics classes for parametrized models with extra inputs
            model_coords = torch.cat((model_coords, torch.zeros(self.numpoints, self.dynamics.input_dim - self.dynamics.state_dim - 1)), dim=1)

        boundary_values = self.dynamics.boundary_fn(self.dynamics.input_to_coord(model_coords)[..., 1:])
        if self.dynamics.loss_type == 'brat_hjivi':
            reach_values = self.dynamics.reach_fn(self.dynamics.input_to_coord(model_coords)[..., 1:])
            avoid_values = self.dynamics.avoid_fn(self.dynamics.input_to_coord(model_coords)[..., 1:])

        if self.pretrain:
            dirichlet_masks = torch.ones(model_coords.shape[0]) > 0
        else:
            # only enforce initial conditions around self.tMin
            dirichlet_masks = (model_coords[:, 0] == self.tMin)

        if self.pretrain:
            self.pretrain_counter += 1
        elif self.counter < self.counter_end:
            self.counter += 1

        if self.pretrain and self.pretrain_counter == self.pretrain_iters:
            self.pretrain = False

        if self.dynamics.loss_type == 'brt_hjivi':
            return {'model_coords': model_coords}, {'boundary_values': boundary_values, 'dirichlet_masks': dirichlet_masks}
        elif self.dynamics.loss_type == 'brat_hjivi':
            return {'model_coords': model_coords}, {'boundary_values': boundary_values, 'reach_values': reach_values, 'avoid_values': avoid_values, 'dirichlet_masks': dirichlet_masks}
        else:
            raise NotImplementedError

In [7]:
# TODO: I don't think jacobian is needed here; torch.autograd.grad should be enough, to compute gradients of a scalar value function w.r.t. inputs

# batched jacobian
# y: [..., N], x: [..., M] -> [..., N, M]
def jacobian(y, x):
    ''' jacobian of y wrt x '''
    jac = torch.zeros(*y.shape, x.shape[-1]).to(y.device)
    for i in range(y.shape[-1]):
        # calculate dydx over batches for each feature value of y
        y_flat = y[...,i].view(-1, 1)
        jac[..., i, :] = grad(y_flat, x, torch.ones_like(y_flat), create_graph=True)[0]

    status = 0
    if torch.any(torch.isnan(jac)):
        status = -1

    return jac, status

In [8]:
class Validator(ABC):
    @abstractmethod
    def validate(self, coords, values):
        raise NotImplementedError

class ValueThresholdValidator(Validator):
    def __init__(self, v_min, v_max):
        self.v_min = v_min
        self.v_max = v_max

    def validate(self, coords, values):
        return (values >= self.v_min)*(values <= self.v_max)

class MLPValidator(Validator):
    def __init__(self, device, mlp, o_min, o_max, model, dynamics):
        self.device = device
        self.mlp = mlp
        self.o_min = o_min
        self.o_max = o_max
        self.model = model
        self.dynamics = dynamics

    def validate(self, coords, values):
        model_results = self.model({'coords': self.dynamics.coord_to_input(coords.to(self.device))})
        inputs = torch.cat((coords[..., 1:].to(self.device), values[:, None].to(self.device)), dim=-1)
        outputs = torch.sigmoid(self.mlp(inputs).squeeze())
        return ((outputs >= self.o_min)*(outputs <=self.o_max)).to(device=values.device)

class MLPConditionedValidator(Validator):
    def __init__(self, device, mlp, o_levels, v_levels, model, dynamics):
        self.device = device
        self.mlp = mlp
        self.o_levels = o_levels
        self.v_levels = v_levels
        self.model = model
        self.dynamics = dynamics
        assert len(self.o_levels) == len(self.v_levels) + 1

    def validate(self, coords, values):
        model_results = self.model({'coords': self.dynamics.coord_to_input(coords.to(self.device))})
        inputs = torch.cat((coords[..., 1:].to(self.device), values[:, None].to(self.device)), dim=-1)
        outputs = torch.sigmoid(self.mlp(inputs).squeeze(dim=-1)).to(device=values.device)
        valids = torch.zeros_like(outputs)
        for i in range(len(self.o_levels) - 1):
            valids = torch.logical_or(
                valids,
                (outputs > self.o_levels[i])*(outputs <= self.o_levels[i+1])*(values >= self.v_levels[i][0])*(values <= self.v_levels[i][1])
            )
        return valids

class MultiValidator(Validator):
    def __init__(self, validators):
        self.validators = validators

    def validate(self, coords, values):
        result = self.validators[0].validate(coords, values)
        for i in range(len(self.validators)-1):
            result = result * self.validators[i+1].validate(coords, values)
        return result

class SampleGenerator(ABC):
    @abstractmethod
    def sample(self, num_samples):
        raise NotImplementedError

class SliceSampleGenerator(SampleGenerator):
    def __init__(self, dynamics, slices):
        self.dynamics = dynamics
        self.slices = slices
        assert self.dynamics.state_dim == len(slices)

    def sample(self, num_samples):
        samples = torch.zeros(num_samples, self.dynamics.state_dim)
        for dim in range(self.dynamics.state_dim):
            if self.slices[dim] is None:
                samples[:, dim].uniform_(*self.dynamics.state_test_range()[dim])
            else:
                samples[:, dim] = self.slices[dim]
        return samples

import torch
from tqdm import tqdm

# # get the tEarliest in [tMin:tMax:dt] at which the state is still valid
# def get_tEarliest(device, model, dynamics, state, tMin, tMax, dt, validator):
#     with torch.no_grad():
#         tEarliest = torch.full(state.shape[:-1], tMin - 1)
#         model_state = dynamics.normalize_state(state)

#         times_to_try = torch.arange(tMin, tMax + dt, dt)
#         for time_to_try in times_to_try:
#             blank_idx = (tEarliest < tMin)
#             time = torch.full((*state.shape[:-1], 1), time_to_try)
#             model_time = dynamics.normalize_time(time)
#             model_coord = torch.cat((model_time, model_state), dim=-1)[blank_idx]
#             model_result = model({'coords': model_coord.to(device)})
#             value = dynamics.output_to_value(output=model_result['model_out'][..., 0], state=state.to(device)).cpu()
#             valid_idx = validator.validate(torch.cat((time, state), dim=-1), value)
#             tMasked = tEarliest[blank_idx]
#             tMasked[valid_idx] = time_to_try
#             tEarliest[blank_idx] = tMasked
#             if torch.all(tEarliest >= tMin):
#                 break
#         blank_idx = (tEarliest < tMin)
#         if torch.any(blank_idx):
#             print(str(torch.sum(blank_idx)), 'invalid states')
#             tEarliest[blank_idx] = tMax
#         return tEarliest

def scenario_optimization(device, model, policy, dynamics, tMin, tMax, dt, set_type, control_type, scenario_batch_size, sample_batch_size, sample_generator, sample_validator, violation_validator, max_scenarios=None, max_samples=None, max_violations=None, tStart_generator=None):
    rem = ((tMax-tMin) / dt)%1
    e_tol = 1e-12
    assert rem < e_tol or 1 - rem < e_tol, f'{tMax-tMin} is not divisible by {dt}'
    assert tMax > tMin
    assert set_type in ['BRS', 'BRT']
    if set_type == 'BRS':
        print('confirm correct calculation of true values of trajectories (batch_scenario_costs)')
        raise NotImplementedError
    assert control_type in ['value', 'ttr', 'init_ttr']
    assert max_scenarios or max_samples or max_violations, 'one of the termination conditions must be used'
    if max_scenarios:
        assert (max_scenarios / scenario_batch_size)%1 == 0, 'max_scenarios is not divisible by scenario_batch_size'
    if max_samples:
        assert (max_samples / sample_batch_size)%1 == 0, 'max_samples is not divisible by sample_batch_size'

    # accumulate scenarios
    times = torch.zeros(0, )
    states = torch.zeros(0, dynamics.state_dim)
    values = torch.zeros(0, )
    costs = torch.zeros(0, )
    init_hams = torch.zeros(0, )
    mean_hams = torch.zeros(0, )
    mean_abs_hams = torch.zeros(0, )
    max_abs_hams = torch.zeros(0, )
    min_abs_hams = torch.zeros(0, )

    num_scenarios = 0
    num_samples = 0
    num_violations = 0

    pbar_pos = 0
    if max_scenarios:
        scenarios_pbar = tqdm(total=max_scenarios, desc='Scenarios', position=pbar_pos)
        pbar_pos += 1
    if max_samples:
        samples_pbar = tqdm(total=max_samples, desc='Samples', position=pbar_pos)
        pbar_pos += 1
    if max_violations:
        violations_pbar = tqdm(total=max_violations, desc='Violations', position=pbar_pos)
        pbar_pos += 1

    nums_valid_samples = []
    while True:
        if (max_scenarios and (num_scenarios >= max_scenarios)) or (max_violations and (num_violations >= max_violations)):
            break

        batch_scenario_times = torch.zeros(scenario_batch_size, )
        batch_scenario_states = torch.zeros(scenario_batch_size, dynamics.state_dim)
        batch_scenario_values = torch.zeros(scenario_batch_size, )

        num_collected_scenarios = 0
        while num_collected_scenarios < scenario_batch_size:
            if max_samples and (num_samples >= max_samples):
                break
            # sample batch
            if tStart_generator is not None:
                batch_sample_times = tStart_generator(sample_batch_size)
                # need to round to nearest dt
                batch_sample_times = torch.round(batch_sample_times/dt)*dt
            else:
                batch_sample_times = torch.full((sample_batch_size, ), tMax)
            batch_sample_states = dynamics.equivalent_wrapped_state(sample_generator.sample(sample_batch_size))
            batch_sample_coords = torch.cat((batch_sample_times.unsqueeze(-1), batch_sample_states), dim=-1)

            # validate batch
            with torch.no_grad():
                batch_sample_model_results = model({'coords': dynamics.coord_to_input(batch_sample_coords.to(device))})
                batch_sample_values = dynamics.io_to_value(batch_sample_model_results['model_in'].detach(), batch_sample_model_results['model_out'].squeeze(dim=-1).detach())
            batch_valid_sample_idxs = torch.where(sample_validator.validate(batch_sample_coords, batch_sample_values))[0].detach().cpu()

            # store valid samples
            num_valid_samples = len(batch_valid_sample_idxs)
            start_idx = num_collected_scenarios
            end_idx = min(start_idx + num_valid_samples, scenario_batch_size)
            batch_scenario_times[start_idx:end_idx] = batch_sample_times[batch_valid_sample_idxs][:end_idx-start_idx]
            batch_scenario_states[start_idx:end_idx] = batch_sample_states[batch_valid_sample_idxs][:end_idx-start_idx]
            batch_scenario_values[start_idx:end_idx] = batch_sample_values[batch_valid_sample_idxs][:end_idx-start_idx]

            # update counters
            num_samples += sample_batch_size
            if max_samples:
                samples_pbar.update(sample_batch_size)
            num_collected_scenarios += end_idx - start_idx
            nums_valid_samples.append(num_valid_samples)
        if max_samples and (num_samples >= max_samples):
            break

        # propagate scenarios
        state_trajs = torch.zeros(scenario_batch_size, int((tMax-tMin)/dt + 1), dynamics.state_dim)
        ctrl_trajs = torch.zeros(scenario_batch_size, int((tMax-tMin)/dt), dynamics.control_dim)
        dstb_trajs = torch.zeros(scenario_batch_size, int((tMax-tMin)/dt), dynamics.disturbance_dim)
        ham_trajs = torch.zeros(scenario_batch_size, int((tMax-tMin)/dt))

        state_trajs[:, 0, :] = batch_scenario_states
        for k in tqdm(range(int((tMax-tMin)/dt)), desc='Trajectory Propagation', position=pbar_pos, leave=False):
            if control_type == 'value':
                traj_time = tMax - k*dt
                traj_times = torch.full((scenario_batch_size, ), traj_time)
            # elif control_type == 'ttr':
            #     traj_times = get_tEarliest(model=model, dynamics=dynamics, state=state_trajs[:, k], tMin=tMin, tMax=traj_time, dt=dt, validator=sample_validator)
            # elif control_type == 'init_ttr':
            #     if k == 0:
            #         init_traj_times = get_tEarliest(model=model, dynamics=dynamics, state=state_trajs[:, k], tMin=tMin, tMax=traj_time, dt=dt, validator=sample_validator)
            #     traj_times = torch.maximum(init_traj_times - k*dt, torch.tensor(tMin)) # check whether this is the best thing to do for init_ttr
            traj_coords = torch.cat((traj_times.unsqueeze(-1), state_trajs[:, k]), dim=-1)
            traj_policy_results = policy({'coords': dynamics.coord_to_input(traj_coords.to(device))})
            traj_dvs = dynamics.io_to_dv(traj_policy_results['model_in'], traj_policy_results['model_out'].squeeze(dim=-1)).detach()

            # TODO: I do not think there is actually any reason to store these trajs? Could save space by removing these.
            ctrl_trajs[:, k] = dynamics.optimal_control(traj_coords[:, 1:].to(device), traj_dvs[..., 1:].to(device))
            dstb_trajs[:, k] = dynamics.optimal_disturbance(traj_coords[:, 1:].to(device), traj_dvs[..., 1:].to(device))
            ham_trajs[:, k] = dynamics.hamiltonian(traj_coords[:, 1:].to(device), traj_dvs[..., 1:].to(device))

            if tStart_generator is not None: # freeze states whose start time has not been reached yet
                is_frozen = batch_scenario_times < traj_times
                is_unfrozen = torch.logical_not(is_frozen)
                state_trajs[is_frozen, k+1] = state_trajs[is_frozen, k]
                state_trajs[is_unfrozen, k+1] = dynamics.equivalent_wrapped_state(state_trajs[is_unfrozen, k].to(device) + dt*dynamics.dsdt(state_trajs[is_unfrozen, k].to(device), ctrl_trajs[is_unfrozen, k].to(device), dstb_trajs[is_unfrozen, k].to(device))).cpu()
            else:
                state_trajs[:, k+1] = dynamics.equivalent_wrapped_state(state_trajs[:, k].to(device) + dt*dynamics.dsdt(state_trajs[:, k].to(device), ctrl_trajs[:, k].to(device), dstb_trajs[:, k].to(device)))

        # compute batch_scenario_costs
        # TODO: need to handle the case of using tStart_generator when extending a trajectory by a frozen initial state will inadvertently affect cost computation (the min lx cost formulation is unaffected, but other cost formulations might care)
        if set_type == 'BRT':
            batch_scenario_costs = dynamics.cost_fn(state_trajs.to(device))
        elif set_type == 'BRS':
            if control_type == 'init_ttr': # is this correct for init_ttr?
                batch_scenario_costs =  dynamics.boundary_fn(state_trajs.to(device))[:, (init_traj_times - tMin) / dt]
            elif control_type == 'value':
                batch_scenario_costs =  dynamics.boundary_fn(state_trajs.to(device))[:, -1]
            else:
                raise NotImplementedError # what is the correct thing to do for ttr?

        # compute batch_scenario_init_hams, batch_scenario_mean_hams, batch_scenario_mean_abs_hams, batch_scenario_max_abs_hams, batch_scenario_min_abs_hams
        batch_scenario_init_hams = ham_trajs[:, 0]
        batch_scenario_mean_hams = torch.mean(ham_trajs, dim=-1)
        batch_scenario_mean_abs_hams = torch.mean(torch.abs(ham_trajs), dim=-1)
        batch_scenario_max_abs_hams = torch.max(torch.abs(ham_trajs), dim=-1).values
        batch_scenario_min_abs_hams = torch.min(torch.abs(ham_trajs), dim=-1).values

        # store scenarios
        times = torch.cat((times, batch_scenario_times.cpu()), dim=0)
        states = torch.cat((states, batch_scenario_states.cpu()), dim=0)
        values = torch.cat((values, batch_scenario_values.cpu()), dim=0)
        costs = torch.cat((costs, batch_scenario_costs.cpu()), dim=0)
        init_hams = torch.cat((init_hams, batch_scenario_init_hams.cpu()), dim=0)
        mean_hams = torch.cat((mean_hams, batch_scenario_mean_hams.cpu()), dim=0)
        mean_abs_hams = torch.cat((mean_abs_hams, batch_scenario_mean_abs_hams.cpu()), dim=0)
        max_abs_hams = torch.cat((max_abs_hams, batch_scenario_max_abs_hams.cpu()), dim=0)
        min_abs_hams = torch.cat((min_abs_hams, batch_scenario_min_abs_hams.cpu()), dim=0)

        # update counters
        num_scenarios += scenario_batch_size
        if max_scenarios:
            scenarios_pbar.update(scenario_batch_size)
        num_new_violations = int(torch.sum(violation_validator.validate(batch_scenario_states, batch_scenario_costs)))
        num_violations += num_new_violations
        if max_violations:
            violations_pbar.update(num_new_violations)

    if max_scenarios:
        scenarios_pbar.close()
    if max_samples:
        samples_pbar.close()
    if max_violations:
        violations_pbar.close()

    violations = violation_validator.validate(states, costs)

    return {
        'times': times,
        'states': states,
        'values': values,
        'costs': costs,
        'init_hams': init_hams,
        'init_abs_hams': torch.abs(init_hams),
        'mean_hams': mean_hams,
        'mean_abs_hams': mean_abs_hams,
        'max_abs_hams': max_abs_hams,
        'min_abs_hams': min_abs_hams,
        'violations': violations,
        'valid_sample_fraction': torch.mean(torch.tensor(nums_valid_samples, dtype=float))/sample_batch_size,
        'violation_rate': 0 if not num_scenarios else num_violations / num_scenarios,
        'maxed_scenarios': (max_scenarios is not None) and num_scenarios >= max_scenarios,
        'maxed_samples': (max_samples is not None) and num_samples >= max_samples,
        'maxed_violations': (max_violations is not None) and num_violations >= max_violations,
        'batch_state_trajs': None if (max_samples and (num_samples >= max_samples)) else state_trajs,
    }

def target_fraction(device, model, dynamics, t, sample_validator, target_validator, num_samples, batch_size):
    with torch.no_grad():
        states = torch.zeros(0, dynamics.state_dim)
        values = torch.zeros(0, )

        while len(states) < num_samples:
            # sample batch
            batch_times = torch.full((batch_size, 1), t)
            batch_states = torch.zeros(batch_size, dynamics.state_dim)
            for dim in range(dynamics.state_dim):
                batch_states[:, dim].uniform_(*dynamics.state_test_range()[dim])
            batch_states = dynamics.equivalent_wrapped_state(batch_states)
            batch_coords = torch.cat((batch_times, batch_states), dim=-1)

            # validate batch
            batch_model_results = model({'coords': dynamics.coord_to_input(batch_coords.to(device))})
            batch_values = dynamics.io_to_value(batch_model_results['model_in'], batch_model_results['model_out'].squeeze(dim=-1)).detach()
            batch_valids = sample_validator.validate(batch_coords, batch_values).detach().cpu()

            # store valid portion of batch
            states = torch.cat((states, batch_states[batch_valids].cpu()), dim=0)
            values = torch.cat((values, batch_values[batch_valids].cpu()), dim=0)

        states = states[:num_samples]
        values = values[:num_samples]
        coords = torch.cat((torch.full((num_samples, 1), t), states), dim=-1)
        valids = target_validator.validate(coords.to(device), values.to(device))
    return torch.sum(valids) / num_samples

class MLP(torch.nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()

        s1 = int(2*input_size)
        s2 = int(input_size)
        s3 = int(input_size)
        self.l1 = torch.nn.Linear(input_size, s1)
        self.a1 = torch.nn.ReLU()
        self.l2 = torch.nn.Linear(s1, s2)
        self.a2 = torch.nn.ReLU()
        self.l3 = torch.nn.Linear(s2, s3)
        self.a3 = torch.nn.ReLU()
        self.l4 = torch.nn.Linear(s3, 1)

    def forward(self, x):
        x = self.l1(x)
        x = self.a1(x)
        x = self.l2(x)
        x = self.a2(x)
        x = self.l3(x)
        x = self.a3(x)
        x = self.l4(x)
        return x



def sample_values(device, model, dynamics, t, num_samples, batch_size):
    with torch.no_grad():
        states = torch.zeros(0, dynamics.state_dim)
        values = torch.zeros(0, )

        while len(states) < num_samples:
            # sample batch
            batch_times = torch.full((batch_size, 1), t)
            batch_states = torch.zeros(batch_size, dynamics.state_dim)
            for dim in range(dynamics.state_dim):
                batch_states[:, dim].uniform_(*dynamics.state_test_range()[dim])
            batch_states = dynamics.equivalent_wrapped_state(batch_states)
            batch_coords = torch.cat((batch_times, batch_states), dim=-1)

            batch_model_results = model({'coords': dynamics.coord_to_input(batch_coords.to(device))})
            batch_values = dynamics.io_to_value(batch_model_results['model_in'], batch_model_results['model_out'].squeeze(dim=-1)).detach()

            # store batch
            states = torch.cat((states, batch_states.cpu()), dim=0)
            values = torch.cat((values, batch_values.cpu()), dim=0)

        states = states[:num_samples]
        values = values[:num_samples]
        coords = torch.cat((torch.full((num_samples, 1), t), states), dim=-1)
    return values


In [9]:
# uses real units
def init_brt_hjivi_loss(dynamics, minWith, dirichlet_loss_divisor):
    def brt_hjivi_loss(state, value, dvdt, dvds, boundary_value, dirichlet_mask, output):
        if torch.all(dirichlet_mask):
            # pretraining loss
            diff_constraint_hom = torch.Tensor([0])
        else:
            ham = dynamics.hamiltonian(state, dvds)
            if minWith == 'zero':
                ham = torch.clamp(ham, max=0.0)

            diff_constraint_hom = dvdt - ham
            if minWith == 'target':
                diff_constraint_hom = torch.max(
                    diff_constraint_hom, value - boundary_value)
        dirichlet = value[dirichlet_mask] - boundary_value[dirichlet_mask]
        if dynamics.deepreach_model == 'exact':
            if torch.all(dirichlet_mask):
                # pretraining
                dirichlet = output.squeeze(dim=-1)[dirichlet_mask]-0.0
            else:
                return {'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()}

        return {'dirichlet': torch.abs(dirichlet).sum() / dirichlet_loss_divisor,
                'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()}

    return brt_hjivi_loss
def init_brat_hjivi_loss(dynamics, minWith, dirichlet_loss_divisor):
    def brat_hjivi_loss(state, value, dvdt, dvds, boundary_value, reach_value, avoid_value, dirichlet_mask, output):
        if torch.all(dirichlet_mask):
            # pretraining loss
            diff_constraint_hom = torch.Tensor([0])
        else:
            ham = dynamics.hamiltonian(state, dvds)
            if minWith == 'zero':
                ham = torch.clamp(ham, max=0.0)

            diff_constraint_hom = dvdt - ham
            if minWith == 'target':
                diff_constraint_hom = torch.min(
                    torch.max(diff_constraint_hom, value - reach_value), value + avoid_value)

        dirichlet = value[dirichlet_mask] - boundary_value[dirichlet_mask]
        if dynamics.deepreach_model == 'exact':
            if torch.all(dirichlet_mask):
                dirichlet = output.squeeze(dim=-1)[dirichlet_mask]-0.0
            else:
                return {'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()}
        return {'dirichlet': torch.abs(dirichlet).sum() / dirichlet_loss_divisor,
                'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()}
    return brat_hjivi_loss

In [10]:
"""
MIT License

Copyright (c) 2020 Vincent Sitzmann

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

class BatchLinear(nn.Linear):
    '''A linear layer'''
    __doc__ = nn.Linear.__doc__

    def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        bias = params.get('bias', None)
        weight = params['weight']

        output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
        output += bias.unsqueeze(-2)
        return output


class Sine(nn.Module):
    def __init(self):
        super().__init__()

    def forward(self, input):
        # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
        return torch.sin(30 * input)


class FCBlock(nn.Module):
    '''A fully connected neural network.
    '''

    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=False, nonlinearity='relu', weight_init=None):
        super().__init__()

        self.first_layer_init = None

        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init),
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None)}

        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init

        self.net = []
        self.net.append(nn.Sequential(
            BatchLinear(in_features, hidden_features), nl
        ))

        for i in range(num_hidden_layers):
            self.net.append(nn.Sequential(
                BatchLinear(hidden_features, hidden_features), nl
            ))

        if outermost_linear:
            self.net.append(nn.Sequential(BatchLinear(hidden_features, out_features)))
        else:
            self.net.append(nn.Sequential(
                BatchLinear(hidden_features, out_features), nl
            ))

        self.net = nn.Sequential(*self.net)
        if self.weight_init is not None:
            self.net.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.net[0].apply(first_layer_init)

    def forward(self, coords, params=None, **kwargs):
        if params is None:
            params = OrderedDict(self.named_parameters())

        output = self.net(coords)
        return output


class SingleBVPNet(nn.Module):
    '''A canonical representation network for a BVP.'''

    def __init__(self, out_features=1, type='sine', in_features=2,
                 mode='mlp', hidden_features=256, num_hidden_layers=3, **kwargs):
        super().__init__()
        self.mode = mode
        self.net = FCBlock(in_features=in_features, out_features=out_features, num_hidden_layers=num_hidden_layers,
                           hidden_features=hidden_features, outermost_linear=True, nonlinearity=type)
        print(self)

    def forward(self, model_input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        # Enables us to compute gradients w.r.t. coordinates
        # TODO: should not need to .clone().detach().requires_grad_(True); instead, use .retain_grad() on input in calling script
        # otherwise, .detach() removes input from the graph so grad cannot propagate back end-to-end, e.g., percept -> NN -> state estimation (input)
        coords_org = model_input['coords'].clone().detach().requires_grad_(True)
        coords = coords_org

        output = self.net(coords)
        return {'model_in': coords_org, 'model_out': output}


########################
# Initialization methods
def init_weights_normal(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')


def init_weights_selu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))


def init_weights_elu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))


def init_weights_xavier(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.xavier_normal_(m.weight)


def sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)


def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-1 / num_input, 1 / num_input)


### 5. Dynamics Code

In [11]:
# during training, states will be sampled uniformly by each state dimension from the model-unit -1 to 1 range (for training stability),
# which may or may not correspond to proper test ranges
# note that coord refers to [time, *state], and input refers to whatever is fed directly to the model (often [time, *state, params])
# in the future, code will need to be fixed to correctly handle parameterized models
class Dynamics(ABC):
    def __init__(self,
    loss_type:str, set_mode:str,
    state_dim:int, input_dim:int,
    control_dim:int, disturbance_dim:int,
    state_mean:list, state_var:list,
    value_mean:float, value_var:float, value_normto:float,
    deepreach_model:str):
        self.loss_type = loss_type
        self.set_mode = set_mode
        self.state_dim = state_dim
        self.input_dim = input_dim
        self.control_dim = control_dim
        self.disturbance_dim = disturbance_dim
        self.state_mean = torch.tensor(state_mean)
        self.state_var = torch.tensor(state_var)
        self.value_mean = value_mean
        self.value_var = value_var
        self.value_normto = value_normto
        self.deepreach_model = deepreach_model
        assert self.loss_type in ['brt_hjivi', 'brat_hjivi'], f'loss type {self.loss_type} not recognized'
        if self.loss_type == 'brat_hjivi':
            assert callable(self.reach_fn) and callable(self.avoid_fn)
        assert self.set_mode in ['reach', 'avoid'], f'set mode {self.set_mode} not recognized'
        for state_descriptor in [self.state_mean, self.state_var]:
            assert len(state_descriptor) == self.state_dim, 'state descriptor dimension does not equal state dimension, ' + str(len(state_descriptor)) + ' != ' + str(self.state_dim)

    # ALL METHODS ARE BATCH COMPATIBLE

    # MODEL-UNIT CONVERSIONS (TODO: refactor into separate model-unit conversion class?)

    # convert model input to real coord
    def input_to_coord(self, input):
        coord = input.clone()
        coord[..., 1:] = (input[..., 1:] * self.state_var.to(device=input.device)) + self.state_mean.to(device=input.device)
        return coord

    # convert real coord to model input
    def coord_to_input(self, coord):
        input = coord.clone()
        input[..., 1:] = (coord[..., 1:] - self.state_mean.to(device=coord.device)) / self.state_var.to(device=coord.device)
        return input

    # convert model io to real value
    def io_to_value(self, input, output):
        if self.deepreach_model=="diff":
            return (output * self.value_var / self.value_normto) + self.boundary_fn(self.input_to_coord(input)[..., 1:])
        elif self.deepreach_model=="exact":
            return (output * input[..., 0] * self.value_var / self.value_normto) + self.boundary_fn(self.input_to_coord(input)[..., 1:])
        else:
            return (output * self.value_var / self.value_normto) + self.value_mean

    # convert model io to real dv
    def io_to_dv(self, input, output):
        dodi = jacobian(output.unsqueeze(dim=-1), input)[0].squeeze(dim=-2)

        if self.deepreach_model=="diff":
            dvdt = (self.value_var / self.value_normto) * dodi[..., 0]

            dvds_term1 = (self.value_var / self.value_normto / self.state_var.to(device=dodi.device)) * dodi[..., 1:]
            state = self.input_to_coord(input)[..., 1:]
            dvds_term2 = jacobian(self.boundary_fn(state).unsqueeze(dim=-1), state)[0].squeeze(dim=-2)
            dvds = dvds_term1 + dvds_term2
        elif self.deepreach_model=="exact":
            dvdt = (self.value_var / self.value_normto) * \
                (input[..., 0]*dodi[..., 0] + output)

            dvds_term1 = (self.value_var / self.value_normto /
                          self.state_var.to(device=dodi.device)) * dodi[..., 1:] * input[..., 0].unsqueeze(-1)
            state = self.input_to_coord(input)[..., 1:]
            dvds_term2 = jacobian(self.boundary_fn(
                state).unsqueeze(dim=-1), state)[0].squeeze(dim=-2)
            dvds = dvds_term1 + dvds_term2
        else:
            dvdt = (self.value_var / self.value_normto) * dodi[..., 0]
            dvds = (self.value_var / self.value_normto / self.state_var.to(device=dodi.device)) * dodi[..., 1:]

        return torch.cat((dvdt.unsqueeze(dim=-1), dvds), dim=-1)

    # ALL FOLLOWING METHODS USE REAL UNITS

    @abstractmethod
    def state_test_range(self):
        raise NotImplementedError

    @abstractmethod
    def equivalent_wrapped_state(self, state):
        raise NotImplementedError

    @abstractmethod
    def dsdt(self, state, control, disturbance):
        raise NotImplementedError

    @abstractmethod
    def boundary_fn(self, state):
        raise NotImplementedError

    @abstractmethod
    def sample_target_state(self, num_samples):
        raise NotImplementedError

    @abstractmethod
    def cost_fn(self, state_traj):
        raise NotImplementedError

    @abstractmethod
    def hamiltonian(self, state, dvds):
        raise NotImplementedError

    @abstractmethod
    def optimal_control(self, state, dvds):
        raise NotImplementedError

    @abstractmethod
    def optimal_disturbance(self, state, dvds):
        raise NotImplementedError

    @abstractmethod
    def plot_config(self):
        raise NotImplementedError

class ParameterizedVertDrone2D(Dynamics):
    def __init__(self, gravity:float, input_multiplier_max:float, input_magnitude_max:float):
        self.gravity = gravity                             # g
        self.input_multiplier_max = input_multiplier_max   # k_max
        self.input_magnitude_max = input_magnitude_max     # u_max
        super().__init__(
            loss_type='brt_hjivi', set_mode='avoid',
            state_dim=3, input_dim=4, control_dim=1, disturbance_dim=0,
            state_mean=[0, 1.5, self.input_multiplier_max/2], # v, z, k
            state_var=[4, 2, self.input_multiplier_max/2],    # v, z, k
            value_mean=0.25,
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-4, 4],                        # v
            [-0.5, 3.5],                    # z
            [0, self.input_multiplier_max], # k
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        return wrapped_state

    # ParameterizedVertDrone2D dynamics
    # \dot v = k*u - g
    # \dot z = v
    # \dot k = 0
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = state[..., 2]*control[..., 0] - self.gravity
        dsdt[..., 1] = state[..., 0]
        dsdt[..., 2] = 0
        return dsdt

    def boundary_fn(self, state):
        return -torch.abs(state[..., 1] - 1.5) + 1.5

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        raise NotImplementedError

    def hamiltonian(self, state, dvds):
        return state[..., 2]*torch.abs(dvds[..., 0]*self.input_magnitude_max) \
                - dvds[..., 0]*self.gravity \
                + dvds[..., 1]*state[..., 0]

    def optimal_control(self, state, dvds):
        raise NotImplementedError

    def optimal_disturbance(self, state, dvds):
        raise NotImplementedError

    def plot_config(self):
        return {
            'state_slices': [0, 1.5, self.input_multiplier_max/2],
            'state_labels': ['v', 'z', 'k'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 2,
        }

class Air3D(Dynamics):
    def __init__(self, collisionR:float, velocity:float, omega_max:float, angle_alpha_factor:float):
        self.collisionR = collisionR
        self.velocity = velocity
        self.omega_max = omega_max
        self.angle_alpha_factor = angle_alpha_factor
        super().__init__(
            loss_type='brt_hjivi', set_mode='avoid',
            state_dim=3, input_dim=4, control_dim=1, disturbance_dim=1,
            state_mean=[0, 0, 0],
            state_var=[1, 1, self.angle_alpha_factor*math.pi],
            value_mean=0.25,
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-1, 1],
            [-1, 1],
            [-math.pi, math.pi],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # Air3D dynamics
    # \dot x    = -v + v \cos \psi + u y
    # \dot y    = v \sin \psi - u x
    # \dot \psi = d - u
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = -self.velocity + self.velocity*torch.cos(state[..., 2]) + control[..., 0]*state[..., 1]
        dsdt[..., 1] = self.velocity*torch.sin(state[..., 2]) - control[..., 0]*state[..., 0]
        dsdt[..., 2] = disturbance[..., 0] - control[..., 0]
        return dsdt

    def boundary_fn(self, state):
        return torch.norm(state[..., :2], dim=-1) - self.collisionR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        ham = self.omega_max * torch.abs(dvds[..., 0] * state[..., 1] - dvds[..., 1] * state[..., 0] - dvds[..., 2])  # Control component
        ham = ham - self.omega_max * torch.abs(dvds[..., 2])  # Disturbance component
        ham = ham + (self.velocity * (torch.cos(state[..., 2]) - 1.0) * dvds[..., 0]) + (self.velocity * torch.sin(state[..., 2]) * dvds[..., 1])  # Constant component
        return ham

    def optimal_control(self, state, dvds):
        det = dvds[..., 0]*state[..., 1] - dvds[..., 1]*state[..., 0]-dvds[..., 2]
        return (self.omega_max * torch.sign(det))[..., None]

    def optimal_disturbance(self, state, dvds):
        return (-self.omega_max * torch.sign(dvds[..., 2]))[..., None]

    def plot_config(self):
        return {
            'state_slices': [0, 0, 0],
            'state_labels': ['x', 'y', 'theta'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 2,
        }

class Dubins3D(Dynamics):
    def __init__(self, goalR:float, velocity:float, omega_max:float, angle_alpha_factor:float, set_mode:str, freeze_model: bool):
        self.goalR = goalR
        self.velocity = velocity
        self.omega_max = omega_max
        self.angle_alpha_factor = angle_alpha_factor
        self.freeze_model = freeze_model
        super().__init__(
            loss_type='brt_hjivi', set_mode=set_mode,
            state_dim=3, input_dim=4, control_dim=1, disturbance_dim=0,
            state_mean=[0, 0, 0],
            state_var=[1, 1, self.angle_alpha_factor*math.pi],
            value_mean=0.25,
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact"
        )

    def state_test_range(self):
        return [
            [-1, 1],
            [-1, 1],
            [-math.pi, math.pi],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # Dubins3D dynamics
    # \dot x    = v \cos \theta
    # \dot y    = v \sin \theta
    # \dot \theta = u
    def dsdt(self, state, control, disturbance):
        if self.freeze_model:
            raise NotImplementedError
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = self.velocity*torch.cos(state[..., 2])
        dsdt[..., 1] = self.velocity*torch.sin(state[..., 2])
        dsdt[..., 2] = control[..., 0]
        return dsdt

    def boundary_fn(self, state):
        return torch.norm(state[..., :2], dim=-1) - self.goalR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        if self.freeze_model:
            raise NotImplementedError
        if self.set_mode == 'reach':
            return self.velocity*(torch.cos(state[..., 2]) * dvds[..., 0] + torch.sin(state[..., 2]) * dvds[..., 1]) - self.omega_max * torch.abs(dvds[..., 2])
        elif self.set_mode == 'avoid':
            return self.velocity*(torch.cos(state[..., 2]) * dvds[..., 0] + torch.sin(state[..., 2]) * dvds[..., 1]) + self.omega_max * torch.abs(dvds[..., 2])

    def optimal_control(self, state, dvds):
        if self.set_mode == 'reach':
            return (-self.omega_max*torch.sign(dvds[..., 2]))[..., None]
        elif self.set_mode == 'avoid':
            return (self.omega_max*torch.sign(dvds[..., 2]))[..., None]

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [0, 0, 0],
            'state_labels': ['x', 'y', r'$\theta$'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 2,
        }

class Dubins4D(Dynamics):
    def __init__(self, bound_mode:str):
        self.vMin = 0.2
        self.vMax = 14.8
        self.collisionR = 1.5
        self.bound_mode = bound_mode
        assert self.bound_mode in ['v1', 'v2']

        xMean = 0
        yMean = 0
        thetaMean = 0
        vMean = 7.5
        aMean = 0
        oMean = 0

        xVar = 10
        yVar = 10
        thetaVar = 1.2*math.pi
        vVar = 7.5
        aVar = 10
        oVar = 3*math.pi if self.bound_mode == 'v1' else 2.0

        super().__init__(
            loss_type='brt_hjivi',
            state_dim=14, input_dim=15,  control_dim=2, disturbance_dim=0,
            state_mean=[xMean, yMean, thetaMean, vMean, xMean, yMean, aMean, aMean, oMean, oMean, aMean, aMean, oMean, oMean],
            state_var=[xVar, yVar, thetaVar, vVar, xVar, yVar, aVar, aVar, oVar, oVar, aVar, aVar, oVar, oVar],
            value_mean=13,
            value_var=14,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-1, 1],
            [-1, 1],
            [-math.pi, math.pi],
            [self.vMin, self.vMax],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    def boundary_fn(self, state):
        return torch.norm(state[..., 0:2] - state[..., 4:6], dim=-1) - self.collisionR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        raise NotImplementedError

    def dsdt(self, state, control, disturbance):
        raise NotImplementedError

    def hamiltonian(self, state, dvds):
        raise NotImplementedError

    def optimal_control(self, state, dvds):
        raise NotImplementedError

    def optimal_disturbance(self, state, dvds):
        raise NotImplementedError

    def plot_config(self):
        raise NotImplementedError

class NarrowPassage(Dynamics):
    def __init__(self, avoid_fn_weight:float, avoid_only:bool):
        self.L = 2.0

        # # Target positions
        self.goalX = [6.0, -6.0]
        self.goalY = [-1.4, 1.4]

        # State bounds
        self.vMin = 0.001
        self.vMax = 6.50
        self.phiMin = -0.3*math.pi + 0.001
        self.phiMax = 0.3*math.pi - 0.001

        # Control bounds
        self.aMin = -4.0
        self.aMax = 2.0
        self.psiMin = -3.0*math.pi
        self.psiMax = 3.0*math.pi

        # Lower and upper curb positions (in the y direction)
        self.curb_positions = [-2.8, 2.8]

        # Stranded car position
        self.stranded_car_pos = [0.0, -1.8]

        self.avoid_fn_weight = avoid_fn_weight

        self.avoid_only = avoid_only

        super().__init__(
            loss_type='brt_hjivi' if self.avoid_only else 'brat_hjivi', set_mode='avoid' if self.avoid_only else 'reach',
            state_dim=10, input_dim=11, control_dim=4, disturbance_dim=0,
            # state = [x1, y1, th1, v1, phi1, x2, y2, th2, v2, phi2]
            state_mean=[
                0, 0, 0, 3, 0,
                0, 0, 0, 3, 0
            ],
            state_var=[
                8.0, 3.8, 1.2*math.pi, 4.0, 1.2*0.3*math.pi,
                8.0, 3.8, 1.2*math.pi, 4.0, 1.2*0.3*math.pi,
            ],
            value_mean=0.25*8.0,
            value_var=0.5*8.0,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-8, 8],
            [-3.8, 3.8],
            [-math.pi, math.pi],
            [-1, 7],
            [-0.3*math.pi, 0.3*math.pi],
            [-8, 8],
            [-3.8, 3.8],
            [-math.pi, math.pi],
            [-1, 7],
            [-0.3*math.pi, 0.3*math.pi],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 4] = (wrapped_state[..., 4] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 7] = (wrapped_state[..., 7] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 9] = (wrapped_state[..., 9] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # NarrowPassage dynamics
    # \dot x   = v * cos(th)
    # \dot y   = v * sin(th)
    # \dot th  = v * tan(phi) / L
    # \dot v   = u1
    # \dot phi = u2
    # \dot x   = ...
    # \dot y   = ...
    # \dot th  = ...
    # \dot v   = ...
    # \dot phi = ...
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = state[..., 3]*torch.cos(state[..., 2])
        dsdt[..., 1] = state[..., 3]*torch.sin(state[..., 2])
        dsdt[..., 2] = state[..., 3]*torch.tan(state[..., 4]) / self.L
        dsdt[..., 3] = control[..., 0]
        dsdt[..., 4] = control[..., 1]
        dsdt[..., 5] = state[..., 8]*torch.cos(state[..., 7])
        dsdt[..., 6] = state[..., 8]*torch.sin(state[..., 7])
        dsdt[..., 7] = state[..., 8]*torch.tan(state[..., 9]) / self.L
        dsdt[..., 8] = control[..., 2]
        dsdt[..., 9] = control[..., 3]
        return dsdt

    def reach_fn(self, state):
        if self.avoid_only:
            raise RuntimeError
        # vehicle 1
        goal_tensor_R1 = torch.tensor([self.goalX[0], self.goalY[0]], device=state.device)
        dist_R1 = torch.norm(state[..., 0:2] - goal_tensor_R1, dim=-1) - self.L
        # vehicle 2
        goal_tensor_R2 = torch.tensor([self.goalX[1], self.goalY[1]], device=state.device)
        dist_R2 = torch.norm(state[..., 5:7] - goal_tensor_R2, dim=-1) - self.L
        return torch.maximum(dist_R1, dist_R2)

    def avoid_fn(self, state):
        # distance from lower curb
        dist_lc_R1 = state[..., 1] - self.curb_positions[0] - 0.5*self.L
        dist_lc_R2 = state[..., 6] - self.curb_positions[0] - 0.5*self.L
        dist_lc = torch.minimum(dist_lc_R1, dist_lc_R2)

        # distance from upper curb
        dist_uc_R1 = self.curb_positions[1] - state[..., 1] - 0.5*self.L
        dist_uc_R2 = self.curb_positions[1] - state[..., 6] - 0.5*self.L
        dist_uc = torch.minimum(dist_uc_R1, dist_uc_R2)

        # distance from the stranded car
        stranded_car_pos = torch.tensor(self.stranded_car_pos, device=state.device)
        dist_stranded_R1 = torch.norm(state[..., 0:2] - stranded_car_pos, dim=-1) - self.L
        dist_stranded_R2 = torch.norm(state[..., 5:7] - stranded_car_pos, dim=-1) - self.L
        dist_stranded = torch.minimum(dist_stranded_R1, dist_stranded_R2)

        # distance between the vehicles themselves
        dist_R1R2 = torch.norm(state[..., 0:2] - state[..., 5:7], dim=-1) - self.L

        return self.avoid_fn_weight * torch.min(torch.min(torch.min(dist_lc, dist_uc), dist_stranded), dist_R1R2)

    def boundary_fn(self, state):
        if self.avoid_only:
            return self.avoid_fn(state)
        else:
            return torch.maximum(self.reach_fn(state), -self.avoid_fn(state))

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        if self.avoid_only:
            return torch.min(self.avoid_fn(state_traj), dim=-1).values
        else:
            # return min_t max{l(x(t)), max_k_up_to_t{-g(x(k))}}, where l(x) is reach_fn, g(x) is avoid_fn
            reach_values = self.reach_fn(state_traj)
            avoid_values = self.avoid_fn(state_traj)
            return torch.min(torch.maximum(reach_values, torch.cummax(-avoid_values, dim=-1).values), dim=-1).values

    def hamiltonian(self, state, dvds):
        optimal_control = self.optimal_control(state, dvds)
        return state[..., 3] * torch.cos(state[..., 2]) * dvds[..., 0] + \
               state[..., 3] * torch.sin(state[..., 2]) * dvds[..., 1] + \
               state[..., 3] * torch.tan(state[..., 4]) * dvds[..., 2] / self.L + \
               optimal_control[..., 0] * dvds[..., 3] + \
               optimal_control[..., 1] * dvds[..., 4] + \
               state[..., 8] * torch.cos(state[..., 7]) * dvds[..., 5] + \
               state[..., 8] * torch.sin(state[..., 7]) * dvds[..., 6] + \
               state[..., 8] * torch.tan(state[..., 9]) * dvds[..., 7] / self.L + \
               optimal_control[..., 2] * dvds[..., 8] + \
               optimal_control[..., 3] * dvds[..., 9]

    def optimal_control(self, state, dvds):
        a1_min = self.aMin * (state[..., 3] > self.vMin)
        a1_max = self.aMax * (state[..., 3] < self.vMax)
        psi1_min = self.psiMin * (state[..., 4] > self.phiMin)
        psi1_max = self.psiMax * (state[..., 4] < self.phiMax)
        a2_min = self.aMin * (state[..., 8] > self.vMin)
        a2_max = self.aMax * (state[..., 8] < self.vMax)
        psi2_min = self.psiMin * (state[..., 9] > self.phiMin)
        psi2_max = self.psiMax * (state[..., 9] < self.phiMax)

        if self.avoid_only:
            a1 = torch.where(dvds[..., 3] < 0, a1_min, a1_max)
            psi1 = torch.where(dvds[..., 4] < 0, psi1_min, psi1_max)
            a2 = torch.where(dvds[..., 8] < 0, a2_min, a2_max)
            psi2 = torch.where(dvds[..., 9] < 0, psi2_min, psi2_max)

        else:
            a1 = torch.where(dvds[..., 3] > 0, a1_min, a1_max)
            psi1 = torch.where(dvds[..., 4] > 0, psi1_min, psi1_max)
            a2 = torch.where(dvds[..., 8] > 0, a2_min, a2_max)
            psi2 = torch.where(dvds[..., 9] > 0, psi2_min, psi2_max)

        return torch.cat((a1[..., None], psi1[..., None], a2[..., None], psi2[..., None]), dim=-1)

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [
                -6.0, -1.4, 0.0, 6.5, 0.0,
                -6.0, 1.4, -math.pi, 0.0, 0.0
            ],
            'state_labels': [
                r'$x_1$', r'$y_1$', r'$\theta_1$', r'$v_1$', r'$\phi_1$',
                r'$x_2$', r'$y_2$', r'$\theta_2$', r'$v_2$', r'$\phi_2$',
            ],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 2,
        }

class ReachAvoidRocketLanding(Dynamics):
    def __init__(self):
        super().__init__(
            loss_type='brat_hjivi', set_mode='reach',
            state_dim=6, input_dim=7, control_dim=2, disturbance_dim=0,
            state_mean=[0.0, 80.0, 0.0, 0.0, 0.0, 0.0],
            state_var=[150.0, 70.0, 1.2*math.pi, 200.0, 200.0, 10.0],
            value_mean=0.0,
            value_var=1.0,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-150, 150],
            [10, 150],
            [-math.pi, math.pi],
            [-200, 200],
            [-200, 200],
            [-10, 10],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # \dot x = v_x
    # \dot y = v_y
    # \dot th = w
    # \dot v_x = u1 * cos(th) - u2 sin(th)
    # \dot v_y = u1 * sin(th) + u2 cos(th) - 9.81
    # \dot w = 0.3 * u1
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = state[..., 3]
        dsdt[..., 1] = state[..., 4]
        dsdt[..., 2] = state[..., 5]
        dsdt[..., 3] = control[..., 0]*torch.cos(state[..., 2]) - control[..., 1]*torch.sin(state[..., 2])
        dsdt[..., 4] = control[..., 0]*torch.sin(state[..., 2]) + control[..., 1]*torch.cos(state[..., 2]) - 9.81
        dsdt[..., 5] = 0.3*control[..., 0]
        return dsdt

    def reach_fn(self, state):
        # Only target set in the xy direction
        # Target set position in x direction
        dist_x = torch.abs(state[..., 0]) - 20.0 #[-20, 150] boundary_fn range

        # Target set position in y direction
        dist_y = state[..., 1] - 20.0  #[-10, 130] boundary_fn range

        # First compute the target function as you normally would but then normalize it later.
        max_dist = torch.max(dist_x, dist_y)
        return torch.where((max_dist >= 0), max_dist/150.0, max_dist/10.0)

    def avoid_fn(self, state):
        # distance to floor
        dist_y = state[..., 1]

        # distance to wall
        wall_left = -30
        wall_right = -20
        wall_bottom = 0
        wall_top = 100
        dist_left = wall_left - state[..., 0]
        dist_right = state[..., 0] - wall_right
        dist_bottom = wall_bottom - state[..., 1]
        dist_top = state[..., 1] - wall_top
        dist_wall_x = torch.max(dist_left, dist_right)
        dist_wall_y = torch.max(dist_bottom, dist_top)
        dist_wall = torch.norm(torch.cat((torch.max(torch.tensor(0), dist_wall_x).unsqueeze(-1), torch.max(torch.tensor(0), dist_wall_y).unsqueeze(-1)), dim=-1), dim=-1) + torch.min(torch.tensor(0), torch.max(dist_wall_x, dist_wall_y))

        return torch.min(dist_y, dist_wall)

    def boundary_fn(self, state):
        return torch.maximum(self.reach_fn(state), -self.avoid_fn(state))

    def sample_target_state(self, num_samples):
        target_state_range = self.state_test_range()
        target_state_range[0] = [-20, 20] # y in [-20, 20]
        target_state_range[1] = [10, 20]  # z in [10, 20]
        target_state_range = torch.tensor(target_state_range)
        return target_state_range[:, 0] + torch.rand(num_samples, self.state_dim)*(target_state_range[:, 1] - target_state_range[:, 0])

    def cost_fn(self, state_traj):
        # return min_t max{l(x(t)), max_k_up_to_t{-g(x(k))}}, where l(x) is reach_fn, g(x) is avoid_fn
        reach_values = self.reach_fn(state_traj)
        avoid_values = self.avoid_fn(state_traj)
        return torch.min(torch.maximum(reach_values, torch.cummax(-avoid_values, dim=-1).values), dim=-1).values

    def hamiltonian(self, state, dvds):
        # Control Hamiltonian
        u1_coeff = dvds[..., 3] * torch.cos(state[..., 2]) + dvds[..., 4] * torch.sin(state[..., 2]) + 0.3 * dvds[..., 5]
        u2_coeff = -dvds[..., 3] * torch.sin(state[..., 2]) + dvds[..., 4] * torch.cos(state[..., 2])
        ham_ctrl = -250.0 * torch.sqrt(u1_coeff * u1_coeff + u2_coeff * u2_coeff)
        # Constant Hamiltonian
        ham_constant = dvds[..., 0] * state[..., 3] + dvds[..., 1] * state[..., 4] + \
                      dvds[..., 2] * state[..., 5]  - dvds[..., 4] * 9.81
        # Compute the Hamiltonian
        ham_vehicle = ham_ctrl + ham_constant
        return ham_vehicle

    def optimal_control(self, state, dvds):
        u1_coeff = dvds[..., 3] * torch.cos(state[..., 2]) + dvds[..., 4] * torch.sin(state[..., 2]) + 0.3 * dvds[..., 5]
        u2_coeff = -dvds[..., 3] * torch.sin(state[..., 2]) + dvds[..., 4] * torch.cos(state[..., 2])
        opt_angle = torch.atan2(u2_coeff, u1_coeff) + math.pi
        return torch.cat((250.0 * torch.cos(opt_angle)[..., None], 250.0 * torch.sin(opt_angle)[..., None]), dim=-1)

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [-100, 120, 0, 150, -5, 0.0],
            'state_labels': ['x', 'y', r'$\theta$', r'$v_x$', r'$v_y$', r'$\omega'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 4,
        }

class RocketLanding(Dynamics):
    def __init__(self):
        super().__init__(
            loss_type='brt_hjivi', set_mode='reach',
            state_dim=6, input_dim=8, control_dim=2, disturbance_dim=0,
            state_mean=[0.0, 80.0, 0.0, 0.0, 0.0, 0.0],
            state_var=[150.0, 70.0, 1.2*math.pi, 200.0, 200.0, 10.0],
            value_mean=0.0,
            value_var=1.0,
            value_normto=0.02,
            deepreach_model="exact",
        )

    # convert model input to real coord
    def input_to_coord(self, input):
        input = input[..., :-1]
        coord = input.clone()
        coord[..., 1:] = (input[..., 1:] * self.state_var.to(device=input.device)) + self.state_mean.to(device=input.device)
        return coord

    # convert real coord to model input
    def coord_to_input(self, coord):
        input = coord.clone()
        input[..., 1:] = (coord[..., 1:] - self.state_mean.to(device=coord.device)) / self.state_var.to(device=coord.device)
        input = torch.cat((input, torch.zeros((*input.shape[:-1], 1), device=input.device)), dim=-1)
        return input

    # convert model io to real value
    def io_to_value(self, input, output):
        if self.deepreach_model=="diff":
            return (output * self.value_var / self.value_normto) + self.boundary_fn(self.input_to_coord(input)[..., 1:])
        else:
            return (output * self.value_var / self.value_normto) + self.value_mean

    # convert model io to real dv
    def io_to_dv(self, input, output):
        dodi = jacobian(output.unsqueeze(dim=-1), input)[0].squeeze(dim=-2)[..., :-1]

        if self.deepreach_model=="diff":
            dvdt = (self.value_var / self.value_normto) * dodi[..., 0]

            dvds_term1 = (self.value_var / self.value_normto / self.state_var.to(device=dodi.device)) * dodi[..., 1:]
            state = self.input_to_coord(input)[..., 1:]
            dvds_term2 = jacobian(self.boundary_fn(state).unsqueeze(dim=-1), state)[0].squeeze(dim=-2)
            dvds = dvds_term1 + dvds_term2

        else:
            dvdt = (self.value_var / self.value_normto) * dodi[..., 0]
            dvds = (self.value_var / self.value_normto / self.state_var.to(device=dodi.device)) * dodi[..., 1:]

        return torch.cat((dvdt.unsqueeze(dim=-1), dvds), dim=-1)


    def state_test_range(self):
        return [
            [-150, 150],
            [10, 150],
            [-math.pi, math.pi],
            [-200, 200],
            [-200, 200],
            [-10, 10],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # \dot x = v_x
    # \dot y = v_y
    # \dot th = w
    # \dot v_x = u1 * cos(th) - u2 sin(th)
    # \dot v_y = u1 * sin(th) + u2 cos(th) - 9.81
    # \dot w = 0.3 * u1
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = state[..., 3]
        dsdt[..., 1] = state[..., 4]
        dsdt[..., 2] = state[..., 5]
        dsdt[..., 3] = control[..., 0]*torch.cos(state[..., 2]) - control[..., 1]*torch.sin(state[..., 2])
        dsdt[..., 4] = control[..., 0]*torch.sin(state[..., 2]) + control[..., 1]*torch.cos(state[..., 2]) - 9.81
        dsdt[..., 5] = 0.3*control[..., 0]
        return dsdt

    def boundary_fn(self, state):
        # Only target set in the yz direction
        # Target set position in y direction
        dist_y = torch.abs(state[..., 0]) - 20.0 #[-20, 150] boundary_fn range

        # Target set position in z direction
        dist_z = state[..., 1] - 20.0  #[-10, 130] boundary_fn range

        # First compute the l(x) as you normally would but then normalize it later.
        lx = torch.max(dist_y, dist_z)
        return torch.where((lx >= 0), lx/150.0, lx/10.0)

    def sample_target_state(self, num_samples):
        target_state_range = self.state_test_range()
        target_state_range[0] = [-20, 20] # y in [-20, 20]
        target_state_range[1] = [10, 20]  # z in [10, 20]
        target_state_range = torch.tensor(target_state_range)
        return target_state_range[:, 0] + torch.rand(num_samples, self.state_dim)*(target_state_range[:, 1] - target_state_range[:, 0])

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        # Control Hamiltonian
        u1_coeff = dvds[..., 3] * torch.cos(state[..., 2]) + dvds[..., 4] * torch.sin(state[..., 2]) + 0.3 * dvds[..., 5]
        u2_coeff = -dvds[..., 3] * torch.sin(state[..., 2]) + dvds[..., 4] * torch.cos(state[..., 2])
        ham_ctrl = -250.0 * torch.sqrt(u1_coeff * u1_coeff + u2_coeff * u2_coeff)
        # Constant Hamiltonian
        ham_constant = dvds[..., 0] * state[..., 3] + dvds[..., 1] * state[..., 4] + \
                      dvds[..., 2] * state[..., 5]  - dvds[..., 4] * 9.81
        # Compute the Hamiltonian
        ham_vehicle = ham_ctrl + ham_constant
        return ham_vehicle

    def optimal_control(self, state, dvds):
        u1_coeff = dvds[..., 3] * torch.cos(state[..., 2]) + dvds[..., 4] * torch.sin(state[..., 2]) + 0.3 * dvds[..., 5]
        u2_coeff = -dvds[..., 3] * torch.sin(state[..., 2]) + dvds[..., 4] * torch.cos(state[..., 2])
        opt_angle = torch.atan2(u2_coeff, u1_coeff) + math.pi
        return torch.cat((250.0 * torch.cos(opt_angle)[..., None], 250.0 * torch.sin(opt_angle)[..., None]), dim=-1)

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [-100, 120, 0, 150, -5, 0.0],
            'state_labels': ['x', 'y', r'$\theta$', r'$v_x$', r'$v_y$', r'$\omega'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 4,
        }

class Quadrotor(Dynamics):
    def __init__(self, collisionR:float, thrust_max:float, set_mode:str):
        self.thrust_max = thrust_max
        self.m=1 #mass
        self.arm_l=0.17
        self.CT=1
        self.CM=0.016
        self.Gz=-9.8

        self.thrust_max = thrust_max
        self.collisionR = collisionR


        super().__init__(
            loss_type='brt_hjivi', set_mode=set_mode,
            state_dim=13, input_dim=14, control_dim=4, disturbance_dim=0,
            state_mean=[0 for i in range(13)],
            state_var=[1.5, 1.5, 1.5, 1, 1, 1, 1, 10, 10 ,10 ,10 ,10 ,10],
            value_mean=(math.sqrt(1.5**2+1.5**2+1.5**2)-2*self.collisionR)/2,
            value_var=math.sqrt(1.5**2+1.5**2+1.5**2),
            value_normto=0.02,
            deepreach_model="exact"
        )

    def state_test_range(self):
        return [
            [-1.5, 1.5],
            [-1.5, 1.5],
            [-1.5, 1.5],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-10, 10],
            [-10, 10],
            [-10, 10],
            [-10, 10],
            [-10, 10],
            [-10, 10],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        return wrapped_state

    # Dubins3D dynamics
    # \dot x    = v \cos \theta
    # \dot y    = v \sin \theta
    # \dot \theta = u
    def dsdt(self, state, control, disturbance):
        qw = state[..., 3] * 1.0
        qx = state[..., 4] * 1.0
        qy = state[..., 5] * 1.0
        qz = state[..., 6] * 1.0
        vx = state[..., 7] * 1.0
        vy = state[..., 8] * 1.0
        vz = state[..., 9] * 1.0
        wx = state[..., 10] * 1.0
        wy = state[..., 11] * 1.0
        wz = state[..., 12] * 1.0
        u1 = control[...,0] * 1.0
        u2 = control[...,1] * 1.0
        u3 = control[...,2] * 1.0
        u4 = control[...,3] * 1.0


        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = vx
        dsdt[..., 1] = vy
        dsdt[..., 2] = vz
        dsdt[..., 3] = -(wx*qx+wy*qy+wz*qz)/2.0
        dsdt[..., 4] =  (wx*qw+wz*qy-wy*qz)/2.0
        dsdt[..., 5] = (wy*qw-wz*qx+wx*qz)/2.0
        dsdt[..., 6] = (wz*qw+wy*qx-wx*qy)/2.0
        dsdt[..., 7] = 2*(qw*qy+qx*qz)*self.CT/self.m*(u1+u2+u3+u4)
        dsdt[..., 8] =2*(-qw*qx+qy*qz)*self.CT/self.m*(u1+u2+u3+u4)
        dsdt[..., 9] =self.Gz+(1-2*torch.pow(qx,2)-2*torch.pow(qy,2))*self.CT/self.m*(u1+u2+u3+u4)
        dsdt[..., 10] = 4*math.sqrt(2)*self.CT*(u1-u2-u3+u4)/(3*self.arm_l*self.m)-5*wy*wz/9.0
        dsdt[..., 11] = 4*math.sqrt(2)*self.CT*(-u1-u2+u3+u4)/(3*self.arm_l*self.m)+5*wx*wz/9.0
        dsdt[..., 12] =12*self.CT*self.CM/(7*self.arm_l**2*self.m)*(u1-u2+u3-u4)
        return dsdt

    def boundary_fn(self, state):
        return torch.norm(state[..., :3], dim=-1) - self.collisionR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        if self.set_mode == 'reach':
            raise NotImplementedError

        elif self.set_mode == 'avoid':
            qw = state[..., 3] * 1.0
            qx = state[..., 4] * 1.0
            qy = state[..., 5] * 1.0
            qz = state[..., 6] * 1.0
            vx = state[..., 7] * 1.0
            vy = state[..., 8] * 1.0
            vz = state[..., 9] * 1.0
            wx = state[..., 10] * 1.0
            wy = state[..., 11] * 1.0
            wz = state[..., 12] * 1.0


            C1=2*(qw*qy+qx*qz)*self.CT/self.m
            C2=2*(-qw*qx+qy*qz)*self.CT/self.m
            C3=(1-2*torch.pow(qx,2)-2*torch.pow(qy,2))*self.CT/self.m
            C4=4*math.sqrt(2)*self.CT/(3*self.arm_l*self.m)
            C5=4*math.sqrt(2)*self.CT/(3*self.arm_l*self.m)
            C6=12*self.CT*self.CM/(7*self.arm_l**2*self.m)

            # Compute the hamiltonian for the quadrotor
            ham= dvds[..., 0]*vx + dvds[..., 1]*vy+ dvds[..., 2]*vz
            ham+= -dvds[..., 3]* (wx*qx+wy*qy+wz*qz)/2.0
            ham+= dvds[..., 4]*(wx*qw+wz*qy-wy*qz)/2.0
            ham+= dvds[..., 5]*(wy*qw-wz*qx+wx*qz)/2.0
            ham+= dvds[..., 6]*(wz*qw+wy*qx-wx*qy)/2.0
            ham+= dvds[..., 9]*-9.8
            ham+= -dvds[..., 10]*5*wy*wz/9.0+ dvds[..., 11]*5*wx*wz/9.0

            ham+=torch.abs(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                +dvds[..., 10]*C4-dvds[..., 11]*C5+dvds[..., 12]*C6)*self.thrust_max

            ham+=torch.abs(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                -dvds[..., 10]*C4-dvds[..., 11]*C5-dvds[..., 12]*C6)*self.thrust_max

            ham+=torch.abs(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                -dvds[..., 10]*C4+dvds[..., 11]*C5+dvds[..., 12]*C6)*self.thrust_max

            ham+=torch.abs(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                +dvds[..., 10]*C4+dvds[..., 11]*C5-dvds[..., 12]*C6)*self.thrust_max

            return ham

    def optimal_control(self, state, dvds):
        if self.set_mode == 'reach':
            raise NotImplementedError
        elif self.set_mode == 'avoid':
            qw = state[..., 3] * 1.0
            qx = state[..., 4] * 1.0
            qy = state[..., 5] * 1.0
            qz = state[..., 6] * 1.0


            C1=2*(qw*qy+qx*qz)*self.CT/self.m
            C2=2*(-qw*qx+qy*qz)*self.CT/self.m
            C3=(1-2*torch.pow(qx,2)-2*torch.pow(qy,2))*self.CT/self.m
            C4=4*math.sqrt(2)*self.CT/(3*self.arm_l*self.m)
            C5=4*math.sqrt(2)*self.CT/(3*self.arm_l*self.m)
            C6=12*self.CT*self.CM/(7*self.arm_l**2*self.m)


            u1=self.thrust_max*torch.sign(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                +dvds[..., 10]*C4-dvds[..., 11]*C5+dvds[..., 12]*C6)
            u2=self.thrust_max*torch.sign(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                -dvds[..., 10]*C4-dvds[..., 11]*C5-dvds[..., 12]*C6)
            u3=self.thrust_max*torch.sign(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                -dvds[..., 10]*C4+dvds[..., 11]*C5+dvds[..., 12]*C6)
            u4=self.thrust_max*torch.sign(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                +dvds[..., 10]*C4+dvds[..., 11]*C5-dvds[..., 12]*C6)

        return torch.cat((u1[..., None], u2[..., None], u3[..., None], u4[..., None]), dim=-1)

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            'state_labels': ['x', 'y', 'z', 'qw', 'qx', 'qy', 'qz', 'vx', 'vy', 'vz', 'wx', 'wy', 'wz'],
            'x_axis_idx': 0,
            'y_axis_idx': 2,
            'z_axis_idx': 7,
        }

class MultiVehicleCollision(Dynamics):
    def __init__(self):
        self.angle_alpha_factor = 1.2
        self.velocity = 0.6
        self.omega_max = 1.1
        self.collisionR = 0.25
        super().__init__(
            loss_type='brt_hjivi', set_mode='avoid',
            state_dim=9, input_dim=10, control_dim=3, disturbance_dim=0,
            state_mean=[
                0, 0,
                0, 0,
                0, 0,
                0, 0, 0,
            ],
            state_var=[
                1, 1,
                1, 1,
                1, 1,
                self.angle_alpha_factor*math.pi, self.angle_alpha_factor*math.pi, self.angle_alpha_factor*math.pi,
            ],
            value_mean=0.25,
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact"
        )

    def state_test_range(self):
        return [
            [-1, 1], [-1, 1],
            [-1, 1], [-1, 1],
            [-1, 1], [-1, 1],
            [-math.pi, math.pi], [-math.pi, math.pi], [-math.pi, math.pi],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 6] = (wrapped_state[..., 6] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 7] = (wrapped_state[..., 7] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 8] = (wrapped_state[..., 8] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # dynamics (per car)
    # \dot x    = v \cos \theta
    # \dot y    = v \sin \theta
    # \dot \theta = u
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = self.velocity*torch.cos(state[..., 6])
        dsdt[..., 1] = self.velocity*torch.sin(state[..., 6])
        dsdt[..., 2] = self.velocity*torch.cos(state[..., 7])
        dsdt[..., 3] = self.velocity*torch.sin(state[..., 7])
        dsdt[..., 4] = self.velocity*torch.cos(state[..., 8])
        dsdt[..., 5] = self.velocity*torch.sin(state[..., 8])
        dsdt[..., 6] = control[..., 0]
        dsdt[..., 7] = control[..., 1]
        dsdt[..., 8] = control[..., 2]
        return dsdt

    def boundary_fn(self, state):
        boundary_values = torch.norm(state[..., 0:2] - state[..., 2:4], dim=-1) - self.collisionR
        for i in range(1, 2):
            boundary_values_current = torch.norm(state[..., 0:2] - state[..., 2*(i+1):2*(i+1)+2], dim=-1) - self.collisionR
            boundary_values = torch.min(boundary_values, boundary_values_current)
        # Collision cost between the evaders themselves
        for i in range(2):
            for j in range(i+1, 2):
                evader1_coords_index = (i+1)*2
                evader2_coords_index = (j+1)*2
                boundary_values_current = torch.norm(state[..., evader1_coords_index:evader1_coords_index+2] - state[..., evader2_coords_index:evader2_coords_index+2], dim=-1) - self.collisionR
                boundary_values = torch.min(boundary_values, boundary_values_current)
        return boundary_values

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        # Compute the hamiltonian for the ego vehicle
        ham = self.velocity*(torch.cos(state[..., 6]) * dvds[..., 0] + torch.sin(state[..., 6]) * dvds[..., 1]) + self.omega_max * torch.abs(dvds[..., 6])
        # Hamiltonian effect due to other vehicles
        ham += self.velocity*(torch.cos(state[..., 7]) * dvds[..., 2] + torch.sin(state[..., 7]) * dvds[..., 3]) + self.omega_max * torch.abs(dvds[..., 7])
        ham += self.velocity*(torch.cos(state[..., 8]) * dvds[..., 4] + torch.sin(state[..., 8]) * dvds[..., 5]) + self.omega_max * torch.abs(dvds[..., 8])
        return ham

    def optimal_control(self, state, dvds):
        return self.omega_max*torch.sign(dvds[..., [6, 7, 8]])

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [
                0, 0,
                -0.4, 0,
                0.4, 0,
                math.pi/2, math.pi/4, 3*math.pi/4,
            ],
            'state_labels': [
                r'$x_1$', r'$y_1$',
                r'$x_2$', r'$y_2$',
                r'$x_3$', r'$y_3$',
                r'$\theta_1$', r'$\theta_2$', r'$\theta_3$',
            ],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 6,
        }

###############################################################
# Custom Dynamics for 2d Planar Robot
###############################################################
class PlanarRobot2D(Dynamics):
    def __init__(
        self,
        goalR: float,
        velocity: float,
        set_mode: str = 'avoid',
        freeze_model: bool = False
    ):
        """
        2D constant-speed robot dynamics for reach/avoid problems.

        Args:
            goalR: Radius of the obstacle (avoid) or goal (reach) circle.
            velocity: Constant forward speed (m/s).
            set_mode: 'reach' or 'avoid'.
            freeze_model: If True, dsdt/hamiltonian will raise NotImplementedError.
        """
        self.goalR = goalR
        self.velocity = velocity
        self.freeze_model = freeze_model

        super().__init__(
            loss_type='brt_hjivi',
            set_mode=set_mode,
            state_dim=2, # [p_x, p_y]
            input_dim=3, # [p_x, p_y, t]
            control_dim=1, # [theta]
            disturbance_dim=0,

            # Normalize x,y from [-2,2] to [-1,1]
            state_mean=[0.0, 0.0],
            state_var=[2.0, 2.0],

            value_mean=0.0, # value_mean is not used in 'exact' mode
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-2.0, 2.0], # p_x
            [-2.0, 2.0], # p_y
        ]

    def equivalent_wrapped_state(self, state):
        return state

    def dsdt(self, state, control, disturbance):
        # \dot p_x = v \cos \theta
        # \dot p_y = v \sin \theta
        if self.freeze_model:
            raise NotImplementedError
        theta = control[..., 0]
        ds = torch.zeros_like(state)
        ds[..., 0] = self.velocity * torch.cos(theta)
        ds[..., 1] = self.velocity * torch.sin(theta)
        return ds

    def boundary_fn(self, state):
        return torch.norm(state, dim=-1) - self.goalR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        raise NotImplementedError
        # return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        if self.freeze_model:
            raise NotImplementedError
        norm = torch.sqrt(dvds[..., 0]**2 + dvds[..., 1]**2)
        if self.set_mode == 'reach':
            return  - self.velocity * norm
        elif self.set_mode == 'avoid':
            return  self.velocity * norm

    def optimal_control(self, state, dvds):
        raise NotImplementedError

    def optimal_disturbance(self, state, dvds):
        raise NotImplementedError

    def plot_config(self):
        return {
            "state_slices": [0, 0],
            "state_labels": ["p_x", "p_y"],
            "x_axis_idx": 0,
            "y_axis_idx": 1,
            "z_axis_idx": -1,
        }

###############################################################


### 6. Experiments Code

In [12]:
class Experiment(ABC):
    def __init__(self, model, dataset, experiment_dir, use_wandb):
        self.model = model
        self.dataset = dataset
        self.experiment_dir = experiment_dir
        self.use_wandb = use_wandb

    @abstractmethod
    def init_special(self):
        raise NotImplementedError

    def _load_checkpoint(self, epoch):
        if epoch == -1:
            model_path = os.path.join(self.experiment_dir, 'training', 'checkpoints', 'model_final.pth')
            self.model.load_state_dict(torch.load(model_path))
        else:
            model_path = os.path.join(self.experiment_dir, 'training', 'checkpoints', 'model_epoch_%04d.pth' % epoch)
            self.model.load_state_dict(torch.load(model_path)['model'])

    def validate(self, device, epoch, save_path, x_resolution, y_resolution, z_resolution, time_resolution):
        was_training = self.model.training
        self.model.eval()
        self.model.requires_grad_(False)

        plot_config = self.dataset.dynamics.plot_config()
        state_test_range = self.dataset.dynamics.state_test_range()

        x_idx = plot_config['x_axis_idx']
        y_idx = plot_config['y_axis_idx']
        z_idx = plot_config.get('z_axis_idx', -1)

        x_min, x_max = state_test_range[x_idx]
        y_min, y_max = state_test_range[y_idx]
        xs = torch.linspace(x_min, x_max, x_resolution)
        ys = torch.linspace(y_min, y_max, y_resolution)
        xys = torch.cartesian_prod(xs, ys)

        # Determine z slices: if z_idx == -1, use a single dummy slice
        if z_idx == -1:
            zs = torch.tensor([0.0], dtype=torch.float32)
        else:
            z_min, z_max = state_test_range[z_idx]
            zs = torch.linspace(z_min, z_max, z_resolution)

        times = torch.linspace(0, self.dataset.tMax, time_resolution)
        fig = plt.figure(figsize=(5*len(times), 5*len(zs)))

        # Loop over time and z-slices
        for i, t in enumerate(times):
            for j, z_val in enumerate(zs):
                coords = torch.zeros(x_resolution*y_resolution, self.dataset.dynamics.state_dim + 1)
                coords[:, 0] = t
                coords[:, 1:] = torch.tensor(plot_config['state_slices'])

                coords[:, 1 + x_idx] = xys[:, 0]
                coords[:, 1 + y_idx] = xys[:, 1]

                # Only fill z if it's a real 3D axis
                if z_idx != -1:
                    coords[:, 1 + z_idx] = z_val

                with torch.no_grad():
                    inp = self.dataset.dynamics.coord_to_input(coords.to(device))
                    results = self.model({'coords': inp})
                    values = self.dataset.dynamics.io_to_value(results['model_in'].detach(), results['model_out'].squeeze(dim=-1).detach())

                ax = fig.add_subplot(len(times), len(zs), (j+1) + i*len(zs))
                title_z = (f", {plot_config['state_labels'][z_idx]} = {z_val:.2f}"
                        if z_idx != -1 else "")
                ax.set_title(f"t = {t:.2f}{title_z}")

                img = values.cpu().numpy().reshape(x_resolution, y_resolution).T
                # binary decision boundary: <=0 in red/blue
                mask = (img <= 0).astype(float)
                s = ax.imshow(mask, cmap='bwr', origin='lower', extent=(x_min, x_max, y_min, y_max))
                fig.colorbar(s, ax=ax)

        fig.savefig(save_path)
        if self.use_wandb:
            wandb.log({
                'step': epoch,
                'val_plot': wandb.Image(fig),
            })
        plt.close()

        if was_training:
            self.model.train()
            self.model.requires_grad_(True)

    def train(
            self, device, batch_size, epochs, lr,
            steps_til_summary, epochs_til_checkpoint,
            loss_fn, clip_grad, use_lbfgs, adjust_relative_grads,
            val_x_resolution, val_y_resolution, val_z_resolution, val_time_resolution,
            use_CSL, CSL_lr, CSL_dt, epochs_til_CSL, num_CSL_samples, CSL_loss_frac_cutoff, max_CSL_epochs, CSL_loss_weight, CSL_batch_size,
        ):
        was_eval = not self.model.training
        self.model.train()
        self.model.requires_grad_(True)

        train_dataloader = DataLoader(self.dataset, shuffle=True, batch_size=batch_size, pin_memory=True, num_workers=0)

        optim = torch.optim.Adam(lr=lr, params=self.model.parameters())

        # copy settings from Raissi et al. (2019) and here
        # https://github.com/maziarraissi/PINNs
        if use_lbfgs:
            optim = torch.optim.LBFGS(lr=lr, params=self.model.parameters(), max_iter=50000, max_eval=50000,
                                    history_size=50, line_search_fn='strong_wolfe')

        training_dir = os.path.join(self.experiment_dir, 'training')

        summaries_dir = os.path.join(training_dir, 'summaries')
        if not os.path.exists(summaries_dir):
            os.makedirs(summaries_dir)

        checkpoints_dir = os.path.join(training_dir, 'checkpoints')
        if not os.path.exists(checkpoints_dir):
            os.makedirs(checkpoints_dir)

        writer = SummaryWriter(summaries_dir)

        total_steps = 0

        if adjust_relative_grads:
            new_weight = 1

        with tqdm(total=len(train_dataloader) * epochs) as pbar:
            train_losses = []
            last_CSL_epoch = -1
            for epoch in range(0, epochs):
                if self.dataset.pretrain: # skip CSL
                    last_CSL_epoch = epoch
                time_interval_length = (self.dataset.counter/self.dataset.counter_end)*(self.dataset.tMax-self.dataset.tMin)
                CSL_tMax = self.dataset.tMin + int(time_interval_length/CSL_dt)*CSL_dt

                # self-supervised learning
                for step, (model_input, gt) in enumerate(train_dataloader):
                    start_time = time.time()

                    model_input = {key: value.to(device) for key, value in model_input.items()}
                    gt = {key: value.to(device) for key, value in gt.items()}

                    model_results = self.model({'coords': model_input['model_coords']})

                    states = self.dataset.dynamics.input_to_coord(model_results['model_in'].detach())[..., 1:]
                    values = self.dataset.dynamics.io_to_value(model_results['model_in'].detach(), model_results['model_out'].squeeze(dim=-1))
                    dvs = self.dataset.dynamics.io_to_dv(model_results['model_in'], model_results['model_out'].squeeze(dim=-1))
                    boundary_values = gt['boundary_values']
                    if self.dataset.dynamics.loss_type == 'brat_hjivi':
                        reach_values = gt['reach_values']
                        avoid_values = gt['avoid_values']
                    dirichlet_masks = gt['dirichlet_masks']

                    if self.dataset.dynamics.loss_type == 'brt_hjivi':
                        losses = loss_fn(states, values, dvs[..., 0], dvs[..., 1:], boundary_values, dirichlet_masks, model_results['model_out'])
                    elif self.dataset.dynamics.loss_type == 'brat_hjivi':
                        losses = loss_fn(states, values, dvs[..., 0], dvs[..., 1:], boundary_values, reach_values, avoid_values, dirichlet_masks, model_results['model_out'])
                    else:
                        raise NotImplementedError

                    if use_lbfgs:
                        def closure():
                            optim.zero_grad()
                            train_loss = 0.
                            for loss_name, loss in losses.items():
                                train_loss += loss.mean()
                            train_loss.backward()
                            return train_loss
                        optim.step(closure)

                    # Adjust the relative magnitude of the losses if required
                    if self.dataset.dynamics.deepreach_model in ['vanilla', 'diff'] and adjust_relative_grads:
                        if losses['diff_constraint_hom'] > 0.01:
                            params = OrderedDict(self.model.named_parameters())
                            # Gradients with respect to the PDE loss
                            optim.zero_grad()
                            losses['diff_constraint_hom'].backward(retain_graph=True)
                            grads_PDE = []
                            for key, param in params.items():
                                grads_PDE.append(param.grad.view(-1))
                            grads_PDE = torch.cat(grads_PDE)

                            # Gradients with respect to the boundary loss
                            optim.zero_grad()
                            losses['dirichlet'].backward(retain_graph=True)
                            grads_dirichlet = []
                            for key, param in params.items():
                                grads_dirichlet.append(param.grad.view(-1))
                            grads_dirichlet = torch.cat(grads_dirichlet)

                            # # Plot the gradients
                            # import seaborn as sns
                            # import matplotlib.pyplot as plt
                            # fig = plt.figure(figsize=(5, 5))
                            # ax = fig.add_subplot(1, 1, 1)
                            # ax.set_yscale('symlog')
                            # sns.distplot(grads_PDE.cpu().numpy(), hist=False, kde_kws={"shade": False}, norm_hist=True)
                            # sns.distplot(grads_dirichlet.cpu().numpy(), hist=False, kde_kws={"shade": False}, norm_hist=True)
                            # fig.savefig('gradient_visualization.png')

                            # fig = plt.figure(figsize=(5, 5))
                            # ax = fig.add_subplot(1, 1, 1)
                            # ax.set_yscale('symlog')
                            # grads_dirichlet_normalized = grads_dirichlet * torch.mean(torch.abs(grads_PDE))/torch.mean(torch.abs(grads_dirichlet))
                            # sns.distplot(grads_PDE.cpu().numpy(), hist=False, kde_kws={"shade": False}, norm_hist=True)
                            # sns.distplot(grads_dirichlet_normalized.cpu().numpy(), hist=False, kde_kws={"shade": False}, norm_hist=True)
                            # ax.set_xlim([-1000.0, 1000.0])
                            # fig.savefig('gradient_visualization_normalized.png')

                            # Set the new weight according to the paper
                            # num = torch.max(torch.abs(grads_PDE))
                            num = torch.mean(torch.abs(grads_PDE))
                            den = torch.mean(torch.abs(grads_dirichlet))
                            new_weight = 0.9*new_weight + 0.1*num/den
                            losses['dirichlet'] = new_weight*losses['dirichlet']
                        writer.add_scalar('weight_scaling', new_weight, total_steps)

                    # import ipdb; ipdb.set_trace()

                    train_loss = 0.
                    for loss_name, loss in losses.items():
                        single_loss = loss.mean()

                        if loss_name == 'dirichlet':
                            writer.add_scalar(loss_name, single_loss/new_weight, total_steps)
                        else:
                            writer.add_scalar(loss_name, single_loss, total_steps)
                        train_loss += single_loss

                    train_losses.append(train_loss.item())
                    writer.add_scalar("total_train_loss", train_loss, total_steps)

                    if not total_steps % steps_til_summary:
                        torch.save(self.model.state_dict(),
                                os.path.join(checkpoints_dir, 'model_current.pth'))
                        # summary_fn(model, model_input, gt, model_output, writer, total_steps)

                    if not use_lbfgs:
                        optim.zero_grad()
                        train_loss.backward()

                        if clip_grad:
                            if isinstance(clip_grad, bool):
                                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.)
                            else:
                                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=clip_grad)

                        optim.step()

                    pbar.update(1)

                    if not total_steps % steps_til_summary:
                        tqdm.write("Epoch %d, Total loss %0.6f, iteration time %0.6f" % (epoch, train_loss, time.time() - start_time))
                        if self.use_wandb:
                            wandb.log({
                                'step': epoch,
                                'train_loss': train_loss,
                                'pde_loss': losses['diff_constraint_hom'],
                            })

                    total_steps += 1

                # cost-supervised learning (CSL) phase
                if use_CSL and not self.dataset.pretrain and (epoch-last_CSL_epoch) >= epochs_til_CSL:
                    last_CSL_epoch = epoch

                    # generate CSL datasets
                    self.model.eval()

                    CSL_dataset = scenario_optimization(
                        device=device, model=self.model, policy=self.model, dynamics=self.dataset.dynamics,
                        tMin=self.dataset.tMin, tMax=CSL_tMax, dt=CSL_dt,
                        set_type="BRT", control_type="value", # TODO: implement option for BRS too
                        scenario_batch_size=min(num_CSL_samples, 100000), sample_batch_size=min(10*num_CSL_samples, 1000000),
                        sample_generator=SliceSampleGenerator(dynamics=self.dataset.dynamics, slices=[None]*self.dataset.dynamics.state_dim),
                        sample_validator=ValueThresholdValidator(v_min=float('-inf'), v_max=float('inf')),
                        violation_validator=ValueThresholdValidator(v_min=0.0, v_max=0.0),
                        max_scenarios=num_CSL_samples, tStart_generator=lambda n : torch.zeros(n).uniform_(self.dataset.tMin, CSL_tMax)
                    )
                    CSL_coords = torch.cat((CSL_dataset['times'].unsqueeze(-1), CSL_dataset['states']), dim=-1)
                    CSL_costs = CSL_dataset['costs']

                    num_CSL_val_samples = int(0.1*num_CSL_samples)
                    CSL_val_dataset = scenario_optimization(
                        model=self.model, policy=self.model, dynamics=self.dataset.dynamics,
                        tMin=self.dataset.tMin, tMax=CSL_tMax, dt=CSL_dt,
                        set_type="BRT", control_type="value", # TODO: implement option for BRS too
                        scenario_batch_size=min(num_CSL_val_samples, 100000), sample_batch_size=min(10*num_CSL_val_samples, 1000000),
                        sample_generator=SliceSampleGenerator(dynamics=self.dataset.dynamics, slices=[None]*self.dataset.dynamics.state_dim),
                        sample_validator=ValueThresholdValidator(v_min=float('-inf'), v_max=float('inf')),
                        violation_validator=ValueThresholdValidator(v_min=0.0, v_max=0.0),
                        max_scenarios=num_CSL_val_samples, tStart_generator=lambda n : torch.zeros(n).uniform_(self.dataset.tMin, CSL_tMax)
                    )
                    CSL_val_coords = torch.cat((CSL_val_dataset['times'].unsqueeze(-1), CSL_val_dataset['states']), dim=-1)
                    CSL_val_costs = CSL_val_dataset['costs']

                    CSL_val_tMax_dataset = scenario_optimization(
                        model=self.model, policy=self.model, dynamics=self.dataset.dynamics,
                        tMin=self.dataset.tMin, tMax=self.dataset.tMax, dt=CSL_dt,
                        set_type="BRT", control_type="value", # TODO: implement option for BRS too
                        scenario_batch_size=min(num_CSL_val_samples, 100000), sample_batch_size=min(10*num_CSL_val_samples, 1000000),
                        sample_generator=SliceSampleGenerator(dynamics=self.dataset.dynamics, slices=[None]*self.dataset.dynamics.state_dim),
                        sample_validator=ValueThresholdValidator(v_min=float('-inf'), v_max=float('inf')),
                        violation_validator=ValueThresholdValidator(v_min=0.0, v_max=0.0),
                        max_scenarios=num_CSL_val_samples # no tStart_generator, since I want all tMax times
                    )
                    CSL_val_tMax_coords = torch.cat((CSL_val_tMax_dataset['times'].unsqueeze(-1), CSL_val_tMax_dataset['states']), dim=-1)
                    CSL_val_tMax_costs = CSL_val_tMax_dataset['costs']

                    self.model.train()

                    # CSL optimizer
                    CSL_optim = torch.optim.Adam(lr=CSL_lr, params=self.model.parameters())

                    # initial CSL val loss
                    CSL_val_results = self.model({'coords': self.dataset.dynamics.coord_to_input(CSL_val_coords.to(device))})
                    CSL_val_preds = self.dataset.dynamics.io_to_value(CSL_val_results['model_in'], CSL_val_results['model_out'].squeeze(dim=-1))
                    CSL_val_errors = CSL_val_preds - CSL_val_costs.to(device)
                    CSL_val_loss = torch.mean(torch.pow(CSL_val_errors, 2))
                    CSL_initial_val_loss = CSL_val_loss
                    if self.use_wandb:
                        wandb.log({
                            "step": epoch,
                            "CSL_val_loss": CSL_val_loss.item()
                        })

                    # initial self-supervised learning (SSL) val loss
                    # right now, just took code from dataio.py and the SSL training loop above; TODO: refactor all this for cleaner modular code
                    CSL_val_states = CSL_val_coords[..., 1:].to(device)
                    CSL_val_dvs = self.dataset.dynamics.io_to_dv(CSL_val_results['model_in'], CSL_val_results['model_out'].squeeze(dim=-1))
                    CSL_val_boundary_values = self.dataset.dynamics.boundary_fn(CSL_val_states)
                    if self.dataset.dynamics.loss_type == 'brat_hjivi':
                        CSL_val_reach_values = self.dataset.dynamics.reach_fn(CSL_val_states)
                        CSL_val_avoid_values = self.dataset.dynamics.avoid_fn(CSL_val_states)
                    CSL_val_dirichlet_masks = CSL_val_coords[:, 0].to(device) == self.dataset.tMin # assumes time unit in dataset (model) is same as real time units
                    if self.dataset.dynamics.loss_type == 'brt_hjivi':
                        SSL_val_losses = loss_fn(CSL_val_states, CSL_val_preds, CSL_val_dvs[..., 0], CSL_val_dvs[..., 1:], CSL_val_boundary_values, CSL_val_dirichlet_masks)
                    elif self.dataset.dynamics.loss_type == 'brat_hjivi':
                        SSL_val_losses = loss_fn(CSL_val_states, CSL_val_preds, CSL_val_dvs[..., 0], CSL_val_dvs[..., 1:], CSL_val_boundary_values, CSL_val_reach_values, CSL_val_avoid_values, CSL_val_dirichlet_masks)
                    else:
                        NotImplementedError
                    SSL_val_loss = SSL_val_losses['diff_constraint_hom'].mean() # I assume there is no dirichlet (boundary) loss here, because I do not ever explicitly generate source samples at tMin (i.e. torch.all(CSL_val_dirichlet_masks == False))
                    if self.use_wandb:
                        wandb.log({
                            "step": epoch,
                            "SSL_val_loss": SSL_val_loss.item()
                        })

                    # CSL training loop
                    for CSL_epoch in tqdm(range(max_CSL_epochs)):
                        CSL_idxs = torch.randperm(num_CSL_samples)
                        for CSL_batch in range(math.ceil(num_CSL_samples/CSL_batch_size)):
                            CSL_batch_idxs = CSL_idxs[CSL_batch*CSL_batch_size:(CSL_batch+1)*CSL_batch_size]
                            CSL_batch_coords = CSL_coords[CSL_batch_idxs]

                            CSL_batch_results = self.model({'coords': self.dataset.dynamics.coord_to_input(CSL_batch_coords.to(device))})
                            CSL_batch_preds = self.dataset.dynamics.io_to_value(CSL_batch_results['model_in'], CSL_batch_results['model_out'].squeeze(dim=-1))
                            CSL_batch_costs = CSL_costs[CSL_batch_idxs].to(device)
                            CSL_batch_errors = CSL_batch_preds - CSL_batch_costs
                            CSL_batch_loss = CSL_loss_weight*torch.mean(torch.pow(CSL_batch_errors, 2))

                            CSL_batch_states = CSL_batch_coords[..., 1:].to(device)
                            CSL_batch_dvs = self.dataset.dynamics.io_to_dv(CSL_batch_results['model_in'], CSL_batch_results['model_out'].squeeze(dim=-1))
                            CSL_batch_boundary_values = self.dataset.dynamics.boundary_fn(CSL_batch_states)
                            if self.dataset.dynamics.loss_type == 'brat_hjivi':
                                CSL_batch_reach_values = self.dataset.dynamics.reach_fn(CSL_batch_states)
                                CSL_batch_avoid_values = self.dataset.dynamics.avoid_fn(CSL_batch_states)
                            CSL_batch_dirichlet_masks = CSL_batch_coords[:, 0].to(device) == self.dataset.tMin # assumes time unit in dataset (model) is same as real time units
                            if self.dataset.dynamics.loss_type == 'brt_hjivi':
                                SSL_batch_losses = loss_fn(CSL_batch_states, CSL_batch_preds, CSL_batch_dvs[..., 0], CSL_batch_dvs[..., 1:], CSL_batch_boundary_values, CSL_batch_dirichlet_masks)
                            elif self.dataset.dynamics.loss_type == 'brat_hjivi':
                                SSL_batch_losses = loss_fn(CSL_batch_states, CSL_batch_preds, CSL_batch_dvs[..., 0], CSL_batch_dvs[..., 1:], CSL_batch_boundary_values, CSL_batch_reach_values, CSL_batch_avoid_values, CSL_batch_dirichlet_masks)
                            else:
                                NotImplementedError
                            SSL_batch_loss = SSL_batch_losses['diff_constraint_hom'].mean() # I assume there is no dirichlet (boundary) loss here, because I do not ever explicitly generate source samples at tMin (i.e. torch.all(CSL_batch_dirichlet_masks == False))

                            CSL_optim.zero_grad()
                            SSL_batch_loss.backward(retain_graph=True)
                            if (not use_lbfgs) and clip_grad: # no adjust_relative_grads, because I assume even with adjustment, the diff_constraint_hom remains unaffected and the only other loss (dirichlet) is zero
                                if isinstance(clip_grad, bool):
                                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.)
                                else:
                                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=clip_grad)
                            CSL_batch_loss.backward()
                            CSL_optim.step()

                        # evaluate on CSL_val_dataset
                        CSL_val_results = self.model({'coords': self.dataset.dynamics.coord_to_input(CSL_val_coords.to(device))})
                        CSL_val_preds = self.dataset.dynamics.io_to_value(CSL_val_results['model_in'], CSL_val_results['model_out'].squeeze(dim=-1))
                        CSL_val_errors = CSL_val_preds - CSL_val_costs.to(device)
                        CSL_val_loss = torch.mean(torch.pow(CSL_val_errors, 2))

                        CSL_val_states = CSL_val_coords[..., 1:].to(device)
                        CSL_val_dvs = self.dataset.dynamics.io_to_dv(CSL_val_results['model_in'], CSL_val_results['model_out'].squeeze(dim=-1))
                        CSL_val_boundary_values = self.dataset.dynamics.boundary_fn(CSL_val_states)
                        if self.dataset.dynamics.loss_type == 'brat_hjivi':
                            CSL_val_reach_values = self.dataset.dynamics.reach_fn(CSL_val_states)
                            CSL_val_avoid_values = self.dataset.dynamics.avoid_fn(CSL_val_states)
                        CSL_val_dirichlet_masks = CSL_val_coords[:, 0].to(device) == self.dataset.tMin # assumes time unit in dataset (model) is same as real time units
                        if self.dataset.dynamics.loss_type == 'brt_hjivi':
                            SSL_val_losses = loss_fn(CSL_val_states, CSL_val_preds, CSL_val_dvs[..., 0], CSL_val_dvs[..., 1:], CSL_val_boundary_values, CSL_val_dirichlet_masks)
                        elif self.dataset.dynamics.loss_type == 'brat_hjivi':
                            SSL_val_losses = loss_fn(CSL_val_states, CSL_val_preds, CSL_val_dvs[..., 0], CSL_val_dvs[..., 1:], CSL_val_boundary_values, CSL_val_reach_values, CSL_val_avoid_values, CSL_val_dirichlet_masks)
                        else:
                            raise NotImplementedError
                        SSL_val_loss = SSL_val_losses['diff_constraint_hom'].mean() # I assume there is no dirichlet (boundary) loss here, because I do not ever explicitly generate source samples at tMin (i.e. torch.all(CSL_val_dirichlet_masks == False))

                        CSL_val_tMax_results = self.model({'coords': self.dataset.dynamics.coord_to_input(CSL_val_tMax_coords.to(device))})
                        CSL_val_tMax_preds = self.dataset.dynamics.io_to_value(CSL_val_tMax_results['model_in'], CSL_val_tMax_results['model_out'].squeeze(dim=-1))
                        CSL_val_tMax_errors = CSL_val_tMax_preds - CSL_val_tMax_costs.to(device)
                        CSL_val_tMax_loss = torch.mean(torch.pow(CSL_val_tMax_errors, 2))

                        # log CSL losses, recovered_safe_set_fracs
                        if self.dataset.dynamics.set_mode == 'reach':
                            CSL_train_batch_theoretically_recoverable_safe_set_frac = torch.sum(CSL_batch_costs.to(device) < 0) / len(CSL_batch_preds)
                            CSL_train_batch_recovered_safe_set_frac = torch.sum(CSL_batch_preds < torch.min(CSL_batch_preds[CSL_batch_costs.to(device) > 0])) / len(CSL_batch_preds)
                            CSL_val_theoretically_recoverable_safe_set_frac = torch.sum(CSL_val_costs.to(device) < 0) / len(CSL_val_preds)
                            CSL_val_recovered_safe_set_frac = torch.sum(CSL_val_preds < torch.min(CSL_val_preds[CSL_val_costs.to(device) > 0])) / len(CSL_val_preds)
                            CSL_val_tMax_theoretically_recoverable_safe_set_frac = torch.sum(CSL_val_tMax_costs.to(device) < 0) / len(CSL_val_tMax_preds)
                            CSL_val_tMax_recovered_safe_set_frac = torch.sum(CSL_val_tMax_preds < torch.min(CSL_val_tMax_preds[CSL_val_tMax_costs.to(device) > 0])) / len(CSL_val_tMax_preds)
                        elif self.dataset.dynamics.set_mode == 'avoid':
                            CSL_train_batch_theoretically_recoverable_safe_set_frac = torch.sum(CSL_batch_costs.to(device) > 0) / len(CSL_batch_preds)
                            CSL_train_batch_recovered_safe_set_frac = torch.sum(CSL_batch_preds > torch.max(CSL_batch_preds[CSL_batch_costs.to(device) < 0])) / len(CSL_batch_preds)
                            CSL_val_theoretically_recoverable_safe_set_frac = torch.sum(CSL_val_costs.to(device) > 0) / len(CSL_val_preds)
                            CSL_val_recovered_safe_set_frac = torch.sum(CSL_val_preds > torch.max(CSL_val_preds[CSL_val_costs.to(device) < 0])) / len(CSL_val_preds)
                            CSL_val_tMax_theoretically_recoverable_safe_set_frac = torch.sum(CSL_val_tMax_costs.to(device) > 0) / len(CSL_val_tMax_preds)
                            CSL_val_tMax_recovered_safe_set_frac = torch.sum(CSL_val_tMax_preds > torch.max(CSL_val_tMax_preds[CSL_val_tMax_costs.to(device) < 0])) / len(CSL_val_tMax_preds)
                        else:
                            raise NotImplementedError
                        if self.use_wandb:
                            wandb.log({
                                "step": epoch+(CSL_epoch+1)*int(0.5*epochs_til_CSL/max_CSL_epochs),
                                "CSL_train_batch_loss": CSL_batch_loss.item(),
                                "SSL_train_batch_loss": SSL_batch_loss.item(),
                                "CSL_val_loss": CSL_val_loss.item(),
                                "SSL_val_loss": SSL_val_loss.item(),
                                "CSL_val_tMax_loss": CSL_val_tMax_loss.item(),
                                "CSL_train_batch_theoretically_recoverable_safe_set_frac": CSL_train_batch_theoretically_recoverable_safe_set_frac.item(),
                                "CSL_val_theoretically_recoverable_safe_set_frac": CSL_val_theoretically_recoverable_safe_set_frac.item(),
                                "CSL_val_tMax_theoretically_recoverable_safe_set_frac": CSL_val_tMax_theoretically_recoverable_safe_set_frac.item(),
                                "CSL_train_batch_recovered_safe_set_frac": CSL_train_batch_recovered_safe_set_frac.item(),
                                "CSL_val_recovered_safe_set_frac": CSL_val_recovered_safe_set_frac.item(),
                                "CSL_val_tMax_recovered_safe_set_frac": CSL_val_tMax_recovered_safe_set_frac.item(),
                            })

                        if CSL_val_loss < CSL_loss_frac_cutoff*CSL_initial_val_loss:
                            break

                if not (epoch+1) % epochs_til_checkpoint:
                    # Saving the optimizer state is important to produce consistent results
                    checkpoint = {
                        'epoch': epoch+1,
                        'model': self.model.state_dict(),
                        'optimizer': optim.state_dict()}
                    torch.save(checkpoint,
                        os.path.join(checkpoints_dir, 'model_epoch_%04d.pth' % (epoch+1)))
                    np.savetxt(os.path.join(checkpoints_dir, 'train_losses_epoch_%04d.txt' % (epoch+1)),
                        np.array(train_losses))
                    self.validate(
                        device=device, epoch=epoch+1, save_path=os.path.join(checkpoints_dir, 'BRS_validation_plot_epoch_%04d.png' % (epoch+1)),
                        x_resolution = val_x_resolution, y_resolution = val_y_resolution, z_resolution=val_z_resolution, time_resolution=val_time_resolution)

        if was_eval:
            self.model.eval()
            self.model.requires_grad_(False)

    def test(self, device, current_time, last_checkpoint, checkpoint_dt, dt, num_scenarios, num_violations, set_type, control_type, data_step, checkpoint_toload=None):
        was_training = self.model.training
        self.model.eval()
        self.model.requires_grad_(False)

        testing_dir = os.path.join(self.experiment_dir, 'testing_%s' % current_time.strftime('%m_%d_%Y_%H_%M'))
        if os.path.exists(testing_dir):
            overwrite = input("The testing directory %s already exists. Overwrite? (y/n)"%testing_dir)
            if not (overwrite == 'y'):
                print('Exiting.')
                quit()
            shutil.rmtree(testing_dir)
        os.makedirs(testing_dir)

        if checkpoint_toload is None:
            print('running cross-checkpoint testing')

            for i in tqdm(range(sidelen), desc='Checkpoint'):
                self._load_checkpoint(epoch=checkpoints[i])
                raise NotImplementedError

        else:
            print('running specific-checkpoint testing')
            self._load_checkpoint(checkpoint_toload)

            model = self.model
            dataset = self.dataset
            dynamics = dataset.dynamics
            raise NotImplementedError

        if was_training:
            self.model.train()
            self.model.requires_grad_(True)

class DeepReach(Experiment):
    def init_special(self):
        pass

### 7. Run Experiment

In [13]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33maanirudh[0m ([33maanirudh-n-a[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [18]:
import sys

sys.argv = [
    "script_name",  # Placeholder for script name (ignored by argparse)
    "--mode", "train",
    "--experiment_class", "DeepReach",
    "--dynamics_class", "PlanarRobot2D",
    "--experiment_name", "brt_obstacle_05m",
    "--minWith", "target",
    "--goalR", "0.5",
    "--velocity", "1.0",
    "--set_mode", "avoid",
    "--num_epochs", "40000",

    "--wandb_project", "reachability-experiments",
    "--wandb_entity", "aanirudh-n-a",
    '--wandb_group', "Training",
    '--wandb_name', "brt_obstacle_05m_run_01",
]

In [19]:
import inspect
from inspect import isclass
p = configargparse.ArgumentParser()

p.add_argument('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
p.add_argument('--mode', type=str, required=True, choices=['all', 'train', 'test'], help="Experiment mode to run (new experiments must choose 'all' or 'train').")

# save/load directory options
p.add_argument('--experiments_dir', type=str, default='./runs', help='Where to save the experiment subdirectory.')
p.add_argument('--experiment_name', type=str, required=True, help='Name of the experient subdirectory.')
p.add_argument('--use_wandb', default=True, action='store_false', help='use wandb for logging')

use_wandb = p.parse_known_args()[0].use_wandb
if use_wandb:
    p.add_argument('--wandb_project', type=str, required=True, help='wandb project')
    p.add_argument('--wandb_entity', type=str, required=True, help='wandb entity')
    p.add_argument('--wandb_group', type=str, required=True, help='wandb group')
    p.add_argument('--wandb_name', type=str, required=True, help='name of wandb run')

mode = p.parse_known_args()[0].mode

if (mode == 'all') or (mode == 'train'):
    p.add_argument('--seed', type=int, default=0, required=False, help='Seed for the experiment.')

    # load experiment_class choices dynamically from experiments module
    experiment_classes_dict = {
        name: clss for name, clss in globals().items()
        if isclass(clss) and issubclass(clss, Experiment) and clss is not Experiment
    }
    # experiment_classes_dict = {name: clss for name, clss in inspect.getmembers(experiments, inspect.isclass) if clss.__bases__[0] == experiments.Experiment}
    p.add_argument('--experiment_class', type=str, default='DeepReach', choices=experiment_classes_dict.keys(), help='Experiment class to use.')
    # load special experiment_class arguments dynamically from chosen experiment class
    experiment_class = DeepReach #experiment_classes_dict[p.parse_known_args()[0].experiment_class]
    experiment_params = {name: param for name, param in inspect.signature(experiment_class.init_special).parameters.items() if name != 'self'}
    for param in experiment_params.keys():
        p.add_argument('--' + param, type=experiment_params[param].annotation, required=True, help='special experiment_class argument')

    # simulation data source options
    p.add_argument('--device', type=str, default='cuda:0', required=False, help='CUDA Device to use.')
    p.add_argument('--numpoints', type=int, default=65000, help='Number of points in simulation data source __getitem__.')
    p.add_argument('--pretrain', action='store_true', default=False, required=False, help='Pretrain dirichlet conditions')
    p.add_argument('--pretrain_iters', type=int, default=2000, required=False, help='Number of pretrain iterations')
    p.add_argument('--tMin', type=float, default=0.0, required=False, help='Start time of the simulation')
    p.add_argument('--tMax', type=float, default=1.0, required=False, help='End time of the simulation')
    p.add_argument('--counter_start', type=int, default=0, required=False, help='Defines the initial time for the curriculum training')
    p.add_argument('--counter_end', type=int, default=-1, required=False, help='Defines the linear step for curriculum training starting from the initial time')
    p.add_argument('--num_src_samples', type=int, default=1000, required=False, help='Number of source samples (initial-time samples) at each time step')
    p.add_argument('--num_target_samples', type=int, default=0, required=False, help='Number of samples inside the target set')

    # model options
    p.add_argument('--model', type=str, default='sine', required=False, choices=['sine', 'tanh', 'sigmoid', 'relu'], help='Type of model to evaluate, default is sine.')
    p.add_argument('--model_mode', type=str, default='mlp', required=False, choices=['mlp', 'rbf', 'pinn'], help='Whether to use uniform velocity parameter')
    p.add_argument('--num_hl', type=int, default=3, required=False, help='The number of hidden layers')
    p.add_argument('--num_nl', type=int, default=512, required=False, help='Number of neurons per hidden layer.')
    p.add_argument('--deepreach_model', type=str, default='exact', required=False, choices=['exact', 'diff', 'vanilla'], help='deepreach model')

    # training options
    p.add_argument('--epochs_til_ckpt', type=int, default=1000, help='Time interval in seconds until checkpoint is saved.')
    p.add_argument('--steps_til_summary', type=int, default=100, help='Time interval in seconds until tensorboard summary is saved.')
    p.add_argument('--batch_size', type=int, default=1, help='Batch size used during training (irrelevant, since len(dataset) == 1).')
    p.add_argument('--lr', type=float, default=2e-5, help='learning rate. default=2e-5')
    p.add_argument('--num_epochs', type=int, default=100000, help='Number of epochs to train for.')
    p.add_argument('--clip_grad', default=0.0, type=float, help='Clip gradient.')
    p.add_argument('--use_lbfgs', default=False, type=bool, help='use L-BFGS.')
    p.add_argument('--adj_rel_grads', default=True, type=bool, help='adjust the relative magnitude of the losses')
    p.add_argument('--dirichlet_loss_divisor', default=1.0, required=False, type=float, help='What to divide the dirichlet loss by for loss reweighting')

    # cost-supervised learning (CSL) options
    p.add_argument('--use_CSL', default=False, action='store_true', help='use cost-supervised learning (CSL)')
    p.add_argument('--CSL_lr', type=float, default=2e-5, help='The learning rate used for CSL')
    p.add_argument('--CSL_dt', type=float, default=0.0025, help='The dt used in rolling out trajectories to get cost labels')
    p.add_argument('--epochs_til_CSL', type=int, default=10000, help='Number of epochs between CSL phases')
    p.add_argument('--num_CSL_samples', type=int, default=1000000, help='Number of cost samples in training dataset for CSL phases')
    p.add_argument('--CSL_loss_frac_cutoff', type=float, default=0.1, help='Fraction of initial cost loss on validation dataset to cutoff CSL phases')
    p.add_argument('--max_CSL_epochs', type=int, default=100, help='Max number of CSL epochs per phase')
    p.add_argument('--CSL_loss_weight', type=float, default=1.0, help='weight of cost loss (relative to PDE loss)')
    p.add_argument('--CSL_batch_size', type=int, default=1000, help='Batch size for training in CSL phases')

    # validation (during training) options
    p.add_argument('--val_x_resolution', type=int, default=200, help='x-axis resolution of validation plot during training')
    p.add_argument('--val_y_resolution', type=int, default=200, help='y-axis resolution of validation plot during training')
    p.add_argument('--val_z_resolution', type=int, default=5, help='z-axis resolution of validation plot during training')
    p.add_argument('--val_time_resolution', type=int, default=3, help='time-axis resolution of validation plot during training')

    # loss options
    p.add_argument('--minWith', type=str, required=True, choices=['none', 'zero', 'target'], help='BRS vs BRT computation (typically should be using target for BRT)')

    # # load dynamics_class choices dynamically from dynamics module
    # dynamics_classes_dict = {name: clss for name, clss in inspect.getmembers(dynamics, inspect.isclass) if clss.__bases__[0] == dynamics.Dynamics}
    p.add_argument('--dynamics_class', type=str, required=True, choices=["PlanarRobot2D"], help='Dynamics class to use.')
    # # load special dynamics_class arguments dynamically from chosen dynamics class
    # dynamics_class = dynamics_classes_dict[p.parse_known_args()[0].dynamics_class]
    # dynamics_params = {name: param for name, param in inspect.signature(dynamics_class).parameters.items() if name != 'self'}
    dynamics_class = PlanarRobot2D
    dynamics_params = {
        name: param for name, param in inspect.signature(dynamics_class).parameters.items() if name != 'self'
    }

    for param in dynamics_params.keys():
        if dynamics_params[param].annotation is bool:
            p.add_argument('--' + param, type=dynamics_params[param].annotation, default=False, help='special dynamics_class argument')
        else:
            p.add_argument('--' + param, type=dynamics_params[param].annotation, required=True, help='special dynamics_class argument')

if (mode == 'all') or (mode == 'test'):
    p.add_argument('--dt', type=float, default=0.0025, help='The dt used in testing simulations')
    p.add_argument('--checkpoint_toload', type=int, default=None, help="The checkpoint to load for testing (-1 for final training checkpoint, None for cross-checkpoint testing")
    p.add_argument('--num_scenarios', type=int, default=100000, help='The number of scenarios sampled in scenario optimization for testing')
    p.add_argument('--num_violations', type=int, default=1000, help='The number of violations to sample for in scenario optimization for testing')
    p.add_argument('--control_type', type=str, default='value', choices=['value', 'ttr', 'init_ttr'], help='The controller to use in scenario optimization for testing')
    p.add_argument('--data_step', type=str, default='run_basic_recovery', choices=['plot_violations', 'run_basic_recovery', 'plot_basic_recovery', 'collect_samples', 'train_binner', 'run_binned_recovery', 'plot_binned_recovery', 'plot_cost_function'], help='The data processing step to run')

opt = p.parse_args()

# start wandb
if use_wandb:
    wandb.init(
        project = opt.wandb_project,
        entity = opt.wandb_entity,
        group = opt.wandb_group,
        name = opt.wandb_name,
    )
    wandb.config.update(opt)

experiment_dir = os.path.join(opt.experiments_dir, opt.experiment_name)
if (mode == 'all') or (mode == 'train'):
    # create experiment dir
    if os.path.exists(experiment_dir):
        overwrite = input("The experiment directory %s already exists. Overwrite? (y/n)"%experiment_dir)
        if not (overwrite == 'y'):
            print('Exiting.')
            quit()
        shutil.rmtree(experiment_dir)
    os.makedirs(experiment_dir)
elif mode == 'test':
    # confirm that experiment dir already exists
    if not os.path.exists(experiment_dir):
        raise RuntimeError('Cannot run test mode: experiment directory not found!')

current_time = datetime.now()
# log current config
with open(os.path.join(experiment_dir, 'config_%s.txt' % current_time.strftime('%m_%d_%Y_%H_%M')), 'w') as f:
    for arg, val in vars(opt).items():
        f.write(arg + ' = ' + str(val) + '\n')

if (mode == 'all') or (mode == 'train'):
    # set counter_end appropriately if needed
    if opt.counter_end == -1:
        opt.counter_end = opt.num_epochs

    # log original options
    with open(os.path.join(experiment_dir, 'orig_opt.pickle'), 'wb') as opt_file:
        pickle.dump(opt, opt_file)

# load original experiment settings
with open(os.path.join(experiment_dir, 'orig_opt.pickle'), 'rb') as opt_file:
    orig_opt = pickle.load(opt_file)

# set the experiment seed
torch.manual_seed(orig_opt.seed)
random.seed(orig_opt.seed)
np.random.seed(orig_opt.seed)

dynamics_class = PlanarRobot2D #getattr(dynamics, orig_opt.dynamics_class)
dynamics = dynamics_class(**{argname: getattr(orig_opt, argname) for argname in inspect.signature(dynamics_class).parameters.keys() if argname != 'self'})
dynamics.deepreach_model=orig_opt.deepreach_model
dataset = ReachabilityDataset(
    dynamics=dynamics, numpoints=orig_opt.numpoints,
    pretrain=orig_opt.pretrain, pretrain_iters=orig_opt.pretrain_iters,
    tMin=orig_opt.tMin, tMax=orig_opt.tMax,
    counter_start=orig_opt.counter_start, counter_end=orig_opt.counter_end,
    num_src_samples=orig_opt.num_src_samples, num_target_samples=orig_opt.num_target_samples)

model = SingleBVPNet(in_features=dynamics.input_dim, out_features=1, type=orig_opt.model, mode=orig_opt.model_mode,
                             final_layer_factor=1., hidden_features=orig_opt.num_nl, num_hidden_layers=orig_opt.num_hl)
model.to(opt.device)

experiment_class = DeepReach #getattr(experiments, orig_opt.experiment_class)
experiment = experiment_class(model=model, dataset=dataset, experiment_dir=experiment_dir, use_wandb=use_wandb)
experiment.init_special(**{argname: getattr(orig_opt, argname) for argname in inspect.signature(experiment_class.init_special).parameters.keys() if argname != 'self'})

if (mode == 'all') or (mode == 'train'):
    if dynamics.loss_type == 'brt_hjivi':
        loss_fn = init_brt_hjivi_loss(dynamics, orig_opt.minWith, orig_opt.dirichlet_loss_divisor)
    elif dynamics.loss_type == 'brat_hjivi':
        loss_fn = init_brat_hjivi_loss(dynamics, orig_opt.minWith, orig_opt.dirichlet_loss_divisor)
    else:
        raise NotImplementedError
    experiment.train(
        device=opt.device, batch_size=orig_opt.batch_size, epochs=orig_opt.num_epochs, lr=orig_opt.lr,
        steps_til_summary=orig_opt.steps_til_summary, epochs_til_checkpoint=orig_opt.epochs_til_ckpt,
        loss_fn=loss_fn, clip_grad=orig_opt.clip_grad, use_lbfgs=orig_opt.use_lbfgs, adjust_relative_grads=orig_opt.adj_rel_grads,
        val_x_resolution=orig_opt.val_x_resolution, val_y_resolution=orig_opt.val_y_resolution, val_z_resolution=orig_opt.val_z_resolution, val_time_resolution=orig_opt.val_time_resolution,
        use_CSL=orig_opt.use_CSL, CSL_lr=orig_opt.CSL_lr, CSL_dt=orig_opt.CSL_dt, epochs_til_CSL=orig_opt.epochs_til_CSL, num_CSL_samples=orig_opt.num_CSL_samples, CSL_loss_frac_cutoff=orig_opt.CSL_loss_frac_cutoff, max_CSL_epochs=orig_opt.max_CSL_epochs, CSL_loss_weight=orig_opt.CSL_loss_weight, CSL_batch_size=orig_opt.CSL_batch_size)

if (mode == 'all') or (mode == 'test'):
    experiment.test(
        device=opt.device, current_time=current_time,
        last_checkpoint=orig_opt.num_epochs, checkpoint_dt=orig_opt.epochs_til_ckpt,
        checkpoint_toload=opt.checkpoint_toload, dt=opt.dt,
        num_scenarios=opt.num_scenarios, num_violations=opt.num_violations,
        set_type='BRT' if orig_opt.minWith in ['zero', 'target'] else 'BRS', control_type=opt.control_type, data_step=opt.data_step)

SingleBVPNet(
  (net): FCBlock(
    (net): Sequential(
      (0): Sequential(
        (0): BatchLinear(in_features=3, out_features=512, bias=True)
        (1): Sine()
      )
      (1): Sequential(
        (0): BatchLinear(in_features=512, out_features=512, bias=True)
        (1): Sine()
      )
      (2): Sequential(
        (0): BatchLinear(in_features=512, out_features=512, bias=True)
        (1): Sine()
      )
      (3): Sequential(
        (0): BatchLinear(in_features=512, out_features=512, bias=True)
        (1): Sine()
      )
      (4): Sequential(
        (0): BatchLinear(in_features=512, out_features=1, bias=True)
      )
    )
  )
)


  0%|          | 1/40000 [00:01<15:53:16,  1.43s/it]

Epoch 0, Total loss 2817.822021, iteration time 1.395628


  0%|          | 102/40000 [00:23<2:25:59,  4.55it/s]

Epoch 100, Total loss 219.990967, iteration time 0.221183


  1%|          | 202/40000 [00:44<2:26:28,  4.53it/s]

Epoch 200, Total loss 7.243275, iteration time 0.222728


  1%|          | 302/40000 [01:06<2:29:38,  4.42it/s]

Epoch 300, Total loss 4.085196, iteration time 0.231560


  1%|          | 402/40000 [01:29<2:29:50,  4.40it/s]

Epoch 400, Total loss 6.817874, iteration time 0.223333


  1%|▏         | 502/40000 [01:51<2:33:23,  4.29it/s]

Epoch 500, Total loss 10.895412, iteration time 0.235565


  2%|▏         | 602/40000 [02:14<2:35:37,  4.22it/s]

Epoch 600, Total loss 13.368431, iteration time 0.240972


  2%|▏         | 702/40000 [02:37<2:32:07,  4.31it/s]

Epoch 700, Total loss 15.966686, iteration time 0.234381


  2%|▏         | 802/40000 [03:00<2:31:40,  4.31it/s]

Epoch 800, Total loss 21.346594, iteration time 0.231023


  2%|▏         | 902/40000 [03:22<2:30:06,  4.34it/s]

Epoch 900, Total loss 23.050741, iteration time 0.235264


  3%|▎         | 1002/40000 [03:46<4:14:27,  2.55it/s]

Epoch 1000, Total loss 28.613506, iteration time 0.096495


  3%|▎         | 1102/40000 [04:09<2:33:22,  4.23it/s]

Epoch 1100, Total loss 28.930210, iteration time 0.239513


  3%|▎         | 1202/40000 [04:31<2:29:59,  4.31it/s]

Epoch 1200, Total loss 24.129002, iteration time 0.238394


  3%|▎         | 1302/40000 [04:54<2:30:46,  4.28it/s]

Epoch 1300, Total loss 55.264885, iteration time 0.233086


  4%|▎         | 1402/40000 [05:17<2:28:48,  4.32it/s]

Epoch 1400, Total loss 37.189682, iteration time 0.229247


  4%|▍         | 1502/40000 [05:40<2:29:35,  4.29it/s]

Epoch 1500, Total loss 61.090290, iteration time 0.245869


  4%|▍         | 1602/40000 [06:02<2:31:33,  4.22it/s]

Epoch 1600, Total loss 43.868122, iteration time 0.239611


  4%|▍         | 1702/40000 [06:25<2:29:02,  4.28it/s]

Epoch 1700, Total loss 54.180862, iteration time 0.231577


  5%|▍         | 1802/40000 [06:48<2:29:20,  4.26it/s]

Epoch 1800, Total loss 33.367905, iteration time 0.240482


  5%|▍         | 1902/40000 [07:11<2:28:36,  4.27it/s]

Epoch 1900, Total loss 57.654694, iteration time 0.226864


  5%|▌         | 2002/40000 [07:34<3:48:32,  2.77it/s]

Epoch 2000, Total loss 72.086365, iteration time 0.107411


  5%|▌         | 2102/40000 [07:57<2:27:42,  4.28it/s]

Epoch 2100, Total loss 56.818802, iteration time 0.234466


  6%|▌         | 2202/40000 [08:20<2:29:42,  4.21it/s]

Epoch 2200, Total loss 68.335808, iteration time 0.243784


  6%|▌         | 2302/40000 [08:42<2:26:58,  4.27it/s]

Epoch 2300, Total loss 67.170265, iteration time 0.235011


  6%|▌         | 2402/40000 [09:05<2:27:19,  4.25it/s]

Epoch 2400, Total loss 72.707733, iteration time 0.234453


  6%|▋         | 2502/40000 [09:28<2:27:21,  4.24it/s]

Epoch 2500, Total loss 52.796455, iteration time 0.244592


  7%|▋         | 2602/40000 [09:51<2:25:11,  4.29it/s]

Epoch 2600, Total loss 70.123520, iteration time 0.233879


  7%|▋         | 2702/40000 [10:14<2:24:48,  4.29it/s]

Epoch 2700, Total loss 49.686676, iteration time 0.234061


  7%|▋         | 2802/40000 [10:36<2:27:31,  4.20it/s]

Epoch 2800, Total loss 48.602005, iteration time 0.246720


  7%|▋         | 2902/40000 [10:59<2:23:00,  4.32it/s]

Epoch 2900, Total loss 60.628403, iteration time 0.231926


  8%|▊         | 3002/40000 [11:23<4:01:39,  2.55it/s]

Epoch 3000, Total loss 55.661774, iteration time 0.113386


  8%|▊         | 3102/40000 [11:45<2:23:41,  4.28it/s]

Epoch 3100, Total loss 58.595314, iteration time 0.234484


  8%|▊         | 3202/40000 [12:08<2:23:30,  4.27it/s]

Epoch 3200, Total loss 80.446144, iteration time 0.237259


  8%|▊         | 3302/40000 [12:31<2:22:42,  4.29it/s]

Epoch 3300, Total loss 75.966156, iteration time 0.238693


  9%|▊         | 3402/40000 [12:54<2:21:43,  4.30it/s]

Epoch 3400, Total loss 57.590050, iteration time 0.246374


  9%|▉         | 3502/40000 [13:16<2:22:02,  4.28it/s]

Epoch 3500, Total loss 64.050186, iteration time 0.237018


  9%|▉         | 3602/40000 [13:39<2:26:26,  4.14it/s]

Epoch 3600, Total loss 50.104610, iteration time 0.254594


  9%|▉         | 3702/40000 [14:02<2:19:50,  4.33it/s]

Epoch 3700, Total loss 50.366352, iteration time 0.233323


 10%|▉         | 3802/40000 [14:25<2:19:07,  4.34it/s]

Epoch 3800, Total loss 49.920654, iteration time 0.235912


 10%|▉         | 3902/40000 [14:47<2:19:42,  4.31it/s]

Epoch 3900, Total loss 41.173309, iteration time 0.238039


 10%|█         | 4002/40000 [15:11<3:25:21,  2.92it/s]

Epoch 4000, Total loss 58.176147, iteration time 0.094923


 10%|█         | 4102/40000 [15:33<2:18:34,  4.32it/s]

Epoch 4100, Total loss 54.093719, iteration time 0.236010


 11%|█         | 4202/40000 [15:56<2:22:04,  4.20it/s]

Epoch 4200, Total loss 64.750778, iteration time 0.249045


 11%|█         | 4302/40000 [16:19<2:17:33,  4.33it/s]

Epoch 4300, Total loss 59.316818, iteration time 0.235753


 11%|█         | 4402/40000 [16:42<2:21:15,  4.20it/s]

Epoch 4400, Total loss 85.767426, iteration time 0.254475


 11%|█▏        | 4502/40000 [17:05<2:17:10,  4.31it/s]

Epoch 4500, Total loss 72.771065, iteration time 0.237519


 12%|█▏        | 4602/40000 [17:27<2:16:15,  4.33it/s]

Epoch 4600, Total loss 57.073769, iteration time 0.228973


 12%|█▏        | 4702/40000 [17:50<2:18:41,  4.24it/s]

Epoch 4700, Total loss 55.400360, iteration time 0.242521


 12%|█▏        | 4802/40000 [18:13<2:16:46,  4.29it/s]

Epoch 4800, Total loss 72.964569, iteration time 0.235226


 12%|█▏        | 4902/40000 [18:36<2:17:50,  4.24it/s]

Epoch 4900, Total loss 81.723724, iteration time 0.239167


 13%|█▎        | 5002/40000 [18:59<3:49:49,  2.54it/s]

Epoch 5000, Total loss 52.496445, iteration time 0.103832


 13%|█▎        | 5102/40000 [19:22<2:15:52,  4.28it/s]

Epoch 5100, Total loss 56.614189, iteration time 0.241461


 13%|█▎        | 5202/40000 [19:45<2:18:10,  4.20it/s]

Epoch 5200, Total loss 93.523422, iteration time 0.251094


 13%|█▎        | 5302/40000 [20:07<2:15:15,  4.28it/s]

Epoch 5300, Total loss 67.786697, iteration time 0.244226


 14%|█▎        | 5402/40000 [20:30<2:14:35,  4.28it/s]

Epoch 5400, Total loss 51.263859, iteration time 0.236355


 14%|█▍        | 5502/40000 [20:53<2:13:03,  4.32it/s]

Epoch 5500, Total loss 81.297386, iteration time 0.237618


 14%|█▍        | 5602/40000 [21:16<2:14:33,  4.26it/s]

Epoch 5600, Total loss 77.485214, iteration time 0.259334


 14%|█▍        | 5702/40000 [21:38<2:12:43,  4.31it/s]

Epoch 5700, Total loss 65.399040, iteration time 0.238080


 15%|█▍        | 5802/40000 [22:01<2:13:45,  4.26it/s]

Epoch 5800, Total loss 77.356873, iteration time 0.238760


 15%|█▍        | 5902/40000 [22:24<2:11:52,  4.31it/s]

Epoch 5900, Total loss 79.610641, iteration time 0.235593


 15%|█▌        | 6002/40000 [22:47<3:14:39,  2.91it/s]

Epoch 6000, Total loss 85.361610, iteration time 0.095567


 15%|█▌        | 6102/40000 [23:10<2:12:29,  4.26it/s]

Epoch 6100, Total loss 72.542984, iteration time 0.239195


 16%|█▌        | 6202/40000 [23:33<2:10:11,  4.33it/s]

Epoch 6200, Total loss 99.153488, iteration time 0.234835


 16%|█▌        | 6302/40000 [23:55<2:10:17,  4.31it/s]

Epoch 6300, Total loss 71.612411, iteration time 0.235656


 16%|█▌        | 6402/40000 [24:18<2:11:15,  4.27it/s]

Epoch 6400, Total loss 68.526871, iteration time 0.245173


 16%|█▋        | 6502/40000 [24:41<2:10:04,  4.29it/s]

Epoch 6500, Total loss 88.659073, iteration time 0.245777


 17%|█▋        | 6602/40000 [25:04<2:09:35,  4.30it/s]

Epoch 6600, Total loss 56.340637, iteration time 0.238552


 17%|█▋        | 6702/40000 [25:26<2:10:52,  4.24it/s]

Epoch 6700, Total loss 111.383659, iteration time 0.248721


 17%|█▋        | 6802/40000 [25:49<2:07:42,  4.33it/s]

Epoch 6800, Total loss 92.984863, iteration time 0.238152


 17%|█▋        | 6902/40000 [26:12<2:07:35,  4.32it/s]

Epoch 6900, Total loss 95.064453, iteration time 0.237861


 18%|█▊        | 7002/40000 [26:35<3:39:44,  2.50it/s]

Epoch 7000, Total loss 106.084541, iteration time 0.103853


 18%|█▊        | 7102/40000 [26:58<2:06:52,  4.32it/s]

Epoch 7100, Total loss 117.329514, iteration time 0.229818


 18%|█▊        | 7202/40000 [27:21<2:07:38,  4.28it/s]

Epoch 7200, Total loss 99.308113, iteration time 0.237351


 18%|█▊        | 7302/40000 [27:43<2:08:00,  4.26it/s]

Epoch 7300, Total loss 110.602768, iteration time 0.240211


 19%|█▊        | 7402/40000 [28:06<2:07:02,  4.28it/s]

Epoch 7400, Total loss 64.019455, iteration time 0.239431


 19%|█▉        | 7502/40000 [28:29<2:05:59,  4.30it/s]

Epoch 7500, Total loss 51.818611, iteration time 0.240448


 19%|█▉        | 7602/40000 [28:52<2:07:29,  4.24it/s]

Epoch 7600, Total loss 66.712807, iteration time 0.248521


 19%|█▉        | 7702/40000 [29:14<2:05:13,  4.30it/s]

Epoch 7700, Total loss 119.937866, iteration time 0.236657


 20%|█▉        | 7802/40000 [29:37<2:04:03,  4.33it/s]

Epoch 7800, Total loss 61.171700, iteration time 0.237389


 20%|█▉        | 7902/40000 [30:00<2:05:15,  4.27it/s]

Epoch 7900, Total loss 52.633476, iteration time 0.238660


 20%|██        | 8002/40000 [30:23<3:04:25,  2.89it/s]

Epoch 8000, Total loss 66.848526, iteration time 0.099323


 20%|██        | 8102/40000 [30:46<2:04:10,  4.28it/s]

Epoch 8100, Total loss 107.534561, iteration time 0.235430


 21%|██        | 8202/40000 [31:08<2:03:14,  4.30it/s]

Epoch 8200, Total loss 100.663803, iteration time 0.236523


 21%|██        | 8302/40000 [31:31<2:02:18,  4.32it/s]

Epoch 8300, Total loss 84.260117, iteration time 0.237783


 21%|██        | 8402/40000 [31:54<2:04:03,  4.24it/s]

Epoch 8400, Total loss 117.388908, iteration time 0.246680


 21%|██▏       | 8502/40000 [32:17<2:01:47,  4.31it/s]

Epoch 8500, Total loss 71.769333, iteration time 0.236685


 22%|██▏       | 8602/40000 [32:39<2:01:11,  4.32it/s]

Epoch 8600, Total loss 107.827065, iteration time 0.237052


 22%|██▏       | 8702/40000 [33:02<2:03:23,  4.23it/s]

Epoch 8700, Total loss 152.617142, iteration time 0.242337


 22%|██▏       | 8802/40000 [33:25<2:00:11,  4.33it/s]

Epoch 8800, Total loss 89.044830, iteration time 0.238173


 22%|██▏       | 8902/40000 [33:48<2:01:22,  4.27it/s]

Epoch 8900, Total loss 66.372826, iteration time 0.236705


 23%|██▎       | 9002/40000 [34:11<3:28:19,  2.48it/s]

Epoch 9000, Total loss 132.541153, iteration time 0.115783


 23%|██▎       | 9102/40000 [34:34<1:59:25,  4.31it/s]

Epoch 9100, Total loss 118.179939, iteration time 0.237157


 23%|██▎       | 9202/40000 [34:57<1:58:47,  4.32it/s]

Epoch 9200, Total loss 106.515923, iteration time 0.241896


 23%|██▎       | 9302/40000 [35:19<2:00:28,  4.25it/s]

Epoch 9300, Total loss 180.582520, iteration time 0.235791


 24%|██▎       | 9402/40000 [35:42<1:58:24,  4.31it/s]

Epoch 9400, Total loss 47.404083, iteration time 0.236114


 24%|██▍       | 9502/40000 [36:05<1:57:27,  4.33it/s]

Epoch 9500, Total loss 116.499687, iteration time 0.236265


 24%|██▍       | 9602/40000 [36:28<1:58:49,  4.26it/s]

Epoch 9600, Total loss 90.575043, iteration time 0.240023


 24%|██▍       | 9702/40000 [36:50<1:56:39,  4.33it/s]

Epoch 9700, Total loss 54.128067, iteration time 0.241925


 25%|██▍       | 9802/40000 [37:13<1:56:13,  4.33it/s]

Epoch 9800, Total loss 138.051941, iteration time 0.235657


 25%|██▍       | 9902/40000 [37:36<1:57:35,  4.27it/s]

Epoch 9900, Total loss 151.827820, iteration time 0.239920


 25%|██▌       | 10002/40000 [37:59<3:36:24,  2.31it/s]

Epoch 10000, Total loss 92.714691, iteration time 0.108822


 25%|██▌       | 10102/40000 [38:22<1:56:16,  4.29it/s]

Epoch 10100, Total loss 148.415527, iteration time 0.231377


 26%|██▌       | 10202/40000 [38:45<1:54:40,  4.33it/s]

Epoch 10200, Total loss 73.214760, iteration time 0.235852


 26%|██▌       | 10302/40000 [39:08<1:54:46,  4.31it/s]

Epoch 10300, Total loss 66.505432, iteration time 0.235204


 26%|██▌       | 10402/40000 [39:30<1:56:39,  4.23it/s]

Epoch 10400, Total loss 103.732178, iteration time 0.230792


 26%|██▋       | 10502/40000 [39:53<1:54:22,  4.30it/s]

Epoch 10500, Total loss 90.789230, iteration time 0.236264


 27%|██▋       | 10602/40000 [40:16<1:53:59,  4.30it/s]

Epoch 10600, Total loss 49.457283, iteration time 0.241551


 27%|██▋       | 10702/40000 [40:38<1:54:24,  4.27it/s]

Epoch 10700, Total loss 144.809158, iteration time 0.244074


 27%|██▋       | 10802/40000 [41:01<1:53:40,  4.28it/s]

Epoch 10800, Total loss 68.124283, iteration time 0.235677


 27%|██▋       | 10902/40000 [41:24<1:53:26,  4.28it/s]

Epoch 10900, Total loss 53.681175, iteration time 0.243179


 28%|██▊       | 11002/40000 [41:48<3:14:27,  2.49it/s]

Epoch 11000, Total loss 73.592079, iteration time 0.101164


 28%|██▊       | 11102/40000 [42:10<1:51:20,  4.33it/s]

Epoch 11100, Total loss 96.817413, iteration time 0.240860


 28%|██▊       | 11202/40000 [42:33<1:50:51,  4.33it/s]

Epoch 11200, Total loss 86.221756, iteration time 0.234702


 28%|██▊       | 11302/40000 [42:56<1:52:51,  4.24it/s]

Epoch 11300, Total loss 132.267517, iteration time 0.239508


 29%|██▊       | 11402/40000 [43:19<1:50:36,  4.31it/s]

Epoch 11400, Total loss 92.217903, iteration time 0.236128


 29%|██▉       | 11502/40000 [43:41<1:51:31,  4.26it/s]

Epoch 11500, Total loss 73.588120, iteration time 0.239594


 29%|██▉       | 11602/40000 [44:04<1:51:57,  4.23it/s]

Epoch 11600, Total loss 105.795380, iteration time 0.247946


 29%|██▉       | 11702/40000 [44:27<1:51:06,  4.24it/s]

Epoch 11700, Total loss 157.014236, iteration time 0.241065


 30%|██▉       | 11802/40000 [44:50<1:49:15,  4.30it/s]

Epoch 11800, Total loss 170.414780, iteration time 0.243579


 30%|██▉       | 11902/40000 [45:12<1:49:37,  4.27it/s]

Epoch 11900, Total loss 162.678345, iteration time 0.234627


 30%|███       | 12002/40000 [45:36<2:40:16,  2.91it/s]

Epoch 12000, Total loss 154.559113, iteration time 0.095380


 30%|███       | 12102/40000 [45:58<1:48:47,  4.27it/s]

Epoch 12100, Total loss 137.836151, iteration time 0.239300


 31%|███       | 12202/40000 [46:21<1:49:19,  4.24it/s]

Epoch 12200, Total loss 136.593597, iteration time 0.233598


 31%|███       | 12302/40000 [46:44<1:47:46,  4.28it/s]

Epoch 12300, Total loss 147.896927, iteration time 0.238690


 31%|███       | 12402/40000 [47:06<1:46:55,  4.30it/s]

Epoch 12400, Total loss 149.679138, iteration time 0.236651


 31%|███▏      | 12502/40000 [47:29<1:49:03,  4.20it/s]

Epoch 12500, Total loss 167.267120, iteration time 0.245479


 32%|███▏      | 12602/40000 [47:52<1:45:27,  4.33it/s]

Epoch 12600, Total loss 151.684723, iteration time 0.240782


 32%|███▏      | 12702/40000 [48:15<1:45:41,  4.30it/s]

Epoch 12700, Total loss 149.770859, iteration time 0.239899


 32%|███▏      | 12802/40000 [48:37<1:47:18,  4.22it/s]

Epoch 12800, Total loss 144.624039, iteration time 0.242434


 32%|███▏      | 12902/40000 [49:00<1:44:19,  4.33it/s]

Epoch 12900, Total loss 153.646301, iteration time 0.235232


 33%|███▎      | 13002/40000 [49:23<2:35:23,  2.90it/s]

Epoch 13000, Total loss 160.098770, iteration time 0.098516


 33%|███▎      | 13102/40000 [49:46<1:43:57,  4.31it/s]

Epoch 13100, Total loss 150.226288, iteration time 0.237629


 33%|███▎      | 13202/40000 [50:09<1:43:10,  4.33it/s]

Epoch 13200, Total loss 157.488647, iteration time 0.235010


 33%|███▎      | 13302/40000 [50:32<1:44:10,  4.27it/s]

Epoch 13300, Total loss 153.030670, iteration time 0.237689


 34%|███▎      | 13402/40000 [50:54<1:43:25,  4.29it/s]

Epoch 13400, Total loss 156.222321, iteration time 0.242663


 34%|███▍      | 13502/40000 [51:17<1:42:23,  4.31it/s]

Epoch 13500, Total loss 150.037323, iteration time 0.234866


 34%|███▍      | 13602/40000 [51:40<1:42:18,  4.30it/s]

Epoch 13600, Total loss 155.878571, iteration time 0.239953


 34%|███▍      | 13702/40000 [52:02<1:41:48,  4.31it/s]

Epoch 13700, Total loss 151.892853, iteration time 0.238709


 35%|███▍      | 13802/40000 [52:25<1:41:03,  4.32it/s]

Epoch 13800, Total loss 170.864273, iteration time 0.234041


 35%|███▍      | 13902/40000 [52:48<1:42:10,  4.26it/s]

Epoch 13900, Total loss 183.908447, iteration time 0.241405


 35%|███▌      | 14002/40000 [53:11<2:30:03,  2.89it/s]

Epoch 14000, Total loss 154.846893, iteration time 0.096395


 35%|███▌      | 14102/40000 [53:34<1:41:31,  4.25it/s]

Epoch 14100, Total loss 165.542526, iteration time 0.243828


 36%|███▌      | 14202/40000 [53:57<1:42:21,  4.20it/s]

Epoch 14200, Total loss 157.581543, iteration time 0.245763


 36%|███▌      | 14302/40000 [54:19<1:39:27,  4.31it/s]

Epoch 14300, Total loss 166.836411, iteration time 0.237076


 36%|███▌      | 14402/40000 [54:42<1:38:50,  4.32it/s]

Epoch 14400, Total loss 158.443115, iteration time 0.241874


 36%|███▋      | 14502/40000 [55:05<1:40:32,  4.23it/s]

Epoch 14500, Total loss 165.637512, iteration time 0.250259


 37%|███▋      | 14602/40000 [55:28<1:38:51,  4.28it/s]

Epoch 14600, Total loss 163.113007, iteration time 0.236534


 37%|███▋      | 14702/40000 [55:50<1:37:28,  4.33it/s]

Epoch 14700, Total loss 215.083420, iteration time 0.239758


 37%|███▋      | 14802/40000 [56:13<1:39:39,  4.21it/s]

Epoch 14800, Total loss 71.082466, iteration time 0.242603


 37%|███▋      | 14902/40000 [56:36<1:36:58,  4.31it/s]

Epoch 14900, Total loss 60.343845, iteration time 0.235643


 38%|███▊      | 15002/40000 [56:59<2:23:20,  2.91it/s]

Epoch 15000, Total loss 84.234764, iteration time 0.095285


 38%|███▊      | 15102/40000 [57:22<1:38:17,  4.22it/s]

Epoch 15100, Total loss 91.620132, iteration time 0.251218


 38%|███▊      | 15202/40000 [57:45<1:36:40,  4.28it/s]

Epoch 15200, Total loss 109.592674, iteration time 0.236212


 38%|███▊      | 15302/40000 [58:08<1:36:24,  4.27it/s]

Epoch 15300, Total loss 64.137924, iteration time 0.250713


 39%|███▊      | 15402/40000 [58:30<1:36:20,  4.26it/s]

Epoch 15400, Total loss 162.861160, iteration time 0.237625


 39%|███▉      | 15502/40000 [58:53<1:34:51,  4.30it/s]

Epoch 15500, Total loss 126.797127, iteration time 0.236989


 39%|███▉      | 15602/40000 [59:16<1:35:21,  4.26it/s]

Epoch 15600, Total loss 70.961105, iteration time 0.240182


 39%|███▉      | 15702/40000 [59:39<1:36:00,  4.22it/s]

Epoch 15700, Total loss 1108.996826, iteration time 0.244985


 40%|███▉      | 15802/40000 [1:00:01<1:33:41,  4.30it/s]

Epoch 15800, Total loss 86.582527, iteration time 0.236611


 40%|███▉      | 15902/40000 [1:00:24<1:34:02,  4.27it/s]

Epoch 15900, Total loss 98.234512, iteration time 0.241433


 40%|████      | 16002/40000 [1:00:48<2:39:52,  2.50it/s]

Epoch 16000, Total loss 71.438629, iteration time 0.107660


 40%|████      | 16102/40000 [1:01:10<1:32:31,  4.31it/s]

Epoch 16100, Total loss 138.945648, iteration time 0.243863


 41%|████      | 16202/40000 [1:01:33<1:32:38,  4.28it/s]

Epoch 16200, Total loss 287.744995, iteration time 0.236689


 41%|████      | 16302/40000 [1:01:56<1:32:49,  4.26it/s]

Epoch 16300, Total loss 169.426224, iteration time 0.235868


 41%|████      | 16402/40000 [1:02:19<1:31:14,  4.31it/s]

Epoch 16400, Total loss 189.952194, iteration time 0.238681


 41%|████▏     | 16502/40000 [1:02:41<1:30:02,  4.35it/s]

Epoch 16500, Total loss 174.909271, iteration time 0.234065


 42%|████▏     | 16602/40000 [1:03:04<1:30:06,  4.33it/s]

Epoch 16600, Total loss 178.341003, iteration time 0.239259


 42%|████▏     | 16702/40000 [1:03:27<1:29:59,  4.31it/s]

Epoch 16700, Total loss 184.298798, iteration time 0.244132


 42%|████▏     | 16802/40000 [1:03:49<1:30:20,  4.28it/s]

Epoch 16800, Total loss 210.397552, iteration time 0.235781


 42%|████▏     | 16902/40000 [1:04:12<1:30:35,  4.25it/s]

Epoch 16900, Total loss 152.063309, iteration time 0.238446


 43%|████▎     | 17002/40000 [1:04:35<2:12:37,  2.89it/s]

Epoch 17000, Total loss 186.473145, iteration time 0.096053


 43%|████▎     | 17102/40000 [1:04:58<1:29:21,  4.27it/s]

Epoch 17100, Total loss 189.067841, iteration time 0.252898


 43%|████▎     | 17202/40000 [1:05:20<1:30:27,  4.20it/s]

Epoch 17200, Total loss 187.995834, iteration time 0.247228


 43%|████▎     | 17302/40000 [1:05:43<1:27:38,  4.32it/s]

Epoch 17300, Total loss 178.230606, iteration time 0.234570


 44%|████▎     | 17402/40000 [1:06:06<1:27:47,  4.29it/s]

Epoch 17400, Total loss 183.333282, iteration time 0.234650


 44%|████▍     | 17502/40000 [1:06:28<1:27:17,  4.30it/s]

Epoch 17500, Total loss 173.672882, iteration time 0.233167


 44%|████▍     | 17602/40000 [1:06:51<1:26:52,  4.30it/s]

Epoch 17600, Total loss 204.696869, iteration time 0.235510


 44%|████▍     | 17702/40000 [1:07:14<1:26:29,  4.30it/s]

Epoch 17700, Total loss 182.594788, iteration time 0.246576


 45%|████▍     | 17802/40000 [1:07:36<1:25:53,  4.31it/s]

Epoch 17800, Total loss 175.469482, iteration time 0.241266


 45%|████▍     | 17902/40000 [1:07:59<1:25:46,  4.29it/s]

Epoch 17900, Total loss 184.616577, iteration time 0.233874


 45%|████▌     | 18002/40000 [1:08:22<2:36:27,  2.34it/s]

Epoch 18000, Total loss 184.925949, iteration time 0.104036


 45%|████▌     | 18102/40000 [1:08:45<1:25:18,  4.28it/s]

Epoch 18100, Total loss 188.882019, iteration time 0.249321


 46%|████▌     | 18202/40000 [1:09:08<1:23:38,  4.34it/s]

Epoch 18200, Total loss 182.457886, iteration time 0.237893


 46%|████▌     | 18302/40000 [1:09:30<1:24:16,  4.29it/s]

Epoch 18300, Total loss 180.546051, iteration time 0.238225


 46%|████▌     | 18402/40000 [1:09:53<1:23:00,  4.34it/s]

Epoch 18400, Total loss 190.535294, iteration time 0.238404


 46%|████▋     | 18502/40000 [1:10:16<1:22:55,  4.32it/s]

Epoch 18500, Total loss 200.262848, iteration time 0.234261


 47%|████▋     | 18602/40000 [1:10:38<1:23:58,  4.25it/s]

Epoch 18600, Total loss 181.800156, iteration time 0.251006


 47%|████▋     | 18702/40000 [1:11:01<1:22:33,  4.30it/s]

Epoch 18700, Total loss 189.664185, iteration time 0.237022


 47%|████▋     | 18802/40000 [1:11:24<1:22:07,  4.30it/s]

Epoch 18800, Total loss 181.835205, iteration time 0.244052


 47%|████▋     | 18902/40000 [1:11:46<1:21:59,  4.29it/s]

Epoch 18900, Total loss 168.342575, iteration time 0.234261


 48%|████▊     | 19002/40000 [1:12:09<1:57:53,  2.97it/s]

Epoch 19000, Total loss 199.268570, iteration time 0.093941


 48%|████▊     | 19102/40000 [1:12:32<1:20:09,  4.35it/s]

Epoch 19100, Total loss 195.018250, iteration time 0.233350


 48%|████▊     | 19202/40000 [1:12:55<1:20:47,  4.29it/s]

Epoch 19200, Total loss 182.373276, iteration time 0.235614


 48%|████▊     | 19302/40000 [1:13:17<1:20:00,  4.31it/s]

Epoch 19300, Total loss 189.964157, iteration time 0.234356


 49%|████▊     | 19402/40000 [1:13:40<1:19:12,  4.33it/s]

Epoch 19400, Total loss 194.100082, iteration time 0.242520


 49%|████▉     | 19502/40000 [1:14:03<1:19:40,  4.29it/s]

Epoch 19500, Total loss 200.273148, iteration time 0.232102


 49%|████▉     | 19602/40000 [1:14:25<1:19:39,  4.27it/s]

Epoch 19600, Total loss 178.263794, iteration time 0.237522


 49%|████▉     | 19702/40000 [1:14:48<1:18:17,  4.32it/s]

Epoch 19700, Total loss 196.304413, iteration time 0.243052


 50%|████▉     | 19802/40000 [1:15:11<1:19:24,  4.24it/s]

Epoch 19800, Total loss 204.724167, iteration time 0.239987


 50%|████▉     | 19902/40000 [1:15:33<1:18:05,  4.29it/s]

Epoch 19900, Total loss 189.834412, iteration time 0.235184


 50%|█████     | 20002/40000 [1:15:57<1:54:16,  2.92it/s]

Epoch 20000, Total loss 207.295792, iteration time 0.095985


 50%|█████     | 20102/40000 [1:16:19<1:17:44,  4.27it/s]

Epoch 20100, Total loss 205.970703, iteration time 0.241436


 51%|█████     | 20202/40000 [1:16:42<1:16:41,  4.30it/s]

Epoch 20200, Total loss 201.280914, iteration time 0.234759


 51%|█████     | 20302/40000 [1:17:05<1:15:32,  4.35it/s]

Epoch 20300, Total loss 189.182587, iteration time 0.236543


 51%|█████     | 20402/40000 [1:17:27<1:17:38,  4.21it/s]

Epoch 20400, Total loss 188.131226, iteration time 0.242912


 51%|█████▏    | 20502/40000 [1:17:50<1:15:08,  4.33it/s]

Epoch 20500, Total loss 191.883759, iteration time 0.232786


 52%|█████▏    | 20602/40000 [1:18:13<1:15:19,  4.29it/s]

Epoch 20600, Total loss 180.347076, iteration time 0.238123


 52%|█████▏    | 20702/40000 [1:18:35<1:15:26,  4.26it/s]

Epoch 20700, Total loss 191.082031, iteration time 0.237346


 52%|█████▏    | 20802/40000 [1:18:58<1:13:56,  4.33it/s]

Epoch 20800, Total loss 234.632553, iteration time 0.233053


 52%|█████▏    | 20902/40000 [1:19:21<1:14:29,  4.27it/s]

Epoch 20900, Total loss 212.362442, iteration time 0.244482


 53%|█████▎    | 21002/40000 [1:19:44<2:08:28,  2.46it/s]

Epoch 21000, Total loss 186.148880, iteration time 0.099804


 53%|█████▎    | 21102/40000 [1:20:07<1:13:39,  4.28it/s]

Epoch 21100, Total loss 189.408661, iteration time 0.234604


 53%|█████▎    | 21202/40000 [1:20:29<1:12:54,  4.30it/s]

Epoch 21200, Total loss 193.686676, iteration time 0.237188


 53%|█████▎    | 21302/40000 [1:20:52<1:12:52,  4.28it/s]

Epoch 21300, Total loss 198.130737, iteration time 0.232297


 54%|█████▎    | 21402/40000 [1:21:15<1:12:08,  4.30it/s]

Epoch 21400, Total loss 184.925323, iteration time 0.240039


 54%|█████▍    | 21502/40000 [1:21:37<1:11:35,  4.31it/s]

Epoch 21500, Total loss 217.164825, iteration time 0.233472


 54%|█████▍    | 21602/40000 [1:22:00<1:11:34,  4.28it/s]

Epoch 21600, Total loss 207.304718, iteration time 0.234982


 54%|█████▍    | 21702/40000 [1:22:23<1:10:38,  4.32it/s]

Epoch 21700, Total loss 202.748856, iteration time 0.233222


 55%|█████▍    | 21802/40000 [1:22:45<1:10:32,  4.30it/s]

Epoch 21800, Total loss 186.944977, iteration time 0.234401


 55%|█████▍    | 21902/40000 [1:23:08<1:10:04,  4.30it/s]

Epoch 21900, Total loss 197.216934, iteration time 0.244787


 55%|█████▌    | 22002/40000 [1:23:31<2:00:37,  2.49it/s]

Epoch 22000, Total loss 205.556030, iteration time 0.102317


 55%|█████▌    | 22102/40000 [1:23:54<1:10:00,  4.26it/s]

Epoch 22100, Total loss 252.917709, iteration time 0.242299


 56%|█████▌    | 22202/40000 [1:24:17<1:09:46,  4.25it/s]

Epoch 22200, Total loss 203.243256, iteration time 0.244173


 56%|█████▌    | 22302/40000 [1:24:40<1:07:38,  4.36it/s]

Epoch 22300, Total loss 193.060272, iteration time 0.225783


 56%|█████▌    | 22402/40000 [1:25:02<1:07:29,  4.35it/s]

Epoch 22400, Total loss 186.297348, iteration time 0.239145


 56%|█████▋    | 22502/40000 [1:25:25<1:09:11,  4.22it/s]

Epoch 22500, Total loss 207.015686, iteration time 0.234853


 57%|█████▋    | 22602/40000 [1:25:47<1:07:41,  4.28it/s]

Epoch 22600, Total loss 213.496826, iteration time 0.234420


 57%|█████▋    | 22702/40000 [1:26:10<1:06:10,  4.36it/s]

Epoch 22700, Total loss 209.451538, iteration time 0.242728


 57%|█████▋    | 22802/40000 [1:26:33<1:07:46,  4.23it/s]

Epoch 22800, Total loss 199.462189, iteration time 0.234916


 57%|█████▋    | 22902/40000 [1:26:55<1:05:52,  4.33it/s]

Epoch 22900, Total loss 256.720795, iteration time 0.232318


 58%|█████▊    | 23002/40000 [1:27:19<1:53:49,  2.49it/s]

Epoch 23000, Total loss 185.759979, iteration time 0.104937


 58%|█████▊    | 23102/40000 [1:27:41<1:05:40,  4.29it/s]

Epoch 23100, Total loss 215.256821, iteration time 0.233466


 58%|█████▊    | 23202/40000 [1:28:04<1:04:56,  4.31it/s]

Epoch 23200, Total loss 214.501923, iteration time 0.234781


 58%|█████▊    | 23302/40000 [1:28:27<1:04:53,  4.29it/s]

Epoch 23300, Total loss 192.846466, iteration time 0.245230


 59%|█████▊    | 23402/40000 [1:28:49<1:04:13,  4.31it/s]

Epoch 23400, Total loss 210.062286, iteration time 0.234886


 59%|█████▉    | 23502/40000 [1:29:12<1:04:12,  4.28it/s]

Epoch 23500, Total loss 193.010483, iteration time 0.239156


 59%|█████▉    | 23602/40000 [1:29:35<1:04:29,  4.24it/s]

Epoch 23600, Total loss 227.351944, iteration time 0.246318


 59%|█████▉    | 23702/40000 [1:29:57<1:02:42,  4.33it/s]

Epoch 23700, Total loss 208.296021, iteration time 0.232617


 60%|█████▉    | 23802/40000 [1:30:20<1:02:09,  4.34it/s]

Epoch 23800, Total loss 225.530914, iteration time 0.232234


 60%|█████▉    | 23902/40000 [1:30:43<1:02:32,  4.29it/s]

Epoch 23900, Total loss 213.025696, iteration time 0.240066


 60%|██████    | 24002/40000 [1:31:06<1:31:52,  2.90it/s]

Epoch 24000, Total loss 202.020721, iteration time 0.092448


 60%|██████    | 24102/40000 [1:31:28<1:01:00,  4.34it/s]

Epoch 24100, Total loss 212.662994, iteration time 0.228059


 61%|██████    | 24202/40000 [1:31:51<1:01:08,  4.31it/s]

Epoch 24200, Total loss 212.702728, iteration time 0.239402


 61%|██████    | 24302/40000 [1:32:14<1:01:20,  4.27it/s]

Epoch 24300, Total loss 194.582855, iteration time 0.233862


 61%|██████    | 24402/40000 [1:32:36<1:01:25,  4.23it/s]

Epoch 24400, Total loss 224.886078, iteration time 0.242545


 61%|██████▏   | 24502/40000 [1:32:59<1:00:28,  4.27it/s]

Epoch 24500, Total loss 216.757263, iteration time 0.237986


 62%|██████▏   | 24602/40000 [1:33:22<59:32,  4.31it/s]

Epoch 24600, Total loss 204.083710, iteration time 0.240924


 62%|██████▏   | 24702/40000 [1:33:45<58:48,  4.34it/s]

Epoch 24700, Total loss 212.339844, iteration time 0.236716


 62%|██████▏   | 24802/40000 [1:34:07<59:15,  4.28it/s]

Epoch 24800, Total loss 204.325104, iteration time 0.244636


 62%|██████▏   | 24902/40000 [1:34:30<57:56,  4.34it/s]

Epoch 24900, Total loss 227.194931, iteration time 0.233760


 63%|██████▎   | 25002/40000 [1:34:53<1:27:08,  2.87it/s]

Epoch 25000, Total loss 216.102844, iteration time 0.095966


 63%|██████▎   | 25102/40000 [1:35:16<58:14,  4.26it/s]

Epoch 25100, Total loss 231.849823, iteration time 0.244421


 63%|██████▎   | 25202/40000 [1:35:38<58:09,  4.24it/s]

Epoch 25200, Total loss 197.189423, iteration time 0.242904


 63%|██████▎   | 25302/40000 [1:36:01<56:31,  4.33it/s]

Epoch 25300, Total loss 219.770905, iteration time 0.233084


 64%|██████▎   | 25402/40000 [1:36:24<56:27,  4.31it/s]

Epoch 25400, Total loss 220.165070, iteration time 0.229814


 64%|██████▍   | 25502/40000 [1:36:46<55:59,  4.32it/s]

Epoch 25500, Total loss 230.507965, iteration time 0.233922


 64%|██████▍   | 25602/40000 [1:37:09<55:47,  4.30it/s]

Epoch 25600, Total loss 207.535614, iteration time 0.235638


 64%|██████▍   | 25702/40000 [1:37:31<54:50,  4.34it/s]

Epoch 25700, Total loss 223.384094, iteration time 0.233905


 65%|██████▍   | 25802/40000 [1:37:54<55:01,  4.30it/s]

Epoch 25800, Total loss 219.375336, iteration time 0.241164


 65%|██████▍   | 25902/40000 [1:38:17<54:55,  4.28it/s]

Epoch 25900, Total loss 229.871704, iteration time 0.233901


 65%|██████▌   | 26002/40000 [1:38:40<1:33:14,  2.50it/s]

Epoch 26000, Total loss 226.501953, iteration time 0.100855


 65%|██████▌   | 26102/40000 [1:39:03<53:28,  4.33it/s]

Epoch 26100, Total loss 220.718994, iteration time 0.235428


 66%|██████▌   | 26202/40000 [1:39:25<54:00,  4.26it/s]

Epoch 26200, Total loss 205.127228, iteration time 0.240720


 66%|██████▌   | 26302/40000 [1:39:48<52:53,  4.32it/s]

Epoch 26300, Total loss 229.030548, iteration time 0.234219


 66%|██████▌   | 26402/40000 [1:40:11<52:33,  4.31it/s]

Epoch 26400, Total loss 217.772217, iteration time 0.233574


 66%|██████▋   | 26502/40000 [1:40:33<53:09,  4.23it/s]

Epoch 26500, Total loss 238.639420, iteration time 0.236840


 67%|██████▋   | 26602/40000 [1:40:56<51:38,  4.32it/s]

Epoch 26600, Total loss 240.850327, iteration time 0.232473


 67%|██████▋   | 26702/40000 [1:41:19<51:03,  4.34it/s]

Epoch 26700, Total loss 237.268631, iteration time 0.233142


 67%|██████▋   | 26802/40000 [1:41:41<51:33,  4.27it/s]

Epoch 26800, Total loss 242.544220, iteration time 0.246878


 67%|██████▋   | 26902/40000 [1:42:04<50:38,  4.31it/s]

Epoch 26900, Total loss 232.290070, iteration time 0.234416


 68%|██████▊   | 27002/40000 [1:42:27<1:15:30,  2.87it/s]

Epoch 27000, Total loss 237.468353, iteration time 0.093263


 68%|██████▊   | 27102/40000 [1:42:50<50:51,  4.23it/s]

Epoch 27100, Total loss 230.781952, iteration time 0.247886


 68%|██████▊   | 27202/40000 [1:43:12<49:09,  4.34it/s]

Epoch 27200, Total loss 218.560760, iteration time 0.241525


 68%|██████▊   | 27302/40000 [1:43:35<48:41,  4.35it/s]

Epoch 27300, Total loss 208.452942, iteration time 0.236084


 69%|██████▊   | 27402/40000 [1:43:58<49:09,  4.27it/s]

Epoch 27400, Total loss 234.246704, iteration time 0.232933


 69%|██████▉   | 27502/40000 [1:44:20<48:15,  4.32it/s]

Epoch 27500, Total loss 215.435394, iteration time 0.232116


 69%|██████▉   | 27602/40000 [1:44:43<48:33,  4.26it/s]

Epoch 27600, Total loss 246.519714, iteration time 0.236419


 69%|██████▉   | 27702/40000 [1:45:06<47:55,  4.28it/s]

Epoch 27700, Total loss 225.394302, iteration time 0.234792


 70%|██████▉   | 27802/40000 [1:45:28<47:35,  4.27it/s]

Epoch 27800, Total loss 230.803528, iteration time 0.246250


 70%|██████▉   | 27902/40000 [1:45:51<46:54,  4.30it/s]

Epoch 27900, Total loss 240.250290, iteration time 0.236202


 70%|███████   | 28002/40000 [1:46:15<1:45:37,  1.89it/s]

Epoch 28000, Total loss 296.839203, iteration time 0.102054


 70%|███████   | 28102/40000 [1:46:38<45:55,  4.32it/s]

Epoch 28100, Total loss 232.331253, iteration time 0.234660


 71%|███████   | 28202/40000 [1:47:00<46:03,  4.27it/s]

Epoch 28200, Total loss 238.893005, iteration time 0.234397


 71%|███████   | 28302/40000 [1:47:23<46:05,  4.23it/s]

Epoch 28300, Total loss 226.786255, iteration time 0.246163


 71%|███████   | 28402/40000 [1:47:46<45:11,  4.28it/s]

Epoch 28400, Total loss 226.559814, iteration time 0.240514


 71%|███████▏  | 28502/40000 [1:48:09<44:33,  4.30it/s]

Epoch 28500, Total loss 246.884888, iteration time 0.233507


 72%|███████▏  | 28602/40000 [1:48:31<44:08,  4.30it/s]

Epoch 28600, Total loss 224.903229, iteration time 0.238910


 72%|███████▏  | 28702/40000 [1:48:54<43:53,  4.29it/s]

Epoch 28700, Total loss 211.481964, iteration time 0.242214


 72%|███████▏  | 28802/40000 [1:49:17<43:12,  4.32it/s]

Epoch 28800, Total loss 235.511230, iteration time 0.232653


 72%|███████▏  | 28902/40000 [1:49:39<43:28,  4.25it/s]

Epoch 28900, Total loss 231.645142, iteration time 0.229316


 73%|███████▎  | 29002/40000 [1:50:03<1:05:17,  2.81it/s]

Epoch 29000, Total loss 207.296127, iteration time 0.096330


 73%|███████▎  | 29102/40000 [1:50:25<41:55,  4.33it/s]

Epoch 29100, Total loss 224.044830, iteration time 0.234816


 73%|███████▎  | 29202/40000 [1:50:48<42:25,  4.24it/s]

Epoch 29200, Total loss 278.278381, iteration time 0.250556


 73%|███████▎  | 29302/40000 [1:51:11<41:13,  4.32it/s]

Epoch 29300, Total loss 247.870255, iteration time 0.233745


 74%|███████▎  | 29402/40000 [1:51:33<41:47,  4.23it/s]

Epoch 29400, Total loss 247.829224, iteration time 0.246438


 74%|███████▍  | 29502/40000 [1:51:56<40:28,  4.32it/s]

Epoch 29500, Total loss 217.294662, iteration time 0.234661


 74%|███████▍  | 29602/40000 [1:52:19<40:17,  4.30it/s]

Epoch 29600, Total loss 265.881836, iteration time 0.239615


 74%|███████▍  | 29702/40000 [1:52:41<40:18,  4.26it/s]

Epoch 29700, Total loss 238.960175, iteration time 0.238176


 75%|███████▍  | 29802/40000 [1:53:04<39:34,  4.29it/s]

Epoch 29800, Total loss 223.887329, iteration time 0.233017


 75%|███████▍  | 29902/40000 [1:53:27<39:08,  4.30it/s]

Epoch 29900, Total loss 320.197327, iteration time 0.233182


 75%|███████▌  | 30002/40000 [1:53:50<1:14:15,  2.24it/s]

Epoch 30000, Total loss 224.568253, iteration time 0.117054


 75%|███████▌  | 30102/40000 [1:54:13<37:54,  4.35it/s]

Epoch 30100, Total loss 250.417603, iteration time 0.234518


 76%|███████▌  | 30202/40000 [1:54:36<38:06,  4.28it/s]

Epoch 30200, Total loss 232.508636, iteration time 0.238976


 76%|███████▌  | 30302/40000 [1:54:58<37:51,  4.27it/s]

Epoch 30300, Total loss 246.197067, iteration time 0.239120


 76%|███████▌  | 30402/40000 [1:55:21<36:52,  4.34it/s]

Epoch 30400, Total loss 239.802994, iteration time 0.240190


 76%|███████▋  | 30502/40000 [1:55:44<37:00,  4.28it/s]

Epoch 30500, Total loss 267.824402, iteration time 0.252737


 77%|███████▋  | 30602/40000 [1:56:06<36:22,  4.31it/s]

Epoch 30600, Total loss 220.364655, iteration time 0.232911


 77%|███████▋  | 30702/40000 [1:56:29<35:55,  4.31it/s]

Epoch 30700, Total loss 233.334808, iteration time 0.232348


 77%|███████▋  | 30802/40000 [1:56:52<36:17,  4.22it/s]

Epoch 30800, Total loss 231.686890, iteration time 0.239955


 77%|███████▋  | 30902/40000 [1:57:14<35:07,  4.32it/s]

Epoch 30900, Total loss 223.781799, iteration time 0.239735


 78%|███████▊  | 31002/40000 [1:57:38<52:15,  2.87it/s]  

Epoch 31000, Total loss 256.092224, iteration time 0.094450


 78%|███████▊  | 31102/40000 [1:58:00<34:48,  4.26it/s]

Epoch 31100, Total loss 258.097198, iteration time 0.241724


 78%|███████▊  | 31202/40000 [1:58:23<33:48,  4.34it/s]

Epoch 31200, Total loss 233.200012, iteration time 0.231320


 78%|███████▊  | 31302/40000 [1:58:46<33:25,  4.34it/s]

Epoch 31300, Total loss 258.188416, iteration time 0.232607


 79%|███████▊  | 31402/40000 [1:59:08<33:50,  4.23it/s]

Epoch 31400, Total loss 239.913513, iteration time 0.248321


 79%|███████▉  | 31502/40000 [1:59:31<32:32,  4.35it/s]

Epoch 31500, Total loss 243.624023, iteration time 0.232352


 79%|███████▉  | 31602/40000 [1:59:53<33:27,  4.18it/s]

Epoch 31600, Total loss 260.887512, iteration time 0.239347


 79%|███████▉  | 31702/40000 [2:00:16<32:15,  4.29it/s]

Epoch 31700, Total loss 231.124344, iteration time 0.234121


 80%|███████▉  | 31802/40000 [2:00:39<31:30,  4.34it/s]

Epoch 31800, Total loss 254.964447, iteration time 0.234541


 80%|███████▉  | 31902/40000 [2:01:01<31:31,  4.28it/s]

Epoch 31900, Total loss 234.838928, iteration time 0.231323


 80%|████████  | 32002/40000 [2:01:25<46:37,  2.86it/s]

Epoch 32000, Total loss 255.205872, iteration time 0.096014


 80%|████████  | 32102/40000 [2:01:47<30:21,  4.34it/s]

Epoch 32100, Total loss 237.945129, iteration time 0.226271


 81%|████████  | 32202/40000 [2:02:10<30:38,  4.24it/s]

Epoch 32200, Total loss 272.317993, iteration time 0.242346


 81%|████████  | 32302/40000 [2:02:33<29:52,  4.29it/s]

Epoch 32300, Total loss 264.194580, iteration time 0.235689


 81%|████████  | 32402/40000 [2:02:56<29:20,  4.32it/s]

Epoch 32400, Total loss 259.964355, iteration time 0.233918


 81%|████████▏ | 32502/40000 [2:03:18<29:01,  4.30it/s]

Epoch 32500, Total loss 226.015167, iteration time 0.234256


 82%|████████▏ | 32602/40000 [2:03:41<28:25,  4.34it/s]

Epoch 32600, Total loss 246.016174, iteration time 0.238197


 82%|████████▏ | 32702/40000 [2:04:04<28:24,  4.28it/s]

Epoch 32700, Total loss 299.166046, iteration time 0.236497


 82%|████████▏ | 32802/40000 [2:04:27<27:54,  4.30it/s]

Epoch 32800, Total loss 276.549042, iteration time 0.236871


 82%|████████▏ | 32902/40000 [2:04:49<27:33,  4.29it/s]

Epoch 32900, Total loss 243.312408, iteration time 0.235525


 83%|████████▎ | 33002/40000 [2:05:13<43:26,  2.68it/s]

Epoch 33000, Total loss 276.717010, iteration time 0.098107


 83%|████████▎ | 33102/40000 [2:05:35<26:46,  4.29it/s]

Epoch 33100, Total loss 257.021118, iteration time 0.233412


 83%|████████▎ | 33202/40000 [2:05:58<26:26,  4.28it/s]

Epoch 33200, Total loss 249.659164, iteration time 0.232153


 83%|████████▎ | 33302/40000 [2:06:21<26:01,  4.29it/s]

Epoch 33300, Total loss 241.097382, iteration time 0.237361


 84%|████████▎ | 33402/40000 [2:06:44<25:39,  4.28it/s]

Epoch 33400, Total loss 258.901367, iteration time 0.233152


 84%|████████▍ | 33502/40000 [2:07:07<25:03,  4.32it/s]

Epoch 33500, Total loss 247.298904, iteration time 0.236541


 84%|████████▍ | 33602/40000 [2:07:29<25:07,  4.25it/s]

Epoch 33600, Total loss 240.969543, iteration time 0.247370


 84%|████████▍ | 33702/40000 [2:07:52<24:16,  4.32it/s]

Epoch 33700, Total loss 243.893402, iteration time 0.235099


 85%|████████▍ | 33802/40000 [2:08:15<24:04,  4.29it/s]

Epoch 33800, Total loss 295.900146, iteration time 0.239902


 85%|████████▍ | 33902/40000 [2:08:37<23:49,  4.27it/s]

Epoch 33900, Total loss 248.842773, iteration time 0.237531


 85%|████████▌ | 34002/40000 [2:09:01<35:19,  2.83it/s]

Epoch 34000, Total loss 231.945572, iteration time 0.097545


 85%|████████▌ | 34102/40000 [2:09:24<23:09,  4.25it/s]

Epoch 34100, Total loss 254.183823, iteration time 0.243257


 86%|████████▌ | 34202/40000 [2:09:46<22:23,  4.32it/s]

Epoch 34200, Total loss 247.087891, iteration time 0.233589


 86%|████████▌ | 34302/40000 [2:10:09<22:24,  4.24it/s]

Epoch 34300, Total loss 246.667145, iteration time 0.243048


 86%|████████▌ | 34402/40000 [2:10:32<21:46,  4.28it/s]

Epoch 34400, Total loss 294.412842, iteration time 0.239132


 86%|████████▋ | 34502/40000 [2:10:55<21:16,  4.31it/s]

Epoch 34500, Total loss 283.438568, iteration time 0.235296


 87%|████████▋ | 34602/40000 [2:11:17<20:59,  4.29it/s]

Epoch 34600, Total loss 268.569794, iteration time 0.238469


 87%|████████▋ | 34702/40000 [2:11:40<20:56,  4.22it/s]

Epoch 34700, Total loss 266.711609, iteration time 0.243774


 87%|████████▋ | 34802/40000 [2:12:03<20:06,  4.31it/s]

Epoch 34800, Total loss 266.396881, iteration time 0.235360


 87%|████████▋ | 34902/40000 [2:12:26<19:55,  4.26it/s]

Epoch 34900, Total loss 265.401672, iteration time 0.231278


 88%|████████▊ | 35002/40000 [2:12:49<34:38,  2.41it/s]

Epoch 35000, Total loss 249.605927, iteration time 0.109180


 88%|████████▊ | 35102/40000 [2:13:12<18:54,  4.32it/s]

Epoch 35100, Total loss 257.985657, iteration time 0.240375


 88%|████████▊ | 35202/40000 [2:13:35<18:49,  4.25it/s]

Epoch 35200, Total loss 272.539856, iteration time 0.241008


 88%|████████▊ | 35302/40000 [2:13:57<18:20,  4.27it/s]

Epoch 35300, Total loss 244.720383, iteration time 0.232879


 89%|████████▊ | 35402/40000 [2:14:20<17:53,  4.28it/s]

Epoch 35400, Total loss 268.609558, iteration time 0.237635


 89%|████████▉ | 35502/40000 [2:14:43<17:29,  4.29it/s]

Epoch 35500, Total loss 307.687073, iteration time 0.233822


 89%|████████▉ | 35602/40000 [2:15:06<16:57,  4.32it/s]

Epoch 35600, Total loss 285.442352, iteration time 0.234483


 89%|████████▉ | 35702/40000 [2:15:28<16:39,  4.30it/s]

Epoch 35700, Total loss 259.795868, iteration time 0.232022


 90%|████████▉ | 35802/40000 [2:15:51<16:36,  4.21it/s]

Epoch 35800, Total loss 266.166199, iteration time 0.238419


 90%|████████▉ | 35902/40000 [2:16:14<15:49,  4.32it/s]

Epoch 35900, Total loss 307.476318, iteration time 0.232938


 90%|█████████ | 36002/40000 [2:16:37<23:48,  2.80it/s]

Epoch 36000, Total loss 261.181274, iteration time 0.095647


 90%|█████████ | 36102/40000 [2:17:00<15:10,  4.28it/s]

Epoch 36100, Total loss 243.629257, iteration time 0.236601


 91%|█████████ | 36202/40000 [2:17:23<14:37,  4.33it/s]

Epoch 36200, Total loss 231.180023, iteration time 0.235440


 91%|█████████ | 36302/40000 [2:17:45<14:21,  4.29it/s]

Epoch 36300, Total loss 282.683594, iteration time 0.230103


 91%|█████████ | 36402/40000 [2:18:08<13:55,  4.31it/s]

Epoch 36400, Total loss 263.492767, iteration time 0.235023


 91%|█████████▏| 36502/40000 [2:18:31<13:31,  4.31it/s]

Epoch 36500, Total loss 266.110413, iteration time 0.232779


 92%|█████████▏| 36602/40000 [2:18:54<13:09,  4.30it/s]

Epoch 36600, Total loss 332.828217, iteration time 0.242337


 92%|█████████▏| 36702/40000 [2:19:16<12:41,  4.33it/s]

Epoch 36700, Total loss 270.935425, iteration time 0.234577


 92%|█████████▏| 36802/40000 [2:19:39<12:20,  4.32it/s]

Epoch 36800, Total loss 330.950439, iteration time 0.233659


 92%|█████████▏| 36902/40000 [2:20:02<12:14,  4.22it/s]

Epoch 36900, Total loss 295.116852, iteration time 0.239424


 93%|█████████▎| 37002/40000 [2:20:25<17:44,  2.82it/s]

Epoch 37000, Total loss 297.224731, iteration time 0.096295


 93%|█████████▎| 37102/40000 [2:20:48<11:08,  4.33it/s]

Epoch 37100, Total loss 289.134460, iteration time 0.232198


 93%|█████████▎| 37202/40000 [2:21:10<10:57,  4.25it/s]

Epoch 37200, Total loss 276.678192, iteration time 0.243912


 93%|█████████▎| 37302/40000 [2:21:33<10:27,  4.30it/s]

Epoch 37300, Total loss 272.772614, iteration time 0.238134


 94%|█████████▎| 37402/40000 [2:21:55<10:02,  4.31it/s]

Epoch 37400, Total loss 275.001129, iteration time 0.233809


 94%|█████████▍| 37502/40000 [2:22:18<09:37,  4.32it/s]

Epoch 37500, Total loss 251.423386, iteration time 0.236192


 94%|█████████▍| 37602/40000 [2:22:41<09:24,  4.25it/s]

Epoch 37600, Total loss 333.483521, iteration time 0.234170


 94%|█████████▍| 37702/40000 [2:23:04<08:59,  4.26it/s]

Epoch 37700, Total loss 287.499176, iteration time 0.236069


 95%|█████████▍| 37802/40000 [2:23:26<08:41,  4.22it/s]

Epoch 37800, Total loss 300.497375, iteration time 0.245306


 95%|█████████▍| 37902/40000 [2:23:49<08:10,  4.28it/s]

Epoch 37900, Total loss 287.628845, iteration time 0.233866


 95%|█████████▌| 38002/40000 [2:24:12<11:57,  2.79it/s]

Epoch 38000, Total loss 260.342438, iteration time 0.094210


 95%|█████████▌| 38102/40000 [2:24:35<07:29,  4.22it/s]

Epoch 38100, Total loss 248.926041, iteration time 0.243790


 96%|█████████▌| 38202/40000 [2:24:58<06:58,  4.30it/s]

Epoch 38200, Total loss 294.290039, iteration time 0.239795


 96%|█████████▌| 38302/40000 [2:25:20<06:33,  4.32it/s]

Epoch 38300, Total loss 278.621643, iteration time 0.235316


 96%|█████████▌| 38402/40000 [2:25:43<06:13,  4.28it/s]

Epoch 38400, Total loss 268.309021, iteration time 0.234066


 96%|█████████▋| 38502/40000 [2:26:05<05:49,  4.29it/s]

Epoch 38500, Total loss 271.743469, iteration time 0.231547


 97%|█████████▋| 38602/40000 [2:26:28<05:22,  4.33it/s]

Epoch 38600, Total loss 275.250305, iteration time 0.232060


 97%|█████████▋| 38702/40000 [2:26:51<05:02,  4.30it/s]

Epoch 38700, Total loss 291.657104, iteration time 0.238033


 97%|█████████▋| 38802/40000 [2:27:13<04:37,  4.32it/s]

Epoch 38800, Total loss 289.956665, iteration time 0.237730


 97%|█████████▋| 38902/40000 [2:27:36<04:14,  4.31it/s]

Epoch 38900, Total loss 272.182556, iteration time 0.225770


 98%|█████████▊| 39002/40000 [2:27:59<05:53,  2.82it/s]

Epoch 39000, Total loss 291.813263, iteration time 0.096154


 98%|█████████▊| 39102/40000 [2:28:22<03:28,  4.31it/s]

Epoch 39100, Total loss 287.149078, iteration time 0.238046


 98%|█████████▊| 39202/40000 [2:28:44<03:05,  4.31it/s]

Epoch 39200, Total loss 276.416504, iteration time 0.232572


 98%|█████████▊| 39302/40000 [2:29:07<02:40,  4.34it/s]

Epoch 39300, Total loss 273.309509, iteration time 0.234007


 99%|█████████▊| 39402/40000 [2:29:30<02:19,  4.29it/s]

Epoch 39400, Total loss 298.372253, iteration time 0.238801


 99%|█████████▉| 39502/40000 [2:29:53<01:56,  4.26it/s]

Epoch 39500, Total loss 276.380249, iteration time 0.242182


 99%|█████████▉| 39602/40000 [2:30:15<01:33,  4.28it/s]

Epoch 39600, Total loss 296.511353, iteration time 0.233811


 99%|█████████▉| 39702/40000 [2:30:38<01:09,  4.32it/s]

Epoch 39700, Total loss 288.798615, iteration time 0.233227


100%|█████████▉| 39802/40000 [2:31:01<00:46,  4.27it/s]

Epoch 39800, Total loss 309.681152, iteration time 0.246232


100%|█████████▉| 39902/40000 [2:31:23<00:22,  4.32it/s]

Epoch 39900, Total loss 269.919373, iteration time 0.234306


100%|██████████| 40000/40000 [2:31:46<00:00,  4.39it/s]
