In [60]:
import pickle

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.functional as F

import os
import glob
import multiprocessing as mp
import pefile

from tqdm.notebook import tqdm
from sklearn.ensemble import RandomForestClassifier
from pathlib import Path

from sklearn.feature_extraction import FeatureHasher
from sklearn.model_selection import StratifiedKFold as KFold
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score

In [33]:
def read_json(path):
    with open(path, "r") as f:
        return json.load(f)
    
def load_pickle(path):
    with open(path, "rb") as f:
        return pickle.load(f)
    
def dump_pickle(vector, path):
    with open(path, "wb") as f:
        pickle.dump(vector, f)

def get_label_table(path):
    table = dict()
    with open(path, "r") as f:
        for line in f.readlines()[1:]:
            md5, label = line.strip().split(",")
            table[md5] = int(label)
    return table

def load_vectors(base_path):
    vectors, labels = [], []
    for path in tqdm(glob.glob(base_path)):
        md5 = os.path.basename(path)[:-4]
        vectors.append(load_pickle(path))
        labels.append(label_table[md5+'.vir'])
    return np.array(vectors), np.array(labels)

word_to_index = load_pickle("word_to_index.pkl")
index_to_word = load_pickle("index_to_word.pkl")

In [95]:
FOLD = 5
SEED = 41

def cross_validation(X, y):
    kf = KFold(n_splits=FOLD, shuffle=True, random_state=SEED)

    accs, precs, recs, f1s = [], [], [], []
    predicts = []
    targets = []
    target_hashs = []
    target_X = []
    for i, (train_idx, test_idx) in enumerate(kf.split(X, y)):
        X_train, X_valid = X[train_idx], X[test_idx]
        y_train, y_valid = y[train_idx], y[test_idx]

        clf = RandomForestClassifier(random_state=SEED, n_jobs=-1)
        clf.fit(X_train, y_train)
        predict = clf.predict(X_valid)

        accs.append(accuracy_score(y_valid, predict))
        precs.append(precision_score(y_valid, predict))
        recs.append(recall_score(y_valid, predict))
        f1s.append(f1_score(y_valid, predict))
    return np.average(accs), np.average(precs), np.average(recs),np.average(f1s)

unknown_function = dict()
def iat_embedding(functions, how):
    n_features = len(embedding_table[0])
    embedded_vector = np.zeros(n_features)
    for function in functions:
        if function in word_to_index:
            index = word_to_index[function]
        else:
            index = word_to_index["<unk>"]
            
        if how == "max":
            embedded_vector = np.maximum(embedded_vector, embedding_table[index])
        else:
            embedded_vector = np.minimum(embedded_vector, embedding_table[index])
        unknown_function[function] = unknown_function.get(function, 0) + 1
    return embedded_vector.tolist()


def iat_feature_hashing(functions, n_features):
    feature_vector = [0] * n_features
    for impstr in functions:
        hash_value = int(hashlib.sha256(impstr.encode()).hexdigest(), 16)
        feature_vector[hash_value & (SIZE_OF_IMPORT_EXPORT - 1)] += 1
    return feature_vector
#     if functions:
#         return FeatureHasher(n_features=n_features, input_type="string").fit_transform(functions).toarray()[0]
#     else:
#         return np.zeros(n_features).tolist()


def load_iat_with_processing(base_path, processing, n_features):
    vectors, labels = [], []
    for path in tqdm(glob.glob(base_path)):
        md5 = os.path.basename(path)[:-4]
        if processing == "embedding":
            vectors.append(iat_embedding(load_pickle(path), how="max"))
        else:
            vectors.append(iat_feature_hashing(load_pickle(path), n_features))
        labels.append(label_table[md5+'.vir'])
    return np.array(vectors), np.array(labels)


def load_header_iat_with_processing(base_path, processing, n_features):
    vectors, labels = [], []
    for header_path, iat_path in tqdm(base_path):
        md5 = os.path.basename(header_path)[:-4]
        header_vector = load_pickle(header_path)
        
        if processing == "embedding":
            iat_vector = iat_embedding(load_pickle(iat_path), how="max")
        else:
            iat_vector = iat_feature_hashing(load_pickle(iat_path), n_features)
        
        vectors.append(header_vector + iat_vector)
        labels.append(label_table[md5+'.vir'])
    return np.array(vectors), np.array(labels)




## Only IAT

In [88]:
label_table = get_label_table("label.csv")
n_features = 200
X_fh, y_fh = load_iat_with_processing("data/FeatureVector/iat_vector/*", "fh", n_features)

print("FH", cross_validation(X_fh, y_fh))

HBox(children=(IntProgress(value=0, max=39998), HTML(value='')))


FH (0.9232210182522815, 0.9383312276829499, 0.9554195804195805, 0.9467962588553792)


In [85]:
embedding_table = torch.load("Pretrained_Apicall_Vector_200.pkl")
n_features = len(embedding_table[0])
X_emb, y_emb = load_iat_with_processing("data/FeatureVector/iat_vector/*", "embedding", n_features)


print(f"Embedding {n_features} maximum", cross_validation(X_emb, y_emb))

HBox(children=(IntProgress(value=0, max=39998), HTML(value='')))


Embedding 200 maximum (0.8828940461307664, 0.8875011041854528, 0.9576223776223778, 0.9212266820076284)


In [86]:
embedding_table = torch.load("Pretrained_Apicall_Vector_100.pkl")
n_features = len(embedding_table[0])
X_emb, y_emb = load_iat_with_processing("data/FeatureVector/iat_vector/*", "embedding", n_features)

print(f"Embedding {n_features} maximum", cross_validation(X_emb, y_emb))

HBox(children=(IntProgress(value=0, max=39998), HTML(value='')))


Embedding 100 maximum (0.8832940930116265, 0.888086084990231, 0.9574475524475524, 0.9214606194066342)


## Header + IAT

In [96]:
label_table = get_label_table("label.csv")
n_features = 200
X_fh, y_fh = load_header_iat_with_processing(list(zip(glob.glob(r"data/FeatureVector/header_vector/*"), glob.glob(r"data/FeatureVector/iat_vector/*"))), "fh", n_features)

print("FH", cross_validation(X_fh, y_fh))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


FH (0.9644981278909863, 0.9689449359326974, 0.9818181818181818, 0.9753373162568144)


In [100]:
embedding_table = torch.load("Pretrained_Apicall_Vector_200.pkl")
n_features = len(embedding_table[0])
X_emb, y_emb = load_header_iat_with_processing(list(zip(glob.glob(r"data/FeatureVector/header_vector/*"), glob.glob(r"data/FeatureVector/iat_vector/*"))), "embedding", n_features)


print(f"Embedding {n_features} maximum", cross_validation(X_emb, y_emb))

HBox(children=(IntProgress(value=0, max=39998), HTML(value='')))


Embedding 200 maximum (0.9638981403925492, 0.9673699243959369, 0.9826573426573427, 0.9749532462059266)


In [101]:
embedding_table = torch.load("Pretrained_Apicall_Vector_100.pkl")
n_features = len(embedding_table[0])
X_emb, y_emb = load_header_iat_with_processing(list(zip(glob.glob(r"data/FeatureVector/header_vector/*"), glob.glob(r"data/FeatureVector/iat_vector/*"))), "embedding", n_features)


print(f"Embedding {n_features} maximum", cross_validation(X_emb, y_emb))

HBox(children=(IntProgress(value=0, max=39998), HTML(value='')))


Embedding 100 maximum (0.9645481153894236, 0.9683356165706545, 0.9825524475524474, 0.9753905839567109)


In [102]:
unknown_function

{'_cordllmain': 806,
 '??2@yapaxi@z': 2064,
 '_except_handler3': 3336,
 '_c_exit': 1068,
 '_exit': 4278,
 '_xcptfilter': 5202,
 '_cexit': 3372,
 'exit': 5164,
 '_acmdln': 2312,
 '__getmainargs': 3870,
 '_initterm': 9132,
 '__setusermatherr': 4286,
 '_adjust_fdiv': 5018,
 '__p__commode': 3676,
 '_vsnwprintf': 1618,
 '__p__fmode': 4124,
 '__set_app_type': 4732,
 '_controlfp': 3092,
 '??3@yaxpax@z': 2322,
 'entercriticalsection': 27406,
 'loadlibraryexa': 14638,
 'setevent': 12260,
 'gettempfilenamew': 4306,
 'gettickcount': 35996,
 'gettemppathw': 6264,
 'unmapviewoffile': 7016,
 'getmodulehandlew': 16866,
 'closehandle': 44522,
 'createprocessw': 6008,
 'getmodulefilenamew': 14778,
 'getlasterror': 45826,
 'setlasterror': 18466,
 'deletecriticalsection': 26180,
 'lstrcpynw': 4956,
 'createfilew': 13514,
 'getcommandlinew': 11508,
 'setunhandledexceptionfilter': 16964,
 'initializecriticalsection': 21034,
 'createfilemappingw': 2710,
 'queryperformancecounter': 20618,
 'getcurrentthreadi

In [103]:
len(unknown_function)

95633