# MMR calculation on student completions 

For Flan T5 base model

In [1]:
import nltk 
import json 
import pandas as pd
import numpy as np
from rouge_score import rouge_scorer
import torch

from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertTokenizer, BertModel
from transformers import BertTokenizer, BertModel

import time
from tqdm import tqdm

import copy

In [2]:
def calculate_rougeL(text1, text2):
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    scores = scorer.score(text1, text2)
    return scores

def calculate_rouge2(text1, text2):
    scorer = rouge_scorer.RougeScorer(['rouge2'], use_stemmer=True)
    scores = scorer.score(text1, text2)
    return scores

In [3]:
def calculate_bertscore(text1, text2, tokenizer, model):
    inputs = tokenizer([text1, text2], return_tensors='pt', padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    
    embeddings = outputs.last_hidden_state
    # Calculate cosine similarity between the embeddings of text1 and text2
    similarity = torch.nn.functional.cosine_similarity(embeddings[0], embeddings[1], dim=0).mean().item()
    return similarity

In [4]:
def read_json(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
        return data  

In [5]:
# following function makes a list of list containing all the rationales with multiple reasonings

def correct_rationale_dict(ds_dict):
    correct_list = []
    ds_data = ds_dict['data']
    for key in ds_data.keys():
        if len(ds_data[key]) > 1 :
            l = []
            for i in ds_data[key]:
                # this is within the list corresponding to each index in data dictionary
                # i is a dictionary 
                if i['answer'] in i['completion']:
                    l.append(i)
            correct_list.append(l)
    return correct_list


def formatTeacherDataset(teacher_list):
    for i in range(len(teacher_list)):
        if len(teacher_list[i]) == 0:
            continue
        d1 = dict()
        d1['sample_index'] = teacher_list[i][0]['sample_index']
        d1['question'] = teacher_list[i][0]['question']
        d1['rationale_dict'] = list()
        for j in range(len(teacher_list[i])):
            ci = teacher_list[i][j]['completion_index']
            rationale = teacher_list[i][j]['reasoning_completion']
            l = list()
            l.append(ci)
            l.append(rationale)
            d1['rationale_dict'].append(l)
        teacher_list[i] = d1

In [6]:
"""
The following modules are for preprocessing the student dataset in a way that it becomes useful 
"""

# make a function for removing all the student rationales that have been incorrectly generated 

def removeIncorrectStudentRat(student_dataset):
    i = 0
    while(i != len(student_dataset)):
        if( student_dataset[i]['correct'] == False ):
            d = student_dataset[i]
            student_dataset.remove(d)
        else:
            i += 1
             

# The followinf function compares all the student rationales and find their respective sample indices in the dataset 

def addSI(student_list, teacher_list):
    for i in range(len(teacher_list)):
        if len(teacher_list[i]) == 0:
            continue
        question = teacher_list[i]['question']
        si = teacher_list[i]['sample_index']     
        for j in range(len(student_list)):
            prompt = student_list[j]['prompt']
            if( question in prompt ):
                student_list[j]['sample_index'] = si
                break            
                
# removing all the rationales with no sample_indices in them from student dataset 

def helper(student_dataset):
    i = 0 
    while( i != len(student_dataset)):
        if ( 'sample_index' not in student_dataset[i] ):
            d = student_dataset[i]
            student_dataset.remove(d)
        else:
            i += 1
            

In [7]:
# the following function removes rationales from the teacher dataset 

def removeEmpty(teacher_list):
    i = 0
    while( i != len(teacher_list)):
        if (len(teacher_list[i]) != 0 ):
            i += 1
        else:
            d = teacher_list[i]
            teacher_list.remove(d)             
            

def removeRat(student_list, teacher_list):
    # Create a set of sample indices from student_list for fast lookup
    student_indices = {student['sample_index'] for student in student_list}
    
    # Use list comprehension to create a new list with only the desired elements
    teacher_list = [d for d in teacher_list if 'sample_index' in d and (len(d) != 0) and (d['sample_index'] in student_indices)]
    
    return teacher_list

In [8]:
def mmr(student_prompt, teacher_rationale_list, lambda_const, metrics):
    ranking_set = []
    
    if( metrics == 'bert'):
        bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        bert_model = BertModel.from_pretrained('bert-base-uncased') 
    
    while len(teacher_rationale_list) > 0:
        rationale_to_add = None
        score = 0
        rouge_score_to_add = 0
        
        for rationale in teacher_rationale_list:
            
            if( metrics == 'bert'):
                first_part = calculate_bertscore(student_prompt, rationale[1], bert_tokenizer, bert_model)
            elif (metrics == 'rougeL'):
                first_part = calculate_rougeL(student_prompt, rationale[1])
                first_part = first_part['rougeL'].fmeasure
            elif (metrics == 'rouge2'):
                first_part = calculate_rouge2(student_prompt, rationale[1])
                first_part = first_part['rouge2'].fmeasure
            else:
                print("Incorrect metric selection")
                break
                
            second_part = 0
            
            for di in ranking_set:
                
                if( metrics == 'bert'):
                    sim = calculate_bertscore(di[1], rationale[1], bert_tokenizer, bert_model)
                elif (metrics == 'rougeL'):
                    sim = calculate_rougeL(di[1], rationale[1])
                    sim = sim['rougeL'].fmeasure
                elif (metrics == 'rouge2'):
                    sim = calculate_rouge2(di[1], rationale[1])
                    sim = sim['rouge2'].fmeasure
                else:
                    print("Incorrect metric selection")
                    break
                
                if sim > second_part:
                    second_part = sim
            
            mmr_score = lambda_const*(first_part)-(1-lambda_const) * second_part
            if (mmr_score > score):
                score = mmr_score
                rationale_to_add = rationale
                rouge_score_to_add = first_part
        
        if rationale_to_add is not None:
            # print(f"Removing rationale: {rationale_to_add}")
            teacher_rationale_list.remove(rationale_to_add)
            rationale_to_add.append(rouge_score_to_add)
            ranking_set.append(rationale_to_add)
        else:
            # print("No rationale to add found.")
            break
                
        #teacher_rationale_list.remove(rationale_to_add)
        #rationale_to_add.append(rouge_score_to_add)
        #ranking_set.append(rationale_to_add)
    
    return ranking_set
        



# understand your mistakes from the following 
"""
def mmr1(student_prompt, teacher_rationale_dict, lambda_const):

    ranking_set = []
    rationales = copy.deepcopy(teacher_rationale_dict)
    while len(r) > 0:
    
        rationale_to_add = list()
        score = 0
        rouge_score_to_add = 0
        
        for rationale in teacher_rationale_dict:
        
            first_part = calculate_rouge(student_prompt, rationale[1])
            first_part = first_part['rouge2'].fmeasure
            second_part = 0
            
            for di in ranking_set:
            
                sim = calculate_rouge(di[1], rationale[1])
                sim = sim['rouge2'].fmeasure
                
                if sim > second_part:
                    second_part = sim
                    
            mmr_score = lambda_const*(first_part)-(1-lambda_const) * second_part
            
            if (mmr_score > score):
                score = mmr_score
                rationale_to_add = rationale
                rouge_score_to_add = first_part
                
        teacher_rationale_dict.remove(rationale_to_add)
        rationale_to_add.append(rouge_score_to_add)
        ranking_set.append(rationale_to_add)
        
    return ranking_set
"""

"\ndef mmr1(student_prompt, teacher_rationale_dict, lambda_const):\n\n    ranking_set = []\n    rationales = copy.deepcopy(teacher_rationale_dict)\n    while len(r) > 0:\n    \n        rationale_to_add = list()\n        score = 0\n        rouge_score_to_add = 0\n        \n        for rationale in teacher_rationale_dict:\n        \n            first_part = calculate_rouge(student_prompt, rationale[1])\n            first_part = first_part['rouge2'].fmeasure\n            second_part = 0\n            \n            for di in ranking_set:\n            \n                sim = calculate_rouge(di[1], rationale[1])\n                sim = sim['rouge2'].fmeasure\n                \n                if sim > second_part:\n                    second_part = sim\n                    \n            mmr_score = lambda_const*(first_part)-(1-lambda_const) * second_part\n            \n            if (mmr_score > score):\n                score = mmr_score\n                rationale_to_add = rationale\n        

In [9]:
def MMRcaculation(teacher_list, student_list, lambda_const, metrics):
    # this function passes the rationale's dictionary from teacher_dataset and prompt to calculate MMR
    mmr_list = []
    for i in range(len(teacher_list)):
        if len(teacher_list[i]) == 0:
            continue
            
        teacher_rationale_list = copy.deepcopy(teacher_list[i]['rationale_dict'])
        
        s = dict()
        prompt = teacher_list[i]['question'] + student_list[i]['model_inference']
        s['sample_index'] = teacher_list[i]['sample_index']
        s['prompt'] = prompt
        s['ranking'] = mmr(prompt, teacher_rationale_list, lambda_const, metrics)
        mmr_list.append(s)
    return mmr_list

In [10]:
def makeJSON(data_list, file_name):
    with open(file_name, 'w') as json_file:
        json.dump(data_list, json_file, indent=4)
    print(f"Data has been written to {file_name}")

In [11]:
addsub70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_addsub.json'
coin70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_coin_flip.json'
du70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_date_understanding.json'
llconc70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_last_letter_concatenation.json'
mulAr70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_multiarith.json'
seq70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_single_eq.json'
strat70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_strategy_qa.json'
svamp70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_svamp.json'
tso70_path = '/Users/shiprasingh/IIT KGP internship /reasoning-teacher/saved/teacher_completion_data/B_text-davinci-002__C_zs_cot_t70/D_tracking_shuffled_objects.json'

# Reading JSON data. The read_json() function returns a dictionary
addsub70 = read_json(addsub70_path)
coin70 = read_json(coin70_path)
du70 = read_json(du70_path)
llconc70 = read_json(llconc70_path)
mulAr70 = read_json(mulAr70_path)
seq70 = read_json(seq70_path)
strat70 = read_json(strat70_path)
svamp70 = read_json(svamp70_path)
tso70 = read_json(tso70_path)

# making rationale dictionaries with CORRECT rationales

addsub70_Clist = correct_rationale_dict(addsub70)
coin70_Clist = correct_rationale_dict(coin70)
du70_Clist = correct_rationale_dict(du70)
llconc70_Clist = correct_rationale_dict(llconc70)
mulAr70_Clist = correct_rationale_dict(mulAr70)
seq70_Clist = correct_rationale_dict(seq70)
strat70_Clist = correct_rationale_dict(strat70)
#svamp70_Clist = correct_rationale_dict(svamp70)
tso70_Clist = correct_rationale_dict(tso70)


formatTeacherDataset(addsub70_Clist)
formatTeacherDataset(coin70_Clist)
formatTeacherDataset(du70_Clist)
formatTeacherDataset(llconc70_Clist)
formatTeacherDataset(mulAr70_Clist)
formatTeacherDataset(seq70_Clist)
formatTeacherDataset(strat70_Clist)
#formatTeacherDataset(svamp70_Clist)
formatTeacherDataset(tso70_Clist)

In [12]:
# importing all the student generated rationale JSON

add_sub_out = "/Users/shiprasingh/IIT KGP internship /FLAN Inferences Dataset/addsub_out.json"
coin_flip_out = "/Users/shiprasingh/IIT KGP internship /FLAN Inferences Dataset/coin_flip_out.json"
date_understanding_out = "/Users/shiprasingh/IIT KGP internship /FLAN Inferences Dataset/date_understanding_out.json"
last_letter_concatenation_out = "/Users/shiprasingh/IIT KGP internship /FLAN Inferences Dataset/last_letter_concatenation_out.json"

addsub = read_json(add_sub_out)
coin = read_json(coin_flip_out)
du = read_json(date_understanding_out)
llconc = read_json(last_letter_concatenation_out)

### Reformating student completion dataset 

In [13]:
removeIncorrectStudentRat(addsub)
removeIncorrectStudentRat(coin)
removeIncorrectStudentRat(du)
removeIncorrectStudentRat(llconc)

addSI(addsub, addsub70_Clist)
addSI(coin, coin70_Clist)
addSI(du, du70_Clist)
addSI(llconc, llconc70_Clist)

helper(addsub)
helper(coin)
helper(du)
helper(llconc)

### Reformating teacher completion dataset 

In [14]:
removeEmpty(addsub70_Clist)
removeEmpty(coin70_Clist)
removeEmpty(du70_Clist)
removeEmpty(llconc70_Clist)

addsub70_Clist_new = removeRat(addsub, addsub70_Clist)
coin70_Clist_new = removeRat(coin, coin70_Clist)
du70_Clist_new = removeRat(du, du70_Clist)
llconc70_Clist_new = removeRat(llconc, llconc70_Clist)

In [15]:
print(addsub70_Clist_new[0])

{'sample_index': 2, 'question': 'Mary is baking a cake . The recipe wants 8 cups of flour . She already put in 2 cups . How many cups does she need to add ?', 'rationale_dict': [[0, ' \n\nThe recipe wants 8 cups of flour.\nShe already put in 2 cups.\n\nThat means she needs to add 8 - 2 = 6 more cups of flour.'], [1, ' \n\nMary is baking a cake. \nThe recipe wants 8 cups of flour. \nShe already put in 2 cups. \n\nThat means she needs to add 6 more cups of flour to the recipe.'], [2, ' \n\nMary is baking a cake. \nThe recipe wants 8 cups of flour. \nShe already put in 2 cups. \n\nSo, she needs to add 6 cups of flour.'], [3, " \n\nFirst, let's figure out how many cups of flour Mary has used so far. We know that she started with 2 cups of flour, and the recipe wants 8 cups of flour. So, we can subtract 2 cups from 8 cups to find out how many more cups Mary needs to add. \n\n8 cups - 2 cups = 6 cups \n\nSo, Mary needs to add 6 cups of flour to the cake."], [4, ' \nThe recipe wants 8 cups of

In [16]:
print(type(addsub70_Clist_new[0]['rationale_dict']))

<class 'list'>


In [17]:
start_time = time.time()
s_addsub = MMRcaculation(addsub70_Clist_new, addsub, 0.50, 'bert')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

Elapsed time: 548.53 seconds


In [18]:
makeJSON(s_addsub, "addsub_mmr_flanT5base_bert_L0.50.json")

Data has been written to addsub_mmr_flanT5base_bert_L0.50.json


In [19]:
start_time = time.time()
s_coin = MMRcaculation(coin70_Clist_new, coin, 0.50, 'bert')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

Elapsed time: 1892.91 seconds


In [20]:
makeJSON(s_coin, "coin_flip_mmr_flanT5base_bert_L0.50.json")

Data has been written to coin_flip_mmr_flanT5base_bert_L0.50.json


In [21]:
start_time = time.time()
s_du = MMRcaculation(du70_Clist_new, du, 0.50, 'bert')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

Elapsed time: 2123.96 seconds


In [22]:
makeJSON(s_du, "date_understanding_mmr_flanT5base_bert_L0.50.json")

Data has been written to date_understanding_mmr_flanT5base_bert_L0.50.json


In [23]:
start_time = time.time()
s_llconc = MMRcaculation(llconc70_Clist_new, llconc, 0.50, 'bert')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

Elapsed time: 0.00 seconds


In [24]:
makeJSON(s_llconc, "Last_Letter_Concatenation_mmr_flanT5base_bert_L0.50.json")

Data has been written to Last_Letter_Concatenation_mmr_flanT5base_bert_L0.50.json


In [25]:
"""
start_time = time.time()
s_mulAr = MMRcaculation(mulAr70_Clist)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

makeJSON(s_mulAr, "MultiArith_mmr_q_rL.json")
"""

'\nstart_time = time.time()\ns_mulAr = MMRcaculation(mulAr70_Clist)\nend_time = time.time()\nelapsed_time = end_time - start_time\nprint(f"Elapsed time: {elapsed_time:.2f} seconds")\n\nmakeJSON(s_mulAr, "MultiArith_mmr_q_rL.json")\n'

In [26]:
"""
def pairWiseScore(df, rationale_list):
    for i in range(len(rationale_list)):
        for j in range(len(rationale_list)):
            if i == j:
                df.iloc[i, j] = 1.0
            else:
                score = calculate_rouge(rationale_list[i], rationale_list[j])
                df.iloc[i, j] = score['rouge2'].fmeasure        # stores only ROUGE2 fmeasure (for now at least)
                
                
rationales = ['ci_0', 'ci_1', 'ci_2', 'ci_3', 'ci_4', 'ci_5', 'ci_6', 'ci_7']
df = pd.DataFrame(index=rationales, columns=rationales)
pairWiseScore(df, rationale_list)


def quesScore(question, rationale_list):
    d = dict()
    for i in range(len(rationale_list)):
        score = calculate_rouge(question, rationale_list[i])
        ci = "ci_" + str(i)
        d[ci] = score['rouge2'].fmeasure
    return d
    

ques_sim_list = quesScore(question, rationale_list)
print(ques_sim_list)


# S contains the rationale with the maximum similarity score with the question 
# R contains the rationales other than the ones in S
maxi = max(ques_sim_list.values())
S = [(key, value) for key, value in ques_sim_list.items() if value == maxi]
R = [(key, value) for key, value in ques_sim_list.items() if (key, value) not in S]
"""


'\ndef pairWiseScore(df, rationale_list):\n    for i in range(len(rationale_list)):\n        for j in range(len(rationale_list)):\n            if i == j:\n                df.iloc[i, j] = 1.0\n            else:\n                score = calculate_rouge(rationale_list[i], rationale_list[j])\n                df.iloc[i, j] = score[\'rouge2\'].fmeasure        # stores only ROUGE2 fmeasure (for now at least)\n                \n                \nrationales = [\'ci_0\', \'ci_1\', \'ci_2\', \'ci_3\', \'ci_4\', \'ci_5\', \'ci_6\', \'ci_7\']\ndf = pd.DataFrame(index=rationales, columns=rationales)\npairWiseScore(df, rationale_list)\n\n\ndef quesScore(question, rationale_list):\n    d = dict()\n    for i in range(len(rationale_list)):\n        score = calculate_rouge(question, rationale_list[i])\n        ci = "ci_" + str(i)\n        d[ci] = score[\'rouge2\'].fmeasure\n    return d\n    \n\nques_sim_list = quesScore(question, rationale_list)\nprint(ques_sim_list)\n\n\n# S contains the rationale with