In [1]:
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Dataset, Data, InMemoryDataset
import torch_geometric.nn as nn

In [9]:
import os
labels = ['addr_contract', 'caller', 'msgvalue', 'balance', 'call_data', 'blk', 'mdata', 'sdata', 'create', 'call', 'callcode', 'delegatecall', 'create2', 'staticcall', 'cal_res', 'comp_res', 'bit_res', 'size', 'code', 'gas', 'return', 'coinbase', 'gasremain', 'revert', 'selfdestruct', 'memory', 'storage', 'flowcontrol']
node_types = ['ADDRESS', 'ORIGIN', 'CALLER', 'CALLVALUE', 'BALANCE', 'SELFBALANCE', 'CALLDATALOAD', 'CALLDATACOPY', 'BLOCKHASH', 'TIMESTAMP', 'NUMBER', 'DIFFICULTY', 'BASEFEE', 'MLOAD', 'SLOAD', 'CREATE', 'CALL', 'CALLCODE', 'DELEGATECALL', 'CREATE2', 'STATICCALL', 'ADD', 'MUL', 'SUB', 'EXP', 'LT', 'GT', 'SLT', 'SGT', 'EQ', 'ISZERO', 'AND', 'OR', 'XOR', 'NOT', 'SHL', 'CALLDATASIZE', 'CODESIZE', 'EXTCODESIZE', 'RETURNDATASIZE', 'MSIZE', 'CODECOPY', 'EXTCODECOPY', 'EXTCODEHASH', 'GASPRICE', 'GASLIMIT', 'RETURNDATACOPY', 'RETURN', 'COINBASE', 'GAS', 'REVERT', 'SELFDESTRUCT', 'MSTORE', 'MSTORE8', 'SSTORE', 'JUMP', 'JUMPI', 'JUMPDEST', 'STOP', 'DIV', 'SDIV', 'MOD', 'SMOD', 'ADDMOD', 'SIGNEXTEND', 'BYTE', 'SHR', 'SAR', 'SHA3', 'CHAINID', 'POP', 'PC', 'PUSH1', 'PUSH2', 'PUSH3', 'PUSH4', 'PUSH5', 'PUSH6', 'PUSH7', 'PUSH8', 'PUSH9', 'PUSH10', 'PUSH11', 'PUSH12', 'PUSH13', 'PUSH14', 'PUSH15', 'PUSH16', 'PUSH17', 'PUSH18', 'PUSH19', 'PUSH20', 'PUSH21', 'PUSH22', 'PUSH23', 'PUSH24', 'PUSH25', 'PUSH26', 'PUSH27', 'PUSH28', 'PUSH29', 'PUSH30', 'PUSH31', 'PUSH32', 'DUP1', 'DUP2', 'DUP3', 'DUP4', 'DUP5', 'DUP6', 'DUP7', 'DUP8', 'DUP9', 'DUP10', 'DUP11', 'DUP12', 'DUP13', 'DUP14', 'DUP15', 'DUP16', 'SWAP1', 'SWAP2', 'SWAP3', 'SWAP4', 'SWAP5', 'SWAP6', 'SWAP7', 'SWAP8', 'SWAP9', 'SWAP10', 'SWAP11', 'SWAP12', 'SWAP13', 'SWAP14', 'SWAP15', 'SWAP16', 'LOGO', 'LOG1', 'LOG2', 'LOG3', 'LOG4', 'PUSH', 'DUP', 'SWAP']
node_attrs = node_types + labels
class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        self.Ngraph = 44
        super().__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        ngraph = self.Ngraph
        vers = [f'{idx}.ver' for idx in range(ngraph)]
        edgs = [f'{idx}.edg' for idx in range(ngraph)]
        bugs = [f'{idx}.type' for idx in range(ngraph)]
        return vers + edgs + bugs

    @property
    def processed_file_names(self):
        ngraph = self.Ngraph
        graphs = [f'{idx}.grap' for idx in range(ngraph)]
        return graphs

    def download(self):
        #  download the source file to `self.raw_dir`.
        pass
        print("in download")
        # raise RuntimeError("in download")

    def process(self):
        self.exist_processed_file_names = []
        for i, j, files in os.walk(self.processed_dir):
            self.exist_processed_file_names = files
            break

        for f in self.processed_file_names:
          if f not in self.exist_processed_file_names:
            print(f"process new file {f}")
            out_path = os.path.join(self.processed_dir, f)
            data = self._process_per_graph(f)
            torch.save(data, out_path)
    
    def _process_per_graph(self, f):
        idx = f[:-5]
        verPath = os.path.join(self.raw_dir, idx+'.ver')
        edgPath = os.path.join(self.raw_dir, idx+'.edg')
        bugPath = os.path.join(self.raw_dir, idx+'.type')
        edge_index = []
        edge_attr = []
        x = []
        for line in open(edgPath, "r"):
            line = line.strip('\n')
            line = line.replace(' ', '').split(',')
            link = [int(line[0]), int(line[1])]
            edge_index.append(link)
            if line[2] == 'exec':
                attr = [0, int(line[3])]
            else:
                attr = [1, int(line[3])]
            edge_attr.append(attr)
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)

        for line in open(verPath, "r"):
            line = line.strip('\n')
            attrOneHot = [0 for i in range(len(node_attrs))]
            line = line.replace(' ', '').replace('\'','')
            attr_begin = line.index('[')+1
            nodeType = line[:attr_begin-1].split(',')[1]
            attrList = line[attr_begin:-1].split(',')
            if '' in attrList:
                attrList.remove('')
            attrList.append(nodeType)
            for attr in attrList:
                idx = node_attrs.index(attr)
                attrOneHot[idx] = 1
            x.append(attrOneHot)
        x = torch.tensor(x, dtype=torch.float)

        for line in open(bugPath, "r"):
          line = line.strip('\n')
          y = int(line)

        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
        return data
    def len(self):
        return len(self.processed_file_names)

    def getitem(self, idx):
        data = torch.load(os.path.join(self.processed_dir, self.processed_file_names[idx]))
        return data

In [10]:
dst = MyOwnDataset("../data/smartbugs")

process new file 1.grap
process new file 2.grap
process new file 3.grap
process new file 4.grap
process new file 5.grap
process new file 6.grap
process new file 7.grap
process new file 8.grap


Processing...


process new file 9.grap
process new file 10.grap
process new file 11.grap
process new file 12.grap
process new file 13.grap
process new file 14.grap
process new file 15.grap
process new file 16.grap
process new file 17.grap
process new file 18.grap
process new file 19.grap
process new file 20.grap
process new file 21.grap
process new file 22.grap
process new file 23.grap
process new file 24.grap
process new file 25.grap
process new file 26.grap
process new file 27.grap
process new file 28.grap
process new file 29.grap
process new file 30.grap
process new file 31.grap
process new file 32.grap
process new file 33.grap
process new file 34.grap
process new file 35.grap
process new file 36.grap
process new file 37.grap
process new file 38.grap
process new file 39.grap
process new file 40.grap
process new file 41.grap
process new file 42.grap
process new file 43.grap


Done!


In [15]:
dst.getitem(43)

Data(x=[105, 172], edge_index=[2, 162], edge_attr=[162, 2], y=1)