# Analysis and Creation of updating formula for graph edges

## Install Required Packages

In [None]:
!pip install numpy
!pip install pandas
!pip install networkx
!pip install scikit-learn
!pip install matplotlib

In [None]:
!pip install polyjuice_nlp
!pip install torch
!pip install evaluate
!pip install bert_score
!python -m spacy download en_core_web

## Imports

In [115]:
# general imports
import numpy as np
import pandas as pd
import networkx as nx

# Metric-related imports
import torch
from transformers import OpenAIGPTTokenizer, OpenAIGPTLMHeadModel
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from evaluate import load
from pylev import levenshtein as lev_dist

In [116]:
%run GPT2_functions.ipynb

In [None]:
%run graph_functions.ipynb

In [118]:
bertscore = load("bertscore")

## Edge Updating Function

In [119]:
def update_edges(edges, substitutions, lr, baseline_metric_value, current_metric_value):
    """
    A function that takes as input a list of weighted edges along with other parameters, and uses
    these parameters to update the edge weights.

    :param edges: an iterable containing weighted edges as tuples
    :param substitutions: a dictionary with edges as keys, and their substitution occurence as values
    :param lr: float value, representing the learning rate for the weight updating
    :param baseline_metric_value: a float, representing the baseline evaluation metric value
    :param current_metric_value: a float, representing the current evaluation metric value
    :returns: a list of tuples, where each tuple represents an updated weighted edge
    """
    
    updated_edges = list()
    for (u, v, w) in edges:
        try:
            # get substitution occurences for each edge
            edge_subs = substitutions[(u,v)] 
            # updating formula
            new_w = w - lr * (baseline_metric_value - current_metric_value) / edge_subs   # - for minimizing, + for maximizing
            # add the updated edge to the list
            updated_edges.append((u, v, new_w))
        except KeyError:
            print("Something went wrong during updating of edges' weights")
    
    return updated_edges

## Metric Functions

In [129]:
def generate_model_agnostic_metrics(data, counter_data):
    """
    A function that takes as input the original and the counter data and returns a dictionary with 
    model-agnostic metrics such as closeness and fluency.

    :param data: dataframe containing one column with the original data
    :param counter_data: dataframe containing one column with the counter data
    :returns: dictionary containing model-agnostic metrics
    """
    
    # extract sentences and counter-sentences from the data and check that they are of the same length
    sentences = [elem[0] for elem in data.values.tolist()]
    counter_sentences= [elem[0] for elem in counter_data.values.tolist()]

    assert len(sentences) == len(counter_sentences)

    sent_length = len(sentences)

    # compute average levenshtein distance as a measurement of closeness
    avg_lev = sum(list(map(lambda x: lev_dist(x[0], x[1]), zip(sentences, counter_sentences)))) / sent_length

    # compute average fluency
    model, tokenizer = model_init()
    #avg_fluency = sum(list(map(lambda x: sent_scoring(model, tokenizer, x)[0], counter_sentences))) / sent_length
    avg_fluency = sum(
        list(map(lambda x: abs(sent_scoring(model, tokenizer, x[0])[0] - sent_scoring(model, tokenizer, x[1])[0]), zip(sentences, counter_sentences)))
    ) / len(sentences)

    # compute average bertscore
    avg_bertscore = 1 - sum(bertscore.compute(predictions=counter_sentences, references=sentences, model_type="distilbert-base-uncased")['f1']) / sent_length
    
    # create metrics dictionary
    metrics = {
        'levenshtein': avg_lev,     # we want this to be as low as possible
        'fluency': avg_fluency,     # we want this to be as low as possible (it is the difference |original_fluency - counter_fluency|)
        'bertscore': avg_bertscore  #  we want this to be as low as possible
    }

    return metrics

In [121]:
def generate_model_related_metrics(original_p, counter_p):
    """
    A function that takes as input the original predictions and the new ones, and returns a dictionary with 
    model-related metrics such as flip-rate.

    :param original_p: list containing the predictions for the original data
    :param counter_p: dataframe containing the predictions for the counter data
    :returns: dictionary containing model-related metrics
    """

    # check that predictions and counter_predictions are of the same length
    assert len(original_p) == len(counter_p)
    
    # compute flip_rate
    flip_rate_percent = sum(x[0] != x[1] for x in zip(original_p, counter_p)) / len(original_p)

    # create metrics dictionary
    metrics = {
        'flip-rate': flip_rate_percent
    }

    return metrics

In [122]:
def get_counterfactual_metric(metrics):
    """
    A function that takes as input a dictionary containing evalutation metrics, and returns
    a combination of those metrics.

    :param metrics: dictionary containing different evaluation metrics such as fluency, flip-rate, etc.
    :returns: float value, computed as a combination of the metrics in the given dictionary
    """

    return 2 / (1/metrics['bertscore'] + 1/metrics['fluency'])
    #return len(metrics) / sum(1/v for v in metrics.values())   # compute final metric as the harmonic mean of the given metrics

In [123]:
def get_baseline_metric(data, model_required=False, preprocessor=None, model=None, antonyms=False):
    """
    A function that takes as input a dataframe with the textual data, and computes a metric based on a bipartite graph,
    where the edge weights represent the distance between words (nodes) as extracted from wordnet.

    :param data: pd.DataFrame() containing one column with the textual data
    :param model_required: boolean value specifing whether a pretrained model is also required for the metric computation
    :param preprocessor: a custom class that implements the necessary preprocessing of the data
    :param model: a pretrained model on the dataset 
    :returns: a float value representing the computed evaluation metric
    """
    
    sents = [elem[0] for elem in data.values.tolist()]
    counter_sents, _, _, _= get_edits(sents, pos='noun', thresh=1, antonyms=antonyms)
    
    counter_data_df = pd.DataFrame({
        'counter_sents': counter_sents
    })
    metrics = generate_model_agnostic_metrics(data, counter_data_df)

    if model_required:
        # first process the original data and get model predictions
        processed_data = preprocessor.process(data)
        original_preds = model.predict(processed_data)
    
        # do the same but for the counterfactual-generated data
        processed_counter_data = preprocessor.process(counter_data_df)
        counter_preds = model.predict(processed_counter_data)

        # add model-related metrics to the current_metrics dictionary
        metrics.update(generate_model_related_metrics(original_preds, counter_preds))
            
    return get_counterfactual_metric(metrics)

## Graph-Related Functions

In [136]:
def create_graph(data, pos, antonyms=False):
    """
    A function that takes as input a dataframe and a part-of-speech tag, and creates a bipartite graph
    with the possible substitution words and their candidates.

    :param data: pd.DataFrame() containing one column with the textual data
    :param pos: string that specifies which part-of-speech shall be considered for substitution (noun, verb, adv)
    :param antonyms: boolean value specifing whether or not to use antonyms in the candidate substitutions 
    :returns: a dictionary containing the graph, along with other related features
    """

    sentences = [elem[0] for elem in data.values.tolist()]
    lst = None
    
    # use appropriate function based on pos to get the list of the specified pos words from the data
    if pos == 'adv':
        lst = create_attributes_list(sentences)
    elif pos == 'verb':
        lst = create_verb_list(sentences)
    elif pos == 'noun':
        lst = create_singular_list(sentences) 
    else:
        raise AttributeError("pos '{}' is not supported!".format(pos)) 
    
    weights = []
    syn0 = list(lst)
    syn1 = list(get_antonym_list(lst)) if antonyms else list(lst)
        
    all_syn0, d0, ind0 = get_synsets(syn0, return_index=True)
    all_syn1, d1, ind1= get_synsets(syn1, return_index=True)
    names0 = ['G0_'+str(i) for i in range(len(all_syn0))]  # give unique names for each synset of the two sets
    names1 = ['G1_'+str(i) for i in range(len(all_syn1))]

    word_to_node0 = dict()
    word_to_node1 = dict()
    for t in zip(names0, ind0):
        word_to_node0[syn0[t[1]]] = t[0]

    for t in zip(names1, ind1):
        word_to_node1[syn1[t[1]]] = t[0]
        
    
    # synset as key, word as val
    combinations_nodes = all_combinations(names0, names1)        # all combinations of names
    combinations_synsets = all_combinations(all_syn0, all_syn1)  # all combinations of synsets
    weights = [1] * len(combinations_nodes)
   
    G, min_list_nodes = bipartite_graph(names0, names1, combinations_nodes, weights) # create bipartite graph

    graph_dict = {
        'graph': G,
        'min_list_nodes': min_list_nodes,
        'weights': weights,
        'd0': d0,
        'd1': d1,
        'comb_nodes': combinations_nodes,
        'comb_syn': combinations_synsets,
        'word_to_node0': word_to_node0,
        'word_to_node1': word_to_node1
    }

    return graph_dict

In [125]:
def generate_graph_matching(graph_dict):
    """
    A function that takes as input a dictionary containing a graph and other related features, and uses
    a minimum graph matching algorithm to return candidate substitutions, along with other graph features.

    :param graph_dict: a dictionary containing a bipartite graph and other related features
    :returns: a list of feasible substitutions, mappings of synsets to their words, and a tuple containing the graph, a min_list_nodes and the minimum matching
    """
    
    # unpack dictionary items
    G = graph_dict['graph']
    min_list_nodes = graph_dict['min_list_nodes']
    weights = graph_dict['weights']
    d0 = graph_dict['d0']
    d1 = graph_dict['d1']
    combinations_nodes = graph_dict['comb_nodes']
    combinations_synsets = graph_dict['comb_syn']

    # find min weight match
    min_match = minimum_match(G, min_list_nodes)                                     
    match_tuple = dict_to_tuple(min_match)
    
    new_match=[]
    for i in match_tuple:
        new_match.append(tuple(sorted(i)))
        new_match = remove_duplicates(new_match)

    positions = pos_in_list(combinations_nodes, list(new_match))
    substitution_synsets = []
    
    for i in positions:
        substitution_synsets.append((weights[i], combinations_synsets[i][0], combinations_synsets[i][1]))     
    # sum_similarities, avg_similarity, best_matched_synsets = total_graph_weight(positions, weights, combinations_synsets)
    
    return substitution_synsets, d0, d1, (G, min_list_nodes, new_match)

In [126]:
def generate_counterfactuals(graph_dict, data, pos):
    """
    A function that takes as input a dictionary containing graph information, along with a dataframe and a part-of-speech tag,
    and uses them to generate counterfactual edits from the data.

    :param graph_dict: a dictionary containing a bipartite graph and other related features
    :param data: pd.DataFrame() containing one column with the textual data
    :param pos: string that specifies which part-of-speech shall be considered for substitution (noun, verb, adv)
    :returns: a dataframe with the generated counterfactual data, a list of selected edges from the graph and a dictionary containing substitution occurrence
    """
    
    G = graph_dict['graph']
    w2n0 = graph_dict['word_to_node0']
    w2n1 = graph_dict['word_to_node1']
    sentences = [elem[0] for elem in data.values.tolist()] 
    
    # find best matching and generate edits
    substitution_synsets, d0, d1, g = generate_graph_matching(graph_dict)
    all_swaps, if_change, attr_counter, substitutions = external_swaps(sentences, pos, substitution_synsets, d0, d1, thresh=4)
    

    counter_data = pd.DataFrame({
        'counter_sents': all_swaps
    })
    
    subs_as_nodes = {(w2n0[k[0]], w2n1[k[1]]): v for (k,v) in substitutions.items()}

    selected_edges = []
    for (u,v) in subs_as_nodes.keys():
        w = G.get_edge_data(u, v, default=0)['weight']
        selected_edges.append((u, v, w))

    return counter_data, selected_edges, subs_as_nodes

In [127]:
def train_graph(graph_dict, data, pos, preprocessor=None, model=None, learning_rate=0.1, th=0.005, max_iterations=100, model_required=False):
    """
    A function that represents the training process for the graph edges. It gets predictions for the original data
    then uses a graph approach to generate counter data and get predictions for them. To get the current_metric
    it compares the two predictions and based on those updates the weights of the selected edges.
    
    :param graph_dict: a dictionary containing the bipartite graph along with other variables and characteristics
    :param data: a dataframe containing the textual examples we will use to train the graph
    :param pos: a string specifing which part-of-speech shall be considered for substitutions (noun, verb, adv)
    :param preprocessor: a custom class that implements the necessary preprocessing of the data
    :param model: a pretrained model on the dataset
    :param learning_rate: float value defining how fast or slow the edge weights will be updated
    :param th: float value defining a threshold, where if the difference |baseline - current| get smaller, the training stops
    :param max_iterations: integer value representing the maximum number of iterations for the training procedure
    :param model_required: boolean value for whether or not to compute model-related metrics
    :returns: the graph_dictionary with the fine-tuned (post-training) graph along with the rest of its features
    """
    
    # generate counter data and compute model_agnostic metrics
    baseline_metric = get_baseline_metric(data, model_required=model_required, preprocessor=preprocessor, model=model)
    current_metric = baseline_metric + 2 * th   # initialize current_metric so that the dif |baseline-current| is bigger than th
    
    iterations = 0
    next_baseline_metric = baseline_metric
    while abs(current_metric - baseline_metric) >= th and iterations < max_iterations:
        print("ITERATION {}".format(iterations))

        updated_edges = []
        baseline_metric = next_baseline_metric

        while nx.is_bipartite(graph_dict['graph']):
            try:
                # produce new counter_data and compute current_metric valule
                counter_data, selected_edges, substitutions = generate_counterfactuals(graph_dict, data, pos)
                current_metrics_dict = generate_model_agnostic_metrics(data, counter_data)
        
                # if needed, compute model-related metrics as well
                if model_required:
                    processed_counter_data = preprocessor.process(counter_data)
                    counter_preds = model.predict(processed_counter_data)
            
                    # add model-related metrics to the current_metrics dictionary
                    current_metrics_dict.update(generate_model_related_metrics(original_preds, counter_preds))  
        
                # compute the final metric as a combination of the previously computed metrics
                current_metric = get_counterfactual_metric(current_metrics_dict)
    
                g = graph_dict['graph']
                g.remove_edges_from(selected_edges)
                new_edges = update_edges(selected_edges, substitutions, learning_rate, baseline_metric, current_metric)
                
                graph_dict['graph'] = g
                updated_edges.extend(new_edges)
            except:
                graph_dict['graph'] = g
                break
            
        g = graph_dict['graph']
        # print(updated_edges)
        g.add_weighted_edges_from(updated_edges)
        graph_dict['graph'] = g

        # update baseline_metric value and iterations
        next_baseline_metric = min(baseline_metric, current_metric)
        iterations += 1

    return graph_dict

## Testing

In [139]:
POS = 'adv'
MAX_ITER = 3
ANTONYMS = True

df = pd.DataFrame({
    'sents': [
        'A great man was standing in a tall and magnificent hill, gazing upon the sad and destructive army',
        'The clever boy was wondering when the fat dog would return with the big stick',
        'A small town was standing next to the large river and the tall building'
    ]
})


gd = create_graph(data=df, pos=POS, antonyms=ANTONYMS)
trained_gd = train_graph(graph_dict=gd, data=df, pos=POS, max_iterations=MAX_ITER)

ITERATION 0


In [140]:
counter_data, selected_edges, subs = generate_counterfactuals(trained_gd, df, POS)
for i in range(df.shape[0]):
    print("ORIGINAL:")
    print(df['sents'][i])
    print("COUNTER:")
    print(counter_data['counter_sents'][i])
    print("===============================================================================================================")

ORIGINAL:
A great man was standing in a tall and magnificent hill, gazing upon the sad and destructive army
COUNTER:
a large man was standing in a glad and magnificent hill, gazing upon the sad and destructive army.
ORIGINAL:
The clever boy was wondering when the fat dog would return with the big stick
COUNTER:
the clever boy was wondering when the short dog would return with the big stick.
ORIGINAL:
A small town was standing next to the large river and the tall building
COUNTER:
a little town was standing next to the small river and the glad building.


In [143]:
final_metric = get_counterfactual_metric(generate_model_agnostic_metrics(df, counter_data))
baseline_metric = get_baseline_metric(df)

print("Baseline metric value: {}".format(baseline_metric))
print("Fine-tuned metric value: {}".format(final_metric))
print("Difference: {}".format(abs(baseline_metric - final_metric)))

Baseline metric value: 0.07030562703168462
Fine-tuned metric value: 0.06654568031053844
Difference: 0.0037599467211461846
