In [1]:
import sys; sys.path.append(2*'../')

import os
import glob
from pathlib import Path
from omegaconf import DictConfig
import yaml

import torch
import lightning as L

from rl4co.tasks.rl4co import RL4COLitModule


  warn(
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoints_path = Path('../../saved_checkpoints/')

exp_name = 'tsp50'
model_name = 'am-tsp50'
checkpoints_path = checkpoints_path / exp_name / model_name

hydra_config_path = checkpoints_path / 'config.yaml'

with open(hydra_config_path, 'r') as stream:
    hydra_config_yaml = yaml.safe_load(stream)


# for each key in hydr_config_yaml, replace / by .
def clean_hydra_config(config, keep_value_only=True):
    """Clean hydra config by nesting dictionary and cleaning values"""
    new_config = {}
    # Iterate over config dictionary
    for key, value in config.items():
        # If key contains slash, split it and create nested dictionary recursively
        if '/' in key:
            keys = key.split('/')
            d = new_config
            for k in keys[:-1]:
                d = d.setdefault(k, {})
            d[keys[-1]] = value['value'] if keep_value_only else value
        else:
            new_config[key] = value['value'] if keep_value_only else value
    return DictConfig(new_config)


# Remove keys containing 'wandb' 
def remove_wandb_keys(config):
    """Remove keys containing 'wandb'"""
    new_config = {}
    for key, value in config.items():
        if 'wandb' in key:
            continue
        else:
            new_config[key] = value
    return new_config

hydra_config_yaml = remove_wandb_keys(hydra_config_yaml)

hydra_config = clean_hydra_config(hydra_config_yaml)
print(hydra_config)



In [3]:
lit_module = RL4COLitModule(hydra_config)

Unused kwargs: {'params': {'total': 708608, 'trainable': 708608, 'non_trainable': 0}}


In [4]:
lit_module = RL4COLitModule(hydra_config)


# Remove setup function from lit_module.model if hasattr(lit_module.model, 'setup')
if hasattr(lit_module.model, 'setup'):
    print("No setup function for model required during testing!")
    lit_module.model.setup = lambda *args, **kwargs: None
if hasattr(lit_module.model, "wrap_dataset"):
    print("No wrap_dataset function for model required during testing!")
    lit_module.model.wrap_dataset = lambda *args, **kwargs: None

# Load from checkpoint. We do not want to load the baseline weights, so we set strict=False
# lit_module.load_from_checkpoint(checkpoints_path / 'epoch_099.ckpt', strict=False)

def load_policy_state_dict(lit_module, path, device='cpu'):
    state_dict = torch.load(path, map_location=device)['state_dict']
    # get only policy parameters
    policy_state_dict = {k: v for k, v in state_dict.items() if 'policy' in k}
    # remove leading 'policy.' from keys
    policy_state_dict = {k.replace('model.policy.', ''): v for k, v in policy_state_dict.items()}
    
    lit_module.model.policy.load_state_dict(policy_state_dict)
    return lit_module


# Generate few training data during setup for fast loading, since not needed
lit_module.train_size = 100 
lit_module = load_policy_state_dict(lit_module, checkpoints_path / 'epoch_099.ckpt')

lit_module.setup('test')

Unused kwargs: {'params': {'total': 708608, 'trainable': 708608, 'non_trainable': 0}}


No setup function for model required during testing!
No wrap_dataset function for model required during testing!


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

policy = lit_module.model.policy.to(device)
policy.eval()
env = lit_module.model.env

# dataloader = lit_module.test_dataloader()

In [6]:
test_dataset = lit_module.test_dataset
dataloader = lit_module._dataloader(test_dataset, batch_size=512)

with torch.no_grad():

    rewards = []

    for batch in dataloader:
        td = batch.to(device).clone()
        td = env.reset(td)
        td_out = policy(td, decode_type="greedy")
        rewards.append(td_out['reward'])

    rewards = torch.cat(rewards)
    print(rewards.mean())

tensor(-5.7785, device='cuda:0')


## Greedy multi-start

In [7]:
from rl4co.utils.ops import unbatchify, batchify, gather_by_index

In [8]:
test_dataset = lit_module.test_dataset
dataloader = lit_module._dataloader(test_dataset, batch_size=2048)


num_starts = env.num_loc
# num_starts
with torch.no_grad():

    rewards_list = []
    actions_list = []

    for batch in dataloader:
        td = batch.to(device).clone()
        td = env.reset(td)
        td_out = policy(td, decode_type="greedy_multistart", 
                        num_starts=num_starts, return_actions=True)
        
        # [batch_size, num_starts]
        rewards = unbatchify(td_out['reward'], num_starts)
        actions = unbatchify(td_out['actions'], num_starts)

        max_rewards, max_idxs = rewards.max(dim=1)
        best_actions = gather_by_index(actions, max_idxs, dim=1)
  
        rewards_list.append(max_rewards)
        actions_list.append(best_actions)

    rewards = torch.cat(rewards_list)
    actions = torch.cat(actions_list)
    print(rewards.mean())

tensor(-5.7668, device='cuda:0')


## Symmetric augmentations

In [9]:
from rl4co.models.zoo.symnco.augmentations import StateAugmentation as StateAugmentationN
from rl4co.models.zoo.pomo.augmentations import StateAugmentation as StateAugmentation8

In [10]:
test_dataset = lit_module.test_dataset
dataloader = lit_module._dataloader(test_dataset, batch_size=128)

# POMO
num_augment = 8
augmentation = StateAugmentation8(env.name, num_augment=num_augment)

# SymNCO
# num_augment = 8
# augmentation = StateAugmentationN(env.name, num_augment=num_augment)


with torch.no_grad():

    rewards_list = []
    actions_list = []

    for batch in dataloader:
        td = batch.to(device)
        td_orig = td.clone()

        td = augmentation(td)

        td = env.reset(td).clone()
        td_out = policy(td, return_actions=True)
        
        
        rewards = env.get_reward(batchify(td_orig, num_augment), td_out['actions'])
        rewards = unbatchify(rewards, num_augment)
        actions = unbatchify(td_out['actions'], num_augment)
        
        max_rewards, max_idxs = rewards.max(dim=1)
        best_actions = gather_by_index(actions, max_idxs, dim=1)
  
        rewards_list.append(max_rewards)
        actions_list.append(best_actions)

    rewards = torch.cat(rewards_list)
    actions = torch.cat(actions_list)
    print(rewards.mean())

tensor(-5.7185, device='cuda:0')


In [11]:
# get param device policy
print(next(policy.parameters()).device)

cuda:0


In [22]:

class EvalBase:
    name = "base"
    def __init__(self, env, policy, **kwargs):
        self.env = env
        self.policy = policy

    def __call__(self, dataloader, **kwargs):

        # Collect timings for evaluation (more accurate than timeit)
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()

        with torch.no_grad():
            rewards_list = []
            actions_list = []

            for batch in dataloader:
                td = batch.to(next(policy.parameters()).device)
                td = env.reset(td)
                actions, rewards = self._inner(td, **kwargs)
                rewards_list.append(rewards)
                actions_list.append(actions)

            rewards = torch.cat(rewards_list)
            actions = torch.cat(actions_list)


        end_event.record()
        torch.cuda.synchronize()
        inference_time = start_event.elapsed_time(end_event)
        # inference_time = 0

        print("Mean reward for {}: {:.4f}".format(self.name, rewards.mean()))
        print("Time: {:.4f}s".format(inference_time/1000))
    
        return {"actions": actions, "rewards": rewards, "inference_time": inference_time, 
                "name": self.name, "avg_reward": rewards.mean()}

    def _inner(self, td):
        raise NotImplementedError("Implement in subclass")


class GreedyEval(EvalBase):
    """Evaluates the policy using greedy decoding and single trajectory"""

    name = "greedy"
    def __init__(self, env, policy):
        super().__init__(env, policy)

    def _inner(self, td):
        td_out = self.policy(td.clone(), decode_type="greedy", num_starts=0, return_actions=True)
        rewards = self.env.get_reward(td, td_out['actions'])
        return td_out['actions'], rewards


class AugmentationEval(EvalBase):
    """Evaluates the policy via N state augmentations"""

    name = "augmentation"
    def __init__(self, env, policy, num_augment=8):
        super().__init__(env, policy)
        if num_augment == 8:
            self.augmentation = StateAugmentation8(env.name, num_augment=num_augment)
        else:
            self.augmentation = StateAugmentationN(env.name, num_augment=num_augment, normalize=True)

    def _inner(self, td, num_augment=None):
        if num_augment is None:
            num_augment = self.augmentation.num_augment
        td_init = td.clone()
        td = self.augmentation(td, num_augment=num_augment)
        td_out = self.policy(td.clone(), decode_type="greedy", num_starts=0, return_actions=True)
        
        rewards = self.env.get_reward(batchify(td_init, num_augment), td_out['actions'])
        rewards = unbatchify(rewards, num_augment)
        actions = unbatchify(td_out['actions'], num_augment)
        
        rewards, max_idxs = rewards.max(dim=1)
        actions = gather_by_index(actions, max_idxs, dim=1)
        return actions, rewards
    

class SamplingEval(EvalBase):
    """Evaluates the policy via N samples from the policy"""

    name = "sampling"
    def __init__(self, env, policy, samples, softmax_temp=None):
        super().__init__(env, policy)
        self.samples = samples
        self.softmax_temp = softmax_temp

    def _inner(self, td):
        td = batchify(td, self.samples)
        td_out = self.policy(td.clone(), decode_type="sampling", 
                             num_starts=0, return_actions=True, 
                             softmax_temp=self.softmax_temp)
        
        rewards = self.env.get_reward(td, td_out['actions'])
        rewards = unbatchify(rewards, self.samples)
        actions = unbatchify(td_out['actions'], self.samples)
        
        rewards, max_idxs = rewards.max(dim=1)
        actions = gather_by_index(actions, max_idxs, dim=1)
        return actions, rewards


class GreedyMultiStartEval(EvalBase):
    """Evaluates the policy via N samples from the policy"""

    name = "greedy_multistart"
    def __init__(self, env, policy, num_starts):
        super().__init__(env, policy)
        self.num_starts = num_starts

    def _inner(self, td):
        td_out = self.policy(td.clone(), decode_type="greedy_multistart",
                             num_starts=self.num_starts, return_actions=True)
        
        rewards = self.env.get_reward(td, td_out['actions'])
        rewards = unbatchify(rewards, self.num_starts)
        actions = unbatchify(td_out['actions'], self.num_starts)
        
        rewards, max_idxs = rewards.max(dim=1)
        actions = gather_by_index(actions, max_idxs, dim=1)
        return actions, rewards
    

class GreedyMultiStartAugmentEval(EvalBase):
    """Evaluates the policy via N samples from the policy"""

    name = "greedy_multistart_augment"
    def __init__(self, env, policy, num_starts, num_augment=8):
        super().__init__(env, policy)
        self.num_starts = num_starts
        if num_augment == 8:
            self.augmentation = StateAugmentation8(env.name, num_augment=num_augment)
        else:
            self.augmentation = StateAugmentationN(env.name, num_augment=num_augment)
            
    def _inner(self, td, num_augment=None):
        if num_augment is None:
            num_augment = self.augmentation.num_augment
        
        td_init = td.clone()

        td = self.augmentation(td, num_augment=num_augment)
        td_out = self.policy(td.clone(), decode_type="greedy_multistart",
                             num_starts=self.num_starts, return_actions=True)
        
        td = batchify(td_init, (num_augment, self.num_starts))

        rewards = self.env.get_reward(td, td_out['actions'])
        rewards = unbatchify(rewards, self.num_starts * num_augment)
        actions = unbatchify(td_out['actions'], self.num_starts * num_augment)
        
        rewards, max_idxs = rewards.max(dim=1)
        actions = gather_by_index(actions, max_idxs, dim=1)
        return actions, rewards

In [23]:
test_dataset = lit_module.test_dataset
dataloader = lit_module._dataloader(test_dataset, batch_size=2048)

# POMO
num_augment = 8
augmentation = StateAugmentation8(env.name, num_augment=num_augment)

# SymNCO
# num_augment = 8
# augmentation = StateAugmentationN(env.name, num_augment=num_augment)


greedy_eval = GreedyEval(env, policy)

ret_vals_greedy = greedy_eval(dataloader)

Mean reward for greedy: -5.7785
Time: 0.6495s


In [25]:
test_dataset = lit_module.test_dataset
dataloader = lit_module._dataloader(test_dataset, batch_size=16)

num_augment = 50

greedy_eval = GreedyMultiStartAugmentEval(env, policy, num_starts=env.num_loc, num_augment=num_augment)

ret_vals_greedy = greedy_eval(dataloader)

Mean reward for greedy_multistart_augment: -5.7022
Time: 115.0232s


In [None]:
# # POMO 8
# 5.7184

# # 8
# 5.7242
# 5.7244

# # 20
# 5.7120
# 5.7118

# # 50
# 5.7050
# 5.7050

In [None]:
actions.shape

torch.Size([10000, 50])

In [None]:
# actions.gather(dim=1, index=max_idxs[:, None])
