In [51]:
import torch
from torch.nn import Module
from torch import nn

class MCLatencyDecoder(Module):
    def __init__(self, emb_size, nheads=4, dropout=0.1):
        super(MCLatencyDecoder, self).__init__()
        
        self.attn1 = nn.MultiheadAttention(emb_size, num_heads=nheads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(emb_size)
        self.attn2 = nn.MultiheadAttention(emb_size, num_heads=nheads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(emb_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(emb_size, emb_size * 4),
            nn.GELU(),
            nn.Linear(emb_size * 4, emb_size)
        )
        
        self.norm3 = nn.LayerNorm(emb_size)
        
    def forward(self, inp, encoded, mask=None):
        # TODO figure out mask
        x1, _ = self.attn1(inp, inp, inp, key_padding_mask=mask, need_weights=False)
        x1 = self.norm1(x1 + inp)
        
        view = encoded[:, :x1.shape[1], :]
        x2, _ = self.attn2(view, view, x1, key_padding_mask=mask, need_weights=False)
        x2 = self.norm2(x2 + x1)
        
        x3 = self.feed_forward(x2)
        x3 = self.norm3(x3 + x2)

        return x3

In [36]:
class MCNNConfig:
    def __init__(self,
                 num_opcodes,
                 batch_size=64,
                 embedding_size=128,
                 hidden_size=256,
                 num_heads_encoder=4,
                 num_heads_decoder=4,
                 num_encoders=4,
                 num_decoders=4,
                 dropout=0.1,
                 learning_rate=1e-3,
                 ):
        self.num_opcodes = num_opcodes
        self.batch_size = batch_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_heads_encoder = num_heads_encoder
        self.num_heads_decoder = num_heads_decoder
        self.num_encoders = num_encoders
        self.num_decoders = num_decoders
        self.dropout = dropout
        self.learning_rate = learning_rate

In [52]:
import lightning.pytorch as pl
from llvm_ml.torch.nn import MCEmbedding, MCGraphEncoder
from torch_geometric.utils import to_dense_batch, to_dense_adj
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim import lr_scheduler
import math

class MCLatencyTransformer(pl.LightningModule):
    def __init__(self, config: MCNNConfig):
        super(MCLatencyTransformer, self).__init__()
        
        self.config = config
        
        self.embedding = MCEmbedding(self.config.num_opcodes, self.config.embedding_size)
        self.encoders = nn.ModuleList(
            [MCGraphEncoder(self.config.embedding_size, self.config.hidden_size, self.config.num_heads_encoder, self.config.dropout) for _ in range(self.config.num_encoders)]
        )
        
        # TODO add positional encoding?
        
        self.decoders = nn.ModuleList(
            [MCLatencyDecoder(self.config.embedding_size, self.config.num_heads_decoder, self.config.dropout) for _ in range(self.config.num_decoders)]
        )
        
        self.fc = nn.Linear(self.config.embedding_size, self.config.embedding_size)
        self.softmax = nn.Softmax(dim=1)
        self.regression = nn.Sequential(
            nn.Linear(self.config.embedding_size, self.config.embedding_size * 4),
            #nn.ReLU(),
            nn.Linear(self.config.embedding_size * 4, self.config.embedding_size),
            #nn.ReLU(),
            nn.Linear(self.config.embedding_size, 1)
        )
        
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
        
        
    def forward(self, nodes, edge_index, batch):
        embedded, pos_enc = self.embedding(nodes)
        
        encoded, mask = to_dense_batch(embedded, batch)
        
        dense_edges = to_dense_adj(edge_index, batch)
        dense_edges = dense_edges.view(encoded.shape[0], encoded.shape[1], encoded.shape[1])
        
        for encoder in self.encoders:
            encoded = encoder(encoded, dense_edges, mask)
            
        decoded_seq = torch.zeros((encoded.shape[0], encoded.shape[1], encoded.shape[2]), device=encoded.device, dtype=torch.float16)
            
        for i in range(1, encoded.shape[1]):
            decoded = decoded_seq
            for decoder in self.decoders:
                decoded = decoder(decoded, encoded, torch.logical_not(mask))
        
            decoded = self.fc(decoded)
            #decoded = self.softmax(decoded)
                
            #last_row = decoded[:, [-1], :]
            #decoded_seq[:, [i - 1], :] = last_row
            # decoded_seq = torch.cat((decoded_seq, last_row), dim=1)
            decoded_seq = decoded
            
        #print(decoded_seq)
        
        out = self.regression(decoded_seq)
        #print(decoded_seq)
        
        return out


    def _init_weights(self, module: nn.Module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


    def _step(self, batch, stage: str):
        bb, raw, mask_id, original_token = batch

        latencies = self.forward(bb.x, bb.edge_index, bb.batch)

        latencies = torch.reshape(latencies, shape=(latencies.shape[0], latencies.shape[1]))
        
        y_hat = torch.sum(latencies[:, 1:], dim=1)
        # print(y_hat)
        
        loss = F.mse_loss(y_hat, bb.y)
        #print(latencies)

        log_prefix = "train" if stage == 'train' else "val"

        self.log(f"{log_prefix}_loss", loss, on_epoch=True, batch_size=self.config.batch_size)

        return loss, bb, raw

    def training_step(self, batch, batch_idx):
        loss, _, _ = self._step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        loss, _, _ = self._step(batch, 'val')
        return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.config.learning_rate, weight_decay=1e-3)
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.1)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss',
            }
        }

In [38]:
from llvm_ml.torch import BasicBlockDataset
banned_ids = []
dataset = BasicBlockDataset("./data/ryzen3600_v16.cbuf", masked=False, banned_ids=banned_ids, prefilter=True)
print(f"Training with {len(dataset)} samples")


Training with 347988 samples


In [53]:
from torch_geometric.loader import DataLoader
from lightning.pytorch.callbacks import ModelSummary, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
import torch

import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')

config = MCNNConfig(dataset.num_opcodes)
config.learning_rate = 1e-5
config.batch_size = 64
config.hidden_size = 128
config.embedding_size = 64

num_training = int(0.7 * len(dataset))
num_val = len(dataset) - num_training

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [num_training, num_val])
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=6, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=6, drop_last=True)

model = MCLatencyTransformer(config)

logger = TensorBoardLogger("runs", name="transformer")
logger.log_graph(model)
callbacks = [
    ModelSummary(max_depth=-1),
    LearningRateMonitor(),
]
trainer = pl.Trainer(max_epochs=50,
                     logger=logger,
                     precision='16-mixed',
                     callbacks=callbacks,
                     #fast_dev_run=True,
                     overfit_batches=1,
                     log_every_n_steps=1,
                     )
trainer.fit(model, train_loader, val_loader)


Using 16bit Automatic Mixed Precision (AMP)
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

    | Name                             | Type                            | Params
---------------------------------------------------------------------------------------
0   | embedding                        | MCEmbedding                     | 1.3 M 
1   | embedding.embedding              | Embedding                       | 1.3 M 
2   | embedding.pos_encoding           | PositionalEncoding              | 0     
3   | embedding.norm                   | LayerNorm                       | 128   
4   | encoders   

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/alex/anaconda3/envs/cpu-uarch-prediction-py11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:268: You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
