In [47]:
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install jsonlines



In [48]:
from typing import Optional, Union
from datasets import load_dataset
from dataclasses import dataclass
import evaluate
import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForMultipleChoice, get_scheduler, AutoConfig, AutoModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from tqdm import tqdm
import argparse

import numpy as np
import scipy as sp

import torch.nn as nn
import torch.nn.functional as F
import argparse
import json
import os
import sys
import random
import pickle


In [49]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

SEED = 595
set_seed(595)

In [50]:
from google.colab import drive

drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [51]:
GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = 'Colab Notebooks/EECS595/Project/Verifiable-Coherent-NLU-main'

In [52]:
DRIVE_PATH = os.path.join("drive", "My Drive", GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)
sys.path.append(DRIVE_PATH)
print(os.listdir(DRIVE_PATH))

['README.md', 'requirements.txt', 'www', 'all_data', 'cache', 'saved_models', 'Verifiable-Coherent-NLU.ipynb', 'BERT_CE.ipynb', 'withBRET_CE.ipynb', 'project_fineTunePIQA.ipynb', 'withBRET_PIQA.ipynb']


In [53]:
task_name = 'trip'

debug = False


config_batch_size = 1
config_lr = 1e-5 # Selected learning rate for best RoBERTa-based model in TRIP paper
config_epochs = 10


In [None]:
loss_weights = [0.0, 0.4, 0.4, 0.2, 0.0]

In [54]:
from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer, AlbertTokenizer, T5Tokenizer, GPT2Tokenizer

In [55]:
mode = 'bert'
tokenizer_class = BertTokenizer
model_name = 'bert-base-uncased'

In [56]:
tokenizer = tokenizer_class.from_pretrained('drive/My Drive/Colab Notebooks/EECS595/Project/Verifiable-Coherent-NLU-main/saved_models/finetunedBert_PIQA',
                                                do_lower_case = False,
                                                cache_dir=os.path.join(DRIVE_PATH, 'cache'))


In [57]:
from transformers import BertForSequenceClassification, RobertaForSequenceClassification, DebertaForSequenceClassification, AlbertForSequenceClassification, AdamW
from transformers import BertForMultipleChoice, RobertaForMultipleChoice, AlbertForMultipleChoice, DebertaModel
from transformers import BertModel, RobertaModel, AlbertModel, DebertaModel, T5Model, T5EncoderModel, GPT2Model
from transformers import RobertaForMaskedLM
from transformers import BertConfig, RobertaConfig, DebertaConfig, AlbertConfig, T5Config, GPT2Config
from www.model.transformers_ext import DebertaForMultipleChoice
from torch.optim import Adam

In [58]:
from www.utils import print_dict

partitions = ['train', 'dev', 'test']
subtasks = ['cloze', 'order']

# We can split the data into multiple json files later
#data_file = os.path.join(DRIVE_PATH, 'all_data/www.json')
with open(os.path.join(DRIVE_PATH, 'all_data/www.json'), 'r') as f:
  dataset = json.load(f)

print('Preprocessed examples:')
for ex_idx in [0,1,5,10]:
  ex = dataset['dev'][list(dataset['dev'].keys())[ex_idx]]
  print_dict(ex)

"""### Data Filtering and Sampling
Since there is a big imbalance between plausible/implausible class labels, we will upsample the plausible stories.

For now, we will also break the dataset into two sub-datasets: cloze and ordering.


"""

cloze_dataset = {p: [] for p in dataset}
order_dataset = {p: [] for p in dataset}

for p in dataset:
  for exid in dataset[p]:
    ex = dataset[p][exid]

    if ex['type'] == None:
      continue

    ex_plaus = dataset[p][str(ex['story_id'])]

    if ex['type'] == 'cloze':
      cloze_dataset[p].append(ex)
      cloze_dataset[p].append(ex_plaus) # For every implausible story, add a copy of its corresponding plausible story

    # Exclude augmented ordering examples from dev and test, since the breakpoints aren't always accurate in those
    elif ex['type'] == 'order' and not (p != 'train' and ex['aug']):
      order_dataset[p].append(ex)
      order_dataset[p].append(ex_plaus)

"""

### Convert TRIP to Two-Story Classification Task

Ready the TRIP dataset for two-story classification."""

from www.utils import print_dict
import json
from collections import Counter

#data_file = os.path.join(DRIVE_PATH, 'all_data/www_2s_new.json')
with open(os.path.join(DRIVE_PATH, 'all_data/www_2s_new.json'), 'r') as f:
  cloze_dataset_2s, order_dataset_2s = json.load(f)

for p in cloze_dataset_2s:
  label_dist = Counter([ex['label'] for ex in cloze_dataset_2s[p]])
  print('Cloze label distribution (%s):' % p)
  print(label_dist.most_common())
print_dict(cloze_dataset_2s['train'][0])


Preprocessed examples:
{
  story_id: 
    13,
  worker_id: 
    A32W24TWSWXW,
  type: 
    None,
  idx: 
    None,
  aug: 
    False,
  actor: 
    John,
  location: 
    kitchen,
  objects: 
    cabinet, counter, knife, pan, potato, pizza,
  sentences: 
    [
      John was getting the snacks ready for the party.
      John opened the cabinet, took out a pan and put it on the counter.
      John opened the fridge and got out the pizza.
      John put the pizza on the pan and put them into the oven.
      John took a knife and cut the hot pizza in eight slices.
    ],
  length: 
    5,
  example_id: 
    13,
  plausible: 
    True,
  breakpoint: 
    -1,
  confl_sents: 
    [],
  confl_pairs: 
    [],
  states: 
    [
      {'h_location': [['John', 0]], 'conscious': [['John', 2]], 'wearing': [['John', 0]], 'h_wet': [['John', 0]], 'hygiene': [['John', 0]], 'location': [['snacks', 0], ['party', 0]], 'exist': [['snacks', 4], ['party', 2]], 'clean': [['snacks', 0], ['party', 0]], 'power': 

In [59]:
cloze_dataset = {p: [] for p in dataset}
order_dataset = {p: [] for p in dataset}

for p in dataset:
  for exid in dataset[p]:
    ex = dataset[p][exid]

    if ex['type'] == None:
      continue

    ex_plaus = dataset[p][str(ex['story_id'])]

    if ex['type'] == 'cloze':
      cloze_dataset[p].append(ex)
      cloze_dataset[p].append(ex_plaus) # For every implausible story, add a copy of its corresponding plausible story

    # Exclude augmented ordering examples from dev and test, since the breakpoints aren't always accurate in those
    elif ex['type'] == 'order' and not (p != 'train' and ex['aug']):
      order_dataset[p].append(ex)
      order_dataset[p].append(ex_plaus)


In [60]:
### Convert TRIP to Two-Story Classification Task

##Ready the TRIP dataset for two-story classification."""

from www.utils import print_dict
import json
from collections import Counter

#data_file = os.path.join(DRIVE_PATH, 'all_data/www_2s_new.json')
with open(os.path.join(DRIVE_PATH, 'all_data/www_2s_new.json'), 'r') as f:
  cloze_dataset_2s, order_dataset_2s = json.load(f)

for p in cloze_dataset_2s:
  label_dist = Counter([ex['label'] for ex in cloze_dataset_2s[p]])
  print('Cloze label distribution (%s):' % p)
  print(label_dist.most_common())
print_dict(cloze_dataset_2s['train'][0])


Cloze label distribution (train):
[(1, 400), (0, 399)]
Cloze label distribution (dev):
[(0, 161), (1, 161)]
Cloze label distribution (test):
[(1, 176), (0, 175)]
{
  example_id: 
    0-C0,
  stories: 
    [
      {'story_id': 0, 'worker_id': 'A1F01FVEPYCPHO', 'type': 'cloze', 'idx': 0, 'aug': False, 'actor': 'Tom', 'location': 'kitchen', 'objects': 'dustbin, microwave, pan, plate, cereal, soup', 'sentences': ['Tom bought a new dustbin for the kitchen.', 'Tom threw a broken plate in the dustbin.', 'Tom got some soup from the fridge.', 'Tom put the soup in the microwave.', 'Tom ate the cold soup.'], 'length': 5, 'example_id': '0-C0', 'plausible': False, 'breakpoint': 4, 'confl_sents': [3], 'confl_pairs': [[3, 4]], 'states': [{'h_location': [['Tom', 0]], 'conscious': [['Tom', 2]], 'wearing': [['Tom', 0]], 'h_wet': [['Tom', 0]], 'hygiene': [['Tom', 0]], 'location': [['dustbin', 6]], 'exist': [['dustbin', 4]], 'clean': [['dustbin', 0]], 'power': [['dustbin', 0]], 'functional': [['dustbin', 

In [None]:
from www.dataset.prepro import get_tiered_data
from www.dataset.featurize import add_bert_features_tiered, get_tensor_dataset_tiered
from collections import Counter
import numpy as np
from www.dataset.ann import att_to_num_classes, idx_to_att
from sklearn.metrics import accuracy_score, f1_score
from www.utils import print_dict

tiered_dataset = cloze_dataset_2s

seq_length = 16 # Max sequence length to pad to

tiered_dataset = get_tiered_data(tiered_dataset)
tiered_dataset = add_bert_features_tiered(tiered_dataset, tokenizer, seq_length, add_segment_ids=True)

from www.dataset.prepro import get_tiered_data, balance_labels
from www.dataset.featurize import add_bert_features_tiered, get_tensor_dataset_tiered
from collections import Counter
import numpy as np
from www.dataset.ann import att_to_num_classes, idx_to_att, att_default_values
from sklearn.metrics import accuracy_score, f1_score
from www.utils import print_dict
import numpy as np


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens ar

In [61]:
# Have to add BERT input IDs and tensorize again
num_runs = 10
stories = []
pred_stories = []
conflicts = []
pred_conflicts = []
preconditions = []
pred_preconditions = []
effects = []
pred_effects = []
verifiability = []
consistency = []
for p in tiered_dataset:
  if p == 'train':
    continue
  metr_avg = {}
  print('starting %s...' % p)
  for r in range(num_runs):
    print('starting run %s...' % str(r))
    for ex in tiered_dataset[p]:
      verifiable = True
      consistent = True

      stories.append(ex['label'])
      pred_stories.append(np.random.randint(2))

      if stories[-1] != pred_stories[-1]:
        verifiable = False

      labels_ex_p = []
      preds_ex_p = []

      labels_ex_e = []
      preds_ex_e = []

      labels_ex_c = []
      preds_ex_c = []

      for si, story in enumerate(ex['stories']):
        labels_story_p = []
        preds_story_p = []

        labels_story_e = []
        preds_story_e = []

        for ent_ann in story['entities']:
          entity = ent_ann['entity']

          if si == 1 - ex['label']:
            labels_ex_c.append(ent_ann['conflict_span_onehot'])
            pred = np.zeros(ent_ann['conflict_span_onehot'].shape)
            for cs in np.random.choice(len(pred), size=2, replace=False):
              pred[cs] = 1
            preds_ex_c.append(pred)

          labels_ent = []
          preds_ent = []
          for s, sent_ann in enumerate(ent_ann['preconditions']):
            if s < len(story['sentences']):
              if entity in story['sentences'][s]:

                labels_ent.append(sent_ann)
                sent_ann_pred = []
                for i, l in enumerate(sent_ann):
                  pl = np.random.randint(att_to_num_classes[idx_to_att[i]])
                  if pl > 0 and pl != att_default_values[idx_to_att[i]]:
                    if pl != l:
                      verifiable = False
                  sent_ann_pred.append(pl)
                preds_ent.append(sent_ann_pred)

          labels_story_p.append(labels_ent)
          preds_story_p.append(preds_ent)

          labels_ent = []
          preds_ent = []
          for s, sent_ann in enumerate(ent_ann['effects']):
            if s < len(story['sentences']):
              if entity in story['sentences'][s]:

                labels_ent.append(sent_ann)
                sent_ann_pred = []
                for i, l in enumerate(sent_ann):
                  pl = np.random.randint(att_to_num_classes[idx_to_att[i]])
                  if pl > 0 and pl != att_default_values[idx_to_att[i]]:
                    if pl != l:
                      verifiable = False
                  sent_ann_pred.append(pl)
                preds_ent.append(sent_ann_pred)

          labels_story_e.append(labels_ent)
          preds_story_e.append(preds_ent)

        labels_ex_p.append(labels_story_p)
        preds_ex_p.append(preds_story_p)

        labels_ex_e.append(labels_story_e)
        preds_ex_e.append(preds_story_e)

      conflicts.append(labels_ex_c)
      pred_conflicts.append(preds_ex_c)

      preconditions.append(labels_ex_p)
      pred_preconditions.append(preds_ex_p)

      effects.append(labels_ex_e)
      pred_effects.append(preds_ex_e)

      p_confl = np.nonzero(np.sum(np.array(preds_ex_c), axis=0))[0]
      l_confl = np.nonzero(np.sum(np.array(labels_ex_c), axis=0))[0]
      assert len(l_confl) == 2, str(labels_ex_c)
      if not (p_confl[0] == l_confl[0] and p_confl[1] == l_confl[1]):
        verifiable = False
        consistent = False

      verifiability.append(1 if verifiable else 0)
      consistency.append(1 if consistent else 0)

    # Compute metrics
    metr = {}
    metr['story_accuracy'] = accuracy_score(stories, pred_stories)

    conflicts_flat = [c for c_ex in conflicts for c_ent in c_ex for c in c_ent]
    pred_conflicts_flat = [c for c_ex in pred_conflicts for c_ent in c_ex for c in c_ent]
    metr['confl_f1'] = f1_score(conflicts_flat, pred_conflicts_flat, average='macro')

    preconditions_flat = [p for p_ex in preconditions for p_story in p_ex for p_sent in p_story for p_ent in p_sent for p in p_ent]
    pred_preconditions_flat = [p for p_ex in pred_preconditions for p_story in p_ex for p_sent in p_story for p_ent in p_sent for p in p_ent]
    metr['precondition_f1'] = f1_score(preconditions_flat, pred_preconditions_flat, average='macro')

    effects_flat = [p for p_ex in effects for p_story in p_ex for p_sent in p_story for p_ent in p_sent for p in p_ent]
    pred_effects_flat = [p for p_ex in pred_effects for p_story in p_ex for p_sent in p_story for p_ent in p_sent for p in p_ent]
    metr['effect_f1'] = f1_score(effects_flat, pred_effects_flat, average='macro')

    metr['verifiability'] = np.mean(verifiability)
    metr['consistency'] = np.mean(consistency)

    for k in metr:
      if k not in metr_avg:
        metr_avg[k] = []
      metr_avg[k].append(metr[k])

  for k in metr_avg:
    metr_avg[k] = (np.mean(metr_avg[k]), np.var(metr_avg[k]) ** 0.5)
  print('RANDOM BASELINE (%s, %s runs)' % (str(p), str(num_runs)))
  print_dict(metr_avg)

starting dev...
starting run 0...
starting run 1...
starting run 2...


Exception ignored in: <function _xla_gc_callback at 0x7afb9aba17e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 96, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 


starting run 3...
starting run 4...
starting run 5...
starting run 6...
starting run 7...
starting run 8...
starting run 9...
RANDOM BASELINE (dev, 10 runs)
{
  story_accuracy: 
    (0.49097604259094946, 0.009189092067337335),
  confl_f1: 
    (0.486331025330813, 0.001570803373304592),
  precondition_f1: 
    (0.040369643573463114, 4.416452682427279e-05),
  effect_f1: 
    (0.04021298694377605, 9.283825334391873e-05),
  verifiability: 
    (0.0, 0.0),
  consistency: 
    (0.11711130829143253, 0.0017785358943156046),
}


starting test...
starting run 0...
starting run 1...
starting run 2...
starting run 3...
starting run 4...
starting run 5...


KeyboardInterrupt: 

In [None]:
from www.dataset.prepro import get_tiered_data, balance_labels
from www.dataset.featurize import add_bert_features_tiered, get_tensor_dataset_tiered
from collections import Counter

tiered_dataset = cloze_dataset_2s

# Debug the code on a small amount of data
if debug:
  for k in tiered_dataset:
    tiered_dataset[k] = tiered_dataset[k][:20]

# train_spans = True
train_spans = False
if train_spans:
  tiered_dataset = get_story_spans_2s(tiered_dataset, train_only=True)
  tiered_dataset['train'] = [ex for ex in tiered_dataset['train'] if ex['label'] != -1] # For now, ignore examples where both stories are plausible :(

seq_length = 16 # Max sequence length to pad to

tiered_dataset = get_tiered_data(tiered_dataset)
tiered_dataset = add_bert_features_tiered(tiered_dataset, tokenizer, seq_length, add_segment_ids=True)

tiered_tensor_dataset = {}
max_story_length = max([len(ex['stories'][0]['sentences']) for p in tiered_dataset for ex in tiered_dataset[p]])
for p in tiered_dataset:
  tiered_tensor_dataset[p] = get_tensor_dataset_tiered(tiered_dataset[p], max_story_length, add_segment_ids=True)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens ar

In [65]:
from www.dataset.ann import att_to_idx, att_to_num_classes, att_types

subtask = 'cloze'
batch_sizes = [config_batch_size]
learning_rates = [config_lr]
epochs = config_epochs
eval_batch_size = 16
generate_learning_curve = False # Generate data for training curve figure in TRIP paper

num_state_labels = {}
for att in att_to_idx:
  if att_types[att] == 'default':
    num_state_labels[att_to_idx[att]] = 3
  else:
    num_state_labels[att_to_idx[att]] = att_to_num_classes[att] # Location attributes fall into this since they don't have well-define pre- and post-condition yet

# Ablation options:
# - attributes: skip attribute prediction phase
# - embeddings: DON'T input contextual embeddings to conflict detector
# - states: DON'T input states to conflict detector
# - states-labels: in states input to conflict detector, include predicted labels
# - states-logits: in states input to conflict detector, include state logits (preferred)
# - states-teacher-forcing: train conflict detector on ground truth state labels (not predictions)
# - states-attention: re-weight input to conflict detector with weights conditioned on states representation
ablation = ['attributes', 'states-logits'] # This is the default mode presented in the paper


In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup
from www.model.train import train_epoch_tiered
from www.model.eval import evaluate_tiered, save_results, save_preds, add_entity_attribute_labels
from sklearn.metrics import accuracy_score, f1_score
from www.utils import print_dict, get_model_dir
from www.model.transformers_ext import TieredModelPipeline
from www.dataset.ann import att_to_num_classes
import shutil
import pandas as pd

seed_val = 22 # Save random seed for reproducibility
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

# We'll keep the validation data here with a constant eval batch size
dev_sampler = SequentialSampler(tiered_tensor_dataset['dev'])
dev_dataloader = DataLoader(tiered_tensor_dataset['dev'], sampler=dev_sampler, batch_size=eval_batch_size)
dev_dataset_name = subtask + '_%s_dev'
dev_ids = [ex['example_id'] for ex in tiered_dataset['dev']]

all_losses = []
param_combos = []
combo_names = []
all_val_objs = []
output_dirs = []
best_obj = 0.0
best_model = '<none>'
best_dir = ''
best_obj2 = 0.0
best_model2 = '<none>'
best_dir2 = ''


In [None]:
config_class = BertConfig
model_name = 'finetunedBert_based_PIQA'
emb_class = BertModel
tokenizer_class = BertTokenizer

In [None]:
tokenizer = tokenizer_class.from_pretrained('drive/My Drive/Colab Notebooks/EECS595/Project/Verifiable-Coherent-NLU-main/saved_models/finetunedBert_based_PIQA/',
                                                do_lower_case = False,
                                                cache_dir=os.path.join(DRIVE_PATH, 'cache'))


In [None]:
print('Beginning grid search for the %s sub-task over %s parameter combination(s)!' % (subtask, str(len(batch_sizes) * len(learning_rates))))
for bs in batch_sizes:
  for lr in learning_rates:
    print('\nTRAINING MODEL: bs=%s, lr=%s' % (str(bs), str(lr)))

    loss_values = []
    obj_values = []

    # Set up training dataset with new batch size
    train_sampler = RandomSampler(tiered_tensor_dataset['train'])
    train_dataloader = DataLoader(tiered_tensor_dataset['train'], sampler=train_sampler, batch_size=bs)

    # Set up model
    config = config_class.from_pretrained('drive/My Drive/Colab Notebooks/EECS595/Project/Verifiable-Coherent-NLU-main/saved_models/finetunedBert_based_PIQA/')
    emb = emb_class.from_pretrained('drive/My Drive/Colab Notebooks/EECS595/Project/Verifiable-Coherent-NLU-main/saved_models/finetunedBert_based_PIQA/',
                                          config=config,
                                          cache_dir=os.path.join(DRIVE_PATH, 'cache'))
    if torch.cuda.is_available():
      emb.cuda()
    device = emb.device
    max_story_length = max([len(ex['stories'][0]['sentences']) for p in tiered_dataset for ex in tiered_dataset[p]])
    model = TieredModelPipeline(emb, max_story_length, len(att_to_num_classes), num_state_labels,
                                config_class, 'drive/My Drive/Colab Notebooks/EECS595/Project/Verifiable-Coherent-NLU-main/saved_models/finetunedBert_based_PIQA/', device,
                                ablation=ablation, loss_weights=loss_weights).to(device)

    # Set up optimizer
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps = total_steps)

    train_lc_data = []
    val_lc_data = []
    for epoch in range(epochs):
      # Train the model for one epoch
      print('[%s] Beginning epoch...' % str(epoch))

      epoch_loss, _ = train_epoch_tiered(model, optimizer, train_dataloader, device, seg_mode=False,
                                         build_learning_curves=generate_learning_curve, val_dataloader=dev_dataloader,
                                         train_lc_data=train_lc_data, val_lc_data=val_lc_data)

      # Save loss
      loss_values.append(epoch_loss)

      # Validate on dev set
      validation_results = evaluate_tiered(model, dev_dataloader, device, [(accuracy_score, 'accuracy'), (f1_score, 'f1')], seg_mode=False, return_explanations=True)
      metr_attr, all_pred_atts, all_atts, \
      metr_prec, all_pred_prec, all_prec, \
      metr_eff, all_pred_eff, all_eff, \
      metr_conflicts, all_pred_conflicts, all_conflicts, \
      metr_stories, all_pred_stories, all_stories, explanations = validation_results[:16]
      explanations = add_entity_attribute_labels(explanations, tiered_dataset['dev'], list(att_to_num_classes.keys()))

      print('[%s] Validation results:' % str(epoch))
      print('[%s] Preconditions:' % str(epoch))
      print_dict(metr_prec)
      print('[%s] Effects:' % str(epoch))
      print_dict(metr_eff)
      print('[%s] Conflicts:' % str(epoch))
      print_dict(metr_conflicts)
      print('[%s] Stories:' % str(epoch))
      print_dict(metr_stories)

      # Save accuracy - want to maximize verifiability of tiered predictions
      ver = metr_stories['verifiability']
      acc = metr_stories['accuracy']
      obj_values.append(ver)

      # Save model checkpoint
      print('[%s] Saving model checkpoint...' % str(epoch))
      model_param_str = get_model_dir(model_name.replace('/', '-'), subtask, bs, lr, epoch) + '_' +  '-'.join([str(lw) for lw in loss_weights]) +  '_tiered_pipeline_lc'
      if train_spans:
        model_param_str += 'spans'
      if len(model.ablation) > 0:
        model_param_str += '_ablate_'
        model_param_str += '_'.join(model.ablation)
      output_dir = os.path.join(DRIVE_PATH, 'saved_models', model_param_str)
      output_dirs.append(output_dir)
      if not os.path.exists(output_dir):
        os.makedirs(output_dir)

      save_results(metr_attr, output_dir, dev_dataset_name % 'attributes')
      save_results(metr_prec, output_dir, dev_dataset_name % 'preconditions')
      save_results(metr_eff, output_dir, dev_dataset_name % 'effects')
      save_results(metr_conflicts, output_dir, dev_dataset_name % 'conflicts')
      save_results(metr_stories, output_dir, dev_dataset_name % 'stories')
      save_results(explanations, output_dir, dev_dataset_name % 'explanations')

      # Just save story preds
      save_preds(dev_ids, all_stories, all_pred_stories, output_dir, dev_dataset_name % 'stories')

      emb = emb.module if hasattr(emb, 'module') else emb
      emb.save_pretrained(output_dir)
      torch.save(model, os.path.join(output_dir, 'classifiers.pth'))
      tokenizer.save_vocabulary(output_dir)

      if ver > best_obj:
        best_obj = ver
        best_model = model_param_str
        best_dir = output_dir
      if acc > best_obj2:
        best_obj2 = acc
        best_model2 = model_param_str
        best_dir2 = output_dir

      for od in output_dirs:
        if od != best_dir and od != best_dir2 and os.path.exists(od):
          shutil.rmtree(od)

      print('[%s] Finished epoch.' % str(epoch))

    all_losses.append(loss_values)
    all_val_objs.append(obj_values)
    param_combos.append((bs, lr))
    combo_names.append('bs=%s, lr=%s' % (str(bs), str(lr)))

print('Finished grid search! :)')
print('Best validation *verifiability* %s from model %s.' % (str(best_obj), best_model))
print('Best validation *accuracy* %s from model %s.' % (str(best_obj2), best_model2))

if generate_learning_curve:
  print('Saving learning curve data...')
  train_lc_data = [subrecord for record in train_lc_data for subrecord in record] # flatten
  val_lc_data = [subrecord for record in val_lc_data for subrecord in record] # flatten

  train_lc_data = pd.DataFrame(train_lc_data)
  print(os.path.join(best_dir if best_dir != '<none>' else best_dir2, 'learning_curve_data_train.csv'))
  train_lc_data.to_csv(os.path.join(best_dir if best_dir != '' else best_dir2, 'learning_curve_data_train.csv'), index=False)
  val_lc_data = pd.DataFrame(val_lc_data)
  val_lc_data.to_csv(os.path.join(best_dir if best_dir != '' else best_dir2, 'learning_curve_data_val.csv'), index=False)
  print('Learning curve data saved. %s rows saved for training, %s rows saved for validation.' % (str(len(train_lc_data.index)), str(len(val_lc_data.index))))

"""Delete all non-best model checkpoints:"""

Beginning grid search for the cloze sub-task over 1 parameter combination(s)!

TRAINING MODEL: bs=1, lr=1e-05


[                                                                        ] [38;2;255;0;0m  0%[39m

[0] Beginning epoch...


[########################################################################] [38;2;0;255;0m100%[39m
[                                                                        ] [38;2;255;0;0m  0%[39m

	Beginning evaluation...
		Running prediction...


[########################################################################] [38;2;0;255;0m100%[39m


		Computing metrics...
	Finished evaluation in 0:00:18s.
[0] Validation results:
[0] Preconditions:
{
  accuracy: 
    0.9928688180077523,
  f1: 
    0.19878184321809625,
  accuracy_0: 
    0.9935786671648065,
  f1_0: 
    0.33469717693475126,
  accuracy_1: 
    0.9981436510530985,
  f1_1: 
    0.6583577370153403,
  accuracy_2: 
    0.9988791855414935,
  f1_2: 
    0.33314642617946205,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9784009713725307,
  f1_5: 
    0.10989806476487228,
  accuracy_6: 
    0.9838999673095783,
  f1_6: 
    0.6262149592171432,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.9921075981880166,
  f1_8: 
    0.3320127216322448,
  accuracy_9: 
    0.9851959090272264,
  f1_9: 
    0.6291960609673763,
  accuracy_10: 
    0.9950614112922057,
  f1_10: 
    0.33250819771263823,
  accuracy_11: 
    0.997151263

[                                                                        ] [38;2;255;0;0m  0%[39m

[0] Finished epoch.
[1] Beginning epoch...


[########################################################################] [38;2;0;255;0m100%[39m
[                                                                        ] [38;2;255;0;0m  0%[39m

	Beginning evaluation...
		Running prediction...


[########################################################################] [38;2;0;255;0m100%[39m


		Computing metrics...
	Finished evaluation in 0:00:19s.
[1] Validation results:
[1] Preconditions:
{
  accuracy: 
    0.9941186428804931,
  f1: 
    0.25324120289756075,
  accuracy_0: 
    0.996310652407416,
  f1_0: 
    0.5622022776530269,
  accuracy_1: 
    0.997653294727502,
  f1_1: 
    0.65622431825718,
  accuracy_2: 
    0.9990192873488069,
  f1_2: 
    0.4581697781100805,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9805842245364965,
  f1_5: 
    0.1660504111454853,
  accuracy_6: 
    0.9859781441180592,
  f1_6: 
    0.6310173828967268,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.9958202960818194,
  f1_8: 
    0.5690753656672971,
  accuracy_9: 
    0.9871690094802223,
  f1_9: 
    0.6336547061403269,
  accuracy_10: 
    0.9950730864428151,
  f1_10: 
    0.33430708876370846,
  accuracy_11: 
    0.997151263251296,

[                                                                        ] [38;2;255;0;0m  0%[39m

[1] Finished epoch.
[2] Beginning epoch...


[########################################################################] [38;2;0;255;0m100%[39m
[                                                                        ] [38;2;255;0;0m  0%[39m

	Beginning evaluation...
		Running prediction...


[########################################################################] [38;2;0;255;0m100%[39m


		Computing metrics...
	Finished evaluation in 0:00:19s.
[2] Validation results:
[2] Preconditions:
{
  accuracy: 
    0.994885116518003,
  f1: 
    0.4414912397863652,
  accuracy_0: 
    0.9957735954793817,
  f1_0: 
    0.5327037792338554,
  accuracy_1: 
    0.9980152243963947,
  f1_1: 
    0.657775902920717,
  accuracy_2: 
    0.9990192873488069,
  f1_2: 
    0.4581697781100805,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9854761126418531,
  f1_5: 
    0.36682332242814647,
  accuracy_6: 
    0.9871456591790034,
  f1_6: 
    0.6339780828458589,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.996544155419605,
  f1_8: 
    0.5921777685003736,
  accuracy_9: 
    0.9883131742399477,
  f1_9: 
    0.6364922069193585,
  accuracy_10: 
    0.9963923784616822,
  f1_10: 
    0.5080386164155419,
  accuracy_11: 
    0.9982720777098024

[                                                                        ] [38;2;255;0;0m  0%[39m

[2] Finished epoch.
[3] Beginning epoch...


[########################################################################] [38;2;0;255;0m100%[39m
[                                                                        ] [38;2;255;0;0m  0%[39m

	Beginning evaluation...
		Running prediction...


[########################################################################] [38;2;0;255;0m100%[39m


		Computing metrics...
	Finished evaluation in 0:00:19s.
[3] Validation results:
[3] Preconditions:
{
  accuracy: 
    0.9951548124970813,
  f1: 
    0.48793826838032417,
  accuracy_0: 
    0.9977700462335964,
  f1_0: 
    0.6052897013833203,
  accuracy_1: 
    0.9981670013543175,
  f1_1: 
    0.6584361076058844,
  accuracy_2: 
    0.9991010134030729,
  f1_2: 
    0.5482371457517029,
  accuracy_3: 
    0.9985989819268668,
  f1_3: 
    0.3330996666355111,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9862583477326857,
  f1_5: 
    0.41247419561389076,
  accuracy_6: 
    0.9874492130948489,
  f1_6: 
    0.634792336312442,
  accuracy_7: 
    0.9982370522579741,
  f1_7: 
    0.3330392494824319,
  accuracy_8: 
    0.9967893335824032,
  f1_8: 
    0.5996815481628807,
  accuracy_9: 
    0.9888385560173726,
  f1_9: 
    0.6377912366935228,
  accuracy_10: 
    0.9964040536122916,
  f1_10: 
    0.5236865000521495,
  accuracy_11: 
    0.9978517722878

[                                                                        ] [38;2;255;0;0m  0%[39m

[3] Finished epoch.
[4] Beginning epoch...


[########################################################################] [38;2;0;255;0m100%[39m
[                                                                        ] [38;2;255;0;0m  0%[39m

	Beginning evaluation...
		Running prediction...


[########################################################################] [38;2;0;255;0m100%[39m


		Computing metrics...
	Finished evaluation in 0:00:19s.
[4] Validation results:
[4] Preconditions:
{
  accuracy: 
    0.9953871479942091,
  f1: 
    0.5434656893793073,
  accuracy_0: 
    0.9978167468360342,
  f1_0: 
    0.6051432189496594,
  accuracy_1: 
    0.9978050716854248,
  f1_1: 
    0.6568706004306241,
  accuracy_2: 
    0.9991944146079484,
  f1_2: 
    0.5878806095779295,
  accuracy_3: 
    0.9988441600896651,
  f1_3: 
    0.4943126935753761,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9864334749918274,
  f1_5: 
    0.46860784946248796,
  accuracy_6: 
    0.987741091860085,
  f1_6: 
    0.6446404388920648,
  accuracy_7: 
    0.9984238546677252,
  f1_7: 
    0.48273023161054046,
  accuracy_8: 
    0.996836034184841,
  f1_8: 
    0.5997911218132056,
  accuracy_9: 
    0.9888502311679821,
  f1_9: 
    0.6377235927172924,
  accuracy_10: 
    0.9966842572269182,
  f1_10: 
    0.5497252528188531,
  accuracy_11: 
    0.99778172138420

[                                                                        ] [38;2;255;0;0m  0%[39m

[4] Finished epoch.
[5] Beginning epoch...


[########################################################################] [38;2;0;255;0m100%[39m
[                                                                        ] [38;2;255;0;0m  0%[39m

	Beginning evaluation...
		Running prediction...


[########################################################################] [38;2;0;255;0m100%[39m


		Computing metrics...
	Finished evaluation in 0:00:19s.
[5] Validation results:
[5] Preconditions:
{
  accuracy: 
    0.9953445336944846,
  f1: 
    0.5734651061332562,
  accuracy_0: 
    0.9975131929201887,
  f1_0: 
    0.594182266780768,
  accuracy_1: 
    0.9980852753000514,
  f1_1: 
    0.6580775222125895,
  accuracy_2: 
    0.9987390837341802,
  f1_2: 
    0.5917220766456168,
  accuracy_3: 
    0.9991010134030729,
  f1_3: 
    0.6065646299531879,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9861882968290291,
  f1_5: 
    0.49146720467006577,
  accuracy_6: 
    0.9876710409564283,
  f1_6: 
    0.707937529915407,
  accuracy_7: 
    0.9984005043665063,
  f1_7: 
    0.4840137515973017,
  accuracy_8: 
    0.9965908560220427,
  f1_8: 
    0.5952339242787504,
  accuracy_9: 
    0.9885233269509177,
  f1_9: 
    0.6370008486713769,
  accuracy_10: 
    0.9964274039135105,
  f1_10: 
    0.5372544888536513,
  accuracy_11: 
    0.997478167468360

[                                                                        ] [38;2;255;0;0m  0%[39m

[5] Finished epoch.
[6] Beginning epoch...


[########################################################################] [38;2;0;255;0m100%[39m
[                                                                        ] [38;2;255;0;0m  0%[39m

	Beginning evaluation...
		Running prediction...


[########################################################################] [38;2;0;255;0m100%[39m


		Computing metrics...
	Finished evaluation in 0:00:19s.
[6] Validation results:
[6] Preconditions:
{
  accuracy: 
    0.9953906505393919,
  f1: 
    0.5882588490106025,
  accuracy_0: 
    0.9977700462335964,
  f1_0: 
    0.6041926375474993,
  accuracy_1: 
    0.9979918740951759,
  f1_1: 
    0.6576946296238325,
  accuracy_2: 
    0.9987974594872274,
  f1_2: 
    0.621614586404825,
  accuracy_3: 
    0.9991944146079484,
  f1_3: 
    0.6994927156808258,
  accuracy_4: 
    0.99978984728903,
  f1_4: 
    0.5110760787061365,
  accuracy_5: 
    0.9859898192686686,
  f1_5: 
    0.5087401375455765,
  accuracy_6: 
    0.987530939149115,
  f1_6: 
    0.709863127150177,
  accuracy_7: 
    0.9984005043665063,
  f1_7: 
    0.4840137515973017,
  accuracy_8: 
    0.9969878111427637,
  f1_8: 
    0.7583443913329954,
  accuracy_9: 
    0.9886167281557932,
  f1_9: 
    0.6372085320567987,
  accuracy_10: 
    0.9964624293653388,
  f1_10: 
    0.5353595055716113,
  accuracy_11: 
    0.9975015177695792,
 

[                                                                        ] [38;2;255;0;0m  0%[39m

[6] Finished epoch.
[7] Beginning epoch...


[########################################################################] [38;2;0;255;0m100%[39m
[                                                                        ] [38;2;255;0;0m  0%[39m

	Beginning evaluation...
		Running prediction...


[########################################################################] [38;2;0;255;0m100%[39m


		Computing metrics...
	Finished evaluation in 0:00:19s.
[7] Validation results:
[7] Preconditions:
{
  accuracy: 
    0.9953748890860692,
  f1: 
    0.5854689210162064,
  accuracy_0: 
    0.997548218372017,
  f1_0: 
    0.5952592581100821,
  accuracy_1: 
    0.9979101480409097,
  f1_1: 
    0.6573278084861895,
  accuracy_2: 
    0.9991944146079484,
  f1_2: 
    0.6787837802410128,
  accuracy_3: 
    0.9991010134030729,
  f1_3: 
    0.6065646299531879,
  accuracy_4: 
    0.9996964460841545,
  f1_4: 
    0.3332827333341118,
  accuracy_5: 
    0.9864451501424368,
  f1_5: 
    0.5074390728307672,
  accuracy_6: 
    0.9865502264979218,
  f1_6: 
    0.6980571708865586,
  accuracy_7: 
    0.9984238546677252,
  f1_7: 
    0.4797409270914294,
  accuracy_8: 
    0.9971045626488582,
  f1_8: 
    0.7601628827744259,
  accuracy_9: 
    0.9882197730350721,
  f1_9: 
    0.6995430535413649,
  accuracy_10: 
    0.9964857796665577,
  f1_10: 
    0.6130715347579127,
  accuracy_11: 
    0.997548218372017

[                                                                        ] [38;2;255;0;0m  0%[39m

[7] Finished epoch.
[8] Beginning epoch...


[########################################################################] [38;2;0;255;0m100%[39m
[                                                                        ] [38;2;255;0;0m  0%[39m

	Beginning evaluation...
		Running prediction...


[########################################################################] [38;2;0;255;0m100%[39m


		Computing metrics...
	Finished evaluation in 0:00:19s.
[8] Validation results:
[8] Preconditions:
{
  accuracy: 
    0.9954595339279877,
  f1: 
    0.6015819380873992,
  accuracy_0: 
    0.9978284219866437,
  f1_0: 
    0.6045633021109177,
  accuracy_1: 
    0.9980969504506608,
  f1_1: 
    0.6581137914061923,
  accuracy_2: 
    0.9994045673189185,
  f1_2: 
    0.8142945910025068,
  accuracy_3: 
    0.9992527903609957,
  f1_3: 
    0.709134575838258,
  accuracy_4: 
    0.99978984728903,
  f1_4: 
    0.5110760787061365,
  accuracy_5: 
    0.9863050483351236,
  f1_5: 
    0.5273318266708531,
  accuracy_6: 
    0.9876476906552094,
  f1_6: 
    0.7569316422695266,
  accuracy_7: 
    0.9984588801195535,
  f1_7: 
    0.49489239746970654,
  accuracy_8: 
    0.9967426329799655,
  f1_8: 
    0.7366543766648898,
  accuracy_9: 
    0.9886517536076216,
  f1_9: 
    0.699176732636861,
  accuracy_10: 
    0.9963690281604632,
  f1_10: 
    0.639983186353025,
  accuracy_11: 
    0.9977700462335964,


[                                                                        ] [38;2;255;0;0m  0%[39m

[8] Finished epoch.
[9] Beginning epoch...


[########################################################################] [38;2;0;255;0m100%[39m
[                                                                        ] [38;2;255;0;0m  0%[39m

	Beginning evaluation...
		Running prediction...


[########################################################################] [38;2;0;255;0m100%[39m


		Computing metrics...
	Finished evaluation in 0:00:19s.
[9] Validation results:
[9] Preconditions:
{
  accuracy: 
    0.9954245084761594,
  f1: 
    0.5901628155361703,
  accuracy_0: 
    0.9976883201793303,
  f1_0: 
    0.6905528400771864,
  accuracy_1: 
    0.9980619249988325,
  f1_1: 
    0.6580003696935944,
  accuracy_2: 
    0.9995096436744034,
  f1_2: 
    0.8492688363751114,
  accuracy_3: 
    0.9991593891561201,
  f1_3: 
    0.6202301080798872,
  accuracy_4: 
    0.99978984728903,
  f1_4: 
    0.5110760787061365,
  accuracy_5: 
    0.9866669780040163,
  f1_5: 
    0.5191523208251634,
  accuracy_6: 
    0.9874725633960678,
  f1_6: 
    0.761541741051832,
  accuracy_7: 
    0.9985522813244291,
  f1_7: 
    0.533685639953874,
  accuracy_8: 
    0.9967426329799655,
  f1_8: 
    0.7544212768117351,
  accuracy_9: 
    0.9885817027039648,
  f1_9: 
    0.7008996435834747,
  accuracy_10: 
    0.9964857796665577,
  f1_10: 
    0.674588793880202,
  accuracy_11: 
    0.9973847662634848,
 

'Delete all non-best model checkpoints:'

In [75]:
probe_model = "bert-base-uncased_cloze_1_1e-05_5_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits" #eval_model_dir "finetunedBert_based_PIQA" #
probe_model = os.path.join(DRIVE_PATH, 'saved_models', probe_model)


In [76]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from www.model.eval import evaluate_tiered, save_results, save_preds, list_comparison, add_entity_attribute_labels
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
metrics = [(accuracy_score, 'accuracy'), (precision_score, 'precision'), (recall_score, 'recall'), (f1_score, 'f1')]
import numpy as np
from www.utils import print_dict

print('Testing model: %s.' % probe_model)

# May alter this depending on which partition(s) you want to run inference on
for p in tiered_dataset:
  if p != 'test':
    continue

  p_dataset = tiered_dataset[p]
  p_tensor_dataset = tiered_tensor_dataset[p]
  p_sampler = SequentialSampler(p_tensor_dataset)
  p_dataloader = DataLoader(p_tensor_dataset, sampler=p_sampler, batch_size=16)
  dev_dataset_name = subtask + '_%s_' + p
  p_ids = [ex['example_id'] for ex in tiered_dataset[p]]

  # Get preds and metrics on this partition
  metr_attr, all_pred_atts, all_atts, \
  metr_prec, all_pred_prec, all_prec, \
  metr_eff, all_pred_eff, all_eff, \
  metr_conflicts, all_pred_conflicts, all_conflicts, \
  metr_stories, all_pred_stories, all_stories, explanations = evaluate_tiered(model, p_dataloader, device, [(accuracy_score, 'accuracy'), (f1_score, 'f1')], seg_mode=False, return_explanations=True)
  explanations = add_entity_attribute_labels(explanations, tiered_dataset[p], list(att_to_num_classes.keys()))

  save_results(metr_attr, probe_model, dev_dataset_name % 'attributes')
  save_results(metr_prec, probe_model, dev_dataset_name % 'preconditions')
  save_results(metr_eff, probe_model, dev_dataset_name % 'effects')
  save_results(metr_conflicts, probe_model, dev_dataset_name % 'conflicts')
  save_results(metr_stories, probe_model, dev_dataset_name % 'stories')
  save_results(explanations, probe_model, dev_dataset_name % 'explanations')

  print('\nPARTITION: %s' % p)
  print('Stories:')
  print_dict(metr_stories)
  print('Conflicts:')
  print_dict(metr_conflicts)
  print('Preconditions:')
  print_dict(metr_prec)
  print('Effects:')
  print_dict(metr_eff)

[                                                                        ] [38;2;255;0;0m  0%[39m

Testing model: drive/My Drive/Colab Notebooks/EECS595/Project/Verifiable-Coherent-NLU-main/saved_models/bert-base-uncased_cloze_1_1e-05_5_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits.
	Beginning evaluation...
		Running prediction...


[########################################################################] [38;2;0;255;0m100%[39m


		Computing metrics...
	Finished evaluation in 0:00:29s.

PARTITION: test
Stories:
{
  accuracy: 
    0.7037037037037037,
  f1: 
    0.7035858126542809,
  verifiability: 
    0.05413105413105413,
}


Conflicts:
{
  accuracy: 
    0.9694900435641176,
  f1: 
    0.6890908446386395,
}


Preconditions:
{
  accuracy: 
    0.996031746031746,
  f1: 
    0.5370166929866431,
  accuracy_0: 
    0.9980328313661647,
  f1_0: 
    0.5705295592145662,
  accuracy_1: 
    0.9964877372284779,
  f1_1: 
    0.6453102016601049,
  accuracy_2: 
    0.9993970364340735,
  f1_2: 
    0.7436091273253046,
  accuracy_3: 
    0.999449795746092,
  f1_3: 
    0.6614238134609711,
  accuracy_4: 
    0.9998869443313888,
  f1_4: 
    0.47220337783539973,
  accuracy_5: 
    0.989380304195119,
  f1_5: 
    0.47555458464995,
  accuracy_6: 
    0.9878578211911545,
  f1_6: 
    0.792595834390284,
  accuracy_7: 
    0.99852273926348,
  f1_7: 
    0.4243149872312167,
  accuracy_8: 
    0.9966987744765523,
  f1_8: 
    0.7552991

In [77]:
eval_model_dir = "finetunedBert_based_PIQA_cloze_1_1e-05_7_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits"#"finetunedBert_based_PIQA_cloze_1_1e-05_7_0.0-0.4-0.4-0.2-0.0_tiered_pipeline_lc_ablate_attributes_states-logits"

In [78]:
import json
import os

model_directories = [eval_model_dir]

partitions = ['dev', 'test']
expl_fname = 'results_cloze_explanations_%s.json'
endtask_fname = 'results_cloze_stories_%s.json'
endtask_fname_new = 'results_cloze_stories_final_%s.json'
for md in model_directories:
  for p in partitions:
    #explanations = json.load(open(os.path.join(DRIVE_PATH, 'saved_models', md, expl_fname % p), 'r'))
    #endtask_results = json.load(open(os.path.join(DRIVE_PATH, 'saved_models', md, endtask_fname % p), 'r'))

    consistent_preds = 0
    verifiable_preds = 0
    total = 0
    for expl in explanations:
      if expl['valid_explanation']:
        verifiable_preds += 1
      if expl['story_pred'] == expl['story_label']:
        if len(expl['conflict_pred']) == len(expl['conflict_label']) and expl['conflict_pred'][0] == expl['conflict_label'][0] and expl['conflict_pred'][1] == expl['conflict_label'][1]:
          expl['consistent'] = True
          consistent_preds += 1
        else:
          expl['consistent'] = False
      total += 1
    print(consistent_preds / total)
    print("total:",total)
    #endtask_results['consistency'] = float(consistent_preds) / total
    print('Found %s consistent preds in %s (versus %s verifiable)' % (str(consistent_preds), p, str(verifiable_preds)))
    json.dump(explanations, open(os.path.join(DRIVE_PATH, 'saved_models', md, (expl_fname % p).replace('explanations', 'explanations_consistency')), 'w'))
    json.dump(endtask_results, open(os.path.join(DRIVE_PATH, 'saved_models', md, endtask_fname_new % p), 'w'))

0.16809116809116809
total: 351
Found 59 consistent preds in dev (versus 19 verifiable)
0.16809116809116809
total: 351
Found 59 consistent preds in test (versus 19 verifiable)


In [74]:
len(explanations)

322

In [66]:
28/322

0.08695652173913043

In [79]:
19 / 351

0.05413105413105413