In [1]:
import numpy as np
import pickle
import pandas as pd
import spacy
import time
import datetime
import openai
import random
import os
from dotenv import load_dotenv

load_dotenv()
nlp = spacy.load("en_core_web_lg")

In [2]:
directory = 'vers31/MLP_Classifiers_480k_training_15_iter_NN_size_200_50_30'
use_z_values = tuple(range(25))
#use_z_values = (0,6,12,18,24)
#use_z_values = (0,3,4,6,8,9,12,13,14,15,16,17,18,19,20,21,22,23,24)
model_names = ["S1-S2",
               "S1-V2",
               "S1-O2",
               "S1-L2",
               "S1-C2",
               "V1-S2",
               "V1-V2",
               "V1-O2",
               "V1-L2",
               "V1-C2",
               "O1-S2",
               "O1-V2",
               "O1-O2",
               "O1-L2",
               "O1-C2",
               "L1-S2",
               "L1-V2",
               "L1-O2",
               "L1-L2",
               "L1-C2",
               "C1-S2",
               "C1-V2",
               "C1-O2",
               "C1-L2",
               "C1-C2",]
subphrase_indices_x1 = ["string_subj_s1", ] * 5 + ["string_verb_s1", ] * 5 + ["string_obj_s1", ] * 5 + ["string_loc_s1", ] * 5 + ["string_clo_s1", ] * 5
subphrase_indices_x2 = ["string_subj_s2", "string_verb_s2", "string_obj_s2", "string_loc_s2", "string_clo_s2"] * 5

model_names = [name for i,name in enumerate(model_names) if i in use_z_values]
subphrase_indices_x1 = [i for j,i in enumerate(subphrase_indices_x1) if j in use_z_values]
subphrase_indices_x2 = [i for j,i in enumerate(subphrase_indices_x2) if j in use_z_values]

clf = list()
for i in range(len(use_z_values)):
    with open("../03_Bayesian_Network/" + directory + "/MLP_Classifier" + str(i) + ".pkl", "rb") as f:
        clf += [pickle.load(f), ]

In [3]:
# Indices for all z variables influenced by Subject2 (Verb2, Object2 etc. respectively) (e.g. Subject1-Subject2)
Subj2_indices = [i for i in range(len(use_z_values)) if use_z_values[i] in (0, 5, 10, 15, 20)]
Verb2_indices = [i for i in range(len(use_z_values)) if use_z_values[i] in (1, 6, 11, 16, 21)]
Obj2_indices = [i for i in range(len(use_z_values)) if use_z_values[i] in (2, 7, 12, 17, 22)]
Loc2_indices = [i for i in range(len(use_z_values)) if use_z_values[i] in (3, 8, 13, 18, 23)]
Clo2_indices = [i for i in range(len(use_z_values)) if use_z_values[i] in (4, 9, 14, 19, 24)]
Sentence2_indices = [Subj2_indices, Verb2_indices, Obj2_indices, Loc2_indices, Clo2_indices]

In [4]:
train1 = pd.read_csv('../Input_Data/e-SNLI/dataset/esnli_train_1.csv')
train2 = pd.read_csv('../Input_Data/e-SNLI/dataset/esnli_train_2.csv')
train = pd.concat([train1, train2])
train = train[train.notnull().apply(all, axis=1)]
dev = pd.read_csv('../Input_Data/e-SNLI/dataset/esnli_dev.csv')
dev = dev[dev.notnull().apply(all, axis=1)]
test = pd.read_csv('../Input_Data/e-SNLI/dataset/esnli_test.csv')
test = test[test.notnull().apply(all, axis=1)]

dev_prepared = pd.read_csv('../02_Extract_Subphrases/prepared_data/subphrase_vectors_dev.csv', sep=';')
dev_prepared = dev_prepared.drop(columns='Unnamed: 0')
dev = dev.set_index('pairID')
rel_pairIDs = dev_prepared.iloc[:,0]
dev = dev.loc[rel_pairIDs]
y_hat = dev.loc[rel_pairIDs].gold_label
dev_prepared = dev_prepared.iloc[:,1:].to_numpy()

dev_subphrases = pd.read_csv('../02_Extract_Subphrases/prepared_data/subphrases_dev.csv', sep=',')
dev_subphrases = dev_subphrases.set_index('pairID')
dev_subphrases = dev_subphrases.loc[rel_pairIDs]

In [5]:
# Prepare colum indices
indices = np.array([[0,1500], [0,1800], [0,2100], [0,2400], [0,2700],
                    [300,1500], [300,1800], [300,2100], [300,2400], [300,2700],
                    [600,1500], [600,1800], [600,2100], [600,2400], [600,2700],
                    [900,1500], [900,1800], [900,2100], [900,2400], [900,2700],
                    [1200,1500], [1200,1800], [1200,2100], [1200,2400], [1200,2700]])
indices = indices[use_z_values,:].tolist()
# Initialise colulmn indices and "nan" values if information (e.g. location of sentence) is not detected
not_nan = [None, ] * len(use_z_values)
cols = [None, ] * len(use_z_values)
for i in range(len(use_z_values)):
    cols[i] = list(range(indices[i][0], indices[i][0]+300)) + list(range(indices[i][1],indices[i][1]+300))
    not_nan[i] = pd.Series([not x for x in pd.DataFrame(np.isnan(dev_prepared[:,cols[i]])).apply(any, axis=1)])
not_nan = np.array(not_nan).T

In [6]:
def predict_y_from_z(z):

    # Iterate through each row if z is a matrix
    if len(z.shape) > 1:
        z = pd.DataFrame(z)
        res = z.apply(predict_y_from_z, axis=1)
        return res.to_numpy()

    # For each single line perform the following:
    else:
        # If any z is 'contradiction' -> output class 'contradiction'
        if any(z == 'contradiction'):
            return 'contradiction'
        # Else if all subphrases of sentence 2 are entailed by any subphrase of sentence 1 -> output class 'entailment'
        elif all([any([z[i] == 'entailment' for i in subphrase_indices]) or all([z[i] == 'nan' or pd.isnull(z[i]) for i in subphrase_indices]) for subphrase_indices in Sentence2_indices]):
            return 'entailment'
        # Else output class 'neutral'
        else:
            return 'neutral'

In [7]:
def softmax(x):
    """
    Compute softmax values for each set of scores in x.
    """
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

def relu(x):
    return np.maximum(x, 0)

def custom_sign(x):
    if len(x.shape) == 2 and x.shape[1] == 1:
        x = x.reshape((-1,))
    else:
        raise ValueError(f"x must be of shape (n,1), not {x.shape}")
    return np.diag((np.sign(x) + 1) / 2) # np.sign returns 1 and -1, but we want 1 and 0

def get_gradient(x1, x2, net, j):
    x = np.vstack((x1, x2))
    n_hidden = net.n_layers_ - 2 # number of hidden layers (subtract output and input layer)
    n = x.shape[0]

    # Get weights and biases from net
    W = net.coefs_
    W = [w.T for w in W]
    b = net.intercepts_
    b = [bt.reshape((-1,1)) for bt in b]

    # Calculate the partial derivatives of each factor in the chain for the final gradient
    # Forward propagation
    h = [relu(W[0] @ x + b[0])]
    for i in range(1, n_hidden):
        h += [relu(W[i] @ h[i-1] + b[i])]
    h += [softmax(W[n_hidden] @ h[-1] + b[-1])]

    # Backward propagation
    delta = np.zeros((net.n_outputs_, 1))
    delta[j] = 1
    gradients = [custom_sign(W[0] @ x + b[0]) @ W[0][:,int(n/2):], ]
    for i in range(1, n_hidden):
        gradients += [custom_sign(W[i] @ h[i-1] + b[i]) @ W[i], ]
    gradients += [(softmax(W[-1] @ h[-2] + b[-1]) * (delta - softmax(W[-1] @ h[-2] + b[-1]))).T @ W[-1], ]

    gradient = gradients[-1]
    for i in range(n_hidden-1, -1, -1):
        gradient = gradient @ gradients[i]
    return gradient.T

In [8]:
def perturb_input(x1, x2, net, learning_rate=0.5):
    x2_star = x2
    z_hat_original = np.argmax(net.predict_proba(np.vstack((x1, x2)).T))
    z_hat = z_hat_original
    epsilon = 0
    iterations = 0
    while z_hat == z_hat_original:
        iterations += 1
        gradient = get_gradient(x1, x2_star, net, z_hat_original)
        epsilon += learning_rate * -gradient
        x2_star = x2 + epsilon
        z_hat = np.argmax(net.predict_proba(np.vstack((x1, x2_star)).T))
    return x2_star, epsilon, iterations

In [9]:
# Imports
from scipy.spatial import distance

# Format the vocabulary for use in the distance function
ids = [x for x in nlp.vocab.vectors.keys()]
vectors = [nlp.vocab.vectors[x] for x in ids]
vectors = np.array(vectors)

# *** Find the closest word below ***
def find_closest_term(embedding_vector, n=1, distance_fun="euclidean"):
    closest_index = distance.cdist(embedding_vector, vectors, metric=distance_fun).argsort()[0][:n]
    word_id = [ids[i] for i in closest_index]
    output_word = [nlp.vocab[i].text for i in word_id]
    output_vec = [nlp.vocab[i].vector for i in word_id]
    return output_word, output_vec

(['frog'],
 [array([ 1.0258  ,  1.4698  ,  1.2748  , -6.01    ,  0.46523 , -0.74243 ,
         -3.7125  , -0.27701 ,  0.6955  ,  5.2781  ,  2.831   , -2.3511  ,
          3.7929  ,  3.1249  ,  2.3439  , -4.2838  , -0.10871 ,  1.2361  ,
          3.8146  , -4.0382  ,  1.7993  ,  1.6935  ,  0.45288 , -3.9446  ,
          2.7754  , -2.4043  , -0.95622 , -0.63344 ,  4.1845  ,  3.3493  ,
         -4.9989  , -0.96089 ,  1.8568  ,  0.34936 ,  1.6861  , -2.0498  ,
          1.3089  , -0.94749 ,  4.2096  , -5.226   ,  1.1977  , -0.062547,
          1.356   ,  1.4038  ,  2.6408  ,  1.5928  ,  2.6488  ,  0.58255 ,
         -4.7756  , -0.32723 ,  2.0029  ,  7.6926  , -1.4262  ,  1.3991  ,
         -0.97878 , -1.1169  ,  0.35985 ,  0.96784 , -2.0557  ,  0.41393 ,
          1.1737  , -0.90461 ,  3.1241  ,  3.4651  , -2.0677  ,  0.15348 ,
         -1.4759  ,  1.2573  ,  0.22136 , -0.56222 ,  0.64254 , -0.69267 ,
          1.5075  ,  0.41354 ,  0.66168 , -3.125   , -3.1246  ,  0.98155 ,
          1.85

In [10]:
def generate_perturbed_inputs_cont(n_examples):
    pairIDs = list()
    original_sentence1 = list()
    original_sentence2 = list()
    x2_star_string = list()
    perturbed_sentence2 = list()
    perturbed_x_list = list()
    epsilons = list()
    sample_indices = list()
    for counter in range(n_examples):
        print(datetime.datetime.now())
        print(counter)
        i = np.random.choice([k for k in range(dev_prepared.shape[0]) if k not in sample_indices])
        model_index = np.random.choice(np.where(not_nan[i])[0])
        original_x2_string = dev_subphrases[subphrase_indices_x2[model_index]].iloc[i]
        while len(original_x2_string.split()) > 1:
            i = np.random.choice([k for k in range(dev_prepared.shape[0]) if k not in sample_indices])
            model_index = np.random.choice(np.where(not_nan[i])[0])
            original_x2_string = dev_subphrases[subphrase_indices_x2[model_index]].iloc[i]
        cur_x1 = dev_prepared[i,cols[model_index][:300]].reshape((-1,1))
        cur_x2 = dev_prepared[i,cols[model_index][300:]].reshape((-1,1))
        x2_star, epsilon, iterations = perturb_input(cur_x1, cur_x2, clf[model_index])

        ms_x2_star = find_closest_term(x2_star.T, n=1)[0][0]
        epsilon = x2_star - cur_x2
        sample_indices += [i, ]
        perturbed_x = dev_prepared[i,:].tolist()
        for counter, k in enumerate(cols[model_index][300:]):
            perturbed_x[k] = x2_star[counter,0]

        pairIDs += [dev.index[i], ]
        original_sentence1 += [dev.iloc[i].Sentence1, ]
        original_sentence2 += [dev.iloc[i].Sentence2, ]
        perturbed_x_list += [perturbed_x, ]
        x2_star_string += [ms_x2_star, ]
        perturbed_sentence2 += [dev.Sentence2.iloc[i].replace(original_x2_string, ms_x2_star)]
        epsilons += [epsilon, ]

    res = pd.DataFrame({"i": sample_indices, "pairID": pairIDs, "Sentence1": original_sentence1, "Sentence2": original_sentence2, "x2_star_string": x2_star_string, "perturbed_sentence2": perturbed_sentence2, "epsilons": epsilons, "perturbed_x": perturbed_x_list})
    return res

In [11]:
def generate_perturbed_inputs_discr(n_examples):
    pairIDs = list()
    original_sentence1 = list()
    original_sentence2 = list()
    x2_star_string_list = list()
    perturbed_sentence2 = list()
    perturbed_x_list = list()
    epsilons = list()
    sample_indices = list()
    for counter in range(n_examples):
        print(datetime.datetime.now())
        print(counter)
        i = np.random.choice([k for k in range(dev_prepared.shape[0]) if k not in sample_indices])
        model_index = np.random.choice(np.where(not_nan[i])[0])
        original_x2_string = dev_subphrases[subphrase_indices_x2[model_index]].iloc[i]
        while len(original_x2_string.split()) > 1:
            i = np.random.choice([k for k in range(dev_prepared.shape[0]) if k not in sample_indices])
            model_index = np.random.choice(np.where(not_nan[i])[0])
            original_x2_string = dev_subphrases[subphrase_indices_x2[model_index]].iloc[i]
        cur_x1 = dev_prepared[i,cols[model_index][:300]].reshape((-1,1))
        cur_x2 = dev_prepared[i,cols[model_index][300:]].reshape((-1,1))
        original_pred = clf[model_index].predict(np.vstack((cur_x1, cur_x2)).T)

        ms_x2 = find_closest_term(cur_x2.T, n=1)[0][0]
        pos_x2 = nlp(ms_x2)[0].pos_
        closest_terms, closest_vecs = find_closest_term(cur_x2.T, n=20000)
        for term, vec in zip(closest_terms, closest_vecs):
            if clf[model_index].predict(np.vstack((cur_x1, vec.reshape((-1,1)))).T) != original_pred and nlp(term)[0].pos_ == pos_x2:
                x2_star = vec
                x2_star_string = term
                break


        epsilon = x2_star - cur_x2
        sample_indices += [i, ]
        perturbed_x = dev_prepared[i,:].tolist()
        for counter, k in enumerate(cols[model_index][300:]):
            perturbed_x[k] = x2_star[counter]

        pairIDs += [dev.index[i], ]
        original_sentence1 += [dev.iloc[i].Sentence1, ]
        original_sentence2 += [dev.iloc[i].Sentence2, ]
        perturbed_x_list += [perturbed_x, ]
        x2_star_string_list += [x2_star_string, ]
        perturbed_sentence2 += [dev.Sentence2.iloc[i].replace(original_x2_string, x2_star_string)]
        epsilons += [epsilon, ]

    res = pd.DataFrame({"i": sample_indices, "pairID": pairIDs, "Sentence1": original_sentence1, "Sentence2": original_sentence2, "x2_star_string": x2_star_string_list, "perturbed_sentence2": perturbed_sentence2, "epsilons": epsilons, "perturbed_x": perturbed_x_list})
    return res

In [12]:
np.random.seed(12345)
perturbed_inputs_cont = generate_perturbed_inputs_cont(150)

2023-04-25 09:45:18.246759
0
2023-04-25 09:45:19.723954
1
2023-04-25 09:45:20.611317
2
2023-04-25 09:45:22.123186
3
2023-04-25 09:45:23.276859
4
2023-04-25 09:45:25.172327
5
2023-04-25 09:45:27.173485
6
2023-04-25 09:45:29.027111
7
2023-04-25 09:45:32.210715
8
2023-04-25 09:45:33.963651
9
2023-04-25 09:45:37.566252
10
2023-04-25 09:45:38.552901
11
2023-04-25 09:45:39.508924
12
2023-04-25 09:45:40.873168
13
2023-04-25 09:45:44.218056
14
2023-04-25 09:45:45.159248
15
2023-04-25 09:45:46.239921
16
2023-04-25 09:45:47.219864
17
2023-04-25 09:45:48.161692
18
2023-04-25 09:45:49.315064
19
2023-04-25 09:45:50.548785
20
2023-04-25 09:45:51.610544
21
2023-04-25 09:45:52.528700
22
2023-04-25 09:45:53.478136
23
2023-04-25 09:45:54.899068
24
2023-04-25 09:45:55.799653
25
2023-04-25 09:45:56.744008
26
2023-04-25 09:45:57.661247
27
2023-04-25 09:45:58.625845
28
2023-04-25 09:46:00.555773
29
2023-04-25 09:46:01.567465
30
2023-04-25 09:46:02.528644
31
2023-04-25 09:46:03.438805
32
2023-04-25 09:46:04.

In [13]:
np.random.seed(12345)
perturbed_inputs_discr = generate_perturbed_inputs_discr(150)

2023-04-25 09:50:24.041402
0
2023-04-25 09:50:26.452460
1
2023-04-25 09:50:28.465677
2
2023-04-25 09:50:30.269716
3
2023-04-25 09:50:32.131182
4
2023-04-25 09:50:34.032687
5
2023-04-25 09:50:35.899917
6
2023-04-25 09:50:37.696941
7
2023-04-25 09:50:39.693418
8
2023-04-25 09:50:41.699654
9
2023-04-25 09:50:46.731427
10
2023-04-25 09:50:48.499971
11
2023-04-25 09:50:50.249795
12
2023-04-25 09:50:52.155381
13
2023-04-25 09:50:53.870267
14
2023-04-25 09:50:55.646006
15
2023-04-25 09:50:57.543426
16
2023-04-25 09:50:59.386896
17
2023-04-25 09:51:01.106189
18
2023-04-25 09:51:02.882350
19
2023-04-25 09:51:04.571901
20
2023-04-25 09:51:07.995852
21
2023-04-25 09:51:09.733714
22
2023-04-25 09:51:11.418441
23
2023-04-25 09:51:13.257726
24
2023-04-25 09:51:15.013335
25
2023-04-25 09:51:16.816243
26
2023-04-25 09:51:18.588371
27
2023-04-25 09:51:20.299756
28
2023-04-25 09:51:22.902235
29
2023-04-25 09:51:24.572693
30
2023-04-25 09:51:26.282963
31
2023-04-25 09:51:28.072958
32
2023-04-25 09:51:29.

In [14]:
perturbed_inputs_array_discr = np.array([x for x in perturbed_inputs_discr.perturbed_x])
perturbed_inputs_array_cont = np.array([x for x in perturbed_inputs_cont.perturbed_x])

In [15]:
z = np.empty((perturbed_inputs_discr.shape[0], len(use_z_values)), dtype=np.dtype('U25'))

z[:,:] = np.nan

for i in range(len(use_z_values)):
    z[not_nan[perturbed_inputs_discr["i"],i], i] = clf[i].predict(dev_prepared[perturbed_inputs_discr.i][not_nan[perturbed_inputs_discr["i"],i],:][:, cols[i]])

y_hat_pred_original = predict_y_from_z(z)

In [16]:
z = np.empty((perturbed_inputs_discr.shape[0], len(use_z_values)), dtype=np.dtype('U25'))

z[:,:] = np.nan

for i in range(len(use_z_values)):
    z[not_nan[perturbed_inputs_discr["i"],i], i] = clf[i].predict(perturbed_inputs_array_discr[not_nan[perturbed_inputs_discr["i"],i],:][:, cols[i]])

y_hat_pred_discr = predict_y_from_z(z)

In [17]:
z[:,:] = np.nan

for i in range(len(use_z_values)):
    z[not_nan[perturbed_inputs_cont["i"],i], i] = clf[i].predict(perturbed_inputs_array_cont[not_nan[perturbed_inputs_cont["i"],i],:][:, cols[i]])

y_hat_pred_cont = predict_y_from_z(z)

In [18]:
np.mean(y_hat_pred_discr == y_hat_pred_original)

0.5266666666666666

In [19]:
np.mean(y_hat_pred_cont == y_hat_pred_original)

0.6866666666666666

In [20]:
perturbed_inputs_discr

Unnamed: 0,i,pairID,Sentence1,Sentence2,x2_star_string,perturbed_sentence2,epsilons,perturbed_x
0,482,3037060954.jpg#0r1e,A couple is a standing together the woman is a...,People are going to a wedding.,crying,People are crying to a wedding.,"[[-1.4193000476837159, -2.944400040435791, -6....","[-11.4831, 7.1157002, -12.8034, 5.9567, 13.625..."
1,547,2169709244.jpg#1r1c,A little girl in a blue dress is sitting on a ...,The girl is riding on a truck.,griding,The girl is griding on a truck.,"[[-0.4973999744415283, 1.2526050006836653, -1....","[4.5049, 6.245, -7.6763, -10.4062, -4.9654, -2..."
2,2704,2320420111.jpg#1r1n,A man in a bicycle is passing through a house ...,A man is biking to work.,gliding,A man is gliding to work.,"[[-0.9467999955177306, -1.3172199746131896, -4...","[-1.2867, -0.7992, -2.092, -0.77679, -2.5057, ..."
3,884,2881441125.jpg#3r1c,Two small children are gathering water from a ...,Two kids are collecting water from a stream.,reallocating,Two kids are reallocating water from a stream.,"[[1.4034999839782714, 2.7470100156784056, 1.48...","[-10.267729, -2.4905999, -5.9831, 5.208599, 8...."
4,2314,3859014834.jpg#0r1n,Several people with parachutes are overlooking...,Paragliders survey their flight options.,Survey,Paragliders Survey their flight options.,"[[-1.3825599917984008, -3.7612598905181884, 4....","[-15.422, -0.6748997, -8.120831, 6.93622, 15.3..."
...,...,...,...,...,...,...,...,...
145,1565,468688517.jpg#1r1n,A man in a brightly colored vest performs on s...,The singer is at the concert.,Was,The singer Was at the concert.,"[[-4.505699968338012, -1.5487430006265641, -1....","[-1.2867, -0.7992, -2.092, -0.77679, -2.5057, ..."
146,2029,5806030158.jpg#3r1c,Children are playing a game outside.,Children are not a game,Childrens,Childrens are not a game,"[[-0.41727997509002684, -0.5193600222206116, -...","[-0.12984, 0.2113, -3.1924, -1.1077, 1.3806, -..."
147,641,2855910826.jpg#0r1c,A football player in a blue and yellow uniform...,The football player is drinking water on the b...,groundwater,The football player is drinking groundwater on...,"[[0.14056997203826904, -0.26873999631404877, -...","[-3.3224, 6.8543997, 2.17405, 3.6950998, 10.00..."
148,404,3167168073.jpg#1r1c,A man in a blue robe walking a herd of camels ...,A man just smuggle a herd of camels.,kidnaping,A man just kidnaping a herd of camels.,"[[2.9030999542236327, 5.686899960136413, 0.944...","[-1.2867, -0.7992, -2.092, -0.77679, -2.5057, ..."


In [21]:
perturbed_inputs_cont

Unnamed: 0,i,pairID,Sentence1,Sentence2,x2_star_string,perturbed_sentence2,epsilons,perturbed_x
0,482,3037060954.jpg#0r1e,A couple is a standing together the woman is a...,People are going to a wedding.,going,People are going to a wedding.,"[[-0.3598499551859904], [-0.05925772112492034]...","[-11.4831, 7.1157002, -12.8034, 5.9567, 13.625..."
1,547,2169709244.jpg#1r1c,A little girl in a blue dress is sitting on a ...,The girl is riding on a truck.,riding,The girl is riding on a truck.,"[[0.015719028233404764], [-0.02779284053834163...","[4.5049, 6.245, -7.6763, -10.4062, -4.9654, -2..."
2,2704,2320420111.jpg#1r1n,A man in a bicycle is passing through a house ...,A man is biking to work.,biking,A man is biking to work.,"[[-0.5273899911587951], [0.04608260788717358],...","[-1.2867, -0.7992, -2.092, -0.77679, -2.5057, ..."
3,884,2881441125.jpg#3r1c,Two small children are gathering water from a ...,Two kids are collecting water from a stream.,collecting,Two kids are collecting water from a stream.,"[[0.13855687178334808], [0.24085338364798137],...","[-10.267729, -2.4905999, -5.9831, 5.208599, 8...."
4,2314,3859014834.jpg#0r1n,Several people with parachutes are overlooking...,Paragliders survey their flight options.,survey,Paragliders survey their flight options.,"[[0.6060190406680043], [0.1822634969993011], [...","[-15.422, -0.6748997, -8.120831, 6.93622, 15.3..."
...,...,...,...,...,...,...,...,...
145,1565,468688517.jpg#1r1n,A man in a brightly colored vest performs on s...,The singer is at the concert.,is,The singer is at the concert.,"[[1.3397752498563107], [1.506416084146296], [-...","[-1.2867, -0.7992, -2.092, -0.77679, -2.5057, ..."
146,2029,5806030158.jpg#3r1c,Children are playing a game outside.,Children are not a game,Children,Children are not a game,"[[-0.0570142794113766], [0.06489870345457899],...","[-0.12984, 0.2113, -3.1924, -1.1077, 1.3806, -..."
147,641,2855910826.jpg#0r1c,A football player in a blue and yellow uniform...,The football player is drinking water on the b...,water,The football player is drinking water on the b...,"[[0.00752735811690447], [0.003784208204943962]...","[-3.3224, 6.8543997, 2.17405, 3.6950998, 10.00..."
148,404,3167168073.jpg#1r1c,A man in a blue robe walking a herd of camels ...,A man just smuggle a herd of camels.,smuggle,A man just smuggle a herd of camels.,"[[0.4532723982181177], [0.0324076792813765], [...","[-1.2867, -0.7992, -2.092, -0.77679, -2.5057, ..."


In [22]:
def create_query(s1, s2):
    return ('Premise: ' + s1 +
            '\nHypothesis: ' + s2 +
            '\nLabel: ')

In [23]:
def single_row_to_string(row):
    s1 = row.Sentence1
    s2 = row.Sentence2
    expl = row.Explanation_1
    label= row.gold_label
    label_map = {"entailment": "Yes",
                 "neutral": "Maybe",
                 "contradiction": "No"}
    return ('Premise: ' + s1 +
            '\nHypothesis: ' + s2 +
            '\nLabel: ' + label +
            '\nExplanation: ' + expl + '\n###\n')

In [24]:
def prepare_examples(data, size_per_class=4):
    example_indices = list()
    for cat in ['neutral', 'contradiction', 'entailment']:
        example_indices += list(np.random.choice(data[data.gold_label == cat].index.values, size=size_per_class, replace=False))
        random.shuffle(example_indices)
    data = data.loc[example_indices]
    res = list()
    for row in data.itertuples():
        res += [single_row_to_string(row)]
    return 'Classify into entailment, neutral, or contradiction and justify the decision.\n\n' + ''.join(res)

In [25]:
labels_discr = list()
explanations_discr = list()
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = "org-zHIEQdY05F58L0NE73T46O4K"

for i in range(perturbed_inputs_discr.shape[0]):
    prompt = prepare_examples(train, size_per_class=2) + create_query(perturbed_inputs_discr.iloc[i].Sentence1, perturbed_inputs_discr.iloc[i].perturbed_sentence2)
    response = openai.Completion.create(
        engine='text-davinci-003',
        prompt = prompt,
        temperature=0,
        max_tokens=58,
        top_p=1,
        )
    labels_discr += [response.choices[0].text.split("\nExplanation: ")[0].strip(),]
    explanations_discr += [response.choices[0].text.split("\nExplanation: ")[1], ]

In [26]:
labels_cont = list()
explanations_cont = list()
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = "org-zHIEQdY05F58L0NE73T46O4K"

for i in range(perturbed_inputs_cont.shape[0]):
    prompt = prepare_examples(train, size_per_class=2) + create_query(perturbed_inputs_cont.iloc[i].Sentence1, perturbed_inputs_cont.iloc[i].perturbed_sentence2)
    response = openai.Completion.create(
        engine='text-davinci-003',
        prompt = prompt,
        temperature=0,
        max_tokens=58,
        top_p=1,
    )
    labels_cont += [response.choices[0].text.split("\nExplanation: ")[0].strip(),]
    explanations_cont += [response.choices[0].text.split("\nExplanation: ")[1], ]

In [28]:
labels_discr = np.array(labels_discr)
print(f"Change in overall predicted label: {np.mean(y_hat_pred_discr != y_hat_pred_original)}")
print(f"Portion of same predictions between GPT-3 and SSM: {np.mean(y_hat_pred_discr == labels_discr)}")
print(f"Portion of same predictions between GPT-3 and SSM when prediction did not change: {np.mean(y_hat_pred_discr[np.where(y_hat_pred_discr == y_hat_pred_original)] == labels_discr[np.where(y_hat_pred_discr == y_hat_pred_original)])}")
print(f"Portion of same predictions between GPT-3 and SSM when prediction changed: {np.mean(y_hat_pred_discr[np.where(y_hat_pred_discr != y_hat_pred_original)] == labels_discr[np.where(y_hat_pred_discr != y_hat_pred_original)])}")

Change in overall predicted label: 0.47333333333333333
Portion of same predictions between GPT-3 and SSM: 0.4
Portion of same predictions between GPT-3 and SSM when prediction did not change: 0.5316455696202531
Portion of same predictions between GPT-3 and SSM when prediction changed: 0.2535211267605634


In [29]:
labels_cont = np.array(labels_cont)
print(f"Change in overall predicted label: {np.mean(y_hat_pred_cont != y_hat_pred_original)}")
print(f"Portion of same predictions between GPT-3 and SSM: {np.mean(y_hat_pred_cont == labels_cont)}")
print(f"Portion of same predictions between GPT-3 and SSM when prediction did not change: {np.mean(y_hat_pred_cont[np.where(y_hat_pred_cont == y_hat_pred_original)] == labels_cont[np.where(y_hat_pred_cont == y_hat_pred_original)])}")
print(f"Portion of same predictions between GPT-3 and SSM when prediction changed: {np.mean(y_hat_pred_cont[np.where(y_hat_pred_cont != y_hat_pred_original)] == labels_cont[np.where(y_hat_pred_cont != y_hat_pred_original)])}")

Change in overall predicted label: 0.31333333333333335
Portion of same predictions between GPT-3 and SSM: 0.44
Portion of same predictions between GPT-3 and SSM when prediction did not change: 0.5728155339805825
Portion of same predictions between GPT-3 and SSM when prediction changed: 0.14893617021276595
