# Attention Model Lightning

In [1]:
# # rich tracebacks
# import rich
# import rich.traceback

# rich.traceback.install()

In [2]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('../../')

import math
from typing import List, Tuple, Optional, NamedTuple, Dict, Union, Any
from einops import rearrange, repeat
from hydra.utils import instantiate

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.nn import DataParallel
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import lightning as L

from torchrl.envs import EnvBase
from torchrl.envs.utils import step_mdp
from tensordict import TensorDict


from ncobench.models.am import AttentionModel
from ncobench.models.common.am_base import AttentionModelBase
from ncobench.models.rl.reinforce import *
# from ncobench.envs.tsp import TSPEnv
# from ncobench.models.nn.graph import GraphAttentionEncoder
from ncobench.models.nn.attention import CrossAttention

from cleaner_v2.tsp_refactor import TSPEnv
from ncobench.data.dataset import TorchDictDataset

  warn(


In [3]:
from am_src.context import env_context
from am_src.embeddings import env_init_embedding, env_dynamic_embedding
from am_src.encoder import GraphAttentionEncoder

## AttentionModelBase

Here we declare the `AttentionModelBase`, which is the `nn.Module`:
- Given initial states, it returns the solutions and rewards for them
- We then wrap the main model with REINFORCE baselines and epoch callbacks to train it (full `AttentionModel`)

In [4]:
from dataclasses import dataclass
import torch.nn.functional as F


@dataclass
class PrecomputedCache:
    node_embeddings: torch.Tensor
    graph_context: torch.Tensor
    glimpse_key: torch.Tensor
    glimpse_val: torch.Tensor
    logit_key: torch.Tensor


class Decoder(nn.Module):
    def __init__(self, env, embedding_dim, n_heads, **logit_attn_kwargs):
        super(Decoder, self).__init__()

        self.env = env
        self.embedding_dim = embedding_dim
        self.n_heads = n_heads

        assert embedding_dim % n_heads == 0

        step_context_dim = 2 * embedding_dim  # Embedding of first and last node
        self.context = env_context(self.env.name, {"context_dim": step_context_dim})
        self.dynamic_embedding = env_dynamic_embedding(
            self.env.name, {"embedding_dim": embedding_dim}
        )
        
        # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim
        self.project_node_embeddings = nn.Linear(embedding_dim, 3 * embedding_dim, bias=False)
        self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.project_step_context = nn.Linear(step_context_dim, embedding_dim, bias=False)

        # MHA
        self.logit_attention = LogitAttention(embedding_dim, n_heads, **logit_attn_kwargs)
        

    def forward(self, td, embeddings, decode_type="sampling"):

        outputs = []
        actions = []

        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        cached_embeds = self._precompute(embeddings)

        while not td["done"].any(): # NOTE: here we suppose all the batch is done at the same time
            
            log_p, mask = self._get_log_p(cached_embeds, td)

            # Select the indices of the next nodes in the sequences, result (batch_size) long
            action = self.decode(log_p.exp().squeeze(1), mask.squeeze(1), decode_type=decode_type)
           
            td.set("action", action[:, None])
            td = self.env.step(td)['next']

            # Collect output of step
            outputs.append(log_p.squeeze(1))
            actions.append(action)

        outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1)
        td.set("reward", self.env.get_reward(td['observation'], actions))
        return outputs, actions, td
    
    def decode(self, probs, mask, decode_type="sampling"):

        assert (probs == probs).all(), "Probs should not contain any nans"

        if decode_type == "greedy":
            _, selected = probs.max(1)
            assert not mask.gather(1, selected.unsqueeze(
                -1)).data.any(), "Decode greedy: infeasible action has maximum probability"

        elif decode_type == "sampling":
            selected = probs.multinomial(1).squeeze(1)

            while mask.gather(1, selected.unsqueeze(-1)).data.any():
                print('Sampled bad values, resampling!')
                selected = probs.multinomial(1).squeeze(1)

        else:
            assert False, "Unknown decode type"
        return selected
    
    def _precompute(self, embeddings):
        # The fixed context projection of the graph embedding is calculated only once for efficiency
        graph_embed = embeddings.mean(1)
        
        # The projection of the node embeddings for the attention is calculated once up front
        glimpse_key_fixed, glimpse_val_fixed, logit_key_fixed = \
            self.project_node_embeddings(embeddings[:, None, :, :]).chunk(3, dim=-1)
        
        # Organize in a TensorDict for easy access
        cached_embeds = PrecomputedCache(
            node_embeddings=embeddings,
            graph_context=self.project_fixed_context(graph_embed)[:, None, :],
            glimpse_key=self.logit_attention._make_heads(glimpse_key_fixed),
            glimpse_val=self.logit_attention._make_heads(glimpse_val_fixed),
            logit_key=logit_key_fixed
        )

        return cached_embeds
        
    def _get_log_p(self, cached, td, normalize=True):
        
        context = self.context(cached.node_embeddings, td)
        step_context = self.project_step_context(context)  # [batch, 1, embed_dim]

        query = cached.graph_context + step_context  # [batch, 1, embed_dim]

        # Compute keys and values for the nodes
        # glimpse_K, glimpse_V, logit_K = self._get_attention_node_data(cached, td['observation'])
        glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = self.dynamic_embedding(td['observation'])
        glimpse_key = cached.glimpse_key + glimpse_key_dynamic
        glimpse_key = cached.glimpse_val + glimpse_val_dynamic
        logit_key = cached.logit_key + logit_key_dynamic

        # Compute the mask
        mask = ~td['action_mask']

        # Compute logits
        log_p = self.logit_attention(query, glimpse_key, glimpse_key, logit_key, mask)

        return log_p, mask   
    

cross_attention = CrossAttention()


class LogitAttention(nn.Module):
    """Calculate logits given query, key and value and logit key 
    If we use Flash Attention, then we automatically move to fp16 for inner computations
    Note: with Flash Attention, masking is not supported

    Perform the following:
        1. Apply cross attention to get the heads
        2. Project heads to get glimpse
        3. Compute attention score between glimpse and logit key
        4. Normalize and mask
    """
    def __init__(self, embed_dim, n_heads, tanh_clipping=10.0, mask_inner=True, mask_logits=True, normalize=True, force_flash_attn=False):
        super(LogitAttention, self).__init__()
        self.n_heads = n_heads
        self.mask_logits = mask_logits
        self.mask_inner = mask_inner
        self.tanh_clipping = tanh_clipping
        self.temp = 1.0
        self.normalize = normalize
        self.force_flash_attn = force_flash_attn

        if force_flash_attn and mask_inner:
            print("WARNING: Flash Attention does not support masking, setting force_flash_attn to False")
            self.force_flash_attn = False

        # Projection - query, key, value already include projections
        self.project_out = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, query, key, value, logit_key, mask):

        # Compute inner multi-head attention with no projections
        heads = self._inner_mha(query, key, value, mask)
        glimpse = self.project_out(heads)

        # Batch matrix multiplication to compute logits (batch_size, num_steps, graph_size)
        # bmm is slightly faster than einsum and matmul       
        logits = torch.bmm(glimpse.squeeze(1), logit_key.squeeze(1).transpose(-2, -1))/ math.sqrt(glimpse.size(-1))

        # From the logits compute the probabilities by clipping, masking and softmax
        if self.tanh_clipping > 0:
            logits = torch.tanh(logits) * self.tanh_clipping

        if self.mask_logits:
            logits[mask] = float('-inf')

        if self.normalize:
            logits = torch.log_softmax(logits / self.temp, dim=-1)

        assert not torch.isnan(logits).any()

        return logits
    
    def _inner_mha(self, query, key, value, mask):

        # Flash Attention: move to fp16 for inner computations
        if self.force_flash_attn:
            src_dtype = query.dtype
            query = rearrange(query, 'b 1 (h s) -> b h 1 s', h=self.n_heads)
            query, key, value = query.half(), key.half(), value.half()
            heads = F.scaled_dot_product_attention(query, key, value)
            heads = rearrange(heads, 'b h 1 g -> b 1 1 (h g)', h=self.n_heads).to(src_dtype)

        # Otherwise, get mask and use cross attention (faster than even original to converge)
        else:
            kv = torch.cat([key, value], dim=2)
            q = rearrange(query, 'b 1 (h s) -> b 1 h s', h=self.n_heads)
            key_padding_mask = ~mask.squeeze() if self.mask_inner else None
            heads = cross_attention(q, kv, key_padding_mask=key_padding_mask)
            heads = rearrange(heads, 'b 1 h g -> b 1 1 (h g)', h=self.n_heads)

        #### NEW
        # query = rearrange(query, 'b 1 (h s) -> b h 1 s', h=self.n_heads)
        # src_dtype = query.dtype
        # query, key, value = query.half(), key.half(), value.half()

        # mask = ~mask.unsqueeze(1) if self.mask_inner else None

        # heads = F.scaled_dot_product_attention(query, key, value, attn_mask=mask)
        # heads = rearrange(heads, 'b h 1 g -> b 1 1 (h g)', h=self.n_heads).to(src_dtype)

        return heads

    def _make_heads(self, v):
        if self.force_flash_attn:
            v = rearrange(v, 'b 1 g (h s) -> b h g s', h=self.n_heads)
        else:
            v = rearrange(v, 'b 1 g (h s) -> b g 1 h s', h=self.n_heads)
        return v

In [5]:
# torch.Size([128, 1, 128])
# torch.Size([128, 10, 1, 8, 16])
# torch.Size([128, 10, 1, 8, 16])
import torch.nn.functional as F

query = torch.rand(128, 1, 128, dtype=torch.float16, device="cuda")
key = torch.rand(128, 10, 1, 8, 16, dtype=torch.float16, device="cuda")
value = torch.rand(128, 10, 1, 8, 16, dtype=torch.float16, device="cuda")


n_heads = 8
query = rearrange(query, 'b 1 (h s) -> b h 1 s', h=n_heads)
key = rearrange(key, 'b g 1 h s -> b h g s', h=n_heads)
value = rearrange(value, 'b g 1 h s -> b h g s', h=n_heads)
# mask 
mask = torch.rand(128, 1, 1, 10, dtype=torch.float16, device="cuda") > 0.5

heads = F.scaled_dot_product_attention(query, key, value, attn_mask=mask) #, attn_mask=mask)

# heads = rearrange(heads, 'b h 1 g -> b 1 1 (h g)', h=n_heads)

print(heads.shape)

torch.Size([128, 8, 1, 16])


In [6]:
class AttentionModelBase(nn.Module):

    def __init__(self,
                 env: EnvBase,
                 embedding_dim: int,
                 hidden_dim: int,
                 *,
                 n_encode_layers: int = 2,
                 normalization: str = 'batch',
                 n_heads: int = 8,
                 checkpoint_encoder: bool = False,
                 mask_inner: bool = True,
                 force_flash_attn: bool = False,
                 **kwargs
                 ):
        super(AttentionModelBase, self).__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_encode_layers = n_encode_layers
        self.env = env

        self.n_heads = n_heads
        self.checkpoint_encoder = checkpoint_encoder

        self.init_embedding = env_init_embedding(self.env.name, {"embedding_dim": embedding_dim})

        self.encoder = GraphAttentionEncoder(
            n_heads=n_heads,
            embed_dim=embedding_dim,
            n_layers=self.n_encode_layers,
            normalization=normalization,
            force_flash_attn=force_flash_attn,
        )
        
        self.decoder = Decoder(env, embedding_dim, n_heads, mask_inner=mask_inner, force_flash_attn=force_flash_attn)


    def forward(self, td: TensorDict, phase: str = "train", decode_type: str = "sampling") -> TensorDict:
        """Given observation, precompute embeddings and rollout"""

        # Set decoding type for policy, can be also greedy
        embedding = self.init_embedding(td)
        encoded_inputs, _ = self.encoder(embedding)

        # Main rollout
        _log_p, actions, td = self.decoder(td, encoded_inputs, decode_type)
        # reward = self.env.get_reward(td['observation'], actions)

        # Log likelyhood is calculated within the model since returning it per action does not work well with
        ll = self._calc_log_likelihood(_log_p, actions, td.get('mask', None))
        out = {"reward": td["reward"], "log_likelihood": ll, "actions": actions}

        return out

    def _calc_log_likelihood(self, _log_p, a, mask):

        # Get log_p corresponding to selected actions
        log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)

        # Optional: mask out actions irrelevant to objective so they do not get reinforced
        if mask is not None:
            log_p[mask] = 0

        assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!"

        # Calculate log_likelihood
        return log_p.sum(1)

## Test `AttentionModelBase`

In [7]:
env = TSPEnv(num_loc=10).transform()

# data = env.gen_params(batch_size=[10000]) # NOTE: need to put batch_size in a list!!
# init_td = env.reset(data)
# env.batch_size = [10000]
init_td = env.reset(batch_size=[10000])
dataset = TorchDictDataset(init_td)


dataloader = DataLoader(
                dataset,
                batch_size=128,
                shuffle=False, # no need to shuffle, we're resampling every epoch
                num_workers=0,
                collate_fn=torch.stack, # we need this to stack the batches in the dataset
            )

model = AttentionModelBase(
    env,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
    # force_flash_attn=True,
).to("cuda")

# model = torch.compile(model, backend="cuda")

x = next(iter(dataloader)).to("cuda")

out = model(x, decode_type="sampling")

In [8]:
def get_lightning_device(lit_module: L.LightningModule) -> torch.device:
    """Get the device of the lightning module
    See device setting issue in setup https://github.com/Lightning-AI/lightning/issues/2638
    """
    if lit_module.trainer.strategy.root_device != lit_module.device:
        return lit_module.trainer.strategy.root_device
    return lit_module.device


class AttentionModel(nn.Module):
    def __init__(self, env, policy):
        super().__init__()
        self.env = env
        self.policy = policy

        # TODO: hydra instantiation
        # self.policy = instantiate(cfg.policy)
        # self.baseline = instantiate(cfg.baseline) TODO

    def forward(self, td: TensorDict, phase: str="train", decode_type: str=None) -> TensorDict:
        """Evaluate model, get costs and log probabilities and compare with baseline"""

        # Evaluate model, get costs and log probabilities
        out_policy = self.policy(td)
        bl_val, bl_loss = self.baseline.eval(td, -out_policy['reward'])

        # print(bl_val, bl_loss)
        # Calculate loss
        advantage = -out_policy['reward'] - bl_val
        reinforce_loss = (advantage * out_policy['log_likelihood']).mean()
        loss = reinforce_loss + bl_loss

        return {'loss': loss, 'reinforce_loss': reinforce_loss, 'bl_loss': bl_loss, 'bl_val': bl_val, **out_policy}
    
    def setup(self, lit_module):
        # Make baseline taking model itself and train_dataloader from model as input
        # TODO make this as taken from config
        self.baseline = instantiate({"_target_": "__main__.WarmupBaseline",
                                    "baseline": {"_target_": "__main__.RolloutBaseline",                                             }
                                    })  

        self.baseline.setup(self.policy, lit_module.val_dataloader(), self.env, device=get_lightning_device(lit_module))         
        # self.baseline = NoBaseline()

    def on_train_epoch_end(self, lit_module):
        # self.baseline.epoch_callback(self.policy, self.env, pl_module)
        self.baseline.epoch_callback(self.policy, lit_module.val_dataloader(), lit_module.current_epoch, self.env, device=get_lightning_device(lit_module))

## Lightning Module

In [9]:
class NCOLightningModule(L.LightningModule):
    def __init__(self, env, model, lr=1e-4, batch_size=128, train_size=1000, val_size=10000):
        super().__init__()

        # TODO: hydra instantiation
        self.env = env
        self.model = model
        self.lr = lr
        self.batch_size = batch_size
        self.train_size = train_size
        self.val_size = val_size

    def setup(self, stage="fit"):
        self.train_dataset = self.get_observation_dataset(self.train_size)
        self.val_dataset = self.get_observation_dataset(self.val_size)
        if hasattr(self.model, "setup"):
            self.model.setup(self)

    def shared_step(self, batch: Any, batch_idx: int, phase: str):
        td = self.env.reset(init_observation=batch)
        output = self.model(td, phase)
        
        # output = self.model(batch, phase)
        self.log(f"{phase}/cost", -output["reward"].mean(), prog_bar=True)
        return {"loss": output['loss']}

    def training_step(self, batch: Any, batch_idx: int):   
        return self.shared_step(batch, batch_idx, phase='train')

    def validation_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='val')

    def test_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='test')

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-5)
        # TODO: scheduler
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_steps)
        return [optim] #, [scheduler]
    
    def train_dataloader(self):
        return self._dataloader(self.train_dataset)
    
    def val_dataloader(self):
        return self._dataloader(self.val_dataset)
    
    def on_train_epoch_end(self):
        if hasattr(self.model, "on_train_epoch_end"):
            self.model.on_train_epoch_end(self)
        self.train_dataset = self.get_observation_dataset(self.train_size) 

    # def get_observation_dataset(self, size):
    #     # online data generation: we generate a new batch online
    #     data = self.env.gen_params(batch_size=size)
    #     return TorchDictDataset(self.env.reset(data))

    def get_observation_dataset(self, size):
        # online data generation: we generate a new batch online
        # data = self.env.gen_params(batch_size=size)
        return TorchDictDataset(self.env.reset(batch_size=[size])['observation'])
       
    def _dataloader(self, dataset):
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False, # no need to shuffle, we're resampling every epoch
            num_workers=0,
            collate_fn=torch.stack, # we need this to stack the batches in the dataset
            pin_memory=self.on_gpu,
        )

In [10]:
# Disable profiling executor. This reduces memory and increases speed.
try:
    torch._C._jit_set_profiling_executor(False)
    torch._C._jit_set_profiling_mode(False)
except AttributeError:
    pass


In [11]:
env = TSPEnv(num_loc=20).transform()

# env = env.transform()
policy = AttentionModelBase(
    env,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
    force_flash_attn=True,
)

model_final = AttentionModel(env, policy)

# # TODO CHANGE THIS
batch_size = 512 #1024 #512

model = NCOLightningModule(env, model_final, batch_size=batch_size, train_size=1280000, lr=1e-4)

# Trick to make calculations faster
torch.set_float32_matmul_precision("medium")

# Wandb Logger - we can use others as well as simply `None`
# logger = pl.loggers.WandbLogger(project="torchrl", name="pendulum")
# logger = L.loggers.CSVLogger("logs", name="tsp")

epochs = 1

# from lightning.pytorch.callbacks import DeviceStatsMonitor
# callbacks = [DeviceStatsMonitor()]

from lightning.pytorch.profilers import AdvancedProfiler

profiler = AdvancedProfiler(dirpath=".", filename="perf_logsv2")

# Trainer
trainer = L.Trainer(
    max_epochs=epochs,
    accelerator="gpu",
    devices=[1],
    # callbacks=callbacks,
    # profiler=profiler,
    # strategy="deepspeed_stage_3_offload",
    # precision=16,
    log_every_n_steps=100,   
    gradient_clip_val=1.0, # clip gradients to avoid exploding gradients
)

# Fit the model
trainer.fit(model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Evaluating baseline model on evaluation dataset


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type           | Params
-----------------------------------------
0 | env   | TSPEnv         | 0     
1 | model | AttentionModel | 1.4 M 
-----------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.681     Total estimated model params size (MB)


                                                                           

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:  16%|█▌        | 398/2500 [00:27<02:23, 14.67it/s, v_num=164, train/cost=4.120]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
