# SymNCO Lightning

In [1]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('../../')

import math
from typing import List, Tuple, Optional, NamedTuple, Dict, Union, Any
from einops import rearrange, repeat
from hydra.utils import instantiate

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.nn import DataParallel
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import lightning as L

from torchrl.envs import EnvBase
from torchrl.envs.utils import step_mdp
from tensordict import TensorDict

from rl4co.envs.tsp import TSPEnv
from rl4co.models.rl.reinforce import *
from rl4co.models.zoo.am.context import env_context
from rl4co.models.zoo.am.embeddings import env_init_embedding, env_dynamic_embedding
from rl4co.models.zoo.am.encoder import GraphAttentionEncoder
from rl4co.models.zoo.am.decoder import Decoder, decode_probs, PrecomputedCache, LogitAttention
from rl4co.models.zoo.am.policy import get_log_likelihood
from rl4co.models.nn.attention import NativeFlashMHA, flash_attn_wrapper
from rl4co.utils.lightning import get_lightning_device
from rl4co.utils.ops import batchify, unbatchify

  warn(


## Novelty compared to `POMO`

Compared to the symmetricities in POMO, SymNCO introduces a new loss function:
$$\mathcal{L}_{total} = \mathcal{L}_{ps} + \beta \mathcal{L}_{ss} + \alpha \mathcal{L}_{inv}$$
where $\mathcal{L}_{ps}$ is the problem symmetricity loss, $\mathcal{L}_{ss}$ is the solution symmetricity loss, and $\mathcal{L}_{inv}$ is the invariant representation loss. The $\beta$ and $\alpha$ are hyperparameters that control the relative importance of the symmetricity and inverse losses. A projection head (MLP) is introduced to process the embeddings and calculate $\mathcal{L}_{inv}$.

In [2]:
# For easier debugging

from rich.traceback import install
install()

<bound method InteractiveShell.excepthook of <ipykernel.zmqshell.ZMQInteractiveShell object at 0x7f751fc4e980>>

## Utilities: action selection, batching


In [3]:
def select_start_nodes(batch_size, num_nodes, device="cpu"):
    """Node selection strategy for POMO
    Selects different start nodes for each batch element
    """
    selected = torch.arange(num_nodes, device=device).repeat(batch_size) # TODO: check
    return selected

In [4]:
# x.view(repeats, s[0] // repeats, *s[1:])
# same but with s[i] and [s[k] for k in len(s) if k != i]]]

In [5]:
from dataclasses import dataclass


@dataclass
class PrecomputedCache:
    node_embeddings: torch.Tensor
    graph_context: torch.Tensor
    glimpse_key: torch.Tensor
    glimpse_val: torch.Tensor
    logit_key: torch.Tensor



class Decoder(nn.Module):
    def __init__(self, env, embedding_dim, num_heads, num_starts=20, use_graph_context=True, **logit_attn_kwargs):
        super(Decoder, self).__init__()

        self.env = env
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads

        assert embedding_dim % num_heads == 0

        self.context = env_context(self.env.name, {"embedding_dim": embedding_dim})
        self.dynamic_embedding = env_dynamic_embedding(
            self.env.name, {"embedding_dim": embedding_dim}
        )

        # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim
        self.project_node_embeddings = nn.Linear(
            embedding_dim, 3 * embedding_dim, bias=False
        )
        self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=False)

        # MHA
        self.logit_attention = LogitAttention(
            embedding_dim, num_heads, **logit_attn_kwargs
        )

        # POMO
        self.num_starts = max(num_starts, 1) # POMO = 1 is just normal REINFORCE
        self.use_graph_context = use_graph_context # disabling makes it like in POMO

    def forward(self, td, embeddings, decode_type="sampling"):
        # Collect outputs
        outputs = []
        actions = []

        if self.num_starts > 1:
            # POMO: first action is decided via select_start_nodes
            action = select_start_nodes(batch_size=td.shape[0], num_nodes=self.num_starts, device=td.device)

            # # Expand td to batch_size * num_starts
            td = batchify(td, self.num_starts)

            td.set("action", action[:, None])
            td = self.env.step(td)["next"]
            log_p = torch.zeros_like(td['action_mask'], device=td.device) # first log_p is 0, so p = log_p.exp() = 1

            outputs.append(log_p.squeeze(1))
            actions.append(action)
        
        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        cached_embeds = self._precompute(embeddings)        

        while not td["done"].all():  
            # Compute the logits for the next node
            log_p, mask = self._get_log_p(cached_embeds, td)

            # Select the indices of the next nodes in the sequences, result (batch_size) long
            action = decode_probs(
                log_p.exp().squeeze(1), mask.squeeze(1), decode_type=decode_type
            )

            # Step the environment
            td.set("action", action[:, None])
            td = self.env.step(td)["next"]

            # Collect output of step
            outputs.append(log_p.squeeze(1))
            actions.append(action)

        outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1)
        td.set("reward", self.env.get_reward(td, actions))
        return outputs, actions, td
    
    def _precompute(self, embeddings):       
        # The projection of the node embeddings for the attention is calculated once up front
        (
            glimpse_key_fixed,
            glimpse_val_fixed,
            logit_key_fixed,
        ) = self.project_node_embeddings(embeddings[:, None, :, :]).chunk(3, dim=-1)


        # In POMO, no graph context (trick for overfit to single graph size) # [batch, 1, embed_dim]
        graph_context = batchify(self.project_fixed_context(embeddings.mean(1))[:, None, :], self.num_starts) if self.use_graph_context else 0
        
        # Organize in a dataclass for easy access
        cached_embeds = PrecomputedCache(
            node_embeddings=batchify(embeddings, self.num_starts),
            graph_context=graph_context,
            glimpse_key=batchify(self.logit_attention._make_heads(glimpse_key_fixed), self.num_starts),
            glimpse_val=batchify(self.logit_attention._make_heads(glimpse_val_fixed), self.num_starts),
            logit_key=batchify(logit_key_fixed, self.num_starts)
        )
        return cached_embeds

    def _get_log_p(self, cached, td):
        # Compute the query based on the context (computes automatically the first and last node context)
        step_context = self.context(cached.node_embeddings, td)
        query = step_context + cached.graph_context

        # Compute keys and values for the nodes
        glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = self.dynamic_embedding(td)
        glimpse_key = cached.glimpse_key + glimpse_key_dynamic
        glimpse_key = cached.glimpse_val + glimpse_val_dynamic
        logit_key = cached.logit_key + logit_key_dynamic

        # Get the mask
        mask = ~td["action_mask"]
        mask = mask.unsqueeze(1) if mask.dim() == 2 else mask

        # Compute logits
        log_p = self.logit_attention(query, glimpse_key, glimpse_key, logit_key, mask)

        return log_p, mask

In [6]:
from torchrl.modules.models import MLP


class SymNCOPolicy(nn.Module):

    def __init__(self,
                 env: EnvBase,
                 embedding_dim: int,
                 hidden_dim: int,
                 encoder: nn.Module = None,
                 decoder: nn.Module = None,
                 projection_head: nn.Module = None,
                 num_starts: int = 10,
                 num_encode_layers: int = 3,
                 normalization: str = 'batch',
                 num_heads: int = 8,
                 checkpoint_encoder: bool = False,
                 mask_inner: bool = True,
                 force_flash_attn: bool = False,
                 **kwargs
                 ):
        """
        Differences with AM and POMO: proj_head for the embeddings
        """
        super(SymNCOPolicy, self).__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_encode_layers = num_encode_layers
        self.env = env

        self.num_heads = num_heads
        self.checkpoint_encoder = checkpoint_encoder
        self.num_starts = num_starts

        self.init_embedding = env_init_embedding(self.env.name, {"embedding_dim": embedding_dim})

        self.encoder = GraphAttentionEncoder(
            num_heads=num_heads,
            embed_dim=embedding_dim,
            num_layers=self.num_encode_layers,
            normalization=normalization,
            force_flash_attn=force_flash_attn,
        ) if encoder is None else encoder
        
        self.decoder = Decoder(env, embedding_dim, num_heads, num_starts=num_starts, mask_inner=mask_inner, force_flash_attn=force_flash_attn) if decoder is None else decoder
        self.projection_head = MLP(embedding_dim, embedding_dim, 1, embedding_dim, nn.ReLU) if projection_head is None else projection_head

    def forward(self, td: TensorDict, phase: str = "train", decode_type: str = "sampling", return_actions: bool = False) -> TensorDict:
        """Given observation, precompute embeddings and rollout"""

        # Set decoding type for policy, can be also greedy
        embeddings = self.init_embedding(td)
        proj_embeddings = self.projection_head(embeddings)
        encoded_inputs, _ = self.encoder(embeddings)

        # Main rollout
        log_p, actions, td = self.decoder(td, encoded_inputs, decode_type)

        # Log likelyhood is calculated within the model since returning it per action does not work well with
        ll = get_log_likelihood(log_p, actions, td.get('mask', None))
        out = {"reward": td["reward"], "log_likelihood": ll, "proj_embeddings": proj_embeddings, "actions": actions if return_actions else None}

        return out

## Test the Policy only

In [7]:
num_loc = 15
env = TSPEnv(num_loc=num_loc).transform()

dataset = env.dataset(batch_size=[10000])

dataloader = DataLoader(
                dataset,
                batch_size=32,
                shuffle=False, # no need to shuffle, we're resampling every epoch
                num_workers=0,
                collate_fn=torch.stack, # we need this to stack the batches in the dataset
            )

model = SymNCOPolicy(
    env,
    embedding_dim=128,
    hidden_dim=128,
    num_encode_layers=3,
    num_starts=num_loc,
    # force_flash_attn=True,
).to("cuda")

# model = torch.compile(model)

x = next(iter(dataloader)).to("cuda")
x = env.reset(init_obs=x)

out = model(x, decode_type="sampling")

## Create full model: `env` + `policy` + `baseline`

In [8]:
import torch

# Create a random tensor of shape [64, 20, 2]
x = torch.randn(64, 20, 2)

# Reflect the tensor along the second dimension
reflected = torch.flip(x, dims=[2])
print(reflected.shape)

torch.Size([64, 20, 2])


In [9]:
def env_aug_feats(env_name: str) -> Tuple[str, ...]:
    return ('observation', 'depot') if env_name == "op" else ('observation',)


def rotation_reflection_transform(x, y, phi, offset=0.5):
    """SR group transform with rotation and reflection (~2x faster than original)"""
    x, y = x - offset, y - offset
    # random rotation
    x_prime = torch.cos(phi) * x - torch.sin(phi) * y
    y_prime = torch.sin(phi) * x + torch.cos(phi) * y
    # make random reflection if phi > 2*pi (i.e. 50% of the time)
    mask = phi > 2 * math.pi
    # vectorized random reflection: swap axes x and y if mask
    xy = torch.cat((x_prime, y_prime), dim=-1)
    xy = torch.where(mask, xy.flip(-1), xy)
    return xy + offset


def augment_xy_data_by_n_fold(xy, num_augment: int = 8):
    """Augment xy data by N times via symmetric rotation transform and concatenate to original data"""
    # create random rotation angles (4*pi for reflection, 2*pi for rotation)
    phi = torch.rand(xy.shape[0], device=xy.device) * 4 * math.pi 
    # set phi to 0 for first , i.e. no augmnetation as in original paper
    phi[:xy.shape[0]//num_augment] = 0.0
    x, y = xy[..., [0]], xy[..., [1]]
    return rotation_reflection_transform(x, y, phi[:, None, None])


class StateAugmentation(nn.Module):
    def __init__(self, env_name, num_augment: int = 8):
        """Augment state by N times via symmetric rotation transform"""
        super(StateAugmentation, self).__init__()
        self.num_augment = num_augment
        self.augmentation = augment_xy_data_by_n_fold
        self.feats = env_aug_feats(env_name)

    def forward(self, td: TensorDict) -> TensorDict:
        td_aug = batchify(td, self.num_augment)
        for feat in self.feats:
            aug_feat = self.augmentation(td_aug[feat], self.num_augment)
            td_aug[feat] = aug_feat
        return td_aug
    

augmentation = StateAugmentation("tsp", num_augment=10)

td = TensorDict({'observation': torch.rand(64, 20, 2)}, batch_size=64)

td_aug = augmentation(td)
z = td_aug['observation']
# print(b['observation'].shape)
a_ = td['observation']
print(a_.mean(), a_.std(), a_.min(), a_.max())
print(z.mean(), z.std(), z.min(), z.max())
print(z.shape)

tensor(0.5022) tensor(0.2890) tensor(0.0015) tensor(0.9999)
tensor(0.4991) tensor(0.2890) tensor(-0.1906) tensor(1.1769)
torch.Size([640, 20, 2])


In [10]:
# Sanity check

# randomly sample 20 different integers from 0 to 20 only once
a = torch.randperm(20)[:20]
td_unaug = unbatchify(td_aug, 10)

for idx in range(5):
    rew0 = env.get_reward(td_unaug[:,idx], a[None]).mean()
    rew1 = env.get_reward(td, a[None]).mean()
    print(r"aug {:.3f}".format(rew0.item()), r"unaug {:.3f}".format(rew1.item()))


aug -10.682 unaug -10.682
aug -10.682 unaug -10.682
aug -10.682 unaug -10.682
aug -10.682 unaug -10.682
aug -10.682 unaug -10.682


In [11]:
def SR_transform(x, y, idx):
    if idx < 0.5:
        phi = idx * 4 * math.pi
    else:
        phi = (idx - 0.5) * 4 * math.pi

    x = x - 1 / 2
    y = y - 1 / 2

    x_prime = torch.cos(phi) * x - torch.sin(phi) * y
    y_prime = torch.sin(phi) * x + torch.cos(phi) * y

    if idx < 0.5:
        dat = torch.cat((x_prime + 1 / 2, y_prime + 1 / 2), dim=2)
    else:
        dat = torch.cat((y_prime + 1 / 2, x_prime + 1 / 2), dim=2)
    return dat


def augment_xy_data_by_N_fold(problems, N, depot=None):
    x = problems[:, :, [0]]
    y = problems[:, :, [1]]

    if depot is not None:
        x_depot = depot[:, :, [0]]
        y_depot = depot[:, :, [1]]
    idx = torch.rand(N - 1)

    for i in range(N - 1):

        problems = torch.cat((problems, SR_transform(x, y, idx[i])), dim=0)
        if depot is not None:
            depot = torch.cat((depot, SR_transform(x_depot, y_depot, idx[i])), dim=0)

    if depot is not None:
        return problems, depot

    return problems

In [12]:
x = torch.rand(1024, 100, 2).cuda()
td = TensorDict({'observation': x}, batch_size=1024)

In [13]:
%timeit augment_xy_data_by_N_fold(x, 10)

579 µs ± 1.17 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [14]:
%timeit augmentation(td)

285 µs ± 404 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [12]:
import torch.nn.functional as F
from einops import rearrange

def problem_symmetricity_loss(reward, log_likelihood, dim=1):
    """
    REINFORCE loss for problem symmetricity
    Baseline is the average reward for all augmented problems
    """
    num_augment = reward.shape[dim]
    if num_augment < 2:
        return 0
    advantage = reward - reward.mean(dim=dim, keepdim=True)
    loss = -advantage * log_likelihood
    return loss.mean()


def solution_symmetricity_loss(reward, log_likelihood, dim=2):
    """
    REINFORCE loss for solution symmetricity
    Baseline is the average reward for all start nodes
    """    
    num_starts = reward.shape[dim]
    if num_starts < 2:
        return 0
    advantage = reward - reward.mean(dim=dim, keepdim=True)
    loss = -advantage * log_likelihood
    return loss.mean()


def invariance_loss(proj_embed, num_augment):
    """Loss for invariant representation on projected nodes"""
    pe = rearrange(proj_embed, '(b a) ... -> b a ...', a=num_augment)
    similarity = sum([F.cosine_similarity(pe[:, 0], pe[:, i], dim=-1) for i in range(1, num_augment)])
    return similarity.mean()

In [14]:
from rl4co.models.rl.reinforce import NoBaseline
from rl4co.models.zoo.pomo.utils import get_best_actions

# def get_best_actions(actions, max_idxs):
#     actions = undo_repeat_batch(actions, max_idxs.shape[0])
#     return actions.gather(0, max_idxs[..., None, None])


class SymNCO(nn.Module):
    def __init__(self, env, policy, baseline=None, num_augment=8, alpha=0.2, beta=1, augment_test=True, **kwargs):
        super().__init__()
        self.env = env
        self.policy = policy
        if baseline is not None:
            print("SymNCO uses baselines in the loss functions, so we do not set the baseline here.")
        self.baseline = NoBaseline() # done in loss function

        # Multi-start parameters
        self.num_starts = getattr(policy, "num_starts", 1)
        self.num_augment = num_augment 
        assert num_augment > 1, "Number of augmentations must be greater than 1 for SymNCO"
        self.augment = StateAugmentation(env.name, num_augment)
        self.augment_test = augment_test
        self.alpha = alpha # weight for invariance loss
        self.beta = beta # weight for solution symmetricity loss
    

    def forward(self, td: TensorDict, phase: str="train", **policy_kwargs) -> TensorDict:
        """Evaluate model, get costs and log probabilities and compare with baseline"""

        # Init vals
        loss_retvals, multi_start_retvals, aug_retvals = {}, {}, {}
        return_action = policy_kwargs.get("return_actions", False)

        # Augment data
        if phase == "train" or self.augment_test:
            td = self.augment(td)
            aug_size = self.num_augment # reward to [batch_size, num_augment, num_starts]
        else:
            aug_size = 1

        # Evaluate model, get costs and log probabilities and more
        out = self.policy(td, **policy_kwargs)
        reward = unbatchify(unbatchify(out["reward"], self.num_starts), aug_size)

        if phase == "train":
            # [batch_size, num_augment, num_starts]
            ll = unbatchify(unbatchify(out["log_likelihood"], self.num_starts), aug_size)
            loss_ps = problem_symmetricity_loss(reward, ll)
            loss_ss = solution_symmetricity_loss(reward, ll)
            loss_inv = invariance_loss(out['proj_embeddings'], self.num_augment)
            loss = loss_ps + self.beta * loss_ss + self.alpha * loss_inv
            loss_retvals = {"loss": loss, "loss_ss": loss_ss, "loss_ps": loss_ps, "loss_inv": loss_inv}

        else:
            # Get best actions for multi-start # [batch_size, num_augment, num_starts]
            max_reward, max_idxs = reward.max(dim=2)
            multi_start_retvals = {"max_reward": max_reward, "best_actions": get_best_actions(out["actions"], max_idxs) if return_action else None}
            # Get best out of augmented # [batch, num_augment]
            max_aug_reward, max_idxs = max_reward.max(dim=1)
            aug_retvals = {"max_aug_reward": max_aug_reward, "best_aug_actions": get_best_actions(out["actions"], max_idxs) if return_action else None}
 
        return { **out, **loss_retvals, **multi_start_retvals, **aug_retvals}
        
    def setup(self, *args, **kwargs):
        pass # no baseline
    
    def on_train_epoch_end(self, *args, **kwargs):
        pass # no baseline

In [15]:
# Simple test of full SymNCO

num_loc = 20
env = TSPEnv(num_loc=num_loc).transform()

dataset = env.dataset(batch_size=[10000])

dataloader = DataLoader(
                dataset,
                batch_size=32,
                shuffle=False, # no need to shuffle, we're resampling every epoch
                num_workers=0,
                collate_fn=torch.stack, # we need this to stack the batches in the dataset
            )

policy = SymNCOPolicy(
    env,
    embedding_dim=128,
    hidden_dim=128,
    num_encode_layers=3,
    num_starts=num_loc,
    # force_flash_attn=True,
).to("cuda")

model = SymNCO(env, policy, num_augment=8, alpha=1, beta=1).to("cuda")


x = next(iter(dataloader)).to("cuda")
x = env.reset(init_obs=x)

out = model(x, decode_type="sampling")

# out = model(x, decode_type="sampling")


## Lightning Module

In [16]:
class NCOLightningModule(L.LightningModule):
    def __init__(self, env, model, lr=1e-4, batch_size=128, train_size=1000, val_size=10000):
        super().__init__()

        # TODO: hydra instantiation
        self.env = env
        self.model = model
        self.lr = lr
        self.batch_size = batch_size
        self.train_size = train_size
        self.val_size = val_size
        self.train_log = ["reward", "loss", "loss_ss", "loss_ps", "loss_inv"]
        self.val_log = ["reward", "max_reward", "max_aug_reward"]
        self.test_log = self.val_log
        self.log_cost = True

    def setup(self, stage="fit"):
        self.train_dataset = self.env.dataset(self.train_size)
        self.val_dataset = self.env.dataset(self.val_size)
        if hasattr(self.model, "setup"):
            self.model.setup(self)

    def shared_step(self, batch: Any, batch_idx: int, phase: str):
        td = self.env.reset(init_obs=batch)
        out = self.model(td, phase)
        
        # Log metrics
        log_metrics = getattr(self, f"{phase}_log")
        metrics = {f"{phase}/{k}": v.mean() for k, v in out.items() if k in log_metrics}

        # If log_cost, replace all max -> min, reward -> cost and invert sign
        if self.log_cost:
            metrics = {k.replace("max", "min").replace("reward", "cost"): -v for k, v in metrics.items()}
        
        self.log_dict(metrics, prog_bar=True)
        
        return {"loss": out.get("loss", None)}

    def training_step(self, batch: Any, batch_idx: int):   
        return self.shared_step(batch, batch_idx, phase='train')

    def validation_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='val')

    def test_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='test')

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-6)
        # optim = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)
        # TODO: scheduler
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_steps)
        return [optim] #, [scheduler]
    
    def train_dataloader(self):
        return self._dataloader(self.train_dataset)
    
    def val_dataloader(self):
        return self._dataloader(self.val_dataset)
    
    def on_train_epoch_end(self):
        if hasattr(self.model, "on_train_epoch_end"):
            self.model.on_train_epoch_end(self)
        self.train_dataset = self.env.dataset(self.train_size) 
       
    def _dataloader(self, dataset):
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False, # no need to shuffle, we're resampling every epoch
            num_workers=0,
            collate_fn=torch.stack, # we need this to stack the batches in the dataset
            # pin_memory=self.on_gpu,
        )

## Main training setup

In [17]:
# Hyperparameters
epochs = 1
batch_size = 64 #1024 #512
n_loc = 20
train_size = 1280000
lr = 1e-4
n_starts = num_loc # TODO: comment to try out = 1
# num_pomo = 1 # set to 1: similar to simple AM

# Environment
env = TSPEnv(num_loc=n_loc).transform()


dataset = env.dataset(batch_size=[10000])

dataloader = DataLoader(
                dataset,
                batch_size=32,
                shuffle=False, # no need to shuffle, we're resampling every epoch
                num_workers=0,
                collate_fn=torch.stack, # we need this to stack the batches in the dataset
            )

policy = SymNCOPolicy(
    env,
    embedding_dim=128,
    hidden_dim=128,
    num_encode_layers=3,
    num_starts=num_loc,
    # force_flash_attn=True,
).to("cuda")

model = SymNCO(env, policy, num_augment=8, alpha=0.2, beta=1).to("cuda")


x = next(iter(dataloader)).to("cuda")
x = env.reset(init_obs=x)

# Quick test
out = model(x, decode_type="sampling")


# Create Lightning module (for training)
lit_model = NCOLightningModule(env, model, batch_size=batch_size, train_size=train_size, lr=lr)

## Fit model

In [18]:
# Trick to make calculations faster
torch.set_float32_matmul_precision("medium")

# Trainer
trainer = L.Trainer(
    max_epochs=epochs,
    accelerator="gpu",
    # devices=[1],
    logger=None, # can replace with WandbLogger, TensorBoardLogger, etc.
    # precision=16, # uncomment to make faster
    log_every_n_steps=100,   
    gradient_clip_val=1.0, # clip gradients to avoid exploding gradients!
)

# Fit the model
trainer.fit(lit_model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | env   | TSPEnv | 0     
1 | model | SymNCO | 743 K 
---------------------------------
743 K     Trainable params
0         Non-trainable params
743 K     Total params
2.973     Total estimated model params size (MB)
2023-04-20 17:43:30.791044: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-04-20 17:43:30.810355: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:   4%|▍         | 813/20000 [01:36<37:59,  8.42it/s, v_num=4, train/cost=4.050, train/loss=0.507, train/loss_ss=0.268, train/loss_ps=0.255, train/loss_inv=-.0829]  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [19]:
trainer.validate(lit_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation DataLoader 0: 100%|██████████| 157/157 [00:07<00:00, 21.82it/s]


[{'val/cost': 4.075383186340332,
  'val/min_cost': 3.876986026763916,
  'val/min_aug_cost': 3.843059778213501}]

## Evaluate generalization on larger problem

In [20]:
# Validating with more locations

env = lit_model.env
env.num_loc = 100
lit_model.val_dataset = env.dataset(lit_model.val_size)
trainer.validate(lit_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation DataLoader 0: 100%|██████████| 157/157 [02:03<00:00,  1.27it/s]


[{'val/cost': 12.835606575012207,
  'val/min_cost': 11.733050346374512,
  'val/min_aug_cost': 11.294989585876465}]