In [None]:
!pip install --upgrade datasets

import re
import html
import ast
import json
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import ast
import matplotlib.cm as cm
import nltk

from huggingface_hub import login
from datasets import Dataset, DatasetDict, load_dataset
from IPython.display import display, HTML
from nltk.corpus import stopwords
from collections import defaultdict

nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

In [None]:
dataset = load_dataset("Ramitha/alqa-slsc-newsqa-40-ig")
df = pd.DataFrame(dataset['rawcases'])

In [None]:
def get_row_by_dataset_and_index(dataset_name, idx):
    filtered = df[df['dataset'] == dataset_name]
    filtered = filtered.reset_index(drop=True)
    if 0 <= idx < len(filtered):
        return filtered.loc[idx]
    else:
        raise IndexError(f"Index {idx} out of range for dataset '{dataset_name}'")

In [None]:
def visualize_multiple_token_weights_blocks(rows, labels=None, cmap="RdYlGn"):
    cmap = cm.get_cmap(cmap)
    html_output = ""
    for i, row in enumerate(rows):
        tokens, weights = zip(*row)
        label = labels[i] if labels else f"Row {i+1}"
        vmin, vmax = min(weights), max(weights)
        # normalization
        if vmin < 0 < vmax:
            norm = mcolors.TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)
            guide_text = "Negative (red) → Neutral (yellow) → Positive (green)"
        else:
            norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
            guide_text = "Low (red) → High (green)"
        # Legend
        step = (vmax - vmin) / 10 if vmax > vmin else 1
        values = np.arange(vmin, vmax + step, step)
        gradient_html = "".join([
            f"<span style='display:inline-block; width:20px; height:16px; background:{mcolors.to_hex(cmap(norm(v)))}'></span>"
            for v in values
        ])
        legend_html = f"""
            <div style="margin-bottom:6px; font-family:sans-serif; font-size:12px;">
                <b>{label} color guide:</b><br>
                {guide_text}<br>
                {gradient_html}<br>
                <small>{' '.join([f'{v:.2f}' for v in values])}</small>
            </div>
        """
        # Tokens
        token_html = "<div style='flex:1; display:flex; flex-wrap:wrap; gap:4px;'>"
        for token, w in zip(tokens, weights):
            rgba = cmap(norm(w))
            hex_color = mcolors.to_hex(rgba)
            safe_tok = html.escape(token).replace("▁", "␣")
            token_html += f"""
                <span style='background:{hex_color};
                             padding:4px 6px;
                             border-radius:6px;
                             margin:2px;
                             display:inline-block;
                             color:black;
                             font-family:monospace;'>
                    {safe_tok}
                </span>
            """
        token_html += "</div>"
        html_output += f"""
            <div style='margin:12px 0;'>
                {legend_html}
                <div style='display:flex; align-items:center;'>
                    <div style='width:80px; font-weight:bold;'>{label}</div>
                    {token_html}
                </div>
            </div>
        """
    display(HTML(html_output))

token_cut_off = 4;
def top_k_words(token_probs, original_text):
    if isinstance(token_probs, str):
        import ast
        token_probs = ast.literal_eval(token_probs)
    word_probs = defaultdict(float)
    current_word = ''
    current_sum = 0.0
    for pair in token_probs:
        if not isinstance(pair, (list, tuple)) or len(pair) != 2:
            continue
        token, prob = pair
        cleaned = re.sub(r'[^a-zA-Z0-9]', '', token)
        if not cleaned:
            continue
        if token.startswith('\u2581') or token.startswith('▁') or token.startswith('\u0120'):
            if current_word:
                word_probs[current_word.lower()] += current_sum
            current_word = cleaned
            current_sum = prob
        else:
            current_word += cleaned
            current_sum += prob
    if current_word:
        word_probs[current_word.lower()] += current_sum
    text_words = set(w.lower() for w in re.findall(r'\w+', original_text))
    filtered_probs = {w: p for w, p in word_probs.items() if w in text_words and w not in stop_words}
    sorted_words = sorted(filtered_probs.items(), key=lambda x: x[1], reverse=True)
    return sorted_words[:token_cut_off]

DATASET SELECTION

In [None]:
dataset_select = ['alqa', 'sl', 'newsqa'][0]
index_select = 27
data_row = get_row_by_dataset_and_index(dataset_select, index_select)

RAW IG FOCUS

In [None]:
print("0.4 LAYER RAW FOCUS PERCENTAGE\n")

print('Question: ', data_row["question"])
visualize_multiple_token_weights_blocks(
    [
      ast.literal_eval(data_row["question_raw_ig_tokens_llama"]),
      ast.literal_eval(data_row["question_raw_ig_tokens_falcon"]),
      ast.literal_eval(data_row["question_raw_ig_tokens_mistral"]),
      ast.literal_eval(data_row["question_raw_ig_tokens_gemma"])
    ],
    labels=["LLaMA", "Falcon", "Mistral", "Gemma"]
)

print('\nAnswer: ', data_row["answer"])
visualize_multiple_token_weights_blocks(
    [
      ast.literal_eval(data_row["answer_raw_ig_tokens_llama"]),
      ast.literal_eval(data_row["answer_raw_ig_tokens_falcon"]),
      ast.literal_eval(data_row["answer_raw_ig_tokens_mistral"]),
      ast.literal_eval(data_row["answer_raw_ig_tokens_gemma"])
    ],
    labels=["LLaMA", "Falcon", "Mistral", "Gemma"]
)

0.4 LAYER RAW FOCUS PERCENTAGE

Question:  Who was the applicant in the case of Sen v Doors Plus Pty Ltd [2006] NSWIRComm 85?


  cmap = cm.get_cmap(cmap)



Answer:  In the case of Sen v Doors Plus Pty Ltd [2006] NSWIRComm 85, the applicant was Orkun James Sen.


In [None]:
print('Question: ', data_row["question"])
print('Llama most probable tokens: ', top_k_words(data_row["question_raw_ig_tokens_llama"], data_row["question"]))
print('Falcon most probable tokens: ', top_k_words(data_row["question_raw_ig_tokens_falcon"], data_row["question"]))
print('Mistral most probable tokens: ', top_k_words(data_row["question_raw_ig_tokens_mistral"], data_row["question"]))
print('Gemma most probable tokens: ', top_k_words(data_row["question_raw_ig_tokens_gemma"], data_row["question"]))

print('\nAnswer: ', data_row["answer"])
print('Llama most probable tokens: ', top_k_words(data_row["answer_raw_ig_tokens_llama"], data_row["answer"]))
print('Falcon most probable tokens: ', top_k_words(data_row["answer_raw_ig_tokens_falcon"], data_row["answer"]))
print('Mistral most probable tokens: ', top_k_words(data_row["answer_raw_ig_tokens_mistral"], data_row["answer"]))
print('Gemma most probable tokens: ', top_k_words(data_row["answer_raw_ig_tokens_gemma"], data_row["answer"]))

Question:  Who was the applicant in the case of Sen v Doors Plus Pty Ltd [2006] NSWIRComm 85?
Llama most probable tokens:  [('applicant', 26.120073531585263), ('plus', 10.879095074169316), ('sen', 7.732038423359137), ('pty', 6.064627691369757)]
Falcon most probable tokens:  [('pty', 6.578295202982819), ('v', 4.853455266614696), ('doors', 3.0955045080856345), ('case', 2.6791865284552916)]
Mistral most probable tokens:  [('applicant', 2.2859280316844393), ('sen', 0.5241122307948896), ('doors', 0.30150741735271513), ('v', 0.1812600871215764)]
Gemma most probable tokens:  [('doors', 5.282667465549959), ('plus', 4.5656199015034264), ('pty', 4.365287236925262), ('applicant', 4.265848272099198)]

Answer:  In the case of Sen v Doors Plus Pty Ltd [2006] NSWIRComm 85, the applicant was Orkun James Sen.
Llama most probable tokens:  [('plus', 21.23848561630692), ('sen', 19.63964152555453), ('case', 10.821846229194556), ('doors', 10.821343730696107)]
Falcon most probable tokens:  [('orkun', 8.32146