# MMR calculation on teacher completions 

In [1]:
import nltk 
import json 
import pandas as pd
import numpy as np
from rouge_score import rouge_scorer
import torch

from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertTokenizer, BertModel
from transformers import BertTokenizer, BertModel

import copy 

import time
from tqdm import tqdm

### Similarity Metrics

In [2]:
def calculate_rougeL(text1, text2):
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    scores = scorer.score(text1, text2)
    return scores

def calculate_rouge2(text1, text2):
    scorer = rouge_scorer.RougeScorer(['rouge2'], use_stemmer=True)
    scores = scorer.score(text1, text2)
    return scores

In [3]:
def calculate_bertscore(text1, text2, tokenizer, model):
    inputs = tokenizer([text1, text2], return_tensors='pt', padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    
    embeddings = outputs.last_hidden_state
    # Calculate cosine similarity between the embeddings of text1 and text2
    similarity = torch.nn.functional.cosine_similarity(embeddings[0], embeddings[1], dim=0).mean().item()
    return similarity

In [4]:
def read_json(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
        return data  

### Format of teacher rationales 

{
    sample_index : int,
    question : "str",
    rationale_list : [
                        [ completion_index, rationale
                        ]
                     ]
}

In [5]:
# following function makes a list of list containing all the rationales with multiple reasonings

def correct_rationale_dict(ds_dict):
    correct_list = []
    ds_data = ds_dict['data']
    for key in ds_data.keys():
        if len(ds_data[key]) > 1 :
            l = []
            for i in ds_data[key]:
                # this is within the list corresponding to each index in data dictionary
                # i is a dictionary 
                if i['answer'] in i['completion']:
                    l.append(i)
            correct_list.append(l)
    return correct_list

def formatTeacherDataset(ds_list):
    for i in range(len(ds_list)):
        if len(ds_list[i]) == 0:
            continue
        d1 = dict()
        d1['sample_index'] = ds_list[i][0]['sample_index']
        d1['question'] = ds_list[i][0]['question']
        d1['rationale_list'] = list()
        for j in range(len(ds_list[i])):
            ci = ds_list[i][j]['completion_index']
            rationale = ds_list[i][j]['reasoning_completion']
            l = list()
            l.append(ci)
            l.append(rationale)
            d1['rationale_list'].append(l)
        ds_list[i] = d1

In [6]:
# make this code usable for all metrics (rouge2, rougeL, bert)

def mmr(question, rationale_list, lambda_const, metrics):
    ranking_set = []
    
    if( metrics == 'bert'):
        bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        bert_model = BertModel.from_pretrained('bert-base-uncased') 
    
    while len(rationale_list) > 0:
        score = 0
        list_to_add = None
        score_to_add = 0
        
        for i in rationale_list:
            if( metrics == 'bert'):
                first_part = calculate_bertscore(question, i[1], bert_tokenizer, bert_model)
            elif (metrics == 'rougeL'):
                first_part = calculate_rougeL(question, i[1])
                first_part = first_part['rougeL'].fmeasure
            elif (metrics == 'rouge2'):
                first_part = calculate_rouge2(question, i[1])
                first_part = first_part['rouge2'].fmeasure
            else:
                print("Incorrect metric selection")
                break
            second_part = 0
            
            for j in ranking_set:
                if( metrics == 'bert'):
                    sim = calculate_bertscore(j[1], i[1], bert_tokenizer, bert_model)
                elif (metrics == 'rougeL'):
                    sim = calculate_rougeL(j[1], i[1])
                    sim = sim['rougeL'].fmeasure
                elif (metrics == 'rouge2'):
                    sim = calculate_rouge2(j[1], i[1])
                    sim = sim['rouge2'].fmeasure
                else:
                    print("Incorrect metric selection")
                    break
                    
                if sim > second_part:
                    second_part = sim
                    
            equation_score = lambda_const*(first_part)-(1-lambda_const) * second_part
            if equation_score > score:
                score = equation_score
                list_to_add = i
                score_to_add = first_part
                
            
        if list_to_add is not None:
            # print(f"Removing rationale: {rationale_to_add}")
            rationale_list.remove(list_to_add)
            list_to_add.append(score_to_add)
            ranking_set.append(list_to_add)
        else:
            # print("No rationale to add found.")
            break
        
    return ranking_set

        

In [7]:
def MMRcaculation(data_list, lambda_const, metrics):
    mmr_list = []
    for i in range(len(data_list)):
        if len(data_list[i]) == 0:
            continue
        
        rat = copy.deepcopy(data_list[i]['rationale_list'])       
        s = dict()
        s['sample_index'] = data_list[i]['sample_index']
        s['question'] = data_list[i]['question']
        s['ranking'] = mmr(data_list[i]['question'], rat, lambda_const, metrics)
        mmr_list.append(s)
    return mmr_list

In [8]:
def makeJSON(data_list, file_name):
    with open(file_name, 'w') as json_file:
        json.dump(data_list, json_file, indent=4)
    print(f"Data has been written to {file_name}")

In [9]:
addsub70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_addsub.json'
coin70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_coin_flip.json'
du70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_date_understanding.json'
llconc70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_last_letter_concatenation.json'
mulAr70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_multiarith.json'
seq70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_single_eq.json'
strat70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_strategy_qa.json'
svamp70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_svamp.json'
tso70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_tracking_shuffled_objects.json'

# Reading JSON data. The read_json() function returns a dictionary
addsub70 = read_json(addsub70_path)
coin70 = read_json(coin70_path)
du70 = read_json(du70_path)
llconc70 = read_json(llconc70_path)
mulAr70 = read_json(mulAr70_path)
seq70 = read_json(seq70_path)
strat70 = read_json(strat70_path)
svamp70 = read_json(svamp70_path)
tso70 = read_json(tso70_path)

# making rationale dictionaries with CORRECT rationales

addsub70_Clist = correct_rationale_dict(addsub70)
coin70_Clist = correct_rationale_dict(coin70)
du70_Clist = correct_rationale_dict(du70)
llconc70_Clist = correct_rationale_dict(llconc70)
mulAr70_Clist = correct_rationale_dict(mulAr70)
seq70_Clist = correct_rationale_dict(seq70)
strat70_Clist = correct_rationale_dict(strat70)
#svamp70_Clist = correct_rationale_dict(svamp70)
tso70_Clist = correct_rationale_dict(tso70)


formatTeacherDataset(addsub70_Clist)
formatTeacherDataset(coin70_Clist)
formatTeacherDataset(du70_Clist)
formatTeacherDataset(llconc70_Clist)
formatTeacherDataset(mulAr70_Clist)
formatTeacherDataset(seq70_Clist)
formatTeacherDataset(strat70_Clist)
#formatTeacherDataset(svamp70_Clist)
formatTeacherDataset(tso70_Clist)

In [10]:
print(addsub70_Clist[0])

{'sample_index': 1, 'question': 'There were 28 bales of hay in the barn . Tim stacked bales in the barn today . There are now 54 bales of hay in the barn . How many bales did he store in the barn ?', 'rationale_list': [[0, ' \n\nThere were 28 bales of hay in the barn.\nTim stacked bales in the barn today.\n\nThis means that there must have been more bales of hay brought into the barn, because the number of bales increased from 28 to 54. So, Tim must have stored 26 bales of hay in the barn today.'], [1, ' \nWe know that there were 28 bales of hay in the barn. \nWe know that Tim stacked bales in the barn today. \nWe know that there are now 54 bales of hay in the barn. \nSo, how many bales did Tim store in the barn? \n\nWe can solve this problem by using basic algebra. \n\nFirst, we need to create a variable to represent the number of bales that Tim stored in the barn. We will use the letter "x" to represent this variable. \n\nNext, we need to create an equation that represents the inform

### Calculating MMR rankings 

In [11]:
start_time = time.time()
s_addsub = MMRcaculation(addsub70_Clist, 0.50, 'bert')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

Elapsed time: 1176.74 seconds


In [12]:
makeJSON(s_addsub, "addsub_mmr_q_bert_L0.50.json")

Data has been written to addsub_mmr_q_bert_L0.50.json


In [13]:
start_time = time.time()
s_coin = MMRcaculation(coin70_Clist, 0.50, 'bert')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

Elapsed time: 2425.02 seconds


In [14]:
makeJSON(s_coin, "coin_flip_mmr_q_bert_L0.50.json")

Data has been written to coin_flip_mmr_q_bert_L0.50.json


In [15]:
start_time = time.time()
s_du = MMRcaculation(du70_Clist, 0.50, 'bert')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

Elapsed time: 2045.22 seconds


In [16]:
makeJSON(s_du, "date_understanding_mmr_q_bert_L0.50.json")

Data has been written to date_understanding_mmr_q_bert_L0.50.json


In [17]:
start_time = time.time()
s_llconc = MMRcaculation(llconc70_Clist, 0.50, 'bert')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

Elapsed time: 1620.38 seconds


In [18]:
makeJSON(s_llconc, "Last_Letter_Concatenation_mmr_q_bert_L0.50.json")

Data has been written to Last_Letter_Concatenation_mmr_q_bert_L0.50.json


In [19]:
"""
start_time = time.time()
s_mulAr = MMRcaculation(mulAr70_Clist)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

makeJSON(s_mulAr, "MultiArith_mmr_q_bert.json")
"""

'\nstart_time = time.time()\ns_mulAr = MMRcaculation(mulAr70_Clist)\nend_time = time.time()\nelapsed_time = end_time - start_time\nprint(f"Elapsed time: {elapsed_time:.2f} seconds")\n\nmakeJSON(s_mulAr, "MultiArith_mmr_q_bert.json")\n'

In [20]:
start_time = time.time()
s_seq = MMRcaculation(seq70_Clist, 0.50, 'bert')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

Elapsed time: 972.21 seconds


In [21]:
makeJSON(s_seq, "Single_Equation_mmr_q_bert_L0.50.json")

Data has been written to Single_Equation_mmr_q_bert_L0.50.json


In [22]:
start_time = time.time()
s_strat = MMRcaculation(strat70_Clist, 0.50, 'bert')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

makeJSON(s_strat, "StrategyQA_mmr_q_bert_L0.50.json")


Elapsed time: 119.35 seconds
Data has been written to StrategyQA_mmr_q_bert_L0.50.json


In [23]:
start_time = time.time()
s_tso = MMRcaculation(tso70_Clist, 0.50, 'bert')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

Elapsed time: 2607.81 seconds


In [24]:
makeJSON(s_tso, "Tracking_Shuffled_Objects_mmr_q_bert_L0.50.json")

Data has been written to Tracking_Shuffled_Objects_mmr_q_bert_L0.50.json


In [25]:
#s_svamp = MMRcaculation(svamp70_Clist)

In [26]:
"""
def pairWiseScore(df, rationale_list):
    for i in range(len(rationale_list)):
        for j in range(len(rationale_list)):
            if i == j:
                df.iloc[i, j] = 1.0
            else:
                score = calculate_rouge(rationale_list[i], rationale_list[j])
                df.iloc[i, j] = score['rouge2'].fmeasure        # stores only ROUGE2 fmeasure (for now at least)
                
                
rationales = ['ci_0', 'ci_1', 'ci_2', 'ci_3', 'ci_4', 'ci_5', 'ci_6', 'ci_7']
df = pd.DataFrame(index=rationales, columns=rationales)
pairWiseScore(df, rationale_list)


def quesScore(question, rationale_list):
    d = dict()
    for i in range(len(rationale_list)):
        score = calculate_rouge(question, rationale_list[i])
        ci = "ci_" + str(i)
        d[ci] = score['rouge2'].fmeasure
    return d
    

ques_sim_list = quesScore(question, rationale_list)
print(ques_sim_list)


# S contains the rationale with the maximum similarity score with the question 
# R contains the rationales other than the ones in S
maxi = max(ques_sim_list.values())
S = [(key, value) for key, value in ques_sim_list.items() if value == maxi]
R = [(key, value) for key, value in ques_sim_list.items() if (key, value) not in S]
"""


'\ndef pairWiseScore(df, rationale_list):\n    for i in range(len(rationale_list)):\n        for j in range(len(rationale_list)):\n            if i == j:\n                df.iloc[i, j] = 1.0\n            else:\n                score = calculate_rouge(rationale_list[i], rationale_list[j])\n                df.iloc[i, j] = score[\'rouge2\'].fmeasure        # stores only ROUGE2 fmeasure (for now at least)\n                \n                \nrationales = [\'ci_0\', \'ci_1\', \'ci_2\', \'ci_3\', \'ci_4\', \'ci_5\', \'ci_6\', \'ci_7\']\ndf = pd.DataFrame(index=rationales, columns=rationales)\npairWiseScore(df, rationale_list)\n\n\ndef quesScore(question, rationale_list):\n    d = dict()\n    for i in range(len(rationale_list)):\n        score = calculate_rouge(question, rationale_list[i])\n        ci = "ci_" + str(i)\n        d[ci] = score[\'rouge2\'].fmeasure\n    return d\n    \n\nques_sim_list = quesScore(question, rationale_list)\nprint(ques_sim_list)\n\n\n# S contains the rationale with