In [None]:
"""
RayRoPE Demo: Self-Attention for Multi-View Images
"""
import sys
sys.path.append('..')

import torch
import torch.nn as nn
from pos_enc.rayrope import RayRoPE_DotProductAttention
from pos_enc.utils.rayrope_mha import MultiheadAttention

In [None]:
class RayRoPETransformer(nn.Module):
    """A minimal transformer using RayRoPE multi-view self-attention."""
    
    def __init__(self, embed_dim, num_heads, num_layers, patches_x, patches_y, img_w, img_h):
        super().__init__()
        head_dim = embed_dim // num_heads
        
        # Single RayRoPE module shared across all layers
        self.rayrope_attn = RayRoPE_DotProductAttention(
            head_dim=head_dim, patches_x=patches_x, patches_y=patches_y,
            image_width=img_w, image_height=img_h,
            pos_enc_type='d_pj+0_3d', num_rays_per_patch=3, depth_type='predict_dsig',
        )
        
        # MHA layers all use the same RayRoPE attention function
        self.mha_layers = nn.ModuleList([
            MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads,
                               predict_d='predict_dsig', sdpa_fn=self.rayrope_attn.forward)
            for _ in range(num_layers)
        ])
        self.norms = nn.ModuleList([nn.LayerNorm(embed_dim) for _ in range(num_layers)])
    
    def forward(self, x, w2cs, Ks):
        # Precompute RayRoPE encodings once from camera poses
        self.rayrope_attn._precompute_and_cache_apply_fns(w2cs, Ks)
        
        for mha, norm in zip(self.mha_layers, self.norms):
            x = x + mha(norm(x), norm(x), norm(x))
        return x

In [None]:
# Create model (head_dim=48 must be multiple of 24 for RayRoPE)
model = RayRoPETransformer(
    embed_dim=768, num_heads=8, num_layers=2,
    patches_x=8, patches_y=8, img_w=128, img_h=128,
).cuda()

In [None]:
# Create dummy data: 2 batches, 3 cameras, 8x8 patches, 128x128 images
B, num_cams, patches, img_size, embed_dim = 2, 3, 8, 128, 768

x = torch.randn(B, num_cams * patches * patches, embed_dim).cuda()  # Patch embeddings

w2cs = torch.eye(4).expand(B, num_cams, -1, -1).clone().cuda()      # Camera extrinsics
w2cs[:, :, :3, 3] = torch.randn(B, num_cams, 3) * 0.5               # Random translation

Ks = torch.zeros(B, num_cams, 3, 3).cuda()                          # Camera intrinsics
Ks[:, :, 0, 0] = Ks[:, :, 1, 1] = 500                               # fx, fy
Ks[:, :, 0, 2] = Ks[:, :, 1, 2] = img_size / 2                      # cx, cy
Ks[:, :, 2, 2] = 1

print(f"Input: x {x.shape}, w2cs {w2cs.shape}, Ks {Ks.shape}")

In [None]:
# Forward pass
with torch.no_grad():
    out = model(x, w2cs, Ks)

print(f"Output: {out.shape}")