In [None]:
!pip install torch
!pip install transformers
!pip install --upgrade datasets
!pip install tqdm

import torch
import json
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import math

from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import HfApi, Repository, create_repo, login
from datasets import Dataset, DatasetDict, load_dataset
from tqdm import tqdm

HF_L = "xxx"
login(token=HF_L)
model_type = 'falcon'

In [None]:
if model_type == 'llama':
    model_name = "meta-llama/Llama-2-7b-hf"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32,
        attn_implementation="eager"
    ).to("cuda")

elif model_type == 'gemma':
    model_name = "google/gemma-7b-it"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32,
        attn_implementation="eager"
    ).to("cuda")

elif model_type == 'falcon':
    model_name = "tiiuae/falcon-7b-instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float32
    ).to("cuda")

elif model_type == 'mistral':
    model_name = "mistralai/Mistral-7B-Instruct-v0.2"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32,
        attn_implementation="eager"
    ).to("cuda")

else:
    raise ValueError(f"Unknown model type: {model_type}")

model.config.output_attentions = True
model.config.output_hidden_states = False

In [None]:
dataset = load_dataset("Ramitha/unique-records-aggregated-40-scores")
df = pd.DataFrame(dataset['rawcases'])

In [None]:
selected_rows = []
selection_limit = 6

for dataset_name, group in df.groupby("dataset"):
    group_sorted = group.sort_values("ILRAlign").reset_index(drop=True)    
    # Lowest 
    lowest = group_sorted.head(selection_limit)    
    # Middle 
    mid_start = len(group_sorted) // 2 - 3
    middle = group_sorted.iloc[mid_start:mid_start+selection_limit]    
    # Highest 
    highest = group_sorted.tail(selection_limit)
    selected = pd.concat([lowest, middle, highest])
    selected_rows.append(selected)

final_selection = pd.concat(selected_rows, ignore_index=True)

In [None]:
DEFAULT_STEPS = 10

def get_num_layers(model):
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return len(model.model.layers)
    if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
        return len(model.transformer.h)
    if hasattr(model, "layers"):
        return len(model.layers)
    raise RuntimeError("Unable to determine number of layers")

def get_layer_module(model, layer_idx):
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return model.model.layers[layer_idx]
    if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
        return model.transformer.h[layer_idx]
    if hasattr(model, "layers"):
        return model.layers[layer_idx]
    raise RuntimeError("Unable to access layers")

def capture_layer_output(model, layer_idx, input_embeds):
    hidden_acts = {}
    def hook_fn(module, inp, out):
        hidden_acts["out"] = out[0] if isinstance(out, tuple) else out
    handle = get_layer_module(model, layer_idx).register_forward_hook(hook_fn)
    _ = model(inputs_embeds=input_embeds, output_hidden_states=False, return_dict=True)
    handle.remove()
    return hidden_acts["out"]

def integrated_gradients_batched(prompt, layer_idx, tokenizer, model,
                                 steps=DEFAULT_STEPS, baseline="zero", max_length=128):
    device = next(model.parameters()).device
    model.eval()
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length).to(device)
    input_ids = inputs["input_ids"]
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
    embedding_layer = model.get_input_embeddings()
    input_embeds = embedding_layer(input_ids)
    if baseline == "zero":
        baseline_embeds = torch.zeros_like(input_embeds, device=device)
    else:
        pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
        pad_ids = torch.full_like(input_ids, pad_id)
        baseline_embeds = embedding_layer(pad_ids)
    alphas = torch.linspace(0, 1, steps + 1, device=device).view(-1, 1, 1, 1)
    interpolated = baseline_embeds + alphas * (input_embeds - baseline_embeds)
    s_steps, s_batch, s_seq, s_hidden = interpolated.shape
    big_batch = interpolated.view(s_steps * s_batch, s_seq, s_hidden).clone().detach().requires_grad_(True)
    # Capture layer output
    target_hidden = capture_layer_output(model, layer_idx, big_batch)
    scalar_score = target_hidden.norm(dim=-1).sum()
    model.zero_grad(set_to_none=True)
    if big_batch.grad is not None:
        big_batch.grad.zero_()
    scalar_score.backward(retain_graph=False)
    grads = big_batch.grad.view(s_steps, s_batch, s_seq, s_hidden).mean(dim=0)
    delta = (input_embeds - baseline_embeds)
    integrated_grads = delta * grads
    token_scores = integrated_grads.mean(dim=-1).detach().cpu()
    return tokens, token_scores[0]

def get_token_speciality_percentages(text, tokenizer, model, target_layer_ratio=0.4,
                                     steps=DEFAULT_STEPS, baseline="zero", eps=1e-8):
    num_layers = get_num_layers(model)
    target_layer = int((num_layers - 1) * target_layer_ratio)

    token_layer_ig = {}
    token_layer_counts = {}
    ordered_tokens = None
    for layer_idx in range(num_layers):
        tokens, ig_scores = integrated_gradients_batched(text, layer_idx, tokenizer, model,
                                                         steps=steps, baseline=baseline)
        if ordered_tokens is None:
            ordered_tokens = tokens
        for tok, score in zip(tokens, ig_scores):
            token_layer_ig.setdefault(tok, {}).setdefault(layer_idx, 0.0)
            token_layer_counts.setdefault(tok, {}).setdefault(layer_idx, 0)
            token_layer_ig[tok][layer_idx] += float(score.item())
            token_layer_counts[tok][layer_idx] += 1
    avg_token_layer_ig = {
        tok: {layer_idx: total / max(1, token_layer_counts[tok][layer_idx])
              for layer_idx, total in layer_dict.items()}
        for tok, layer_dict in token_layer_ig.items()
    }
    token_percentages = {}
    for tok, layer_scores in avg_token_layer_ig.items():
        target_score = layer_scores.get(target_layer, 0.0)
        total_abs = sum(abs(v) for v in layer_scores.values())
        percent = target_score / (total_abs + eps) * 100
        token_percentages[tok] = percent
    result = [(tok, float(token_percentages.get(tok, 0.0))) for tok in ordered_tokens]
    return result

def integrated_gradients_layer_percentages(prompt, layer_idx, tokenizer, model,
                                           steps=DEFAULT_STEPS, baseline="zero", 
                                           max_length=128, eps=1e-8):
    device = next(model.parameters()).device
    model.eval()
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length).to(device)
    input_ids = inputs["input_ids"]
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
    embedding_layer = model.get_input_embeddings()
    input_embeds = embedding_layer(input_ids)
    if baseline == "zero":
        baseline_embeds = torch.zeros_like(input_embeds, device=device)
    else:
        pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
        pad_ids = torch.full_like(input_ids, pad_id)
        baseline_embeds = embedding_layer(pad_ids)
    alphas = torch.linspace(0, 1, steps + 1, device=device).view(-1, 1, 1, 1)
    interpolated = baseline_embeds + alphas * (input_embeds - baseline_embeds)
    s_steps, s_batch, s_seq, s_hidden = interpolated.shape
    big_batch = interpolated.view(s_steps * s_batch, s_seq, s_hidden).clone().detach().requires_grad_(True)
    target_hidden = capture_layer_output(model, layer_idx, big_batch)
    scalar_score = target_hidden.norm(dim=-1).sum()
    model.zero_grad(set_to_none=True)
    if big_batch.grad is not None:
        big_batch.grad.zero_()
    scalar_score.backward()
    grads = big_batch.grad.view(s_steps, s_batch, s_seq, s_hidden).mean(dim=0)
    delta = input_embeds - baseline_embeds
    integrated_grads = delta * grads
    token_scores = integrated_grads.mean(dim=-1).detach().cpu()[0]
    abs_scores = token_scores.abs()
    total = abs_scores.sum().item() + eps
    percentages = [(tok, (abs_scores[i].item() / total) * 100) for i, tok in enumerate(tokens)]

    return percentages

def get_token_layer_percentages(text, tokenizer, model, target_layer_ratio=0.4,
                                steps=DEFAULT_STEPS, baseline="zero", eps=1e-8):
    num_layers = get_num_layers(model)
    target_layer = int((num_layers - 1) * target_layer_ratio)
    return integrated_gradients_layer_percentages(text, target_layer, tokenizer, model,
                                                  steps=steps, baseline=baseline, eps=eps)

In [None]:
HF_DATASET = "Ramitha/unique-records-selected-integrated-gradients"
dataset = load_dataset(HF_DATASET)
df = pd.DataFrame(dataset['rawcases'])

q_col = "question"
a_col = "answer"
q_col_out = f"question_ig_tokens_{model_type}"
a_col_out = f"answer_ig_tokens_{model_type}"
q_raw_col_out = f"question_raw_ig_tokens_{model_type}"
a_raw_col_out = f"answer_raw_ig_tokens_{model_type}"

df[q_col_out] = pd.Series([None] * len(df), dtype="object")
df[a_col_out] = pd.Series([None] * len(df), dtype="object")
df[q_raw_col_out] = pd.Series([None] * len(df), dtype="object")
df[a_raw_col_out] = pd.Series([None] * len(df), dtype="object")

for idx, row in tqdm(df.iterrows(), total=len(df), desc="Computing IG scores"):
    q_scores = get_token_speciality_percentages(row[q_col], tokenizer, model)
    a_scores = get_token_speciality_percentages(row[a_col], tokenizer, model)
    q_raw_scores = get_token_layer_percentages(row[q_col], tokenizer, model)
    a_raw_scores = get_token_layer_percentages(row[a_col], tokenizer, model)
    df.at[idx, q_col_out] = q_scores
    df.at[idx, a_col_out] = a_scores
    df.at[idx, q_raw_col_out] = q_raw_scores
    df.at[idx, a_raw_col_out] = a_raw_scores

df[q_col_out] = df[q_col_out].apply(lambda x: json.dumps(x) if x is not None else "[]")
df[a_col_out] = df[a_col_out].apply(lambda x: json.dumps(x) if x is not None else "[]")
df[q_raw_col_out] = df[q_raw_col_out].apply(lambda x: json.dumps(x) if x is not None else "[]")
df[a_raw_col_out] = df[a_raw_col_out].apply(lambda x: json.dumps(x) if x is not None else "[]")

In [None]:
hf_dataset = DatasetDict({
    'rawcases': Dataset.from_pandas(df)
})
hf_dataset.push_to_hub(HF_DATASET)