# jTrans evaluation code

This notebook calculates the MRR and Recall based on jTrans author's codes. Author also provides testing script in [fasteval](https://github.com/vul337/jTrans/blob/main/fasteval.py) and [eval_save.py](https://github.com/vul337/jTrans/blob/main/eval_save.py). If you have good GPU (vram>=48G) and slow Disk, this notebook will speed up evaluation speed, otherwise please use author's own code to evaluate the model.

In [None]:
import torch
torch.cuda.get_device_name(0)
from transformers import BertTokenizer, BertForMaskedLM, BertModel
from tokenizer import *
import pickle
from torch.utils.data import DataLoader
import os
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from data import help_tokenize, load_paired_data, FunctionDataset_CL, FunctionDataset_CL_Load
from transformers import AdamW
import torch.nn.functional as F
import argparse
import logging
import sys
import time
import data as data
import pickle
import sys
from datautils_windows.playdata import DatasetBase as DatasetBase
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import argparse

In [None]:
def eval(net, data_loader):
    net.eval()
    with torch.no_grad():
        avg = []
        total_true_positives = 0
        total_retrieved = 0
        total_relevant = 0
        eval_iterator = tqdm(data_loader)

        recall2 = 0
        recall5 = 0
        recall10 = 0

        for i, (seq1, seq2, _, mask1, mask2, _) in enumerate(eval_iterator):
            input_ids1, attention_mask1 = seq1, mask1
            input_ids2, attention_mask2 = seq2, mask2

            output1 = net(input_ids=input_ids1, attention_mask=attention_mask1)
            anchor = output1.pooler_output

            output2 = net(input_ids=input_ids2, attention_mask=attention_mask2)
            pos = output2.pooler_output

            ans = 0
            for k in range(len(anchor)):  # Use a different loop index variable
                vA = anchor[k:k+1]  # No need to call  again
                sim = []
                for j in range(len(pos)):
                    vB = pos[j:j+1]
                    AB_sim = F.cosine_similarity(vA, vB).item()
                    sim.append(AB_sim)
                
                sim = np.array(sim)
                y = np.argsort(sim)[::-1]  # Sort in descending order of similarity
                posi = np.where(y == k)[0][0] + 1  # Find the position of the ground truth

                # These codes are from original author's codes, which they gave explanation in their paper
                if posi == 1:
                    total_true_positives += 1
                if posi <= 2:
                    recall2 +=1
                if posi <= 5:
                    recall5 +=1
                if posi <= 10:
                    recall10 +=1

                ans += 1 / posi
            
            # Update total counts
            total_relevant += len(anchor)
            total_retrieved += len(anchor)
            
            ans = ans / len(anchor)
            avg.append(ans)

        return avg, total_true_positives, recall2, recall5, recall10, total_relevant

class BinBertModel(BertModel):
    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config
        self.embeddings.position_embeddings=self.embeddings.word_embeddings

In [None]:
model_path = "models/jTrans-finetune" # Download from author's github
model = BinBertModel.from_pretrained(model_path)

eval_path = "/data/jTrans/some_extract" # Generated from previous notebook
model = nn.DataParallel(model)

In [None]:
tokenizer = BertTokenizer.from_pretrained("./jtrans_tokenizer")
valid_set = FunctionDataset_CL_Load(tokenizer, eval_path, convert_jump_addr=True, opt=None) 
# FunctionDataset_CL_Load might need to be modified to adopt cross compiler data/file name convention change, 
# or simply use DataBaseCrossCompiler provided by author
valid_dataloader = DataLoader(valid_set, batch_size=10000, num_workers=128, shuffle=True)


### Warning

If you see GPU OOM, please use author's code to evaluate the model or go with CPU if you have large MEM installed (but will be slower)
If it gives out outputs, you can use the numbers to calculate Recall@* and MRR

In [None]:
eval(model, valid_dataloader)