In [30]:
import json
from statistics import mode

In [85]:
from sklearn.metrics import accuracy_score, f1_score

In [126]:
with open('results_cloze_explanations_softmax_test_xlnet.json') as f:
    xlnet_explanations = json.load(f)
with open('results_cloze_explanations_softmax_test_roberta.json') as f:
    roberta_explanations = json.load(f)
with open('results_cloze_explanations_softmax_test_bert.json') as f:
    bert_explanations = json.load(f)

In [139]:
desired_attributes = ['story_pred', 'story_label', 'conflict_label', 'conflict_pred', 'valid_explanation', 'story_prob']

In [140]:
xlnet_filtered_explanations = [{key:value for key,value in entry.items() if key in desired_attributes} for entry in xlnet_explanations]
roberta_filtered_explanations = [{key:value for key,value in entry.items() if key in desired_attributes} for entry in roberta_explanations]
bert_filtered_explanations = [{key:value for key,value in entry.items() if key in desired_attributes} for entry in bert_explanations]


In [129]:
ensemble_model_predictions = [{key:value for key,value in entry.items() if key in ['story_label', 'conflict_label']} for entry in roberta_filtered_explanations]

In [141]:
roberta_filtered_explanations[0]

{'story_label': 1,
 'story_prob': [0.959881067276001, 0.040118925273418427],
 'story_pred': 0,
 'conflict_label': [3, 4],
 'conflict_pred': [1, 2, 3, 4],
 'valid_explanation': False}

In [134]:
with open('results_cloze_explanations_softmax_1_test_bert.json') as f:
    bert_explanations = json.load(f)

In [135]:
bert_explanations[0]

{'story_label': 1,
 'story_prob': [0.4892828166484833, 0.5107171535491943],
 'story_pred': 1,
 'conflict_label': [3, 4],
 'conflict_pred': [3, 4],
 'preconditions_label': {'radio': {'0': {'h_location': 0,
    'conscious': 0,
    'wearing': 0,
    'h_wet': 0,
    'hygiene': 0,
    'location': 0,
    'exist': 0,
    'clean': 0,
    'power': 0,
    'functional': 0,
    'pieces': 0,
    'wet': 0,
    'open': 0,
    'temperature': 0,
    'solid': 0,
    'contain': 0,
    'running': 0,
    'moveable': 0,
    'mixed': 0,
    'edible': 0},
   '1': {'h_location': 0,
    'conscious': 0,
    'wearing': 0,
    'h_wet': 0,
    'hygiene': 0,
    'location': 0,
    'exist': 0,
    'clean': 0,
    'power': 0,
    'functional': 0,
    'pieces': 0,
    'wet': 0,
    'open': 0,
    'temperature': 0,
    'solid': 0,
    'contain': 0,
    'running': 0,
    'moveable': 0,
    'mixed': 0,
    'edible': 0},
   '2': {'h_location': 0,
    'conscious': 0,
    'wearing': 0,
    'h_wet': 0,
    'hygiene': 0,
    '

In [137]:
with open('results_cloze_explanations_softmax_test_roberta.json') as f1:
    roberta_explanations = json.load(f1)

In [138]:
roberta_explanations[0]

{'story_label': 1,
 'story_prob': [0.959881067276001, 0.040118925273418427],
 'story_pred': 0,
 'conflict_label': [3, 4],
 'conflict_pred': [1, 2, 3, 4],
 'preconditions_label': {'anger': {'0': {'h_location': 0,
    'conscious': 0,
    'wearing': 0,
    'h_wet': 0,
    'hygiene': 0,
    'location': 0,
    'exist': 0,
    'clean': 0,
    'power': 0,
    'functional': 0,
    'pieces': 0,
    'wet': 0,
    'open': 0,
    'temperature': 0,
    'solid': 0,
    'contain': 0,
    'running': 0,
    'moveable': 0,
    'mixed': 0,
    'edible': 0},
   '1': {'h_location': 0,
    'conscious': 0,
    'wearing': 0,
    'h_wet': 0,
    'hygiene': 0,
    'location': 0,
    'exist': 0,
    'clean': 0,
    'power': 0,
    'functional': 0,
    'pieces': 0,
    'wet': 0,
    'open': 0,
    'temperature': 0,
    'solid': 0,
    'contain': 0,
    'running': 0,
    'moveable': 0,
    'mixed': 0,
    'edible': 0},
   '2': {'h_location': 0,
    'conscious': 0,
    'wearing': 0,
    'h_wet': 0,
    'hygiene': 0

In [94]:
num_examples = len(bert_explanations)

Maximum Voting Ensemble 
-Run either the following cell or the Maximum Confidence Ensemble cell to generate the ensemble's predictions

In [96]:
for i in range(num_examples):
    
    xlnet_story_pred = xlnet_filtered_explanations[i]['story_pred']   
    roberta_story_pred = roberta_filtered_explanations[i]['story_pred']
    bert_story_pred = bert_filtered_explanations[i]['story_pred']
    
    xlnet_exp_pred = xlnet_filtered_explanations[i]['valid_explanation']   
    roberta_exp_pred = roberta_filtered_explanations[i]['valid_explanation']
    bert_exp_pred = bert_filtered_explanations[i]['valid_explanation']
    
    xlnet_conflict_pred = tuple(xlnet_filtered_explanations[i]['conflict_pred'])
    roberta_conflict_pred = tuple(roberta_filtered_explanations[i]['conflict_pred'])
    bert_conflict_pred = tuple(bert_filtered_explanations[i]['conflict_pred'])
    
    ensemble_model_predictions[i]['story_pred'] = mode([xlnet_story_pred, roberta_story_pred, bert_story_pred])
    ensemble_model_predictions[i]['valid_explanation'] = mode([xlnet_exp_pred, roberta_exp_pred, bert_exp_pred])
    
    all_conflicts = [xlnet_conflict_pred, roberta_conflict_pred, bert_conflict_pred]
    all_conflicts_set = set(all_conflicts)
    
    if len(all_conflicts_set) <3:
        ensemble_model_predictions[i]['conflict_pred'] = max(all_conflicts_set, key=all_conflicts.count)
    else:
        ensemble_model_predictions[i]['conflict_pred'] = list(roberta_conflict_pred)
    

Maximum Confidence Ensemble

In [None]:
for i in range(num_examples):
    
    xlnet_story_pred = xlnet_filtered_explanations[i]['story_pred']   
    roberta_story_pred = roberta_filtered_explanations[i]['story_pred']
    bert_story_pred = bert_filtered_explanations[i]['story_pred']

    xlnet_story_prob = max(xlnet_filtered_explanations[i]['story_prob'])
    roberta_story_prob = max(roberta_filtered_explanations[i]['story_prob'])
    bert_story_prob = max(bert_filtered_explanations[i]['story_prob'])
    
    xlnet_exp_pred = xlnet_filtered_explanations[i]['valid_explanation']   
    roberta_exp_pred = roberta_filtered_explanations[i]['valid_explanation']
    bert_exp_pred = bert_filtered_explanations[i]['valid_explanation']
    
    xlnet_conflict_pred = tuple(xlnet_filtered_explanations[i]['conflict_pred'])
    roberta_conflict_pred = tuple(roberta_filtered_explanations[i]['conflict_pred'])
    bert_conflict_pred = tuple(bert_filtered_explanations[i]['conflict_pred'])

    # Set the story_pred as the prediction from the model with the highest story probability for a given class label
    if xlnet_story_prob > roberta_story_prob and xlnet_story_prob > bert_story_prob:
        ensemble_model_predictions[i]['story_pred'] = xlnet_story_pred
        ensemble_model_predictions[i]['valid_explanation'] = xlnet_exp_pred
        ensemble_model_predictions[i]['conflict_pred'] = xlnet_conflict_pred
    elif roberta_story_prob > xlnet_story_prob and roberta_story_prob > bert_story_prob:
        ensemble_model_predictions[i]['story_pred'] = roberta_story_pred
        ensemble_model_predictions[i]['valid_explanation'] = roberta_exp_pred
        ensemble_model_predictions[i]['conflict_pred'] = roberta_conflict_pred
    elif bert_story_prob > roberta_story_prob and bert_story_prob > xlnet_story_prob:
        ensemble_model_predictions[i]['story_pred'] = bert_story_pred
        ensemble_model_predictions[i]['valid_explanation'] = bert_exp_pred
        ensemble_model_predictions[i]['conflict_pred'] = bert_conflict_pred

In [75]:
ensemble_model_predictions_story = [x['story_pred'] for x in ensemble_model_predictions]
ensemble_model_actual_story = [x['story_label'] for x in ensemble_model_predictions]
ensemble_model_predictions_conflict = [x['conflict_pred'] for x in ensemble_model_predictions]
ensemble_model_actual_conflict = [x['conflict_label'] for x in ensemble_model_predictions]

In [69]:
accuracy_score(ensemble_model_predictions_story, ensemble_model_actual_story)

0.7720797720797721

In [87]:
f1_score(ensemble_model_predictions_story, ensemble_model_actual_story)

0.7687861271676302

In [102]:
verifiable_preds = 0
consistent_preds = 0

In [105]:
for expl in ensemble_model_predictions:
    if expl['valid_explanation']:
        verifiable_preds += 1
        if expl['story_pred'] == expl['story_label']:
            if len(expl['conflict_pred']) == len(expl['conflict_label']) and expl['conflict_pred'][0] == expl['conflict_label'][0] and expl['conflict_pred'][1] == expl['conflict_label'][1]:
                consistent_preds += 1

In [107]:
verifiable_preds/num_examples

0.05982905982905983

In [116]:
consistent_preds/num_examples

0.05982905982905983

In [117]:
len([x['valid_explanation'] for x in ensemble_model_predictions if x['valid_explanation'] == True])

21

In [118]:
xlnet_filtered_explanations[0]

{'story_label': 1,
 'story_pred': 0,
 'conflict_label': [3, 4],
 'conflict_pred': [2, 3],
 'valid_explanation': False}