In [None]:
# 导入一些有用的模块
import gc
import shutil
import os
from argparse import ArgumentParser

from gensim.models.word2vec import Word2Vec

import configs
import src.data as data
import src.prepare as prepare
import src.process as process
import src.utils.functions.cpg as cpg

In [None]:
# 读入配置信息
PATHS = configs.Paths()
FILES = configs.Files()
DEVICE = FILES.get_device()

In [None]:
# 原始的代码数据集存放在 data/raw (Paths.raw)中
# 每一条数据包含：
# project名，commit_id，target（是否为漏洞代码），func（函数代码文本）
# 总共27318条数据，为了节约时间本实验中只使用其中的200条
# 可以尝试修改参数DATA_SELETED来提高数据量
# 并使用joern工具将代码数据转化为CPG图
# 感兴趣的同学可以了解一下joern： https://docs.joern.io/home
# 代码图数据存放在data/cpg (Paths.cpg)中，其中每100条数据写入一个pkl文件中
DATA_SELETED=200
def select(dataset):
    result = dataset.loc[dataset['project'] == "FFmpeg"]
    len_filter = result.func.str.len() < 1200
    result = result.loc[len_filter]
    #print(len(result))
    #result = result.iloc[11001:]
    #print(len(result))
    # 暂时只使用前DATA_SELETED条数据
    result = result.head(DATA_SELETED)

    return result

shutil.rmtree(PATHS.joern, True)
context = configs.Create()
raw = data.read(PATHS.raw, FILES.raw)
filtered = data.apply_filter(raw, select)
filtered = data.clean(filtered)
assert DATA_SELETED <= len(filtered)
data.drop(filtered, ["commit_id", "project"])
slices = data.slice_frame(filtered, context.slice_size)
slices = [(s, slice.apply(lambda x: x)) for s, slice in slices]

cpg_files = []
# Create CPG binary files
for s, slice in slices:
    data.to_files(slice, PATHS.joern)
    cpg_file = prepare.joern_parse(context.joern_cli_dir, PATHS.joern, PATHS.cpg, f"{s}_{FILES.cpg}")
    cpg_files.append(cpg_file)
    print(f"Dataset {s} to cpg.")
    shutil.rmtree(PATHS.joern)
# Create CPG with graphs json files
json_files = prepare.joern_create(context.joern_cli_dir, PATHS.cpg, PATHS.cpg, cpg_files)
for (s, slice), json_file in zip(slices, json_files):
    graphs = prepare.json_process(PATHS.cpg, json_file)
    if graphs is None:
        print(f"Dataset chunk {s} not processed.")
        continue
    dataset = data.create_with_index(graphs, ["Index", "cpg"])
    dataset = data.inner_join_by_index(slice, dataset)
    print(f"Writing cpg dataset chunk {s}.")
    data.write(dataset, PATHS.cpg, f"{s}_{FILES.cpg}.pkl")
    del dataset
    gc.collect()

In [None]:
# 训练用于编码节点中文本的word2vec模型

context = configs.Embed()
# 对源代码进行分词
dataset_files = data.get_directory_files(PATHS.cpg)
w2vmodel = Word2Vec(**context.w2v_args)
w2v_init = True
for pkl_file in dataset_files:
    file_name = pkl_file.split(".")[0]
    cpg_dataset = data.load(PATHS.cpg, pkl_file)
    tokens_dataset = data.tokenize(cpg_dataset)
    data.write(tokens_dataset, PATHS.tokens, f"{file_name}_{FILES.tokens}")
    # 使用word2vec去学习节点文本的初始编码
    w2vmodel.build_vocab(sentences=tokens_dataset.tokens, update=not w2v_init)
    w2vmodel.train(tokens_dataset.tokens, total_examples=w2vmodel.corpus_count, epochs=1)
    if w2v_init:
        w2v_init = False
    # 对CPG图中的节点文本进行编码并进行存储
    cpg_dataset["nodes"] = cpg_dataset.apply(lambda row: cpg.parse_to_nodes(row.cpg, context.nodes_dim), axis=1)
    # 去掉没有节点的数据
    cpg_dataset = cpg_dataset.loc[cpg_dataset.nodes.map(len) > 0]
    cpg_dataset["input"] = cpg_dataset.apply(lambda row: prepare.nodes_to_input(row.nodes, row.target, context.nodes_dim,
                                                                                w2vmodel.wv, context.edge_type), axis=1)
    data.drop(cpg_dataset, ["nodes"])
    print(f"Saving input dataset {file_name} with size {len(cpg_dataset)}.")
    data.write(cpg_dataset[["input", "target"]], PATHS.input, f"{file_name}_{FILES.input}")
    del cpg_dataset
    gc.collect()
print("Saving w2vmodel.")
w2vmodel.save(f"{PATHS.w2v}/{FILES.w2v}")

In [None]:
# 模型训练和验证
context = configs.Process()
devign = configs.Devign()
model_path = PATHS.model + FILES.model
model = process.Devign(path=model_path, device=DEVICE, model=devign.model, learning_rate=devign.learning_rate,
                        weight_decay=devign.weight_decay,
                        loss_lambda=devign.loss_lambda)
train = process.Train(model, context.epochs)
input_dataset = data.loads(PATHS.input)
# 划分数据集并且使用DataLoad来加载数据
train_loader, val_loader, test_loader = list(
    map(lambda x: x.get_loader(context.batch_size, shuffle=context.shuffle),
        data.train_val_test_split(input_dataset, shuffle=context.shuffle)))
train_loader_step = process.LoaderStep("Train", train_loader, DEVICE)
val_loader_step = process.LoaderStep("Validation", val_loader, DEVICE)
test_loader_step = process.LoaderStep("Test", test_loader, DEVICE)

train(train_loader_step, val_loader_step)
model.save()

process.predict(model, test_loader_step)