In [1]:
using Pkg
Pkg.activate(".")
#Pkg.add(["Flux", "DLProteinFormats", "Onion", "RandomFeatureMaps", "StatsBase", "Plots"])

[32m[1m  Activating[22m[39m project at `c:\Users\User\Desktop\SoFo\code\SoFo-2025-Translation-Equivariant-Transformer`


In [2]:
using Flux, DLProteinFormats, Onion, RandomFeatureMaps, StatsBase, Plots

In [3]:
dat = DLProteinFormats.load(PDBSimpleFlat500);

L = 30
train_inds = findall(dat.len .> L)
 
function random_batch(dat, L, B, filt_inds)
    locs = zeros(Float32, 3, L, B)
    inds = sample(filt_inds, B, replace=false)
    AAs = zeros(Int, L, B)
    for (i,ind) in enumerate(inds)
        l_range = rand(1:dat[ind].len - L + 1)
        locs[:, :, i] = dat[ind].locs[:, 1, l_range:l_range+L-1]
        AAs[:, i] = dat[ind].AAs[l_range:l_range+L-1]
    end
    return (;locs, AAs = Flux.onehotbatch(AAs, 1:20))
end

batch = random_batch(dat, L, 10, train_inds);
 
struct Toy0{L}
    layers::L
end
Flux.@layer Toy0
function Toy0()
    layers = (;
        AA_decoder = Dense(3 => 20, bias=false),
    )
    return Toy0(layers)
end
function (m::Toy0)(locs)
    l = m.layers
    aa_logits = l.AA_decoder(locs)
    return aa_logits
end
 
struct Toy1{L}
    layers::L
end
Flux.@layer Toy1
function Toy1(dim, depth)
    layers = (;
        loc_encoder = Dense(3 => dim, bias=false),
        transformers = [Onion.TransformerBlock(dim, 8, rope=Onion.MultiDimRoPE) for _ in 1:depth],
        AA_decoder = Dense(dim => 20, bias=false),
    )
    return Toy1(layers)
end
function (m::Toy1)(locs)
    l = m.layers
    x = l.loc_encoder(locs)
    for transformerblock in l.transformers
        x = transformerblock(x, 0, nothing, locs)
        #locs = updatelocs(x, locs)
    end
    aa_logits = l.AA_decoder(x)
    return aa_logits
end
 
struct Toy2{L}
    layers::L
end
Flux.@layer Toy2
function Toy2(dim, depth)
    layers = (;
        loc_rff = RandomFourierFeatures(3 => 64, 0.1f0),
        loc_encoder = Dense(64 => dim, bias=false),
        transformers = [Onion.TransformerBlock(dim, 8) for _ in 1:depth],
        AA_decoder = Dense(dim => 20, bias=false),
    )
    return Toy2(layers)
end
function (m::Toy2)(locs)
    l = m.layers
    x = l.loc_encoder(l.loc_rff(locs))
    for layer in l.transformers
        x = layer(x, 0, nothing)
    end
    aa_logits = l.AA_decoder(x)
    return aa_logits
end
 

In [4]:
#model = Toy0()
model = Toy1(64, 4)
#model = Toy2(64, 4)
opt_state = Flux.setup(AdamW(eta = 0.001), model)
 
losses = Float32[]

Float32[]

In [5]:
model.layers.transformers[1]

TransformerBlock(
  Attention(
    Dense(64 => 64; bias=false),        [90m# 4_096 parameters[39m
    Dense(64 => 64; bias=false),        [90m# 4_096 parameters[39m
    Dense(64 => 64; bias=false),        [90m# 4_096 parameters[39m
    Dense(64 => 64; bias=false),        [90m# 4_096 parameters[39m
    64,
    8,
    8,
    8,
  ),
  StarGLU(
    Dense(64 => 256; bias=false),       [90m# 16_384 parameters[39m
    Dense(256 => 64; bias=false),       [90m# 16_384 parameters[39m
    Dense(64 => 256; bias=false),       [90m# 16_384 parameters[39m
    NNlib.swish,
  ),
  RMSNorm{Float32, Vector{Float32}}(Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 1.0f-5),  [90m# 64 parameters[39m
  RMSNorm{Float32, Vector{Float32}}(Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 1.0f-5),  [90m# 64 parameters[39m
  Onion.MultiDimRoPE,
) [90m                  #

In [None]:
for epoch in 1:20 # 1:100
    tot_loss = 0f0
    for i in 1:1_000 # 1:10_000
        batch = random_batch(dat, L, 10, train_inds);
        l, grad = Flux.withgradient(model) do m
            aalogits = m(batch.locs)
            Flux.logitcrossentropy(aalogits, batch.AAs)
        end
        Flux.update!(opt_state, model, grad[1])
        tot_loss += l
        if mod(i, 50) == 0
            println(epoch, " ", i, " ", tot_loss/50)
            push!(losses, tot_loss/50)
            tot_loss = 0f0
        end
        (mod(i, 500) == 0) && savefig(plot(losses), "losses_toy_MultiDimRoPE.pdf")
    end
end


1 50 3.372465
1 100 3.0511887
1 150 2.995998
1 200 2.98264
1 250 2.9905777
1 300 2.9700623
1 350 2.9567733
1 400 2.945958
1 450 2.956957
1 500 2.9402626
1 550 2.934617
1 600 2.9285443
1 650 2.9337845
1 700 2.9316099
1 750 2.9213576
1 800 2.9174562
1 850 2.921789
1 900 2.9093676
1 950 2.915376
1 1000 2.9206727
2 50 2.9170485
2 100 2.899255
2 150 2.9003658
2 200 2.9075286
2 250 2.8878295
2 300 2.9025567
2 350 2.8923025
2 400 2.904101
2 450 2.8907034
2 500 2.8906848
2 550 2.8873181
2 600 2.8805869
2 650 2.8794146
2 700 2.8790338
2 750 2.8782744
2 800 2.878731
2 850 2.8831854
2 900 2.8802595
2 950 2.8888102
2 1000 2.8699756
3 50 2.8699083
3 100 2.8646445
3 150 2.8553307
3 200 2.863233
3 250 2.8625739
3 300 2.8719995
3 350 2.861191
3 400 2.8535023
3 450 2.851653
3 500 2.864762
3 550 2.8573709
3 600 2.845413
3 650 2.845547
3 700 2.8695822
3 750 2.8478754
3 800 2.851848
3 850 2.8292806
3 900 2.8435218
3 950 2.8360178
3 1000 2.8423874
4 50 2.8462903
4 100 2.8544028
4 150 2.8419824
4 200 2.8480