In [None]:
import sys
import os
import nltk
import json
import nest_asyncio
import matplotlib.pyplot as plt
import numpy as np
import torch
import time        
import warnings

from IPython.display import display
from huggingface_hub import login
from llama_cloud_services import LlamaParse
from dotenv import load_dotenv
from context_cite import ContextCiter
from context_cite.utils import aggregate_logit_probs
from context_cite.context_partitioner import SentencePeriodPartitioner
from transformers import AutoTokenizer, AutoModelForCausalLM
from scipy.stats import spearmanr

In [None]:
load_dotenv()
nest_asyncio.apply()
warnings.filterwarnings("ignore")
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
# login(token=os.getenv("HF_TOKEN"))
# nltk.download('punkt_tab')
parser = LlamaParse(api_key=os.getenv("LLAMA_CLOUD_API_TOKEN"))

In [None]:
document = 'documents/sample.txt'
MODEL_NAME = "Llama-3.2-1B-Instruct"
model_name = "meta-llama/Llama-3.2-1B-Instruct" # 3.2 1B Instruct for faster inference, 3.1 8B for better performance

In [None]:
# docs = parser.load_data(document)
# data = ""
# for doc in docs:
#     if len(doc.text) >= 32:
#         data += doc.text + " "
# len(data)

In [None]:
def plot(cc: ContextCiter, path: str = None):
    pred_logs = cc._logit_probs
    pred_logits = aggregate_logit_probs(pred_logs)
    actu_logits = cc._actual_logit_probs

    preds = pred_logits.flatten()
    actus = actu_logits.flatten()
    assert len(preds) == len(actus), f"{len(preds)} != {len(actus)}"

    # Compute Spearman correlation without modifying the actual data
    corr, _ = spearmanr(preds, actus)  # ✅ Correct way to compute Spearman correlation

    plt.figure(figsize=(8, 8))
    plt.scatter(preds, actus, alpha=0.3, label="Context ablations")  # Scatter plot

    # Plot y = x reference line
    x_line = np.linspace(min(preds.min(), actus.min()), max(preds.max(), actus.max()), 100)
    plt.plot(x_line, x_line, '--', color='gray', label="y = x")

    # Labels and title
    plt.xlabel("Predicted log-probability")
    plt.ylabel("Actual log-probability")
    plt.title(f"Predicted vs. Actual log-probability\nSpearman correlation: {corr:.2f}")
    plt.legend()
    plt.grid(True)

    if path:
        plt.savefig(path)
    plt.show()

In [None]:
def input_handler(path: str) -> str:
	if path.endswith(".pdf"):
		docs = parser.load_data(path)
		data = ""
		for doc in docs:
			if len(doc.text) >= 32:
				data += doc.text + " "
		return data
	elif path.endswith(".txt"):
		with open(path, "r") as file:
			data = file.read()
		return data
	else:
		raise ValueError("Invalid file format")

In [None]:
cc = ContextCiter.from_pretrained(
	model_name,
	context=input_handler(document),
	query="What is Transformer?",
	device="cuda",
	num_ablations=64
)

In [None]:
res = cc.get_attributions(as_dataframe=True, top_k=5)
res