In [1]:
!pip install outlines
!pip install context_cite

Collecting context_cite
  Downloading context_cite-0.0.4-py3-none-any.whl.metadata (7.7 kB)
Collecting nltk>=3.8.2 (from context_cite)
  Using cached nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Collecting spacy (from context_cite)
  Using cached spacy-3.8.4-cp311-cp311-win_amd64.whl.metadata (27 kB)
Collecting click (from nltk>=3.8.2->context_cite)
  Using cached click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting spacy-legacy<3.1.0,>=3.0.11 (from spacy->context_cite)
  Using cached spacy_legacy-3.0.12-py2.py3-none-any.whl.metadata (2.8 kB)
Collecting spacy-loggers<2.0.0,>=1.0.0 (from spacy->context_cite)
  Using cached spacy_loggers-1.0.5-py3-none-any.whl.metadata (23 kB)
Collecting murmurhash<1.1.0,>=0.28.0 (from spacy->context_cite)
  Using cached murmurhash-1.0.12-cp311-cp311-win_amd64.whl.metadata (2.2 kB)
Collecting cymem<2.1.0,>=2.0.2 (from spacy->context_cite)
  Using cached cymem-2.0.11-cp311-cp311-win_amd64.whl.metadata (8.8 kB)
Collecting preshed<3.1.0,>=3.0.2 (from 

## Imports

In [1]:
%load_ext autoreload
%autoreload 2

# Add the path to the parent directory to sys
import sys, os

# If current directory is called 'notebooks', chdir to the parent
if os.path.basename(os.getcwd()) == 'notebooks':
    os.chdir('../')
    
sys.path.append('attribution')

from torch.utils.data import DataLoader

import pandas as pd
from constants import ModelNames, DatasetNames, LANGUAGE_MAPPING
from model_utils import Model 
from dataset_utils import GSMDataset, PaddingCollator, is_correct_gsm, extract_answer_gsm
from context_cite import ContextCiter
from tqdm.notebook import tqdm

from contextlib import contextmanager

import warnings

# Filter specific warning categories
warnings.filterwarnings("ignore", category=UserWarning)  # For general user warnings
warnings.filterwarnings("ignore", category=FutureWarning)  # For deprecation warnings

# Definitions
processed_data_path = "results/analysis_mgsm_en_Qwen2.5-1.5B-Instruct_results.csv"
model_name = ModelNames.QwenInstruct

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\User\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


## Analysis: Processing Responses

In [30]:
def load_model_dataset():
    model = Model(ModelNames.QwenInstruct)
    return model, DatasetNames.MGSM

class ResponseProcessing():
    def __init__(self, model, dataset, config='en', is_cot=True):
        self.df_column_names = ["question", "actual_answer", "model_gen_steps", "model_gen_answer", 'model_answer_str']
        self.model = model
        self.dataset = dataset
        self.config = config
        self.is_cot = is_cot
    
    def convert_dashes_incremental_steps_list(self, steps):
        furnished_steps = []

        i = 1
        for _, step in enumerate(steps[1:]):
            if step:  # Skip empty parts (if any)
                
                # I removed the full stop because contextcite treats the step number itself as a new sentence
                furnished_steps.append(str(i) + " " + step)  # Replace with number (1, 2, 3...)
                i += 1
        
        return furnished_steps

    def convert_dashes_incremental_steps(self, step):

        '''
        Returns str
        '''

        furnished_steps = self.convert_dashes_incremental_steps_list(step)

        final_str = "Step-by-Step Answer:\n"

        # Added a \n to better separate the steps
        final_str += "\n".join(furnished_steps)

        return final_str


    def process_model_responses_for_analysis(self):
        
        # Load train for instructions
        mgsm_train = GSMDataset(self.dataset, self.model.tokenizer, split='train', config=self.config)
        
        mgsm_test = GSMDataset(self.dataset, self.model.tokenizer, instructions='', split='test', config=self.config)
        
        mgsm_generation_df = pd.read_csv('results\mgsm_en_Qwen2.5-1.5B-Instruct_results.csv')
        mgsm_generations = mgsm_generation_df['response'].tolist()
        
        all_steps = []
        all_gen_final_ans = []
        all_answer_strings = []  # For storing the last line
        
        for response in mgsm_generations:
            # Split response by newlines
            lines = response.strip().split('\n')
            
            # Extract the last line as the answer string
            answer_string = lines[-1]
            all_answer_strings.append(answer_string)
            
            # Use all lines except the last for steps
            remaining_response = '\n'.join(lines[:-1])
            steps = remaining_response.split("\n-")
                
            steps_str = self.convert_dashes_incremental_steps(steps)
            all_steps.append(steps_str)
            
            # Extract numerical answer
            gen_final_ans = extract_answer_gsm(response)
            all_gen_final_ans.append(gen_final_ans)
        
        # Combine each question with mgsm_train.instructions
        instructions = mgsm_train.instructions + '\n\n' if self.is_cot else ''
        
        # Get questions as a list
        questions = mgsm_test.dataset['question']
        
        # Create a list of questions with instructions prepended to each
        question_list = [instructions + q for q in questions]
        
        actual_answer = mgsm_test.dataset['answer_number']
        
        # Create DataFrame with all columns
        percentile_list = pd.DataFrame(
            data=zip(question_list, actual_answer, all_steps, all_gen_final_ans, all_answer_strings), 
            columns=self.df_column_names
        )
        
        percentile_list.to_csv(processed_data_path, index=False)



## Main

In [31]:
# This will always be true. 
# I think you meant to use __name__ == '__main__' but this does not work in Jupyter Notebooks
if '__main__':
    context_model, dataset = load_model_dataset()
    
    responseProcessing = ResponseProcessing(context_model, dataset)
    responseProcessing.process_model_responses_for_analysis()

    

Some parameters are on the meta device because they were offloaded to the cpu.
Device set to use cuda:0


## Inferencing

### Steps:
 1. read from "analysis_{model_name}"
 2. pass in model_generated_steps and query
 3. Check if there answer matches with our answer (I think it might be worthwile to also check wrong answers.)
 4. If yes, then use cc.getattribution() to attribution [contextCite](https://github.com/MadryLab/context-cite)
 5. Save the np.array to the respective row of the "analysis_{model_name}" set

In [2]:
context_model = Model(ModelNames.QwenInstruct)

# Unlike RAG, the context follows the query
prompt_template = '{query}\n{context}'

model_responses = pd.read_csv(processed_data_path)
model_responses.head()

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Device set to use cuda:0


Unnamed: 0,question,actual_answer,model_gen_steps,model_gen_answer,model_answer_str
0,Question: Roger has 5 tennis balls. He buys 2 ...,18,"Step-by-Step Answer:\n1 First, calculate the ...",18.0,The answer is 18.<
1,Question: Roger has 5 tennis balls. He buys 2 ...,3,Step-by-Step Answer:\n1 The robe requires 2 b...,3.0,The answer is 3.<
2,Question: Roger has 5 tennis balls. He buys 2 ...,70000,Step-by-Step Answer:\n1 The original price of...,170000.0,The answer is 170000.<
3,Question: Roger has 5 tennis balls. He buys 2 ...,540,Step-by-Step Answer:\n1 James runs 3 sprints ...,3.0,"- Since he runs 3 times a week, he runs a tota..."
4,Question: Roger has 5 tennis balls. He buys 2 ...,20,Step-by-Step Answer:\n1 The total amount of f...,20.0,The answer is 20.<


In [3]:
cite_df = pd.DataFrame()

# Get length of model_responses
len_responses = len(model_responses)

# initialize a progress bar
pbar = tqdm(total=len_responses)
error_counter = 0

# Iterate over the rows of the DataFrame
for index, row in model_responses.iterrows():
    pbar.update(1)
    context = row['model_gen_steps']
    query = row['question']
    answer_string = row['model_answer_str']
    
    # Abstain from pre-train because it creates a new model each time
    # Constructor is needed due to processing during initialization
    cc = ContextCiter(context_model.model, context_model.tokenizer, context, query, prompt_template=prompt_template)
    
    # We want to use precomputed answers
    # See https://github.com/MadryLab/context-cite/issues/4
    _, prompt = cc._get_prompt_ids(return_prompt=True)
    cc._cache["output"] = prompt + answer_string
    
    # This returns an importance for each line in the context
    # The progress bar is annoying
    line_importance = cc.get_attributions(as_dataframe=False, verbose=False)
    
    # Get each line and importance and add to df
    lines = context.split('\n')
    
    # If number of lines and importance values do not match, raise an error
    if len(lines) != len(line_importance):
        print(f"Number of lines ({len(lines)}) and importance values ({len(line_importance)}) do not match in example {index} Skipping...")
        error_counter += 1
        continue
    
    # Create a temporary DataFrame with sample_index to identify which example each line belongs to
    temp_df = pd.DataFrame({
        'sample_index': index,  # Use the DataFrame index as sample index
        'line': lines,
        'importance': line_importance
    })
    
    cite_df = pd.concat([cite_df, temp_df], ignore_index=True)
    
pbar.close()
print(f"Number of errors: {error_counter}")

  0%|          | 0/250 [00:00<?, ?it/s]

Number of lines (4) and importance values (6) do not match in example 6 Skipping...
Number of lines (6) and importance values (11) do not match in example 7 Skipping...
Number of lines (5) and importance values (7) do not match in example 8 Skipping...
Number of lines (5) and importance values (7) do not match in example 10 Skipping...
Number of lines (4) and importance values (7) do not match in example 11 Skipping...
Number of lines (9) and importance values (11) do not match in example 14 Skipping...
Number of lines (4) and importance values (7) do not match in example 34 Skipping...
Number of lines (9) and importance values (10) do not match in example 38 Skipping...
Number of lines (6) and importance values (7) do not match in example 39 Skipping...
Number of lines (3) and importance values (5) do not match in example 40 Skipping...
Number of lines (4) and importance values (5) do not match in example 42 Skipping...
Number of lines (9) and importance values (11) do not match in ex

In [None]:
# Store results as JSON
import json
import numpy as np

# Create a list to store one dictionary per question
result_list = []

# Custom JSON encoder to handle NumPy types
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.integer, np.int64)):
            return int(obj)
        elif isinstance(obj, (np.floating, np.float64)):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)

for sample_index, group in cite_df.groupby('sample_index'):
    original_row = model_responses.iloc[sample_index]
    
    # Create a dictionary for this sample
    sample_dict = {
        'sample_index': sample_index,  # No need to manually convert
        'question': original_row['question'],
        'actual_answer': int(original_row['actual_answer']),
        'model_gen_answer': int(original_row['model_gen_answer']),
        'model_answer_str': original_row['model_answer_str'],
        'lines_and_importance': [
            {'text': row['line'], 'importance': row['importance']} 
            for _, row in group.iterrows()
        ]  # No need to manually convert
    }
    
    # Add this dictionary to our results list
    result_list.append(sample_dict)

# Save as JSON file with proper formatting and custom encoder
with open('results/contextcite_en_QwenInstruct_COT.json', 'w', encoding='utf-8') as f:
    json.dump(result_list, f, ensure_ascii=False, indent=2, cls=NumpyEncoder)