# Support / Contrastivity

In [None]:
# Load temp_analysis.xlsx
# Then do the simulatability analysis
# Load the xlsx sheet
import openpyxl
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import re
import ast

file_path = 'temp_analysis.xlsx'

# Load all sheets into a dictionary
sheets_dict = pd.read_excel(file_path, sheet_name=None)

# Refactor image paths and ensure column types
for sheet_name, df in sheets_dict.items():
    # Ensure specific columns are strings
    for col in ['question', 'predicted_answer', 'correct_answer', 'generated_rationale']:
        if col in df.columns:
            df[col] = df[col].astype(str)
            
    # Make Alternative Hypotheses a list, using ast to convert string to list
    if 'Alternative Hypotheses' in df.columns:
        df['Alternative Hypotheses'] = df['Alternative Hypotheses'].apply(ast.literal_eval)
    
    # Print sample rows to verify
    print(f"Sheet Name: {sheet_name}")
    print(df.head())
    print()
    
# Extract a small sheet for testing
mini_sheet = sheets_dict['LLaVA-1.5 with image'].head()
demo_sheets_dict = {
    "mini_sheet": mini_sheet,
}
mini_sheet


In [None]:
%pip install bitsandbytes

In [None]:
import os
import re
import ast
import nltk
import torch
import json
import hashlib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

if 'nli_tokenizer' not in globals():
    nli_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl")
if 'nli_model' not in globals():
    nli_model = AutoModelForSeq2SeqLM.from_pretrained("soumyasanyal/nli-entailment-verifier-xxl", load_in_8bit=True, device_map="auto")

def get_longest_rationale(rationale_list):
    rationales = eval(rationale_list)
    return max(rationales, key=len) if isinstance(rationales, list) else ''

def calc_support_prob(premise, hypothesis, nli_model=nli_model, nli_tokenizer=nli_tokenizer):
    def get_score(nli_model, nli_tokenizer, input_ids):
        pos_ids = nli_tokenizer('Yes').input_ids
        neg_ids = nli_tokenizer('No').input_ids
        pos_id = pos_ids[0]
        neg_id = neg_ids[0]

        with torch.no_grad():
            logits = nli_model(input_ids, decoder_input_ids=torch.zeros((input_ids.size(0), 1), dtype=torch.long)).logits
            pos_logits = logits[:, 0, pos_id]
            neg_logits = logits[:, 0, neg_id]
            posneg_logits = torch.cat([pos_logits.unsqueeze(-1), neg_logits.unsqueeze(-1)], dim=1)

            # Cast to float before applying softmax
            posneg_logits = posneg_logits.float()
            scores = torch.nn.functional.softmax(posneg_logits, dim=1)
            entail_score = scores[:, 0].item()
            no_entail_score = scores[:, 1].item()
        
        return entail_score, no_entail_score
    
    prompt = f"Premise: {premise}\nHypothesis: {hypothesis}\nGiven the premise, is the hypothesis correct?\nAnswer:"
    input_ids = nli_tokenizer(prompt, return_tensors='pt').input_ids
    return get_score(nli_model, nli_tokenizer, input_ids)[0]

def generate_mask(generated_rationale, predicted_answer):
    # Create a regex pattern to match the predicted answer case-insensitively and as a whole word
    predicted_answer = str(predicted_answer)
    pattern = re.compile(r'\b' + re.escape(predicted_answer) + r'\b', re.IGNORECASE)
    return pattern.sub("<mask>", generated_rationale)

def evaluate_support(data, nli_model, nli_tokenizer, use_mask=True, use_pieces=False, hypothesis_col='hypothesis', threshold=0.5):
    support_scores = []
    for idx, row in data.iterrows():
        if use_pieces:
            premise = row['concat_rationale_pieces_mask'] if use_mask else row['concat_rationale_pieces']
        else: 
            premise = row['gen_rationale_mask'] if use_mask else row['generated_rationale']
        
        hypothesis = row[hypothesis_col]
        entail_prob = calc_support_prob(premise, hypothesis, nli_model=nli_model, nli_tokenizer=nli_tokenizer)
        support = entail_prob > threshold 
        if entail_prob < threshold:
            print(f"Premise: {premise}")
            print(f"Hypothesis: {hypothesis}")
            print(f"Probability: {entail_prob}")
        support_scores.append({
            'entail_prob': entail_prob,
            'support': support
        })
    return support_scores

def compute_file_hash(file_path):
    hasher = hashlib.md5()
    with open(file_path, 'rb') as f:
        buf = f.read()
        hasher.update(buf)
    return hasher.hexdigest()

In [None]:
# generate mask for each rationale
for sheet_name, df in sheets_dict.items():
    df['rationale_mask'] = df.apply(lambda x: generate_mask(x['generated_rationale'], x['predicted_answer']), axis=1)

In [None]:
def evaluate_strict_support(data, nli_model, nli_tokenizer, use_mask=True, use_pieces=False, hypothesis_col='hypothesis', alt_hypotheses_col='alternative_hypotheses', threshold=0.5):
    strict_supports = []
    for idx, row in data.iterrows():
        premise = row['rationale_mask']
        hypothesis = row[hypothesis_col]
        alt_hypotheses = row[alt_hypotheses_col]
        
        entail_prob = calc_support_prob(premise, hypothesis, nli_model=nli_model, nli_tokenizer=nli_tokenizer)
        alt_entail_probs = []
        for alt_hypothesis in alt_hypotheses:
            alt_entail_probs.append(calc_support_prob(premise, alt_hypothesis, nli_model=nli_model, nli_tokenizer=nli_tokenizer))
        
        strict_support = entail_prob > threshold and all(alt_entail_prob < threshold for alt_entail_prob in alt_entail_probs)
#         if strict_support == False:
        print(f"Premise: {premise}")
        print(f"Hypothesis: {hypothesis}")
        print(f"Probability: {entail_prob}")
        print(f"Alternative hypotheses: {alt_hypotheses}")
        print(f"Probability: {alt_entail_probs}")
        print("--------------------------------------------")
        strict_supports.append(strict_support)
        data.loc[idx, 'entail_prob'] = entail_prob
        data.loc[idx, 'alt_ent_prob'] = str(alt_entail_probs)
        data.loc[idx, 'support'] = entail_prob > threshold
        data.loc[idx, 'strict_sim'] = strict_support
    return
    
for sheet_name, df in sheets_dict.items():
    evaluate_strict_support(df, nli_model, nli_tokenizer, use_mask=True, use_pieces=False, hypothesis_col='Hypothesis', alt_hypotheses_col='Alternative Hypotheses', threshold=0.5)

In [None]:
# Store the updated sheets here, as a temp, in case of any errors
temp_file_path = 'temp_analysis_forked_v2.xlsx'
with pd.ExcelWriter(temp_file_path) as writer:
    for sheet_name, df in sheets_dict.items():
        df.to_excel(writer, sheet_name=sheet_name, index=False)

# Load the temp file to verify
sheets_dict_temp = pd.read_excel(temp_file_path, sheet_name=None)
for sheet_name, df in sheets_dict_temp.items():
    print(f"Sheet Name: {sheet_name}")
    print(df.head())
    print()