In [None]:
import os
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.manifold import TSNE
from datasets import load_dataset

# random seed
seed_list = [22, 32, 42, 52, 62]

# ============== Plot style + group colors ==============
mpl.rcParams.update({
    "font.family": "serif",
    "font.serif": ["DejaVu Serif"],
    "font.size": 16,
    "axes.titlesize": 18,
    "axes.labelsize": 16,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "lines.linewidth": 2,
    "axes.linewidth": 1.2,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "legend.fontsize": 16,
    "figure.dpi": 150,
})

# Darker color version
GROUP_COLORS = {
    "Original": "#0b4a8a",   # Deep blue (Harry Potter book)
    "items":    "#1e7a1e",   # Deep green (Synthetic)
    "Textbook": "#5b2b8a",   # Deep purple (Textbook_HP)
}

def get_group_color(label):
    if "forget" in label.lower():
        return GROUP_COLORS["Original"]
    elif "synthetic" in label.lower():
        return GROUP_COLORS["items"]
    elif "textbook" in label.lower():
        return GROUP_COLORS["Textbook"]
    else:
        return "gray"

# ============== Utility functions ==============
def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

def read_txt_file(path):
    with open(path, "r", encoding="utf-8") as f:
        return [line.strip() for line in f if line.strip()]

def chunk_lines_to_sentences(lines, tokenizer, max_tokens=128):
    """Chunk text by tokenizer token count."""
    chunks, current_chunk, current_len = [], [], 0
    for line in lines:
        tokens = tokenizer.tokenize(line)
        if not tokens:
            continue
        for token in tokens:
            current_chunk.append(token)
            current_len += 1
            if current_len >= max_tokens:
                text_chunk = tokenizer.convert_tokens_to_string(current_chunk)
                chunks.append(text_chunk.strip())
                current_chunk, current_len = [], 0
    if current_chunk:
        chunks.append(tokenizer.convert_tokens_to_string(current_chunk).strip())
    return chunks

# ============== Main pipeline ==============
if __name__ == "__main__":
    dataset_paths = [
        "../Synthetic_data/HP/parameters/synthetic_Harry_Potter_both_itemsperround6_maxround4_mianum1000_jailbreaknum1000.txt",
        "../Scripts/HP/data/books/raw/forget.txt",
    ]
    datasets = {os.path.basename(p): p for p in dataset_paths}
    outdir = "Figures/Relevance_Sentence_Centroid"
    ensure_dir(outdir)

    model_name = "meta-llama/Llama-2-7b-hf"
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    print("\n=== Sentence-level chunking (â‰ˆ128 tokens) ===")
    name2sentences = {}
    for name, path in datasets.items():
        lines = read_txt_file(path)
        chunks = chunk_lines_to_sentences(lines, tokenizer, max_tokens=128)
        name2sentences[name] = chunks
        print(f"{name:>40} | {len(chunks)} chunks")

    # ---- Added: load Textbook_HP from HuggingFace ----
    print("\n=== Loading Textbook_HP dataset from HuggingFace ===")
    ds = load_dataset("WhyTheMoon/textbook_hp", split="gpt_4o_mini")
    texts = [d["text"] for d in ds]
    textbook_chunks = chunk_lines_to_sentences(texts, tokenizer, max_tokens=128)
    name2sentences["Textbook_HP"] = textbook_chunks
    print(f"{'Textbook_HP':>40} | {len(textbook_chunks)} chunks")

    # ---- Model embeddings ----
    print("\n=== Extracting embeddings from model ===")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModel.from_pretrained(
        model_name, torch_dtype=torch.float16, trust_remote_code=True
    ).to(device)
    model.eval()

    def get_embedding(text):
        """Mean-pool last hidden states + whitening + L2 normalization."""
        with torch.no_grad():
            inputs = tokenizer(
                text, return_tensors="pt", truncation=True, max_length=512
            ).to(device)
            outputs = model(**inputs, output_hidden_states=True)
            hidden = outputs.last_hidden_state
            mask = inputs.attention_mask.unsqueeze(-1).expand(hidden.size()).float()
            mean = torch.sum(hidden * mask, dim=1) / torch.clamp(mask.sum(1), min=1e-9)
            emb = mean.squeeze(0).float().cpu().numpy()
            emb = (emb - emb.mean()) / (emb.std() + 1e-9)
            return emb / (np.linalg.norm(emb) + 1e-9)

    Domain = "Harry Potter"
    Domain_emb = get_embedding(Domain)

    name2embs = {}
    N = 1000
    for i in range(1, len(seed_list)):
        np.random.seed(seed_list[i])
        for name, sents in name2sentences.items():
            if not sents:
                name2embs[name] = np.zeros((0, Domain_emb.shape[0]))
                continue
            sampled_idx = np.random.choice(len(sents), min(N, len(sents)), replace=False)
            sents_sampled = [sents[idx] for idx in sampled_idx]
            embs = [get_embedding(s) for s in sents_sampled]
            name2embs[name] = np.stack(embs)
            print(f"{name:>40} | {len(embs)} embeddings extracted")

        # ---- Centroid analysis + visualization ----
        print("\n=== t-SNE visualization with 3 datasets and centroid distances ===")
        forget_name = [n for n in name2embs.keys() if "forget" in n.lower()][0]
        synthetic_name = [n for n in name2embs.keys() if "synthetic" in n.lower()][-1]
        textbook_name = "Textbook_HP"

        forget_embs = name2embs[forget_name]
        synth_embs = name2embs[synthetic_name]
        textb_embs = name2embs[textbook_name]

        # Select top-200 closest points to the domain embedding (per set)
        forget_dists = np.linalg.norm(forget_embs - Domain_emb, axis=1)
        synth_dists = np.linalg.norm(synth_embs - Domain_emb, axis=1)
        textb_dists = np.linalg.norm(textb_embs - Domain_emb, axis=1)

        top_k = 200
        forget_top_embs = forget_embs[np.argsort(forget_dists)[:top_k]]
        synth_top_embs = synth_embs[np.argsort(synth_dists)[:top_k]]
        textb_top_embs = textb_embs[np.argsort(textb_dists)[:top_k]]

        all_selected = np.vstack(
            [forget_top_embs, synth_top_embs, textb_top_embs, Domain_emb.reshape(1, -1)]
        )
        tsne = TSNE(
            n_components=2,
            perplexity=30,
            n_iter=2000,
            random_state=30,
            init="pca",
            learning_rate="auto",
        )
        emb2d = tsne.fit_transform(all_selected)

        fN, sN, tN = len(forget_top_embs), len(synth_top_embs), len(textb_top_embs)
        forget_2d = emb2d[:fN]
        synth_2d = emb2d[fN:fN + sN]
        textb_2d = emb2d[fN + sN:fN + sN + tN]
        Domain_2d = emb2d[-1]

        # Compute centroids and distances
        fcent2d = np.mean(forget_2d, axis=0)
        scent2d = np.mean(synth_2d, axis=0)
        tcent2d = np.mean(textb_2d, axis=0)

        dist_f2t = np.linalg.norm(fcent2d - Domain_2d)
        dist_s2t = np.linalg.norm(scent2d - Domain_2d)
        dist_t2t = np.linalg.norm(tcent2d - Domain_2d)
        print(
            f"Centroid Distances (t-SNE space): HP={dist_f2t:.2f}, "
            f"Synthetic={dist_s2t:.2f}, Textbook={dist_t2t:.2f}"
        )

        # Plot
        plt.figure(figsize=(9, 8))
        plt.scatter(
            forget_2d[:, 0], forget_2d[:, 1],
            c=GROUP_COLORS["Original"],
            alpha=0.7,
            s=85,
            label="HP book (top-200)"
        )
        plt.scatter(
            synth_2d[:, 0], synth_2d[:, 1],
            c=GROUP_COLORS["items"],
            alpha=0.7,
            s=85,
            label="BiForget_HP (top-200)"
        )
        plt.scatter(
            textb_2d[:, 0], textb_2d[:, 1],
            c=GROUP_COLORS["Textbook"],
            alpha=0.7,
            s=85,
            label="Textbook_HP (top-200)"
        )
        plt.scatter(
            Domain_2d[0], Domain_2d[1],
            c="red",
            s=450,
            marker="*",
            label="Domain: Harry Potter"
        )

        # ---------- Centroids are drawn without per-class legend entries ----------
        for cent, color in [
            (fcent2d, GROUP_COLORS["Original"]),
            (scent2d, GROUP_COLORS["items"]),
            (tcent2d, GROUP_COLORS["Textbook"])
        ]:
            # Draw centroid point (no label)
            plt.scatter(
                cent[0], cent[1],
                c=color,
                s=450,
                marker="P",
                edgecolors="black",
                linewidths=1.0
            )

            # Draw connecting line
            plt.plot(
                [cent[0], Domain_2d[0]],
                [cent[1], Domain_2d[1]],
                color=color,
                linestyle="--",
                lw=4,
                alpha=0.8
            )

            # Annotate distance
            midx, midy = (cent[0] + Domain_2d[0]) / 2, (cent[1] + Domain_2d[1]) / 2
            dist = np.linalg.norm(cent - Domain_2d)
            plt.text(
                midx, midy,
                f"{dist:.2f}",
                color=color,
                fontsize=20,
                fontweight="bold",
                ha="center",
                va="center",
                bbox=dict(
                    facecolor="white",
                    alpha=0.7,
                    edgecolor="none",
                    boxstyle="round,pad=0.2"
                )
            )

        # ---------- Add a single unified centroid legend entry ----------
        plt.scatter(
            [], [],
            c="gray",
            s=180,
            marker="P",
            edgecolors="black",
            linewidths=1.0,
            label="Centroid"
        )

        plt.title(
            f"t-SNE of Top-{top_k} Chunks per Set\nCentroid Distances to Domain",
            fontsize=20
        )
        plt.legend(loc="best")
        plt.xlabel("t-SNE Dimension 1")
        plt.ylabel("t-SNE Dimension 2")
        plt.tight_layout()

        savepath = os.path.join(
            outdir,
            f"Embedding_visualization_tsne_centroids_Textbook_top{top_k}_seed{seed_list[i]}.pdf"
        )
        plt.savefig(savepath, bbox_inches="tight", dpi=300)
        plt.show()
        plt.close()
        print(
            f"Saved centroid-linked t-SNE figure with Textbook_HP: {savepath}"
        )
