In [1]:
import numpy as np
import pandas as pd
from transformers import AutoTokenizer
import torch
import json
from model.longformer_tfidf import LongformerTFIDFForSequenceClassification

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "models/tfidf/longformer-large-seed100" # longformer T

max_length = 4096
truncation_side = 'left'
lower = True

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=max_length)

device = torch.device('cuda:1')

tokenizer = AutoTokenizer.from_pretrained(model_path, truncation_side=truncation_side)
model = LongformerTFIDFForSequenceClassification.from_pretrained(model_path, 
                                                            num_labels=2).to(device)
model.eval()

EnsembleLongformerForSequenceClassification(
  (longformer): LongformerModel(
    (embeddings): LongformerEmbeddings(
      (word_embeddings): Embedding(50265, 1024, padding_idx=1)
      (position_embeddings): Embedding(4098, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): LongformerEncoder(
      (layer): ModuleList(
        (0): LongformerLayer(
          (attention): LongformerAttention(
            (self): LongformerSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (query_global): Linear(in_features=1024, out_features=1024, bias=True)
              (key_global): Linear(in_features=1024, out_features=1024, bias=Tr

In [3]:
# Load ILDC expert
import glob
import re
files = sorted(glob.glob('data/ILDC/ILDC_expert/source/*.txt'))
len(files)

ori_texts = [open(file).read() for file in files]
texts = [' '.join(text.split('\n')) for text in ori_texts]
texts = [re.sub('\s+', ' ', text).strip() for text in texts]

if lower:
    texts = [text.lower() for text in texts]

In [6]:
from tqdm import tqdm
import numpy as np
import pickle as pkl
from data_collator.data_collator_tfidf import DataCollatorTFIDF

vectorizer = pkl.load(open('tfidf_vectorizer-threshold350.pkl', 'rb'))
data_collator = DataCollatorTFIDF(tokenizer=tokenizer)

def process_input(texts):
    features = []
    for text in texts:
        feature = tokenizer(text, truncation=True, max_length=max_length)
        tfidf_vector = np.array(vectorizer.transform([text]).todense())[0]
        feature['tfidf_feature'] = tfidf_vector
        features.append(feature)
    batch = data_collator(features)
    return batch

bs = 2

predicts = []
for i in tqdm(range(0, len(texts), bs)):
    batch = texts[i:i+bs]
    batch = process_input(batch)

    for k in batch:
        batch[k] = batch[k].to(device)

    with torch.no_grad():
        outputs = model(**batch)
        logits = outputs.logits.cpu().numpy()
        predict = logits.argmax(axis=1)
        predicts.extend(predict)

  tensor = as_tensor(value)
100%|█████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:07<00:00,  3.50it/s]


In [9]:
def get_prediction_probability(texts):
    batch = process_input(texts)
    for k in batch:
        batch[k] = batch[k].to(device)
    with torch.no_grad():
        outputs = model(**batch)
        logits = outputs.logits.cpu().numpy()
        
    probs = (np.exp(logits).T / np.exp(logits).sum(-1)).T
    return probs

import re

from nltk.tokenize.punkt import PunktSentenceTokenizer

def custom_tokenizer_nltk(text, return_offsets_mapping=True):
    sentences = []
    offset_mapping = []
    
    for start, end in PunktSentenceTokenizer().span_tokenize(text):
        length = end - start
        sentences.append(text[start:end])
        offset_mapping.append((start, end))
        
    out = {
        "input_ids": sentences,
        "offset_mapping": offset_mapping
    }
    
    return out

def get_end_of_text(text):
    text = re.sub(" decision \?\?$", "", text)
    last_tokens = tokenizer.encode(text, padding=True, max_length=4096, truncation=True)
    end_of_text = tokenizer.decode(last_tokens)[3:-4] # remove <s> </s>
    return end_of_text



In [10]:
def get_explanation(text, k='40%'):
    end_of_text = get_end_of_text(text)
    sentence_score = []

    masks = custom_tokenizer_nltk(text)
    sentences = masks['input_ids']
    offset_mappings = masks['offset_mapping']

    original_doc_logits = get_prediction_probability([end_of_text])[0]
    predict_label = original_doc_logits.argmax(axis=-1)
    original_doc_score = original_doc_logits[predict_label]

    for sentence, offset in zip(sentences, offset_mappings):
        masked_doc = end_of_text[:offset[0]] + end_of_text[offset[1]:]
        masked_doc_score = get_prediction_probability([masked_doc])[0, predict_label]

        important_score = original_doc_score - masked_doc_score

        sentence_score.append((sentence, important_score))
    
    sentence_score = sorted(sentence_score, key=lambda x: x[1], reverse=True)
    sorted_sentences = [x[0] for x in sentence_score]
    
    if k == '40%':
        topk = int(len(sentences) * 0.4)
    elif k == '10':
        topk = 10
        
    explain = ' '.join(sorted_sentences[:topk]).strip()
    explain = re.sub('\s+', ' ', explain)
    return explain

explanations = []

for text in tqdm(texts):
    explanation = get_explanation(text, k='40%')
    explanations.append(explanation)

100%|█████████████████████████████████████████████████████████████████████████████████████| 56/56 [15:44<00:00, 16.87s/it]


In [11]:
occ_exp = {}
for explain, file in zip(explanations, files):
    filename = file.split('/')[-1]
    occ_exp[filename] = explain

In [12]:
# This code is inherited from https://github.com/Exploration-Lab/CJPE

import nltk
from nltk.tokenize import word_tokenize 
from rouge import Rouge 
import nltk.translate
from tqdm import tqdm
import numpy as np

def get_BLEU_score(ref_text, machine_text):
    tok_ref_text = word_tokenize(ref_text)
    tok_machine_text = word_tokenize(machine_text)
    sc = nltk.translate.bleu_score.sentence_bleu([tok_ref_text], tok_machine_text, weights = (0.5,0.5))
    return sc

def jaccard_similarity(query, document):
    query = word_tokenize(query)
    document = word_tokenize(document)
    intersection = set(query).intersection(set(document))
    union = set(query).union(set(document))
    if(len(union)==0):
        return 0
    return len(intersection)/len(union)

def overlap_coefficient_min(query, document):
    query = word_tokenize(query)
    document = word_tokenize(document)
    intersection = set(query).intersection(set(document))
    den = min(len(set(query)),len(set(document)))
    if(den==0):
        return 0
    return len(intersection)/den

def overlap_coefficient_max(query, document):
    query = word_tokenize(query)
    document = word_tokenize(document)
    intersection = set(query).intersection(set(document))
    den = max(len(set(query)),len(set(document)))
    if(den==0):
        return 0
    return len(intersection)/den

def occ_result_maker(Rank_initial, Rank_final, occ_exp, gold_exp):
    rouge1 = []
    rouge2 = []
    rougel = []
    jaccard = []
    bleu = []
    meteor = []
    overlap_min = []
    overlap_max = []
    
    files = list(gold_exp.keys())
    
    for u in range(5):
        user = "User " + str(u+1)
        r1 = []
        r2 = []
        rl = []
        jacc = []
        bl = []
        met = []
        omin = []
        omax = []
        
        for i in tqdm(range(len(files))):
            f = files[i]
            ref_text = ""
            for rank in range(Rank_initial, Rank_final+1, 1):
                if(gold_exp[f][user]["exp"]["Rank" + str(rank)]!=""):
                    ref_text += gold_exp[f][user]["exp"]["Rank" + str(rank)] + " "
                
            machine_text = occ_exp[f]
            machine_text = machine_text.lower()
            ref_text = ref_text.lower()
            
            if(ref_text == ""):
                continue
            rouge = Rouge()
            score = rouge.get_scores(machine_text, ref_text)
            r1.append(score[0]['rouge-1']['f'])
            r2.append(score[0]['rouge-2']['f'])
            rl.append(score[0]['rouge-l']['f'])
            jacc.append(jaccard_similarity(ref_text, machine_text))
            omin.append(overlap_coefficient_min(ref_text, machine_text))
            omax.append(overlap_coefficient_max(ref_text, machine_text))
            bl.append(get_BLEU_score(ref_text, machine_text))
            
#             print('===', ref_text)
#             print('===', machine_text)
            met.append(nltk.translate.meteor_score.meteor_score([ref_text.split()], machine_text.split()))
            
        rouge1.append(np.mean(r1))
        rouge2.append(np.mean(r2))
        rougel.append(np.mean(rl))
        jaccard.append(np.mean(jacc))
        overlap_min.append(np.mean(omin))
        overlap_max.append(np.mean(omax))
        bleu.append(np.mean(bl))
        meteor.append(np.mean(met))
        
    print("ROUGE-1 : {:}".format(rouge1) + "\n\n")
    print("ROUGE-2 : {:}".format(rouge2) + "\n\n")
    print("ROUGE-L : {:}".format(rougel)+ "\n\n")
    print("Jaccard : {:}".format(jaccard)+ "\n\n")
    print("Overmin : {:}".format(overlap_min)+ "\n\n")
    print("Overmax : {:}".format(overlap_max)+ "\n\n")
    print("BLEU    : {:}".format(bleu)+ "\n\n")
    print("METEOR  : {:}".format(meteor)+ "\n\n") 
            

In [13]:
import json
gold_exp = json.load(open('data/ILDC/gold_explanations_ranked.json'))
occ_result_maker(1, 10, occ_exp, gold_exp)

100%|█████████████████████████████████████████████████████████████████████████████████████| 56/56 [01:57<00:00,  2.10s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 56/56 [01:23<00:00,  1.49s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 56/56 [02:42<00:00,  2.91s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 56/56 [03:10<00:00,  3.39s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 56/56 [01:38<00:00,  1.75s/it]

ROUGE-1 : [0.6728244041642905, 0.5804562352169429, 0.7065788679140699, 0.7206439637166981, 0.5968851828412346]


ROUGE-2 : [0.5537667911485151, 0.4391302967808964, 0.6015698579312018, 0.6145012886825254, 0.460719620769056]


ROUGE-L : [0.6601222702297712, 0.5541566417051895,  0.6975489384977261, 0.7152230833905514, 0.575994705052335]


Jaccard : [0.5223500810569234, 0.4242793494297573, 0.55669154901328, 0.5445209117877944, 0.4379944088056159]


Overmin : [0.7922623850463363, 0.6996248986608471, 0.8773273640989261, 0.8887593940818513, 0.6909408185732746]


Overmax : [0.6059111123933351, 0.5219418803133252, 0.6062007716768416, 0.58740837277975, 0.5413169060827409]


BLEU    : [0.469689580174378, 0.4612571751883142, 0.4155969855778494, 0.394920598035904, 0.4673805104372826]


METEOR  : [0.39135295051274294, 0.45025931411401393, 0.3510144751694292, 0.3392467454474154, 0.435055087932992]





