In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd

# === Load Data ===
npy_data = np.load("radial_pos_enc.npy")
pt_data = torch.load("radial_pos_enc.pt")

# === Print Shapes ===
print("NPY shape:", npy_data.shape)
print("PT shape:", pt_data.shape)

# === Display First Ring Vector from Batch 1 ===
print("\nNPY - Batch 1, Ring 1 vector (first 10 dims):")
print(npy_data[0, 0][:10])

print("\nPT - Batch 1, Ring 1 vector (first 10 dims):")
print(pt_data[0, 0][:10])

# === Visualize All 4 Ring Embeddings ===
plt.figure(figsize=(10, 4))
for ring in range(npy_data.shape[1]):
    plt.plot(npy_data[0, ring], label=f"Ring {ring+1}")
plt.title("Learnable Positional Embeddings (Batch 1)")
plt.xlabel("Embedding Dimension")
plt.ylabel("Value")
plt.legend()
plt.tight_layout()
plt.show()

# === Display Full Ring Vectors in Table ===
df = pd.DataFrame(npy_data[0], index=[f"Ring {i+1}" for i in range(4)])
df.columns = [f"D{i}" for i in range(npy_data.shape[2])]
df