In [2]:
using Flux, Onion, DLProteinFormats

In [40]:
using StatsBase # gives sample()
using Plots # gives savefig()

In [44]:
dat = DLProteinFormats.load(PDBSimpleFlat500)

L = 30
train_inds = findall(dat.len .> L)
 
function random_batch(dat, L, B, filt_inds)
    locs = zeros(Float32, 3, L, B)
    inds = sample(filt_inds, B, replace=false)
    AAs = zeros(Int, L, B)
    for (i,ind) in enumerate(inds)
        l_range = rand(1:dat[ind].len - L + 1)
        locs[:, :, i] = dat[ind].locs[:, 1, l_range:l_range+L-1]
        AAs[:, i] = dat[ind].AAs[l_range:l_range+L-1]
    end
    return (;locs, AAs = Flux.onehotbatch(AAs, 1:20))
end

batch = random_batch(dat, L, 10, train_inds);

In [67]:
struct SimpleModel{L}
    layers::L
end
Flux.@layer SimpleModel
function SimpleModel(dim::Int, depth::Int)
    layers = (;
        encoder = Dense(3 => dim, bias=false),
        hidden_layers = [Onion.TransformerBlock(dim, 8) for _ in 1:depth],
        decoder = Dense(dim => 20, bias=false), 
    )
    return SimpleModel(layers)
end

function (ForwardPass::SimpleModel)(coords)
    layers = ForwardPass.layers
    x = layers.encoder(coords)
    for layer in layers.hidden_layers
        x = layer(x, 0, nothing)
    end
    y_bar = layers.decoder(x)
    return y_bar
end

In [68]:
model = SimpleModel(64, 4)
opt_state = Flux.setup(AdamW(eta=0.001), model)

losses = Float32[]
for epoch in 1:20 # 1:100
    tot_loss = 0f0
    for i in 1:1_000 # 1:10_000
        batch = random_batch(dat, L, 10, train_inds);
        l, grad = Flux.withgradient(model) do ForwardPass
            aalogits = ForwardPass(batch.locs)
            Flux.logitcrossentropy(aalogits, batch.AAs)
        end
        Flux.update!(opt_state, model, grad[1])
        tot_loss += l
        if mod(i, 50) == 0
            println(epoch, " ", i, " ", tot_loss/50)
            push!(losses, tot_loss/50)
            tot_loss = 0f0
        end
        (mod(i, 500) == 0) && savefig(plot(losses), "custom_toy.pdf")
    end
end



1 50 3.4297
1 100 3.0634942
1 150 2.9956083
1 200 2.988072
1 250 2.9680076
1 300 2.9576998
1 350 2.952713
1 400 2.9547837
1 450 2.9388642
1 500 2.931306
1 550 2.931523
1 600 2.9228666
1 650 2.921827
1 700 2.915366
1 750 2.914993
1 800 2.918815
1 850 2.9043837
1 900 2.9029422
1 950 2.8820863
1 1000 2.8963957
2 50 2.905801
2 100 2.902774
2 150 2.9038727
2 200 2.8814166
2 250 2.8916647
2 300 2.894657
2 350 2.8968961
2 400 2.8816602
2 450 2.890688
2 500 2.8896322
2 550 2.8832023
2 600 2.8814785
2 650 2.8737054
2 700 2.894975
2 750 2.874523
2 800 2.8723621
2 850 2.8695636
2 900 2.8700352
2 950 2.8704917
2 1000 2.8554022
3 50 2.8715944
3 100 2.8681583
3 150 2.8546772
3 200 2.8627114
3 250 2.8590121
3 300 2.8450654
3 350 2.8441718
3 400 2.860211
3 450 2.8468268
3 500 2.8538435
3 550 2.8608017
3 600 2.848368
3 650 2.8359292
3 700 2.86243
3 750 2.8488085
3 800 2.8577569
3 850 2.8509333
3 900 2.8499305
3 950 2.8492718
3 1000 2.831536
4 50 2.8428333
4 100 2.8122647
4 150 2.83724
4 200 2.8214958
4