In [1]:
using Flux
using CUDA

In [2]:
using LinearAlgebra
using Distributions

In [7]:
struct Attention{I <: Integer, O, F1 <: AbstractArray, F2 <: AbstractArray}
    d_in::I
    d_k::I
    d_out::I
    sqrt_d_k::O
    
    Wq::F1
    Wk::F1
    Wv::F1
    bv::F2
end


function Attention(d_in::Int, d_k::Int, d_out::Int)
    sqrt_d = Float32(sqrt(d_k))
    Wq = randn(Float32, d_out, d_in)
    Wk = randn(Float32, d_k, d_in)
    Wv = randn(Float32, d_k, d_in)
    bv = randn(Float32, d_out)
    return Attention(d_in, d_k, d_out, sqrt_d, Wq, Wk, Wv, bv)
end

struct RMSLayerNorm{F} #<: AbstractArray
    #d_in::I
    g::F
end

function RMSLayerNorm(d_in::Int)
    g = ones(Float32, d_in)
    return RMSLayerNorm(g)
end

Flux.trainable(a::Attention) = (Wq=a.Wq, Wk = a.Wk, Wv= a.Wv, b=a.bv)

Flux.trainable(a::RMSLayerNorm) = (; g = a.g)

function (m::Attention)(x::AbstractArray)
    q = m.Wq * x
    k = m.Wk * x
    v = m.Wv * x
    a = softmax(q' * k / m.sqrt_d_k, dims=2)
    #a = softmax(q * k')
    return a * v .+ m.bv
end

function (m::RMSLayerNorm)(x::AbstractArray)
    return Diagonal(m.g) * x / Diagonal(sqrt.(vec(mean(x.^2, dims=1))))
end

Flux.@layer Attention
#Flux.@functor RMSLayerNorm
Flux.@layer RMSLayerNorm

In [30]:
model = Chain(Dense(5,5), RMSLayerNorm(5), tanh_fast) |> gpu

Chain(
  Dense(5 => 5),                        [90m# 30 parameters[39m
  RMSLayerNorm{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}(Float32[1.0, 1.0, 1.0, 1.0, 1.0]),  [90m# 5 parameters[39m
  NNlib.tanh_fast,
) [90m                  # Total: 3 arrays, [39m35 parameters, 512 bytes.

In [31]:
test = cu(rand(Float32, 5, 20))

5×20 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 0.710282  0.468841  0.155123   0.984254  …  0.659424  0.70125    0.575374
 0.858692  0.511139  0.0676039  0.292649     0.177536  0.591967   0.928848
 0.169994  0.255369  0.972712   0.274113     0.556985  0.0961319  0.0833815
 0.489282  0.429942  0.206087   0.937918     0.233601  0.517326   0.701799
 0.583291  0.703098  0.819706   0.089069     0.506761  0.613294   0.445674

In [32]:
model(test)

5×20 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
  0.602715   0.457042   0.288227  …   0.70698    0.579766   0.41954
 -0.913179  -0.891117  -0.848338     -0.865343  -0.891189  -0.907478
  0.754435   0.761685   0.623224      0.521342   0.749178   0.818099
  0.360557   0.627924   0.574335      0.671513   0.678817   0.326798
  0.764369   0.794513   0.913013      0.841089   0.738366   0.77582

In [17]:
test2 = mean(test, dims=1)

1×10 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 0.690286  0.659953  0.444712  0.559699  …  0.48145  0.452321  0.443771

In [18]:
test3 = vec(test2)

10-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
 0.6902859
 0.6599535
 0.4447119
 0.5596994
 0.46951193
 0.4326908
 0.47565067
 0.48144966
 0.45232105
 0.44377148

In [6]:
model = Chain(Dense(20,20), RMSLayerNorm(20)) |> gpu

ErrorException: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.

In [28]:
model2 = RMSLayerNorm(20) |> gpu

RMSLayerNorm{Diagonal{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}(Float32[1.0 0.0 … 0.0 0.0; 0.0 1.0 … 0.0 0.0; … ; 0.0 0.0 … 1.0 0.0; 0.0 0.0 … 0.0 1.0])

In [29]:
model1 = Dense(20,20) |> gpu

Dense(20 => 20)     [90m# 420 parameters[39m

In [30]:
model_tot = Chain(model1, model2)

ErrorException: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.

In [33]:
test = randn(Float32, 20, 10) |> gpu

20×10 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
  0.112839    0.293395  -0.492453  …  -0.286902    0.352718    0.22967
  0.708213   -2.03752   -1.18512       1.55661     0.33228    -1.6682
  0.858796   -0.779169   0.393398      1.50495    -0.430251    0.71641
  0.516096   -0.788281  -1.11363      -0.405662   -0.272435   -0.237644
 -2.60874    -1.04056    0.432863     -1.47429    -1.09094    -0.832609
  1.10736    -1.40061   -0.140503  …  -0.575804   -0.535272   -1.04574
 -1.06602    -0.528929  -0.720663      0.319153   -0.460791   -1.03256
  1.08024     0.382457   1.46068      -1.42958    -0.670673    0.0583352
  0.483158   -1.34547    0.66802      -0.400428   -1.59719    -1.3818
  1.60576     0.216369   0.212982     -0.240003    0.336564   -0.347004
  1.89623     0.68115    1.09173   …  -0.902484    0.528603   -0.69045
  1.93298    -0.384858  -1.26968       0.602897   -1.32293    -0.51338
  1.29343    -0.140246  -0.851839      0.0135565  -0.890413    0.645362
  0.100368   -1.28824  

In [34]:
model(test)

20×10 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
  0.101854    0.298945  -0.571124  …  -0.331914    0.351924    0.264362
  0.639269   -2.07606   -1.37445       1.80082     0.331532   -1.92019
  0.775193   -0.793907   0.456245      1.74106    -0.429282    0.824626
  0.465854   -0.803192  -1.29154      -0.469305   -0.271821   -0.273541
 -2.35478    -1.06024    0.502015     -1.70559    -1.08849    -0.958378
  0.999559   -1.42711   -0.162949  …  -0.66614    -0.534067   -1.2037
 -0.962245   -0.538934  -0.835792      0.369224   -0.459753   -1.18854
  0.975083    0.389692   1.69403      -1.65386    -0.669163    0.067147
  0.436123   -1.37092    0.774739     -0.46325    -1.59359    -1.59053
  1.44944     0.220462   0.247007     -0.277656    0.335806   -0.39942
  1.71163     0.694034   1.26614   …  -1.04407     0.527412   -0.794746
  1.7448     -0.392137  -1.47251       0.697484   -1.31995    -0.590929
  1.16751    -0.142899  -0.987924      0.0156834  -0.888408    0.742846
  0.090597   -1.3126