In [None]:
#########------------------------------MAIN FUNCTIONS FOR GENERATIONS---------------------------#############

In [3]:
#--------------------------------------- Prompt Creaters ---------------------------------------
# Please create your own prompts if you want to work with other models 
# See the models prompting/input template and modify teh zeroshot and fewshot functions  


In [1]:
def instruction_prompt(utterance, context = 'contextual scenario', qualifier = 'toxic'):
    
    '''
    wraps the utterance with prompt and qualifier
    '''
    prompt = 'Add ' + context + ' to the utterance ' + '\''+ utterance + '\'' + ', such that the statement becomes ' + qualifier
    return prompt

In [2]:
def prompt_zeroshot_llama2(instruction):
    '''
    zero shot prompting in llama_2 style
    '''
    system_prompt = "### System:\nYou are an assistance which adds context to any given utterance.\n\n"
    prompt = f"{system_prompt}### User: {instruction}\n\n### Assistant:"
    return prompt

In [3]:
def prompt_fewshot_llama2(utterance, num_shot, examples,context = 'contextual scenario', qualifier= 'toxic'):
    '''
    Few-shot prompting in llama-2 style
    Args:
    utterance :str       any utterance or hate speech
    num_shots   :int      number of shots  
    examples    :list     list of list containing examples of utterance and context  len(examples) == num_shots 
    '''
    assert len(examples) == num_shot, 'Mismatch num-shots and number of examples'
    system_prompt = "### System:\nYou are an assistance which adds context to any given utterance.\n\n"
    prompt = f"{system_prompt}"
    for item in examples:
        add = f"### User: {instruction_prompt(item[0], context=context, qualifier = qualifier)}\n\n### Assistant:{item[1]}\n\n"
        prompt = prompt + add

    input = f"### User: {instruction_prompt(utterance, context=context, qualifier = qualifier)}\n\n### Assistant:"
    prompt = prompt + input
    return prompt

In [4]:
def prompt_utterance(context):
    '''
    prompt to generate utterance for multistage generation
    '''
    prompt = 'Add utterance to the context ' + '\''+ context + '\'' 
        
    return prompt


def prompt_zeroshot_utterance(instruction):
    system_prompt = "### System:\nYou are a dialog writer, who can write dialog for any given contextual scenario such that the final sentence follows the context\n\n"
    prompt = f"{system_prompt}### User: {instruction}\n\n### Assistant:"
    return prompt

def prompt_fewshot_utternace(context, num_shot, examples):
    '''
    Args:
    num_shots   :int       number of shots  
    examples    :list      list containing examples of utterance and context  len(examples) == num_shots 
    '''
    assert len(examples) == num_shot, 'Mismatch num-shots and number of examples'
    system = "### System:\nYou are a dialog writer, who can write dialog for any given contextual scenario such that the final sentence follows the context\n\n"
    prompt = system + ''
    for item in examples:
        add = f"### User: {prompt_utterance(item[0])}\n\n### Assistant:{item[1]}\n\n"
        prompt +=add

    prompt = prompt + f"### User: {prompt_utterance(context)}\n\n### Assistant:"
    return prompt

In [5]:
#--------------------------------------- Error Handling & Incontext Example Function ---------------------------------------
# These functions does some error handling for the generation functions 
# It also pulls incontext examples from the Example folder according to the given polarity


In [6]:
def fewshot_assertion(num_shot_con1, num_shot_con2, num_shot_utt, in_p, tran_p, out_p, double_polarity=False ):
    examples = examples_pull(in_p=in_p, tran_p= tran_p, out_p= out_p, double_polarity=double_polarity)
    
    assert num_shot_con1 <= len(examples[0]), 'AssertionError: num_shot_con1 greater than examples: decrease num_shot_con1 or increase examples in examples file'
    assert num_shot_con2 <= len(examples[1]), 'AssertionError: num_shot_con2 greater than examples: decrease num_shot_con2 or increase examples in examples file'    
    assert num_shot_utt <= len(examples[2]),  'AssertionError: num_shot_utt greater than examples: decrease num_shot_utt or increase examples in examples file'
    examples_con1 = examples[0][:num_shot_con1]
    examples_utt = examples[1][:num_shot_utt]
    if double_polarity == True:
        example_double_toxic = examples[2][:num_shot_con2]
        example_double_benign = examples[3][:num_shot_con2]
        return examples_con1, examples_utt, examples_double_toxic, examples_double_benign 
    else:
        
        examples_con2 = examples[2][:num_shot_con2]
        return examples_con1, examples_utt, examples_con2
    
polarity = ['benign', 'toxic', 'neutral']

def examples_pull(in_p, tran_p, out_p, double_polarity = False):
    assert in_p in polarity, 'Assertion error: Invalid polarity for in_p'
    assert tran_p in polarity, 'Assertion error: Invalid polarity for tran_p'
    assert out_p in polarity, 'Assertion error: Invalid polarity for out_p '
    examples =  examples_file(in_p, tran_p, out_p, double_polarity = double_polarity)
    with open(f"Examples/{examples[0]}") as fp:   
        examples_con1 = json.load(fp)
    with open(f"Examples/{examples[1]}") as fp:   
        examples_utt = json.load(fp)
    
    if double_polarity == True:
        with open(f"Examples/{examples[2]}") as fp:   
            examples_double_toxic = json.load(fp)
        with open(f"Examples/{examples[3]}") as fp:   
            examples_double_benign = json.load(fp)
        return examples_con1, examples_utt, examples_double_toxic, examples_double_benign 
    else:
        with open(f"Examples/{examples[2]}") as fp:   
            examples_con2 = json.load(fp)
        return examples_con1, examples_utt, examples_con2 
        
        

    

def examples_file(in_p, tran_p, out_p, double_polarity = False):
    examples_con1 = in_p + '_' + tran_p + '_examples' + '.json'
    examples_utt = 'examples_utterance.json'
    if double_polarity == True:
        
        examples_double_toxic = 'neutral_' + 'toxic' + '_examples' + '.json'
        examples_double_benign = 'neutral_' + 'benign' + '_examples' + '.json'
        return examples_con1, examples_utt, examples_double_toxic, examples_double_benign
    else:
        examples_con2 = 'neutral_' + out_p + '_examples' + '.json'
        return examples_con1, examples_utt, examples_con2

    
   

In [7]:
def fewshot_assertion_direct(num_shot, in_p, out_p, double_polarity=False ):
    
    examples = examples_pull_direct(in_p, out_p, double_polarity = False)
    
    for item in examples:
        assert num_shot <= len(item), 'AssertionError: num_shot greater than examples: decrease num_shot or increase examples in examples file'
    if double_polarity == True:
        example_double_toxic = examples[0][:num_shot]
        example_double_benign = examples[1][:num_shot]
        return example_double_toxic, example_double_benign
    else:
        
        examples_trans = examples[0][:num_shot]
        return examples_trans
    
polarity = ['benign', 'toxic', 'neutral']

def examples_pull_direct(in_p, out_p, double_polarity = False):
    assert in_p in polarity, 'Assertion error: Invalid polarity for in_p'
    assert out_p in polarity, 'Assertion error: Invalid polarity for out_p '
    examples =  examples_file_direct(in_p=in_p, out_p=out_p, double_polarity = double_polarity)
    
    if double_polarity == True:
        with open(f"Examples/{examples[0]}") as fp:   
            examples_double_toxic = json.load(fp)
        with open(f"Examples/{examples[1]}") as fp:   
            examples_double_benign = json.load(fp)
        return examples_double_toxic, examples_double_benign
    else:
        with open(f"Examples/{examples[0]}") as fp:   
            examples_tran = json.load(fp)
        return [examples_tran]
        
        

    

def examples_file_direct(in_p, out_p, double_polarity = False):
    examples_tran = in_p + '_' + out_p + '_examples' + '.json'
    if double_polarity == True:
        
        examples_tran_1 = in_p + '_toxic' + '_examples' + '.json'
        examples_tran_2 = in_p + '_benign' + '_examples' + '.json'
        return examples_tran_1, examples_tran_2
    return [examples_tran]

    
   

In [8]:
##  For editing utterance: Has not been used in the paper
polarity = ['benign', 'toxic', 'neutral']
def fewshot_assertion_edit(num_shot, in_p, out_p):
    
    examples = examples_pull_direct_edit(in_p, out_p)
    
    for item in examples:
        assert num_shot <= len(item), 'AssertionError: num_shot greater than examples: decrease num_shot or increase examples in examples file'
    return examples[0][:num_shot]
        
    


def examples_pull_edit(in_p, out_p):
    examples =  examples_file_direct_edit(in_p=in_p, out_p=out_p)
    
    with open(f"Examples/{examples}") as fp:   
        examples = json.load(fp)
        
    return [examples]
        
        

    

def examples_file_edit(in_p, out_p):
    assert in_p in polarity, 'Assertion error: Invalid polarity for in_p'
    assert out_p in polarity, 'Assertion error: Invalid polarity for out_p '
    
    return in_p + '_' + out_p + '_edit_examples' + '.json'

    
   

In [9]:
#--------------------------------------- Generation Function ---------------------------------------
# These are the base functions for the generation methods 
# Resemblance to paper: augment_direct--> Direct Augment
# Resemblance to paper: augment_multistage--> Multistage Augment
# Resemblance to paper: augment_niter_multistage--> N-iter Multistage Augment

In [10]:
def direct_augment_generator(utterance,  model, tokenizer, num_shot, examples, bad_words_ids, context = 'contextual scenario', qualifier='toxic', do_sample=True, penalty_alpha=0.6, top_k=8, max_new_tokens=70, repetition_penalty=1.2):
        prompt = prompt_fewshot_llama2(utterance, num_shot, examples, qualifier='toxic')   
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
        out_list = [utterance]
        tokens = model.generate(**inputs,  do_sample=do_sample, penalty_alpha=penalty_alpha, top_k=top_k,  max_new_tokens= max_new_tokens, repetition_penalty=repetition_penalty, bad_words_ids=bad_words_ids)
        out_list.append(tokenizer.decode(tokens[0]))         
        return out_list

In [11]:
def extract_multistage(lis, assertion = True):
    generation = lis[1].split('### Assistant:')
    if assertion == True:
        
        try:
            assert len(generation) >= 8
            if len(generation) == 8:
                extract = generation[-1]
            elif len(generation) > 8:
                extract = generation[7]
            
        except AssertionError:
            print('\n model generation was erroneous')
            extract = 'N/A'
    generation = extract.split('<eos>')
    generation = generation[0]
    generation = generation.replace('<s>', '')
    generation = generation.replace('</s>', '')
    generation = generation.replace('eos', '')
    generation = generation.replace('<', '')
    generation = generation.replace('>', '')
    return generation


In [1]:
# Direct Augment Method
# inputs: utterance--> str, model--> transformers model, tokenizer--> transformer tokenizer 
# num_shot--> int (number of in context example, default 6: changing this requires changing the number of examples in examples folder)
# in_p--> str (input polarity e.g. 'toxic', 'benign'), out_p--> desired output polarity
# bad_words_ids--> list[int] (ids to omit, default provided as bag_of_words_ids.json)
# do_sample=True, penalty_alpha=0.6, top_k=8 --> Contrastive Decoding Parameters
# double_polarity--> Boolean (if True will generate both toxic and benign context for the given utterance) 


def augment_direct(utterance,  model, tokenizer, num_shot, in_p, out_p, bad_words_ids, context = 'contextual scenario', qualifier='toxic', do_sample=True, penalty_alpha=0.6, top_k=8, max_new_tokens=70, repetition_penalty=1.2, double_polarity=False):
       

    examples = fewshot_assertion_direct(num_shot=num_shot, in_p=in_p, out_p=out_p, double_polarity=double_polarity)
    
    
    if double_polarity == True:
        context_augment = direct_augment_generator(utterance, model=model, 
                                              tokenizer=tokenizer, qualifier='toxic', 
                                              bad_words_ids=bad_words_ids, num_shot=num_shot, 
                                              examples=examples[0], penalty_alpha=penalty_alpha, 
                                              top_k=top_k, max_new_tokens=max_new_tokens)
    
        extract_context_toxic = extract_multistage(context_augment)
        context_augment = direct_augment_generator(utterance, model=model, 
                                              tokenizer=tokenizer, qualifier='benign', 
                                              bad_words_ids=bad_words_ids, num_shot=num_shot, 
                                              examples=examples[1], penalty_alpha=penalty_alpha, 
                                              top_k=top_k, max_new_tokens=max_new_tokens)
    
        extract_context_benign = extract_multistage(context_augment)
        return [{'utterance':utterance,'context_toxic': extract_context_toxic}, {'utterance':utterance,'context_benign': extract_context_benign}]     
    
    context_augment = direct_augment_generator(utterance, model=model, 
                                              tokenizer=tokenizer, qualifier=qualifier, 
                                              bad_words_ids=bad_words_ids, num_shot=num_shot, 
                                              examples=examples, penalty_alpha=penalty_alpha, 
                                              top_k=top_k, max_new_tokens=max_new_tokens)
    
    extract_context = extract_multistage(context_augment)
    
    return [{'utterance':utterance,'context': extract_context}]

In [29]:
#  Utterance generation functions for Multistage Generations

def generate_utterance_fewshot(context, utterance, model, tokenizer, num_shot, examples, do_sample=True, penalty_alpha=0.6, top_k=8, max_new_tokens=40, repetition_penalty=1.2):     
    prompt = prompt_fewshot_utternace(context,num_shot, examples)        
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    out_list = []
    tokens = model.generate(**inputs,  do_sample=do_sample, penalty_alpha=penalty_alpha, top_k=top_k,  max_new_tokens= max_new_tokens, repetition_penalty=repetition_penalty)
    out_list.append(context)
    out_list.append(tokenizer.decode(tokens[0])) 
    return out_list    


def generate_utterance_zeroshot(context, utterance, model, tokenizer, do_sample=True, penalty_alpha=0.6, top_k=8, max_new_tokens=40, repetition_penalty=1.2):     
    prompt = prompt_utterance(context)
    prompt = prompt_zeroshot_utterance(prompt)
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    out_list = []
    tokens = model.generate(**inputs,  do_sample=do_sample, penalty_alpha=penalty_alpha, top_k=top_k,  max_new_tokens= max_new_tokens, repetition_penalty=repetition_penalty)
    out_list.append(context)
    out_list.append(tokenizer.decode(tokens[0])) 
    return out_list  



In [None]:
# Multistage Augment Method
# inputs: utterance--> str, model_con1, model_con2, model_utt --> transformers model,
# tokenizer_con1, tokenizer_con2, tokenizer_utt--> transformer tokenizer
# num_shot--> int (number of in context example, default 6: changing this requires changing the number of examples in examples folder)
# in_p--> str (input polarity e.g. 'toxic', 'benign'), out_p--> desired output polarity
# bad_words_ids--> list[int] (ids to omit, default provided as bag_of_words_ids.json)
# do_sample=True, pa_con1=0.6, topk_con1=8, pa_con2=0.6, topk_con2=8, pa_utt=0.6, topk_utt=8--> Contrastive Decoding Parameters
# double_polarity--> Boolean (if True will generate both toxic and benign context for the given utterance) 



def augment_multistage(utterance, model_con1, model_con2, model_utt, 
                                 tokenizer_con1, tokenizer_con2, tokenizer_utt, bad_words_ids,
                                 in_p='neutral', tran_p ='benign', out_p= 'toxic',
                                 num_shot_con1= 6, num_shot_con2=6, num_shot_utt=6,
                                 context = 'contextual scenario',  
                                 qualifier_con1='highly positive', qualifier_con2='offensive',
                                 do_sample=True, pa_con1=0.6, topk_con1=8, pa_con2=0.6, topk_con2=8, pa_utt=0.6, topk_utt=8, 
                                 maxtoken_con1=70, maxtoken_con2 = 70, maxtoken_utt=40, repetition_penalty=1.2, double_polarity= False, multiple_utternace = False, multiple_utternace_no = None):

   
    if multiple_utternace == True: assert multiple_utternace_no > 1, 'Please add the number of utterance example with multiple_utternace_no '
    examples = fewshot_assertion(num_shot_con1=num_shot_con1, num_shot_con2=num_shot_con2, 
                                                                   num_shot_utt=num_shot_utt, in_p=in_p, tran_p=tran_p, 
                                                                   out_p=out_p, double_polarity=double_polarity )
    
    
    
    
    ##  context -- generation
    context_augment = direct_augment_generator(utterance, model=model_con1, 
                                              tokenizer=tokenizer_con1, qualifier=qualifier_con1, 
                                              bad_words_ids=bad_words_ids, num_shot=6, 
                                              examples=examples[0], penalty_alpha=pa_con1, 
                                              top_k=topk_con1, max_new_tokens=maxtoken_con1)
    ##  context -- extraction
    extract_context = extract_multistage(context_augment)
    
    if multiple_utternace == True:
        multiple_utt = []
        for i in range(multiple_utternace_no):
            utt_generate = generate_utterance_fewshot(context=extract_context, utterance=utterance,
                                       model=model_utt, tokenizer=tokenizer_utt, examples=examples[1], num_shot=6,  
                                        penalty_alpha=pa_utt,
                                       top_k=topk_utt, max_new_tokens=maxtoken_utt)
            utt_ex = extract_multistage(utt_generate, assertion=True)
            multiple_utt.append(utt_ex)
    
        out = {}
        if double_polarity == True:
            for item in multiple_utt:
            
                context_augment_2 = direct_augment_generator(item, model=model_con2, 
                                              tokenizer=tokenizer_con2, qualifier='toxic', 
                                              bad_words_ids=bad_words_ids, num_shot=6, 
                                              examples=examples[2], penalty_alpha=pa_con2, 
                                              top_k=topk_con2, max_new_tokens=maxtoken_con2)
                context_augment_3 = direct_augment_generator(item, model=model_con2, 
                                                  tokenizer=tokenizer_con2, qualifier='benign', 
                                                  bad_words_ids=bad_words_ids, num_shot=6, 
                                                  examples=examples[3], penalty_alpha=pa_con2,
                                                  top_k=topk_con2, max_new_tokens=maxtoken_con2)
   
                extract_context_2 = extract_multistage(context_augment_3)
                extract_context_3 = extract_multistage(context_augment_3)
                out['utterance_' + multiple_utt.index(item)] = [(item, extract_context_2), (item, extract_context_3) ]
        else:   
            for item in multiple_utt:
                
                context_augment_2 = direct_augment_generator(utt_ex, model=model_con2, 
                                                  tokenizer=tokenizer_con2, qualifier=qualifier_con2, 
                                                  bad_words_ids=bad_words_ids, num_shot=6, 
                                                  examples=examples[2], penalty_alpha=pa_con2, 
                                                  top_k=topk_con2, max_new_tokens=maxtoken_con2)
            
                extract_context_2 = extract_multistage(context_augment_2)
                out['utterance_' + multiple_utt.index(item)] = [item, extract_context_2]
        return out
    
    else:
        utt_generate = generate_utterance_fewshot(context=extract_context, utterance=utterance,
                                       model=model_utt, tokenizer=tokenizer_utt, examples=examples[1], num_shot=6,  
                                        penalty_alpha=pa_utt,
                                       top_k=topk_utt, max_new_tokens=maxtoken_utt)
        utt_ex = extract_multistage(utt_generate, assertion=True)
        


    
    
    
    if double_polarity == True:
        
        ##  pool neutral to toxic examples
        context_augment_2 = direct_augment_generator(utt_ex, model=model_con2, 
                                              tokenizer=tokenizer_con2, qualifier='toxic', 
                                              bad_words_ids=bad_words_ids, num_shot=6, 
                                              examples=examples[2], penalty_alpha=pa_con2, 
                                              top_k=topk_con2, max_new_tokens=maxtoken_con2)
        context_augment_3 = direct_augment_generator(utt_ex, model=model_con2, 
                                              tokenizer=tokenizer_con2, qualifier='benign', 
                                              bad_words_ids=bad_words_ids, num_shot=6, 
                                              examples=examples[3], penalty_alpha=pa_con2, 
                                              top_k=topk_con2, max_new_tokens=maxtoken_con2)
    ##  2 context -- extraction
        extract_context_2 = extract_multistage(context_augment_3)
        extract_context_3 = extract_multistage(context_augment_3)
        return [{'utterance':utt_ex,'context_toxic': extract_context_2}, {'utterance':utt_ex,'context_benign': extract_context_3} ]
        
    context_augment_2 = direct_augment_generator(utt_ex, model=model_con2, 
                                              tokenizer=tokenizer_con2, qualifier=qualifier_con2, 
                                              bad_words_ids=bad_words_ids, num_shot=6, 
                                              examples=examples[2], penalty_alpha=pa_con2, 
                                              top_k=topk_con2, max_new_tokens=maxtoken_con2)
    ##  2 context -- extraction
    extract_context_2 = extract_multistage(context_augment_2)
    
    return [{'utterance':utt_ex,'context': extract_context_2}]





            
            
    
    
    
    
    

In [33]:
# N-iteration Multistage Augment Method
# inputs: n--> int (number of iteration), utterance--> str, model_con1, model_con2, model_utt --> transformers model,
# tokenizer_con1, tokenizer_con2, tokenizer_utt--> transformer tokenizer
# stack_output--> Boolean (if True then output will contain intermediate generations)
# num_shot--> int (number of in context example, default 6: changing this requires changing the number of examples in examples folder)
# in_p--> str (input polarity e.g. 'toxic', 'benign'), out_p--> desired output polarity
# bad_words_ids--> list[int] (ids to omit, default provided as bag_of_words_ids.json)
# do_sample=True, pa_con1=0.6, topk_con1=8, pa_con2=0.6, topk_con2=8, pa_utt=0.6, topk_utt=8--> Contrastive Decoding Parameters
# double_polarity--> Boolean (if True will generate both toxic and benign context for the given utterance)


def augment_niter_multistage(n,utterance, model_con1, model_con2, model_utt, 
                                 tokenizer_con1, tokenizer_con2, tokenizer_utt, bad_words_ids, stack_output = False,
                                 in_p='neutral', tran_p ='benign', out_p= 'toxic',
                                 num_shot_con1= 6, num_shot_con2=6, num_shot_utt=6,
                                 context = 'contextual scenario',  
                                 qualifier_con1='highly positive', qualifier_con2='offensive',
                                 do_sample=True, pa_con1=0.6, topk_con1=8, pa_con2=0.6, topk_con2=8, pa_utt=0.6, topk_utt=8, 
                                 maxtoken_con1=70, maxtoken_con2 = 70, maxtoken_utt=40, repetition_penalty=1.2, double_polarity= False):
   
    
    if double_polarity == True:
        
        ##  n-iter logic 
        if stack_output == True:
                out = {}
        for i in tqdm(range(n)):

            ## generation logic
            
            
                
            epoch = augment_multistage(utterance, model_con1=model_con1, model_con2=model_con2, model_utt=model_utt, tokenizer_utt=tokenizer_utt, 
                             tokenizer_con1=tokenizer_con1, tokenizer_con2=tokenizer_con2,bad_words_ids=bad_words_ids, 
                             in_p=in_p, tran_p =tran_p, out_p= out_p,
                             num_shot_con1= num_shot_con1, num_shot_con2=num_shot_con2, num_shot_utt=num_shot_utt,
                             context = context,  
                             qualifier_con1=qualifier_con1, qualifier_con2=qualifier_con2,
                             do_sample=do_sample, pa_con1=pa_con1, topk_con1=topk_con1, pa_con2=pa_con2, topk_con2=topk_con2, pa_utt=pa_utt, topk_utt=topk_utt, 
                             maxtoken_con1=maxtoken_con1, maxtoken_con2 = maxtoken_con2, maxtoken_utt=maxtoken_utt, repetition_penalty=repetition_penalty, double_polarity= double_polarity)
            
            utterance = epoch[0][0]
            if stack_output == True:
                out['epoch' + str(i+1)] = epoch 
            

            
               
            
        return out if stack_output == True else epoch
                
        

            
            
    
    
    
    
    out = {}
    for i in tqdm(range(n)):

        if stack_output == True:
            
            
            epoch = augment_multistage(utterance, model_con1=model_con1, model_con2=model_con2, model_utt=model_utt, tokenizer_utt=tokenizer_utt, 
                             tokenizer_con1=tokenizer_con1, tokenizer_con2=tokenizer_con2,bad_words_ids=bad_words_ids, 
                             in_p=in_p, tran_p =tran_p, out_p= out_p,
                             num_shot_con1= num_shot_con1, num_shot_con2=num_shot_con2, num_shot_utt=num_shot_utt,
                             context = context,  
                             qualifier_con1=qualifier_con1, qualifier_con2=qualifier_con2,
                             do_sample=do_sample, pa_con1=pa_con1, topk_con1=topk_con1, pa_con2=pa_con2, topk_con2=topk_con2, pa_utt=pa_utt, topk_utt=topk_utt, 
                             maxtoken_con1=maxtoken_con1, maxtoken_con2 = maxtoken_con2, maxtoken_utt=maxtoken_utt, repetition_penalty=repetition_penalty, double_polarity= double_polarity)
        
            utterance = epoch[0]
            out['epoch' + str(i+1)] = epoch
            continue
        
        
        epoch = augment_multistage(utterance, model_con1=model_con1, model_con2=model_con2, model_utt=model_utt, tokenizer_utt=tokenizer_utt, 
                             tokenizer_con1=tokenizer_con1, tokenizer_con2=tokenizer_con2,bad_words_ids=bad_words_ids, 
                             in_p=in_p, tran_p =tran_p, out_p= out_p,
                             num_shot_con1= num_shot_con1, num_shot_con2=num_shot_con2, num_shot_utt=num_shot_utt,
                             context = context,  
                             qualifier_con1=qualifier_con1, qualifier_con2=qualifier_con2,
                             do_sample=do_sample, pa_con1=pa_con1, topk_con1=topk_con1, pa_con2=pa_con2, topk_con2=topk_con2, pa_utt=pa_utt, topk_utt=topk_utt, 
                             maxtoken_con1=maxtoken_con1, maxtoken_con2 = maxtoken_con2, maxtoken_utt=maxtoken_utt, repetition_penalty=repetition_penalty, double_polarity= double_polarity)
        
        
        
            
        
        utterance = epoch[0]['utterance']
    
    if stack_output==True:
        return out
    return epoch


In [15]:
#########------------------------------ Using the Generation Methods Directly ---------------------------#############

In [None]:
# Please install modules from requirements.txt before proceeding

In [16]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

tokenizer = AutoTokenizer.from_pretrained("quantumaikr/llama-2-70b-fb16-orca-chat-10k")
model = AutoModelForCausalLM.from_pretrained("quantumaikr/llama-2-70b-fb16-orca-chat-10k", device_map ="auto")

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [17]:
import json
from tqdm import tqdm
with open("bag_of_words_ids.json", "r")as f:
    bag_of_words_ids = json.load(f)

In [20]:
# Please note: We are using the same model for multistage and N-iter multistage
# If you wand to use multiple models, input the models in the correct order, see fuctions augment_multistage and augment_niter_multistage
# For Multistage methods, if you want you can play with stack_output or multiple_utterance. multiple_utterance was not used in the paper due to high time complexity.
# stack_output--> stack the intermediate utterance and context and show it in the output
# multiple_utterance ---> generates many  alterante intermediate utterance instead of one utterance and context for all branched utterances

In [24]:
utterance = "I think we should meet at night "

In [27]:
# Direct Augment
generation = augment_direct(utterance,  model=model, tokenizer=tokenizer, num_shot=6, in_p='benign', out_p='toxic', bad_words_ids=bag_of_words_ids, context = 'contextual scenario', qualifier='toxic', do_sample=True, penalty_alpha=0.6, top_k=8, max_new_tokens=70, repetition_penalty=1.2, double_polarity=False)
print(generation)

[{'utterance': 'I think we should meet at night ', 'context': 'Imagine being in a workplace where sexual harassment is prevalent. A male coworker says, "I think we should meet at night," implying an unwanted romantic encounter instead of discussing work matters during daytime hours.'}]


In [30]:
# Multistage Augment 
generation = augment_multistage(utterance, model, model, model, 
                                 tokenizer, tokenizer, tokenizer, bad_words_ids=bag_of_words_ids,
                                 in_p='benign', tran_p ='toxic', out_p= 'benign',
                                 num_shot_con1= 6, num_shot_con2=6, num_shot_utt=6,
                                 context = 'contextual scenario',  
                                 qualifier_con1='highly positive', qualifier_con2='offensive',
                                 do_sample=True, pa_con1=0.6, topk_con1=8, pa_con2=0.6, topk_con2=8, pa_utt=0.6, topk_utt=8, 
                                 maxtoken_con1=70, maxtoken_con2 = 70, maxtoken_utt=40, repetition_penalty=1.2, double_polarity= False, multiple_utternace = False, multiple_utternace_no = None)
print(generation)

[{'utterance': ' "Hey guys, why don\'t we plan something fun like a happy hour where we can relax and talk about the latest developments? What do you say?" ', 'context': 'Imagine a group project meeting where team members are feeling stressed out because of constant deadlines and conflicts arising among themselves. One member suggests having a casual gathering outside work hours to bond and share ideas informally. However, another colleague feels threatened by this suggestion since they believe it would lead to exclusion and prefer formal communication methods'}]


In [34]:
# N-iter Multistage Augment 
generation = augment_niter_multistage(2,utterance, model, model, model, 
                                 tokenizer, tokenizer, tokenizer, bad_words_ids=bag_of_words_ids, stack_output = False,
                                 in_p='neutral', tran_p ='benign', out_p= 'toxic',
                                 num_shot_con1= 6, num_shot_con2=6, num_shot_utt=6,
                                 context = 'contextual scenario',  
                                 qualifier_con1='highly positive', qualifier_con2='offensive',
                                 do_sample=True, pa_con1=0.6, topk_con1=8, pa_con2=0.6, topk_con2=8, pa_utt=0.6, topk_utt=8, 
                                 maxtoken_con1=70, maxtoken_con2 = 70, maxtoken_utt=40, repetition_penalty=1.2, double_polarity= False)
print(generation)

100%|████████████████████████████████████████████| 2/2 [03:44<00:00, 112.32s/it]

[{'utterance': ' "Let\'s light up these firecrackers and see how crazy things get!" ', 'context': 'During a political rally, an influential figure encourages supporters to use violence and intimidation tactics against opposing parties while promoting their own agenda, creating chaos and division within the country.'}]





In [None]:
#########--------------------------- Codes for Using Toxigen as utterance with the generation methods ---------------------------#############

In [None]:
# Extracting index of datapoints according to some score filter on toxicity, intent, etc


In [None]:
def extract_toxic_index(dataset, toxic_lowerbound, toxic_upperbound):
    '''
    Filter the index for the opted toxicity threshold
    
    '''
    toxic_list = []
    for index, item in enumerate(dataset['toxicity_human']):
        if item > toxic_lowerbound and item < toxic_upperbound:
            toxic_list.append(index)
    return toxic_list



In [None]:
def extract_intent_index(dataset, intent_lowerbound, intent_upperbound):
    '''
    Filter the index for the opted intent threshold
    
    '''
    intent_list = []
    for index, item in enumerate(dataset['intent']):
        if item > intent_lowerbound and item < intent_upperbound:
            intent_list.append(index)
    return intent_list

In [None]:
def extract_framing_index(dataset, frame = "disagreement" ):
    '''
    Filter the index for the opted framing

    '''
    framing_list = []
    for index, item in enumerate(dataset['framing']):
        if item == frame:
            framing_list.append(index)
    return framing_list

In [None]:
# dataset load 
# Please use huggingface auth token

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
from datasets import load_dataset
dataset_train = load_dataset('skg/toxigen-data', split='train', use_auth_token=True)

In [None]:
# Create Pipeline for generation

In [None]:
def gen_pipeline_augment_direct(dataset_index, dataset, model, tokenizer, bad_words_ids= bad_words_ids, num_shot=6, 
                                in_p='neutral', out_p='toxic', 
                                context = 'contextual scenario', qualifier='toxic', 
                                do_sample=True, penalty_alpha=0.6, top_k=8, 
                                max_new_tokens=70, repetition_penalty=1.2, 
                                double_polarity=False):
    out = {}
    data = dataset_index
    for i in tqdm(range(len(data))):
        key = data[i]
        utterance = dataset[key]['text']
        generation = augment_direct( utterance=utterance,  model=model, tokenizer=tokenizer, num_shot=num_shot, 
                                    in_p=in_p, out_p=out_p, bad_words_ids=bad_words_ids, 
                                    context = context, qualifier=qualifier, 
                                    do_sample=do_sample, penalty_alpha=penalty_alpha, top_k=top_k, 
                                    max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, 
                                    double_polarity=double_polarity)
         
        out[key] = generation
            
    return out   

In [None]:
def gen_pipeline_augment_multistage(dataset_index, dataset, model_con1, model_con2, model_utt, 
                                 tokenizer_con1, tokenizer_con2, tokenizer_utt,
                                 in_p='neutral', tran_p ='benign', out_p= 'toxic',
                                 num_shot_con1= 6, num_shot_con2=6, num_shot_utt=6,
                                 context = 'contextual scenario',  
                                 qualifier_con1='highly positive', qualifier_con2='offensive',
                                 do_sample=True, pa_con1=0.6, topk_con1=8, 
                                 pa_con2=0.6, topk_con2=8, pa_utt=0.6, topk_utt=8, 
                                 maxtoken_con1=70, maxtoken_con2 = 70, maxtoken_utt=40, 
                                 repetition_penalty=1.2, double_polarity= False, 
                                 multiple_utternace = False, multiple_utternace_no = None):
    out = {}
    data = dataset_index
    for i in tqdm(range(len(data))):
        key = data[i]
        utterance = dataset[key]['text']
        generation = augment_multistage(utterance, model_con1=model_con1, model_con2=model_con2, model_utt=model_utt, 
                                 tokenizer_con1=tokenizer_con1, tokenizer_con2=tokenizer_con2, tokenizer_utt=tokenizer_utt,
                                 in_p=in_p, tran_p =tran_p, out_p= out_p,
                                 num_shot_con1= num_shot_con1, num_shot_con2=num_shot_con2, num_shot_utt=num_shot_utt,
                                 context = context,  
                                 qualifier_con1=qualifier_con1, qualifier_con2=qualifier_con2,
                                 do_sample=do_sample, pa_con1=pa_con1, topk_con1=topk_con1, 
                                 pa_con2=pa_con1, topk_con2=topk_con2, pa_utt=pa_utt, topk_utt=topk_utt, 
                                 maxtoken_con1=maxtoken_con1, maxtoken_con2 = maxtoken_con2, maxtoken_utt=maxtoken_utt, 
                                 repetition_penalty=repetition_penalty, double_polarity=double_polarity, 
                                 multiple_utternace = multiple_utternace, multiple_utternace_no = multiple_utternace_no)
         
        out[key] = generation
            
    return out   

In [None]:
def gen_pipeline_augment_niter_multistage(
    num_iteration,
    dataset_index, 
    dataset, 
    model_con1, 
    model_con2, 
    model_utt, 
    tokenizer_con1, 
    tokenizer_con2, 
    tokenizer_utt, 
    stack_output = False,
    in_p='neutral', 
    tran_p ='benign', 
    out_p= 'toxic',
    num_shot_con1= 6,
    num_shot_con2=6, 
    num_shot_utt=6,
    context = 'contextual scenario',  
    qualifier_con1='highly positive', 
    qualifier_con2='offensive',
    do_sample=True, pa_con1=0.6, 
    topk_con1=8, 
    pa_con2=0.6, 
    topk_con2=8, 
    pa_utt=0.6, 
    topk_utt=8, 
    maxtoken_con1=70, 
    maxtoken_con2 = 70, 
    maxtoken_utt=40, 
    repetition_penalty=1.2, 
    double_polarity= False):
    out = {}
    data = dataset_index
    for i in tqdm(range(len(data))):
        key = data[i]
        utterance = dataset[key]['text']
        generation = augment_niter_multistage(num_iteration,utterance, model_con1=model_con1, model_con2=model_con2, model_utt=model_utt, 
                                 tokenizer_con1=tokenizer_con1, tokenizer_con2=tokenizer_con2, tokenizer_utt=tokenizer_utt, stack_output = stack_output,
                                 in_p=in_p, tran_p =tran_p, out_p= out_p,
                                 num_shot_con1= num_shot_con1, num_shot_con2=num_shot_con2, num_shot_utt=num_shot_utt,
                                 context = context,  
                                 qualifier_con1=qualifier_con1, qualifier_con2=qualifier_con2,
                                 do_sample=do_sample, pa_con1=pa_con1, topk_con1=topk_con1, pa_con2=pa_con2, topk_con2=topk_con2, pa_utt=pa_utt, topk_utt=topk_utt, 
                                 maxtoken_con1=maxtoken_con1, maxtoken_con2 = maxtoken_con2, maxtoken_utt=maxtoken_utt, repetition_penalty=repetition_penalty, double_polarity= double_polarity)
         
        out[key] = generation
            
    return out   