In [1]:
# load offline trained agent
# load expert agent
# write ensemble class for offline trained agents
# use distributional RL to detect risky states
# use ensembles to detect novel states
# if novelty is above a treshold give control to expert 
# if risk is above a treshold give control to expert
# can conformal prediction give us guaranties about the performance in this setu?p
# empirecally verify if we are able to get the desired performance

In [63]:
import re
import os
import torch
import numpy as np
import gymnasium as gym
from types import SimpleNamespace

from examples.offline.utils import load_buffer_d4rl
from tianshou.policy import DSACPolicy, BasePolicy
from tianshou.env import SubprocVectorEnv
from tianshou.data import Collector, Batch
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, QuantileMlp
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

In [3]:
device = "cuda"

In [4]:
def parse_value(value):
    # Convert simple types (int, float, bool, None)
    if value.isdigit():
        return int(value)
    elif re.match(r'^\d+\.\d+$', value):
        return float(value)
    elif value == "True":
        return True
    elif value == "False":
        return False
    elif value == "None":
        return None
    elif value.startswith("[") and value.endswith("]"):
        # Convert the list items
        items = re.split(r',(?=[^\]]*(?:\[|$))', value[1:-1])
        return [parse_value(item.strip()) for item in items]
    elif value.startswith("(") and value.endswith(")"):
        # Convert the tuple items
        items = re.split(r',(?=[^\)]*(?:\(|$))', value[1:-1])
        # Special case for single-item tuple
        if len(items) == 2 and items[0].strip() != '':
            return (parse_value(items[0].strip()),)
        return tuple(parse_value(item.strip()) for item in items)
    elif value.startswith("'") and value.endswith("'"):
        return value[1:-1]
    # Else, return the value as-is
    return value

def get_args(event_file):
    ea = EventAccumulator(event_file)
    ea.Reload()  # Load the file
    # Get the text data
    texts = ea.Tags()["tensors"]
    # Extract the actual text content
    text_data = {}
    for tag in texts:
        events = ea.Tensors(tag)
        for event in events:
            # You can extract the wall_time and step if needed
            # wall_time, step, value = event.wall_time, event.step, event.text
            text_data[tag] = event.tensor_proto.string_val
    data = text_data['args/text_summary'][0]
    # Convert bytes to string
    data_str = data.decode('utf-8')
    # Remove the "Namespace(" prefix and the trailing ")"
    data_str = data_str[len("Namespace("):-1]
    # Split into key-value pairs
    key_values = re.split(r',(?=\s\w+=)', data_str)
    # Parse each key-value pair
    args_dict = {}
    for kv in key_values:
        key, value = kv.split('=', 1)
        key = key.strip()
        args_dict[key] = parse_value(value)
    args = SimpleNamespace(**args_dict)
    return args

In [5]:
def load_policy(args, path):
    env = gym.make(args.task)
    # model
    net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=device)
    actor = ActorProb(
        net_a,
        args.action_shape,
        device=device,
        unbounded=True,
        conditioned_sigma=True,
    ).to(device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    critic1 = QuantileMlp(hidden_sizes=args.hidden_sizes, input_size=args.state_shape[0] + args.action_shape[0], device=device).to(device)
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
    critic2 = QuantileMlp(hidden_sizes=args.hidden_sizes, input_size=args.state_shape[0] + args.action_shape[0], device=device).to(device)
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
    policy = DSACPolicy(
        actor,
        actor_optim,
        critic1,
        critic1_optim,
        critic2,
        critic2_optim,
        risk_type=args.risk_type,
        tau=args.tau,
        gamma=args.gamma,
        alpha=args.alpha,
        action_space=env.action_space,
        device=device,
    )
    dirname = os.path.dirname(path)
    if os.path.isfile(os.path.join(dirname, "actor.pth")):
        policy.actor.load_state_dict(torch.load(os.path.join(dirname, "actor.pth"), map_location=device))
        print("Loaded actor from: ", os.path.join(dirname, "actor.pth"))
    if os.path.isfile(os.path.join(dirname, "critic1.pth")):
        policy.critic1.load_state_dict(torch.load(os.path.join(dirname, "critic1.pth"), map_location=device))
        policy.critic1_old.load_state_dict(torch.load(os.path.join(dirname, "critic1.pth"), map_location=device))
        print("Loaded critic1 from: ", os.path.join(dirname, "critic1.pth"))
    if os.path.isfile(os.path.join(dirname, "critic2.pth")):
        policy.critic2.load_state_dict(torch.load(os.path.join(dirname, "critic2.pth"), map_location=device))
        policy.critic2_old.load_state_dict(torch.load(os.path.join(dirname, "critic2.pth"), map_location=device))
        print("Loaded critic2 from: ", os.path.join(dirname, "critic2.pth"))
    else:
        policy.load_state_dict(torch.load(path, map_location=device))
        print("Loaded agent from: ", path)
    return policy

def load_behavioral_crtitic(args, path):
    behavioral_critic = QuantileMlp(
        input_size=args.state_shape[0] + args.action_shape[0],
        hidden_sizes=args.hidden_sizes,
        device=device,
    ).to(device)
    behavioral_critic.load_state_dict(torch.load(path, map_location=device))
    return behavioral_critic

def get_model(log_path, type=None):
    files = os.listdir(log_path)
    event_file = [f for f in files if f.startswith('event')][0]
    full_path = os.path.join(log_path, event_file)
    args = get_args(full_path)
    if type == "behavioral":
        resume_path = os.path.join(log_path, 'model.pth')
        policy = load_behavioral_crtitic(args, resume_path)
    else:
        resume_path = os.path.join(log_path, 'policy.pth')
        policy = load_policy(args, resume_path)
    return policy

In [6]:
log_path = "/data/user/R901105/dev/log/Hopper-v4/qr/231102-133240"
behavioral_critic = get_model(log_path, "behavioral")

In [7]:
log_path1 = "/data/user/R901105/dev/log/Hopper-v2/codac_bc/neutral/0/231102-150037"
offline_policy1 = get_model(log_path1)

  logger.deprecation(
  logger.deprecation(


Loaded actor from:  /data/user/R901105/dev/log/Hopper-v2/codac_bc/neutral/0/231102-150037/actor.pth
Loaded critic1 from:  /data/user/R901105/dev/log/Hopper-v2/codac_bc/neutral/0/231102-150037/critic1.pth
Loaded critic2 from:  /data/user/R901105/dev/log/Hopper-v2/codac_bc/neutral/0/231102-150037/critic2.pth


In [8]:
log_path2 = "/data/user/R901105/dev/log/Hopper-v2/codac_bc/wang/0/231106-104348"
offline_policy2 = get_model(log_path2)

Loaded actor from:  /data/user/R901105/dev/log/Hopper-v2/codac_bc/wang/0/231106-104348/actor.pth
Loaded critic1 from:  /data/user/R901105/dev/log/Hopper-v2/codac_bc/wang/0/231106-104348/critic1.pth
Loaded critic2 from:  /data/user/R901105/dev/log/Hopper-v2/codac_bc/wang/0/231106-104348/critic2.pth


  logger.deprecation(
  logger.deprecation(


In [9]:
log_path3 = "/data/user/R901105/dev/log/Hopper-v2/codac_bc/cvar/0/231106-123337"
offline_policy3 = get_model(log_path3)

Loaded actor from:  /data/user/R901105/dev/log/Hopper-v2/codac_bc/cvar/0/231106-123337/actor.pth
Loaded critic1 from:  /data/user/R901105/dev/log/Hopper-v2/codac_bc/cvar/0/231106-123337/critic1.pth
Loaded critic2 from:  /data/user/R901105/dev/log/Hopper-v2/codac_bc/cvar/0/231106-123337/critic2.pth


In [10]:
log_path = "/data/user/R901105/dev/log/Hopper-v4/dsac/wang/0/230824-151635"
expert_policy = get_model(log_path) 

Loaded agent from:  /data/user/R901105/dev/log/Hopper-v4/dsac/wang/0/230824-151635/policy.pth


In [64]:
class EnsemblePolicy(BasePolicy):
    def __init__(self, policies, action_space):
        super().__init__(action_space=action_space)
        self.policies = policies

    def forward(self, batch, state=None, **kwargs):
        actions = np.stack([p(batch).act.detach().cpu() for p in self.policies])
        return Batch(**{'act': np.mean(actions, axis=0), 'state': None})
    
    def learn(self, batch, **kwargs):
        pass
    
    def get_qvalues(self, obs, act):
        q_values = np.stack([p.critic1(obs, act).detach().cpu() for p in self.policies])
        return q_values

In [56]:
# check the performance of the expert policy
task = "Hopper-v2"
env = gym.make(task)
envs = SubprocVectorEnv([lambda: gym.make("Hopper-v2") for _ in range(20)])

In [49]:
policies = [offline_policy1, offline_policy2, offline_policy3]

In [65]:
ensemble = EnsemblePolicy(policies, env.action_space)

In [66]:
collector = Collector(ensemble, envs)

In [67]:
result = collector.collect(n_episode=100)

In [69]:
result['rews']

array([1297.4056133 , 1551.75705752, 1525.07748803, 1543.24188126,
       1560.76635264, 1574.4339472 , 1569.98205767, 1561.72085199,
       1578.28293235, 1636.2818007 , 1645.35106032, 1774.9540754 ,
       1787.26778757, 1772.41089845, 1827.39825956, 1819.17287196,
       1996.27301949, 2267.44090986, 2377.69300312, 2413.05030823,
       1417.60868587, 1549.67812415, 1535.43799899, 1660.32682601,
       1762.30531851, 1722.73171207, 1526.29202326, 1807.54897506,
       2121.68204262, 1805.1412413 , 1957.47295122, 1722.55052329,
       1785.47652772, 2136.30970367, 1710.58438485, 2136.00623529,
       2428.17527494, 2172.07740722, 2119.91707466, 2117.84973968,
       1250.58047082, 1499.95465738, 1191.73644691, 1643.66161783,
       1698.83964616, 1794.35831886, 1842.91518837, 1622.4076671 ,
       1670.8247402 , 1826.863056  , 1465.48605068, 1728.47698258,
       1249.89990589, 2124.01860294, 1581.46086299, 1552.62223346,
       1804.64894351, 2611.88258341, 1595.56698652, 1588.84678

In [None]:
# establish baselines of the ensemble performance and the expert performance 
# based on epistemic and aleatoric uncertainties give control to the expert policy
# check the performance of the ensemble + expert in comparaison with only the expert
# perform continual learning and observe the evolution of numbers of calls to the expert