In [None]:
from transformers import AutoModelForMaskedLM
from attack_classification import USE
import transformers
import math
import language_tool_python


In [None]:
import os

os.environ['http_proxy'] = ''
os.environ['https_proxy'] = ''

class GPT2LM:
    def __init__(self, cuda=-1, model_resolution = './GPT2'):
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(model_resolution)
        self.lm = transformers.GPT2LMHeadModel.from_pretrained(model_resolution)
        # self.lm = torch.load('gpt2-large.pkl')
        self.cuda = cuda
        if self.cuda >= 0 :
            self.lm.cuda(self.cuda)

    def __call__(self, sent):
        """
        :param str sent: A sentence.
        :return: Fluency (ppl).
        :rtype: float
        """
        sent = sent.lower()
        ipt = self.tokenizer(sent, return_tensors="pt", verbose=False)
        
        if self.cuda >= 0:
            for k in ipt.keys():
                ipt[k] = ipt[k].cuda(self.cuda)
        
        return math.exp(self.lm(**ipt, labels=ipt.input_ids)[0])
    
class GrammarChecker:
    def __init__(self):
        self.lang_tool = language_tool_python.LanguageTool('en-US')
        # self.lang_tool = language_tool_python.LanguageToolPublicAPI('en-US')
        # self.lang_tool = language_tool_python.LanguageTool('en-US', remote_server="http://localhost:8081")


    def __call__(self, sentence):
        sentence = sentence.lower()
        matches = self.lang_tool.check(sentence)
        return len(matches)
    

gpt = GPT2LM(cuda=0)
grammarchecker = GrammarChecker()
use = USE([0])



In [None]:
from tqdm import tqdm
def read_log(path):
    ori, adv = [], []
    with open(path) as f:
        lines = f.readlines()
        for line in lines:
            line = ' '.join(line.split()[:512])
            if line.startswith('orig sent '):
                ori.append(line[15:].replace("- ", "").replace("n \' t","n\'t").replace("' d", "'d").replace("\' s","\'s"))
            elif line.startswith("adv sent "):
                adv.append(line[14:].replace("- ", "").replace("n \' t","n\'t").replace("' d", "'d").replace("\' s","\'s"))
    print(f"ORI{len(ori)}, ADV{len(adv)}")
    return ori, adv

def eval_PPL(oris,advs, metric, uses=None, limit=0.0):
    print("Eval with" + str(metric))
    ori_PPLs = []
    adv_PPLs = []
    delta_PPLs = []
    length = len(oris)

    def mean(L):
        return sum(L)/ len(L)
    
    for i in tqdm(range(length)):
        if uses is not None and uses[i] < limit: continue
        ori, adv = oris[i], advs[i]
        try:
            ori_PPL = metric(ori)
            adv_PPL = metric(adv)
        except Exception as result:
            print(result, ori, adv)
            continue
        delta_PPL = (adv_PPL - ori_PPL)
        
        ori_PPLs.append(ori_PPL)
        adv_PPLs.append(adv_PPL)
        delta_PPLs.append(delta_PPL)


    return  {
        'ori': mean(ori_PPLs), 
        'adv': mean(adv_PPLs), 
        'delta': mean(delta_PPLs), 
        'delta%': str(mean(delta_PPLs) / mean(ori_PPLs) * 100)[:7] + '%',
        }

def eval_score(oris, advs, metric):
    print("Eval with" + str(metric))
    scores = []
    length = len(oris)
    def mean(L):
        return sum(L)/ len(L)
    
    for i in tqdm(range(length)):
        ori, adv = oris[i], advs[i]

        score = metric(ori, adv)
        scores.append(score)

        # if i % 100 == 0:
        #     print(mean(scores))

    return mean(scores), scores

def dict_putout(d):
    _d = {}
    for key in d.keys():
        if type(d[key]) != type({}):
            _d[key] = d[key]
        else:
            output_d = dict_putout(d[key])
            for _key in output_d.keys():
                _d[key + '_' + _key] = output_d[_key]
    return _d.copy()

eval_datas = {}


filenames = [
        '/home/ubuntu/pth/adv_results/....',
    ]

for file in filenames:
    if  ('.' in file) or ('adversaries' in file):
        ori, adv = read_log(file)
        eval_datas[file] = {}
        
        eval_datas[file]['GrammarError'] = eval_PPL(ori, adv, grammarchecker)
        eval_datas[file]['USE'], uses = eval_score(ori, adv, use)
        eval_datas[file]['PPL'] = eval_PPL(ori, adv, gpt, uses, 0.7)
        
        eval_datas[file] = dict_putout(eval_datas[file])
        print(file, eval_datas[file])
