In [1]:
import torch_geometric as tg
from llvm_ml.utils import load_dataset
import numpy as np
import torch


class MCMLMDataset(tg.data.Dataset):
    def __init__(self, path, context_length=2):
        super().__init__()
        dataset = load_dataset(path, True, False)
        
        self.context_length = 2
        
        self.graphs = []
        self.masked_ids = []
        
        for piece in dataset:
            if len(piece.nodes) < (context_length * 2 + 1):
                continue

            nodes = np.zeros(len(piece.nodes), dtype=np.int_)
            
            for idx, n in enumerate(piece.nodes):
                nodes[idx] = n.opcode

            edges = np.zeros((len(piece.edges), 2), dtype=np.int_)

            for idx, e in enumerate(piece.edges):
                edges[idx, 0] = e.from_node
                edges[idx, 1] = e.to_node

            for idx in range(context_length, len(piece.nodes) - context_length):
                self.masked_ids.append(idx)
                self.graphs.append(tg.data.Data(x=torch.from_numpy(nodes), edge_index=torch.from_numpy(np.transpose(edges)).contiguous(), y=torch.tensor(piece.nodes[idx].opcode)))

    def len(self) -> int:
        return len(self.graphs)

    def get(self, idx: int):
        return self.graphs[idx], self.masked_ids[idx]


In [2]:
dataset = MCMLMDataset("./data/ryzen3600_v8.cbuf")
print(f"Training with {len(dataset)} examples")

Training with 728026 examples


In [3]:
import pytorch_lightning as pl
from torch.nn import Embedding, Linear
from torch_geometric.nn import GraphConv
from torch_geometric.utils import to_dense_batch
import torch.nn.functional as F
from torchmetrics import Accuracy

class MCMLM(pl.LightningModule):
    def __init__(self, hidden_size, num_opcodes, batch_size, embedding_dim=32, context_size=2, learning_rate=0.002):
        super(MCMLM, self).__init__()
        
        self.batch_size = batch_size
        self.context_size = context_size
        self.num_opcodes = num_opcodes
        self.lr = learning_rate

        self.embedding = Embedding(num_embeddings=num_opcodes, embedding_dim=embedding_dim)
        self.conv = GraphConv(embedding_dim, hidden_size)
        
        self.fc = Linear(hidden_size, hidden_size)
        self.decode = Linear(hidden_size, num_opcodes)

        self.val_accuracy = Accuracy(task="binary")
        self.train_accuracy = Accuracy(task="binary")
        
        
    def forward(self, data, masked_id):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        
        x = self.embedding(x)
        
        x = self.conv(x, edge_index)
        nodes, _ = to_dense_batch(x, batch)

        x = torch.zeros(self.batch_size, self.context_size * 2 + 1, nodes.shape[2], device=x.device, dtype=x.dtype)
        
        for b in range(self.batch_size):
            for idx, val in enumerate(range(-self.context_size, self.context_size + 1)):
                if val == 0:
                    x[b, idx, :] = 0
                else:
                    x[b, idx, :] = nodes[b, masked_id[b] + val, :]
        
        x = F.gelu(x)
        x = self.fc(x)
        x = F.gelu(x)
        x = self.decode(x)
        x = x[:, self.context_size, :]
        x = F.sigmoid(x)
        
        x = x.reshape(self.batch_size, self.num_opcodes)
        
        return x
    
    def training_step(self, batch, batch_idx):
        data, masked_ids = batch
        
        y_hat = self(data, masked_ids)
        
        target = torch.zeros(self.batch_size, self.num_opcodes, dtype=y_hat.dtype)
        for idx in masked_ids:
            target[:, idx] = 1

        target = target.to(y_hat.device)
        
        loss = F.binary_cross_entropy(y_hat, target)

        self.log("train_loss", loss, on_epoch=True, batch_size=self.batch_size)
        self.train_accuracy(y_hat, target)
        self.log("train_accuracy", self.train_accuracy)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        data, masked_ids = batch

        y_hat = self(data, masked_ids)

        target = torch.zeros(self.batch_size, self.num_opcodes, dtype=y_hat.dtype)
        for idx in masked_ids:
            target[:, idx] = 1

        target = target.to(y_hat.device)
        loss = F.binary_cross_entropy(y_hat, target)

        self.log("val_loss", loss, on_epoch=True, batch_size=self.batch_size)
        self.val_accuracy(y_hat, target)
        self.log("val_accuracy", self.val_accuracy, on_epoch=True, batch_size=self.batch_size)

        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, verbose=True, min_lr=1e-6, cooldown=5)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss',
            }
        }

In [4]:
from torch_geometric.loader import DataLoader
import torch.utils.data
from lightning.pytorch.loggers import TensorBoardLogger

batch_size = 512
hidden_size = 128

num_training = int(0.7 * len(dataset))
num_val = len(dataset) - num_training

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [num_training, num_val])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=6, drop_last=True)

model = MCMLM(hidden_size, 21000, batch_size)

logger = TensorBoardLogger("runs", name="mcmlm")
logger.log_graph(model)
trainer = pl.Trainer(max_epochs=100, logger=logger, fast_dev_run=False)
trainer.fit(model, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2023-08-27 09:53:16.015845: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type           | Params
--------------------------------------------------
0 | embedding      | Embedding      | 672 K 
1 | conv           | GraphConv      | 8.3 K 
2 | fc             | Linear         | 16.5 K
3 | decode         | Linear         | 2.7 M 
4 | val_accuracy   | BinaryAccuracy | 0     
5 | train_accuracy | BinaryAccuracy | 0     
--------------------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.4 M     Total pa

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
