In [1]:
import os,json,random,torch,csv,nltk,math,string,ast
from tqdm import tqdm_notebook
import numpy as np 
from dotted_dict import DottedDict
nltk.data.path.append('D:\\python_pkg_data\\nltk_data')
from nltk.tokenize import word_tokenize
from nltk.tokenize.treebank import TreebankWordDetokenizer
from termcolor import colored
from transformers import AutoTokenizer,pipeline

os.environ['PYTHONHASHSEED'] = str(2020)
os.environ['TRANSFORMERS_CACHE'] = 'D:\\python_pkg_data\\huggingface\\transformers'
np.random.seed(2020)
random.seed(2020)
torch.manual_seed(2020)
torch.cuda.manual_seed_all(2020)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [8]:
## replace certian text with [mask] token
def Make_masked_text(token_text,replace_id,tokenizer):
    
    token_text[replace_id] = tokenizer.mask_token
    mask_text = TreebankWordDetokenizer().detokenize(token_text)
    return mask_text

## return all synonym candidates
def Get_synonyms(mask_text,unmasker,ban_list):
    
    synonym_list = []
    
    for item in unmasker(mask_text):
        synonym = item['token_str'].lower().strip()
        if synonym not in ban_list:
             synonym_list.append(item['token_str'].strip())
    

    return synonym_list

## generate several semi-facutal examples for each document
def Generate_semi_factual(text,unmasker,tokenizer,rationale_span,ban_list,percentage=0.15,num_candidates=5):

    candidates = []
    token_text = word_tokenize(text)
    all_ids = [i for i in range(len(token_text))]
    
    ## ensure the tokens being replaced are not punctuations and rationale tokens
    punc_ids = [idx for idx,token in enumerate(token_text) if token in string.punctuation]
    non_rationale_ids = list(set(all_ids)-set(punc_ids)-set(rationale_span))
    
    ## num of tokens being replaced
    num_replace = math.ceil(percentage*len(non_rationale_ids))
    
    ## synonyms of each position inferenced by language models, to reduce prediction time
    synonyms = {}
   
    
    for i in tqdm_notebook(range(num_candidates)):
        
        token_candidate = token_text.copy()
        
        ## random select non-rationale tokens
        replace_ids = random.sample(non_rationale_ids,num_replace)
           
                
        for replace_id in replace_ids:
            ## if the synonyms of this position have been predicted by the language models before, direct use
            if replace_id in synonyms.keys():
                
                ## random select one word from the synonym list
                if len(synonyms[replace_id])>=1:
                    synonym = random.choice(synonyms[replace_id])
                    token_candidate[replace_id] = synonym
                else:
                    pass
            else:
            ## predict the synonyms given this position
                mask_text = Make_masked_text(token_text.copy(),replace_id,tokenizer)
              
                ## get synonyms by mask-filling prediction, ensure the synonym words are not sentimental words or the original token
                synonyms[replace_id] = Get_synonyms(mask_text,unmasker,ban_list+[token_text[replace_id]])
                
                if len(synonyms[replace_id])>=1:
                
                    synonym = random.choice(synonyms[replace_id])
                      
                    token_candidate[replace_id] = synonym
                    
                else:
                    pass
                
        candidates.append(TreebankWordDetokenizer().detokenize(token_candidate))
        
        
    return candidates


def Random_replacement(text,unmasker,tokenizer,ban_list,percentage=0.15,num_candidates=5):

    candidates = []
    token_text = word_tokenize(text)
    
    punc_ids = [idx for idx,token in enumerate(token_text) if token in string.punctuation]
    non_punc_ids = list(set([i for i in range(len(token_text))])-set(punc_ids))
    
    
    num_replace = math.ceil(0.15*len(token_text))
    synonyms = {}
   
    
    for i in tqdm_notebook(range(num_candidates)):
        ## ensure the tokens being replaced are not punctuations
        token_candidate = token_text.copy()
        
        replace_ids = random.sample(non_punc_ids,num_replace)
           
                
        for replace_id in replace_ids:
            if replace_id in synonyms.keys():
                if len(synonyms[replace_id])>=1:
                    synonym = random.choice(synonyms[replace_id])
                    token_candidate[replace_id] = synonym
                else:
                    pass
            else:
                mask_text = Make_masked_text(token_text.copy(),replace_id,tokenizer)
              
                synonyms[replace_id] = Get_synonyms(mask_text,unmasker,ban_list+[token_text[replace_id]])
                
                if len(synonyms[replace_id])>=1:
                
                    synonym = random.choice(synonyms[replace_id])
                      
                    token_candidate[replace_id] = synonym
                    
                else:
                    pass
                
        candidates.append(TreebankWordDetokenizer().detokenize(token_candidate))
        
        
    return candidates


def Visualise_rationales(original,rationale_spans,rationale_pos,visualise_all=False):
    
    if visualise_all:
        highlighted = []
        for idx,term in enumerate(word_tokenize(original)):
            if idx in rationale_pos:
                highlighted.append(colored(term,'blue'))
            else:
                highlighted.append(term)
            
        return TreebankWordDetokenizer().detokenize(highlighted)
                
    else:
        highlights = []
        for span in rationale_spans:
            highlighted = []
            for idx,term in enumerate(word_tokenize(original)):
                if idx in span:
                    highlighted.append(colored(term,'blue'))
                else:
                    highlighted.append(term)
                    
            highlights.append(TreebankWordDetokenizer().detokenize(highlighted))
        
        return highlights
    


In [15]:
args = {
    'data_dir':'./datasets/IMDb/orig/train.tsv',
    'pos_dict_dir':'./datasets/positive-words.txt',
    'neg_dict_dir':'./datasets/negative-words.txt',
    'pos_labels_dir': './datasets/IMDb/human_labelled/positives.json',
    'neg_labels_dir':'./datasets/IMDb/human_labelled/negatives.json',
    'supplement':'./datasets/IMDb/human_labelled/supplement_rationales.tsv',
    'is_random_replace':False
}

args = DottedDict(args)

# load datasets, positive words and negative words dictionaries

In [10]:
## load original data

data = {} ##  description:{'doc_id': {'text', 'label '}}
          ##              e.g. {'4': {'text': 'Long, boring, blasphemous. Never have I been so glad to see ending credits roll.', 'label': 0}}
with open(args.data_dir,errors='ignore') as file:
    file = csv.reader(file, delimiter="\t")
    for idx,row in enumerate(file):
        if len(row)>0:
            if row[0] == 'Negative':
                data[row[2]] = {'text':row[1],'label':0}
            else:
                data[row[2]] = {'text':row[1],'label':1}


## a list of positive words and negative words uses as a ban list
positive_terms = []            
with open(args.pos_dict_dir,'r') as file:
    positive_terms=file.read().splitlines()
    
negative_terms = []            
with open(args.neg_dict_dir,'r') as file:
    negative_terms=file.read().splitlines()

# load rationale annotations

In [11]:
pos_rationales = json.load(open(args.pos_labels_dir, 'r'))
neg_rationales = json.load(open(args.neg_labels_dir, 'r'))


rationale_spans = {} # rationales spans of each document {'doc_id':[{'rationale1_id': 'start_position','end_position','text'},...]}
                     #                               e.g. {'1006': [{'token_id': 1, 'start_token': 0, 'end_token': 1, 'text': 'SUcks'},

for item in neg_rationales:
    key = list(item.keys())[1]
    index = list(item.keys())[1][9:-4]
    rationale_spans[index] = item[key]
    

for item in pos_rationales:
    key = list(item.keys())[1]
    index = list(item.keys())[1][9:-4]
    rationale_spans[index] = item[key]

rationale_positions = {} # rationales positions of each docuemnt {'doc_id':[pos1, pos2, pos3, ...],...]}
                        #  '1006': [0, 48, 49, 53, 70, 74, 75, 17, 19, 21, 34, 83, 85, 87],

doc_ids = list(rationale_spans.keys())

for doc_id in doc_ids:
    doc_positions = []
    positions = rationale_spans[doc_id]
    for span in positions:
        start = span['start_token']
        end = span['end_token']
        doc_positions = doc_positions +[i for i in range(start,end)]
        
    rationale_positions[doc_id] = doc_positions
    
## add supplement rationale positions
supplement_rationales = {}
with open(args['supplement'],'r') as file:
    file = csv.reader(file, delimiter='\t')
    for idx,row in enumerate(file):
        supplement_rationales[row[0]] = ast.literal_eval(row[1])
        
rationale_positions.update(supplement_rationales)

# visualise rationals of random examples

In [25]:
random_index = random.sample(doc_ids,5)


for doc_id in random_index[-10:]:

    label = 'negative' if data[doc_id] == 0 else 'positive'
    print(f'doc_id:{doc_id}')
    print(f'doc_label:{label}')

    ## select text
    original = data[doc_id]['text']
    token_original = word_tokenize(original)

    rationale_pos = rationale_positions[doc_id]

    highlighted = Visualise_rationales(original,_,rationale_pos,visualise_all=True)
    print(highlighted)
    print('*'*100)

doc_id:21915
doc_label:positive
I think this movie is a [34mvery[0m [34mfunny[0m film and one of the [34mbest[0m 'National Lampoon's' films, it also has a [34mvery[0m [34mcatchy[0m spoof title, which basically sums up what the whole movie is about .... Men In White!!!! The story is a spoof of many films including a Will Smith film, as you might have guessed, 'Men In Black' . I will not give the ending away but it has a [34mvery[0m [34mgood[0m ending in is very funny (Leslie Nielsen style humour) from start to finish, especially the bit near the beginning when thy are in the street collecting the dustbins (Garbage Cans). Also, they have a [34mpretty[0m [34mcool[0m dustbin lorry (Garbage Collecting Truck) in that scene too . The acting is not superb, actually, it is not very good, but that is what makes the film [34mfunny[0m, it is a comedy, loosen up!! I love the story line, partly because it is so far fetched and partly because it is interesting to see how subtle (O

# generate semi-factual examples based on non-rationales

In [13]:
## an example of generating semi-factual data off-line by RoBERTa
unmasker = pipeline('fill-mask', model='roberta-base',top_k=15,device=0)
tokenizer = AutoTokenizer.from_pretrained('roberta-base')

In [22]:
doc_idx = '16740'
ban_list = positive_terms + negative_terms
text = data[doc_idx]['text']
candidates = Generate_semi_factual(text,unmasker,tokenizer,rationale_positions[doc_idx],ban_list,num_candidates=16,percentage=0.05)

print(candidates[:3])

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/16 [00:00<?, ?it/s]

["I found this to be a profoundly amusing dark comedy . Brosnan is genius; as anyone will now testify, he is not to be seen in the bond role . Kinnear was as charismatic and as funny as someone could have been in the film . I don't know if I've laughed as hard during any movie! What an unexpected pleasure! My favourite line would be' I feel like a bangkok hooker on a Sunday - the navy left town' . Brosnan delivered this very un-bond line with such unexpected comedic finesse . I was also very impressed with Hope Davis's performance . It seems the everyone in this movie branched out from their previous work to such a degree that it actually improved the comedy . If you saw the dark and hilarious 'The Weather Man', you will definitely like this . I voted 10.", "I found this to be a profoundly amusing dark satire . Brosnan is genius; as anyone will now testify, he is not to be pigeonholed in the bond role . Kinnear was as charismatic and as funny as anyone could have been in the role . I d

In [16]:
## generating semi-factual/random-replacement data off-line by RoBERTa just 10 for saving time
augmented_data = {}
for doc_id in tqdm_notebook(list(rationale_positions.keys())[:10]):
    ori_text = data[doc_id]['text']
    ori_label = data[doc_id]['label']   
    
    ## using random replacement or not
    if not args.is_random_replace:
        candidates = Generate_semi_factual(ori_text,unmasker,tokenizer,rationale_positions[doc_id],ban_list,percentage=0.05,num_candidates=16)
    else:
        candidates = Random_replacement(ori_text,unmasker,tokenizer,ban_list,percentage=0.05,num_candidates=16)
        
    augmented_data[doc_id] = {'candidates':candidates,'label':ori_label}

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  This is separate from the ipykernel package so we can avoid doing imports until


  0%|          | 0/10 [00:00<?, ?it/s]

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

In [17]:
## save to local

if not args.is_random_replace:
    with open("./datasets/IMDb/semi_factual_augmented_examples.json", "w") as file_name:
        json.dump(augmented_data, file_name)    

else:
    with open("./datasets/IMDb/random_replace_examples.json", "w") as file_name:
        json.dump(augmented_data, file_name)    

In [3]:
## load from local to check
# with open("./datasets/IMDb/semi_factual_augmented_examples.json", "r") as file_name:
#     augmented_data = json.load(file_name)
    
with open("../over_generalisation/step0_LR_5e-6/IMDb_og_human_trainer_2019_25_7/augmented_step1.json", "r") as file_name:
    augmented_data = json.load(file_name)

In [13]:
len(augmented_data['8518']['candidates'])

16

In [16]:
7-8

-1

In [15]:
a[:-3]

[1]

In [2]:
import os
import shutil

for seed in range(2019,2029):

    source_folder = f"../over_generalisation/step0_LR_5e-6/IMDb_og_human_trainer_{seed}_25_7/"
    destination_folder = f"./AL_results/AL_step0_IMDb_trainer_{seed}_25/"

    # fetch all files
    for file_name in os.listdir(source_folder):
        print(file_name)
        if file_name.startswith('checkpoint'):
            pass
        else:
            source = source_folder + file_name
            destination = destination_folder + file_name

            if not os.path.exists(destination_folder):
                os.mkdir(destination_folder)

            if os.path.isfile(source):
                shutil.copy(source, destination)
                print('copied', file_name)
    # construct full file path
    source = source_folder + file_name
    destination = destination_folder + file_name
    # copy only files
    if os.path.isfile(source):
        shutil.copy(source, destination)
        print('copied', file_name)

.ipynb_checkpoints
augmented_step1.json
copied augmented_step1.json
checkpoint-700
false_rationale.txt
copied false_rationale.txt
false_rationles.tsv
copied false_rationles.tsv
keys.txt
copied keys.txt
loggings.json
copied loggings.json
missing_rationales_augmented_step1.json
copied missing_rationales_augmented_step1.json
new_keys.txt
copied new_keys.txt
process_output.txt
copied process_output.txt
Untitled.ipynb
copied Untitled.ipynb
copied Untitled.ipynb
augmented_step1.json
copied augmented_step1.json
checkpoint-100
false_rationale.txt
copied false_rationale.txt
false_rationles.tsv
copied false_rationles.tsv
keys.txt
copied keys.txt
loggings.json
copied loggings.json
missing_rationales_augmented_step1.json
copied missing_rationales_augmented_step1.json
missing_rationales_output.txt
copied missing_rationales_output.txt
new_keys.txt
copied new_keys.txt
process_output.txt
copied process_output.txt
copied process_output.txt
augmented_step1.json
copied augmented_step1.json
checkpoint-200