In [1]:
using Flux, Onion

In [None]:
struct RoPE{A<:AbstractArray}
    cos::A
    sin::A
end

Flux.@layer RoPE trainable=()

Base.getindex(rope::RoPE, i) = RoPE(rope.cos[:,i,:,:], rope.sin[:,i,:,:])

function RoPE(dim::Int, end_pos::Int; theta::T=10000f0, start_pos=0) where T
    freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim))
    freqs_complex = cis.(T.(start_pos:end_pos-1) * freqs')
    cos = permutedims(real(freqs_complex), (2, 1))  # (head_dim/2, seq_len)
    sin = permutedims(imag(freqs_complex), (2, 1))
    cos = reshape(cos, (dim÷2, end_pos - start_pos, 1, 1))
    sin = reshape(sin, (dim÷2, end_pos - start_pos, 1, 1))
    return RoPE(cos, sin)
end

#----------------------------------

struct AxisAlignedRoPE{A<:AbstractArray}
    # RoPE matrix -> tensor (one matrix per dim)
    cos::A
    sin::A
end

Flux.@layer AxisAlignedRoPE trainable=()

Base.getindex(rope::AxisAlignedRoPE, i) = AxisAlignedRoPE(rope.cos[:,i,:,:], rope.sin[:,i,:,:])

function AxisAlignedRoPE(dim::Int, coords::{R<:AbstractArray} ; theta::T=10000f0, start_pos=0) where T
    freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim))
    freqs_complex = cis.(T.(coords) * freqs')
    
