# SymNCO Lightning

In [4]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('../../')

import math
from typing import List, Tuple, Optional, NamedTuple, Dict, Union, Any
from einops import rearrange, repeat
from hydra.utils import instantiate

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.nn import DataParallel
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import lightning as L

from torchrl.envs import EnvBase
from torchrl.envs.utils import step_mdp
from tensordict import TensorDict

from ncobench.envs.tsp import TSPEnv
from ncobench.models.rl.reinforce import *
from ncobench.models.co.am.context import env_context
from ncobench.models.co.am.embeddings import env_init_embedding, env_dynamic_embedding
from ncobench.models.co.am.encoder import GraphAttentionEncoder
from ncobench.models.co.am.decoder import Decoder, decode_probs, PrecomputedCache, LogitAttention
from ncobench.models.co.am.policy import get_log_likelihood
from ncobench.models.nn.attention import NativeFlashMHA, flash_attn_wrapper
from ncobench.utils.lightning import get_lightning_device

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Novelty compared to `POMO`

Compared to the symmetricities in POMO, SymNCO introduces a new loss function:
$$\mathcal{L}_{total} = \mathcal{L}_{ps} + \beta \mathcal{L}_{ss} + \alpha \mathcal{L}_{inv}$$
where $\mathcal{L}_{ps}$ is the problem symmetricity loss, $\mathcal{L}_{ss}$ is the solution symmetricity loss, and $\mathcal{L}_{inv}$ is the invariant representation loss. The $\beta$ and $\alpha$ are hyperparameters that control the relative importance of the symmetricity and inverse losses. A projection head (MLP) is introduced to process the embeddings and calculate $\mathcal{L}_{inv}$.

In [5]:
# For easier debugging

from rich.traceback import install
install()

<bound method InteractiveShell.excepthook of <ipykernel.zmqshell.ZMQInteractiveShell object at 0x7f189dd97310>>

## Utilities: action selection, batching


In [6]:
# @torch.compile
def select_start_nodes(batch_size, num_nodes, device="cpu"):
    """Node selection strategy for POMO
    Selects different start nodes for each batch element
    """
    selected = torch.arange(num_nodes, device=device).repeat_interleave(batch_size, dim=0) # TODO: check
    # requires grad
    # selected.requires_grad_ = True # TODO check
    return selected


# @torch.compile
def repeat_batch(x, repeats):
    """Same as repeat on dim=0 for tensordicts as well
    Same as einops.repeat(x, 'b n d -> (r b) n d', r=repeats) but 50% faster
    """
    s = x.shape
    return x.expand(repeats, *s).contiguous().view(s[0] * repeats, *s[1:]) 


# @torch.compile
def undo_repeat_batch(x, repeats, dim=0):
    """Undoes repeat_batch
    Same as einops.rearrange(x, '(r b) ... -> r b ...', r=repeats) but 3x faster
    """
    s = x.shape
    return x.view(repeats, s[dim] // repeats, *[s[i] for i in range(len(s)) if i != dim])

# note that repeat is the first dimension!

In [4]:
# x.view(repeats, s[0] // repeats, *s[1:])
# same but with s[i] and [s[k] for k in len(s) if k != i]]]

In [7]:
from dataclasses import dataclass


@dataclass
class PrecomputedCache:
    node_embeddings: torch.Tensor
    glimpse_key: torch.Tensor
    glimpse_val: torch.Tensor
    logit_key: torch.Tensor


class Decoder(nn.Module):
    def __init__(self, env, embedding_dim, num_heads, num_starts=20, **logit_attn_kwargs):
        super(Decoder, self).__init__()

        self.env = env
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads

        assert embedding_dim % num_heads == 0

        self.context = env_context(self.env.name, {"embedding_dim": embedding_dim})
        self.dynamic_embedding = env_dynamic_embedding(
            self.env.name, {"embedding_dim": embedding_dim}
        )

        # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim
        self.project_node_embeddings = nn.Linear(
            embedding_dim, 3 * embedding_dim, bias=False
        )
        self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=False)

        # MHA
        self.logit_attention = LogitAttention(
            embedding_dim, num_heads, **logit_attn_kwargs
        )

        # POMO
        self.num_starts = max(num_starts, 1) # POMO = 1 is just normal REINFORCE

    def forward(self, td, embeddings, decode_type="sampling"):
        # Collect outputs
        outputs = []
        actions = []

        if self.num_starts > 1:
            # POMO: first action is decided via select_start_nodes
            action = select_start_nodes(batch_size=td.shape[0], num_nodes=self.num_starts, device=td.device)

            # # Expand td to batch_size * num_starts
            td = repeat_batch(td, self.num_starts)

            td.set("action", action[:, None])
            td = self.env.step(td)["next"]
            log_p = torch.zeros_like(td['action_mask'], device=td.device) # first log_p is 0, so p = log_p.exp() = 1

            outputs.append(log_p.squeeze(1))
            actions.append(action)
        
        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        cached_embeds = self._precompute(embeddings)        

        # Here we suppose all the batch is done at the same time
        while not td["done"].any():  
            # Compute the logits for the next node
            log_p, mask = self._get_log_p(cached_embeds, td)

            # Select the indices of the next nodes in the sequences, result (batch_size) long
            action = decode_probs(
                log_p.exp().squeeze(1), mask.squeeze(1), decode_type=decode_type
            )

            # Step the environment
            td.set("action", action[:, None])
            td = self.env.step(td)["next"]

            # Collect output of step
            outputs.append(log_p.squeeze(1))
            actions.append(action)

        outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1)
        td.set("reward", self.env.get_reward(td, actions))
        return outputs, actions, td
    
    def _precompute(self, embeddings):       
        # The projection of the node embeddings for the attention is calculated once up front
        (
            glimpse_key_fixed,
            glimpse_val_fixed,
            logit_key_fixed,
        ) = self.project_node_embeddings(embeddings[:, None, :, :]).chunk(3, dim=-1)

        # Organize in a dataclass for easy access
        cached_embeds = PrecomputedCache(
            node_embeddings=repeat_batch(embeddings, self.num_starts),
            glimpse_key=repeat_batch(self.logit_attention._make_heads(glimpse_key_fixed), self.num_starts),
            glimpse_val=repeat_batch(self.logit_attention._make_heads(glimpse_val_fixed), self.num_starts),
            logit_key=repeat_batch(logit_key_fixed, self.num_starts)
        )

        return cached_embeds

    def _get_log_p(self, cached, td):
        # Compute the query based on the context (computes automatically the first and last node context)
        step_context = self.context(cached.node_embeddings, td)
        query = step_context # in POMO, no graph context (trick for overfit) # [batch, 1, embed_dim]

        # Compute keys and values for the nodes
        glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = self.dynamic_embedding(td)
        glimpse_key = cached.glimpse_key + glimpse_key_dynamic
        glimpse_key = cached.glimpse_val + glimpse_val_dynamic
        logit_key = cached.logit_key + logit_key_dynamic

        # Get the mask
        mask = ~td["action_mask"]
        mask = mask.unsqueeze(1) if mask.dim() == 2 else mask

        # Compute logits
        log_p = self.logit_attention(query, glimpse_key, glimpse_key, logit_key, mask)

        return log_p, mask

In [8]:
from torchrl.modules.models import MLP

embedding_dim = 128
a = MLP(embedding_dim, embedding_dim, 1, embedding_dim, nn.ReLU)
print(a)

MLP(
  (0): Linear(in_features=128, out_features=128, bias=True)
  (1): ReLU()
  (2): Linear(in_features=128, out_features=128, bias=True)
)


In [9]:
from torchrl.modules.models import MLP


class SymNCOPolicy(nn.Module):

    def __init__(self,
                 env: EnvBase,
                 embedding_dim: int,
                 hidden_dim: int,
                 encoder: nn.Module = None,
                 decoder: nn.Module = None,
                 projection_head: nn.Module = None,
                 num_starts: int = 10,
                 num_encode_layers: int = 3,
                 normalization: str = 'batch',
                 num_heads: int = 8,
                 checkpoint_encoder: bool = False,
                 mask_inner: bool = True,
                 force_flash_attn: bool = False,
                 **kwargs
                 ):
        """
        Differences with AM and POMO: proj_head for the embeddings
        """
        super(SymNCOPolicy, self).__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_encode_layers = num_encode_layers
        self.env = env

        self.num_heads = num_heads
        self.checkpoint_encoder = checkpoint_encoder
        self.num_starts = num_starts

        self.init_embedding = env_init_embedding(self.env.name, {"embedding_dim": embedding_dim})

        self.encoder = GraphAttentionEncoder(
            num_heads=num_heads,
            embed_dim=embedding_dim,
            num_layers=self.num_encode_layers,
            normalization=normalization,
            force_flash_attn=force_flash_attn,
        ) if encoder is None else encoder
        
        self.decoder = Decoder(env, embedding_dim, num_heads, num_starts=num_starts, mask_inner=mask_inner, force_flash_attn=force_flash_attn) if decoder is None else decoder
        self.projection_head = MLP(embedding_dim, embedding_dim, 1, embedding_dim, nn.ReLU) if projection_head is None else projection_head

    def forward(self, td: TensorDict, phase: str = "train", decode_type: str = "sampling", return_actions: bool = False) -> TensorDict:
        """Given observation, precompute embeddings and rollout"""

        # Set decoding type for policy, can be also greedy
        embeddings = self.init_embedding(td)
        proj_embeddings = self.projection_head(embeddings)
        encoded_inputs, _ = self.encoder(embeddings)

        # Main rollout
        _log_p, actions, td = self.decoder(td, encoded_inputs, decode_type)

        # Log likelyhood is calculated within the model since returning it per action does not work well with
        ll = get_log_likelihood(_log_p, actions, td.get('mask', None))
        out = {"reward": td["reward"], "log_likelihood": ll, "proj_embeddings": proj_embeddings, "actions": actions if return_actions else None}

        return out

## Test the Policy only

In [10]:
num_loc = 15
env = TSPEnv(num_loc=num_loc).transform()

dataset = env.dataset(batch_size=[10000])

dataloader = DataLoader(
                dataset,
                batch_size=32,
                shuffle=False, # no need to shuffle, we're resampling every epoch
                num_workers=0,
                collate_fn=torch.stack, # we need this to stack the batches in the dataset
            )

model = SymNCOPolicy(
    env,
    embedding_dim=128,
    hidden_dim=128,
    num_encode_layers=3,
    num_starts=num_loc,
    # force_flash_attn=True,
).to("cuda")

# model = torch.compile(model)

x = next(iter(dataloader)).to("cuda")
x = env.reset(init_obs=x)

out = model(x, decode_type="sampling")

## Create full model: `env` + `policy` + `baseline`

In [34]:
def SR_transform(x, y, idx):
    if idx < 0.5:
        phi = idx * 4 * math.pi
    else:
        phi = (idx - 0.5) * 4 * math.pi

    x = x - 1 / 2
    y = y - 1 / 2

    x_prime = torch.cos(phi) * x - torch.sin(phi) * y
    y_prime = torch.sin(phi) * x + torch.cos(phi) * y

    if idx < 0.5:
        dat = torch.cat((x_prime + 1 / 2, y_prime + 1 / 2), dim=2)
    else:
        dat = torch.cat((y_prime + 1 / 2, x_prime + 1 / 2), dim=2)
    return dat


def augment_xy_data_by_N_fold(problems, N, depot=None):
    x = problems[:, :, [0]]
    y = problems[:, :, [1]]

    if depot is not None:
        x_depot = depot[:, :, [0]]
        y_depot = depot[:, :, [1]]
    idx = torch.rand(N - 1)

    for i in range(N - 1):

        problems = torch.cat((problems, SR_transform(x, y, idx[i])), dim=0)
        if depot is not None:
            depot = torch.cat((depot, SR_transform(x_depot, y_depot, idx[i])), dim=0)

    if depot is not None:
        return problems, depot.view(-1, 2)

    return problems


def augment(input, N,problem):
    is_vrp = problem.NAME == 'cvrp' or problem.NAME == 'sdvrp'
    is_orienteering = problem.NAME == 'op'
    is_pctsp = problem.NAME == 'pctsp'
    if is_vrp or is_orienteering or is_pctsp:
        if is_vrp:
            features = ('demand',)
        elif is_orienteering:
            features = ('prize','max_length')
        else:
            assert is_pctsp
            features = ('deterministic_prize', 'penalty')

        input['loc'], input['depot'] = augment_xy_data_by_N_fold(input['loc'], N, depot=input['depot'].view(-1, 1, 2))

        for feat in features:
            input[feat] = input[feat].repeat(N, 1)
        if is_orienteering:
            input['max_length'] = input['max_length'].view(-1)
        return input

        # TSP
    return augment_xy_data_by_N_fold(input, N)




# Test the above

a = torch.rand(10, 10, 2)
b = torch.rand(10, 10, 2)

# input_ = {'loc': a}#, 'depot': b}

@dataclass
class Problem:
    NAME: str = 'tsp'

problem = Problem()
z = augment(a, 10, problem)
# print(z.shape, z)
# print(
    # print mean, std
print(z.mean(), z.std(), z.min(), z.max())
print(z.shape)




tensor(0.4930) tensor(0.2905) tensor(-0.1644) tensor(1.1442)
torch.Size([100, 10, 2])


In [35]:
%timeit augment(a, 10, problem)

674 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [36]:
def env_aug_feats(env_name: str) -> Tuple[str, ...]:
    return ('observation', 'depot') if env_name == "op" else ('observation',)


def rotation_reflection_transform(x, y, phi, offset=0.5):
    x, y = x - offset, y - offset
    
    x_prime = torch.cos(phi) * x - torch.sin(phi) * y
    y_prime = torch.sin(phi) * x + torch.cos(phi) * y
    # make random reflection with mask
    mask = phi > 0.5
    x_prime = torch.where(mask, y_prime, x_prime)
    y_prime = torch.where(mask, x_prime, y_prime)
    return torch.cat((x_prime + offset, y_prime + offset), dim=-1)


def augment_xy_data_by_n_fold(x, y, num_augment: int = 8):
    phi = torch.rand(num_augment, device=x.device) * 2 * math.pi
    phi = phi[None, :, None]
    return rotation_reflection_transform(x, y, phi)


class StateAugmentation(nn.Module):
    def __init__(self, env_name, num_augment: int = 8):
        """Augment state by N times via symmetric rotation transform"""
        super(StateAugmentation, self).__init__()
        self.num_augment = num_augment
        self.augmentation = augment_xy_data_by_n_fold
        self.feats = env_aug_feats(env_name)

    def forward(self, td: TensorDict) -> TensorDict:
        td_aug = repeat_batch(td, self.num_augment)
        for feat in self.feats:
            x, y = td_aug[feat][...,[0]], td_aug[feat][...,[1]]
            aug_feat = self.augmentation(x, y, self.num_augment)
            td_aug[feat] = aug_feat
        return td_aug
    

augmentation = StateAugmentation("tsp", num_augment=10)

a = TensorDict({'observation': torch.rand(10, 10, 2)}, batch_size=10)

z = augmentation(a)['observation']
# print(b['observation'].shape)

print(z.mean(), z.std(), z.min(), z.max())
print(z.shape)

tensor(0.5121) tensor(0.2761) tensor(-0.1003) tensor(1.0442)
torch.Size([100, 10, 2])


In [10]:
import torch.nn.functional as F

from ncobench.models.rl.reinforce import SharedBaseline


def get_best_actions(actions, max_idxs):
    actions = undo_repeat_batch(actions, max_idxs.shape[0])
    return actions.gather(0, max_idxs[..., None, None])


class SymNCO(nn.Module):
    def __init__(self, env, policy, baseline=None, num_augment=8, **kwargs):
        super().__init__()
        self.env = env
        self.policy = policy
        self.baseline = SharedBaseline() if baseline is None else baseline
        if not isinstance(self.baseline, SharedBaseline):
            print("Baseline is not SharedBaseline, used for ")

        # POMO parameters
        self.num_starts = policy.num_starts
        self.num_augment = num_augment
        assert num_augment > 0, "Number of augmentations must be greater than 0 for SymNCO"
        self.augment = StateAugmentation(env.name, num_augment)

    def forward(self, td: TensorDict, phase: str="train", decode_type: str="sampling", return_actions: bool=False) -> TensorDict:
        """Evaluate model, get costs and log probabilities and compare with baseline"""

        # Augment data
        if phase == "train":
            td = self.augment(td)

        # Evaluate model, get costs and log probabilities
        out = self.policy(td, decode_type=decode_type, return_actions=return_actions)

        if phase == "train":
            costs = undo_repeat_batch(-out['reward'], self.policy.n_starts)
            ll = undo_repeat_batch(out['log_likelihood'], self.policy.n_starts)

            # Evaluate baseline
            # num_augment, num_starts, batch, graph_size, 2
            # (num_augment, num_starts), batch, graph_size, 2
            # [(num_augment, num_starts), batch, graph_size, 2  ].mean(0) --> baseline
            bl_val, bl_loss = self.baseline.eval(td, costs, on_dim=0)

            # Calculate REINFORCE loss
            advantage = costs - bl_val
            reinforce_loss = (advantage * ll).mean()
            loss = reinforce_loss + bl_loss

            #### TODO

        # Multi-start rollout as in POMO
        # [n_starts, num_augment, batch]
        pomo_retvals = {}
        if self.num_starts > 1:
            reward = undo_repeat_batch(undo_repeat_batch(out["reward"], self.num_augment), self.num_starts, dim=1)
            max_reward, max_idxs = reward.max(dim=0)
            pomo_retvals = {"max_reward": max_reward, "best_actions": get_best_actions(out["actions"], max_idxs) if return_actions else None}

        # Get augmentation score only during inference
        aug_retvals = {}
        if phase != "train" and self.augment is not None:
            # [n_augment, batch]
            aug_reward = undo_repeat_batch(max_reward, self.num_augment)
            max_aug_reward, max_idxs = aug_reward.max(dim=0)
            aug_retvals = {"max_aug_reward": max_aug_reward, "best_aug_actions": get_best_actions(out["actions"], max_idxs) if return_actions else None}
 
        return { **out, **pomo_retvals, **aug_retvals}
        
    def setup(self, lit_module):
        # Make baseline taking model itself and train_dataloader from model as input
        if hasattr(self.baseline, "setup"):
            self.baseline.setup(self.policy, lit_module.train_dataloader(), self.env, device=get_lightning_device(lit_module))
    
    def on_train_epoch_end(self, lit_module):
        # self.baseline.epoch_callback(self.policy, self.env, pl_module)
        self.baseline.epoch_callback(self.policy, lit_module.val_dataloader(), lit_module.current_epoch, self.env, device=get_lightning_device(lit_module))

## Lightning Module

In [11]:
class NCOLightningModule(L.LightningModule):
    def __init__(self, env, model, lr=1e-4, batch_size=128, train_size=1000, val_size=10000):
        super().__init__()

        # TODO: hydra instantiation
        self.env = env
        self.model = model
        self.lr = lr
        self.batch_size = batch_size
        self.train_size = train_size
        self.val_size = val_size

    def setup(self, stage="fit"):
        self.train_dataset = self.env.dataset(self.train_size)
        self.val_dataset = self.env.dataset(self.val_size)
        if hasattr(self.model, "setup"):
            self.model.setup(self)

    def shared_step(self, batch: Any, batch_idx: int, phase: str):
        td = self.env.reset(init_obs=batch)
        output = self.model(td, phase)
        
        # output = self.model(batch, phase)
        self.log(f"{phase}/cost", -output["reward"].mean(), prog_bar=True)
        self.log(f"{phase}/pomo_cost", -output["max_reward"].mean(), prog_bar=True)
        if phase != "train" and self.model.n_augment > 1:
            self.log(f"{phase}/aug_cost", -output["max_aug_reward"].mean(), prog_bar=True)
        
        return {"loss": output['loss']}

    def training_step(self, batch: Any, batch_idx: int):   
        return self.shared_step(batch, batch_idx, phase='train')

    def validation_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='val')

    def test_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='test')

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-6)
        # optim = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)
        # TODO: scheduler
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_steps)
        return [optim] #, [scheduler]
    
    def train_dataloader(self):
        return self._dataloader(self.train_dataset)
    
    def val_dataloader(self):
        return self._dataloader(self.val_dataset)
    
    def on_train_epoch_end(self):
        if hasattr(self.model, "on_train_epoch_end"):
            self.model.on_train_epoch_end(self)
        self.train_dataset = self.env.dataset(self.train_size) 
       
    def _dataloader(self, dataset):
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False, # no need to shuffle, we're resampling every epoch
            num_workers=0,
            collate_fn=torch.stack, # we need this to stack the batches in the dataset
            pin_memory=self.on_gpu,
        )

## Main training setup

In [12]:
# Hyperparameters
epochs = 1
batch_size = 64 #1024 #512
n_loc = 20
train_size = 1280000
lr = 1e-4
n_starts = num_loc # TODO: comment to try out = 1
# num_pomo = 1 # set to 1: similar to simple AM

# Environment
env = TSPEnv(num_loc=n_loc).transform()

# Policy
policy = SymNCOPolicy(
    env,
    num_starts=n_starts,
    embedding_dim=128,
    hidden_dim=128,
    num_encode_layers=3,
    # force_flash_attn=True,
)

# Baseline
# baseline = WarmupBaseline(RolloutBaseline())
baseline = SharedBaseline() # TODO: uncomment

# Create RL model
model = POMO(env, policy, baseline)

# Create Lightning module (for training)
lit_model = NCOLightningModule(env, model, batch_size=batch_size, train_size=train_size, lr=lr)

## Fit model

In [13]:
# Trick to make calculations faster
torch.set_float32_matmul_precision("medium")

# Trainer
trainer = L.Trainer(
    max_epochs=epochs,
    accelerator="gpu",
    devices=[1],
    logger=None, # can replace with WandbLogger, TensorBoardLogger, etc.
    # precision=16, # uncomment to make faster
    log_every_n_steps=100,   
    gradient_clip_val=1.0, # clip gradients to avoid exploding gradients!
)

# Fit the model
trainer.fit(lit_model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
trainer.validate(lit_model)