In [1]:
# Created by Yaru Niu

import gym
import numpy as np
import copy
import argparse
import random
import os
import torch
from icct.rl_helpers import ddt_sac_policy
from icct.core.icct_helpers import convert_to_crisp
from icct.rl_helpers.eval_callback import EpCheckPointCallback
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    CombinedExtractor,
    FlattenExtractor
)

from stable_baselines3 import SAC
import highway_env
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.evaluation import evaluate_policy

# os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import env

# ignore gym/stablebaseline warnings
import warnings
warnings.filterwarnings('ignore', category=UserWarning)


def make_env(env_name, seed):
    set_random_seed(seed)
    if env_name == 'lunar':
        env = gym.make('LunarLanderContinuous-v2')
        name = 'LunarLanderContinuous-v2'
        env.seed(seed)
    elif env_name == 'walker':
        env = gym.make("BipedalWalker-v3")
        name = 'BipedalWalker-v3'
        env.seed(seed)
    elif env_name == 'cart':
        env = gym.make('InvertedPendulum-v2')
        name = 'InvertedPendulum-v2'
        env.seed(seed)
    elif env_name == 'lane_keeping':
        env = gym.make('lane-keeping-v0')
        name = 'lane-keeping-v0'
        env.seed(seed)
    elif env_name == 'ring_accel':
        create_env, gym_name = make_create_env(params=ring_accel_params, version=0)
        env = create_env()  
        name = gym_name
        env.seed(seed)
    # elif env_name == 'ring_lane_changing':
    #     create_env, gym_name = make_create_env(params=ring_accel_lc_params, version=0)
    #     env = create_env()  
    #     name = gym_name    
    # elif env_name == 'figure8':
    #     create_env, gym_name = make_create_env(params=fig8_params, version=0)
    #     env = create_env()  
    #     name = gym_name 
    elif env_name == 'highway':
        env = gym.make('highway-v0')
        name = 'highway-v0'
        env.seed(seed)
    
    elif env_name == 'intersection':
        env = gym.make('intersection-v0')
        name = 'intersection-v0'
        env.seed(seed)
    
    elif env_name == 'racetrack':
        env = gym.make('racetrack-v0')
        name = 'racetrack-v0'
        env.seed(seed)
        
    else:
        raise Exception('No valid environment selected')
    
    
    return env, name
torch.use_deterministic_algorithms(True)

  logger.warn("Overriding environment {}".format(id))
  logger.warn("Overriding environment {}".format(id))
  logger.warn("Overriding environment {}".format(id))


In [2]:

# load model
model_drnet = SAC.load(f'{"../trained_models/walker_cpu/drnet/512leaves_04-19-01-Dec/"}best_model_seed{"11"}.zip', device='cpu')
model_nn = SAC.load(f'{"../trained_models/walker_cpu/mlp/512leaves_56-18-01-Dec/"}best_model_seed{"11"}.zip', device='cpu')

DTSemNetReg(
  (linear1): Linear(in_features=24, out_features=511, bias=True)
  (reluP): ReLU()
  (reluM): ReLU()
  (linear2): Linear(in_features=1022, out_features=512, bias=False)
  (mpool): MaxPoolLayer()
  (softmax): Softmax(dim=-1)
  (regression_layer_action1): Sequential(
    (0): Linear(in_features=24, out_features=512, bias=True)
  )
  (regression_layer_action2): Sequential(
    (0): Linear(in_features=24, out_features=512, bias=True)
  )
  (regression_layer_action3): Sequential(
    (0): Linear(in_features=24, out_features=512, bias=True)
  )
  (regression_layer_action4): Sequential(
    (0): Linear(in_features=24, out_features=512, bias=True)
  )
)


In [3]:
model_nn.policy.actor

Actor(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (latent_pi): Sequential(
    (0): Linear(in_features=24, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
  )
  (mu): Sequential(
    (0): Linear(in_features=512, out_features=4, bias=True)
    (1): Hardtanh(min_val=-2.0, max_val=2.0)
  )
)

In [4]:
# # # #===== Torchstat =====#
# # stat(model, dummy_input)
# from torchstat import stat
# from torchinfo import summary


# # Define the input size as (batch_size, channels, height, width)
# input_size = (1, 24)  

# # Print model summary
# summary(model_nn.policy.actor, input_size=input_size)

In [7]:
num_params = sum(p.numel() for p in model_drnet.policy.actor.parameters())
print("Number of tunable parameters:", num_params)


Number of tunable parameters: 589483


In [11]:
# ========== Seed for Test Envs [500, 501 .... 599] ==========
model_drnet.set_random_seed(500)
model_nn.set_random_seed(500)
env, env_n = make_env('walker', seed=500)

In [14]:
import time

# Set the batch size
batch_size = 1

# Create an empty list to store the samples
samples = []

# Record the inference time
start_time = time.time()

# Take random samples from the environment
for _ in range(batch_size):
    sample = env.observation_space.sample()
    samples.append(sample)

# Pass the samples through the policy network
for sample in samples:
    model_nn.predict(sample)

# Calculate the inference time
inference_time = time.time() - start_time

print("Inference time:", inference_time)


Inference time: 0.002820730209350586


DTRegNetInference time: 0.04

NN Inference time: 0.003

In [None]:


reward, _ = evaluate_policy(model,
                            env,
                            n_eval_episodes=1,
                            deterministic=True)