### 1. Install Dependencies

In [1]:
! pip install configargparse

Collecting configargparse
  Downloading ConfigArgParse-1.7-py3-none-any.whl.metadata (23 kB)
Downloading ConfigArgParse-1.7-py3-none-any.whl (25 kB)
Installing collected packages: configargparse
Successfully installed configargparse-1.7


In [2]:
! pip install torch torchvision torchaudio

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

### 2. Imports

In [3]:
import os
import shutil
import time
from datetime import datetime
import pickle
from google.colab import drive

from abc import ABC, abstractmethod
from collections import OrderedDict

import math
import random
import numpy as np

import torch
from torch.utils.data import Dataset
from torch.autograd import grad
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch import nn

import wandb

import matplotlib
import matplotlib.pyplot as plt
import plotly.express as px
import scipy.io as spio

from tqdm.autonotebook import tqdm
from sklearn import svm

import configargparse

  from tqdm.autonotebook import tqdm


In [4]:
torch.cuda.is_available()

True

### 3. Connect to Google Drive

In [5]:
drive.mount("/content/drive",force_remount=True)
os.chdir("/content/drive/My Drive")

Mounted at /content/drive


### 4. Utils Code

In [6]:
# uses model input and real boundary fn
class ReachabilityDataset(Dataset):
    def __init__(self, dynamics, numpoints, pretrain, pretrain_iters, tMin, tMax, counter_start, counter_end, num_src_samples, num_target_samples):
        self.dynamics = dynamics
        self.numpoints = numpoints
        self.pretrain = pretrain
        self.pretrain_counter = 0
        self.pretrain_iters = pretrain_iters
        self.tMin = tMin
        self.tMax = tMax
        self.counter = counter_start
        self.counter_end = counter_end
        self.num_src_samples = num_src_samples
        self.num_target_samples = num_target_samples

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        # uniformly sample domain and include coordinates where source is non-zero
        model_states = torch.zeros(self.numpoints, self.dynamics.state_dim).uniform_(-1, 1)
        if self.num_target_samples > 0:
            target_state_samples = self.dynamics.sample_target_state(self.num_target_samples)
            model_states[-self.num_target_samples:] = self.dynamics.coord_to_input(torch.cat((torch.zeros(self.num_target_samples, 1), target_state_samples), dim=-1))[:, 1:self.dynamics.state_dim+1]

        if self.pretrain:
            # only sample in time around the initial condition
            times = torch.full((self.numpoints, 1), self.tMin)
        else:
            # slowly grow time values from start time
            times = self.tMin + torch.zeros(self.numpoints, 1).uniform_(0, (self.tMax-self.tMin) * (self.counter / self.counter_end))
            # make sure we always have training samples at the initial time
            times[-self.num_src_samples:, 0] = self.tMin
        model_coords = torch.cat((times, model_states), dim=1)
        if self.dynamics.input_dim > self.dynamics.state_dim + 1: # temporary workaround for having to deal with dynamics classes for parametrized models with extra inputs
            model_coords = torch.cat((model_coords, torch.zeros(self.numpoints, self.dynamics.input_dim - self.dynamics.state_dim - 1)), dim=1)

        boundary_values = self.dynamics.boundary_fn(self.dynamics.input_to_coord(model_coords)[..., 1:])
        if self.dynamics.loss_type == 'brat_hjivi':
            reach_values = self.dynamics.reach_fn(self.dynamics.input_to_coord(model_coords)[..., 1:])
            avoid_values = self.dynamics.avoid_fn(self.dynamics.input_to_coord(model_coords)[..., 1:])

        if self.pretrain:
            dirichlet_masks = torch.ones(model_coords.shape[0]) > 0
        else:
            # only enforce initial conditions around self.tMin
            dirichlet_masks = (model_coords[:, 0] == self.tMin)

        if self.pretrain:
            self.pretrain_counter += 1
        elif self.counter < self.counter_end:
            self.counter += 1

        if self.pretrain and self.pretrain_counter == self.pretrain_iters:
            self.pretrain = False

        if self.dynamics.loss_type == 'brt_hjivi':
            return {'model_coords': model_coords}, {'boundary_values': boundary_values, 'dirichlet_masks': dirichlet_masks}
        elif self.dynamics.loss_type == 'brat_hjivi':
            return {'model_coords': model_coords}, {'boundary_values': boundary_values, 'reach_values': reach_values, 'avoid_values': avoid_values, 'dirichlet_masks': dirichlet_masks}
        else:
            raise NotImplementedError

In [7]:
# TODO: I don't think jacobian is needed here; torch.autograd.grad should be enough, to compute gradients of a scalar value function w.r.t. inputs

# batched jacobian
# y: [..., N], x: [..., M] -> [..., N, M]
def jacobian(y, x):
    ''' jacobian of y wrt x '''
    jac = torch.zeros(*y.shape, x.shape[-1]).to(y.device)
    for i in range(y.shape[-1]):
        # calculate dydx over batches for each feature value of y
        y_flat = y[...,i].view(-1, 1)
        jac[..., i, :] = grad(y_flat, x, torch.ones_like(y_flat), create_graph=True)[0]

    status = 0
    if torch.any(torch.isnan(jac)):
        status = -1

    return jac, status

In [8]:
class Validator(ABC):
    @abstractmethod
    def validate(self, coords, values):
        raise NotImplementedError

class ValueThresholdValidator(Validator):
    def __init__(self, v_min, v_max):
        self.v_min = v_min
        self.v_max = v_max

    def validate(self, coords, values):
        return (values >= self.v_min)*(values <= self.v_max)

class MLPValidator(Validator):
    def __init__(self, device, mlp, o_min, o_max, model, dynamics):
        self.device = device
        self.mlp = mlp
        self.o_min = o_min
        self.o_max = o_max
        self.model = model
        self.dynamics = dynamics

    def validate(self, coords, values):
        model_results = self.model({'coords': self.dynamics.coord_to_input(coords.to(self.device))})
        inputs = torch.cat((coords[..., 1:].to(self.device), values[:, None].to(self.device)), dim=-1)
        outputs = torch.sigmoid(self.mlp(inputs).squeeze())
        return ((outputs >= self.o_min)*(outputs <=self.o_max)).to(device=values.device)

class MLPConditionedValidator(Validator):
    def __init__(self, device, mlp, o_levels, v_levels, model, dynamics):
        self.device = device
        self.mlp = mlp
        self.o_levels = o_levels
        self.v_levels = v_levels
        self.model = model
        self.dynamics = dynamics
        assert len(self.o_levels) == len(self.v_levels) + 1

    def validate(self, coords, values):
        model_results = self.model({'coords': self.dynamics.coord_to_input(coords.to(self.device))})
        inputs = torch.cat((coords[..., 1:].to(self.device), values[:, None].to(self.device)), dim=-1)
        outputs = torch.sigmoid(self.mlp(inputs).squeeze(dim=-1)).to(device=values.device)
        valids = torch.zeros_like(outputs)
        for i in range(len(self.o_levels) - 1):
            valids = torch.logical_or(
                valids,
                (outputs > self.o_levels[i])*(outputs <= self.o_levels[i+1])*(values >= self.v_levels[i][0])*(values <= self.v_levels[i][1])
            )
        return valids

class MultiValidator(Validator):
    def __init__(self, validators):
        self.validators = validators

    def validate(self, coords, values):
        result = self.validators[0].validate(coords, values)
        for i in range(len(self.validators)-1):
            result = result * self.validators[i+1].validate(coords, values)
        return result

class SampleGenerator(ABC):
    @abstractmethod
    def sample(self, num_samples):
        raise NotImplementedError

class SliceSampleGenerator(SampleGenerator):
    def __init__(self, dynamics, slices):
        self.dynamics = dynamics
        self.slices = slices
        assert self.dynamics.state_dim == len(slices)

    def sample(self, num_samples):
        samples = torch.zeros(num_samples, self.dynamics.state_dim)
        for dim in range(self.dynamics.state_dim):
            if self.slices[dim] is None:
                samples[:, dim].uniform_(*self.dynamics.state_test_range()[dim])
            else:
                samples[:, dim] = self.slices[dim]
        return samples

import torch
from tqdm import tqdm

# # get the tEarliest in [tMin:tMax:dt] at which the state is still valid
# def get_tEarliest(device, model, dynamics, state, tMin, tMax, dt, validator):
#     with torch.no_grad():
#         tEarliest = torch.full(state.shape[:-1], tMin - 1)
#         model_state = dynamics.normalize_state(state)

#         times_to_try = torch.arange(tMin, tMax + dt, dt)
#         for time_to_try in times_to_try:
#             blank_idx = (tEarliest < tMin)
#             time = torch.full((*state.shape[:-1], 1), time_to_try)
#             model_time = dynamics.normalize_time(time)
#             model_coord = torch.cat((model_time, model_state), dim=-1)[blank_idx]
#             model_result = model({'coords': model_coord.to(device)})
#             value = dynamics.output_to_value(output=model_result['model_out'][..., 0], state=state.to(device)).cpu()
#             valid_idx = validator.validate(torch.cat((time, state), dim=-1), value)
#             tMasked = tEarliest[blank_idx]
#             tMasked[valid_idx] = time_to_try
#             tEarliest[blank_idx] = tMasked
#             if torch.all(tEarliest >= tMin):
#                 break
#         blank_idx = (tEarliest < tMin)
#         if torch.any(blank_idx):
#             print(str(torch.sum(blank_idx)), 'invalid states')
#             tEarliest[blank_idx] = tMax
#         return tEarliest

def scenario_optimization(device, model, policy, dynamics, tMin, tMax, dt, set_type, control_type, scenario_batch_size, sample_batch_size, sample_generator, sample_validator, violation_validator, max_scenarios=None, max_samples=None, max_violations=None, tStart_generator=None):
    rem = ((tMax-tMin) / dt)%1
    e_tol = 1e-12
    assert rem < e_tol or 1 - rem < e_tol, f'{tMax-tMin} is not divisible by {dt}'
    assert tMax > tMin
    assert set_type in ['BRS', 'BRT']
    if set_type == 'BRS':
        print('confirm correct calculation of true values of trajectories (batch_scenario_costs)')
        raise NotImplementedError
    assert control_type in ['value', 'ttr', 'init_ttr']
    assert max_scenarios or max_samples or max_violations, 'one of the termination conditions must be used'
    if max_scenarios:
        assert (max_scenarios / scenario_batch_size)%1 == 0, 'max_scenarios is not divisible by scenario_batch_size'
    if max_samples:
        assert (max_samples / sample_batch_size)%1 == 0, 'max_samples is not divisible by sample_batch_size'

    # accumulate scenarios
    times = torch.zeros(0, )
    states = torch.zeros(0, dynamics.state_dim)
    values = torch.zeros(0, )
    costs = torch.zeros(0, )
    init_hams = torch.zeros(0, )
    mean_hams = torch.zeros(0, )
    mean_abs_hams = torch.zeros(0, )
    max_abs_hams = torch.zeros(0, )
    min_abs_hams = torch.zeros(0, )

    num_scenarios = 0
    num_samples = 0
    num_violations = 0

    pbar_pos = 0
    if max_scenarios:
        scenarios_pbar = tqdm(total=max_scenarios, desc='Scenarios', position=pbar_pos)
        pbar_pos += 1
    if max_samples:
        samples_pbar = tqdm(total=max_samples, desc='Samples', position=pbar_pos)
        pbar_pos += 1
    if max_violations:
        violations_pbar = tqdm(total=max_violations, desc='Violations', position=pbar_pos)
        pbar_pos += 1

    nums_valid_samples = []
    while True:
        if (max_scenarios and (num_scenarios >= max_scenarios)) or (max_violations and (num_violations >= max_violations)):
            break

        batch_scenario_times = torch.zeros(scenario_batch_size, )
        batch_scenario_states = torch.zeros(scenario_batch_size, dynamics.state_dim)
        batch_scenario_values = torch.zeros(scenario_batch_size, )

        num_collected_scenarios = 0
        while num_collected_scenarios < scenario_batch_size:
            if max_samples and (num_samples >= max_samples):
                break
            # sample batch
            if tStart_generator is not None:
                batch_sample_times = tStart_generator(sample_batch_size)
                # need to round to nearest dt
                batch_sample_times = torch.round(batch_sample_times/dt)*dt
            else:
                batch_sample_times = torch.full((sample_batch_size, ), tMax)
            batch_sample_states = dynamics.equivalent_wrapped_state(sample_generator.sample(sample_batch_size))
            batch_sample_coords = torch.cat((batch_sample_times.unsqueeze(-1), batch_sample_states), dim=-1)

            # validate batch
            with torch.no_grad():
                batch_sample_model_results = model({'coords': dynamics.coord_to_input(batch_sample_coords.to(device))})
                batch_sample_values = dynamics.io_to_value(batch_sample_model_results['model_in'].detach(), batch_sample_model_results['model_out'].squeeze(dim=-1).detach())
            batch_valid_sample_idxs = torch.where(sample_validator.validate(batch_sample_coords, batch_sample_values))[0].detach().cpu()

            # store valid samples
            num_valid_samples = len(batch_valid_sample_idxs)
            start_idx = num_collected_scenarios
            end_idx = min(start_idx + num_valid_samples, scenario_batch_size)
            batch_scenario_times[start_idx:end_idx] = batch_sample_times[batch_valid_sample_idxs][:end_idx-start_idx]
            batch_scenario_states[start_idx:end_idx] = batch_sample_states[batch_valid_sample_idxs][:end_idx-start_idx]
            batch_scenario_values[start_idx:end_idx] = batch_sample_values[batch_valid_sample_idxs][:end_idx-start_idx]

            # update counters
            num_samples += sample_batch_size
            if max_samples:
                samples_pbar.update(sample_batch_size)
            num_collected_scenarios += end_idx - start_idx
            nums_valid_samples.append(num_valid_samples)
        if max_samples and (num_samples >= max_samples):
            break

        # propagate scenarios
        state_trajs = torch.zeros(scenario_batch_size, int((tMax-tMin)/dt + 1), dynamics.state_dim)
        ctrl_trajs = torch.zeros(scenario_batch_size, int((tMax-tMin)/dt), dynamics.control_dim)
        dstb_trajs = torch.zeros(scenario_batch_size, int((tMax-tMin)/dt), dynamics.disturbance_dim)
        ham_trajs = torch.zeros(scenario_batch_size, int((tMax-tMin)/dt))

        state_trajs[:, 0, :] = batch_scenario_states
        for k in tqdm(range(int((tMax-tMin)/dt)), desc='Trajectory Propagation', position=pbar_pos, leave=False):
            if control_type == 'value':
                traj_time = tMax - k*dt
                traj_times = torch.full((scenario_batch_size, ), traj_time)
            # elif control_type == 'ttr':
            #     traj_times = get_tEarliest(model=model, dynamics=dynamics, state=state_trajs[:, k], tMin=tMin, tMax=traj_time, dt=dt, validator=sample_validator)
            # elif control_type == 'init_ttr':
            #     if k == 0:
            #         init_traj_times = get_tEarliest(model=model, dynamics=dynamics, state=state_trajs[:, k], tMin=tMin, tMax=traj_time, dt=dt, validator=sample_validator)
            #     traj_times = torch.maximum(init_traj_times - k*dt, torch.tensor(tMin)) # check whether this is the best thing to do for init_ttr
            traj_coords = torch.cat((traj_times.unsqueeze(-1), state_trajs[:, k]), dim=-1)
            traj_policy_results = policy({'coords': dynamics.coord_to_input(traj_coords.to(device))})
            traj_dvs = dynamics.io_to_dv(traj_policy_results['model_in'], traj_policy_results['model_out'].squeeze(dim=-1)).detach()

            # TODO: I do not think there is actually any reason to store these trajs? Could save space by removing these.
            ctrl_trajs[:, k] = dynamics.optimal_control(traj_coords[:, 1:].to(device), traj_dvs[..., 1:].to(device))
            dstb_trajs[:, k] = dynamics.optimal_disturbance(traj_coords[:, 1:].to(device), traj_dvs[..., 1:].to(device))
            ham_trajs[:, k] = dynamics.hamiltonian(traj_coords[:, 1:].to(device), traj_dvs[..., 1:].to(device))

            if tStart_generator is not None: # freeze states whose start time has not been reached yet
                is_frozen = batch_scenario_times < traj_times
                is_unfrozen = torch.logical_not(is_frozen)
                state_trajs[is_frozen, k+1] = state_trajs[is_frozen, k]
                state_trajs[is_unfrozen, k+1] = dynamics.equivalent_wrapped_state(state_trajs[is_unfrozen, k].to(device) + dt*dynamics.dsdt(state_trajs[is_unfrozen, k].to(device), ctrl_trajs[is_unfrozen, k].to(device), dstb_trajs[is_unfrozen, k].to(device))).cpu()
            else:
                state_trajs[:, k+1] = dynamics.equivalent_wrapped_state(state_trajs[:, k].to(device) + dt*dynamics.dsdt(state_trajs[:, k].to(device), ctrl_trajs[:, k].to(device), dstb_trajs[:, k].to(device)))

        # compute batch_scenario_costs
        # TODO: need to handle the case of using tStart_generator when extending a trajectory by a frozen initial state will inadvertently affect cost computation (the min lx cost formulation is unaffected, but other cost formulations might care)
        if set_type == 'BRT':
            batch_scenario_costs = dynamics.cost_fn(state_trajs.to(device))
        elif set_type == 'BRS':
            if control_type == 'init_ttr': # is this correct for init_ttr?
                batch_scenario_costs =  dynamics.boundary_fn(state_trajs.to(device))[:, (init_traj_times - tMin) / dt]
            elif control_type == 'value':
                batch_scenario_costs =  dynamics.boundary_fn(state_trajs.to(device))[:, -1]
            else:
                raise NotImplementedError # what is the correct thing to do for ttr?

        # compute batch_scenario_init_hams, batch_scenario_mean_hams, batch_scenario_mean_abs_hams, batch_scenario_max_abs_hams, batch_scenario_min_abs_hams
        batch_scenario_init_hams = ham_trajs[:, 0]
        batch_scenario_mean_hams = torch.mean(ham_trajs, dim=-1)
        batch_scenario_mean_abs_hams = torch.mean(torch.abs(ham_trajs), dim=-1)
        batch_scenario_max_abs_hams = torch.max(torch.abs(ham_trajs), dim=-1).values
        batch_scenario_min_abs_hams = torch.min(torch.abs(ham_trajs), dim=-1).values

        # store scenarios
        times = torch.cat((times, batch_scenario_times.cpu()), dim=0)
        states = torch.cat((states, batch_scenario_states.cpu()), dim=0)
        values = torch.cat((values, batch_scenario_values.cpu()), dim=0)
        costs = torch.cat((costs, batch_scenario_costs.cpu()), dim=0)
        init_hams = torch.cat((init_hams, batch_scenario_init_hams.cpu()), dim=0)
        mean_hams = torch.cat((mean_hams, batch_scenario_mean_hams.cpu()), dim=0)
        mean_abs_hams = torch.cat((mean_abs_hams, batch_scenario_mean_abs_hams.cpu()), dim=0)
        max_abs_hams = torch.cat((max_abs_hams, batch_scenario_max_abs_hams.cpu()), dim=0)
        min_abs_hams = torch.cat((min_abs_hams, batch_scenario_min_abs_hams.cpu()), dim=0)

        # update counters
        num_scenarios += scenario_batch_size
        if max_scenarios:
            scenarios_pbar.update(scenario_batch_size)
        num_new_violations = int(torch.sum(violation_validator.validate(batch_scenario_states, batch_scenario_costs)))
        num_violations += num_new_violations
        if max_violations:
            violations_pbar.update(num_new_violations)

    if max_scenarios:
        scenarios_pbar.close()
    if max_samples:
        samples_pbar.close()
    if max_violations:
        violations_pbar.close()

    violations = violation_validator.validate(states, costs)

    return {
        'times': times,
        'states': states,
        'values': values,
        'costs': costs,
        'init_hams': init_hams,
        'init_abs_hams': torch.abs(init_hams),
        'mean_hams': mean_hams,
        'mean_abs_hams': mean_abs_hams,
        'max_abs_hams': max_abs_hams,
        'min_abs_hams': min_abs_hams,
        'violations': violations,
        'valid_sample_fraction': torch.mean(torch.tensor(nums_valid_samples, dtype=float))/sample_batch_size,
        'violation_rate': 0 if not num_scenarios else num_violations / num_scenarios,
        'maxed_scenarios': (max_scenarios is not None) and num_scenarios >= max_scenarios,
        'maxed_samples': (max_samples is not None) and num_samples >= max_samples,
        'maxed_violations': (max_violations is not None) and num_violations >= max_violations,
        'batch_state_trajs': None if (max_samples and (num_samples >= max_samples)) else state_trajs,
    }

def target_fraction(device, model, dynamics, t, sample_validator, target_validator, num_samples, batch_size):
    with torch.no_grad():
        states = torch.zeros(0, dynamics.state_dim)
        values = torch.zeros(0, )

        while len(states) < num_samples:
            # sample batch
            batch_times = torch.full((batch_size, 1), t)
            batch_states = torch.zeros(batch_size, dynamics.state_dim)
            for dim in range(dynamics.state_dim):
                batch_states[:, dim].uniform_(*dynamics.state_test_range()[dim])
            batch_states = dynamics.equivalent_wrapped_state(batch_states)
            batch_coords = torch.cat((batch_times, batch_states), dim=-1)

            # validate batch
            batch_model_results = model({'coords': dynamics.coord_to_input(batch_coords.to(device))})
            batch_values = dynamics.io_to_value(batch_model_results['model_in'], batch_model_results['model_out'].squeeze(dim=-1)).detach()
            batch_valids = sample_validator.validate(batch_coords, batch_values).detach().cpu()

            # store valid portion of batch
            states = torch.cat((states, batch_states[batch_valids].cpu()), dim=0)
            values = torch.cat((values, batch_values[batch_valids].cpu()), dim=0)

        states = states[:num_samples]
        values = values[:num_samples]
        coords = torch.cat((torch.full((num_samples, 1), t), states), dim=-1)
        valids = target_validator.validate(coords.to(device), values.to(device))
    return torch.sum(valids) / num_samples

class MLP(torch.nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()

        s1 = int(2*input_size)
        s2 = int(input_size)
        s3 = int(input_size)
        self.l1 = torch.nn.Linear(input_size, s1)
        self.a1 = torch.nn.ReLU()
        self.l2 = torch.nn.Linear(s1, s2)
        self.a2 = torch.nn.ReLU()
        self.l3 = torch.nn.Linear(s2, s3)
        self.a3 = torch.nn.ReLU()
        self.l4 = torch.nn.Linear(s3, 1)

    def forward(self, x):
        x = self.l1(x)
        x = self.a1(x)
        x = self.l2(x)
        x = self.a2(x)
        x = self.l3(x)
        x = self.a3(x)
        x = self.l4(x)
        return x



def sample_values(device, model, dynamics, t, num_samples, batch_size):
    with torch.no_grad():
        states = torch.zeros(0, dynamics.state_dim)
        values = torch.zeros(0, )

        while len(states) < num_samples:
            # sample batch
            batch_times = torch.full((batch_size, 1), t)
            batch_states = torch.zeros(batch_size, dynamics.state_dim)
            for dim in range(dynamics.state_dim):
                batch_states[:, dim].uniform_(*dynamics.state_test_range()[dim])
            batch_states = dynamics.equivalent_wrapped_state(batch_states)
            batch_coords = torch.cat((batch_times, batch_states), dim=-1)

            batch_model_results = model({'coords': dynamics.coord_to_input(batch_coords.to(device))})
            batch_values = dynamics.io_to_value(batch_model_results['model_in'], batch_model_results['model_out'].squeeze(dim=-1)).detach()

            # store batch
            states = torch.cat((states, batch_states.cpu()), dim=0)
            values = torch.cat((values, batch_values.cpu()), dim=0)

        states = states[:num_samples]
        values = values[:num_samples]
        coords = torch.cat((torch.full((num_samples, 1), t), states), dim=-1)
    return values


In [9]:
# uses real units
def init_brt_hjivi_loss(dynamics, minWith, dirichlet_loss_divisor):
    def brt_hjivi_loss(state, value, dvdt, dvds, boundary_value, dirichlet_mask, output):
        if torch.all(dirichlet_mask):
            # pretraining loss
            diff_constraint_hom = torch.Tensor([0])
        else:
            ham = dynamics.hamiltonian(state, dvds)
            if minWith == 'zero':
                ham = torch.clamp(ham, max=0.0)

            diff_constraint_hom = dvdt - ham
            if minWith == 'target':
                diff_constraint_hom = torch.max(
                    diff_constraint_hom, value - boundary_value)
        dirichlet = value[dirichlet_mask] - boundary_value[dirichlet_mask]
        if dynamics.deepreach_model == 'exact':
            if torch.all(dirichlet_mask):
                # pretraining
                dirichlet = output.squeeze(dim=-1)[dirichlet_mask]-0.0
            else:
                return {'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()}

        return {'dirichlet': torch.abs(dirichlet).sum() / dirichlet_loss_divisor,
                'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()}

    return brt_hjivi_loss
def init_brat_hjivi_loss(dynamics, minWith, dirichlet_loss_divisor):
    def brat_hjivi_loss(state, value, dvdt, dvds, boundary_value, reach_value, avoid_value, dirichlet_mask, output):
        if torch.all(dirichlet_mask):
            # pretraining loss
            diff_constraint_hom = torch.Tensor([0])
        else:
            ham = dynamics.hamiltonian(state, dvds)
            if minWith == 'zero':
                ham = torch.clamp(ham, max=0.0)

            diff_constraint_hom = dvdt - ham
            if minWith == 'target':
                diff_constraint_hom = torch.min(
                    torch.max(diff_constraint_hom, value - reach_value), value + avoid_value)

        dirichlet = value[dirichlet_mask] - boundary_value[dirichlet_mask]
        if dynamics.deepreach_model == 'exact':
            if torch.all(dirichlet_mask):
                dirichlet = output.squeeze(dim=-1)[dirichlet_mask]-0.0
            else:
                return {'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()}
        return {'dirichlet': torch.abs(dirichlet).sum() / dirichlet_loss_divisor,
                'diff_constraint_hom': torch.abs(diff_constraint_hom).sum()}
    return brat_hjivi_loss

In [10]:
"""
MIT License

Copyright (c) 2020 Vincent Sitzmann

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

class BatchLinear(nn.Linear):
    '''A linear layer'''
    __doc__ = nn.Linear.__doc__

    def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        bias = params.get('bias', None)
        weight = params['weight']

        output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
        output += bias.unsqueeze(-2)
        return output


class Sine(nn.Module):
    def __init(self):
        super().__init__()

    def forward(self, input):
        # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
        return torch.sin(30 * input)


class FCBlock(nn.Module):
    '''A fully connected neural network.
    '''

    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=False, nonlinearity='relu', weight_init=None):
        super().__init__()

        self.first_layer_init = None

        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init),
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None)}

        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init

        self.net = []
        self.net.append(nn.Sequential(
            BatchLinear(in_features, hidden_features), nl
        ))

        for i in range(num_hidden_layers):
            self.net.append(nn.Sequential(
                BatchLinear(hidden_features, hidden_features), nl
            ))

        if outermost_linear:
            self.net.append(nn.Sequential(BatchLinear(hidden_features, out_features)))
        else:
            self.net.append(nn.Sequential(
                BatchLinear(hidden_features, out_features), nl
            ))

        self.net = nn.Sequential(*self.net)
        if self.weight_init is not None:
            self.net.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.net[0].apply(first_layer_init)

    def forward(self, coords, params=None, **kwargs):
        if params is None:
            params = OrderedDict(self.named_parameters())

        output = self.net(coords)
        return output


class SingleBVPNet(nn.Module):
    '''A canonical representation network for a BVP.'''

    def __init__(self, out_features=1, type='sine', in_features=2,
                 mode='mlp', hidden_features=256, num_hidden_layers=3, **kwargs):
        super().__init__()
        self.mode = mode
        self.net = FCBlock(in_features=in_features, out_features=out_features, num_hidden_layers=num_hidden_layers,
                           hidden_features=hidden_features, outermost_linear=True, nonlinearity=type)
        print(self)

    def forward(self, model_input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        # Enables us to compute gradients w.r.t. coordinates
        # TODO: should not need to .clone().detach().requires_grad_(True); instead, use .retain_grad() on input in calling script
        # otherwise, .detach() removes input from the graph so grad cannot propagate back end-to-end, e.g., percept -> NN -> state estimation (input)
        coords_org = model_input['coords'].clone().detach().requires_grad_(True)
        coords = coords_org

        output = self.net(coords)
        return {'model_in': coords_org, 'model_out': output}


########################
# Initialization methods
def init_weights_normal(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')


def init_weights_selu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))


def init_weights_elu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))


def init_weights_xavier(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.xavier_normal_(m.weight)


def sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)


def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-1 / num_input, 1 / num_input)


### 5. Dynamics Code

In [11]:
# during training, states will be sampled uniformly by each state dimension from the model-unit -1 to 1 range (for training stability),
# which may or may not correspond to proper test ranges
# note that coord refers to [time, *state], and input refers to whatever is fed directly to the model (often [time, *state, params])
# in the future, code will need to be fixed to correctly handle parameterized models
class Dynamics(ABC):
    def __init__(self,
    loss_type:str, set_mode:str,
    state_dim:int, input_dim:int,
    control_dim:int, disturbance_dim:int,
    state_mean:list, state_var:list,
    value_mean:float, value_var:float, value_normto:float,
    deepreach_model:str):
        self.loss_type = loss_type
        self.set_mode = set_mode
        self.state_dim = state_dim
        self.input_dim = input_dim
        self.control_dim = control_dim
        self.disturbance_dim = disturbance_dim
        self.state_mean = torch.tensor(state_mean)
        self.state_var = torch.tensor(state_var)
        self.value_mean = value_mean
        self.value_var = value_var
        self.value_normto = value_normto
        self.deepreach_model = deepreach_model
        assert self.loss_type in ['brt_hjivi', 'brat_hjivi'], f'loss type {self.loss_type} not recognized'
        if self.loss_type == 'brat_hjivi':
            assert callable(self.reach_fn) and callable(self.avoid_fn)
        assert self.set_mode in ['reach', 'avoid'], f'set mode {self.set_mode} not recognized'
        for state_descriptor in [self.state_mean, self.state_var]:
            assert len(state_descriptor) == self.state_dim, 'state descriptor dimension does not equal state dimension, ' + str(len(state_descriptor)) + ' != ' + str(self.state_dim)

    # ALL METHODS ARE BATCH COMPATIBLE

    # MODEL-UNIT CONVERSIONS (TODO: refactor into separate model-unit conversion class?)

    # convert model input to real coord
    def input_to_coord(self, input):
        coord = input.clone()
        coord[..., 1:] = (input[..., 1:] * self.state_var.to(device=input.device)) + self.state_mean.to(device=input.device)
        return coord

    # convert real coord to model input
    def coord_to_input(self, coord):
        input = coord.clone()
        input[..., 1:] = (coord[..., 1:] - self.state_mean.to(device=coord.device)) / self.state_var.to(device=coord.device)
        return input

    # convert model io to real value
    def io_to_value(self, input, output):
        if self.deepreach_model=="diff":
            return (output * self.value_var / self.value_normto) + self.boundary_fn(self.input_to_coord(input)[..., 1:])
        elif self.deepreach_model=="exact":
            return (output * input[..., 0] * self.value_var / self.value_normto) + self.boundary_fn(self.input_to_coord(input)[..., 1:])
        else:
            return (output * self.value_var / self.value_normto) + self.value_mean

    # convert model io to real dv
    def io_to_dv(self, input, output):
        dodi = jacobian(output.unsqueeze(dim=-1), input)[0].squeeze(dim=-2)

        if self.deepreach_model=="diff":
            dvdt = (self.value_var / self.value_normto) * dodi[..., 0]

            dvds_term1 = (self.value_var / self.value_normto / self.state_var.to(device=dodi.device)) * dodi[..., 1:]
            state = self.input_to_coord(input)[..., 1:]
            dvds_term2 = jacobian(self.boundary_fn(state).unsqueeze(dim=-1), state)[0].squeeze(dim=-2)
            dvds = dvds_term1 + dvds_term2
        elif self.deepreach_model=="exact":
            dvdt = (self.value_var / self.value_normto) * \
                (input[..., 0]*dodi[..., 0] + output)

            dvds_term1 = (self.value_var / self.value_normto /
                          self.state_var.to(device=dodi.device)) * dodi[..., 1:] * input[..., 0].unsqueeze(-1)
            state = self.input_to_coord(input)[..., 1:]
            dvds_term2 = jacobian(self.boundary_fn(
                state).unsqueeze(dim=-1), state)[0].squeeze(dim=-2)
            dvds = dvds_term1 + dvds_term2
        else:
            dvdt = (self.value_var / self.value_normto) * dodi[..., 0]
            dvds = (self.value_var / self.value_normto / self.state_var.to(device=dodi.device)) * dodi[..., 1:]

        return torch.cat((dvdt.unsqueeze(dim=-1), dvds), dim=-1)

    # ALL FOLLOWING METHODS USE REAL UNITS

    @abstractmethod
    def state_test_range(self):
        raise NotImplementedError

    @abstractmethod
    def equivalent_wrapped_state(self, state):
        raise NotImplementedError

    @abstractmethod
    def dsdt(self, state, control, disturbance):
        raise NotImplementedError

    @abstractmethod
    def boundary_fn(self, state):
        raise NotImplementedError

    @abstractmethod
    def sample_target_state(self, num_samples):
        raise NotImplementedError

    @abstractmethod
    def cost_fn(self, state_traj):
        raise NotImplementedError

    @abstractmethod
    def hamiltonian(self, state, dvds):
        raise NotImplementedError

    @abstractmethod
    def optimal_control(self, state, dvds):
        raise NotImplementedError

    @abstractmethod
    def optimal_disturbance(self, state, dvds):
        raise NotImplementedError

    @abstractmethod
    def plot_config(self):
        raise NotImplementedError

class ParameterizedVertDrone2D(Dynamics):
    def __init__(self, gravity:float, input_multiplier_max:float, input_magnitude_max:float):
        self.gravity = gravity                             # g
        self.input_multiplier_max = input_multiplier_max   # k_max
        self.input_magnitude_max = input_magnitude_max     # u_max
        super().__init__(
            loss_type='brt_hjivi', set_mode='avoid',
            state_dim=3, input_dim=4, control_dim=1, disturbance_dim=0,
            state_mean=[0, 1.5, self.input_multiplier_max/2], # v, z, k
            state_var=[4, 2, self.input_multiplier_max/2],    # v, z, k
            value_mean=0.25,
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-4, 4],                        # v
            [-0.5, 3.5],                    # z
            [0, self.input_multiplier_max], # k
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        return wrapped_state

    # ParameterizedVertDrone2D dynamics
    # \dot v = k*u - g
    # \dot z = v
    # \dot k = 0
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = state[..., 2]*control[..., 0] - self.gravity
        dsdt[..., 1] = state[..., 0]
        dsdt[..., 2] = 0
        return dsdt

    def boundary_fn(self, state):
        return -torch.abs(state[..., 1] - 1.5) + 1.5

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        raise NotImplementedError

    def hamiltonian(self, state, dvds):
        return state[..., 2]*torch.abs(dvds[..., 0]*self.input_magnitude_max) \
                - dvds[..., 0]*self.gravity \
                + dvds[..., 1]*state[..., 0]

    def optimal_control(self, state, dvds):
        raise NotImplementedError

    def optimal_disturbance(self, state, dvds):
        raise NotImplementedError

    def plot_config(self):
        return {
            'state_slices': [0, 1.5, self.input_multiplier_max/2],
            'state_labels': ['v', 'z', 'k'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 2,
        }

class Air3D(Dynamics):
    def __init__(self, collisionR:float, velocity:float, omega_max:float, angle_alpha_factor:float):
        self.collisionR = collisionR
        self.velocity = velocity
        self.omega_max = omega_max
        self.angle_alpha_factor = angle_alpha_factor
        super().__init__(
            loss_type='brt_hjivi', set_mode='avoid',
            state_dim=3, input_dim=4, control_dim=1, disturbance_dim=1,
            state_mean=[0, 0, 0],
            state_var=[1, 1, self.angle_alpha_factor*math.pi],
            value_mean=0.25,
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-1, 1],
            [-1, 1],
            [-math.pi, math.pi],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # Air3D dynamics
    # \dot x    = -v + v \cos \psi + u y
    # \dot y    = v \sin \psi - u x
    # \dot \psi = d - u
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = -self.velocity + self.velocity*torch.cos(state[..., 2]) + control[..., 0]*state[..., 1]
        dsdt[..., 1] = self.velocity*torch.sin(state[..., 2]) - control[..., 0]*state[..., 0]
        dsdt[..., 2] = disturbance[..., 0] - control[..., 0]
        return dsdt

    def boundary_fn(self, state):
        return torch.norm(state[..., :2], dim=-1) - self.collisionR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        ham = self.omega_max * torch.abs(dvds[..., 0] * state[..., 1] - dvds[..., 1] * state[..., 0] - dvds[..., 2])  # Control component
        ham = ham - self.omega_max * torch.abs(dvds[..., 2])  # Disturbance component
        ham = ham + (self.velocity * (torch.cos(state[..., 2]) - 1.0) * dvds[..., 0]) + (self.velocity * torch.sin(state[..., 2]) * dvds[..., 1])  # Constant component
        return ham

    def optimal_control(self, state, dvds):
        det = dvds[..., 0]*state[..., 1] - dvds[..., 1]*state[..., 0]-dvds[..., 2]
        return (self.omega_max * torch.sign(det))[..., None]

    def optimal_disturbance(self, state, dvds):
        return (-self.omega_max * torch.sign(dvds[..., 2]))[..., None]

    def plot_config(self):
        return {
            'state_slices': [0, 0, 0],
            'state_labels': ['x', 'y', 'theta'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 2,
        }

class Dubins3D(Dynamics):
    def __init__(self, goalR:float, velocity:float, omega_max:float, angle_alpha_factor:float, set_mode:str, freeze_model: bool):
        self.goalR = goalR
        self.velocity = velocity
        self.omega_max = omega_max
        self.angle_alpha_factor = angle_alpha_factor
        self.freeze_model = freeze_model
        super().__init__(
            loss_type='brt_hjivi', set_mode=set_mode,
            state_dim=3, input_dim=4, control_dim=1, disturbance_dim=0,
            state_mean=[0, 0, 0],
            state_var=[1, 1, self.angle_alpha_factor*math.pi],
            value_mean=0.25,
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact"
        )

    def state_test_range(self):
        return [
            [-1, 1],
            [-1, 1],
            [-math.pi, math.pi],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # Dubins3D dynamics
    # \dot x    = v \cos \theta
    # \dot y    = v \sin \theta
    # \dot \theta = u
    def dsdt(self, state, control, disturbance):
        if self.freeze_model:
            raise NotImplementedError
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = self.velocity*torch.cos(state[..., 2])
        dsdt[..., 1] = self.velocity*torch.sin(state[..., 2])
        dsdt[..., 2] = control[..., 0]
        return dsdt

    def boundary_fn(self, state):
        return torch.norm(state[..., :2], dim=-1) - self.goalR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        if self.freeze_model:
            raise NotImplementedError
        if self.set_mode == 'reach':
            return self.velocity*(torch.cos(state[..., 2]) * dvds[..., 0] + torch.sin(state[..., 2]) * dvds[..., 1]) - self.omega_max * torch.abs(dvds[..., 2])
        elif self.set_mode == 'avoid':
            return self.velocity*(torch.cos(state[..., 2]) * dvds[..., 0] + torch.sin(state[..., 2]) * dvds[..., 1]) + self.omega_max * torch.abs(dvds[..., 2])

    def optimal_control(self, state, dvds):
        if self.set_mode == 'reach':
            return (-self.omega_max*torch.sign(dvds[..., 2]))[..., None]
        elif self.set_mode == 'avoid':
            return (self.omega_max*torch.sign(dvds[..., 2]))[..., None]

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [0, 0, 0],
            'state_labels': ['x', 'y', r'$\theta$'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 2,
        }

class Dubins4D(Dynamics):
    def __init__(self, bound_mode:str):
        self.vMin = 0.2
        self.vMax = 14.8
        self.collisionR = 1.5
        self.bound_mode = bound_mode
        assert self.bound_mode in ['v1', 'v2']

        xMean = 0
        yMean = 0
        thetaMean = 0
        vMean = 7.5
        aMean = 0
        oMean = 0

        xVar = 10
        yVar = 10
        thetaVar = 1.2*math.pi
        vVar = 7.5
        aVar = 10
        oVar = 3*math.pi if self.bound_mode == 'v1' else 2.0

        super().__init__(
            loss_type='brt_hjivi',
            state_dim=14, input_dim=15,  control_dim=2, disturbance_dim=0,
            state_mean=[xMean, yMean, thetaMean, vMean, xMean, yMean, aMean, aMean, oMean, oMean, aMean, aMean, oMean, oMean],
            state_var=[xVar, yVar, thetaVar, vVar, xVar, yVar, aVar, aVar, oVar, oVar, aVar, aVar, oVar, oVar],
            value_mean=13,
            value_var=14,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-1, 1],
            [-1, 1],
            [-math.pi, math.pi],
            [self.vMin, self.vMax],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    def boundary_fn(self, state):
        return torch.norm(state[..., 0:2] - state[..., 4:6], dim=-1) - self.collisionR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        raise NotImplementedError

    def dsdt(self, state, control, disturbance):
        raise NotImplementedError

    def hamiltonian(self, state, dvds):
        raise NotImplementedError

    def optimal_control(self, state, dvds):
        raise NotImplementedError

    def optimal_disturbance(self, state, dvds):
        raise NotImplementedError

    def plot_config(self):
        raise NotImplementedError

class NarrowPassage(Dynamics):
    def __init__(self, avoid_fn_weight:float, avoid_only:bool):
        self.L = 2.0

        # # Target positions
        self.goalX = [6.0, -6.0]
        self.goalY = [-1.4, 1.4]

        # State bounds
        self.vMin = 0.001
        self.vMax = 6.50
        self.phiMin = -0.3*math.pi + 0.001
        self.phiMax = 0.3*math.pi - 0.001

        # Control bounds
        self.aMin = -4.0
        self.aMax = 2.0
        self.psiMin = -3.0*math.pi
        self.psiMax = 3.0*math.pi

        # Lower and upper curb positions (in the y direction)
        self.curb_positions = [-2.8, 2.8]

        # Stranded car position
        self.stranded_car_pos = [0.0, -1.8]

        self.avoid_fn_weight = avoid_fn_weight

        self.avoid_only = avoid_only

        super().__init__(
            loss_type='brt_hjivi' if self.avoid_only else 'brat_hjivi', set_mode='avoid' if self.avoid_only else 'reach',
            state_dim=10, input_dim=11, control_dim=4, disturbance_dim=0,
            # state = [x1, y1, th1, v1, phi1, x2, y2, th2, v2, phi2]
            state_mean=[
                0, 0, 0, 3, 0,
                0, 0, 0, 3, 0
            ],
            state_var=[
                8.0, 3.8, 1.2*math.pi, 4.0, 1.2*0.3*math.pi,
                8.0, 3.8, 1.2*math.pi, 4.0, 1.2*0.3*math.pi,
            ],
            value_mean=0.25*8.0,
            value_var=0.5*8.0,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-8, 8],
            [-3.8, 3.8],
            [-math.pi, math.pi],
            [-1, 7],
            [-0.3*math.pi, 0.3*math.pi],
            [-8, 8],
            [-3.8, 3.8],
            [-math.pi, math.pi],
            [-1, 7],
            [-0.3*math.pi, 0.3*math.pi],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 4] = (wrapped_state[..., 4] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 7] = (wrapped_state[..., 7] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 9] = (wrapped_state[..., 9] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # NarrowPassage dynamics
    # \dot x   = v * cos(th)
    # \dot y   = v * sin(th)
    # \dot th  = v * tan(phi) / L
    # \dot v   = u1
    # \dot phi = u2
    # \dot x   = ...
    # \dot y   = ...
    # \dot th  = ...
    # \dot v   = ...
    # \dot phi = ...
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = state[..., 3]*torch.cos(state[..., 2])
        dsdt[..., 1] = state[..., 3]*torch.sin(state[..., 2])
        dsdt[..., 2] = state[..., 3]*torch.tan(state[..., 4]) / self.L
        dsdt[..., 3] = control[..., 0]
        dsdt[..., 4] = control[..., 1]
        dsdt[..., 5] = state[..., 8]*torch.cos(state[..., 7])
        dsdt[..., 6] = state[..., 8]*torch.sin(state[..., 7])
        dsdt[..., 7] = state[..., 8]*torch.tan(state[..., 9]) / self.L
        dsdt[..., 8] = control[..., 2]
        dsdt[..., 9] = control[..., 3]
        return dsdt

    def reach_fn(self, state):
        if self.avoid_only:
            raise RuntimeError
        # vehicle 1
        goal_tensor_R1 = torch.tensor([self.goalX[0], self.goalY[0]], device=state.device)
        dist_R1 = torch.norm(state[..., 0:2] - goal_tensor_R1, dim=-1) - self.L
        # vehicle 2
        goal_tensor_R2 = torch.tensor([self.goalX[1], self.goalY[1]], device=state.device)
        dist_R2 = torch.norm(state[..., 5:7] - goal_tensor_R2, dim=-1) - self.L
        return torch.maximum(dist_R1, dist_R2)

    def avoid_fn(self, state):
        # distance from lower curb
        dist_lc_R1 = state[..., 1] - self.curb_positions[0] - 0.5*self.L
        dist_lc_R2 = state[..., 6] - self.curb_positions[0] - 0.5*self.L
        dist_lc = torch.minimum(dist_lc_R1, dist_lc_R2)

        # distance from upper curb
        dist_uc_R1 = self.curb_positions[1] - state[..., 1] - 0.5*self.L
        dist_uc_R2 = self.curb_positions[1] - state[..., 6] - 0.5*self.L
        dist_uc = torch.minimum(dist_uc_R1, dist_uc_R2)

        # distance from the stranded car
        stranded_car_pos = torch.tensor(self.stranded_car_pos, device=state.device)
        dist_stranded_R1 = torch.norm(state[..., 0:2] - stranded_car_pos, dim=-1) - self.L
        dist_stranded_R2 = torch.norm(state[..., 5:7] - stranded_car_pos, dim=-1) - self.L
        dist_stranded = torch.minimum(dist_stranded_R1, dist_stranded_R2)

        # distance between the vehicles themselves
        dist_R1R2 = torch.norm(state[..., 0:2] - state[..., 5:7], dim=-1) - self.L

        return self.avoid_fn_weight * torch.min(torch.min(torch.min(dist_lc, dist_uc), dist_stranded), dist_R1R2)

    def boundary_fn(self, state):
        if self.avoid_only:
            return self.avoid_fn(state)
        else:
            return torch.maximum(self.reach_fn(state), -self.avoid_fn(state))

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        if self.avoid_only:
            return torch.min(self.avoid_fn(state_traj), dim=-1).values
        else:
            # return min_t max{l(x(t)), max_k_up_to_t{-g(x(k))}}, where l(x) is reach_fn, g(x) is avoid_fn
            reach_values = self.reach_fn(state_traj)
            avoid_values = self.avoid_fn(state_traj)
            return torch.min(torch.maximum(reach_values, torch.cummax(-avoid_values, dim=-1).values), dim=-1).values

    def hamiltonian(self, state, dvds):
        optimal_control = self.optimal_control(state, dvds)
        return state[..., 3] * torch.cos(state[..., 2]) * dvds[..., 0] + \
               state[..., 3] * torch.sin(state[..., 2]) * dvds[..., 1] + \
               state[..., 3] * torch.tan(state[..., 4]) * dvds[..., 2] / self.L + \
               optimal_control[..., 0] * dvds[..., 3] + \
               optimal_control[..., 1] * dvds[..., 4] + \
               state[..., 8] * torch.cos(state[..., 7]) * dvds[..., 5] + \
               state[..., 8] * torch.sin(state[..., 7]) * dvds[..., 6] + \
               state[..., 8] * torch.tan(state[..., 9]) * dvds[..., 7] / self.L + \
               optimal_control[..., 2] * dvds[..., 8] + \
               optimal_control[..., 3] * dvds[..., 9]

    def optimal_control(self, state, dvds):
        a1_min = self.aMin * (state[..., 3] > self.vMin)
        a1_max = self.aMax * (state[..., 3] < self.vMax)
        psi1_min = self.psiMin * (state[..., 4] > self.phiMin)
        psi1_max = self.psiMax * (state[..., 4] < self.phiMax)
        a2_min = self.aMin * (state[..., 8] > self.vMin)
        a2_max = self.aMax * (state[..., 8] < self.vMax)
        psi2_min = self.psiMin * (state[..., 9] > self.phiMin)
        psi2_max = self.psiMax * (state[..., 9] < self.phiMax)

        if self.avoid_only:
            a1 = torch.where(dvds[..., 3] < 0, a1_min, a1_max)
            psi1 = torch.where(dvds[..., 4] < 0, psi1_min, psi1_max)
            a2 = torch.where(dvds[..., 8] < 0, a2_min, a2_max)
            psi2 = torch.where(dvds[..., 9] < 0, psi2_min, psi2_max)

        else:
            a1 = torch.where(dvds[..., 3] > 0, a1_min, a1_max)
            psi1 = torch.where(dvds[..., 4] > 0, psi1_min, psi1_max)
            a2 = torch.where(dvds[..., 8] > 0, a2_min, a2_max)
            psi2 = torch.where(dvds[..., 9] > 0, psi2_min, psi2_max)

        return torch.cat((a1[..., None], psi1[..., None], a2[..., None], psi2[..., None]), dim=-1)

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [
                -6.0, -1.4, 0.0, 6.5, 0.0,
                -6.0, 1.4, -math.pi, 0.0, 0.0
            ],
            'state_labels': [
                r'$x_1$', r'$y_1$', r'$\theta_1$', r'$v_1$', r'$\phi_1$',
                r'$x_2$', r'$y_2$', r'$\theta_2$', r'$v_2$', r'$\phi_2$',
            ],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 2,
        }

class ReachAvoidRocketLanding(Dynamics):
    def __init__(self):
        super().__init__(
            loss_type='brat_hjivi', set_mode='reach',
            state_dim=6, input_dim=7, control_dim=2, disturbance_dim=0,
            state_mean=[0.0, 80.0, 0.0, 0.0, 0.0, 0.0],
            state_var=[150.0, 70.0, 1.2*math.pi, 200.0, 200.0, 10.0],
            value_mean=0.0,
            value_var=1.0,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-150, 150],
            [10, 150],
            [-math.pi, math.pi],
            [-200, 200],
            [-200, 200],
            [-10, 10],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # \dot x = v_x
    # \dot y = v_y
    # \dot th = w
    # \dot v_x = u1 * cos(th) - u2 sin(th)
    # \dot v_y = u1 * sin(th) + u2 cos(th) - 9.81
    # \dot w = 0.3 * u1
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = state[..., 3]
        dsdt[..., 1] = state[..., 4]
        dsdt[..., 2] = state[..., 5]
        dsdt[..., 3] = control[..., 0]*torch.cos(state[..., 2]) - control[..., 1]*torch.sin(state[..., 2])
        dsdt[..., 4] = control[..., 0]*torch.sin(state[..., 2]) + control[..., 1]*torch.cos(state[..., 2]) - 9.81
        dsdt[..., 5] = 0.3*control[..., 0]
        return dsdt

    def reach_fn(self, state):
        # Only target set in the xy direction
        # Target set position in x direction
        dist_x = torch.abs(state[..., 0]) - 20.0 #[-20, 150] boundary_fn range

        # Target set position in y direction
        dist_y = state[..., 1] - 20.0  #[-10, 130] boundary_fn range

        # First compute the target function as you normally would but then normalize it later.
        max_dist = torch.max(dist_x, dist_y)
        return torch.where((max_dist >= 0), max_dist/150.0, max_dist/10.0)

    def avoid_fn(self, state):
        # distance to floor
        dist_y = state[..., 1]

        # distance to wall
        wall_left = -30
        wall_right = -20
        wall_bottom = 0
        wall_top = 100
        dist_left = wall_left - state[..., 0]
        dist_right = state[..., 0] - wall_right
        dist_bottom = wall_bottom - state[..., 1]
        dist_top = state[..., 1] - wall_top
        dist_wall_x = torch.max(dist_left, dist_right)
        dist_wall_y = torch.max(dist_bottom, dist_top)
        dist_wall = torch.norm(torch.cat((torch.max(torch.tensor(0), dist_wall_x).unsqueeze(-1), torch.max(torch.tensor(0), dist_wall_y).unsqueeze(-1)), dim=-1), dim=-1) + torch.min(torch.tensor(0), torch.max(dist_wall_x, dist_wall_y))

        return torch.min(dist_y, dist_wall)

    def boundary_fn(self, state):
        return torch.maximum(self.reach_fn(state), -self.avoid_fn(state))

    def sample_target_state(self, num_samples):
        target_state_range = self.state_test_range()
        target_state_range[0] = [-20, 20] # y in [-20, 20]
        target_state_range[1] = [10, 20]  # z in [10, 20]
        target_state_range = torch.tensor(target_state_range)
        return target_state_range[:, 0] + torch.rand(num_samples, self.state_dim)*(target_state_range[:, 1] - target_state_range[:, 0])

    def cost_fn(self, state_traj):
        # return min_t max{l(x(t)), max_k_up_to_t{-g(x(k))}}, where l(x) is reach_fn, g(x) is avoid_fn
        reach_values = self.reach_fn(state_traj)
        avoid_values = self.avoid_fn(state_traj)
        return torch.min(torch.maximum(reach_values, torch.cummax(-avoid_values, dim=-1).values), dim=-1).values

    def hamiltonian(self, state, dvds):
        # Control Hamiltonian
        u1_coeff = dvds[..., 3] * torch.cos(state[..., 2]) + dvds[..., 4] * torch.sin(state[..., 2]) + 0.3 * dvds[..., 5]
        u2_coeff = -dvds[..., 3] * torch.sin(state[..., 2]) + dvds[..., 4] * torch.cos(state[..., 2])
        ham_ctrl = -250.0 * torch.sqrt(u1_coeff * u1_coeff + u2_coeff * u2_coeff)
        # Constant Hamiltonian
        ham_constant = dvds[..., 0] * state[..., 3] + dvds[..., 1] * state[..., 4] + \
                      dvds[..., 2] * state[..., 5]  - dvds[..., 4] * 9.81
        # Compute the Hamiltonian
        ham_vehicle = ham_ctrl + ham_constant
        return ham_vehicle

    def optimal_control(self, state, dvds):
        u1_coeff = dvds[..., 3] * torch.cos(state[..., 2]) + dvds[..., 4] * torch.sin(state[..., 2]) + 0.3 * dvds[..., 5]
        u2_coeff = -dvds[..., 3] * torch.sin(state[..., 2]) + dvds[..., 4] * torch.cos(state[..., 2])
        opt_angle = torch.atan2(u2_coeff, u1_coeff) + math.pi
        return torch.cat((250.0 * torch.cos(opt_angle)[..., None], 250.0 * torch.sin(opt_angle)[..., None]), dim=-1)

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [-100, 120, 0, 150, -5, 0.0],
            'state_labels': ['x', 'y', r'$\theta$', r'$v_x$', r'$v_y$', r'$\omega'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 4,
        }

class RocketLanding(Dynamics):
    def __init__(self):
        super().__init__(
            loss_type='brt_hjivi', set_mode='reach',
            state_dim=6, input_dim=8, control_dim=2, disturbance_dim=0,
            state_mean=[0.0, 80.0, 0.0, 0.0, 0.0, 0.0],
            state_var=[150.0, 70.0, 1.2*math.pi, 200.0, 200.0, 10.0],
            value_mean=0.0,
            value_var=1.0,
            value_normto=0.02,
            deepreach_model="exact",
        )

    # convert model input to real coord
    def input_to_coord(self, input):
        input = input[..., :-1]
        coord = input.clone()
        coord[..., 1:] = (input[..., 1:] * self.state_var.to(device=input.device)) + self.state_mean.to(device=input.device)
        return coord

    # convert real coord to model input
    def coord_to_input(self, coord):
        input = coord.clone()
        input[..., 1:] = (coord[..., 1:] - self.state_mean.to(device=coord.device)) / self.state_var.to(device=coord.device)
        input = torch.cat((input, torch.zeros((*input.shape[:-1], 1), device=input.device)), dim=-1)
        return input

    # convert model io to real value
    def io_to_value(self, input, output):
        if self.deepreach_model=="diff":
            return (output * self.value_var / self.value_normto) + self.boundary_fn(self.input_to_coord(input)[..., 1:])
        else:
            return (output * self.value_var / self.value_normto) + self.value_mean

    # convert model io to real dv
    def io_to_dv(self, input, output):
        dodi = jacobian(output.unsqueeze(dim=-1), input)[0].squeeze(dim=-2)[..., :-1]

        if self.deepreach_model=="diff":
            dvdt = (self.value_var / self.value_normto) * dodi[..., 0]

            dvds_term1 = (self.value_var / self.value_normto / self.state_var.to(device=dodi.device)) * dodi[..., 1:]
            state = self.input_to_coord(input)[..., 1:]
            dvds_term2 = jacobian(self.boundary_fn(state).unsqueeze(dim=-1), state)[0].squeeze(dim=-2)
            dvds = dvds_term1 + dvds_term2

        else:
            dvdt = (self.value_var / self.value_normto) * dodi[..., 0]
            dvds = (self.value_var / self.value_normto / self.state_var.to(device=dodi.device)) * dodi[..., 1:]

        return torch.cat((dvdt.unsqueeze(dim=-1), dvds), dim=-1)


    def state_test_range(self):
        return [
            [-150, 150],
            [10, 150],
            [-math.pi, math.pi],
            [-200, 200],
            [-200, 200],
            [-10, 10],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 2] = (wrapped_state[..., 2] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # \dot x = v_x
    # \dot y = v_y
    # \dot th = w
    # \dot v_x = u1 * cos(th) - u2 sin(th)
    # \dot v_y = u1 * sin(th) + u2 cos(th) - 9.81
    # \dot w = 0.3 * u1
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = state[..., 3]
        dsdt[..., 1] = state[..., 4]
        dsdt[..., 2] = state[..., 5]
        dsdt[..., 3] = control[..., 0]*torch.cos(state[..., 2]) - control[..., 1]*torch.sin(state[..., 2])
        dsdt[..., 4] = control[..., 0]*torch.sin(state[..., 2]) + control[..., 1]*torch.cos(state[..., 2]) - 9.81
        dsdt[..., 5] = 0.3*control[..., 0]
        return dsdt

    def boundary_fn(self, state):
        # Only target set in the yz direction
        # Target set position in y direction
        dist_y = torch.abs(state[..., 0]) - 20.0 #[-20, 150] boundary_fn range

        # Target set position in z direction
        dist_z = state[..., 1] - 20.0  #[-10, 130] boundary_fn range

        # First compute the l(x) as you normally would but then normalize it later.
        lx = torch.max(dist_y, dist_z)
        return torch.where((lx >= 0), lx/150.0, lx/10.0)

    def sample_target_state(self, num_samples):
        target_state_range = self.state_test_range()
        target_state_range[0] = [-20, 20] # y in [-20, 20]
        target_state_range[1] = [10, 20]  # z in [10, 20]
        target_state_range = torch.tensor(target_state_range)
        return target_state_range[:, 0] + torch.rand(num_samples, self.state_dim)*(target_state_range[:, 1] - target_state_range[:, 0])

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        # Control Hamiltonian
        u1_coeff = dvds[..., 3] * torch.cos(state[..., 2]) + dvds[..., 4] * torch.sin(state[..., 2]) + 0.3 * dvds[..., 5]
        u2_coeff = -dvds[..., 3] * torch.sin(state[..., 2]) + dvds[..., 4] * torch.cos(state[..., 2])
        ham_ctrl = -250.0 * torch.sqrt(u1_coeff * u1_coeff + u2_coeff * u2_coeff)
        # Constant Hamiltonian
        ham_constant = dvds[..., 0] * state[..., 3] + dvds[..., 1] * state[..., 4] + \
                      dvds[..., 2] * state[..., 5]  - dvds[..., 4] * 9.81
        # Compute the Hamiltonian
        ham_vehicle = ham_ctrl + ham_constant
        return ham_vehicle

    def optimal_control(self, state, dvds):
        u1_coeff = dvds[..., 3] * torch.cos(state[..., 2]) + dvds[..., 4] * torch.sin(state[..., 2]) + 0.3 * dvds[..., 5]
        u2_coeff = -dvds[..., 3] * torch.sin(state[..., 2]) + dvds[..., 4] * torch.cos(state[..., 2])
        opt_angle = torch.atan2(u2_coeff, u1_coeff) + math.pi
        return torch.cat((250.0 * torch.cos(opt_angle)[..., None], 250.0 * torch.sin(opt_angle)[..., None]), dim=-1)

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [-100, 120, 0, 150, -5, 0.0],
            'state_labels': ['x', 'y', r'$\theta$', r'$v_x$', r'$v_y$', r'$\omega'],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 4,
        }

class Quadrotor(Dynamics):
    def __init__(self, collisionR:float, thrust_max:float, set_mode:str):
        self.thrust_max = thrust_max
        self.m=1 #mass
        self.arm_l=0.17
        self.CT=1
        self.CM=0.016
        self.Gz=-9.8

        self.thrust_max = thrust_max
        self.collisionR = collisionR


        super().__init__(
            loss_type='brt_hjivi', set_mode=set_mode,
            state_dim=13, input_dim=14, control_dim=4, disturbance_dim=0,
            state_mean=[0 for i in range(13)],
            state_var=[1.5, 1.5, 1.5, 1, 1, 1, 1, 10, 10 ,10 ,10 ,10 ,10],
            value_mean=(math.sqrt(1.5**2+1.5**2+1.5**2)-2*self.collisionR)/2,
            value_var=math.sqrt(1.5**2+1.5**2+1.5**2),
            value_normto=0.02,
            deepreach_model="exact"
        )

    def state_test_range(self):
        return [
            [-1.5, 1.5],
            [-1.5, 1.5],
            [-1.5, 1.5],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-1, 1],
            [-10, 10],
            [-10, 10],
            [-10, 10],
            [-10, 10],
            [-10, 10],
            [-10, 10],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        return wrapped_state

    # Dubins3D dynamics
    # \dot x    = v \cos \theta
    # \dot y    = v \sin \theta
    # \dot \theta = u
    def dsdt(self, state, control, disturbance):
        qw = state[..., 3] * 1.0
        qx = state[..., 4] * 1.0
        qy = state[..., 5] * 1.0
        qz = state[..., 6] * 1.0
        vx = state[..., 7] * 1.0
        vy = state[..., 8] * 1.0
        vz = state[..., 9] * 1.0
        wx = state[..., 10] * 1.0
        wy = state[..., 11] * 1.0
        wz = state[..., 12] * 1.0
        u1 = control[...,0] * 1.0
        u2 = control[...,1] * 1.0
        u3 = control[...,2] * 1.0
        u4 = control[...,3] * 1.0


        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = vx
        dsdt[..., 1] = vy
        dsdt[..., 2] = vz
        dsdt[..., 3] = -(wx*qx+wy*qy+wz*qz)/2.0
        dsdt[..., 4] =  (wx*qw+wz*qy-wy*qz)/2.0
        dsdt[..., 5] = (wy*qw-wz*qx+wx*qz)/2.0
        dsdt[..., 6] = (wz*qw+wy*qx-wx*qy)/2.0
        dsdt[..., 7] = 2*(qw*qy+qx*qz)*self.CT/self.m*(u1+u2+u3+u4)
        dsdt[..., 8] =2*(-qw*qx+qy*qz)*self.CT/self.m*(u1+u2+u3+u4)
        dsdt[..., 9] =self.Gz+(1-2*torch.pow(qx,2)-2*torch.pow(qy,2))*self.CT/self.m*(u1+u2+u3+u4)
        dsdt[..., 10] = 4*math.sqrt(2)*self.CT*(u1-u2-u3+u4)/(3*self.arm_l*self.m)-5*wy*wz/9.0
        dsdt[..., 11] = 4*math.sqrt(2)*self.CT*(-u1-u2+u3+u4)/(3*self.arm_l*self.m)+5*wx*wz/9.0
        dsdt[..., 12] =12*self.CT*self.CM/(7*self.arm_l**2*self.m)*(u1-u2+u3-u4)
        return dsdt

    def boundary_fn(self, state):
        return torch.norm(state[..., :3], dim=-1) - self.collisionR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        if self.set_mode == 'reach':
            raise NotImplementedError

        elif self.set_mode == 'avoid':
            qw = state[..., 3] * 1.0
            qx = state[..., 4] * 1.0
            qy = state[..., 5] * 1.0
            qz = state[..., 6] * 1.0
            vx = state[..., 7] * 1.0
            vy = state[..., 8] * 1.0
            vz = state[..., 9] * 1.0
            wx = state[..., 10] * 1.0
            wy = state[..., 11] * 1.0
            wz = state[..., 12] * 1.0


            C1=2*(qw*qy+qx*qz)*self.CT/self.m
            C2=2*(-qw*qx+qy*qz)*self.CT/self.m
            C3=(1-2*torch.pow(qx,2)-2*torch.pow(qy,2))*self.CT/self.m
            C4=4*math.sqrt(2)*self.CT/(3*self.arm_l*self.m)
            C5=4*math.sqrt(2)*self.CT/(3*self.arm_l*self.m)
            C6=12*self.CT*self.CM/(7*self.arm_l**2*self.m)

            # Compute the hamiltonian for the quadrotor
            ham= dvds[..., 0]*vx + dvds[..., 1]*vy+ dvds[..., 2]*vz
            ham+= -dvds[..., 3]* (wx*qx+wy*qy+wz*qz)/2.0
            ham+= dvds[..., 4]*(wx*qw+wz*qy-wy*qz)/2.0
            ham+= dvds[..., 5]*(wy*qw-wz*qx+wx*qz)/2.0
            ham+= dvds[..., 6]*(wz*qw+wy*qx-wx*qy)/2.0
            ham+= dvds[..., 9]*-9.8
            ham+= -dvds[..., 10]*5*wy*wz/9.0+ dvds[..., 11]*5*wx*wz/9.0

            ham+=torch.abs(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                +dvds[..., 10]*C4-dvds[..., 11]*C5+dvds[..., 12]*C6)*self.thrust_max

            ham+=torch.abs(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                -dvds[..., 10]*C4-dvds[..., 11]*C5-dvds[..., 12]*C6)*self.thrust_max

            ham+=torch.abs(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                -dvds[..., 10]*C4+dvds[..., 11]*C5+dvds[..., 12]*C6)*self.thrust_max

            ham+=torch.abs(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                +dvds[..., 10]*C4+dvds[..., 11]*C5-dvds[..., 12]*C6)*self.thrust_max

            return ham

    def optimal_control(self, state, dvds):
        if self.set_mode == 'reach':
            raise NotImplementedError
        elif self.set_mode == 'avoid':
            qw = state[..., 3] * 1.0
            qx = state[..., 4] * 1.0
            qy = state[..., 5] * 1.0
            qz = state[..., 6] * 1.0


            C1=2*(qw*qy+qx*qz)*self.CT/self.m
            C2=2*(-qw*qx+qy*qz)*self.CT/self.m
            C3=(1-2*torch.pow(qx,2)-2*torch.pow(qy,2))*self.CT/self.m
            C4=4*math.sqrt(2)*self.CT/(3*self.arm_l*self.m)
            C5=4*math.sqrt(2)*self.CT/(3*self.arm_l*self.m)
            C6=12*self.CT*self.CM/(7*self.arm_l**2*self.m)


            u1=self.thrust_max*torch.sign(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                +dvds[..., 10]*C4-dvds[..., 11]*C5+dvds[..., 12]*C6)
            u2=self.thrust_max*torch.sign(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                -dvds[..., 10]*C4-dvds[..., 11]*C5-dvds[..., 12]*C6)
            u3=self.thrust_max*torch.sign(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                -dvds[..., 10]*C4+dvds[..., 11]*C5+dvds[..., 12]*C6)
            u4=self.thrust_max*torch.sign(dvds[..., 7]*C1+dvds[..., 8]*C2+dvds[..., 9]*C3
                +dvds[..., 10]*C4+dvds[..., 11]*C5-dvds[..., 12]*C6)

        return torch.cat((u1[..., None], u2[..., None], u3[..., None], u4[..., None]), dim=-1)

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            'state_labels': ['x', 'y', 'z', 'qw', 'qx', 'qy', 'qz', 'vx', 'vy', 'vz', 'wx', 'wy', 'wz'],
            'x_axis_idx': 0,
            'y_axis_idx': 2,
            'z_axis_idx': 7,
        }

class MultiVehicleCollision(Dynamics):
    def __init__(self):
        self.angle_alpha_factor = 1.2
        self.velocity = 0.6
        self.omega_max = 1.1
        self.collisionR = 0.25
        super().__init__(
            loss_type='brt_hjivi', set_mode='avoid',
            state_dim=9, input_dim=10, control_dim=3, disturbance_dim=0,
            state_mean=[
                0, 0,
                0, 0,
                0, 0,
                0, 0, 0,
            ],
            state_var=[
                1, 1,
                1, 1,
                1, 1,
                self.angle_alpha_factor*math.pi, self.angle_alpha_factor*math.pi, self.angle_alpha_factor*math.pi,
            ],
            value_mean=0.25,
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact"
        )

    def state_test_range(self):
        return [
            [-1, 1], [-1, 1],
            [-1, 1], [-1, 1],
            [-1, 1], [-1, 1],
            [-math.pi, math.pi], [-math.pi, math.pi], [-math.pi, math.pi],
        ]

    def equivalent_wrapped_state(self, state):
        wrapped_state = torch.clone(state)
        wrapped_state[..., 6] = (wrapped_state[..., 6] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 7] = (wrapped_state[..., 7] + math.pi) % (2*math.pi) - math.pi
        wrapped_state[..., 8] = (wrapped_state[..., 8] + math.pi) % (2*math.pi) - math.pi
        return wrapped_state

    # dynamics (per car)
    # \dot x    = v \cos \theta
    # \dot y    = v \sin \theta
    # \dot \theta = u
    def dsdt(self, state, control, disturbance):
        dsdt = torch.zeros_like(state)
        dsdt[..., 0] = self.velocity*torch.cos(state[..., 6])
        dsdt[..., 1] = self.velocity*torch.sin(state[..., 6])
        dsdt[..., 2] = self.velocity*torch.cos(state[..., 7])
        dsdt[..., 3] = self.velocity*torch.sin(state[..., 7])
        dsdt[..., 4] = self.velocity*torch.cos(state[..., 8])
        dsdt[..., 5] = self.velocity*torch.sin(state[..., 8])
        dsdt[..., 6] = control[..., 0]
        dsdt[..., 7] = control[..., 1]
        dsdt[..., 8] = control[..., 2]
        return dsdt

    def boundary_fn(self, state):
        boundary_values = torch.norm(state[..., 0:2] - state[..., 2:4], dim=-1) - self.collisionR
        for i in range(1, 2):
            boundary_values_current = torch.norm(state[..., 0:2] - state[..., 2*(i+1):2*(i+1)+2], dim=-1) - self.collisionR
            boundary_values = torch.min(boundary_values, boundary_values_current)
        # Collision cost between the evaders themselves
        for i in range(2):
            for j in range(i+1, 2):
                evader1_coords_index = (i+1)*2
                evader2_coords_index = (j+1)*2
                boundary_values_current = torch.norm(state[..., evader1_coords_index:evader1_coords_index+2] - state[..., evader2_coords_index:evader2_coords_index+2], dim=-1) - self.collisionR
                boundary_values = torch.min(boundary_values, boundary_values_current)
        return boundary_values

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        # Compute the hamiltonian for the ego vehicle
        ham = self.velocity*(torch.cos(state[..., 6]) * dvds[..., 0] + torch.sin(state[..., 6]) * dvds[..., 1]) + self.omega_max * torch.abs(dvds[..., 6])
        # Hamiltonian effect due to other vehicles
        ham += self.velocity*(torch.cos(state[..., 7]) * dvds[..., 2] + torch.sin(state[..., 7]) * dvds[..., 3]) + self.omega_max * torch.abs(dvds[..., 7])
        ham += self.velocity*(torch.cos(state[..., 8]) * dvds[..., 4] + torch.sin(state[..., 8]) * dvds[..., 5]) + self.omega_max * torch.abs(dvds[..., 8])
        return ham

    def optimal_control(self, state, dvds):
        return self.omega_max*torch.sign(dvds[..., [6, 7, 8]])

    def optimal_disturbance(self, state, dvds):
        return 0

    def plot_config(self):
        return {
            'state_slices': [
                0, 0,
                -0.4, 0,
                0.4, 0,
                math.pi/2, math.pi/4, 3*math.pi/4,
            ],
            'state_labels': [
                r'$x_1$', r'$y_1$',
                r'$x_2$', r'$y_2$',
                r'$x_3$', r'$y_3$',
                r'$\theta_1$', r'$\theta_2$', r'$\theta_3$',
            ],
            'x_axis_idx': 0,
            'y_axis_idx': 1,
            'z_axis_idx': 6,
        }

###############################################################
# Custom Dynamics for 2d Planar Robot
###############################################################
class PlanarRobot2D(Dynamics):
    def __init__(
        self,
        goalR: float,
        velocity: float,
        set_mode: str = 'avoid',
        freeze_model: bool = False
    ):
        """
        2D constant-speed robot dynamics for reach/avoid problems.

        Args:
            goalR: Radius of the obstacle (avoid) or goal (reach) circle.
            velocity: Constant forward speed (m/s).
            set_mode: 'reach' or 'avoid'.
            freeze_model: If True, dsdt/hamiltonian will raise NotImplementedError.
        """
        self.goalR = goalR
        self.velocity = velocity
        self.freeze_model = freeze_model

        super().__init__(
            loss_type='brt_hjivi',
            set_mode=set_mode,
            state_dim=2, # [p_x, p_y]
            input_dim=3, # [p_x, p_y, t]
            control_dim=1, # [theta]
            disturbance_dim=0,

            # Normalize x,y from [-2,2] to [-1,1]
            state_mean=[0.0, 0.0],
            state_var=[2.0, 2.0],

            value_mean=0.0, # value_mean is not used in 'exact' mode
            value_var=0.5,
            value_normto=0.02,
            deepreach_model="exact",
        )

    def state_test_range(self):
        return [
            [-2.0, 2.0], # p_x
            [-2.0, 2.0], # p_y
        ]

    def equivalent_wrapped_state(self, state):
        return state

    def dsdt(self, state, control, disturbance):
        # \dot p_x = v \cos \theta
        # \dot p_y = v \sin \theta
        if self.freeze_model:
            raise NotImplementedError
        theta = control[..., 0]
        ds = torch.zeros_like(state)
        ds[..., 0] = self.velocity * torch.cos(theta)
        ds[..., 1] = self.velocity * torch.sin(theta)
        return ds

    def boundary_fn(self, state):
        return torch.norm(state, dim=-1) - self.goalR

    def sample_target_state(self, num_samples):
        raise NotImplementedError

    def cost_fn(self, state_traj):
        raise NotImplementedError
        # return torch.min(self.boundary_fn(state_traj), dim=-1).values

    def hamiltonian(self, state, dvds):
        if self.freeze_model:
            raise NotImplementedError
        norm = torch.sqrt(dvds[..., 0]**2 + dvds[..., 1]**2)
        if self.set_mode == 'reach':
            return  - self.velocity * norm
        elif self.set_mode == 'avoid':
            return  self.velocity * norm

    def optimal_control(self, state, dvds):
        raise NotImplementedError

    def optimal_disturbance(self, state, dvds):
        raise NotImplementedError

    def plot_config(self):
        return {
            "state_slices": [0, 0],
            "state_labels": ["p_x", "p_y"],
            "x_axis_idx": 0,
            "y_axis_idx": 1,
            "z_axis_idx": -1,
        }

###############################################################


### 6. Experiments Code

In [12]:
class Experiment(ABC):
    def __init__(self, model, dataset, experiment_dir, use_wandb):
        self.model = model
        self.dataset = dataset
        self.experiment_dir = experiment_dir
        self.use_wandb = use_wandb

    @abstractmethod
    def init_special(self):
        raise NotImplementedError

    def _load_checkpoint(self, epoch):
        if epoch == -1:
            model_path = os.path.join(self.experiment_dir, 'training', 'checkpoints', 'model_final.pth')
            self.model.load_state_dict(torch.load(model_path))
        else:
            model_path = os.path.join(self.experiment_dir, 'training', 'checkpoints', 'model_epoch_%04d.pth' % epoch)
            self.model.load_state_dict(torch.load(model_path)['model'])

    def validate(self, device, epoch, save_path, x_resolution, y_resolution, z_resolution, time_resolution):
        was_training = self.model.training
        self.model.eval()
        self.model.requires_grad_(False)

        plot_config = self.dataset.dynamics.plot_config()
        state_test_range = self.dataset.dynamics.state_test_range()

        x_idx = plot_config['x_axis_idx']
        y_idx = plot_config['y_axis_idx']
        z_idx = plot_config.get('z_axis_idx', -1)

        x_min, x_max = state_test_range[x_idx]
        y_min, y_max = state_test_range[y_idx]
        xs = torch.linspace(x_min, x_max, x_resolution)
        ys = torch.linspace(y_min, y_max, y_resolution)
        xys = torch.cartesian_prod(xs, ys)

        # Determine z slices: if z_idx == -1, use a single dummy slice
        if z_idx == -1:
            zs = torch.tensor([0.0], dtype=torch.float32)
        else:
            z_min, z_max = state_test_range[z_idx]
            zs = torch.linspace(z_min, z_max, z_resolution)

        times = torch.linspace(0, self.dataset.tMax, time_resolution)
        fig = plt.figure(figsize=(5*len(times), 5*len(zs)))

        # Loop over time and z-slices
        for i, t in enumerate(times):
            for j, z_val in enumerate(zs):
                coords = torch.zeros(x_resolution*y_resolution, self.dataset.dynamics.state_dim + 1)
                coords[:, 0] = t
                coords[:, 1:] = torch.tensor(plot_config['state_slices'])

                coords[:, 1 + x_idx] = xys[:, 0]
                coords[:, 1 + y_idx] = xys[:, 1]

                # Only fill z if it's a real 3D axis
                if z_idx != -1:
                    coords[:, 1 + z_idx] = z_val

                with torch.no_grad():
                    inp = self.dataset.dynamics.coord_to_input(coords.to(device))
                    results = self.model({'coords': inp})
                    values = self.dataset.dynamics.io_to_value(results['model_in'].detach(), results['model_out'].squeeze(dim=-1).detach())

                ax = fig.add_subplot(len(times), len(zs), (j+1) + i*len(zs))
                title_z = (f", {plot_config['state_labels'][z_idx]} = {z_val:.2f}"
                        if z_idx != -1 else "")
                ax.set_title(f"t = {t:.2f}{title_z}")

                img = values.cpu().numpy().reshape(x_resolution, y_resolution).T
                # binary decision boundary: <=0 in red/blue
                mask = (img <= 0).astype(float)
                s = ax.imshow(mask, cmap='bwr', origin='lower', extent=(x_min, x_max, y_min, y_max))
                fig.colorbar(s, ax=ax)

        fig.savefig(save_path)
        if self.use_wandb:
            wandb.log({
                'step': epoch,
                'val_plot': wandb.Image(fig),
            })
        plt.close()

        if was_training:
            self.model.train()
            self.model.requires_grad_(True)

    def train(
            self, device, batch_size, epochs, lr,
            steps_til_summary, epochs_til_checkpoint,
            loss_fn, clip_grad, use_lbfgs, adjust_relative_grads,
            val_x_resolution, val_y_resolution, val_z_resolution, val_time_resolution,
            use_CSL, CSL_lr, CSL_dt, epochs_til_CSL, num_CSL_samples, CSL_loss_frac_cutoff, max_CSL_epochs, CSL_loss_weight, CSL_batch_size,
        ):
        was_eval = not self.model.training
        self.model.train()
        self.model.requires_grad_(True)

        train_dataloader = DataLoader(self.dataset, shuffle=True, batch_size=batch_size, pin_memory=True, num_workers=0)

        optim = torch.optim.Adam(lr=lr, params=self.model.parameters())

        # copy settings from Raissi et al. (2019) and here
        # https://github.com/maziarraissi/PINNs
        if use_lbfgs:
            optim = torch.optim.LBFGS(lr=lr, params=self.model.parameters(), max_iter=50000, max_eval=50000,
                                    history_size=50, line_search_fn='strong_wolfe')

        training_dir = os.path.join(self.experiment_dir, 'training')

        summaries_dir = os.path.join(training_dir, 'summaries')
        if not os.path.exists(summaries_dir):
            os.makedirs(summaries_dir)

        checkpoints_dir = os.path.join(training_dir, 'checkpoints')
        if not os.path.exists(checkpoints_dir):
            os.makedirs(checkpoints_dir)

        writer = SummaryWriter(summaries_dir)

        total_steps = 0

        if adjust_relative_grads:
            new_weight = 1

        with tqdm(total=len(train_dataloader) * epochs) as pbar:
            train_losses = []
            last_CSL_epoch = -1
            for epoch in range(0, epochs):
                if self.dataset.pretrain: # skip CSL
                    last_CSL_epoch = epoch
                time_interval_length = (self.dataset.counter/self.dataset.counter_end)*(self.dataset.tMax-self.dataset.tMin)
                CSL_tMax = self.dataset.tMin + int(time_interval_length/CSL_dt)*CSL_dt

                # self-supervised learning
                for step, (model_input, gt) in enumerate(train_dataloader):
                    start_time = time.time()

                    model_input = {key: value.to(device) for key, value in model_input.items()}
                    gt = {key: value.to(device) for key, value in gt.items()}

                    model_results = self.model({'coords': model_input['model_coords']})

                    states = self.dataset.dynamics.input_to_coord(model_results['model_in'].detach())[..., 1:]
                    values = self.dataset.dynamics.io_to_value(model_results['model_in'].detach(), model_results['model_out'].squeeze(dim=-1))
                    dvs = self.dataset.dynamics.io_to_dv(model_results['model_in'], model_results['model_out'].squeeze(dim=-1))
                    boundary_values = gt['boundary_values']
                    if self.dataset.dynamics.loss_type == 'brat_hjivi':
                        reach_values = gt['reach_values']
                        avoid_values = gt['avoid_values']
                    dirichlet_masks = gt['dirichlet_masks']

                    if self.dataset.dynamics.loss_type == 'brt_hjivi':
                        losses = loss_fn(states, values, dvs[..., 0], dvs[..., 1:], boundary_values, dirichlet_masks, model_results['model_out'])
                    elif self.dataset.dynamics.loss_type == 'brat_hjivi':
                        losses = loss_fn(states, values, dvs[..., 0], dvs[..., 1:], boundary_values, reach_values, avoid_values, dirichlet_masks, model_results['model_out'])
                    else:
                        raise NotImplementedError

                    if use_lbfgs:
                        def closure():
                            optim.zero_grad()
                            train_loss = 0.
                            for loss_name, loss in losses.items():
                                train_loss += loss.mean()
                            train_loss.backward()
                            return train_loss
                        optim.step(closure)

                    # Adjust the relative magnitude of the losses if required
                    if self.dataset.dynamics.deepreach_model in ['vanilla', 'diff'] and adjust_relative_grads:
                        if losses['diff_constraint_hom'] > 0.01:
                            params = OrderedDict(self.model.named_parameters())
                            # Gradients with respect to the PDE loss
                            optim.zero_grad()
                            losses['diff_constraint_hom'].backward(retain_graph=True)
                            grads_PDE = []
                            for key, param in params.items():
                                grads_PDE.append(param.grad.view(-1))
                            grads_PDE = torch.cat(grads_PDE)

                            # Gradients with respect to the boundary loss
                            optim.zero_grad()
                            losses['dirichlet'].backward(retain_graph=True)
                            grads_dirichlet = []
                            for key, param in params.items():
                                grads_dirichlet.append(param.grad.view(-1))
                            grads_dirichlet = torch.cat(grads_dirichlet)

                            # # Plot the gradients
                            # import seaborn as sns
                            # import matplotlib.pyplot as plt
                            # fig = plt.figure(figsize=(5, 5))
                            # ax = fig.add_subplot(1, 1, 1)
                            # ax.set_yscale('symlog')
                            # sns.distplot(grads_PDE.cpu().numpy(), hist=False, kde_kws={"shade": False}, norm_hist=True)
                            # sns.distplot(grads_dirichlet.cpu().numpy(), hist=False, kde_kws={"shade": False}, norm_hist=True)
                            # fig.savefig('gradient_visualization.png')

                            # fig = plt.figure(figsize=(5, 5))
                            # ax = fig.add_subplot(1, 1, 1)
                            # ax.set_yscale('symlog')
                            # grads_dirichlet_normalized = grads_dirichlet * torch.mean(torch.abs(grads_PDE))/torch.mean(torch.abs(grads_dirichlet))
                            # sns.distplot(grads_PDE.cpu().numpy(), hist=False, kde_kws={"shade": False}, norm_hist=True)
                            # sns.distplot(grads_dirichlet_normalized.cpu().numpy(), hist=False, kde_kws={"shade": False}, norm_hist=True)
                            # ax.set_xlim([-1000.0, 1000.0])
                            # fig.savefig('gradient_visualization_normalized.png')

                            # Set the new weight according to the paper
                            # num = torch.max(torch.abs(grads_PDE))
                            num = torch.mean(torch.abs(grads_PDE))
                            den = torch.mean(torch.abs(grads_dirichlet))
                            new_weight = 0.9*new_weight + 0.1*num/den
                            losses['dirichlet'] = new_weight*losses['dirichlet']
                        writer.add_scalar('weight_scaling', new_weight, total_steps)

                    # import ipdb; ipdb.set_trace()

                    train_loss = 0.
                    for loss_name, loss in losses.items():
                        single_loss = loss.mean()

                        if loss_name == 'dirichlet':
                            writer.add_scalar(loss_name, single_loss/new_weight, total_steps)
                        else:
                            writer.add_scalar(loss_name, single_loss, total_steps)
                        train_loss += single_loss

                    train_losses.append(train_loss.item())
                    writer.add_scalar("total_train_loss", train_loss, total_steps)

                    if not total_steps % steps_til_summary:
                        torch.save(self.model.state_dict(),
                                os.path.join(checkpoints_dir, 'model_current.pth'))
                        # summary_fn(model, model_input, gt, model_output, writer, total_steps)

                    if not use_lbfgs:
                        optim.zero_grad()
                        train_loss.backward()

                        if clip_grad:
                            if isinstance(clip_grad, bool):
                                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.)
                            else:
                                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=clip_grad)

                        optim.step()

                    pbar.update(1)

                    if not total_steps % steps_til_summary:
                        tqdm.write("Epoch %d, Total loss %0.6f, iteration time %0.6f" % (epoch, train_loss, time.time() - start_time))
                        if self.use_wandb:
                            wandb.log({
                                'step': epoch,
                                'train_loss': train_loss,
                                'pde_loss': losses['diff_constraint_hom'],
                            })

                    total_steps += 1

                # cost-supervised learning (CSL) phase
                if use_CSL and not self.dataset.pretrain and (epoch-last_CSL_epoch) >= epochs_til_CSL:
                    last_CSL_epoch = epoch

                    # generate CSL datasets
                    self.model.eval()

                    CSL_dataset = scenario_optimization(
                        device=device, model=self.model, policy=self.model, dynamics=self.dataset.dynamics,
                        tMin=self.dataset.tMin, tMax=CSL_tMax, dt=CSL_dt,
                        set_type="BRT", control_type="value", # TODO: implement option for BRS too
                        scenario_batch_size=min(num_CSL_samples, 100000), sample_batch_size=min(10*num_CSL_samples, 1000000),
                        sample_generator=SliceSampleGenerator(dynamics=self.dataset.dynamics, slices=[None]*self.dataset.dynamics.state_dim),
                        sample_validator=ValueThresholdValidator(v_min=float('-inf'), v_max=float('inf')),
                        violation_validator=ValueThresholdValidator(v_min=0.0, v_max=0.0),
                        max_scenarios=num_CSL_samples, tStart_generator=lambda n : torch.zeros(n).uniform_(self.dataset.tMin, CSL_tMax)
                    )
                    CSL_coords = torch.cat((CSL_dataset['times'].unsqueeze(-1), CSL_dataset['states']), dim=-1)
                    CSL_costs = CSL_dataset['costs']

                    num_CSL_val_samples = int(0.1*num_CSL_samples)
                    CSL_val_dataset = scenario_optimization(
                        model=self.model, policy=self.model, dynamics=self.dataset.dynamics,
                        tMin=self.dataset.tMin, tMax=CSL_tMax, dt=CSL_dt,
                        set_type="BRT", control_type="value", # TODO: implement option for BRS too
                        scenario_batch_size=min(num_CSL_val_samples, 100000), sample_batch_size=min(10*num_CSL_val_samples, 1000000),
                        sample_generator=SliceSampleGenerator(dynamics=self.dataset.dynamics, slices=[None]*self.dataset.dynamics.state_dim),
                        sample_validator=ValueThresholdValidator(v_min=float('-inf'), v_max=float('inf')),
                        violation_validator=ValueThresholdValidator(v_min=0.0, v_max=0.0),
                        max_scenarios=num_CSL_val_samples, tStart_generator=lambda n : torch.zeros(n).uniform_(self.dataset.tMin, CSL_tMax)
                    )
                    CSL_val_coords = torch.cat((CSL_val_dataset['times'].unsqueeze(-1), CSL_val_dataset['states']), dim=-1)
                    CSL_val_costs = CSL_val_dataset['costs']

                    CSL_val_tMax_dataset = scenario_optimization(
                        model=self.model, policy=self.model, dynamics=self.dataset.dynamics,
                        tMin=self.dataset.tMin, tMax=self.dataset.tMax, dt=CSL_dt,
                        set_type="BRT", control_type="value", # TODO: implement option for BRS too
                        scenario_batch_size=min(num_CSL_val_samples, 100000), sample_batch_size=min(10*num_CSL_val_samples, 1000000),
                        sample_generator=SliceSampleGenerator(dynamics=self.dataset.dynamics, slices=[None]*self.dataset.dynamics.state_dim),
                        sample_validator=ValueThresholdValidator(v_min=float('-inf'), v_max=float('inf')),
                        violation_validator=ValueThresholdValidator(v_min=0.0, v_max=0.0),
                        max_scenarios=num_CSL_val_samples # no tStart_generator, since I want all tMax times
                    )
                    CSL_val_tMax_coords = torch.cat((CSL_val_tMax_dataset['times'].unsqueeze(-1), CSL_val_tMax_dataset['states']), dim=-1)
                    CSL_val_tMax_costs = CSL_val_tMax_dataset['costs']

                    self.model.train()

                    # CSL optimizer
                    CSL_optim = torch.optim.Adam(lr=CSL_lr, params=self.model.parameters())

                    # initial CSL val loss
                    CSL_val_results = self.model({'coords': self.dataset.dynamics.coord_to_input(CSL_val_coords.to(device))})
                    CSL_val_preds = self.dataset.dynamics.io_to_value(CSL_val_results['model_in'], CSL_val_results['model_out'].squeeze(dim=-1))
                    CSL_val_errors = CSL_val_preds - CSL_val_costs.to(device)
                    CSL_val_loss = torch.mean(torch.pow(CSL_val_errors, 2))
                    CSL_initial_val_loss = CSL_val_loss
                    if self.use_wandb:
                        wandb.log({
                            "step": epoch,
                            "CSL_val_loss": CSL_val_loss.item()
                        })

                    # initial self-supervised learning (SSL) val loss
                    # right now, just took code from dataio.py and the SSL training loop above; TODO: refactor all this for cleaner modular code
                    CSL_val_states = CSL_val_coords[..., 1:].to(device)
                    CSL_val_dvs = self.dataset.dynamics.io_to_dv(CSL_val_results['model_in'], CSL_val_results['model_out'].squeeze(dim=-1))
                    CSL_val_boundary_values = self.dataset.dynamics.boundary_fn(CSL_val_states)
                    if self.dataset.dynamics.loss_type == 'brat_hjivi':
                        CSL_val_reach_values = self.dataset.dynamics.reach_fn(CSL_val_states)
                        CSL_val_avoid_values = self.dataset.dynamics.avoid_fn(CSL_val_states)
                    CSL_val_dirichlet_masks = CSL_val_coords[:, 0].to(device) == self.dataset.tMin # assumes time unit in dataset (model) is same as real time units
                    if self.dataset.dynamics.loss_type == 'brt_hjivi':
                        SSL_val_losses = loss_fn(CSL_val_states, CSL_val_preds, CSL_val_dvs[..., 0], CSL_val_dvs[..., 1:], CSL_val_boundary_values, CSL_val_dirichlet_masks)
                    elif self.dataset.dynamics.loss_type == 'brat_hjivi':
                        SSL_val_losses = loss_fn(CSL_val_states, CSL_val_preds, CSL_val_dvs[..., 0], CSL_val_dvs[..., 1:], CSL_val_boundary_values, CSL_val_reach_values, CSL_val_avoid_values, CSL_val_dirichlet_masks)
                    else:
                        NotImplementedError
                    SSL_val_loss = SSL_val_losses['diff_constraint_hom'].mean() # I assume there is no dirichlet (boundary) loss here, because I do not ever explicitly generate source samples at tMin (i.e. torch.all(CSL_val_dirichlet_masks == False))
                    if self.use_wandb:
                        wandb.log({
                            "step": epoch,
                            "SSL_val_loss": SSL_val_loss.item()
                        })

                    # CSL training loop
                    for CSL_epoch in tqdm(range(max_CSL_epochs)):
                        CSL_idxs = torch.randperm(num_CSL_samples)
                        for CSL_batch in range(math.ceil(num_CSL_samples/CSL_batch_size)):
                            CSL_batch_idxs = CSL_idxs[CSL_batch*CSL_batch_size:(CSL_batch+1)*CSL_batch_size]
                            CSL_batch_coords = CSL_coords[CSL_batch_idxs]

                            CSL_batch_results = self.model({'coords': self.dataset.dynamics.coord_to_input(CSL_batch_coords.to(device))})
                            CSL_batch_preds = self.dataset.dynamics.io_to_value(CSL_batch_results['model_in'], CSL_batch_results['model_out'].squeeze(dim=-1))
                            CSL_batch_costs = CSL_costs[CSL_batch_idxs].to(device)
                            CSL_batch_errors = CSL_batch_preds - CSL_batch_costs
                            CSL_batch_loss = CSL_loss_weight*torch.mean(torch.pow(CSL_batch_errors, 2))

                            CSL_batch_states = CSL_batch_coords[..., 1:].to(device)
                            CSL_batch_dvs = self.dataset.dynamics.io_to_dv(CSL_batch_results['model_in'], CSL_batch_results['model_out'].squeeze(dim=-1))
                            CSL_batch_boundary_values = self.dataset.dynamics.boundary_fn(CSL_batch_states)
                            if self.dataset.dynamics.loss_type == 'brat_hjivi':
                                CSL_batch_reach_values = self.dataset.dynamics.reach_fn(CSL_batch_states)
                                CSL_batch_avoid_values = self.dataset.dynamics.avoid_fn(CSL_batch_states)
                            CSL_batch_dirichlet_masks = CSL_batch_coords[:, 0].to(device) == self.dataset.tMin # assumes time unit in dataset (model) is same as real time units
                            if self.dataset.dynamics.loss_type == 'brt_hjivi':
                                SSL_batch_losses = loss_fn(CSL_batch_states, CSL_batch_preds, CSL_batch_dvs[..., 0], CSL_batch_dvs[..., 1:], CSL_batch_boundary_values, CSL_batch_dirichlet_masks)
                            elif self.dataset.dynamics.loss_type == 'brat_hjivi':
                                SSL_batch_losses = loss_fn(CSL_batch_states, CSL_batch_preds, CSL_batch_dvs[..., 0], CSL_batch_dvs[..., 1:], CSL_batch_boundary_values, CSL_batch_reach_values, CSL_batch_avoid_values, CSL_batch_dirichlet_masks)
                            else:
                                NotImplementedError
                            SSL_batch_loss = SSL_batch_losses['diff_constraint_hom'].mean() # I assume there is no dirichlet (boundary) loss here, because I do not ever explicitly generate source samples at tMin (i.e. torch.all(CSL_batch_dirichlet_masks == False))

                            CSL_optim.zero_grad()
                            SSL_batch_loss.backward(retain_graph=True)
                            if (not use_lbfgs) and clip_grad: # no adjust_relative_grads, because I assume even with adjustment, the diff_constraint_hom remains unaffected and the only other loss (dirichlet) is zero
                                if isinstance(clip_grad, bool):
                                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.)
                                else:
                                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=clip_grad)
                            CSL_batch_loss.backward()
                            CSL_optim.step()

                        # evaluate on CSL_val_dataset
                        CSL_val_results = self.model({'coords': self.dataset.dynamics.coord_to_input(CSL_val_coords.to(device))})
                        CSL_val_preds = self.dataset.dynamics.io_to_value(CSL_val_results['model_in'], CSL_val_results['model_out'].squeeze(dim=-1))
                        CSL_val_errors = CSL_val_preds - CSL_val_costs.to(device)
                        CSL_val_loss = torch.mean(torch.pow(CSL_val_errors, 2))

                        CSL_val_states = CSL_val_coords[..., 1:].to(device)
                        CSL_val_dvs = self.dataset.dynamics.io_to_dv(CSL_val_results['model_in'], CSL_val_results['model_out'].squeeze(dim=-1))
                        CSL_val_boundary_values = self.dataset.dynamics.boundary_fn(CSL_val_states)
                        if self.dataset.dynamics.loss_type == 'brat_hjivi':
                            CSL_val_reach_values = self.dataset.dynamics.reach_fn(CSL_val_states)
                            CSL_val_avoid_values = self.dataset.dynamics.avoid_fn(CSL_val_states)
                        CSL_val_dirichlet_masks = CSL_val_coords[:, 0].to(device) == self.dataset.tMin # assumes time unit in dataset (model) is same as real time units
                        if self.dataset.dynamics.loss_type == 'brt_hjivi':
                            SSL_val_losses = loss_fn(CSL_val_states, CSL_val_preds, CSL_val_dvs[..., 0], CSL_val_dvs[..., 1:], CSL_val_boundary_values, CSL_val_dirichlet_masks)
                        elif self.dataset.dynamics.loss_type == 'brat_hjivi':
                            SSL_val_losses = loss_fn(CSL_val_states, CSL_val_preds, CSL_val_dvs[..., 0], CSL_val_dvs[..., 1:], CSL_val_boundary_values, CSL_val_reach_values, CSL_val_avoid_values, CSL_val_dirichlet_masks)
                        else:
                            raise NotImplementedError
                        SSL_val_loss = SSL_val_losses['diff_constraint_hom'].mean() # I assume there is no dirichlet (boundary) loss here, because I do not ever explicitly generate source samples at tMin (i.e. torch.all(CSL_val_dirichlet_masks == False))

                        CSL_val_tMax_results = self.model({'coords': self.dataset.dynamics.coord_to_input(CSL_val_tMax_coords.to(device))})
                        CSL_val_tMax_preds = self.dataset.dynamics.io_to_value(CSL_val_tMax_results['model_in'], CSL_val_tMax_results['model_out'].squeeze(dim=-1))
                        CSL_val_tMax_errors = CSL_val_tMax_preds - CSL_val_tMax_costs.to(device)
                        CSL_val_tMax_loss = torch.mean(torch.pow(CSL_val_tMax_errors, 2))

                        # log CSL losses, recovered_safe_set_fracs
                        if self.dataset.dynamics.set_mode == 'reach':
                            CSL_train_batch_theoretically_recoverable_safe_set_frac = torch.sum(CSL_batch_costs.to(device) < 0) / len(CSL_batch_preds)
                            CSL_train_batch_recovered_safe_set_frac = torch.sum(CSL_batch_preds < torch.min(CSL_batch_preds[CSL_batch_costs.to(device) > 0])) / len(CSL_batch_preds)
                            CSL_val_theoretically_recoverable_safe_set_frac = torch.sum(CSL_val_costs.to(device) < 0) / len(CSL_val_preds)
                            CSL_val_recovered_safe_set_frac = torch.sum(CSL_val_preds < torch.min(CSL_val_preds[CSL_val_costs.to(device) > 0])) / len(CSL_val_preds)
                            CSL_val_tMax_theoretically_recoverable_safe_set_frac = torch.sum(CSL_val_tMax_costs.to(device) < 0) / len(CSL_val_tMax_preds)
                            CSL_val_tMax_recovered_safe_set_frac = torch.sum(CSL_val_tMax_preds < torch.min(CSL_val_tMax_preds[CSL_val_tMax_costs.to(device) > 0])) / len(CSL_val_tMax_preds)
                        elif self.dataset.dynamics.set_mode == 'avoid':
                            CSL_train_batch_theoretically_recoverable_safe_set_frac = torch.sum(CSL_batch_costs.to(device) > 0) / len(CSL_batch_preds)
                            CSL_train_batch_recovered_safe_set_frac = torch.sum(CSL_batch_preds > torch.max(CSL_batch_preds[CSL_batch_costs.to(device) < 0])) / len(CSL_batch_preds)
                            CSL_val_theoretically_recoverable_safe_set_frac = torch.sum(CSL_val_costs.to(device) > 0) / len(CSL_val_preds)
                            CSL_val_recovered_safe_set_frac = torch.sum(CSL_val_preds > torch.max(CSL_val_preds[CSL_val_costs.to(device) < 0])) / len(CSL_val_preds)
                            CSL_val_tMax_theoretically_recoverable_safe_set_frac = torch.sum(CSL_val_tMax_costs.to(device) > 0) / len(CSL_val_tMax_preds)
                            CSL_val_tMax_recovered_safe_set_frac = torch.sum(CSL_val_tMax_preds > torch.max(CSL_val_tMax_preds[CSL_val_tMax_costs.to(device) < 0])) / len(CSL_val_tMax_preds)
                        else:
                            raise NotImplementedError
                        if self.use_wandb:
                            wandb.log({
                                "step": epoch+(CSL_epoch+1)*int(0.5*epochs_til_CSL/max_CSL_epochs),
                                "CSL_train_batch_loss": CSL_batch_loss.item(),
                                "SSL_train_batch_loss": SSL_batch_loss.item(),
                                "CSL_val_loss": CSL_val_loss.item(),
                                "SSL_val_loss": SSL_val_loss.item(),
                                "CSL_val_tMax_loss": CSL_val_tMax_loss.item(),
                                "CSL_train_batch_theoretically_recoverable_safe_set_frac": CSL_train_batch_theoretically_recoverable_safe_set_frac.item(),
                                "CSL_val_theoretically_recoverable_safe_set_frac": CSL_val_theoretically_recoverable_safe_set_frac.item(),
                                "CSL_val_tMax_theoretically_recoverable_safe_set_frac": CSL_val_tMax_theoretically_recoverable_safe_set_frac.item(),
                                "CSL_train_batch_recovered_safe_set_frac": CSL_train_batch_recovered_safe_set_frac.item(),
                                "CSL_val_recovered_safe_set_frac": CSL_val_recovered_safe_set_frac.item(),
                                "CSL_val_tMax_recovered_safe_set_frac": CSL_val_tMax_recovered_safe_set_frac.item(),
                            })

                        if CSL_val_loss < CSL_loss_frac_cutoff*CSL_initial_val_loss:
                            break

                if not (epoch+1) % epochs_til_checkpoint:
                    # Saving the optimizer state is important to produce consistent results
                    checkpoint = {
                        'epoch': epoch+1,
                        'model': self.model.state_dict(),
                        'optimizer': optim.state_dict()}
                    torch.save(checkpoint,
                        os.path.join(checkpoints_dir, 'model_epoch_%04d.pth' % (epoch+1)))
                    np.savetxt(os.path.join(checkpoints_dir, 'train_losses_epoch_%04d.txt' % (epoch+1)),
                        np.array(train_losses))
                    self.validate(
                        device=device, epoch=epoch+1, save_path=os.path.join(checkpoints_dir, 'BRS_validation_plot_epoch_%04d.png' % (epoch+1)),
                        x_resolution = val_x_resolution, y_resolution = val_y_resolution, z_resolution=val_z_resolution, time_resolution=val_time_resolution)

        if was_eval:
            self.model.eval()
            self.model.requires_grad_(False)

    def test(self, device, current_time, last_checkpoint, checkpoint_dt, dt, num_scenarios, num_violations, set_type, control_type, data_step, checkpoint_toload=None):
        was_training = self.model.training
        self.model.eval()
        self.model.requires_grad_(False)

        testing_dir = os.path.join(self.experiment_dir, 'testing_%s' % current_time.strftime('%m_%d_%Y_%H_%M'))
        if os.path.exists(testing_dir):
            overwrite = input("The testing directory %s already exists. Overwrite? (y/n)"%testing_dir)
            if not (overwrite == 'y'):
                print('Exiting.')
                quit()
            shutil.rmtree(testing_dir)
        os.makedirs(testing_dir)

        if checkpoint_toload is None:
            print('running cross-checkpoint testing')

            for i in tqdm(range(sidelen), desc='Checkpoint'):
                self._load_checkpoint(epoch=checkpoints[i])
                raise NotImplementedError

        else:
            print('running specific-checkpoint testing')
            self._load_checkpoint(checkpoint_toload)

            model = self.model
            dataset = self.dataset
            dynamics = dataset.dynamics
            raise NotImplementedError

        if was_training:
            self.model.train()
            self.model.requires_grad_(True)

class DeepReach(Experiment):
    def init_special(self):
        pass

### 7. Run Experiment

In [13]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33maanirudh[0m ([33maanirudh-n-a[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [14]:
import sys

sys.argv = [
    "script_name",  # Placeholder for script name (ignored by argparse)
    "--mode", "train",
    "--experiment_class", "DeepReach",
    "--dynamics_class", "PlanarRobot2D",
    "--experiment_name", "brt_goal_025m",
    "--minWith", "target",
    "--goalR", "0.25",
    "--velocity", "1.0",
    "--set_mode", "reach",
    "--num_epochs", "45000",

    "--wandb_project", "reachability-experiments",
    "--wandb_entity", "aanirudh-n-a",
    '--wandb_group', "Training",
    '--wandb_name', "brt_goal_025m_run_01",
]

In [15]:
import inspect
from inspect import isclass
p = configargparse.ArgumentParser()

p.add_argument('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
p.add_argument('--mode', type=str, required=True, choices=['all', 'train', 'test'], help="Experiment mode to run (new experiments must choose 'all' or 'train').")

# save/load directory options
p.add_argument('--experiments_dir', type=str, default='./runs', help='Where to save the experiment subdirectory.')
p.add_argument('--experiment_name', type=str, required=True, help='Name of the experient subdirectory.')
p.add_argument('--use_wandb', default=True, action='store_false', help='use wandb for logging')

use_wandb = p.parse_known_args()[0].use_wandb
if use_wandb:
    p.add_argument('--wandb_project', type=str, required=True, help='wandb project')
    p.add_argument('--wandb_entity', type=str, required=True, help='wandb entity')
    p.add_argument('--wandb_group', type=str, required=True, help='wandb group')
    p.add_argument('--wandb_name', type=str, required=True, help='name of wandb run')

mode = p.parse_known_args()[0].mode

if (mode == 'all') or (mode == 'train'):
    p.add_argument('--seed', type=int, default=0, required=False, help='Seed for the experiment.')

    # load experiment_class choices dynamically from experiments module
    experiment_classes_dict = {
        name: clss for name, clss in globals().items()
        if isclass(clss) and issubclass(clss, Experiment) and clss is not Experiment
    }
    # experiment_classes_dict = {name: clss for name, clss in inspect.getmembers(experiments, inspect.isclass) if clss.__bases__[0] == experiments.Experiment}
    p.add_argument('--experiment_class', type=str, default='DeepReach', choices=experiment_classes_dict.keys(), help='Experiment class to use.')
    # load special experiment_class arguments dynamically from chosen experiment class
    experiment_class = DeepReach #experiment_classes_dict[p.parse_known_args()[0].experiment_class]
    experiment_params = {name: param for name, param in inspect.signature(experiment_class.init_special).parameters.items() if name != 'self'}
    for param in experiment_params.keys():
        p.add_argument('--' + param, type=experiment_params[param].annotation, required=True, help='special experiment_class argument')

    # simulation data source options
    p.add_argument('--device', type=str, default='cuda:0', required=False, help='CUDA Device to use.')
    p.add_argument('--numpoints', type=int, default=65000, help='Number of points in simulation data source __getitem__.')
    p.add_argument('--pretrain', action='store_true', default=False, required=False, help='Pretrain dirichlet conditions')
    p.add_argument('--pretrain_iters', type=int, default=2000, required=False, help='Number of pretrain iterations')
    p.add_argument('--tMin', type=float, default=0.0, required=False, help='Start time of the simulation')
    p.add_argument('--tMax', type=float, default=1.0, required=False, help='End time of the simulation')
    p.add_argument('--counter_start', type=int, default=0, required=False, help='Defines the initial time for the curriculum training')
    p.add_argument('--counter_end', type=int, default=-1, required=False, help='Defines the linear step for curriculum training starting from the initial time')
    p.add_argument('--num_src_samples', type=int, default=1000, required=False, help='Number of source samples (initial-time samples) at each time step')
    p.add_argument('--num_target_samples', type=int, default=0, required=False, help='Number of samples inside the target set')

    # model options
    p.add_argument('--model', type=str, default='sine', required=False, choices=['sine', 'tanh', 'sigmoid', 'relu'], help='Type of model to evaluate, default is sine.')
    p.add_argument('--model_mode', type=str, default='mlp', required=False, choices=['mlp', 'rbf', 'pinn'], help='Whether to use uniform velocity parameter')
    p.add_argument('--num_hl', type=int, default=3, required=False, help='The number of hidden layers')
    p.add_argument('--num_nl', type=int, default=512, required=False, help='Number of neurons per hidden layer.')
    p.add_argument('--deepreach_model', type=str, default='exact', required=False, choices=['exact', 'diff', 'vanilla'], help='deepreach model')

    # training options
    p.add_argument('--epochs_til_ckpt', type=int, default=1000, help='Time interval in seconds until checkpoint is saved.')
    p.add_argument('--steps_til_summary', type=int, default=100, help='Time interval in seconds until tensorboard summary is saved.')
    p.add_argument('--batch_size', type=int, default=1, help='Batch size used during training (irrelevant, since len(dataset) == 1).')
    p.add_argument('--lr', type=float, default=2e-5, help='learning rate. default=2e-5')
    p.add_argument('--num_epochs', type=int, default=100000, help='Number of epochs to train for.')
    p.add_argument('--clip_grad', default=0.0, type=float, help='Clip gradient.')
    p.add_argument('--use_lbfgs', default=False, type=bool, help='use L-BFGS.')
    p.add_argument('--adj_rel_grads', default=True, type=bool, help='adjust the relative magnitude of the losses')
    p.add_argument('--dirichlet_loss_divisor', default=1.0, required=False, type=float, help='What to divide the dirichlet loss by for loss reweighting')

    # cost-supervised learning (CSL) options
    p.add_argument('--use_CSL', default=False, action='store_true', help='use cost-supervised learning (CSL)')
    p.add_argument('--CSL_lr', type=float, default=2e-5, help='The learning rate used for CSL')
    p.add_argument('--CSL_dt', type=float, default=0.0025, help='The dt used in rolling out trajectories to get cost labels')
    p.add_argument('--epochs_til_CSL', type=int, default=10000, help='Number of epochs between CSL phases')
    p.add_argument('--num_CSL_samples', type=int, default=1000000, help='Number of cost samples in training dataset for CSL phases')
    p.add_argument('--CSL_loss_frac_cutoff', type=float, default=0.1, help='Fraction of initial cost loss on validation dataset to cutoff CSL phases')
    p.add_argument('--max_CSL_epochs', type=int, default=100, help='Max number of CSL epochs per phase')
    p.add_argument('--CSL_loss_weight', type=float, default=1.0, help='weight of cost loss (relative to PDE loss)')
    p.add_argument('--CSL_batch_size', type=int, default=1000, help='Batch size for training in CSL phases')

    # validation (during training) options
    p.add_argument('--val_x_resolution', type=int, default=200, help='x-axis resolution of validation plot during training')
    p.add_argument('--val_y_resolution', type=int, default=200, help='y-axis resolution of validation plot during training')
    p.add_argument('--val_z_resolution', type=int, default=5, help='z-axis resolution of validation plot during training')
    p.add_argument('--val_time_resolution', type=int, default=3, help='time-axis resolution of validation plot during training')

    # loss options
    p.add_argument('--minWith', type=str, required=True, choices=['none', 'zero', 'target'], help='BRS vs BRT computation (typically should be using target for BRT)')

    # # load dynamics_class choices dynamically from dynamics module
    # dynamics_classes_dict = {name: clss for name, clss in inspect.getmembers(dynamics, inspect.isclass) if clss.__bases__[0] == dynamics.Dynamics}
    p.add_argument('--dynamics_class', type=str, required=True, choices=["PlanarRobot2D"], help='Dynamics class to use.')
    # # load special dynamics_class arguments dynamically from chosen dynamics class
    # dynamics_class = dynamics_classes_dict[p.parse_known_args()[0].dynamics_class]
    # dynamics_params = {name: param for name, param in inspect.signature(dynamics_class).parameters.items() if name != 'self'}
    dynamics_class = PlanarRobot2D
    dynamics_params = {
        name: param for name, param in inspect.signature(dynamics_class).parameters.items() if name != 'self'
    }

    for param in dynamics_params.keys():
        if dynamics_params[param].annotation is bool:
            p.add_argument('--' + param, type=dynamics_params[param].annotation, default=False, help='special dynamics_class argument')
        else:
            p.add_argument('--' + param, type=dynamics_params[param].annotation, required=True, help='special dynamics_class argument')

if (mode == 'all') or (mode == 'test'):
    p.add_argument('--dt', type=float, default=0.0025, help='The dt used in testing simulations')
    p.add_argument('--checkpoint_toload', type=int, default=None, help="The checkpoint to load for testing (-1 for final training checkpoint, None for cross-checkpoint testing")
    p.add_argument('--num_scenarios', type=int, default=100000, help='The number of scenarios sampled in scenario optimization for testing')
    p.add_argument('--num_violations', type=int, default=1000, help='The number of violations to sample for in scenario optimization for testing')
    p.add_argument('--control_type', type=str, default='value', choices=['value', 'ttr', 'init_ttr'], help='The controller to use in scenario optimization for testing')
    p.add_argument('--data_step', type=str, default='run_basic_recovery', choices=['plot_violations', 'run_basic_recovery', 'plot_basic_recovery', 'collect_samples', 'train_binner', 'run_binned_recovery', 'plot_binned_recovery', 'plot_cost_function'], help='The data processing step to run')

opt = p.parse_args()

# start wandb
if use_wandb:
    wandb.init(
        project = opt.wandb_project,
        entity = opt.wandb_entity,
        group = opt.wandb_group,
        name = opt.wandb_name,
    )
    wandb.config.update(opt)

experiment_dir = os.path.join(opt.experiments_dir, opt.experiment_name)
if (mode == 'all') or (mode == 'train'):
    # create experiment dir
    if os.path.exists(experiment_dir):
        overwrite = input("The experiment directory %s already exists. Overwrite? (y/n)"%experiment_dir)
        if not (overwrite == 'y'):
            print('Exiting.')
            quit()
        shutil.rmtree(experiment_dir)
    os.makedirs(experiment_dir)
elif mode == 'test':
    # confirm that experiment dir already exists
    if not os.path.exists(experiment_dir):
        raise RuntimeError('Cannot run test mode: experiment directory not found!')

current_time = datetime.now()
# log current config
with open(os.path.join(experiment_dir, 'config_%s.txt' % current_time.strftime('%m_%d_%Y_%H_%M')), 'w') as f:
    for arg, val in vars(opt).items():
        f.write(arg + ' = ' + str(val) + '\n')

if (mode == 'all') or (mode == 'train'):
    # set counter_end appropriately if needed
    if opt.counter_end == -1:
        opt.counter_end = opt.num_epochs

    # log original options
    with open(os.path.join(experiment_dir, 'orig_opt.pickle'), 'wb') as opt_file:
        pickle.dump(opt, opt_file)

# load original experiment settings
with open(os.path.join(experiment_dir, 'orig_opt.pickle'), 'rb') as opt_file:
    orig_opt = pickle.load(opt_file)

# set the experiment seed
torch.manual_seed(orig_opt.seed)
random.seed(orig_opt.seed)
np.random.seed(orig_opt.seed)

dynamics_class = PlanarRobot2D #getattr(dynamics, orig_opt.dynamics_class)
dynamics = dynamics_class(**{argname: getattr(orig_opt, argname) for argname in inspect.signature(dynamics_class).parameters.keys() if argname != 'self'})
dynamics.deepreach_model=orig_opt.deepreach_model
dataset = ReachabilityDataset(
    dynamics=dynamics, numpoints=orig_opt.numpoints,
    pretrain=orig_opt.pretrain, pretrain_iters=orig_opt.pretrain_iters,
    tMin=orig_opt.tMin, tMax=orig_opt.tMax,
    counter_start=orig_opt.counter_start, counter_end=orig_opt.counter_end,
    num_src_samples=orig_opt.num_src_samples, num_target_samples=orig_opt.num_target_samples)

model = SingleBVPNet(in_features=dynamics.input_dim, out_features=1, type=orig_opt.model, mode=orig_opt.model_mode,
                             final_layer_factor=1., hidden_features=orig_opt.num_nl, num_hidden_layers=orig_opt.num_hl)
model.to(opt.device)

experiment_class = DeepReach #getattr(experiments, orig_opt.experiment_class)
experiment = experiment_class(model=model, dataset=dataset, experiment_dir=experiment_dir, use_wandb=use_wandb)
experiment.init_special(**{argname: getattr(orig_opt, argname) for argname in inspect.signature(experiment_class.init_special).parameters.keys() if argname != 'self'})

if (mode == 'all') or (mode == 'train'):
    if dynamics.loss_type == 'brt_hjivi':
        loss_fn = init_brt_hjivi_loss(dynamics, orig_opt.minWith, orig_opt.dirichlet_loss_divisor)
    elif dynamics.loss_type == 'brat_hjivi':
        loss_fn = init_brat_hjivi_loss(dynamics, orig_opt.minWith, orig_opt.dirichlet_loss_divisor)
    else:
        raise NotImplementedError
    experiment.train(
        device=opt.device, batch_size=orig_opt.batch_size, epochs=orig_opt.num_epochs, lr=orig_opt.lr,
        steps_til_summary=orig_opt.steps_til_summary, epochs_til_checkpoint=orig_opt.epochs_til_ckpt,
        loss_fn=loss_fn, clip_grad=orig_opt.clip_grad, use_lbfgs=orig_opt.use_lbfgs, adjust_relative_grads=orig_opt.adj_rel_grads,
        val_x_resolution=orig_opt.val_x_resolution, val_y_resolution=orig_opt.val_y_resolution, val_z_resolution=orig_opt.val_z_resolution, val_time_resolution=orig_opt.val_time_resolution,
        use_CSL=orig_opt.use_CSL, CSL_lr=orig_opt.CSL_lr, CSL_dt=orig_opt.CSL_dt, epochs_til_CSL=orig_opt.epochs_til_CSL, num_CSL_samples=orig_opt.num_CSL_samples, CSL_loss_frac_cutoff=orig_opt.CSL_loss_frac_cutoff, max_CSL_epochs=orig_opt.max_CSL_epochs, CSL_loss_weight=orig_opt.CSL_loss_weight, CSL_batch_size=orig_opt.CSL_batch_size)

if (mode == 'all') or (mode == 'test'):
    experiment.test(
        device=opt.device, current_time=current_time,
        last_checkpoint=orig_opt.num_epochs, checkpoint_dt=orig_opt.epochs_til_ckpt,
        checkpoint_toload=opt.checkpoint_toload, dt=opt.dt,
        num_scenarios=opt.num_scenarios, num_violations=opt.num_violations,
        set_type='BRT' if orig_opt.minWith in ['zero', 'target'] else 'BRS', control_type=opt.control_type, data_step=opt.data_step)

SingleBVPNet(
  (net): FCBlock(
    (net): Sequential(
      (0): Sequential(
        (0): BatchLinear(in_features=3, out_features=512, bias=True)
        (1): Sine()
      )
      (1): Sequential(
        (0): BatchLinear(in_features=512, out_features=512, bias=True)
        (1): Sine()
      )
      (2): Sequential(
        (0): BatchLinear(in_features=512, out_features=512, bias=True)
        (1): Sine()
      )
      (3): Sequential(
        (0): BatchLinear(in_features=512, out_features=512, bias=True)
        (1): Sine()
      )
      (4): Sequential(
        (0): BatchLinear(in_features=512, out_features=1, bias=True)
      )
    )
  )
)


  0%|          | 1/45000 [00:01<17:38:44,  1.41s/it]

Epoch 0, Total loss 2817.822021, iteration time 1.388211


  0%|          | 102/45000 [00:22<2:43:27,  4.58it/s]

Epoch 100, Total loss 507.876862, iteration time 0.223210


  0%|          | 202/45000 [00:44<2:52:59,  4.32it/s]

Epoch 200, Total loss 185.247498, iteration time 0.244905


  1%|          | 302/45000 [01:06<2:51:27,  4.34it/s]

Epoch 300, Total loss 298.409607, iteration time 0.235880


  1%|          | 402/45000 [01:30<2:57:58,  4.18it/s]

Epoch 400, Total loss 383.542908, iteration time 0.246486


  1%|          | 502/45000 [01:53<2:54:47,  4.24it/s]

Epoch 500, Total loss 711.284119, iteration time 0.229558


  1%|▏         | 602/45000 [02:15<2:51:32,  4.31it/s]

Epoch 600, Total loss 602.290283, iteration time 0.234004


  2%|▏         | 702/45000 [02:38<2:49:41,  4.35it/s]

Epoch 700, Total loss 735.211182, iteration time 0.236336


  2%|▏         | 802/45000 [03:00<2:50:35,  4.32it/s]

Epoch 800, Total loss 851.230713, iteration time 0.245440


  2%|▏         | 902/45000 [03:23<2:49:38,  4.33it/s]

Epoch 900, Total loss 885.404602, iteration time 0.237290


  2%|▏         | 1002/45000 [03:46<4:35:01,  2.67it/s]

Epoch 1000, Total loss 2202.255859, iteration time 0.096160


  2%|▏         | 1102/45000 [04:09<2:52:32,  4.24it/s]

Epoch 1100, Total loss 1013.609497, iteration time 0.256852


  3%|▎         | 1202/45000 [04:32<2:50:40,  4.28it/s]

Epoch 1200, Total loss 1166.095459, iteration time 0.238873


  3%|▎         | 1302/45000 [04:55<2:48:35,  4.32it/s]

Epoch 1300, Total loss 1409.507080, iteration time 0.232023


  3%|▎         | 1402/45000 [05:18<2:53:49,  4.18it/s]

Epoch 1400, Total loss 1322.895874, iteration time 0.245464


  3%|▎         | 1502/45000 [05:40<2:47:48,  4.32it/s]

Epoch 1500, Total loss 1595.369141, iteration time 0.236311


  4%|▎         | 1602/45000 [06:03<2:51:22,  4.22it/s]

Epoch 1600, Total loss 1407.250732, iteration time 0.248707


  4%|▍         | 1702/45000 [06:26<2:48:33,  4.28it/s]

Epoch 1700, Total loss 1598.338867, iteration time 0.240109


  4%|▍         | 1802/45000 [06:49<2:48:23,  4.28it/s]

Epoch 1800, Total loss 1573.339478, iteration time 0.237429


  4%|▍         | 1902/45000 [07:12<2:48:41,  4.26it/s]

Epoch 1900, Total loss 1565.680298, iteration time 0.237058


  4%|▍         | 2002/45000 [07:35<4:00:32,  2.98it/s]

Epoch 2000, Total loss 2026.873535, iteration time 0.094524


  5%|▍         | 2102/45000 [07:58<2:46:21,  4.30it/s]

Epoch 2100, Total loss 1904.078857, iteration time 0.234542


  5%|▍         | 2202/45000 [08:21<2:46:52,  4.27it/s]

Epoch 2200, Total loss 2016.972412, iteration time 0.235783


  5%|▌         | 2302/45000 [08:44<2:47:41,  4.24it/s]

Epoch 2300, Total loss 1864.442871, iteration time 0.229340


  5%|▌         | 2402/45000 [09:07<2:44:15,  4.32it/s]

Epoch 2400, Total loss 4212.590820, iteration time 0.233178


  6%|▌         | 2502/45000 [09:29<2:48:22,  4.21it/s]

Epoch 2500, Total loss 2308.892090, iteration time 0.254353


  6%|▌         | 2602/45000 [09:52<2:46:49,  4.24it/s]

Epoch 2600, Total loss 2102.637695, iteration time 0.250816


  6%|▌         | 2702/45000 [10:15<2:45:57,  4.25it/s]

Epoch 2700, Total loss 1931.611816, iteration time 0.241401


  6%|▌         | 2802/45000 [10:38<2:50:24,  4.13it/s]

Epoch 2800, Total loss 2337.413818, iteration time 0.232911


  6%|▋         | 2902/45000 [11:01<2:43:34,  4.29it/s]

Epoch 2900, Total loss 2459.489990, iteration time 0.236241


  7%|▋         | 3002/45000 [11:25<3:56:48,  2.96it/s]

Epoch 3000, Total loss 1823.517456, iteration time 0.092428


  7%|▋         | 3102/45000 [11:48<2:43:47,  4.26it/s]

Epoch 3100, Total loss 2808.748047, iteration time 0.252592


  7%|▋         | 3202/45000 [12:11<2:45:22,  4.21it/s]

Epoch 3200, Total loss 3031.921631, iteration time 0.257138


  7%|▋         | 3302/45000 [12:34<2:42:20,  4.28it/s]

Epoch 3300, Total loss 2615.085938, iteration time 0.241251


  8%|▊         | 3402/45000 [12:57<2:42:24,  4.27it/s]

Epoch 3400, Total loss 2117.352295, iteration time 0.241671


  8%|▊         | 3502/45000 [13:20<2:44:30,  4.20it/s]

Epoch 3500, Total loss 2350.847900, iteration time 0.245438


  8%|▊         | 3602/45000 [13:43<2:42:54,  4.24it/s]

Epoch 3600, Total loss 1544.615601, iteration time 0.244481


  8%|▊         | 3702/45000 [14:07<2:43:45,  4.20it/s]

Epoch 3700, Total loss 1368.176270, iteration time 0.244787


  8%|▊         | 3802/45000 [14:30<2:46:04,  4.13it/s]

Epoch 3800, Total loss 2216.078613, iteration time 0.249014


  9%|▊         | 3902/45000 [14:54<2:43:52,  4.18it/s]

Epoch 3900, Total loss 987.698547, iteration time 0.244878


  9%|▉         | 4002/45000 [15:18<3:56:30,  2.89it/s]

Epoch 4000, Total loss 826.748291, iteration time 0.093683


  9%|▉         | 4102/45000 [15:41<2:43:08,  4.18it/s]

Epoch 4100, Total loss 873.955078, iteration time 0.246229


  9%|▉         | 4202/45000 [16:05<2:46:08,  4.09it/s]

Epoch 4200, Total loss 683.044739, iteration time 0.247651


 10%|▉         | 4302/45000 [16:29<2:43:19,  4.15it/s]

Epoch 4300, Total loss 1033.001221, iteration time 0.251342


 10%|▉         | 4402/45000 [16:52<2:41:20,  4.19it/s]

Epoch 4400, Total loss 3393.323730, iteration time 0.243174


 10%|█         | 4502/45000 [17:16<2:42:02,  4.17it/s]

Epoch 4500, Total loss 921.447632, iteration time 0.245787


 10%|█         | 4602/45000 [17:39<2:43:35,  4.12it/s]

Epoch 4600, Total loss 548.739136, iteration time 0.245742


 10%|█         | 4702/45000 [18:03<2:40:06,  4.19it/s]

Epoch 4700, Total loss 1090.715698, iteration time 0.246055


 11%|█         | 4802/45000 [18:27<2:40:45,  4.17it/s]

Epoch 4800, Total loss 917.823730, iteration time 0.248492


 11%|█         | 4902/45000 [18:50<2:40:18,  4.17it/s]

Epoch 4900, Total loss 342.309692, iteration time 0.251320


 11%|█         | 5002/45000 [19:14<4:06:49,  2.70it/s]

Epoch 5000, Total loss 463.711426, iteration time 0.101013


 11%|█▏        | 5102/45000 [19:38<2:38:09,  4.20it/s]

Epoch 5100, Total loss 799.005859, iteration time 0.241951


 12%|█▏        | 5202/45000 [20:01<2:37:37,  4.21it/s]

Epoch 5200, Total loss 911.266602, iteration time 0.243294


 12%|█▏        | 5302/45000 [20:25<2:40:34,  4.12it/s]

Epoch 5300, Total loss 1170.175537, iteration time 0.261164


 12%|█▏        | 5402/45000 [20:49<2:37:11,  4.20it/s]

Epoch 5400, Total loss 1037.520020, iteration time 0.245930


 12%|█▏        | 5502/45000 [21:12<2:37:10,  4.19it/s]

Epoch 5500, Total loss 844.266479, iteration time 0.252408


 12%|█▏        | 5602/45000 [21:36<2:38:14,  4.15it/s]

Epoch 5600, Total loss 1098.082031, iteration time 0.247720


 13%|█▎        | 5702/45000 [22:00<2:38:40,  4.13it/s]

Epoch 5700, Total loss 528.162109, iteration time 0.248439


 13%|█▎        | 5802/45000 [22:23<2:35:20,  4.21it/s]

Epoch 5800, Total loss 332.950043, iteration time 0.244660


 13%|█▎        | 5902/45000 [22:47<2:34:57,  4.21it/s]

Epoch 5900, Total loss 652.235535, iteration time 0.244873


 13%|█▎        | 6002/45000 [23:11<3:41:27,  2.94it/s]

Epoch 6000, Total loss 530.814453, iteration time 0.095436


 14%|█▎        | 6102/45000 [23:34<2:35:49,  4.16it/s]

Epoch 6100, Total loss 747.520630, iteration time 0.244804


 14%|█▍        | 6202/45000 [23:58<2:35:14,  4.17it/s]

Epoch 6200, Total loss 796.821777, iteration time 0.247775


 14%|█▍        | 6302/45000 [24:21<2:33:12,  4.21it/s]

Epoch 6300, Total loss 487.250427, iteration time 0.243612


 14%|█▍        | 6402/45000 [24:45<2:33:34,  4.19it/s]

Epoch 6400, Total loss 449.239838, iteration time 0.245441


 14%|█▍        | 6502/45000 [25:09<2:34:59,  4.14it/s]

Epoch 6500, Total loss 374.452881, iteration time 0.250111


 15%|█▍        | 6602/45000 [25:32<2:33:39,  4.16it/s]

Epoch 6600, Total loss 733.980286, iteration time 0.244485


 15%|█▍        | 6702/45000 [25:56<2:32:08,  4.20it/s]

Epoch 6700, Total loss 556.958130, iteration time 0.243507


 15%|█▌        | 6802/45000 [26:19<2:31:38,  4.20it/s]

Epoch 6800, Total loss 772.640869, iteration time 0.242119


 15%|█▌        | 6902/45000 [26:43<2:32:50,  4.15it/s]

Epoch 6900, Total loss 573.966431, iteration time 0.247012


 16%|█▌        | 7002/45000 [27:07<3:34:27,  2.95it/s]

Epoch 7000, Total loss 997.005737, iteration time 0.093116


 16%|█▌        | 7102/45000 [27:31<2:30:53,  4.19it/s]

Epoch 7100, Total loss 941.904480, iteration time 0.247044


 16%|█▌        | 7202/45000 [27:54<2:33:48,  4.10it/s]

Epoch 7200, Total loss 606.532227, iteration time 0.250294


 16%|█▌        | 7302/45000 [28:18<2:31:20,  4.15it/s]

Epoch 7300, Total loss 731.556030, iteration time 0.247191


 16%|█▋        | 7402/45000 [28:42<2:32:01,  4.12it/s]

Epoch 7400, Total loss 551.788208, iteration time 0.253848


 17%|█▋        | 7502/45000 [29:05<2:29:42,  4.17it/s]

Epoch 7500, Total loss 715.811829, iteration time 0.249096


 17%|█▋        | 7602/45000 [29:29<2:30:44,  4.13it/s]

Epoch 7600, Total loss 596.209595, iteration time 0.248981


 17%|█▋        | 7702/45000 [29:53<2:28:56,  4.17it/s]

Epoch 7700, Total loss 802.969421, iteration time 0.245103


 17%|█▋        | 7802/45000 [30:16<2:28:58,  4.16it/s]

Epoch 7800, Total loss 664.283936, iteration time 0.247367


 18%|█▊        | 7902/45000 [30:40<2:30:52,  4.10it/s]

Epoch 7900, Total loss 631.057739, iteration time 0.252086


 18%|█▊        | 8002/45000 [31:04<3:33:34,  2.89it/s]

Epoch 8000, Total loss 817.972046, iteration time 0.096501


 18%|█▊        | 8102/45000 [31:28<2:28:21,  4.15it/s]

Epoch 8100, Total loss 779.764099, iteration time 0.247394


 18%|█▊        | 8202/45000 [31:51<2:26:41,  4.18it/s]

Epoch 8200, Total loss 765.567383, iteration time 0.245070


 18%|█▊        | 8302/45000 [32:15<2:27:42,  4.14it/s]

Epoch 8300, Total loss 791.049744, iteration time 0.247134


 19%|█▊        | 8402/45000 [32:39<2:27:18,  4.14it/s]

Epoch 8400, Total loss 520.236633, iteration time 0.250886


 19%|█▉        | 8502/45000 [33:02<2:26:23,  4.16it/s]

Epoch 8500, Total loss 699.758118, iteration time 0.245311


 19%|█▉        | 8602/45000 [33:26<2:24:57,  4.18it/s]

Epoch 8600, Total loss 681.669678, iteration time 0.248459


 19%|█▉        | 8702/45000 [33:50<2:26:36,  4.13it/s]

Epoch 8700, Total loss 673.999756, iteration time 0.241456


 20%|█▉        | 8802/45000 [34:13<2:24:55,  4.16it/s]

Epoch 8800, Total loss 764.884460, iteration time 0.248406


 20%|█▉        | 8902/45000 [34:37<2:24:34,  4.16it/s]

Epoch 8900, Total loss 673.497803, iteration time 0.247563


 20%|██        | 9002/45000 [35:01<3:26:16,  2.91it/s]

Epoch 9000, Total loss 658.550537, iteration time 0.093107


 20%|██        | 9102/45000 [35:25<2:24:57,  4.13it/s]

Epoch 9100, Total loss 688.969116, iteration time 0.246886


 20%|██        | 9202/45000 [35:48<2:22:42,  4.18it/s]

Epoch 9200, Total loss 652.657288, iteration time 0.246350


 21%|██        | 9302/45000 [36:12<2:24:20,  4.12it/s]

Epoch 9300, Total loss 742.021851, iteration time 0.249128


 21%|██        | 9402/45000 [36:36<2:21:22,  4.20it/s]

Epoch 9400, Total loss 669.725403, iteration time 0.242592


 21%|██        | 9502/45000 [36:59<2:24:58,  4.08it/s]

Epoch 9500, Total loss 637.526306, iteration time 0.249325


 21%|██▏       | 9602/45000 [37:23<2:23:56,  4.10it/s]

Epoch 9600, Total loss 655.939575, iteration time 0.264344


 22%|██▏       | 9702/45000 [37:47<2:21:47,  4.15it/s]

Epoch 9700, Total loss 824.938721, iteration time 0.248415


 22%|██▏       | 9802/45000 [38:10<2:22:13,  4.12it/s]

Epoch 9800, Total loss 645.577759, iteration time 0.247441


 22%|██▏       | 9902/45000 [38:34<2:20:13,  4.17it/s]

Epoch 9900, Total loss 692.014404, iteration time 0.248538


 22%|██▏       | 10002/45000 [38:58<3:22:04,  2.89it/s]

Epoch 10000, Total loss 679.061462, iteration time 0.099638


 22%|██▏       | 10102/45000 [39:22<2:20:19,  4.14it/s]

Epoch 10100, Total loss 711.914856, iteration time 0.246793


 23%|██▎       | 10202/45000 [39:46<2:20:58,  4.11it/s]

Epoch 10200, Total loss 621.690308, iteration time 0.248328


 23%|██▎       | 10302/45000 [40:09<2:19:35,  4.14it/s]

Epoch 10300, Total loss 622.829956, iteration time 0.243291


 23%|██▎       | 10402/45000 [40:33<2:17:44,  4.19it/s]

Epoch 10400, Total loss 631.420410, iteration time 0.243499


 23%|██▎       | 10502/45000 [40:57<2:18:34,  4.15it/s]

Epoch 10500, Total loss 716.406799, iteration time 0.248992


 24%|██▎       | 10602/45000 [41:20<2:18:22,  4.14it/s]

Epoch 10600, Total loss 705.950134, iteration time 0.251254


 24%|██▍       | 10702/45000 [41:44<2:17:38,  4.15it/s]

Epoch 10700, Total loss 607.900269, iteration time 0.247649


 24%|██▍       | 10802/45000 [42:08<2:16:39,  4.17it/s]

Epoch 10800, Total loss 560.507080, iteration time 0.247642


 24%|██▍       | 10902/45000 [42:31<2:15:34,  4.19it/s]

Epoch 10900, Total loss 247.174072, iteration time 0.244803


 24%|██▍       | 11002/45000 [42:56<4:43:08,  2.00it/s]

Epoch 11000, Total loss 442.213989, iteration time 0.099860


 25%|██▍       | 11102/45000 [43:20<2:16:01,  4.15it/s]

Epoch 11100, Total loss 183.924011, iteration time 0.248208


 25%|██▍       | 11202/45000 [43:43<2:14:47,  4.18it/s]

Epoch 11200, Total loss 300.887634, iteration time 0.244558


 25%|██▌       | 11302/45000 [44:07<2:13:50,  4.20it/s]

Epoch 11300, Total loss 292.148010, iteration time 0.241862


 25%|██▌       | 11402/45000 [44:31<2:14:22,  4.17it/s]

Epoch 11400, Total loss 2019.467407, iteration time 0.254914


 26%|██▌       | 11502/45000 [44:54<2:13:13,  4.19it/s]

Epoch 11500, Total loss 331.886475, iteration time 0.248223


 26%|██▌       | 11602/45000 [45:18<2:13:29,  4.17it/s]

Epoch 11600, Total loss 352.261047, iteration time 0.245939


 26%|██▌       | 11702/45000 [45:41<2:14:11,  4.14it/s]

Epoch 11700, Total loss 397.842468, iteration time 0.244022


 26%|██▌       | 11802/45000 [46:05<2:11:13,  4.22it/s]

Epoch 11800, Total loss 793.717896, iteration time 0.241121


 26%|██▋       | 11902/45000 [46:28<2:12:40,  4.16it/s]

Epoch 11900, Total loss 684.421021, iteration time 0.242883


 27%|██▋       | 12002/45000 [46:52<3:09:30,  2.90it/s]

Epoch 12000, Total loss 630.299194, iteration time 0.105369


 27%|██▋       | 12102/45000 [47:16<2:13:17,  4.11it/s]

Epoch 12100, Total loss 717.731995, iteration time 0.247389


 27%|██▋       | 12202/45000 [47:40<2:14:15,  4.07it/s]

Epoch 12200, Total loss 611.852051, iteration time 0.254199


 27%|██▋       | 12302/45000 [48:03<2:09:46,  4.20it/s]

Epoch 12300, Total loss 621.479980, iteration time 0.245750


 28%|██▊       | 12402/45000 [48:27<2:09:46,  4.19it/s]

Epoch 12400, Total loss 688.050781, iteration time 0.233485


 28%|██▊       | 12502/45000 [48:50<2:12:09,  4.10it/s]

Epoch 12500, Total loss 584.067505, iteration time 0.254250


 28%|██▊       | 12602/45000 [49:14<2:11:34,  4.10it/s]

Epoch 12600, Total loss 619.130249, iteration time 0.256915


 28%|██▊       | 12702/45000 [49:38<2:08:45,  4.18it/s]

Epoch 12700, Total loss 603.801270, iteration time 0.246539


 28%|██▊       | 12802/45000 [50:01<2:10:48,  4.10it/s]

Epoch 12800, Total loss 561.277588, iteration time 0.248618


 29%|██▊       | 12902/45000 [50:25<2:09:00,  4.15it/s]

Epoch 12900, Total loss 605.083130, iteration time 0.247309


 29%|██▉       | 13002/45000 [50:49<3:03:51,  2.90it/s]

Epoch 13000, Total loss 622.548462, iteration time 0.092146


 29%|██▉       | 13102/45000 [51:13<2:06:48,  4.19it/s]

Epoch 13100, Total loss 569.824951, iteration time 0.242949


 29%|██▉       | 13202/45000 [51:37<2:08:18,  4.13it/s]

Epoch 13200, Total loss 450.020966, iteration time 0.246177


 30%|██▉       | 13302/45000 [52:00<2:05:40,  4.20it/s]

Epoch 13300, Total loss 646.838257, iteration time 0.241806


 30%|██▉       | 13402/45000 [52:24<2:07:00,  4.15it/s]

Epoch 13400, Total loss 599.326355, iteration time 0.246715


 30%|███       | 13502/45000 [52:48<2:06:38,  4.15it/s]

Epoch 13500, Total loss 650.514648, iteration time 0.247274


 30%|███       | 13602/45000 [53:11<2:04:51,  4.19it/s]

Epoch 13600, Total loss 584.735107, iteration time 0.240973


 30%|███       | 13702/45000 [53:35<2:04:49,  4.18it/s]

Epoch 13700, Total loss 633.565430, iteration time 0.242883


 31%|███       | 13802/45000 [53:58<2:05:09,  4.15it/s]

Epoch 13800, Total loss 641.470886, iteration time 0.246871


 31%|███       | 13902/45000 [54:22<2:05:35,  4.13it/s]

Epoch 13900, Total loss 656.487671, iteration time 0.249929


 31%|███       | 14002/45000 [54:46<2:59:18,  2.88it/s]

Epoch 14000, Total loss 651.697754, iteration time 0.095711


 31%|███▏      | 14102/45000 [55:10<2:04:41,  4.13it/s]

Epoch 14100, Total loss 564.052734, iteration time 0.246474


 32%|███▏      | 14202/45000 [55:34<2:03:22,  4.16it/s]

Epoch 14200, Total loss 663.197937, iteration time 0.250516


 32%|███▏      | 14302/45000 [55:57<2:04:59,  4.09it/s]

Epoch 14300, Total loss 613.274902, iteration time 0.252460


 32%|███▏      | 14402/45000 [56:21<2:02:32,  4.16it/s]

Epoch 14400, Total loss 562.850952, iteration time 0.246956


 32%|███▏      | 14502/45000 [56:45<2:01:04,  4.20it/s]

Epoch 14500, Total loss 535.503113, iteration time 0.249681


 32%|███▏      | 14602/45000 [57:08<2:03:18,  4.11it/s]

Epoch 14600, Total loss 814.091980, iteration time 0.243283


 33%|███▎      | 14702/45000 [57:32<2:01:51,  4.14it/s]

Epoch 14700, Total loss 463.645508, iteration time 0.251951


 33%|███▎      | 14802/45000 [57:56<2:00:31,  4.18it/s]

Epoch 14800, Total loss 586.682129, iteration time 0.249735


 33%|███▎      | 14902/45000 [58:19<2:00:23,  4.17it/s]

Epoch 14900, Total loss 578.967041, iteration time 0.246873


 33%|███▎      | 15002/45000 [58:44<3:20:27,  2.49it/s]

Epoch 15000, Total loss 668.875671, iteration time 0.107566


 34%|███▎      | 15102/45000 [59:07<2:00:14,  4.14it/s]

Epoch 15100, Total loss 596.176758, iteration time 0.256258


 34%|███▍      | 15202/45000 [59:31<1:58:54,  4.18it/s]

Epoch 15200, Total loss 586.781555, iteration time 0.249658


 34%|███▍      | 15302/45000 [59:55<1:57:56,  4.20it/s]

Epoch 15300, Total loss 581.680359, iteration time 0.243531


 34%|███▍      | 15402/45000 [1:00:18<1:59:38,  4.12it/s]

Epoch 15400, Total loss 571.798645, iteration time 0.254205


 34%|███▍      | 15502/45000 [1:00:42<1:57:16,  4.19it/s]

Epoch 15500, Total loss 616.792480, iteration time 0.242976


 35%|███▍      | 15602/45000 [1:01:06<1:58:30,  4.13it/s]

Epoch 15600, Total loss 479.115540, iteration time 0.246304


 35%|███▍      | 15702/45000 [1:01:29<1:58:21,  4.13it/s]

Epoch 15700, Total loss 683.122803, iteration time 0.249635


 35%|███▌      | 15802/45000 [1:01:53<1:56:40,  4.17it/s]

Epoch 15800, Total loss 503.144012, iteration time 0.253201


 35%|███▌      | 15902/45000 [1:02:17<1:57:10,  4.14it/s]

Epoch 15900, Total loss 499.883606, iteration time 0.246731


 36%|███▌      | 16002/45000 [1:02:41<2:46:21,  2.91it/s]

Epoch 16000, Total loss 494.253937, iteration time 0.094003


 36%|███▌      | 16102/45000 [1:03:05<1:55:44,  4.16it/s]

Epoch 16100, Total loss 496.409546, iteration time 0.246126


 36%|███▌      | 16202/45000 [1:03:28<1:56:28,  4.12it/s]

Epoch 16200, Total loss 578.552856, iteration time 0.247154


 36%|███▌      | 16302/45000 [1:03:52<1:55:10,  4.15it/s]

Epoch 16300, Total loss 621.868286, iteration time 0.248517


 36%|███▋      | 16402/45000 [1:04:16<1:54:02,  4.18it/s]

Epoch 16400, Total loss 659.369629, iteration time 0.244613


 37%|███▋      | 16502/45000 [1:04:39<1:55:26,  4.11it/s]

Epoch 16500, Total loss 582.930420, iteration time 0.252254


 37%|███▋      | 16602/45000 [1:05:03<1:53:49,  4.16it/s]

Epoch 16600, Total loss 716.932495, iteration time 0.248102


 37%|███▋      | 16702/45000 [1:05:27<1:52:59,  4.17it/s]

Epoch 16700, Total loss 544.915405, iteration time 0.242747


 37%|███▋      | 16802/45000 [1:05:50<1:53:05,  4.16it/s]

Epoch 16800, Total loss 599.034180, iteration time 0.254563


 38%|███▊      | 16902/45000 [1:06:14<1:53:52,  4.11it/s]

Epoch 16900, Total loss 203.700958, iteration time 0.257776


 38%|███▊      | 17002/45000 [1:06:38<2:41:19,  2.89it/s]

Epoch 17000, Total loss 192.980835, iteration time 0.095970


 38%|███▊      | 17102/45000 [1:07:02<1:51:40,  4.16it/s]

Epoch 17100, Total loss 155.259003, iteration time 0.246319


 38%|███▊      | 17202/45000 [1:07:25<1:51:37,  4.15it/s]

Epoch 17200, Total loss 225.070679, iteration time 0.256849


 38%|███▊      | 17302/45000 [1:07:49<1:52:14,  4.11it/s]

Epoch 17300, Total loss 235.705856, iteration time 0.253724


 39%|███▊      | 17402/45000 [1:08:13<1:50:32,  4.16it/s]

Epoch 17400, Total loss 328.401917, iteration time 0.248754


 39%|███▉      | 17502/45000 [1:08:36<1:51:19,  4.12it/s]

Epoch 17500, Total loss 508.364471, iteration time 0.252879


 39%|███▉      | 17602/45000 [1:09:00<1:49:11,  4.18it/s]

Epoch 17600, Total loss 483.108490, iteration time 0.242688


 39%|███▉      | 17702/45000 [1:09:24<1:51:18,  4.09it/s]

Epoch 17700, Total loss 487.615875, iteration time 0.259940


 40%|███▉      | 17802/45000 [1:09:47<1:49:28,  4.14it/s]

Epoch 17800, Total loss 457.366455, iteration time 0.247124


 40%|███▉      | 17902/45000 [1:10:11<1:47:47,  4.19it/s]

Epoch 17900, Total loss 496.132446, iteration time 0.243291


 40%|████      | 18002/45000 [1:10:35<2:35:14,  2.90it/s]

Epoch 18000, Total loss 481.852051, iteration time 0.092623


 40%|████      | 18102/45000 [1:10:59<1:48:29,  4.13it/s]

Epoch 18100, Total loss 475.728027, iteration time 0.237637


 40%|████      | 18202/45000 [1:11:22<1:48:34,  4.11it/s]

Epoch 18200, Total loss 505.100159, iteration time 0.252031


 41%|████      | 18302/45000 [1:11:46<1:46:18,  4.19it/s]

Epoch 18300, Total loss 470.606384, iteration time 0.245239


 41%|████      | 18402/45000 [1:12:10<1:45:43,  4.19it/s]

Epoch 18400, Total loss 514.543091, iteration time 0.246127


 41%|████      | 18502/45000 [1:12:33<1:47:13,  4.12it/s]

Epoch 18500, Total loss 524.690674, iteration time 0.242842


 41%|████▏     | 18602/45000 [1:12:57<1:46:11,  4.14it/s]

Epoch 18600, Total loss 494.978027, iteration time 0.255145


 42%|████▏     | 18702/45000 [1:13:21<1:45:29,  4.15it/s]

Epoch 18700, Total loss 475.444031, iteration time 0.245775


 42%|████▏     | 18802/45000 [1:13:44<1:44:39,  4.17it/s]

Epoch 18800, Total loss 593.666565, iteration time 0.249936


 42%|████▏     | 18902/45000 [1:14:08<1:45:57,  4.11it/s]

Epoch 18900, Total loss 363.055542, iteration time 0.257469


 42%|████▏     | 19002/45000 [1:14:32<2:34:45,  2.80it/s]

Epoch 19000, Total loss 558.592163, iteration time 0.104513


 42%|████▏     | 19102/45000 [1:14:56<1:42:50,  4.20it/s]

Epoch 19100, Total loss 234.927429, iteration time 0.246388


 43%|████▎     | 19202/45000 [1:15:19<1:42:37,  4.19it/s]

Epoch 19200, Total loss 181.253769, iteration time 0.244411


 43%|████▎     | 19302/45000 [1:15:43<1:43:28,  4.14it/s]

Epoch 19300, Total loss 403.786926, iteration time 0.261271


 43%|████▎     | 19402/45000 [1:16:07<1:43:41,  4.11it/s]

Epoch 19400, Total loss 551.350342, iteration time 0.249036


 43%|████▎     | 19502/45000 [1:16:30<1:42:40,  4.14it/s]

Epoch 19500, Total loss 689.323242, iteration time 0.248596


 44%|████▎     | 19602/45000 [1:16:54<1:40:46,  4.20it/s]

Epoch 19600, Total loss 515.068726, iteration time 0.247313


 44%|████▍     | 19702/45000 [1:17:17<1:41:52,  4.14it/s]

Epoch 19700, Total loss 501.433105, iteration time 0.247279


 44%|████▍     | 19802/45000 [1:17:41<1:42:41,  4.09it/s]

Epoch 19800, Total loss 579.688904, iteration time 0.249735


 44%|████▍     | 19902/45000 [1:18:05<1:41:10,  4.13it/s]

Epoch 19900, Total loss 483.616913, iteration time 0.258048


 44%|████▍     | 20002/45000 [1:18:29<2:24:45,  2.88it/s]

Epoch 20000, Total loss 522.194214, iteration time 0.093621


 45%|████▍     | 20102/45000 [1:18:53<1:38:35,  4.21it/s]

Epoch 20100, Total loss 529.729553, iteration time 0.243263


 45%|████▍     | 20202/45000 [1:19:16<1:40:01,  4.13it/s]

Epoch 20200, Total loss 526.812012, iteration time 0.245445


 45%|████▌     | 20302/45000 [1:19:40<1:39:20,  4.14it/s]

Epoch 20300, Total loss 481.207825, iteration time 0.249904


 45%|████▌     | 20402/45000 [1:20:03<1:38:24,  4.17it/s]

Epoch 20400, Total loss 503.411926, iteration time 0.243479


 46%|████▌     | 20502/45000 [1:20:27<1:38:31,  4.14it/s]

Epoch 20500, Total loss 570.267334, iteration time 0.252218


 46%|████▌     | 20602/45000 [1:20:50<1:38:13,  4.14it/s]

Epoch 20600, Total loss 562.867554, iteration time 0.243407


 46%|████▌     | 20702/45000 [1:21:14<1:37:07,  4.17it/s]

Epoch 20700, Total loss 481.237946, iteration time 0.247145


 46%|████▌     | 20802/45000 [1:21:38<1:37:04,  4.15it/s]

Epoch 20800, Total loss 490.613342, iteration time 0.246947


 46%|████▋     | 20902/45000 [1:22:02<1:35:52,  4.19it/s]

Epoch 20900, Total loss 507.805481, iteration time 0.249704


 47%|████▋     | 21002/45000 [1:22:26<2:38:05,  2.53it/s]

Epoch 21000, Total loss 873.780334, iteration time 0.101306


 47%|████▋     | 21102/45000 [1:22:49<1:35:03,  4.19it/s]

Epoch 21100, Total loss 486.255371, iteration time 0.244384


 47%|████▋     | 21202/45000 [1:23:13<1:35:06,  4.17it/s]

Epoch 21200, Total loss 449.067322, iteration time 0.248771


 47%|████▋     | 21302/45000 [1:23:37<1:34:50,  4.16it/s]

Epoch 21300, Total loss 453.397797, iteration time 0.233396


 48%|████▊     | 21402/45000 [1:24:00<1:36:29,  4.08it/s]

Epoch 21400, Total loss 453.796631, iteration time 0.264968


 48%|████▊     | 21502/45000 [1:24:24<1:34:02,  4.16it/s]

Epoch 21500, Total loss 478.295288, iteration time 0.246232


 48%|████▊     | 21602/45000 [1:24:48<1:34:25,  4.13it/s]

Epoch 21600, Total loss 463.898438, iteration time 0.246248


 48%|████▊     | 21702/45000 [1:25:11<1:33:12,  4.17it/s]

Epoch 21700, Total loss 466.306030, iteration time 0.241885


 48%|████▊     | 21802/45000 [1:25:35<1:34:16,  4.10it/s]

Epoch 21800, Total loss 429.825348, iteration time 0.254178


 49%|████▊     | 21902/45000 [1:25:59<1:32:41,  4.15it/s]

Epoch 21900, Total loss 410.986877, iteration time 0.247926


 49%|████▉     | 22002/45000 [1:26:23<2:13:01,  2.88it/s]

Epoch 22000, Total loss 447.205383, iteration time 0.093291


 49%|████▉     | 22102/45000 [1:26:47<1:31:39,  4.16it/s]

Epoch 22100, Total loss 444.178711, iteration time 0.247706


 49%|████▉     | 22202/45000 [1:27:10<1:32:27,  4.11it/s]

Epoch 22200, Total loss 948.309814, iteration time 0.253468


 50%|████▉     | 22302/45000 [1:27:34<1:31:11,  4.15it/s]

Epoch 22300, Total loss 168.726562, iteration time 0.249382


 50%|████▉     | 22402/45000 [1:27:58<1:30:46,  4.15it/s]

Epoch 22400, Total loss 238.853317, iteration time 0.248334


 50%|█████     | 22502/45000 [1:28:21<1:30:00,  4.17it/s]

Epoch 22500, Total loss 222.483261, iteration time 0.243845


 50%|█████     | 22602/45000 [1:28:45<1:29:42,  4.16it/s]

Epoch 22600, Total loss 200.720245, iteration time 0.248298


 50%|█████     | 22702/45000 [1:29:09<1:30:21,  4.11it/s]

Epoch 22700, Total loss 577.773438, iteration time 0.247090


 51%|█████     | 22802/45000 [1:29:32<1:28:51,  4.16it/s]

Epoch 22800, Total loss 527.036560, iteration time 0.246248


 51%|█████     | 22902/45000 [1:29:56<1:28:33,  4.16it/s]

Epoch 22900, Total loss 522.438110, iteration time 0.247957


 51%|█████     | 23002/45000 [1:30:20<2:18:06,  2.65it/s]

Epoch 23000, Total loss 508.199066, iteration time 0.098222


 51%|█████▏    | 23102/45000 [1:30:44<1:28:04,  4.14it/s]

Epoch 23100, Total loss 523.229004, iteration time 0.243368


 52%|█████▏    | 23202/45000 [1:31:07<1:27:43,  4.14it/s]

Epoch 23200, Total loss 538.641602, iteration time 0.250944


 52%|█████▏    | 23302/45000 [1:31:31<1:26:46,  4.17it/s]

Epoch 23300, Total loss 554.481079, iteration time 0.244087


 52%|█████▏    | 23402/45000 [1:31:55<1:28:03,  4.09it/s]

Epoch 23400, Total loss 492.602173, iteration time 0.255279


 52%|█████▏    | 23502/45000 [1:32:18<1:25:26,  4.19it/s]

Epoch 23500, Total loss 513.410400, iteration time 0.244305


 52%|█████▏    | 23602/45000 [1:32:42<1:25:26,  4.17it/s]

Epoch 23600, Total loss 545.403992, iteration time 0.244679


 53%|█████▎    | 23702/45000 [1:33:06<1:25:30,  4.15it/s]

Epoch 23700, Total loss 518.591614, iteration time 0.246735


 53%|█████▎    | 23802/45000 [1:33:29<1:24:48,  4.17it/s]

Epoch 23800, Total loss 526.743774, iteration time 0.245797


 53%|█████▎    | 23902/45000 [1:33:53<1:24:19,  4.17it/s]

Epoch 23900, Total loss 478.955444, iteration time 0.242659


 53%|█████▎    | 24002/45000 [1:34:17<2:02:43,  2.85it/s]

Epoch 24000, Total loss 490.560272, iteration time 0.099997


 54%|█████▎    | 24102/45000 [1:34:41<1:25:00,  4.10it/s]

Epoch 24100, Total loss 474.649475, iteration time 0.252955


 54%|█████▍    | 24202/45000 [1:35:04<1:22:36,  4.20it/s]

Epoch 24200, Total loss 503.550629, iteration time 0.242425


 54%|█████▍    | 24302/45000 [1:35:28<1:23:03,  4.15it/s]

Epoch 24300, Total loss 503.586304, iteration time 0.247189


 54%|█████▍    | 24402/45000 [1:35:52<1:22:41,  4.15it/s]

Epoch 24400, Total loss 479.870056, iteration time 0.244761


 54%|█████▍    | 24502/45000 [1:36:15<1:22:40,  4.13it/s]

Epoch 24500, Total loss 512.388550, iteration time 0.239679


 55%|█████▍    | 24602/45000 [1:36:39<1:22:42,  4.11it/s]

Epoch 24600, Total loss 593.772339, iteration time 0.247602


 55%|█████▍    | 24702/45000 [1:37:03<1:21:42,  4.14it/s]

Epoch 24700, Total loss 183.992157, iteration time 0.250158


 55%|█████▌    | 24802/45000 [1:37:26<1:21:43,  4.12it/s]

Epoch 24800, Total loss 295.411621, iteration time 0.248201


 55%|█████▌    | 24902/45000 [1:37:50<1:20:19,  4.17it/s]

Epoch 24900, Total loss 253.006149, iteration time 0.244679


 56%|█████▌    | 25002/45000 [1:38:15<2:14:27,  2.48it/s]

Epoch 25000, Total loss 318.211426, iteration time 0.107419


 56%|█████▌    | 25102/45000 [1:38:38<1:19:05,  4.19it/s]

Epoch 25100, Total loss 336.772827, iteration time 0.242243


 56%|█████▌    | 25202/45000 [1:39:02<1:19:18,  4.16it/s]

Epoch 25200, Total loss 346.231476, iteration time 0.243474


 56%|█████▌    | 25302/45000 [1:39:25<1:18:11,  4.20it/s]

Epoch 25300, Total loss 455.447388, iteration time 0.242535


 56%|█████▋    | 25402/45000 [1:39:49<1:19:03,  4.13it/s]

Epoch 25400, Total loss 443.293579, iteration time 0.244083


 57%|█████▋    | 25502/45000 [1:40:13<1:18:38,  4.13it/s]

Epoch 25500, Total loss 423.952698, iteration time 0.246617


 57%|█████▋    | 25602/45000 [1:40:37<1:17:57,  4.15it/s]

Epoch 25600, Total loss 449.983795, iteration time 0.247382


 57%|█████▋    | 25702/45000 [1:41:00<1:17:42,  4.14it/s]

Epoch 25700, Total loss 432.825562, iteration time 0.248362


 57%|█████▋    | 25802/45000 [1:41:24<1:17:45,  4.11it/s]

Epoch 25800, Total loss 434.915344, iteration time 0.250797


 58%|█████▊    | 25902/45000 [1:41:48<1:16:26,  4.16it/s]

Epoch 25900, Total loss 438.868835, iteration time 0.248456


 58%|█████▊    | 26002/45000 [1:42:12<1:52:39,  2.81it/s]

Epoch 26000, Total loss 439.222290, iteration time 0.092817


 58%|█████▊    | 26102/45000 [1:42:36<1:16:01,  4.14it/s]

Epoch 26100, Total loss 411.525146, iteration time 0.257856


 58%|█████▊    | 26202/45000 [1:42:59<1:15:52,  4.13it/s]

Epoch 26200, Total loss 405.269043, iteration time 0.252639


 58%|█████▊    | 26302/45000 [1:43:23<1:15:23,  4.13it/s]

Epoch 26300, Total loss 465.931152, iteration time 0.247216


 59%|█████▊    | 26402/45000 [1:43:47<1:14:40,  4.15it/s]

Epoch 26400, Total loss 329.595398, iteration time 0.248625


 59%|█████▉    | 26502/45000 [1:44:10<1:13:44,  4.18it/s]

Epoch 26500, Total loss 351.427002, iteration time 0.242983


 59%|█████▉    | 26602/45000 [1:44:34<1:13:51,  4.15it/s]

Epoch 26600, Total loss 436.481567, iteration time 0.236830


 59%|█████▉    | 26702/45000 [1:44:58<1:12:51,  4.19it/s]

Epoch 26700, Total loss 437.927063, iteration time 0.244919


 60%|█████▉    | 26802/45000 [1:45:21<1:12:32,  4.18it/s]

Epoch 26800, Total loss 442.446716, iteration time 0.245184


 60%|█████▉    | 26902/45000 [1:45:45<1:12:58,  4.13it/s]

Epoch 26900, Total loss 453.370361, iteration time 0.247235


 60%|██████    | 27002/45000 [1:46:09<2:02:05,  2.46it/s]

Epoch 27000, Total loss 471.048157, iteration time 0.103575


 60%|██████    | 27102/45000 [1:46:33<1:11:54,  4.15it/s]

Epoch 27100, Total loss 453.768677, iteration time 0.248115


 60%|██████    | 27202/45000 [1:46:57<1:11:53,  4.13it/s]

Epoch 27200, Total loss 441.378601, iteration time 0.247754


 61%|██████    | 27302/45000 [1:47:20<1:11:27,  4.13it/s]

Epoch 27300, Total loss 407.575165, iteration time 0.251972


 61%|██████    | 27402/45000 [1:47:44<1:10:09,  4.18it/s]

Epoch 27400, Total loss 406.625305, iteration time 0.236029


 61%|██████    | 27502/45000 [1:48:08<1:11:16,  4.09it/s]

Epoch 27500, Total loss 438.160461, iteration time 0.262004


 61%|██████▏   | 27602/45000 [1:48:31<1:09:50,  4.15it/s]

Epoch 27600, Total loss 413.508362, iteration time 0.253469


 62%|██████▏   | 27702/45000 [1:48:55<1:09:28,  4.15it/s]

Epoch 27700, Total loss 380.327393, iteration time 0.247455


 62%|██████▏   | 27802/45000 [1:49:18<1:09:08,  4.15it/s]

Epoch 27800, Total loss 420.996582, iteration time 0.237811


 62%|██████▏   | 27902/45000 [1:49:42<1:09:45,  4.08it/s]

Epoch 27900, Total loss 415.880402, iteration time 0.254940


 62%|██████▏   | 28002/45000 [1:50:06<1:41:48,  2.78it/s]

Epoch 28000, Total loss 411.182953, iteration time 0.092796


 62%|██████▏   | 28102/45000 [1:50:30<1:07:29,  4.17it/s]

Epoch 28100, Total loss 476.390503, iteration time 0.250924


 63%|██████▎   | 28202/45000 [1:50:53<1:07:16,  4.16it/s]

Epoch 28200, Total loss 455.864746, iteration time 0.245898


 63%|██████▎   | 28302/45000 [1:51:17<1:07:29,  4.12it/s]

Epoch 28300, Total loss 414.841827, iteration time 0.244287


 63%|██████▎   | 28402/45000 [1:51:41<1:06:23,  4.17it/s]

Epoch 28400, Total loss 228.422195, iteration time 0.242428


 63%|██████▎   | 28502/45000 [1:52:05<1:07:07,  4.10it/s]

Epoch 28500, Total loss 299.716064, iteration time 0.264052


 64%|██████▎   | 28602/45000 [1:52:28<1:05:21,  4.18it/s]

Epoch 28600, Total loss 600.516602, iteration time 0.242872


 64%|██████▍   | 28702/45000 [1:52:52<1:06:14,  4.10it/s]

Epoch 28700, Total loss 420.786560, iteration time 0.243416


 64%|██████▍   | 28802/45000 [1:53:15<1:05:10,  4.14it/s]

Epoch 28800, Total loss 460.772827, iteration time 0.245183


 64%|██████▍   | 28902/45000 [1:53:39<1:03:54,  4.20it/s]

Epoch 28900, Total loss 426.469513, iteration time 0.250477


 64%|██████▍   | 29002/45000 [1:54:04<2:01:35,  2.19it/s]

Epoch 29000, Total loss 451.387451, iteration time 0.092912


 65%|██████▍   | 29102/45000 [1:54:27<1:04:04,  4.13it/s]

Epoch 29100, Total loss 408.600403, iteration time 0.249754


 65%|██████▍   | 29202/45000 [1:54:51<1:03:45,  4.13it/s]

Epoch 29200, Total loss 413.584106, iteration time 0.244377


 65%|██████▌   | 29302/45000 [1:55:15<1:02:28,  4.19it/s]

Epoch 29300, Total loss 399.904358, iteration time 0.241774


 65%|██████▌   | 29402/45000 [1:55:38<1:02:16,  4.17it/s]

Epoch 29400, Total loss 419.112549, iteration time 0.242942


 66%|██████▌   | 29502/45000 [1:56:02<1:01:46,  4.18it/s]

Epoch 29500, Total loss 480.751617, iteration time 0.232753


 66%|██████▌   | 29602/45000 [1:56:26<1:01:58,  4.14it/s]

Epoch 29600, Total loss 412.607361, iteration time 0.252778


 66%|██████▌   | 29702/45000 [1:56:49<1:01:03,  4.18it/s]

Epoch 29700, Total loss 420.419037, iteration time 0.245374


 66%|██████▌   | 29802/45000 [1:57:13<1:00:47,  4.17it/s]

Epoch 29800, Total loss 392.354919, iteration time 0.244047


 66%|██████▋   | 29902/45000 [1:57:36<1:00:09,  4.18it/s]

Epoch 29900, Total loss 444.012299, iteration time 0.244888


 67%|██████▋   | 30002/45000 [1:58:01<1:41:55,  2.45it/s]

Epoch 30000, Total loss 373.166077, iteration time 0.100893


 67%|██████▋   | 30102/45000 [1:58:25<59:19,  4.19it/s]

Epoch 30100, Total loss 260.860107, iteration time 0.243793


 67%|██████▋   | 30202/45000 [1:58:48<59:13,  4.16it/s]

Epoch 30200, Total loss 234.421097, iteration time 0.245490


 67%|██████▋   | 30302/45000 [1:59:12<58:34,  4.18it/s]

Epoch 30300, Total loss 489.282135, iteration time 0.242178


 68%|██████▊   | 30402/45000 [1:59:36<59:06,  4.12it/s]

Epoch 30400, Total loss 471.390991, iteration time 0.244390


 68%|██████▊   | 30502/45000 [1:59:59<57:21,  4.21it/s]

Epoch 30500, Total loss 464.022644, iteration time 0.239233


 68%|██████▊   | 30602/45000 [2:00:23<58:35,  4.10it/s]

Epoch 30600, Total loss 478.479370, iteration time 0.245042


 68%|██████▊   | 30702/45000 [2:00:46<57:58,  4.11it/s]

Epoch 30700, Total loss 485.228668, iteration time 0.248374


 68%|██████▊   | 30802/45000 [2:01:10<57:59,  4.08it/s]

Epoch 30800, Total loss 473.680206, iteration time 0.245936


 69%|██████▊   | 30902/45000 [2:01:34<56:03,  4.19it/s]

Epoch 30900, Total loss 494.761963, iteration time 0.241891


 69%|██████▉   | 31002/45000 [2:01:58<1:23:38,  2.79it/s]

Epoch 31000, Total loss 471.269836, iteration time 0.108810


 69%|██████▉   | 31102/45000 [2:02:22<56:32,  4.10it/s]

Epoch 31100, Total loss 498.266724, iteration time 0.242995


 69%|██████▉   | 31202/45000 [2:02:45<54:52,  4.19it/s]

Epoch 31200, Total loss 505.332336, iteration time 0.240871


 70%|██████▉   | 31302/45000 [2:03:09<54:59,  4.15it/s]

Epoch 31300, Total loss 485.466858, iteration time 0.247277


 70%|██████▉   | 31402/45000 [2:03:32<54:28,  4.16it/s]

Epoch 31400, Total loss 446.224060, iteration time 0.242143


 70%|███████   | 31502/45000 [2:03:56<54:20,  4.14it/s]

Epoch 31500, Total loss 441.807800, iteration time 0.237066


 70%|███████   | 31602/45000 [2:04:20<54:05,  4.13it/s]

Epoch 31600, Total loss 484.072754, iteration time 0.258353


 70%|███████   | 31702/45000 [2:04:43<53:04,  4.18it/s]

Epoch 31700, Total loss 535.323486, iteration time 0.242464


 71%|███████   | 31802/45000 [2:05:07<52:45,  4.17it/s]

Epoch 31800, Total loss 523.009277, iteration time 0.242211


 71%|███████   | 31902/45000 [2:05:31<53:09,  4.11it/s]

Epoch 31900, Total loss 467.636292, iteration time 0.255939


 71%|███████   | 32002/45000 [2:05:55<1:18:15,  2.77it/s]

Epoch 32000, Total loss 441.333588, iteration time 0.095571


 71%|███████▏  | 32102/45000 [2:06:18<51:42,  4.16it/s]

Epoch 32100, Total loss 470.744629, iteration time 0.250365


 72%|███████▏  | 32202/45000 [2:06:42<50:59,  4.18it/s]

Epoch 32200, Total loss 460.910278, iteration time 0.253635


 72%|███████▏  | 32302/45000 [2:07:06<51:05,  4.14it/s]

Epoch 32300, Total loss 490.825623, iteration time 0.247113


 72%|███████▏  | 32402/45000 [2:07:29<50:23,  4.17it/s]

Epoch 32400, Total loss 479.397644, iteration time 0.242129


 72%|███████▏  | 32502/45000 [2:07:53<50:18,  4.14it/s]

Epoch 32500, Total loss 445.228058, iteration time 0.245793


 72%|███████▏  | 32602/45000 [2:08:17<49:11,  4.20it/s]

Epoch 32600, Total loss 521.301392, iteration time 0.242704


 73%|███████▎  | 32702/45000 [2:08:40<49:40,  4.13it/s]

Epoch 32700, Total loss 479.869019, iteration time 0.250871


 73%|███████▎  | 32802/45000 [2:09:04<48:43,  4.17it/s]

Epoch 32800, Total loss 466.484100, iteration time 0.245856


 73%|███████▎  | 32902/45000 [2:09:27<48:45,  4.14it/s]

Epoch 32900, Total loss 455.205017, iteration time 0.245970


 73%|███████▎  | 33002/45000 [2:09:52<1:11:16,  2.81it/s]

Epoch 33000, Total loss 466.448730, iteration time 0.093352


 74%|███████▎  | 33102/45000 [2:10:15<47:58,  4.13it/s]

Epoch 33100, Total loss 453.718018, iteration time 0.249635


 74%|███████▍  | 33202/45000 [2:10:39<47:20,  4.15it/s]

Epoch 33200, Total loss 411.220917, iteration time 0.247168


 74%|███████▍  | 33302/45000 [2:11:03<46:58,  4.15it/s]

Epoch 33300, Total loss 516.014282, iteration time 0.244847


 74%|███████▍  | 33402/45000 [2:11:26<46:11,  4.19it/s]

Epoch 33400, Total loss 163.436752, iteration time 0.241115


 74%|███████▍  | 33502/45000 [2:11:50<45:44,  4.19it/s]

Epoch 33500, Total loss 254.785767, iteration time 0.245167


 75%|███████▍  | 33602/45000 [2:12:14<45:51,  4.14it/s]

Epoch 33600, Total loss 387.635315, iteration time 0.243074


 75%|███████▍  | 33702/45000 [2:12:37<45:25,  4.15it/s]

Epoch 33700, Total loss 411.912018, iteration time 0.247466


 75%|███████▌  | 33802/45000 [2:13:01<44:53,  4.16it/s]

Epoch 33800, Total loss 401.922913, iteration time 0.246063


 75%|███████▌  | 33902/45000 [2:13:24<44:33,  4.15it/s]

Epoch 33900, Total loss 411.756165, iteration time 0.247951


 76%|███████▌  | 34002/45000 [2:13:49<1:16:18,  2.40it/s]

Epoch 34000, Total loss 396.402832, iteration time 0.106783


 76%|███████▌  | 34102/45000 [2:14:13<43:39,  4.16it/s]

Epoch 34100, Total loss 387.693390, iteration time 0.243092


 76%|███████▌  | 34202/45000 [2:14:36<43:22,  4.15it/s]

Epoch 34200, Total loss 398.186981, iteration time 0.246324


 76%|███████▌  | 34302/45000 [2:15:00<43:00,  4.14it/s]

Epoch 34300, Total loss 403.981873, iteration time 0.247188


 76%|███████▋  | 34402/45000 [2:15:24<43:03,  4.10it/s]

Epoch 34400, Total loss 431.140717, iteration time 0.250230


 77%|███████▋  | 34502/45000 [2:15:47<41:45,  4.19it/s]

Epoch 34500, Total loss 406.160889, iteration time 0.242843


 77%|███████▋  | 34602/45000 [2:16:11<41:50,  4.14it/s]

Epoch 34600, Total loss 365.365479, iteration time 0.245662


 77%|███████▋  | 34702/45000 [2:16:35<41:21,  4.15it/s]

Epoch 34700, Total loss 469.955780, iteration time 0.250480


 77%|███████▋  | 34802/45000 [2:16:58<41:15,  4.12it/s]

Epoch 34800, Total loss 204.191254, iteration time 0.245414


 78%|███████▊  | 34902/45000 [2:17:22<40:17,  4.18it/s]

Epoch 34900, Total loss 740.657288, iteration time 0.242816


 78%|███████▊  | 35002/45000 [2:17:46<58:41,  2.84it/s]  

Epoch 35000, Total loss 450.522797, iteration time 0.093115


 78%|███████▊  | 35102/45000 [2:18:09<39:38,  4.16it/s]

Epoch 35100, Total loss 453.302521, iteration time 0.257058


 78%|███████▊  | 35202/45000 [2:18:33<39:38,  4.12it/s]

Epoch 35200, Total loss 461.051056, iteration time 0.249982


 78%|███████▊  | 35302/45000 [2:18:57<38:28,  4.20it/s]

Epoch 35300, Total loss 451.662048, iteration time 0.245641


 79%|███████▊  | 35402/45000 [2:19:20<38:28,  4.16it/s]

Epoch 35400, Total loss 437.698242, iteration time 0.245164


 79%|███████▉  | 35502/45000 [2:19:44<37:55,  4.17it/s]

Epoch 35500, Total loss 452.067810, iteration time 0.245004


 79%|███████▉  | 35602/45000 [2:20:08<38:00,  4.12it/s]

Epoch 35600, Total loss 457.133087, iteration time 0.244407


 79%|███████▉  | 35702/45000 [2:20:31<37:23,  4.14it/s]

Epoch 35700, Total loss 449.436401, iteration time 0.246321


 80%|███████▉  | 35802/45000 [2:20:55<36:36,  4.19it/s]

Epoch 35800, Total loss 454.781250, iteration time 0.247014


 80%|███████▉  | 35902/45000 [2:21:19<36:38,  4.14it/s]

Epoch 35900, Total loss 505.169617, iteration time 0.246192


 80%|████████  | 36002/45000 [2:21:43<1:01:10,  2.45it/s]

Epoch 36000, Total loss 435.384399, iteration time 0.107068


 80%|████████  | 36102/45000 [2:22:07<35:28,  4.18it/s]

Epoch 36100, Total loss 433.494873, iteration time 0.247452


 80%|████████  | 36202/45000 [2:22:30<35:11,  4.17it/s]

Epoch 36200, Total loss 442.311737, iteration time 0.248441


 81%|████████  | 36302/45000 [2:22:54<34:40,  4.18it/s]

Epoch 36300, Total loss 554.903320, iteration time 0.247913


 81%|████████  | 36402/45000 [2:23:17<34:52,  4.11it/s]

Epoch 36400, Total loss 433.424805, iteration time 0.253719


 81%|████████  | 36502/45000 [2:23:41<34:16,  4.13it/s]

Epoch 36500, Total loss 402.965698, iteration time 0.245027


 81%|████████▏ | 36602/45000 [2:24:05<33:20,  4.20it/s]

Epoch 36600, Total loss 329.353546, iteration time 0.241132


 82%|████████▏ | 36702/45000 [2:24:28<32:53,  4.20it/s]

Epoch 36700, Total loss 184.040268, iteration time 0.242404


 82%|████████▏ | 36802/45000 [2:24:52<32:50,  4.16it/s]

Epoch 36800, Total loss 440.786774, iteration time 0.240581


 82%|████████▏ | 36902/45000 [2:25:16<32:44,  4.12it/s]

Epoch 36900, Total loss 396.969910, iteration time 0.250759


 82%|████████▏ | 37002/45000 [2:25:40<47:24,  2.81it/s]

Epoch 37000, Total loss 418.453400, iteration time 0.093390


 82%|████████▏ | 37102/45000 [2:26:03<31:51,  4.13it/s]

Epoch 37100, Total loss 399.090729, iteration time 0.246268


 83%|████████▎ | 37202/45000 [2:26:27<31:23,  4.14it/s]

Epoch 37200, Total loss 401.158722, iteration time 0.247261


 83%|████████▎ | 37302/45000 [2:26:51<30:59,  4.14it/s]

Epoch 37300, Total loss 414.842682, iteration time 0.241402


 83%|████████▎ | 37402/45000 [2:27:14<30:30,  4.15it/s]

Epoch 37400, Total loss 409.925781, iteration time 0.248221


 83%|████████▎ | 37502/45000 [2:27:38<29:46,  4.20it/s]

Epoch 37500, Total loss 424.156128, iteration time 0.241678


 84%|████████▎ | 37602/45000 [2:28:02<29:36,  4.16it/s]

Epoch 37600, Total loss 381.114258, iteration time 0.246607


 84%|████████▍ | 37702/45000 [2:28:25<29:06,  4.18it/s]

Epoch 37700, Total loss 432.525421, iteration time 0.250954


 84%|████████▍ | 37802/45000 [2:28:49<28:47,  4.17it/s]

Epoch 37800, Total loss 397.698853, iteration time 0.249999


 84%|████████▍ | 37902/45000 [2:29:13<28:21,  4.17it/s]

Epoch 37900, Total loss 436.424988, iteration time 0.243091


 84%|████████▍ | 38002/45000 [2:29:37<41:18,  2.82it/s]

Epoch 38000, Total loss 393.928284, iteration time 0.093378


 85%|████████▍ | 38102/45000 [2:30:00<28:21,  4.05it/s]

Epoch 38100, Total loss 450.897369, iteration time 0.255862


 85%|████████▍ | 38202/45000 [2:30:24<27:22,  4.14it/s]

Epoch 38200, Total loss 417.966736, iteration time 0.256317


 85%|████████▌ | 38302/45000 [2:30:48<26:53,  4.15it/s]

Epoch 38300, Total loss 464.954712, iteration time 0.246804


 85%|████████▌ | 38402/45000 [2:31:11<26:30,  4.15it/s]

Epoch 38400, Total loss 168.218674, iteration time 0.248977


 86%|████████▌ | 38502/45000 [2:31:35<26:18,  4.12it/s]

Epoch 38500, Total loss 313.700073, iteration time 0.245253


 86%|████████▌ | 38602/45000 [2:31:59<25:32,  4.17it/s]

Epoch 38600, Total loss 230.784851, iteration time 0.252563


 86%|████████▌ | 38702/45000 [2:32:22<25:14,  4.16it/s]

Epoch 38700, Total loss 452.790344, iteration time 0.243861


 86%|████████▌ | 38802/45000 [2:32:46<24:56,  4.14it/s]

Epoch 38800, Total loss 468.272522, iteration time 0.250388


 86%|████████▋ | 38902/45000 [2:33:10<24:37,  4.13it/s]

Epoch 38900, Total loss 443.292633, iteration time 0.247332


 87%|████████▋ | 39002/45000 [2:33:34<35:39,  2.80it/s]

Epoch 39000, Total loss 462.203644, iteration time 0.095246


 87%|████████▋ | 39102/45000 [2:33:57<23:23,  4.20it/s]

Epoch 39100, Total loss 467.748566, iteration time 0.243303


 87%|████████▋ | 39202/45000 [2:34:21<23:21,  4.14it/s]

Epoch 39200, Total loss 449.729065, iteration time 0.257641


 87%|████████▋ | 39302/45000 [2:34:45<23:02,  4.12it/s]

Epoch 39300, Total loss 459.246735, iteration time 0.251251


 88%|████████▊ | 39402/45000 [2:35:08<22:30,  4.15it/s]

Epoch 39400, Total loss 443.918671, iteration time 0.248670


 88%|████████▊ | 39502/45000 [2:35:32<22:02,  4.16it/s]

Epoch 39500, Total loss 436.403198, iteration time 0.246497


 88%|████████▊ | 39602/45000 [2:35:56<21:25,  4.20it/s]

Epoch 39600, Total loss 418.368347, iteration time 0.242565


 88%|████████▊ | 39702/45000 [2:36:19<21:20,  4.14it/s]

Epoch 39700, Total loss 450.118896, iteration time 0.250029


 88%|████████▊ | 39802/45000 [2:36:43<20:49,  4.16it/s]

Epoch 39800, Total loss 435.392181, iteration time 0.252770


 89%|████████▊ | 39902/45000 [2:37:06<20:39,  4.11it/s]

Epoch 39900, Total loss 414.777008, iteration time 0.247712


 89%|████████▉ | 40002/45000 [2:37:31<29:41,  2.81it/s]

Epoch 40000, Total loss 446.970764, iteration time 0.094941


 89%|████████▉ | 40102/45000 [2:37:54<19:50,  4.11it/s]

Epoch 40100, Total loss 354.984070, iteration time 0.243916


 89%|████████▉ | 40202/45000 [2:38:18<19:12,  4.16it/s]

Epoch 40200, Total loss 429.324219, iteration time 0.244259


 90%|████████▉ | 40302/45000 [2:38:42<18:46,  4.17it/s]

Epoch 40300, Total loss 446.099731, iteration time 0.245935


 90%|████████▉ | 40402/45000 [2:39:05<18:29,  4.15it/s]

Epoch 40400, Total loss 443.471985, iteration time 0.247182


 90%|█████████ | 40502/45000 [2:39:29<18:08,  4.13it/s]

Epoch 40500, Total loss 394.865479, iteration time 0.245659


 90%|█████████ | 40602/45000 [2:39:53<17:39,  4.15it/s]

Epoch 40600, Total loss 498.453613, iteration time 0.245209


 90%|█████████ | 40702/45000 [2:40:16<17:11,  4.17it/s]

Epoch 40700, Total loss 541.781616, iteration time 0.246539


 91%|█████████ | 40802/45000 [2:40:40<16:45,  4.18it/s]

Epoch 40800, Total loss 555.965027, iteration time 0.243673


 91%|█████████ | 40902/45000 [2:41:03<16:33,  4.12it/s]

Epoch 40900, Total loss 681.187500, iteration time 0.251090


 91%|█████████ | 41002/45000 [2:41:28<23:52,  2.79it/s]

Epoch 41000, Total loss 244.129211, iteration time 0.092761


 91%|█████████▏| 41102/45000 [2:41:51<15:41,  4.14it/s]

Epoch 41100, Total loss 261.176697, iteration time 0.248595


 92%|█████████▏| 41202/45000 [2:42:15<15:18,  4.13it/s]

Epoch 41200, Total loss 451.173370, iteration time 0.254289


 92%|█████████▏| 41302/45000 [2:42:39<14:57,  4.12it/s]

Epoch 41300, Total loss 432.475952, iteration time 0.245869


 92%|█████████▏| 41402/45000 [2:43:02<14:23,  4.17it/s]

Epoch 41400, Total loss 470.025391, iteration time 0.241507


 92%|█████████▏| 41502/45000 [2:43:26<14:05,  4.14it/s]

Epoch 41500, Total loss 495.757080, iteration time 0.246083


 92%|█████████▏| 41602/45000 [2:43:49<13:29,  4.20it/s]

Epoch 41600, Total loss 460.266357, iteration time 0.238457


 93%|█████████▎| 41702/45000 [2:44:13<13:09,  4.18it/s]

Epoch 41700, Total loss 470.053589, iteration time 0.243715


 93%|█████████▎| 41802/45000 [2:44:37<12:58,  4.11it/s]

Epoch 41800, Total loss 448.055298, iteration time 0.264660


 93%|█████████▎| 41902/45000 [2:45:00<12:18,  4.19it/s]

Epoch 41900, Total loss 452.630035, iteration time 0.242374


 93%|█████████▎| 42002/45000 [2:45:24<17:42,  2.82it/s]

Epoch 42000, Total loss 439.442535, iteration time 0.091402


 94%|█████████▎| 42102/45000 [2:45:48<11:39,  4.14it/s]

Epoch 42100, Total loss 457.396515, iteration time 0.246744


 94%|█████████▍| 42202/45000 [2:46:12<11:30,  4.05it/s]

Epoch 42200, Total loss 404.833618, iteration time 0.265207


 94%|█████████▍| 42302/45000 [2:46:35<10:42,  4.20it/s]

Epoch 42300, Total loss 423.963745, iteration time 0.242618


 94%|█████████▍| 42402/45000 [2:46:59<10:23,  4.17it/s]

Epoch 42400, Total loss 450.493744, iteration time 0.244730


 94%|█████████▍| 42502/45000 [2:47:22<10:05,  4.13it/s]

Epoch 42500, Total loss 428.768372, iteration time 0.247081


 95%|█████████▍| 42602/45000 [2:47:46<09:40,  4.13it/s]

Epoch 42600, Total loss 470.765503, iteration time 0.244800


 95%|█████████▍| 42702/45000 [2:48:10<09:10,  4.18it/s]

Epoch 42700, Total loss 341.230103, iteration time 0.246096


 95%|█████████▌| 42802/45000 [2:48:33<08:51,  4.13it/s]

Epoch 42800, Total loss 263.724945, iteration time 0.244093


 95%|█████████▌| 42902/45000 [2:48:57<08:24,  4.16it/s]

Epoch 42900, Total loss 210.982574, iteration time 0.243805


 96%|█████████▌| 43002/45000 [2:49:22<14:00,  2.38it/s]

Epoch 43000, Total loss 352.012573, iteration time 0.103587


 96%|█████████▌| 43102/45000 [2:49:45<07:36,  4.16it/s]

Epoch 43100, Total loss 435.273560, iteration time 0.247775


 96%|█████████▌| 43202/45000 [2:50:09<07:09,  4.18it/s]

Epoch 43200, Total loss 440.343475, iteration time 0.245254


 96%|█████████▌| 43302/45000 [2:50:32<06:45,  4.18it/s]

Epoch 43300, Total loss 452.701538, iteration time 0.245808


 96%|█████████▋| 43402/45000 [2:50:56<06:28,  4.12it/s]

Epoch 43400, Total loss 448.390564, iteration time 0.249555


 97%|█████████▋| 43502/45000 [2:51:20<05:59,  4.17it/s]

Epoch 43500, Total loss 454.249573, iteration time 0.242372


 97%|█████████▋| 43602/45000 [2:51:43<05:35,  4.17it/s]

Epoch 43600, Total loss 434.853210, iteration time 0.245641


 97%|█████████▋| 43702/45000 [2:52:07<05:13,  4.14it/s]

Epoch 43700, Total loss 451.464935, iteration time 0.247300


 97%|█████████▋| 43802/45000 [2:52:31<04:50,  4.13it/s]

Epoch 43800, Total loss 457.556152, iteration time 0.245598


 98%|█████████▊| 43902/45000 [2:52:54<04:23,  4.16it/s]

Epoch 43900, Total loss 469.504578, iteration time 0.252566


 98%|█████████▊| 44002/45000 [2:53:18<05:58,  2.78it/s]

Epoch 44000, Total loss 446.835052, iteration time 0.091733


 98%|█████████▊| 44102/45000 [2:53:42<03:34,  4.18it/s]

Epoch 44100, Total loss 440.485535, iteration time 0.238811


 98%|█████████▊| 44202/45000 [2:54:06<03:13,  4.11it/s]

Epoch 44200, Total loss 414.898560, iteration time 0.250425


 98%|█████████▊| 44302/45000 [2:54:29<02:46,  4.19it/s]

Epoch 44300, Total loss 476.491669, iteration time 0.240371


 99%|█████████▊| 44402/45000 [2:54:53<02:23,  4.17it/s]

Epoch 44400, Total loss 489.915527, iteration time 0.241995


 99%|█████████▉| 44502/45000 [2:55:16<01:58,  4.19it/s]

Epoch 44500, Total loss 401.431946, iteration time 0.238953


 99%|█████████▉| 44602/45000 [2:55:40<01:36,  4.12it/s]

Epoch 44600, Total loss 454.419617, iteration time 0.251618


 99%|█████████▉| 44702/45000 [2:56:04<01:11,  4.18it/s]

Epoch 44700, Total loss 413.478546, iteration time 0.242973


100%|█████████▉| 44802/45000 [2:56:27<00:47,  4.17it/s]

Epoch 44800, Total loss 423.471527, iteration time 0.252428


100%|█████████▉| 44902/45000 [2:56:51<00:23,  4.17it/s]

Epoch 44900, Total loss 498.365753, iteration time 0.241048


100%|██████████| 45000/45000 [2:57:15<00:00,  4.23it/s]
