In [None]:
using Pkg
Pkg.activate(".")
#Pkg.add(["Flux", "DLProteinFormats", "Onion", "RandomFeatureMaps", "StatsBase", "Plots"])

In [1]:
using Flux, DLProteinFormats, Onion, RandomFeatureMaps, StatsBase, Plots

In [4]:
dat = DLProteinFormats.load(PDBSimpleFlat500);

L = 30
train_inds = findall(dat.len .> L)
 
function random_batch(dat, L, B, filt_inds)
    locs = zeros(Float32, 3, L, B)
    inds = sample(filt_inds, B, replace=false)
    AAs = zeros(Int, L, B)
    for (i,ind) in enumerate(inds)
        l_range = rand(1:dat[ind].len - L + 1)
        locs[:, :, i] = dat[ind].locs[:, 1, l_range:l_range+L-1]
        AAs[:, i] = dat[ind].AAs[l_range:l_range+L-1]
    end
    return (;locs, AAs = Flux.onehotbatch(AAs, 1:20))
end

batch = random_batch(dat, L, 10, train_inds);
 
struct Toy0{L}
    layers::L
end
Flux.@layer Toy0
function Toy0()
    layers = (;
        AA_decoder = Dense(3 => 20, bias=false),
    )
    return Toy0(layers)
end
function (m::Toy0)(locs)
    l = m.layers
    aa_logits = l.AA_decoder(locs)
    return aa_logits
end
 
struct Toy1{L}
    layers::L
end
Flux.@layer Toy1
function Toy1(dim, depth)
    layers = (;
        loc_encoder = Dense(3 => dim, bias=false),
        transformers = [Onion.TransformerBlock(dim, 8) for _ in 1:depth],
        AA_decoder = Dense(dim => 20, bias=false),
    )
    return Toy1(layers)
end
function (m::Toy1)(locs)
    l = m.layers
    x = l.loc_encoder(locs)
    for layer in l.transformers
        x = layer(x, 0, nothing)
    end
    aa_logits = l.AA_decoder(x)
    return aa_logits
end
 
struct Toy2{L}
    layers::L
end
Flux.@layer Toy2
function Toy2(dim, depth)
    layers = (;
        loc_rff = RandomFourierFeatures(3 => 64, 0.1f0),
        loc_encoder = Dense(64 => dim, bias=false),
        transformers = [Onion.TransformerBlock(dim, 8) for _ in 1:depth],
        AA_decoder = Dense(dim => 20, bias=false),
    )
    return Toy2(layers)
end
function (m::Toy2)(locs)
    l = m.layers
    x = l.loc_encoder(l.loc_rff(locs))
    for layer in l.transformers
        x = layer(x, 0, nothing)
    end
    aa_logits = l.AA_decoder(x)
    return aa_logits
end
 
 
#model = Toy0()
#model = Toy1(64, 4)
model = Toy2(64, 4)
opt_state = Flux.setup(AdamW(eta = 0.001), model)
 
losses = Float32[]
for epoch in 1:20 # 1:100
    tot_loss = 0f0
    for i in 1:1_000 # 1:10_000
        batch = random_batch(dat, L, 10, train_inds);
        l, grad = Flux.withgradient(model) do m
            aalogits = m(batch.locs)
            Flux.logitcrossentropy(aalogits, batch.AAs)
        end
        Flux.update!(opt_state, model, grad[1])
        tot_loss += l
        if mod(i, 50) == 0
            println(epoch, " ", i, " ", tot_loss/50)
            push!(losses, tot_loss/50)
            tot_loss = 0f0
        end
        (mod(i, 500) == 0) && savefig(plot(losses), "losses_toy2.pdf")
    end
end


1 50 3.4911306
1 100 3.0472863
1 150 3.01849
1 200 2.9722974
1 250 2.9606898
1 300 2.958691
1 350 2.9460843
1 400 2.939723
1 450 2.930393
1 500 2.931485
1 550 2.9275763
1 600 2.918871
1 650 2.909648
1 700 2.911985
1 750 2.901052
1 800 2.9014356
1 850 2.9049335
1 900 2.8998847
1 950 2.8961618
1 1000 2.899681
2 50 2.8922176
2 100 2.8984327
2 150 2.889498
2 200 2.898843
2 250 2.8897498
2 300 2.8883402
2 350 2.88637
2 400 2.8824177
2 450 2.876782
2 500 2.883402
2 550 2.8808005
2 600 2.8742821
2 650 2.8735802
2 700 2.870503
2 750 2.867412
2 800 2.8741212
2 850 2.87578
2 900 2.8682227
2 950 2.8738542
2 1000 2.8694134
3 50 2.8631284
3 100 2.8688242
3 150 2.8584378
3 200 2.8599274
3 250 2.8522646
3 300 2.8653042
3 350 2.8578463
3 400 2.8551874
3 450 2.8587842
3 500 2.8461316
3 550 2.8598113
3 600 2.8554175
3 650 2.8587558
3 700 2.8481321
3 750 2.8597116
3 800 2.8406458
3 850 2.8399584
3 900 2.846041
3 950 2.8295605
3 1000 2.8429916
4 50 2.8316183
4 100 2.8363495
4 150 2.8384233
4 200 2.8223884