In [1]:
print("hello")

hello


In [2]:
# %pip install seaborn tdqm

In [3]:
NUM_SAMPLES_PER_PROMPT = 30

In [4]:
SUBSPACE_DIM = "div_basis_ver2.pt"

In [5]:
# ============================================
# 1. セットアップ：ライブラリインストール & インポート
# ============================================

import os
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt


import pandas as pd
import seaborn as sns
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

Using device: cuda


In [6]:
# ============================================
# 2. モデル・トークナイザのロード
# ============================================

# ※環境に合わせてここを書き換えてください
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"  # 例：Qwen2.5-7B-Instruct

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
)
model.eval()

print("Model loaded:", MODEL_NAME)

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

Model loaded: Qwen/Qwen2.5-7B-Instruct


In [7]:
# -----------------------------------------------------------------
# 2. 多様なプロンプトの用意 (カテゴリ分け)
# -----------------------------------------------------------------
# 特定の研究トピックだけでなく、言語モデルが扱う「多様な分布」を網羅します
prompt_categories = {
    "Research": [
        "Explain the concept of attention mechanisms in deep learning.",
        "Propose a method to reduce hallucination in LLMs.",
        "Describe the challenges of reinforcement learning from human feedback.",
    ],
    "Creative": [
        "Write a poem about a lonely satellite orbiting Mars.",
        "Describe a fantasy world where islands float in the sky.",
        "Draft a dialogue between a coffee mug and a tea cup.",
    ],
    "Logic_Math": [
        "Solve this logic puzzle: Three switches are outside a room...",
        "Explain the Pythagorean theorem to a 5-year-old.",
        "Write a step-by-step guide to debugging python code.",
    ],
    "Daily_Life": [
        "Give me a recipe for spicy pasta.",
        "How do I remove a red wine stain from a white shirt?",
        "Suggest an itinerary for a 3-day trip to Tokyo.",
    ],
    "Chat": [
        "Hello, how are you today?",
        "Tell me a joke about programming.",
        "What is your favorite color?",
    ]
}

# リストの平坦化
all_prompts = []
for cat, prompts in prompt_categories.items():
    for p in prompts:
        all_prompts.append({"category": cat, "prompt": p})

print(f"Total base prompts: {len(all_prompts)}")

Total base prompts: 15


In [None]:
# -----------------------------------------------------------------
# 3. 生成 & 全トークンHidden抽出関数 (修正版)
# -----------------------------------------------------------------

@torch.no_grad()
def generate_and_collect_tokens(prompt_data, num_samples=3):
    """
    1つのプロンプトから複数回生成し、
    「生成された全てのトークン」のHidden Stateを収集する
    """
    results = []
    prompt_text = prompt_data["prompt"]
    category = prompt_data["category"]
    
    inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE)
    
    for _ in range(num_samples):
        # 生成
        outputs = model.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=True,
            temperature=0.9,
            top_p=0.95,
            output_hidden_states=True,
            return_dict_in_generate=True,
            pad_token_id=tokenizer.pad_token_id
        )
        
        # テキストのデコード
        generated_ids = outputs.sequences[0]
        # full_text = tokenizer.decode(generated_ids, skip_special_tokens=True) # 未使用ならコメントアウト可
        new_text = tokenizer.decode(generated_ids[inputs.input_ids.shape[1]:], skip_special_tokens=True)

        token_embeddings = []
        
        # ★★★ 修正箇所: outputs.hidden_states[1:] としてプロンプト部分(index 0)をスキップする ★★★
        # index 0 は (Batch, Prompt_Len, Dim) なので shape が合いません。
        # index 1以降が生成トークン (Batch, 1, Dim) です。
        if len(outputs.hidden_states) > 1:
            for step_data in outputs.hidden_states[1:]:
                # step_data: tuple of layers. Take the last layer [-1]
                # shape: (Batch=1, 1, Hidden_Dim)
                last_layer = step_data[-1].squeeze(0).squeeze(0) # -> (Hidden_Dim)
                token_embeddings.append(last_layer.cpu())
            
        if len(token_embeddings) > 0:
            # (Seq_Len, Hidden_Dim)
            stacked_tokens = torch.stack(token_embeddings, dim=0)
            
            results.append({
                "category": category,
                "prompt": prompt_text,
                "response": new_text,
                "hidden_states": stacked_tokens # Tensor [T, D]
            })
            
    return results

In [9]:
# -----------------------------------------------------------------
# 4. データ収集実行
# -----------------------------------------------------------------
print("Collecting data (this may take a minute)...")
all_data = []
all_hidden_tensors = []

for p_data in tqdm(all_prompts):
    # 各プロンプトについて 5回 サンプリングして多様性を確保
    res = generate_and_collect_tokens(p_data, num_samples=5)
    all_data.extend(res)
    for r in res:
        all_hidden_tensors.append(r["hidden_states"])

# 全トークンを結合: [Total_Tokens, D]
# これがPCAの入力になる
X_all = torch.cat(all_hidden_tensors, dim=0)
print(f"\nCollected Total Tokens: {X_all.shape[0]}")
print(f"Hidden Dimension: {X_all.shape[1]}")

Collecting data (this may take a minute)...


  0%|          | 0/15 [00:01<?, ?it/s]


RuntimeError: stack expects each tensor to be equal size, but got [11, 3584] at entry 0 and [3584] at entry 1

In [None]:
# -----------------------------------------------------------------
# 5. PCAによるサブスペース構築 (Submodel作成)
# -----------------------------------------------------------------
print("\nRunning PCA to build diversity subspace...")

# float32に変換してPCA
X_np = X_all.float().numpy()

# 中心化
mean_vec = np.mean(X_np, axis=0)
X_centered = X_np - mean_vec

# PCA (sklearn)
pca = PCA(n_components=SUBSPACE_DIM)
pca.fit(X_centered)

# 基底ベクトル [k, D]
basis_np = pca.components_
basis_torch = torch.tensor(basis_np, dtype=torch.float32)

print(f"Basis shape: {basis_torch.shape}")
print(f"Explained variance ratio: {pca.explained_variance_ratio_}")
print(f"Total explained variance: {sum(pca.explained_variance_ratio_):.4f}")

# ★ 保存 ★
torch.save(basis_torch, "div_basis.pt")
print(">>> Saved 'div_basis.pt' successfully!")

In [None]:
# -----------------------------------------------------------------
# 6. 分析クラスの定義 (元のNotebookの再現)
# -----------------------------------------------------------------
class DiversitySubspaceModel:
    def __init__(self, basis: torch.Tensor):
        self.basis = basis # [k, D]
        self.k = basis.shape[0]
    
    def project(self, h: torch.Tensor) -> torch.Tensor:
        # h: [N, D] -> z: [N, k]
        if h.dtype != self.basis.dtype:
            h = h.to(self.basis.dtype)
        return h @ self.basis.T

    def token_diversity_score(self, h_seq: torch.Tensor) -> float:
        # 1つのシーケンス(T, D)に対する多様性スコア
        # 今回のRDRL実装に合わせて「原点からの距離(Norm)」の平均を計算してみる
        # または「分散」を見る
        z = self.project(h_seq) # [T, k]
        # PPOコードで使っているロジック: L2 Norm
        norms = torch.norm(z, dim=-1)
        return norms.mean().item()

div_model = DiversitySubspaceModel(basis_torch)

In [None]:
# -----------------------------------------------------------------
# 7. 分析: カテゴリごとの多様性比較
# -----------------------------------------------------------------
# カテゴリごとに、生成されたトークンがサブスペース上でどう分布しているか確認

category_scores = {}

for item in all_data:
    cat = item["category"]
    h = item["hidden_states"]
    
    # サブスペース上のスコア(PPOで報酬になる値)を計算
    score = div_model.token_diversity_score(h)
    
    if cat not in category_scores:
        category_scores[cat] = []
    category_scores[cat].append(score)

print("\n=== Subspace Score by Category (Proxy for Diversity Reward) ===")
rows = []
for cat, scores in category_scores.items():
    mean_score = np.mean(scores)
    std_score = np.std(scores)
    print(f"{cat:12s} | Mean: {mean_score:.4f} | Std: {std_score:.4f}")
    
    for s in scores:
        rows.append({"Category": cat, "Score": s})

df_scores = pd.DataFrame(rows)

In [None]:
# -----------------------------------------------------------------
# 8. テキスト多様性(Jaccard)との相関 (元のNotebookの再現)
# -----------------------------------------------------------------
# ※トークンレベルになったので厳密な比較は難しいですが、
#   カテゴリ単位で「テキストがバラついているか」vs「スコアが高いか」を見ます

def jaccard_similarity(str1, str2):
    a = set(str1.lower().split())
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

# カテゴリごとのテキスト多様性(1 - 平均Jaccard)を計算
cat_text_diversity = {}

for cat in prompt_categories.keys():
    # そのカテゴリの生成テキストを集める
    texts = [d["response"] for d in all_data if d["category"] == cat]
    
    if len(texts) < 2:
        cat_text_diversity[cat] = 0
        continue
        
    sims = []
    for i, t1 in enumerate(texts):
        for j, t2 in enumerate(texts):
            if i < j:
                sims.append(jaccard_similarity(t1, t2))
    
    # Diversity = 1 - Similarity
    cat_text_diversity[cat] = 1.0 - np.mean(sims)

print("\n=== Correlation Check ===")
print(f"{'Category':12s} | Text Div (Jaccard) | Subspace Score (Mean)")
for cat in cat_text_diversity.keys():
    text_div = cat_text_diversity[cat]
    sub_score = np.mean(category_scores[cat])
    print(f"{cat:12s} | {text_div:.4f}               | {sub_score:.4f}")

In [None]:
# -----------------------------------------------------------------
# 9. 可視化 (Visualization)
# -----------------------------------------------------------------

# (1) Boxplot of Scores
plt.figure(figsize=(10, 6))
sns.boxplot(data=df_scores, x="Category", y="Score")
plt.title("Distribution of Diversity Scores by Category")
plt.ylabel("Subspace Projection Norm (Reward Signal)")
plt.show()

# (2) Scatter Plot of Tokens (PC1 vs PC2)
# 全トークンからランダムにサンプリングして描画
num_plot_points = 2000
indices = np.random.choice(X_all.shape[0], num_plot_points, replace=False)
X_sample = X_all[indices].float()
# 射影
z_sample = div_model.project(X_sample).numpy()

plt.figure(figsize=(8, 8))
plt.scatter(z_sample[:, 0], z_sample[:, 1], alpha=0.5, s=5, c='blue')
plt.title(f"Token Distribution in Subspace (PC1 vs PC2)\nSampled {num_plot_points} tokens")
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.grid(True, alpha=0.3)
plt.show()

# (3) Explained Variance
plt.figure(figsize=(8, 4))
plt.bar(range(1, SUBSPACE_DIM + 1), pca.explained_variance_ratio_)
plt.xlabel("Principal Component Index")
plt.ylabel("Explained Variance Ratio")
plt.title("Importance of Each Dimension in Subspace")
plt.show()