In [22]:
import dataframe_image as dfi
from IPython.display import display, HTML
from ipywidgets import interact, IntSlider, Output, VBox, HTML as HTMLWidget, Button, HBox
import json
from datasets import load_dataset, Dataset
import textstat
import numpy as np
import re
import pandas as pd
#from gsm8k import SYSTEM_PROMPT
import textwrap
from glob import glob
import asyncio

In [23]:
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

def extract_hash_answer(text: str) -> str | None:
    return text.split("####")[-1].strip() if "####" in text else None

def get_gsm8k_questions(split="train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split]
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    })
    return data

In [24]:
dataset = get_gsm8k_questions(split = "test")

In [None]:
output_files = sorted(glob('lora_checkpoints/*/test_examples_fixed.json'))
responses = {}
for o in output_files:
    with open(o) as f:
        all_outputs = json.load(f)
    assert len(all_outputs) == 6 or len(all_outputs) == 5, len(all_outputs)
    responses[o.split('/')[1]] = all_outputs
len(responses)

In [27]:
pretrained_file = glob('lora_checkpoints/pretrained_examples_fixed.json')
assert len(pretrained_file) == 1, len(pretrained_file)
pretrained_file = pretrained_file[0]
with open(pretrained_file) as f:
    pretrained_outputs = json.load(f)

In [28]:
def get_answer(text: str) -> str:
    return text.split("<answer>")[1].split("</answer>")[0].strip()

def get_answer_rate(responses):
    correct, wrong = 0, 0
    for i, o in enumerate(responses):
        try: 
            ans = int(get_answer(o))
            if ans == int(dataset[i]['answer']):
                correct += 1
            else:
                wrong += 1
        except:
            pass
    return correct / (wrong + correct)

def get_total_average_flesch_kincaid(responses) -> float:
    scores = [textstat.flesch_kincaid_grade(r) for r in responses]
    return sum(scores) / len(scores)

def flesch_kincaid_reward_func(responses) -> float:
    scores = [textstat.flesch_kincaid_grade(r.split('<reasoning>')[-1].split('</reasoning>')[0]) for r in responses]
    return np.mean(scores)

def get_average_length(responses) -> list[float]:
    return sum([len(r) for r in responses]) / len(responses)

def soft_format_reward_func(responses) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
    return sum([1 if match else 0 for match in matches])

def has_reasoning(responses) -> list[float]:
    return sum([1 if ("<reasoning>" in r and '</reasoning>' in r) else 0 for r in responses])

def has_answer(responses) -> list[float]:
    return sum([1 if ("<answer>" in r and "</answer>" in r) else 0 for r in responses])

In [29]:
all_metrics = {}
for n, all_outputs in responses.items():
    metrics = []
    for outputs in all_outputs:
        metrics.append({
            "correct answer\n(when <answer> present)": get_answer_rate(outputs),
            "average_flesch_kincaid": get_total_average_flesch_kincaid(outputs),
            "reasoning flesch kincaid": flesch_kincaid_reward_func(outputs),
            "average_length": get_average_length(outputs),
            "soft_format_reward": soft_format_reward_func(outputs),
            "has reasoning tokens": has_reasoning(outputs),
            "has answer tokens": has_answer(outputs),
        })
    all_metrics[n] = metrics

pretrained_metrics = {
    "correct answer\n(when <answer> present)": get_answer_rate(pretrained_outputs),
    "average_flesch_kincaid": get_total_average_flesch_kincaid(pretrained_outputs),
    "reasoning flesch kincaid": flesch_kincaid_reward_func(pretrained_outputs),
    "average_length": get_average_length(pretrained_outputs),
    "soft_format_reward": soft_format_reward_func(pretrained_outputs),
    "has reasoning tokens": has_reasoning(pretrained_outputs),
    "has answer tokens": has_answer(pretrained_outputs),
}

In [None]:
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

limits = {
    "correct answer\n(when <answer> present)" : [0, 1],
    "soft_format_reward" : [0,100],
    "has reasoning tokens": [0,100],
    "has answer tokens": [0,100]
}
checkpoints = list(range(500, 3001, 500))
fig, axarr = plt.subplots(2, 4, figsize=(20, 10))
axarr = axarr.flatten()
cats = metrics[0].keys()
colors = ['r', 'b', 'g', 'orange']


# Collect handles for method lines only from the first subplot
method_handles = []

for ax_idx, (ax, cat) in enumerate(zip(axarr, cats)):
    for i, (name, metrics) in enumerate(all_metrics.items()):
        y = [metrics[i][cat] for i in range(0, len(metrics))]
        line, = ax.plot(checkpoints[:len(metrics)], y, marker='o', color=colors[i], label=name)
        if ax_idx == 0:
            method_handles.append(line)
        if cat in limits:
            ax.set_ylim(limits[cat])
    ax.set_title(cat, fontsize=24)
    ax.tick_params(axis='both', labelsize=16)
    ax.set_xlabel('Num Steps', fontsize=16)

labels = list(all_metrics.keys())
# Create a custom handle for the pretrained dashed line
plot_pretrained = True
if plot_pretrained:
    for ax_idx, (ax, cat) in enumerate(zip(axarr, cats)):
        ax.axhline(pretrained_metrics[cat], color='black', linestyle='--')
    pretrained_handle = Line2D([0], [0], color='black', linestyle='--', label='Pretrained')
    labels.append('Pretrained')
    method_handles.append(pretrained_handle)


# Add global legend with method handles + pretrained handle
fig.legend(handles=method_handles,
           labels=labels,
           loc='upper center', ncol=len(all_metrics) + 1, fontsize=16, bbox_to_anchor=(0.5, 1.05))

plt.tight_layout(rect=[0, 0, 1, 0.95])



In [None]:
class ResponseVisualizer:
    def __init__(self,
                 model_responses,
                 questions,
                 checkpoint_names,
                 answers,
                 pretrained_outputs):
        """
        Args:
            model_responses: dict mapping model_name -> list of checkpoint_response_lists
                             (EXCLUDING pretrained; these should correspond to checkpoint_names[1:])
            questions:       list of N questions
            checkpoint_names: list of len = 1 + num_checkpoints, e.g. ['pretrained','cp1','cp2',…]
            answers:         list of N answers
            pretrained_outputs: list of N strings, one “pretrained” response per question
        """
        self.model_responses    = model_responses
        self.questions          = questions
        self.answers            = answers
        self.pretrained_outputs = pretrained_outputs

        self.model_names    = list(model_responses.keys())
        self.num_models     = len(self.model_names)
        self.checkpoint_names = checkpoint_names
        # now includes pretrained as first row
        self.num_checkpoints = len(checkpoint_names)
        self.num_questions  = len(questions)

    def wrap_text(self, text, width=80):
        return '\n'.join(textwrap.wrap(text, width=width))

    def highlight_special_tokens(self, text):
        # escape then re-insert span tags around <tokens>
        escaped = text.replace('<','&lt;').replace('>','&gt;')
        token_pattern = r'(&lt;/?[a-zA-Z_]+&gt;)'
        return re.sub(
            token_pattern,
            lambda m: f'<span style="background-color:#fffa8c;'
                      f'font-weight:bold; padding:2px 4px; border-radius:3px;">'
                      f'{m.group(1)}</span>',
            escaped
        ).replace('&lt;span','<span').replace('&lt;/span&gt;','</span>')

    def display_responses_table(self, question_idx):
        if not (0 <= question_idx < self.num_questions):
            raise IndexError(f"Question index must be 0 ≤ idx < {self.num_questions}")

        records = []
        for model in self.model_names:
            for cp_idx, cp_name in enumerate(self.checkpoint_names):
                if cp_name == 'pretrained':
                    # take from the separate list
                    resp = self.pretrained_outputs[question_idx]
                else:
                    # offset by 1 because model_responses lists exclude pretrained
                    model_cp_idx = cp_idx - 1
                    try:
                        resp = self.model_responses[model][model_cp_idx][question_idx]
                    except Exception:
                        resp = ''
                html_resp = self.highlight_special_tokens(resp)
                records.append({
                    'Model': model,
                    'Checkpoint': cp_name,
                    'Response': html_resp
                })

        df = pd.DataFrame(records)
        df_pivot = df.pivot(index='Checkpoint', columns='Model', values='Response')
        # enforce ordering
        df_pivot = df_pivot.reindex(self.checkpoint_names)

        # header + answer
        q_html = f"""
        <div style="background:#eef5fa; padding:12px;
                    border-left:6px solid #1a73e8; margin-bottom:15px;border-radius:4px;">
          <h3 style="margin:0;">Q{question_idx}: {self.questions[question_idx]}</h3>
        </div>
        <div style="background:#fff8e1; padding:10px;
                    border-left:6px solid #f39c12; margin-bottom:15px;border-radius:4px;">
          <strong>Answer:</strong> {self.answers[question_idx]}
        </div>
        """

        styled = (
            df_pivot.style
                    .set_properties(**{
                        'white-space':'pre-wrap',
                        'text-align':'left',
                        'vertical-align':'top',
                        'padding':'8px',
                        'line-height':'1.4'
                    })
                    .set_table_styles([
                        {'selector':'th', 
                         'props':[
                            ('background-color','#1a73e8'),
                            ('color','white'),
                            ('font-size','13px'),
                            ('text-align','center'),
                            ('padding','8px')
                         ]},
                        {'selector':'tr:nth-child(even)',
                         'props':[('background-color','#f7f7f7')]},
                        {'selector':'td',
                         'props':[
                            ('border','1px solid #ddd'),
                            ('max-width','500px'),
                            ('font-size','14px')
                         ]}
                    ])
                    .set_caption("<b>Responses (models × checkpoints)</b>")
                    .set_uuid("")  # stable HTML
                    .to_html(escape=False)
        )

        return HTML(q_html + styled), df_pivot

    def interactive_visualizer(self):
        slider   = IntSlider(value=0, min=0, max=self.num_questions-1,
                             step=1, description='Question:',
                             continuous_update=False,
                             style={'description_width':'initial'})
        out      = Output()
        header   = HTMLWidget(value=f"<h2>Response Visualizer ({self.num_models} models, "
                                    f"{self.num_checkpoints} checkpoints)</h2>")
        save_btn = Button(description='Save as PNG')

        def update(idx):
            with out:
                out.clear_output(wait=True)
                table_html, df_pivot = self.display_responses_table(idx)
                display(table_html)

                # fig_len = self.display_length_comparison(idx)
                # plt.show()

                # ta = self.display_token_analysis(idx)
                # if hasattr(ta, '__html__'):
                #     display(ta)
                # else:
                #     plt.show()

                # stash for saving
                self._last_df   = df_pivot
                # self._last_figs = [fig_len]
                # tok_fig = self.display_token_analysis(idx)
                # if not hasattr(tok_fig, '__html__'):
                #     self._last_figs.append(tok_fig)

        def save_snapshot(_):
            idx = slider.value
            # --- 1) Render table as a Matplotlib figure and save ---
            fig_table, ax = plt.subplots(figsize=(12, 8))
            ax.axis('off')
            # Build wrapped cell text
            cell_text = [
                [
                  textwrap.fill(self._last_df.loc[cp, model], width=40)
                  for model in self._last_df.columns
                ]
                for cp in self._last_df.index
            ]
            tbl = ax.table(
                cellText=cell_text,
                rowLabels=self._last_df.index.tolist(),
                colLabels=self._last_df.columns.tolist(),
                cellLoc='left',
                loc='center'
            )
            tbl.auto_set_font_size(False)
            tbl.set_fontsize(10)
            fig_table.tight_layout()
            fig_table.savefig(f"question_{idx}_table.png", dpi=300, bbox_inches='tight')
            plt.close(fig_table)

            # --- 2) Save all stored Matplotlib figures ---
            for i, fig in enumerate(self._last_figs):
                fig.savefig(f"question_{idx}_fig{i}.png", dpi=300, bbox_inches='tight')
                plt.close(fig)

            print(f"Saved: question_{idx}_table.png + {len(self._last_figs)} figure(s).")

        save_btn.on_click(save_snapshot)
        interact(update, idx=slider)
        controls = HBox([slider, save_btn], layout={'align_items':'center','spacing':'20px'})
        return VBox([header, controls, out])
    
    def count_special_tokens(self, text):
        """
        Count special tokens (like <reasoning>) in the text.
        
        Args:
            text: The response text that may contain special tokens.
            
        Returns:
            Dictionary of token counts
        """
        import re
        
        # Define pattern for tokens like <token_name> and </token_name>
        token_pattern = r'</?([a-zA-Z_]+)>'
        
        # Find all tokens
        tokens = re.findall(token_pattern, text)
        
        # Count token occurrences
        token_counts = {}
        for token in tokens:
            if token in token_counts:
                token_counts[token] += 1
            else:
                token_counts[token] = 1
        
        return token_counts
    
    def display_length_comparison(self, question_idx):
        """
        Display a bar chart comparing response lengths across models and checkpoints.
        
        Args:
            question_idx: Index of the question to display response lengths for.
        """
        if not 0 <= question_idx < self.num_questions:
            print(f"Question index must be between 0 and {self.num_questions-1}")
            return
        
        # Prepare data for plotting
        fig, ax = plt.subplots(figsize=(12, 6))
        
        bar_width = 0.8 / self.num_models
        checkpoint_positions = np.arange(self.num_checkpoints)
        
        for i, model_name in enumerate(self.model_names):
            response_lengths = [len(self.model_responses[model_name][cp][question_idx]) 
                               for cp in range(self.num_checkpoints)]
            
            positions = checkpoint_positions + (i * bar_width)
            ax.bar(positions, response_lengths, width=bar_width, label=model_name)
        
        # Set labels and title
        ax.set_xlabel('Checkpoint')
        ax.set_ylabel('Response Length (characters)')
        ax.set_title(f'Response Length Comparison for Question {question_idx}')
        ax.set_xticks(checkpoint_positions + bar_width * (self.num_models - 1) / 2)
        ax.set_xticklabels(self.checkpoint_names, rotation=45, ha='right')
        ax.legend()
        
        plt.tight_layout()
        return fig
    
    def display_token_analysis(self, question_idx):
        """
        Display analysis of special tokens in responses.
        
        Args:
            question_idx: Index of the question to analyze tokens for.
        """
        if not 0 <= question_idx < self.num_questions:
            print(f"Question index must be between 0 and {self.num_questions-1}")
            return
        
        # Collect token counts for all models and checkpoints
        all_tokens = set()
        token_data = {}
        
        for model_name in self.model_names:
            token_data[model_name] = []
            
            for cp in range(self.num_checkpoints):
                response = self.model_responses[model_name][cp][question_idx]
                token_counts = self.count_special_tokens(response)
                token_data[model_name].append(token_counts)
                all_tokens.update(token_counts.keys())
        
        # Create plots for token analysis
        if not all_tokens:
            # No tokens found
            return HTML("<div style='padding: 10px; background-color: #f8f9fa; border-left: 5px solid #6c757d;'><p>No special tokens detected in the responses for this question.</p></div>")
        
        # Sort tokens alphabetically
        all_tokens = sorted(list(all_tokens))
        
        # Create a figure with subplots for each token
        n_tokens = len(all_tokens)
        fig, axes = plt.subplots(nrows=n_tokens, figsize=(14, 4*n_tokens))
        
        # Handle the case when there's only one token
        if n_tokens == 1:
            axes = [axes]
            
        # Get bar positions
        bar_width = 0.8 / self.num_models
        checkpoint_positions = np.arange(self.num_checkpoints)
        
        for token_idx, token in enumerate(all_tokens):
            ax = axes[token_idx]
            
            # Prepare data
            for i, model_name in enumerate(self.model_names):
                token_counts = [token_data[model_name][cp].get(token, 0) for cp in range(self.num_checkpoints)]
                
                positions = checkpoint_positions + (i * bar_width)
                ax.bar(positions, token_counts, width=bar_width, label=model_name if token_idx == 0 else "")
            
            # Add title and labels
            ax.set_title(f'Token <{token}> Frequency')
            ax.set_xlabel('Checkpoint')
            ax.set_ylabel('Count')
            ax.set_xticks(checkpoint_positions + bar_width * (self.num_models - 1) / 2)
            ax.set_xticklabels(self.checkpoint_names, rotation=45, ha='right')
            
            # Add grid for readability
            ax.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add legend to the first subplot only
        if n_tokens > 0:
            axes[0].legend(loc='upper right')
        
        plt.tight_layout()
        return fig
    

    
# If running in a Jupyter notebook environment, this will display the interactive visualizer
questions = [dataset[i]['question'] for i in range(100)]
answers = [dataset[i]['answer'] for i in range(100)]
checkpoint_names = ['pretrained'] + checkpoints
visualizer = ResponseVisualizer(responses, questions, checkpoint_names, answers, pretrained_outputs)
visualizer.interactive_visualizer()