# Attention Model Lightning

In [1]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('../../')

import math
from typing import List, Tuple, Optional, NamedTuple, Dict, Union, Any
from einops import rearrange, repeat
from hydra.utils import instantiate

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.nn import DataParallel
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import lightning as L

from torchrl.envs import EnvBase
from torchrl.envs.utils import step_mdp
from tensordict import TensorDict

from ncobench.envs.tsp import TSPEnv
from ncobench.models.rl.reinforce import *
from ncobench.models.co.am.context import env_context
from ncobench.models.co.am.embeddings import env_init_embedding, env_dynamic_embedding
from ncobench.models.co.am.encoder import GraphAttentionEncoder
from ncobench.models.co.am.decoder import Decoder, decode_probs, PrecomputedCache, LogitAttention
from ncobench.models.co.am.policy import get_log_likelihood
from ncobench.models.nn.attention import NativeFlashMHA, flash_attn_wrapper
from ncobench.utils.lightning import get_lightning_device

  warn(


## AttentionModelBase

Here we declare the `AttentionModelBase`, which is the `nn.Module`:
- Given initial states, it returns the solutions and rewards for them
- We then wrap the main model with REINFORCE baselines and epoch callbacks to train it (full `AttentionModel`)

In [3]:
class AttentionModelBase(nn.Module):

    def __init__(self,
                 env: EnvBase,
                 embedding_dim: int,
                 hidden_dim: int,
                 encoder: nn.Module = None,
                 decoder: nn.Module = None,
                 n_encode_layers: int = 3,
                 normalization: str = 'batch',
                 n_heads: int = 8,
                 checkpoint_encoder: bool = False,
                 mask_inner: bool = True,
                 force_flash_attn: bool = False,
                 **kwargs
                 ):
        super(AttentionModelBase, self).__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_encode_layers = n_encode_layers
        self.env = env

        self.n_heads = n_heads
        self.checkpoint_encoder = checkpoint_encoder

        self.init_embedding = env_init_embedding(self.env.name, {"embedding_dim": embedding_dim})

        self.encoder = GraphAttentionEncoder(
            n_heads=n_heads,
            embed_dim=embedding_dim,
            n_layers=self.n_encode_layers,
            normalization=normalization,
            force_flash_attn=force_flash_attn,
        ) if encoder is None else encoder
        
        self.decoder = Decoder(env, embedding_dim, n_heads, mask_inner=mask_inner, force_flash_attn=force_flash_attn) if decoder is None else decoder

    def forward(self, td: TensorDict, phase: str = "train", decode_type: str = "sampling") -> TensorDict:
        """Given observation, precompute embeddings and rollout"""

        # Set decoding type for policy, can be also greedy
        embedding = self.init_embedding(td)
        encoded_inputs, _ = self.encoder(embedding)

        # Main rollout
        _log_p, actions, td = self.decoder(td, encoded_inputs, decode_type)
        # reward = self.env.get_reward(td['observation'], actions)

        # Log likelyhood is calculated within the model since returning it per action does not work well with
        ll = self._calc_log_likelihood(_log_p, actions, td.get('mask', None))
        out = {"reward": td["reward"], "log_likelihood": ll, "actions": actions}

        return out

    def _calc_log_likelihood(self, _log_p, a, mask):

        # Get log_p corresponding to selected actions
        log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)

        # Optional: mask out actions irrelevant to objective so they do not get reinforced
        if mask is not None:
            log_p[mask] = 0

        assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!"

        # Calculate log_likelihood
        return log_p.sum(1)

## Test `AttentionModelBase`

In [6]:
from ncobench.envs.dpp import DPPEnv

In [17]:
x.shape

torch.Size([])

In [15]:
# num_loc = 15
# env = TSPEnv(num_loc=num_loc).transform()

base_folder = '../../'
chip_fpath = base_folder+"data/dpp/10x10_pkg_chip.npy"
decap_fpath = base_folder+"data/dpp/01nF_decap.npy"
freq_fpath = base_folder+"data/dpp/freq_201.npy"

env = DPPEnv(chip_fpath=chip_fpath, decap_fpath=decap_fpath, freq_fpath=freq_fpath)
dataset = env.dataset(batch_size=[1000])

dataloader = DataLoader(
                dataset,
                batch_size=32,
                shuffle=False, # no need to shuffle, we're resampling every epoch
                num_workers=0,
                collate_fn=torch.stack, # we need this to stack the batches in the dataset
            )

model = AttentionModelBase(
    env,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
    # num_pomo=num_loc,
    # force_flash_attn=True,
).to("cuda")

# model = torch.compile(model)

x = next(iter(dataloader)).to("cuda")
x = env.reset(init_obs=x)

out = model(x, decode_type="sampling")

RuntimeError: mat1 and mat2 must have the same dtype

In [6]:
def get_lightning_device(lit_module: L.LightningModule) -> torch.device:
    """Get the device of the lightning module
    See device setting issue in setup https://github.com/Lightning-AI/lightning/issues/2638
    """
    if lit_module.trainer.strategy.root_device != lit_module.device:
        return lit_module.trainer.strategy.root_device
    return lit_module.device


class AttentionModel(nn.Module):
    def __init__(self, env, policy):
        super().__init__()
        self.env = env
        self.policy = policy

        # TODO: hydra instantiation
        # self.policy = instantiate(cfg.policy)
        # self.baseline = instantiate(cfg.baseline) TODO

    def forward(self, td: TensorDict, phase: str="train", decode_type: str=None) -> TensorDict:
        """Evaluate model, get costs and log probabilities and compare with baseline"""

        # Evaluate model, get costs and log probabilities
        out_policy = self.policy(td)
        bl_val, bl_loss = self.baseline.eval(td, -out_policy['reward'])

        # print(bl_val, bl_loss)
        # Calculate loss
        advantage = -out_policy['reward'] - bl_val
        reinforce_loss = (advantage * out_policy['log_likelihood']).mean()
        loss = reinforce_loss + bl_loss

        return {'loss': loss, 'reinforce_loss': reinforce_loss, 'bl_loss': bl_loss, 'bl_val': bl_val, **out_policy}
    
    def setup(self, lit_module):
        # Make baseline taking model itself and train_dataloader from model as input
        # TODO make this as taken from config
        self.baseline = instantiate({"_target_": "__main__.WarmupBaseline",
                                    "baseline": {"_target_": "__main__.RolloutBaseline",                                             }
                                    })  

        self.baseline.setup(self.policy, lit_module.val_dataloader(), self.env, device=get_lightning_device(lit_module))         
        # self.baseline = NoBaseline()

    def on_train_epoch_end(self, lit_module):
        # self.baseline.epoch_callback(self.policy, self.env, pl_module)
        self.baseline.epoch_callback(self.policy, lit_module.val_dataloader(), lit_module.current_epoch, self.env, device=get_lightning_device(lit_module))

## Lightning Module

In [7]:
class NCOLightningModule(L.LightningModule):
    def __init__(self, env, model, lr=1e-4, batch_size=128, train_size=1000, val_size=10000):
        super().__init__()

        # TODO: hydra instantiation
        self.env = env
        self.model = model
        self.lr = lr
        self.batch_size = batch_size
        self.train_size = train_size
        self.val_size = val_size

    def setup(self, stage="fit"):
        self.train_dataset = self.get_observation_dataset(self.train_size)
        self.val_dataset = self.get_observation_dataset(self.val_size)
        if hasattr(self.model, "setup"):
            self.model.setup(self)

    def shared_step(self, batch: Any, batch_idx: int, phase: str):
        td = self.env.reset(init_observation=batch)
        output = self.model(td, phase)
        
        # output = self.model(batch, phase)
        self.log(f"{phase}/cost", -output["reward"].mean(), prog_bar=True)
        return {"loss": output['loss']}

    def training_step(self, batch: Any, batch_idx: int):   
        return self.shared_step(batch, batch_idx, phase='train')

    def validation_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='val')

    def test_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='test')

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-5)
        # optim = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)
        # TODO: scheduler
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_steps)
        return [optim] #, [scheduler]
    
    def train_dataloader(self):
        return self._dataloader(self.train_dataset)
    
    def val_dataloader(self):
        return self._dataloader(self.val_dataset)
    
    def on_train_epoch_end(self):
        if hasattr(self.model, "on_train_epoch_end"):
            self.model.on_train_epoch_end(self)
        self.train_dataset = self.get_observation_dataset(self.train_size) 

    # def get_observation_dataset(self, size):
    #     # online data generation: we generate a new batch online
    #     data = self.env.gen_params(batch_size=size)
    #     return TorchDictDataset(self.env.reset(data))

    def get_observation_dataset(self, size):
        # online data generation: we generate a new batch online
        # data = self.env.gen_params(batch_size=size)
        return TorchDictDataset(self.env.reset(batch_size=[size])['observation'])
       
    def _dataloader(self, dataset):
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False, # no need to shuffle, we're resampling every epoch
            num_workers=0,
            collate_fn=torch.stack, # we need this to stack the batches in the dataset
            pin_memory=self.on_gpu,
        )

In [8]:
# Disable profiling executor. This reduces memory and increases speed.
try:
    torch._C._jit_set_profiling_executor(False)
    torch._C._jit_set_profiling_mode(False)
except AttributeError:
    pass


In [13]:
env = TSPEnv(num_loc=20).transform()

# env = env.transform()
policy = AttentionModelBase(
    env,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
    force_flash_attn=True,
)


model = instantiate({"_target_": "__main__.AttentionModel", "env": env}, policy=policy)



: 

In [6]:
cfg.strategy

ConfigAttributeError: Missing key strategy
    full_key: strategy
    object_type=dict

In [7]:
import lightning.strategies

ModuleNotFoundError: No module named 'lightning.strategies'

In [3]:
# Create omegaconf config
from omegaconf import OmegaConf

cfg = OmegaConf.create({"model": "ciao"})

# Add key to config
cfg.est


NameError: name 'model' is not defined

In [9]:
env = TSPEnv(num_loc=20).transform()

# env = env.transform()
policy = AttentionModelBase(
    env,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
    force_flash_attn=True,
)

model_final = AttentionModel(env, policy)

# # TODO CHANGE THIS
batch_size = 512 #1024 #512

model = NCOLightningModule(env, model_final, batch_size=batch_size, train_size=1280000, lr=1e-4)

# Trick to make calculations faster
torch.set_float32_matmul_precision("medium")

# Wandb Logger - we can use others as well as simply `None`
# logger = pl.loggers.WandbLogger(project="torchrl", name="pendulum")
# logger = L.loggers.CSVLogger("logs", name="tsp")

epochs = 1


# from lightning.pytorch.callbacks import DeviceStatsMonitor
# callbacks = [DeviceStatsMonitor()]

from lightning.pytorch.profilers import AdvancedProfiler

profiler = AdvancedProfiler(dirpath=".", filename="perf_logsv2")

# Trainer
trainer = L.Trainer(
    max_epochs=epochs,
    accelerator="gpu",
    devices=[1],
    # callbacks=callbacks,
    # profiler=profiler,
    # strategy="deepspeed_stage_3_offload",
    precision=16,
    log_every_n_steps=100,   
    gradient_clip_val=1.0, # clip gradients to avoid exploding gradients
)

# Fit the model
trainer.fit(model)

  rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Evaluating baseline model on evaluation dataset


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type           | Params
-----------------------------------------
0 | env   | TSPEnv         | 0     
1 | model | AttentionModel | 1.4 M 
-----------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.681     Total estimated model params size (MB)


                                                                           

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:  92%|█████████▏| 2306/2500 [02:23<00:12, 16.06it/s, v_num=177, train/cost=4.020]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
