In [None]:
using Pkg
Pkg.activate(".")

In [33]:
using MOGMOG
using DLProteinFormats
using Flux
using Onion
using RandomFeatureMaps
using Onion.Einops

In [None]:
data = DLProteinFormats.load(PDBSimpleFlat500)

locations = [d.locs for d in data]
sequences = [d.AAs for d in data]

In [None]:
vec(locations[1])

```julia
function Toy2(dim, depth)
    layers = (;
        loc_rff = RandomFourierFeatures(3 => 64, 0.1f0),
        loc_encoder = Dense(64 => dim, bias=false),
        transformers = [Onion.TransformerBlock(dim, 8) for _ in 1:depth],
        AA_decoder = Dense(dim => 20, bias=false),
    )
    return Toy2(layers)
end

function (m::Toy2)(locs)
    l = m.layers
    x = l.loc_encoder(l.loc_rff(locs))
    for layer in l.transformers
        x = layer(x, 0, nothing)
    end
    aa_logits = l.AA_decoder(x)
    return aa_logits
end
```

In [40]:
using NNlib

In [73]:
rff_dim = 32
embedding_dim = 64

rff = RandomFourierFeatures(1 => rff_dim, 0.1f0)
pre_transformer_proj = Dense(rff_dim => embedding_dim, bias=false)
transformer_blocks = [DART(TransformerBlock(64, 1, 1)) for i in 1:10]



10-element Vector{DART{TransformerBlock{Attention{Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}}, StarGLU{Dense{typeof(identity), Matrix{Float32}, Bool}, typeof(swish)}, RMSNorm{Float32, Vector{Float32}}, RMSNorm{Float32, Vector{Float32}}}}}:
 DART{TransformerBlock{Attention{Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}}, StarGLU{Dense{typeof(identity), Matrix{Float32}, Bool}, typeof(swish)}, RMSNorm{Float32, Vector{Float32}}, RMSNorm{Float32, Vector{Float32}}}}(TransformerBlock{Attention{Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}}, StarGLU{Dense{typeof(id

In [74]:
input_coordinates = rearrange(locations[1], (:K, 1, :L) --> (:K, :L))
rotated_coordinates = transform_molecule(input_coordinates)
coordinate_tokens = rearrange(rotated_coordinates, (:K, :L) --> (1, :K, :L))
clock_tokens = rff(coordinate_tokens)
embeddings = pre_transformer_proj(clock_tokens)
for block in transformer_blocks
    embeddings = transformer_block(embeddings)
end
embeddings

64×3×284 Array{Float32, 3}:
[:, :, 1] =
 -2407.14   -2562.78     -937.067
  1286.41    1602.41     1029.42
 -1107.37    -540.147   -1204.58
  -736.283   -315.754    -683.237
 -2334.51   -1678.19    -1276.43
  3111.12    3049.13     1381.32
  -273.528   -723.555    -542.492
  -748.07    -811.119    -315.506
   637.612    -45.3017   -680.749
   628.865    433.497    -279.523
     ⋮                  
  -441.058    626.187   -1134.05
   702.076    842.322    1670.47
  2093.78    1814.72     1226.65
 -2099.05   -2215.23    -2761.29
  -899.831   -763.283    -875.541
  -339.231   -159.238    -372.918
 -2202.51   -2224.38     -972.282
  -278.958    320.072     268.213
 -1407.58   -1092.7      -549.213

[:, :, 2] =
 -1920.8    -2014.27     -925.375
  1322.87    1584.71      970.839
  -600.665   -331.768   -1212.94
  -211.425    -80.7043   -705.439
 -1695.49   -1455.84    -1326.92
  2645.61    2690.94     1319.66
  -398.563   -688.619    -630.396
  -737.094   -720.233    -265.68
   486.908     4

In [None]:
Onion.causal_mask(rand(5,5))

In [None]:
###foot

# Parameters for the MOGMOG model
const rff_dim = 32               # Random Fourier feature dimension
const embed_dim = 64             # Final transformer input dim
const vocab_size = 10            # Number of atom types (incl. STOP)

# -- Modules for the foot --
const RFF = RandomFourierFeatures(3 => rff_dim, 0.1f0)  # 3D input
const CoordProj = Dense(rff_dim => embed_dim, bias=false)
const AtomEmbed = Embedding(vocab_size, embed_dim)
###################################################################################3

struct CoordEmbed 
    mlp::Chain #this is a multi-layer perceptron (MLP)
end

function CoordEmbed(embed_dim::Int)
    return CoordEmbed(Chain(
        Dense(3, embed_dim, relu),
        Dense(embed_dim, embed_dim)
    ))
end

function (ce::CoordEmbed)(coords::Matrix{Float64})  # (3, L)
    return ce.mlp(coords')'  # Transpose for batch → return (embed_dim, L)
end 
###################################################################################3

function encode_prefix(positions::Matrix{Float64}, atom_types::Vector{Int}, ce::CoordEmbed, ae::Embedding)
    # positions: (3, L)
    # atom_types: (L,)

    coord_embed = ce(positions)         # (embed_dim, L)
    type_embed = ae(atom_types)         # (embed_dim, L)

    return coord_embed .+ type_embed    # Combine the two
end



In [None]:
###TRAINING loop

using Flux, JLD2, StatsBase, Random, Plots
losses_over_time = Float32[]

# Load processed molecules
@load "processed_molecules.jld2" result  # Loads `result::Vector{Molecule}`

# Atom dictionary: atom => integer index
atom_dict = Dict("C"=>1, "O"=>2, "N"=>3, "H"=>4, "F"=>5, "STOP"=>6)  # Customize this

# Model
embed_dim = 64
n_components = 5
vocab_size = length(atom_dict)
model = MOGMOGModel(embed_dim, n_components, vocab_size)

# Optimizer
opt_state = Flux.setup(AdamW(0.001), model)

# Hyperparameters
nepochs = 10 
nbatches = 1000 
batchsize = 16 

# Training loop
for epoch in 1:nepochs # Looks at the data 10 times 
    @info "Epoch $epoch"
    shuffled = shuffle(result) 
    total_loss = 0.0

    for i in 1:nbatches # For every batch 
        mols = shuffled[(i-1)*batchsize+1:min(i*batchsize, end)] # Take 16 molecules from the shuffled list 
        loss_batch = 0.0

        grads = Flux.gradient(model) do m # Calculate all derivatives (gradients) of the model 
            losses = Float32[]
            for mol in mols # Calculate loss for every molecule 
                try
                    l = loss_fn(m, mol, atom_dict)
                    push!(losses, l)
                catch e  
                    @warn "Failed on molecule: $e"
                end
            end
            loss_batch = mean(losses) # Take the mean of the loss
            push!(losses_over_time, loss_batch) 
            return loss_batch
        end

        Flux.update!(opt_state, model, grads[1]) 
        total_loss += loss_batch

        if i % 50 == 0
            println("Batch $i, loss = $(round(loss_batch, digits=4))")
        end
    end

    println("Epoch $epoch done. Avg loss = $(round(total_loss / nbatches, digits=4))")
end

plot(losses_over_time, xlabel="Batch", ylabel="Loss", label="Training Loss", title="Final Loss Curve")
savefig("final_loss_curve.png")

