In [None]:
# ========== radial_tokenizer TEST ==========

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os

# ========== Configuration ==========
output_name = "sample1"
output_dir = "output"
img_path = "C:/Users/denni/Downloads/test.png"  # Change if needed

# ========= Load Tokens ==========
tokens_192d_path = os.path.join(output_dir, "tokens_192D", f"{output_name}.pt")
tokens_48d_path = os.path.join(output_dir, "tokens_48D", f"{output_name}.pt")
weights_path = os.path.join(output_dir, "projection_weights", f"{output_name}.pt")

tokens_192d = torch.load(tokens_192d_path)
print("✅ Loaded 192D Tokens:", tokens_192d.shape)

# Optional: Load 48 tokens
if os.path.exists(tokens_48d_path):
    tokens_48d = torch.load(tokens_48d_path)
    print("📎 Loaded 48D Tokens:", tokens_48d.shape)
    df_48d = pd.DataFrame(tokens_48d.squeeze(0).numpy(),
                         columns=[f"Feature_{i+1}" for i in range(48)],
                         index=[f"Ring_{i+1}" for i in range(4)])
    display(df_48d)
else:
    print("⚠️ No 48D tokens found (likely disabled via save_48d=False)")

# ========= Inspect 192D Tokens ==========
tokens_192d_np = tokens_192d.squeeze(0).detach().numpy()
df_192d = pd.DataFrame(tokens_192d_np,
                       columns=[f"Dim_{i+1}" for i in range(192)],
                       index=[f"Ring_{i+1}" for i in range(4)])
display(df_192d)

# ========= Plot Embeddings Per Ring ==========
plt.figure(figsize=(10, 5))
for i in range(4):
    plt.plot(tokens_192d_np[i], label=f"Ring {i+1}")
plt.title("Projected 192D Embeddings Per Ring")
plt.xlabel("Embedding Dimension")
plt.ylabel("Value")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# ========= Inspect Projection Weights ==========
proj_weights = torch.load(weights_path)
first_weight = proj_weights["proj.weight"]
print("🔍 Projection weight shape:", first_weight.shape)
print("→ First 3 output dimensions from 48D input:\n", first_weight[:3])

# ========= Visualize Ring Overlay on Image ==========
def visualize_rings(image_path, center=(64, 64), rings=[20, 40, 60, 80]):
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"Image not found: {image_path}")
    image = cv2.resize(image, (128, 128))
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]

    overlay = image_rgb.copy()
    for i, r in enumerate(rings):
        cv2.circle(overlay, center, r, colors[i], thickness=2)

    plt.figure(figsize=(4, 4))
    plt.imshow(overlay)
    plt.title("Overlayed Radial Rings")
    plt.axis("off")
    plt.tight_layout()
    plt.show()

visualize_rings(img_path, center=(64, 64))

In [None]:
# ========== radial_positional_encoding TEST ==========
import torch
import os
import numpy as np
import matplotlib.pyplot as plt

# === Paths ===
input_192d_path = "output/tokens_192D/sample1.pt"
encoded_path = "output/encoded_tokens/encoded_radial_tokens.pt"

# === Load tensors ===
tokens_192d = torch.load(input_192d_path)           # shape: [1, 4, 192]
encoded_tokens = torch.load(encoded_path)           # shape: [1, 4, 192]

# === Sanity Check ===
assert tokens_192d.shape == encoded_tokens.shape == (1, 4, 192), "Shape mismatch"

# === Convert to NumPy ===
tokens_192d_np = tokens_192d.squeeze(0).detach().numpy()       # [4, 192]
encoded_np = encoded_tokens.squeeze(0).detach().numpy()        # [4, 192]

# === Plot Original Projected Embeddings ===
plt.figure(figsize=(10, 5))
for i in range(4):
    plt.plot(tokens_192d_np[i], label=f"Ring {i+1}")
plt.title("Projected 192D Embeddings Per Ring (Before Encoding)")
plt.xlabel("Embedding Dimension")
plt.ylabel("Value")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# === Plot Encoded Embeddings ===
plt.figure(figsize=(10, 5))
for i in range(4):
    plt.plot(encoded_np[i], label=f"Ring {i+1}")
plt.title("Encoded 192D Embeddings Per Ring (After Positional Encoding)")
plt.xlabel("Embedding Dimension")
plt.ylabel("Value")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# === Plot Difference Per Ring ===
plt.figure(figsize=(10, 5))
for i in range(4):
    diff = encoded_np[i] - tokens_192d_np[i]
    plt.plot(diff, label=f"Ring {i+1}")
plt.title("Difference: Encoded - Original Embeddings")
plt.xlabel("Embedding Dimension")
plt.ylabel("Delta Value")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()