In [3]:
import os
import random
import torch
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from sklearn.metrics import pairwise_distances
from sklearn.manifold import TSNE

from deit.models_v2_rope import compute_mixed_cis, init_t_xy
from models import vit_rope

mix_model = vit_rope.rope_mixed_deit_base_patch16_LS(pretrained=True)
mix_model.eval()

axial_model = vit_rope.rope_axial_deit_small_patch16_LS(pretrained=True)
axial_model.eval()

  from .autonotebook import tqdm as notebook_tqdm
  def rope_axial_deit_small_patch16_LS(pretrained=False, img_size=224,  **kwargs):
  def rope_axial_deit_base_patch16_LS(pretrained=False, img_size=224,  **kwargs):
  def rope_axial_deit_large_patch16_LS(pretrained=False, img_size=224,  **kwargs):
  def rope_mixed_deit_small_patch16_LS(pretrained=False, img_size=224,  **kwargs):
  def rope_mixed_deit_base_patch16_LS(pretrained=False, img_size=224,  **kwargs):
  def rope_mixed_deit_large_patch16_LS(pretrained=False, img_size=224,  **kwargs):
  def rope_axial_ape_deit_small_patch16_LS(pretrained=False, img_size=224,  **kwargs):
  def rope_axial_ape_deit_base_patch16_LS(pretrained=False, img_size=224,  **kwargs):
  def rope_axial_ape_deit_large_patch16_LS(pretrained=False, img_size=224,  **kwargs):
  def rope_mixed_ape_deit_small_patch16_LS(pretrained=False, img_size=224,  **kwargs):
  def rope_mixed_ape_deit_base_patch16_LS(pretrained=False, img_size=224,  **kwargs):
  def rope_mixed_ape_

Removing key freqs_t_x from pretrained checkpoint
Removing key freqs_t_y from pretrained checkpoint


rope_vit_models(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x RoPE_Layer_scale_init_Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): RoPEAttention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): L

In [10]:
for name, param in mix_model.named_parameters():
    print(f"{name} | shape: {param.shape} | requires_grad: {param.requires_grad}")


cls_token | shape: torch.Size([1, 1, 768]) | requires_grad: True
freqs | shape: torch.Size([2, 12, 384]) | requires_grad: True
patch_embed.proj.weight | shape: torch.Size([768, 3, 16, 16]) | requires_grad: True
patch_embed.proj.bias | shape: torch.Size([768]) | requires_grad: True
blocks.0.gamma_1 | shape: torch.Size([768]) | requires_grad: True
blocks.0.gamma_2 | shape: torch.Size([768]) | requires_grad: True
blocks.0.norm1.weight | shape: torch.Size([768]) | requires_grad: True
blocks.0.norm1.bias | shape: torch.Size([768]) | requires_grad: True
blocks.0.attn.qkv.weight | shape: torch.Size([2304, 768]) | requires_grad: True
blocks.0.attn.qkv.bias | shape: torch.Size([2304]) | requires_grad: True
blocks.0.attn.proj.weight | shape: torch.Size([768, 768]) | requires_grad: True
blocks.0.attn.proj.bias | shape: torch.Size([768]) | requires_grad: True
blocks.0.norm2.weight | shape: torch.Size([768]) | requires_grad: True
blocks.0.norm2.bias | shape: torch.Size([768]) | requires_grad: True


In [68]:
def compute_mixed_cis(freqs, t_x, t_y, num_heads):
    """
    args:
        freqs: [2, num_heads, freq_dim]
        t_x: [N]
        t_y: [N]
        num_heads: int
    returns:
        freqs_cis: [N, num_heads, freq_dim]
    """
    with torch.cuda.amp.autocast(enabled=False):
        # freqs: [2, num_heads, freq_dim]
        freqs_x = torch.einsum("n,hf->nhf", t_x, freqs[0])  # [N, H, F]
        freqs_y = torch.einsum("n,hf->nhf", t_y, freqs[1])  # [N, H, F]
        angles = freqs_x + freqs_y
        freqs_cis = torch.polar(torch.ones_like(angles), angles)  # [N, H, F]
    return freqs_cis

def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 100.0):
    freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))

    t_x, t_y = init_t_xy(end_x, end_y)
    freqs_x = torch.outer(t_x, freqs_x)
    freqs_y = torch.outer(t_y, freqs_y)
    freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
    freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
    return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)


def plot_tsne_subplot(freqs_cis, layer_idx, ax, N, width=14):
    rotary_codes_complex = freqs_cis.reshape(N, -1)
    rotary_codes = np.concatenate([
        rotary_codes_complex.real.numpy(),
        rotary_codes_complex.imag.numpy()
    ], axis=-1).astype(np.float32)

    tsne = TSNE(n_components=2, perplexity=30, random_state=42)
    tsne_result = tsne.fit_transform(rotary_codes)

    dists = pairwise_distances(rotary_codes)
    np.fill_diagonal(dists, np.inf)
    i, j = np.unravel_index(np.argmin(dists), dists.shape)
    min_dist = dists[i, j]

    # 可视化
    scatter = ax.scatter(tsne_result[:, 0], tsne_result[:, 1], c=np.arange(N), cmap='viridis', s=20)
    ax.plot(
        [tsne_result[i, 0], tsne_result[j, 0]],
        [tsne_result[i, 1], tsne_result[j, 1]],
        'r-', linewidth=1.5
    )
    ax.scatter(tsne_result[[i, j], 0], tsne_result[[i, j], 1], color='red', s=30)
    ax.set_title(f'Layer {layer_idx} (min dist={min_dist:.2f})', fontsize=9)
    ax.axis('off')

def plot_tsne(freqs_cis, N, width=14):
    rotary_codes_complex = freqs_cis.reshape(N, -1)
    rotary_codes = np.concatenate([
        rotary_codes_complex.real.numpy(),
        rotary_codes_complex.imag.numpy()
    ], axis=-1).astype(np.float32)

    tsne = TSNE(n_components=2, perplexity=30, random_state=42)
    tsne_result = tsne.fit_transform(rotary_codes)

    # 距离计算
    dists = pairwise_distances(rotary_codes)
    np.fill_diagonal(dists, np.inf)
    i, j = np.unravel_index(np.argmin(dists), dists.shape)
    min_dist = dists[i, j]

    plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c=np.arange(N), cmap='viridis', s=20)
    plt.title("t-SNE of RoPE-Axial Encoding (Fixed)")
    plt.show()

import seaborn as sns
def plot_heatmap(freqs_cis, layer_idx=None, title_prefix=""):
    """
    args:
        freqs_cis: [N, H, D] numpy array
    """
    rotary_codes_complex = freqs_cis.reshape(N, -1)
    rotary_codes = np.concatenate([
        rotary_codes_complex.real.numpy(),
        rotary_codes_complex.imag.numpy()
    ], axis=-1).astype(np.float32)
    dists = pairwise_distances(rotary_codes)
    np.fill_diagonal(dists, np.nan)  

    plt.figure(figsize=(6, 5))
    sns.heatmap(dists, cmap='viridis', square=True, cbar=True)
    plt.xlabel("Position Index")
    plt.ylabel("Position Index")
    plt.tight_layout()
    plt.show()

def analyze_mix_model(model, freqs, t_x, t_y, layers=list(range(12)), width=14):
    num_heads = model.num_heads
    embed_dim = model.embed_dim
    freq_dim_per_head = embed_dim // num_heads // 2
    N = len(t_x)

    fig, axes = plt.subplots(4, 3, figsize=(12, 10))
    axes = axes.flatten()

    for idx, layer_idx in enumerate(layers):
        freqs_layer = freqs[:, layer_idx].view(2, num_heads, freq_dim_per_head)
        cis = compute_mixed_cis(freqs_layer, t_x, t_y, num_heads)
        plot_tsne_subplot(cis, layer_idx, axes[idx], N, width)
    plt.tight_layout()
    plt.suptitle("RoPE Injectivity Check across Layers", fontsize=14, y=1.02)
    plt.subplots_adjust(top=0.92)
    plt.show()

def analyze_axial_model(model, end_x, end_y, width=14):
    embed_dim = model.embed_dim
    cis = compute_axial_cis(embed_dim, end_x, end_y)
    N = end_x * end_y
    plot_tsne(cis, N, width)
    
def idx_to_coord(idx, width=14):
    return int(idx % width), int(idx // width)


In [69]:
end_x = end_y = 14 # 对应224x224 / 16 = 14x14 patch
t_x, t_y = init_t_xy(end_x, end_y)
N = end_x * end_y

In [None]:
analyze_axial_model(axial_model, end_x, end_y)

In [4]:
freqs = mix_model.freqs.detach().cpu()
freqs.shape

torch.Size([2, 12, 384])

In [None]:
freqs = mix_model.freqs.detach().cpu()
plot_heatmap(freqs[:,0,:], layer_idx=0, title_prefix="RoPE-Mixed")
analyze_mix_model(mix_model, freqs, t_x, t_y, layers=np.arange(12))

In [16]:
import h5py
import os

root_dir = './data/depth_data'
with h5py.File(os.path.join(root_dir, "depth_train.h5"), "r") as f:
    image = f["image"]
    depth = f["depth"]
    print(image.shape)
    print(type(image[0]))


(27260, 128, 160, 3)
<class 'numpy.ndarray'>
