In [1]:
import torch
from llvm_ml.torch import BasicBlockDataset
banned_ids = []
dataset = BasicBlockDataset("./data/ryzen3600_v14.cbuf", masked=True, banned_ids=banned_ids, prefilter=True)
print(f"Training with {len(dataset)} samples")

Training with 465539 samples


In [2]:
import torch.nn as nn
import torch.nn.functional as F

class MCAttentionHead(nn.Module):
     def __init__(self, in_size, heads=4):
        super().__init__()
        
        self.num_heads = heads
        
        self.key = nn.Linear(in_size, in_size, bias=False)
        self.query = nn.Linear(in_size, in_size, bias=False)
        self.value = nn.Linear(in_size, in_size, bias=False)
         
        self.proj = nn.Linear(in_size, in_size)
        
     def forward(self, nodes, edge_index, mask=None):
        B, T, C = nodes.shape
        
        k = self.key(nodes).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        q = self.query(nodes).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        
        weight = torch.matmul(q, k.transpose(-2, -1)) * self.num_heads ** -0.5
        
        scale = 3 * edge_index + torch.ones(edge_index.shape)

        weight = weight * scale
        
        if mask is not None:
            weight = weight.masked_fill(mask.unsqueeze(-1) == False, float('-inf'))

        # TODO for decoder weight must be a triangle matrix
        weight = F.softmax(weight, dim=-1)
        
        v = self.value(nodes).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        
        res = torch.matmul(weight, v)
        
        res = res.transpose(1, 2).contiguous().view(B, T, C)
        
        return self.proj(res)

In [3]:
import torch_geometric
import torch

loader = torch_geometric.loader.DataLoader(dataset, batch_size=1, shuffle=False)

bb = dataset[0][0]

print(bb)

from torch_geometric.utils import to_dense_adj, to_dense_batch

dense_nodes, mask = to_dense_batch(bb.x)
dense_adj = to_dense_adj(bb.edge_index.to(dtype=torch.int64))

print(dense_nodes)
print(dense_adj)
print(mask)

Data(x=[11], edge_index=[2, 36], y=4.25)
tensor([[    0,  2838,   520, 21001,  2798,   420,   413,  1790,  1834,  1399,
          1835]], dtype=torch.int32)
tensor([[[0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0.],
         [1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.],
         [1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0.],
         [1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
tensor([[True, True, True, True, True, True, True, True, True, True, True]])


In [4]:
embedding = torch.nn.Embedding(21002, 64)

embedded_nodes = embedding(dense_nodes)

In [5]:
head = MCAttentionHead(64)

print(dense_adj)

tensor([[[0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0.],
         [1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.],
         [1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0.],
         [1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])


In [47]:
class MCAttention(nn.Module):
    def __init__(self, in_size, heads=6):
        super().__init__()
        
        self.heads = nn.ModuleList([MCAttentionHead(in_size, heads) for _ in range(heads)])
        
    def forward(self, nodes, edge_index, mask=None):
        return torch.cat([h(nodes, edge_index, mask) for h in self.heads], dim=-1)

In [49]:
attn = MCAttention(64, 4)

print(attn(embedded_nodes, dense_adj, mask))

tensor([[[ 0.2542, -0.2073, -0.0994,  ...,  0.4446,  0.1520, -0.0060],
         [ 0.0114, -0.1102,  0.0033,  ...,  0.0111,  0.0599,  0.2841],
         [ 0.2915, -0.0176,  0.2914,  ...,  0.1328,  0.4376,  0.4120],
         ...,
         [-0.3651,  0.2059, -0.0525,  ..., -0.0204, -0.0605,  0.3920],
         [-0.0338,  0.0016,  0.0926,  ...,  0.0396, -0.1833,  0.2443],
         [ 0.1662, -0.0659,  0.1320,  ..., -0.0608, -0.1146,  0.1879]]],
       grad_fn=<CatBackward0>)
