In [1]:
import torch
import pickle
from sklearn.preprocessing import LabelEncoder
from sctokenizer import CTokenizer

from models import StackLSTM

In [2]:
from collections import defaultdict

In [3]:
from tqdm import tqdm

In [4]:
vocab_path = "vocab.pkl"

In [5]:
with open(vocab_path, "rb") as input_file:
    vocab = pickle.load(input_file)
le = LabelEncoder()
le.fit(["<SOC>", "<EOC>"])
le.fit(list(vocab))


In [6]:
def process_and_encode(text):
    tokenizer = CTokenizer()
    all_tokens = tokenizer.tokenize(text)
    filtered = [token for token in all_tokens if token.token_value in vocab]
    values = [token.token_value for token in filtered]
    lines = [token.line for token in filtered]
    code = torch.tensor(le.transform(values))
    return code, lines

In [7]:
HIDDEN_SIZE_CONTROLLER = 8
EMBED_DIM = 164
HIDDEN_SIZE_STACK = 8

In [8]:
model = StackLSTM(embedding_size=len(vocab),
                  embedding_dim=EMBED_DIM,
                  hidden_size_controller=HIDDEN_SIZE_CONTROLLER,
                  hidden_size_stack=HIDDEN_SIZE_STACK,
                  batch_size=1,
                  label_encoder=le)
model.load_state_dict(torch.load("model_7"))
model.eval()

StackLSTM(
  (embedding): Embedding(10001, 164)
  (controller): LSTMCell(172, 8)
  (output_linear): Linear(in_features=8, out_features=10001, bias=True)
  (softmax): Softmax(dim=None)
  (push_fc): Linear(in_features=8, out_features=1, bias=True)
  (pop_fc): Linear(in_features=8, out_features=1, bias=True)
  (values_fc): Linear(in_features=8, out_features=8, bias=True)
  (classifier): Linear(in_features=10001, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [17]:
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score, balanced_accuracy_score

def compute_metrics(pred, true):
    predicted = (torch.as_tensor(pred) > 0.5).long().tolist()
    return {
        "MCC": matthews_corrcoef(true, predicted),
        "F1": f1_score(true, predicted, average='macro'),
        "Acc": accuracy_score(true, predicted),
        "BAcc": balanced_accuracy_score(true, predicted),
        "Count": len(pred)
    }

In [14]:
import os
import glob
from datasets import Dataset

In [28]:
@torch.no_grad()
def eval(dataset_path, dataset_name):
    print(f"Evaluating {dataset_name}")
    preds = []
    trues = []
    index = []
    for name in tqdm(glob.glob(f"../data/test/{dataset_path}/*.c")):
        idx = name.split("/")[-1].split("_")[0]
        label = int(name.split("/")[-1].split("_")[-1].split(".")[0])

        with open(name, "r") as f:
            text = f.read()
        if len(text) > 1_000:
            continue
        data, lines = process_and_encode(text)
        pred = model(data.unsqueeze(1))
        preds.append(pred)
        trues.append(label)
        index.append(int(idx))
    data = Dataset.from_dict({
        "index": index,
        "pred": preds,
        "true": trues,
    })
    data.save_to_disk(os.path.join("report/prediction/stacklstm", dataset_name))
    print(compute_metrics(preds, trues))

In [25]:
from datasets import load_from_disk, concatenate_datasets

In [32]:
@torch.no_grad()
def eval_rest(dataset_path, dataset_name):
    print(f"Evaluating rest of {dataset_name}")
    dataset = load_from_disk(os.path.join("report/prediction/stacklstm", dataset_name))
    indices = set(dataset["index"])
    preds = []
    trues = []
    index = []
    remaining_files = list(filter(
        lambda name: int(name.split("/")[-1].split("_")[0]) not in indices,
        glob.glob(f"../data/test/{dataset_path}/*.c")
    ))
    for name in tqdm(remaining_files):
        idx = name.split("/")[-1].split("_")[0]
        if int(idx) in indices:
            continue
        label = int(name.split("/")[-1].split("_")[-1].split(".")[0])

        with open(name, "r") as f:
            text = f.read()
        data, lines = process_and_encode(text)
        pred = model(data.unsqueeze(1))
        preds.append(pred)
        trues.append(label)
        index.append(int(idx))
    data = Dataset.from_dict({
        "index": index,
        "pred": preds,
        "true": trues,
    })
    data = concatenate_datasets([dataset, data])
    data.save_to_disk(os.path.join("report/prediction/stacklstmfull", dataset_name))
    print(compute_metrics(preds, trues))

In [20]:
DATASETS = [
    ("test", "test"),
    ("apply_codestyle_Chromium", "perturbed-data/apply_codestyle_Chromium"),
    ("apply_codestyle_Google", "perturbed-data/apply_codestyle_Google"),
    ("apply_codestyle_LLVM", "perturbed-data/apply_codestyle_LLVM"),
    ("apply_codestyle_Mozilla", "perturbed-data/apply_codestyle_Mozilla"),
    ("apply_cobfuscate", "perturbed-data/apply_cobfuscate"),
    ("double_obfuscate", "perturbed-data/double_obfuscate"),
    ("obfuscate_then_style", "perturbed-data/obfuscate_then_style"),
    ("py_obfuscate_then_style", "perturbed-data/py_obfuscate_then_style"),
    ("apply_py_obfuscator", "perturbed-data/apply_py_obfuscator"),
]

In [23]:
for dataset_path, dataset_name in DATASETS:
    eval(dataset_path, dataset_name)

Evaluating test


100%|██████████| 18864/18864 [28:39<00:00, 10.97it/s]


Saving the dataset (0/1 shards):   0%|          | 0/14758 [00:00<?, ? examples/s]

{'MCC': 0.13678984353107346, 'F1': 0.4969692685639592, 'Acc': 0.747594525003388, 'BAcc': 0.6579116650584862, 'Count': 14758}
Evaluating perturbed-data/apply_codestyle_Chromium


100%|██████████| 18864/18864 [18:02<00:00, 17.42it/s]


Saving the dataset (0/1 shards):   0%|          | 0/14578 [00:00<?, ? examples/s]

{'MCC': 0.13187702429382447, 'F1': 0.4947459742013428, 'Acc': 0.7496227191658664, 'BAcc': 0.6549341126529664, 'Count': 14578}
Evaluating perturbed-data/apply_codestyle_Google


100%|██████████| 18864/18864 [17:06<00:00, 18.37it/s]


Saving the dataset (0/1 shards):   0%|          | 0/14677 [00:00<?, ? examples/s]

{'MCC': 0.13144484754144345, 'F1': 0.49424892971011414, 'Acc': 0.7480411528241466, 'BAcc': 0.6542571256837912, 'Count': 14677}
Evaluating perturbed-data/apply_codestyle_LLVM


100%|██████████| 18864/18864 [17:14<00:00, 18.23it/s]


Saving the dataset (0/1 shards):   0%|          | 0/14667 [00:00<?, ? examples/s]

{'MCC': 0.13158947555513492, 'F1': 0.4943429304468957, 'Acc': 0.7484829890229767, 'BAcc': 0.6545723263956544, 'Count': 14667}
Evaluating perturbed-data/apply_codestyle_Mozilla


100%|██████████| 18864/18864 [16:09<00:00, 19.45it/s]


Saving the dataset (0/1 shards):   0%|          | 0/14573 [00:00<?, ? examples/s]

{'MCC': 0.1320639440556975, 'F1': 0.4940070033304318, 'Acc': 0.7489878542510121, 'BAcc': 0.6563161876155679, 'Count': 14573}
Evaluating perturbed-data/apply_cobfuscate


100%|██████████| 18864/18864 [17:11<00:00, 18.29it/s]


Saving the dataset (0/1 shards):   0%|          | 0/9229 [00:00<?, ? examples/s]

{'MCC': 0.004021791959466382, 'F1': 0.4651473778460304, 'Acc': 0.7731065120814823, 'BAcc': 0.5045750405425328, 'Count': 9229}
Evaluating perturbed-data/double_obfuscate


100%|██████████| 18864/18864 [13:15<00:00, 23.71it/s]


Saving the dataset (0/1 shards):   0%|          | 0/5981 [00:00<?, ? examples/s]

{'MCC': 0.046706556629016675, 'F1': 0.4876382601873284, 'Acc': 0.8060524995820098, 'BAcc': 0.551146970344182, 'Count': 5981}
Evaluating perturbed-data/obfuscate_then_style


100%|██████████| 18864/18864 [27:32<00:00, 11.42it/s] 


Saving the dataset (0/1 shards):   0%|          | 0/7664 [00:00<?, ? examples/s]

{'MCC': 0.02462802417191854, 'F1': 0.47221536897740957, 'Acc': 0.7796189979123174, 'BAcc': 0.5284406867231883, 'Count': 7664}
Evaluating perturbed-data/py_obfuscate_then_style


100%|██████████| 18864/18864 [17:35<00:00, 17.87it/s]


Saving the dataset (0/1 shards):   0%|          | 0/13891 [00:00<?, ? examples/s]

{'MCC': 0.10518368032199671, 'F1': 0.5087766763787119, 'Acc': 0.8062054567705709, 'BAcc': 0.6118289169281866, 'Count': 13891}
Evaluating perturbed-data/apply_py_obfuscator


100%|██████████| 18864/18864 [22:19<00:00, 14.09it/s] 


Saving the dataset (0/1 shards):   0%|          | 0/14955 [00:00<?, ? examples/s]

{'MCC': 0.10581656233300865, 'F1': 0.5070603088971157, 'Acc': 0.7924439986626546, 'BAcc': 0.6107366957539835, 'Count': 14955}


In [33]:
for dataset_path, dataset_name in DATASETS:
    eval_rest(dataset_path, dataset_name)

Evaluating rest of test


100%|██████████| 4106/4106 [43:49<00:00,  1.56it/s]   


Saving the dataset (0/1 shards):   0%|          | 0/18864 [00:00<?, ? examples/s]

{'MCC': 0.12369790735738138, 'F1': 0.4525352542622657, 'Acc': 0.5187530443253775, 'BAcc': 0.5939212779973649, 'Count': 4106}
Evaluating rest of perturbed-data/apply_codestyle_Chromium


100%|██████████| 4286/4286 [39:40<00:00,  1.80it/s]   


Saving the dataset (0/1 shards):   0%|          | 0/18864 [00:00<?, ? examples/s]

{'MCC': 0.1269531915829264, 'F1': 0.45607519259178475, 'Acc': 0.5221651889874008, 'BAcc': 0.5958149253731343, 'Count': 4286}
Evaluating rest of perturbed-data/apply_codestyle_Google


100%|██████████| 4187/4187 [38:32<00:00,  1.81it/s]   


Saving the dataset (0/1 shards):   0%|          | 0/18864 [00:00<?, ? examples/s]

{'MCC': 0.12837772211680537, 'F1': 0.4571350630330112, 'Acc': 0.5223310245999523, 'BAcc': 0.5964592198160078, 'Count': 4187}
Evaluating rest of perturbed-data/apply_codestyle_LLVM


100%|██████████| 4197/4197 [37:09<00:00,  1.88it/s]   


Saving the dataset (0/1 shards):   0%|          | 0/18864 [00:00<?, ? examples/s]

{'MCC': 0.12760975225592358, 'F1': 0.4567841689665079, 'Acc': 0.5218012866333095, 'BAcc': 0.5958209659075148, 'Count': 4197}
Evaluating rest of perturbed-data/apply_codestyle_Mozilla


100%|██████████| 4291/4291 [38:30<00:00,  1.86it/s]   


Saving the dataset (0/1 shards):   0%|          | 0/18864 [00:00<?, ? examples/s]

{'MCC': 0.1263577459831757, 'F1': 0.45826695061782285, 'Acc': 0.5248193894197157, 'BAcc': 0.5949166750198018, 'Count': 4291}
Evaluating rest of perturbed-data/apply_cobfuscate


100%|██████████| 9635/9635 [1:04:53<00:00,  2.47it/s] 


Saving the dataset (0/1 shards):   0%|          | 0/18864 [00:00<?, ? examples/s]

{'MCC': 0.012657069295031641, 'F1': 0.4840612633616579, 'Acc': 0.7334717176959004, 'BAcc': 0.5098851388903045, 'Count': 9635}
Evaluating rest of perturbed-data/double_obfuscate


100%|██████████| 12883/12883 [1:33:24<00:00,  2.30it/s]  


Saving the dataset (0/1 shards):   0%|          | 0/18864 [00:00<?, ? examples/s]

{'MCC': 0.03278747105526105, 'F1': 0.4900945580173892, 'Acc': 0.74928199953427, 'BAcc': 0.5271306431489566, 'Count': 12883}
Evaluating rest of perturbed-data/obfuscate_then_style


100%|██████████| 11200/11200 [1:20:47<00:00,  2.31it/s]  


Saving the dataset (0/1 shards):   0%|          | 0/18864 [00:00<?, ? examples/s]

{'MCC': 0.02516744739300999, 'F1': 0.4868057468109849, 'Acc': 0.7370535714285714, 'BAcc': 0.5203822997120893, 'Count': 11200}
Evaluating rest of perturbed-data/py_obfuscate_then_style


100%|██████████| 4973/4973 [40:58<00:00,  2.02it/s]   


Saving the dataset (0/1 shards):   0%|          | 0/18864 [00:00<?, ? examples/s]

{'MCC': 0.10387707808413571, 'F1': 0.4769141368466195, 'Acc': 0.5829479187613111, 'BAcc': 0.5807659072336966, 'Count': 4973}
Evaluating rest of perturbed-data/apply_py_obfuscator


100%|██████████| 3909/3909 [36:07<00:00,  1.80it/s]   


Saving the dataset (0/1 shards):   0%|          | 0/18864 [00:00<?, ? examples/s]

{'MCC': 0.11116461929010228, 'F1': 0.4800397287814373, 'Acc': 0.5748273215656178, 'BAcc': 0.5838610415261586, 'Count': 3909}
