# Importing libraries

In [1]:
import os
import sys
from dotenv import load_dotenv
from typing import Tuple
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torch.profiler import profile, record_function, ProfilerActivity
import wandb
from src.utils import set_seed, load_text, split_text
from src.config import ModelConfig, TrainConfig, GenerationConfig
from src.train import Trainer
from tokenizer.tokenizer import CharTokenizer
from models.GPT import GPT

In [2]:
PROJECT_ROOT = os.path.abspath(os.getcwd() + "/..")
sys.path.append(PROJECT_ROOT)
print(f"PROJECT_ROOT: {PROJECT_ROOT}")

PROJECT_ROOT: /home/pathfinder/projects/PathFinder


# Configuration

In [3]:
model_config = ModelConfig(
    vocab_size=-1,
    max_seq_len=128,
    flash=True,
    d_embed=256,
    n_layers=4,
    n_heads=4,
    d_head=64,
    rank=16,
    d_ff=1024,
    cla=False
)

train_config = TrainConfig(
    debug=False,
    wandb_project="nanoGPT",
    per_device_train_batch_size=512,
    per_device_eval_batch_size=1024,
    gradient_accumulation_steps=512 // 512,
    num_train_epochs=1,
    learning_rate=5e-4,
    weight_decay=0.01,
    attn_decay=0.5,
    eval_steps=100,
    mixed_precision=True,
    matmul_precision="high",
)

generation_config = GenerationConfig(
    use_cache=True,
    max_new_tokens=1000,
    temperature=1.0,
    top_k=50
)

In [4]:
load_dotenv()
wandb.login(key=os.environ.get("WANDB_API_KEY"))

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/pathfinder/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mpathfinderkr[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Utils

## Reproducibility

In [5]:
set_seed(train_config.seed)

Random seed set to 42


## Device

In [6]:
device = torch.device("cuda")
print(f"Device: {torch.cuda.get_device_name(device)}")
torch.set_float32_matmul_precision(train_config.matmul_precision)  # Tensor Cores
print(f"MatMul Precision: {train_config.matmul_precision}")

Device: NVIDIA GeForce RTX 4080 SUPER
MatMul Precision: high


# Dataset

In [7]:
dataset_path = os.path.join(PROJECT_ROOT, "datasets/Shakespeare/shakespeare.txt")
shakespeare_text = load_text(dataset_path)

Loaded text data from /home/pathfinder/projects/PathFinder/datasets/Shakespeare/shakespeare.txt (length: 1115394 characters).


In [8]:
if train_config.debug:
    subset_shakespeare_text = shakespeare_text[:10000]
    print(subset_shakespeare_text)
    shakespeare_text = subset_shakespeare_text

# Tokenizer

In [9]:
char_tokenizer = CharTokenizer()
char_tokenizer.build_vocab(text=shakespeare_text)
vocab_path = os.path.join(PROJECT_ROOT, "datasets/Shakespeare/vocab.json")
char_tokenizer.save_vocab(vocab_path)
model_config.vocab_size = char_tokenizer.vocab_size

Vocabulary size: 69
Vocabulary saved to /home/pathfinder/projects/PathFinder/datasets/Shakespeare/vocab.json.


In [10]:
if train_config.debug:
    print("Vocabulary:", char_tokenizer.char2idx)

# Preprocessing

In [11]:
train_text, val_text = split_text(shakespeare_text, val_size=0.1)
print(f"Training text length: {len(train_text)} characters")
print(f"Validation text length: {len(val_text)} characters")

Training text length: 1003854 characters
Validation text length: 111540 characters


In [12]:
class TextDataset(Dataset):
    def __init__(self, text: str, tokenizer: CharTokenizer, max_seq_len: int):
        self.encoded = tokenizer.encode(text)
        self.max_seq_len = max_seq_len

    def __len__(self) -> int:
        return len(self.encoded) - self.max_seq_len

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        input_ids = self.encoded[idx:idx + self.max_seq_len]
        target_ids = self.encoded[idx + 1:idx + self.max_seq_len + 1]
        return input_ids, target_ids

def collate_fn(batch):
    input_ids = torch.stack([item[0] for item in batch])
    target_ids = torch.stack([item[1] for item in batch])
    return {
        "input_ids": input_ids,
        #"attention_mask": attention_mask,
        "target_ids": target_ids
    }

train_dataset = TextDataset(train_text, char_tokenizer, model_config.max_seq_len)
val_dataset = TextDataset(val_text, char_tokenizer, model_config.max_seq_len)

train_loader = DataLoader(
    train_dataset,
    collate_fn=collate_fn,
    batch_size=train_config.per_device_eval_batch_size,
    shuffle=True,
    num_workers=4
)
val_loader = DataLoader(
    val_dataset,
    collate_fn=collate_fn,
    batch_size=train_config.per_device_eval_batch_size,
    shuffle=False,
    num_workers=4
)

In [13]:
if train_config.debug:
    sample_batch = next(iter(train_loader))
    print(f"Sample input IDs: {sample_batch['input_ids'][0]}")
    print(f"Sample target IDs: {sample_batch['target_ids'][0]}")

# Model

In [14]:
# Initialize the model
model = GPT(model_config).to(device)
model = torch.compile(model)
print(model)
print(f"Number of parameters: {model.get_num_params() / 1e6:.2f}M")

OptimizedModule(
  (_orig_mod): GPT(
    (token_embedding): Embedding(69, 256)
    (positional_encoding): Embedding(128, 256)
    (dropout): Dropout(p=0.01, inplace=False)
    (blocks): ModuleList(
      (0-3): 4 x Block(
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadAttention(
          (Wq): Linear(in_features=256, out_features=256, bias=False)
          (Wkv_down): Linear(in_features=256, out_features=16, bias=False)
          (Wk_up): Linear(in_features=16, out_features=256, bias=False)
          (Wv_up): Linear(in_features=16, out_features=256, bias=False)
          (out_proj): Linear(in_features=256, out_features=256, bias=False)
          (dropout): Dropout(p=0.01, inplace=False)
        )
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): FeedForward(
          (fc1): Linear(in_features=256, out_features=1024, bias=False)
          (fc2): Linear(in_features=1024, out_features=256, bias=False)
 

# Training

In [15]:
trainer = Trainer(
    model=model,
    train_config=train_config,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    master_process=True
)
trainer.train()

Training: 100%|██████████| 981/981 [02:15<00:00,  7.23it/s, epoch=1, grad_norm=0.5658, loss=2.0813, lr=0.000000]


0,1
Grad Norm,▄▃▂▂▂▁▁▁▁▂▂▁▂▁▂▄█▂▃▁▂▂▂▃▂▂▂▃▂▁▂▂▃▂▁▁▁▁▁▂
Learning Rate,▅▆▆▆█████▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁
Train Loss,██▇▇▄▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val Loss,█▆▅▄▂▂▁▁▁▁
Val Perplexity,█▆▅▃▂▂▁▁▁▁

0,1
Grad Norm,0.56576
Learning Rate,0.0
Train Loss,2.08134
Val Loss,2.13713
Val Perplexity,8.47509


## Save the model

In [16]:
if not train_config.debug:
    pass
    #output_dir = os.path.join(PROJECT_ROOT, "checkpoints", train_config.model_name, train_config.run_name)
    #os.makedirs(output_dir, exist_ok=True)
    #try:
    #    model.save_pretrained(
    #        output_dir,
    #        safe_serialization=True
    #    )
    #    print("Model saved successfully")
    #xcept Exception as e:
    #    print(f"Error saving model: {e}")
    # Push to Hugging Face Hub
    #model.push_to_hub(
    #    repo_id=f"PathFinderKR/{train_config.model_name}-{train_config.run_name}",
    #    private=True,
    #    use_auth_token=os.environ.get("HUGGINGFACE_TOKEN")
    #)
    #print(f"Model pushed to Hugging Face Hub: PathFinderKR/{train_config.model_name}-{train_config.run_name}")

In [17]:
# To load the model later, you can use:
# model = GPT.from_pretrained(output_dir).to(device)

# Inference

In [18]:
user_prompt = "To be, or not to be, that is the question"
input_ids = char_tokenizer.encode(user_prompt).unsqueeze(0).to(device)
output = model.generate(
    input_ids,
    use_cache=True,
    max_new_tokens=generation_config.max_new_tokens,
    temperature=generation_config.temperature,
    top_k=generation_config.top_k,
    tokenizer=char_tokenizer
)
response = char_tokenizer.decode(output[0].squeeze().cpu().numpy())

g utrngs.
Thasth high dare Seve's hear weis wellds,
Andve shous mele fat wor eal thingm
Resetting KV cache
Mame he you thour 'te bottlle! alle bay,
Ived Ran'sw I Happromes femartod be facambts
whend our levesens whre wilour worre the tResetting KV cache
, mort of aslt, at iny onges peor esm,
Id im twither chadere oul gows oulys hour
Whutind! jut bon pureding dvighm, theef.

BOMINResetting KV cache
lce, I'l way,
Thiper, fit balim kidrp, tay theiced ive thas tooty
He prauke tomey hean.

Sey:
Deer; the sonr thy alde my.

BEDY:Resetting KV cache
 Rit to, ktoy, a lofe wigh'd hour reaiag
Warsen stiser thates ime; hef heer, tity nidd ard halll dak
Clits youses my co deor hicResetting KV cache
ovick,
Thou sher the 'vis, meavers tith nouimes
I ave con tim. Whan boty you and thaks lorch.

KING VIMARK:
Mis hener'd, det balResetting KV cache
et the oowelg nise rutld
Lecart seamessog whas all hes budol nit,
O, es wourd you tit ep; balllk I'suckes, al sthe?

GLOUMES:
DuResetting KV cache
f mory.

DU

In [19]:
print("=" * 50)
print("User prompt: ")
print(user_prompt)
print("-" * 50)
print("🤖 Model Response:")
print(response)

User prompt: 
To be, or not to be, that is the question
--------------------------------------------------
🤖 Model Response:
To be, or not to be, that is the questiong utrngs.
Thasth high dare Seve's hear weis wellds,
Andve shous mele fat wor eal thingm
Mame he you thour 'te bottlle! alle bay,
Ived Ran'sw I Happromes femartod be facambts
whend our levesens whre wilour worre the t, mort of aslt, at iny onges peor esm,
Id im twither chadere oul gows oulys hour
Whutind! jut bon pureding dvighm, theef.

BOMINlce, I'l way,
Thiper, fit balim kidrp, tay theiced ive thas tooty
He prauke tomey hean.

Sey:
Deer; the sonr thy alde my.

BEDY: Rit to, ktoy, a lofe wigh'd hour reaiag
Warsen stiser thates ime; hef heer, tity nidd ard halll dak
Clits youses my co deor hicovick,
Thou sher the 'vis, meavers tith nouimes
I ave con tim. Whan boty you and thaks lorch.

KING VIMARK:
Mis hener'd, det balet the oowelg nise rutld
Lecart seamessog whas all hes budol nit,
O, es wourd you tit ep; balllk I'suckes,

In [20]:
asdf

NameError: name 'asdf' is not defined

# Speedometer

In [None]:
if train_config.debug:
    speedometer(
        model=model,
        input_ids=char_tokenizer.encode("a").unsqueeze(0).to(device),
        use_cache=False,
        warmup_tokens=100,
        timing_tokens=100,
        num_runs=5
    )

# Profiling

In [None]:
if train_config.debug:
    input_ids = torch.randint(0, model_config.vocab_size, (1, model_config.max_seq_len), device=device)
    with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("model_inference"):
            model(input_ids)
    print(prof.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=20))

# Attention Scores

In [None]:
def predict_max_distribution(self, layer_idx=0, head_idx=0, seq_len=512, num_samples=1000, plot=True):
    """
    Predict and visualize the distribution of maximum attention scores

    Args:
        layer_idx: Layer to analyze
        head_idx: Head to analyze
        seq_len: Sequence length
        num_samples: Number of samples for distribution
        plot: Whether to plot the histogram

    Returns:
        Dictionary with max values and statistics
    """
    import numpy as np
    import torch
    import matplotlib.pyplot as plt

    device = next(self.parameters()).device
    config = self.config
    scale = config.scale if config.scale is not None else config.d_head ** -0.5

    if config.rank is not None:
        print("Multi Head Latent Attention not supported yet")
        return None

    # Get attention weights
    attn_layer = self.blocks[layer_idx].attn
    qkv_weight = attn_layer.qkv_proj.weight
    wq_full = qkv_weight[:config.d_embed, :]
    wk_full = qkv_weight[config.d_embed:2*config.d_embed, :]

    # Extract head-specific weights
    head_start = head_idx * config.d_head
    head_end = (head_idx + 1) * config.d_head
    wq_head = wq_full[head_start:head_end, :]  # [d_head, d_embed]
    wk_head = wk_full[head_start:head_end, :]  # [d_head, d_embed]

    print(f"Generating distribution for Layer {layer_idx}, Head {head_idx}")
    print(f"Sequence length: {seq_len}, Samples: {num_samples}")

    max_values = []

    with torch.no_grad():
        for i in range(num_samples):
            if i % 100 == 0:
                print(f"Progress: {i}/{num_samples}")

            # Sample random input (LayerNorm output assumption)
            x = torch.randn(seq_len, config.d_embed, device=device)

            # Compute Q, K
            q = x @ wq_head.T  # [seq_len, d_head]
            k = x @ wk_head.T  # [seq_len, d_head]

            # Compute attention scores
            attn_scores = q @ k.T * scale  # [seq_len, seq_len]

            # Apply causal mask
            causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
            attn_masked = attn_scores.masked_fill(causal_mask == 0, float('-inf'))

            # Get maximum value
            max_val = attn_masked.max().item()
            if max_val != float('-inf'):
                max_values.append(max_val)

    max_values = np.array(max_values)

    # Calculate statistics
    stats = {
        'max_values': max_values,
        'mean': np.mean(max_values),
        'std': np.std(max_values),
        'median': np.median(max_values),
        'min': np.min(max_values),
        'max': np.max(max_values),
        'percentiles': {
            '25': np.percentile(max_values, 25),
            '75': np.percentile(max_values, 75),
            '90': np.percentile(max_values, 90),
            '95': np.percentile(max_values, 95),
            '99': np.percentile(max_values, 99)
        },
        'layer_idx': layer_idx,
        'head_idx': head_idx,
        'seq_len': seq_len,
        'scale': scale
    }

    # Theoretical comparison
    wqk_product = wq_head @ wk_head.T
    qk_variance = torch.trace(wqk_product).item()
    scaled_std = np.sqrt(qk_variance) * scale
    theoretical_max = scaled_std * np.sqrt(2 * np.log(seq_len))

    stats['theoretical'] = {
        'qk_std': np.sqrt(qk_variance),
        'scaled_std': scaled_std,
        'predicted_max': theoretical_max
    }

    # Print statistics
    print(f"\n=== DISTRIBUTION STATISTICS ===")
    print(f"Sample size: {len(max_values)}")
    print(f"Mean: {stats['mean']:.4f}")
    print(f"Std: {stats['std']:.4f}")
    print(f"Median: {stats['median']:.4f}")
    print(f"Range: [{stats['min']:.4f}, {stats['max']:.4f}]")
    print(f"90th percentile: {stats['percentiles']['90']:.4f}")
    print(f"95th percentile: {stats['percentiles']['95']:.4f}")
    print(f"99th percentile: {stats['percentiles']['99']:.4f}")
    print(f"Theoretical prediction: {theoretical_max:.4f}")
    print(f"Empirical vs Theoretical: {stats['mean']/theoretical_max:.4f}")

    if plot:
        plot_max_distribution(stats)

    return stats

def plot_max_distribution(stats):
    """
    Plot histogram of maximum attention scores
    """
    import matplotlib.pyplot as plt

    max_values = stats['max_values']
    theoretical_max = stats['theoretical']['predicted_max']

    plt.figure(figsize=(12, 8))

    # Main histogram
    plt.subplot(2, 2, 1)
    n, bins, patches = plt.hist(max_values, bins=50, density=True, alpha=0.7,
                               color='lightcoral', edgecolor='black', linewidth=0.5)

    # Add vertical lines for key statistics
    plt.axvline(stats['mean'], color='red', linestyle='--', linewidth=2,
                label=f'Mean: {stats["mean"]:.3f}')
    plt.axvline(stats['median'], color='blue', linestyle='--', linewidth=2,
                label=f'Median: {stats["median"]:.3f}')
    plt.axvline(theoretical_max, color='green', linestyle='--', linewidth=2,
                label=f'Theoretical: {theoretical_max:.3f}')
    plt.axvline(stats['percentiles']['95'], color='orange', linestyle='--', linewidth=2,
                label=f'95th %ile: {stats["percentiles"]["95"]:.3f}')

    plt.xlabel('Maximum Attention Score (QK^T/√D)')
    plt.ylabel('Frequency Density')
    plt.title(f'Distribution of Max Attention Scores\n'
              f'Layer {stats["layer_idx"]}, Head {stats["head_idx"]}, T={stats["seq_len"]}')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Cumulative distribution
    plt.subplot(2, 2, 2)
    sorted_values = np.sort(max_values)
    cumulative = np.arange(1, len(sorted_values) + 1) / len(sorted_values)
    plt.plot(sorted_values, cumulative, 'b-', linewidth=2)
    plt.axvline(stats['mean'], color='red', linestyle='--', alpha=0.7)
    plt.axvline(theoretical_max, color='green', linestyle='--', alpha=0.7)
    plt.xlabel('Maximum Attention Score')
    plt.ylabel('Cumulative Probability')
    plt.title('Cumulative Distribution Function')
    plt.grid(True, alpha=0.3)

    # Box plot
    plt.subplot(2, 2, 3)
    box_data = [max_values]
    bp = plt.boxplot(box_data, patch_artist=True)
    bp['boxes'][0].set_facecolor('lightblue')
    plt.ylabel('Maximum Attention Score')
    plt.title('Box Plot')
    plt.grid(True, alpha=0.3)

    # Q-Q plot vs normal distribution
    plt.subplot(2, 2, 4)
    from scipy import stats as scipy_stats
    scipy_stats.probplot(max_values, dist="norm", plot=plt)
    plt.title('Q-Q Plot vs Normal Distribution')
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Separate figure for detailed histogram
    plt.figure(figsize=(10, 6))

    # More detailed histogram
    plt.hist(max_values, bins=75, density=True, alpha=0.8,
             color='lightcoral', edgecolor='black', linewidth=0.3)

    # Add statistics lines
    plt.axvline(stats['mean'], color='red', linestyle='-', linewidth=3,
                label=f'Mean: {stats["mean"]:.4f}')
    plt.axvline(stats['median'], color='blue', linestyle='-', linewidth=2,
                label=f'Median: {stats["median"]:.4f}')
    plt.axvline(theoretical_max, color='green', linestyle='--', linewidth=2,
                label=f'Theoretical: {theoretical_max:.4f}')

    # Add percentile lines
    plt.axvline(stats['percentiles']['90'], color='orange', linestyle=':', linewidth=2,
                label=f'90th %ile: {stats["percentiles"]["90"]:.4f}')
    plt.axvline(stats['percentiles']['95'], color='purple', linestyle=':', linewidth=2,
                label=f'95th %ile: {stats["percentiles"]["95"]:.4f}')
    plt.axvline(stats['percentiles']['99'], color='brown', linestyle=':', linewidth=2,
                label=f'99th %ile: {stats["percentiles"]["99"]:.4f}')

    plt.xlabel('Maximum Attention Score (QK^T/√D)', fontsize=12)
    plt.ylabel('Frequency Density', fontsize=12)
    plt.title(f'Distribution of Maximum Attention Scores\n'
              f'Layer {stats["layer_idx"]}, Head {stats["head_idx"]}, '
              f'Sequence Length {stats["seq_len"]}, Scale {stats["scale"]:.6f}', fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)

    # Add text box with statistics
    textstr = f'''Statistics:
    Samples: {len(max_values)}
    Mean: {stats["mean"]:.4f}
    Std: {stats["std"]:.4f}
    Min: {stats["min"]:.4f}
    Max: {stats["max"]:.4f}'''

    props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
    plt.text(0.02, 0.98, textstr, transform=plt.gca().transAxes, fontsize=9,
             verticalalignment='top', bbox=props)

    plt.tight_layout()
    plt.show()

def compare_multiple_heads(self, seq_len=512, num_samples=500, max_heads=8):
    """
    Compare distributions across multiple heads
    """
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.flatten()

    all_stats = []

    for head_idx in range(min(max_heads, self.config.n_heads)):
        print(f"\nAnalyzing Head {head_idx}...")
        stats = self.predict_max_distribution(
            layer_idx=0, head_idx=head_idx,
            seq_len=seq_len, num_samples=num_samples, plot=False
        )
        all_stats.append(stats)

        ax = axes[head_idx]
        max_values = stats['max_values']

        ax.hist(max_values, bins=30, density=True, alpha=0.7,
                color=plt.cm.Set3(head_idx), edgecolor='black', linewidth=0.5)
        ax.axvline(stats['mean'], color='red', linestyle='--', linewidth=2)
        ax.set_title(f'Head {head_idx}\nMean: {stats["mean"]:.3f}')
        ax.set_xlabel('Max Score')
        ax.set_ylabel('Density')
        ax.grid(True, alpha=0.3)

    plt.suptitle(f'Distribution Comparison Across Heads (Layer 0, T={seq_len})', fontsize=16)
    plt.tight_layout()
    plt.show()

    return all_stats

# Usage examples:
stats = model.predict_max_distribution(layer_idx=0, head_idx=0, seq_len=512, num_samples=1000)
all_stats = model.compare_multiple_heads(seq_len=512, num_samples=500)

In [None]:
def analyze_qk_distribution(model, model_config, num_samples=1000):
    model.eval()
    results = {}

    # 입력 분포 (LayerNorm 후라고 가정: 평균=0, 분산=1)
    with torch.no_grad():
        x_samples = torch.randn(num_samples, model_config.d_embed, device=device)  # [num_samples, d_embed]

    scale = model_config.scale
    theoretical_max = np.sqrt(2 * np.log(model_config.max_seq_len))

    for layer_idx in range(model_config.n_layers):
        attn_layer = model.blocks[layer_idx].attn

        if model_config.rank is None:
            # Multi Head Attention
            qkv_weight = attn_layer.qkv_proj.weight  # [3*d_embed, d_embed]
            wq_full = qkv_weight[:model_config.d_embed, :]      # [d_embed, d_embed]
            wk_full = qkv_weight[model_config.d_embed:2*model_config.d_embed, :]  # [d_embed, d_embed]

            layer_results = {}

            for head_idx in range(model_config.n_heads):
                # 헤드별 가중치 추출
                head_start = head_idx * model_config.d_head
                head_end = (head_idx + 1) * model_config.d_head

                wq_head = wq_full[head_start:head_end, :]  # [d_head, d_embed]
                wk_head = wk_full[head_start:head_end, :]  # [d_head, d_embed]

                with torch.no_grad():
                    # Q, K 계산
                    q_samples = x_samples @ wq_head.T  # [num_samples, d_head]
                    k_samples = x_samples @ wk_head.T  # [num_samples, d_head]

                    # 실제 분산 계산 (CLT 기반)
                    # Var[QK^T] = Σ(d=1 to d_head) Var[Q_d * K_d]
                    # = Σ(d=1 to d_head) E[Q_d^2] * E[K_d^2] (독립성 가정)
                    q_var_per_dim = q_samples.var(dim=0)  # [d_head]
                    k_var_per_dim = k_samples.var(dim=0)  # [d_head]

                    # 실제 분산 (정확한 계산)
                    actual_variance = (q_var_per_dim * k_var_per_dim).sum().item()
                    actual_std = np.sqrt(actual_variance)

                    # 스케일링 후 분산
                    scaled_variance = actual_variance * (scale ** 2)
                    scaled_std = actual_std * scale

                    # 최댓값 분포 근사 (Gumbel 분포 사용)
                    # max(QK^T/√D) ≈ μ + σ * √(2 log(seq_len))
                    predicted_max = scaled_std * np.sqrt(2 * np.log(model_config.max_seq_len))

                    # 실제 계산으로 검증 (작은 샘플로)
                    if head_idx == 0:  # 첫 번째 헤드만 실제 계산
                        qk_scores = q_samples @ k_samples.T  # [num_samples, num_samples]
                        qk_scaled = qk_scores * scale

                        # Causal mask (상삼각 제거)
                        tril_mask = torch.tril(torch.ones(num_samples, num_samples, device=device))
                        qk_masked = qk_scaled.masked_fill(tril_mask == 0, float('-inf'))

                        max_values = qk_masked.max(dim=-1).values  # [num_samples]
                        max_values = max_values[max_values != float('-inf')]

                        empirical_max_mean = max_values.mean().item()
                        empirical_max_std = max_values.std().item()

                    layer_results[head_idx] = {
                        'actual_variance': actual_variance,
                        'actual_std': actual_std,
                        'scaled_variance': scaled_variance,
                        'scaled_std': scaled_std,
                        'predicted_max': predicted_max,
                        'q_var_mean': q_var_per_dim.mean().item(),
                        'k_var_mean': k_var_per_dim.mean().item(),
                    }

                    # 첫 번째 헤드는 실제 값도 저장
                    if head_idx == 0:
                        layer_results[head_idx].update({
                            'empirical_max_mean': empirical_max_mean,
                            'empirical_max_std': empirical_max_std,
                        })

            results[layer_idx] = layer_results

        else:
            # Multi Head Latent Attention (간단히 처리)
            print(f"Layer {layer_idx}: Multi Head Latent Attention (skipped)")
            continue

    # 결과 출력
    print_comprehensive_results(results, theoretical_max, scale)
    return results

def print_comprehensive_results(results, theoretical_max, scale):
    """결과를 보기 좋게 출력"""

    print(f"\n{'='*20} COMPREHENSIVE QK^T DISTRIBUTION ANALYSIS {'='*20}")
    print(f"Scale factor: {scale:.6f}")
    print(f"Theoretical max (√(2log T)): {theoretical_max:.4f}")
    print(f"{'='*80}")

    # 헤더
    print(f"{'Layer':<6} {'Head':<4} {'Actual_Std':<10} {'Scaled_Std':<10} {'Pred_Max':<9} {'Q_Var':<7} {'K_Var':<7} {'Status':<10}")
    print(f"{'-'*80}")

    all_scaled_stds = []
    all_pred_maxes = []

    for layer_idx, layer_data in results.items():
        for head_idx, head_data in layer_data.items():
            actual_std = head_data['actual_std']
            scaled_std = head_data['scaled_std']
            pred_max = head_data['predicted_max']
            q_var = head_data['q_var_mean']
            k_var = head_data['k_var_mean']

            all_scaled_stds.append(scaled_std)
            all_pred_maxes.append(pred_max)

            # 상태 판정
            if scaled_std > 1.5:
                status = "🔴HIGH"
            elif scaled_std > 1.1:
                status = "🟡MED"
            else:
                status = "🟢OK"

            print(f"{layer_idx:<6} {head_idx:<4} {actual_std:<10.4f} {scaled_std:<10.4f} {pred_max:<9.3f} {q_var:<7.3f} {k_var:<7.3f} {status:<10}")

    # 통계 요약
    print(f"{'-'*80}")
    print(f"SUMMARY STATISTICS:")
    print(f"  Scaled Std  - Mean: {np.mean(all_scaled_stds):.4f}, Std: {np.std(all_scaled_stds):.4f}")
    print(f"  Predicted Max - Mean: {np.mean(all_pred_maxes):.4f}, Std: {np.std(all_pred_maxes):.4f}")
    print(f"  Target: Scaled_Std ≈ 1.0, Predicted_Max ≈ {theoretical_max:.3f}")

    # 실제 검증 (첫 번째 레이어의 첫 번째 헤드)
    if 0 in results and 0 in results[0]:
        first_head = results[0][0]
        if 'empirical_max_mean' in first_head:
            emp_max = first_head['empirical_max_mean']
            pred_max = first_head['predicted_max']
            print(f"\nVERIFICATION (Layer 0, Head 0):")
            print(f"  Empirical max mean: {emp_max:.4f}")
            print(f"  Predicted max: {pred_max:.4f}")
            print(f"  Ratio (emp/pred): {emp_max/pred_max:.4f}")

    # 문제 있는 레이어/헤드 식별
    problematic = []
    for layer_idx, layer_data in results.items():
        for head_idx, head_data in layer_data.items():
            if head_data['scaled_std'] > 1.5:
                problematic.append((layer_idx, head_idx, head_data['scaled_std']))

    if problematic:
        print(f"\n⚠️  PROBLEMATIC HEADS (Scaled_Std > 1.5):")
        for layer_idx, head_idx, scaled_std in problematic:
            print(f"  Layer {layer_idx}, Head {head_idx}: {scaled_std:.4f}")
    else:
        print(f"\n✅ All heads have reasonable attention distributions!")

In [None]:
results = analyze_qk_distribution(model, model_config, num_samples=1000)

In [None]:
def plot_distribution_heatmap(results, config):
    """레이어/헤드별 분포를 히트맵으로 시각화"""

    scaled_stds = np.zeros((config.n_layers, config.n_heads))
    pred_maxes = np.zeros((config.n_layers, config.n_heads))

    for layer_idx, layer_data in results.items():
        for head_idx, head_data in layer_data.items():
            scaled_stds[layer_idx, head_idx] = head_data['scaled_std']
            pred_maxes[layer_idx, head_idx] = head_data['predicted_max']

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Scaled Standard Deviation
    im1 = ax1.imshow(scaled_stds, cmap='RdYlBu_r', aspect='auto')
    ax1.set_title('Scaled Standard Deviation (QK^T/√D)')
    ax1.set_xlabel('Head Index')
    ax1.set_ylabel('Layer Index')
    plt.colorbar(im1, ax=ax1)

    # Predicted Maximum
    im2 = ax2.imshow(pred_maxes, cmap='RdYlBu_r', aspect='auto')
    ax2.set_title('Predicted Maximum (QK^T/√D)')
    ax2.set_xlabel('Head Index')
    ax2.set_ylabel('Layer Index')
    plt.colorbar(im2, ax=ax2)

    plt.tight_layout()
    plt.show()

In [None]:
plot_distribution_heatmap(results, model.config)