In [None]:
!pip install datasets



In [None]:
from datasets import load_dataset
import pandas as pd
import re
import numpy as np

def load_trip(N):
    test_dataset = load_dataset("sled-umich/TRIP", split='ClozeTest')
    df = pd.DataFrame(columns=['group_1', 'group_2', 'plausible', 'confl_pairs'])
    for i in range(N):
        group_1 = test_dataset[i]['stories'][0]['sentences']
        group_2 = test_dataset[i]['stories'][1]['sentences']
        plausible_1 = test_dataset[i]['stories'][0]['plausible']
        plausible = 1 if plausible_1 else 2
        if plausible_1:
            confl_pairs = test_dataset[i]['stories'][1]['confl_pairs']
        else:
            confl_pairs = test_dataset[i]['stories'][0]['confl_pairs']
        df.loc[i] = [group_1, group_2, plausible, confl_pairs]
    return df

trip_df = load_trip(100)
trip_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 100 entries, 0 to 99
Data columns (total 4 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   group_1      100 non-null    object
 1   group_2      100 non-null    object
 2   plausible    100 non-null    int64 
 3   confl_pairs  100 non-null    object
dtypes: int64(1), object(3)
memory usage: 3.9+ KB


In [None]:
df = pd.read_csv('clingo_few_shot_mistral_results.csv')
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100 entries, 0 to 99
Data columns (total 3 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   ground_truth  100 non-null    int64 
 1   predictions   100 non-null    object
 2   answers       100 non-null    object
dtypes: int64(1), object(2)
memory usage: 2.5+ KB


In [None]:
def extract_stories(text):
    seq_pattern = r"sequence\((.*?)\)"
    sequence_matches = re.findall(seq_pattern, text)
    story_1 = []
    story_2 = []
    for seq_match in sequence_matches:
      if seq_match[0] == "1":
        story_1.append(seq_match.split()[-1])
      elif seq_match[0] == "2":
        story_2.append(seq_match.split()[-1])

    conflict_pairs = []
    imp_pattern = r"implausible\(C\) :- (.*?)\."
    for match in re.findall(imp_pattern, text):
        pair = []
        for seq_match in re.findall(seq_pattern, match):
            pair.append(seq_match.split()[-1])
        if len(pair) >= 2 :
            conflict_pairs.append(pair)
    return story_1, story_2, conflict_pairs

def get_plausible_story(text):
    story_1, story_2, conflict_pairs = extract_stories(text)
    for conflict_pair in conflict_pairs:
        if set(conflict_pair).issubset(set(story_1)):
            return 2
        elif set(conflict_pair).issubset(set(story_2)):
            return 1
    return None

def get_conflict_pairs_index(text):
    story_1, story_2, conflict_pairs = extract_stories(text)
    conflict_pairs_indexes = []
    for conflict_pair in conflict_pairs:
        if set(conflict_pair).issubset(set(story_1)):
            index_a = story_1.index(conflict_pair[0])
            index_b = story_1.index(conflict_pair[1])
            conflict_pairs_indexes.append([min(index_a, index_b), max(index_a, index_b)])
        elif set(conflict_pair).issubset(set(story_2)):
            index_a = story_2.index(conflict_pair[0])
            index_b = story_2.index(conflict_pair[1])
            conflict_pairs_indexes.append([min(index_a, index_b), max(index_a, index_b)])
    if len(conflict_pairs_indexes) == 0:
        return None
    return conflict_pairs_indexes

def print_conflict(index, print_text=True):
    df = pd.read_csv('clingo_few_shot_mistral_results.csv')
    text = df.iloc[index, 2]
    if print_text:
      print(text)
    story_1, story_2, conflict_pairs = extract_stories(text)
    print(story_1)
    print(story_2)
    print(conflict_pairs)
    print("Plausible Story:", get_plausible_story(text))
    print("Ground Truth:", df.loc[index, 'ground_truth'])
    print("Prediction:", df.loc[index, 'predictions'])
    print("Conflict Pairs Indexes:", get_conflict_pairs_index(text))

In [None]:
def get_eval_df():
    df = pd.read_csv('clingo_few_shot_mistral_results.csv')
    trip_df = load_trip(100)
    eval_df = pd.DataFrame(columns=['pred_plausible', 'plausible', 'pred_confl_pairs', 'confl_pairs'])
    for k in range(len(df)):
        text = df.loc[k, 'answers']
        pred_plausible = get_plausible_story(text)
        pred_confl_pairs = get_conflict_pairs_index(text)
        eval_df.loc[k] = [pred_plausible, trip_df.loc[k, 'plausible'], pred_confl_pairs, trip_df.loc[k, 'confl_pairs']]
    return eval_df

eval_df = get_eval_df()
eval_df.head()


Unnamed: 0,pred_plausible,plausible,pred_confl_pairs,confl_pairs
0,2.0,2,"[[4, 5]]","[[3, 4]]"
1,1.0,1,"[[3, 4]]","[[3, 4]]"
2,,1,,"[[2, 3]]"
3,,1,,"[[2, 3]]"
4,2.0,2,"[[4, 5]]","[[2, 4], [3, 4]]"


In [None]:
def evaluate():
    eval_df = get_eval_df()
    correct_plausible = 0
    correct_confl_pairs = 0
    for k in range(len(eval_df)):
        if eval_df.loc[k, 'pred_plausible'] == eval_df.loc[k, 'plausible']:
            correct_plausible += 1
            if eval_df.loc[k, 'pred_confl_pairs'] == eval_df.loc[k, 'confl_pairs']:
                correct_confl_pairs += 1
    print("Plausible Accuracy:", correct_plausible/len(eval_df))
    print("Conflict Pairs Consistency:", correct_confl_pairs/len(eval_df))

evaluate()

Plausible Accuracy: 0.65
Conflict Pairs Consistency: 0.21


In [None]:
for i in range(10):
  print_conflict(i, print_text=False)
  print()

['walk_in_door', 'step_on_book', 'stomp_on_carpet', 'smash_radio', 'turn_on_radio', 'sit_down']
['walk_in_door', 'step_on_book', 'stomp_on_carpet', 'smash_radio', 'give_box', 'apologize']
[['turn_on_radio', 'sit_down']]
Plausible Story: 2
Ground Truth: 2
Prediction: 1
Conflict Pairs Indexes: [[4, 5]]

['walk_in_door', 'step_on_book', 'stomp_carpet', 'smash_radio', 'give_box', 'apologize']
['walk_in_door', 'step_on_book', 'stomp_carpet', 'smash_radio', 'switch_on_radio', 'hear_music']
[['smash_radio', 'switch_on_radio']]
Plausible Story: 1
Ground Truth: 1
Prediction: 1
Conflict Pairs Indexes: [[3, 4]]

['fill_sandbox', 'set_sandbox', 'take_off_shoes', 'put_shoes_in_duffle_bag', 'get_in_sandbox']
['fill_sandbox', 'set_sandbox', 'put_on_shoes', 'put_shoes_in_duffle_bag', 'get_in_sandbox']
[['take_off_shoes', 'put_on_shoes']]
Plausible Story: None
Ground Truth: 1
Prediction: 1
Conflict Pairs Indexes: None

['fill_sandbox', 'set_sandbox', 'take_off_shoes', 'put_shoes_in_duffle_bag', 'get_in