In [1]:
using Pkg
Pkg.activate("..")

[32m[1m  Activating[22m[39m project at `~/.julia/dev/MOGMOG.jl`


In [3]:
using MOGMOG
using DLProteinFormats
using Flux
using Onion
using RandomFeatureMaps
using Onion.Einops

In [4]:
using Flux, JLD2, Random

struct Molecule
    atoms::Vector{String}
    positions::Matrix{Float64}
end

Base.length(mol::Molecule) = length(mol.atoms)

# Load processed molecules
# @load expanduser("~/processed_molecules.jld2") result  # Loads `result::Vector{Molecule}`


In [6]:
atom_dict = Dict(name => i for (i, name) in enumerate(["C", "F", "H", "N", "O", "STOP"]))
PAD = atom_dict["STOP"]
vocab_size = length(atom_dict)

6

In [29]:

# Model
embed_dim = 128
n_components = 5
vocab_size = length(atom_dict)
model = MOGMOGModel(embed_dim, n_components, vocab_size, depth=4)

# Optimizer
opt_state = Flux.setup(AdamW(0.001f0), model)


(foot = (current_coord_embed = (layers = ((W = (),), (weight = [32mLeaf(AdamW(eta=0.001, beta=(0.9, 0.999), lambda=0.0, epsilon=1.0e-8, couple=true), [39m(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))[32m)[39m, bias = [32mLeaf(AdamW(eta=0.001, beta=(0.9, 0.999), lambda=0.0, epsilon=1.0e-8, couple=true), [39m(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))[32m)[39m, σ = ())),), atom_embed = (weight = [32mLeaf(AdamW(eta=0.001, beta=(0.9, 0.999), lambda=0.0, epsilon=1.0e-8, couple=true), [39m(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0

In [30]:
sum(length.(Flux.trainables(model)))

1119119

In [None]:


# Pad function
function pad_batch(mols::Vector{Molecule})
    max_len = maximum(length.(mols)) + 1  # Find the maximum length of molecules
    B = length(mols)
    atom_types = fill(PAD, max_len, B)
    coordinates = zeros(Float32, 3, max_len, B) .+ 0
    atom_type_mask = zeros(Float32, max_len - 1, B)
    coordinate_mask = zeros(Float32, max_len - 1, B)

    for (i, mol) in enumerate(mols)
        L = length(mol)
        for j in 1:L
            atom_types[j, i] = get(atom_dict, mol.atoms[j], PAD)
        end
        coordinates[:, 1:L, i] = mol.positions[1:L, :]'
        coordinate_mask[1:L-1, i] .= 1.0
        atom_type_mask[1:L, i] .= 1.0
    end
    return atom_types, coordinates, atom_type_mask, coordinate_mask
end

pad_batch (generic function with 2 methods)

In [39]:
result = [
    Molecule(["C", "H", "H", "H", "H"], randn(5, 3)),
    Molecule(["C", "F", "C"], randn(3, 3)),
]

pad_batch(result) .|> display;

6×2 Matrix{Int64}:
 1  1
 3  2
 3  1
 3  6
 3  6
 6  6

3×6×2 Array{Float32, 3}:
[:, :, 1] =
  0.158273   1.62107   1.8172    1.27361   0.828848  0.0
  1.18803   -1.0081    0.363998  1.05924  -0.846691  0.0
 -0.484876  -1.13739  -1.07275   1.26865   1.18528   0.0

[:, :, 2] =
 -1.48635   -0.443743  1.24576   0.0  0.0  0.0
 -0.355486  -0.140105  0.325522  0.0  0.0  0.0
 -0.925501  -0.544627  0.782694  0.0  0.0  0.0

5×2 Matrix{Float32}:
 1.0  1.0
 1.0  1.0
 1.0  1.0
 1.0  0.0
 1.0  0.0

5×2 Matrix{Float32}:
 1.0  1.0
 1.0  1.0
 1.0  0.0
 1.0  0.0
 0.0  0.0

In [69]:
loss(model, pad_batch(result)...)

coordinates = zeros(Float32, 3, max_len, B) .+ 0 = Float32[0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0;;; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0]


15.972348943862457

In [71]:
loss(model, pad_batch(result)...)

coordinates = zeros(Float32, 3, max_len, B) .+ 0 = Float32[0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0;;; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0]


15.86637760749771

In [None]:
nepochs = 1
nbatches = 100
batchsize = 16
max_len = 32
all_losses = []

# Training loop
for epoch in 1:nepochs
    @info "Epoch $epoch"
    shuffled = shuffle(result)
    epoch_losses = []

    for i in 1:nbatches
        mols = shuffled[i:i+batchsize-1]
        atom_ids, positions, atom_mask, coordinate_mask = pad_batch(mols, atom_dict, max_len)

        loss_batch, (grad_model,) = Flux.withgradient(model) do m
            loss(m, atom_ids, positions, atom_mask, coordinate_mask)
        end

        Flux.update!(opt_state, model, grad_model)
        push!(all_losses, loss_batch)
        push!(epoch_losses, loss_batch)

        println("Batch $i, loss = $(round(loss_batch, digits=4))")
    end

    println("Epoch $epoch done. Avg loss = $(round(mean(epoch_losses), digits=4))")
end


In [None]:
data = DLProteinFormats.load(PDBSimpleFlat500)

locations = [d.locs for d in data]
sequences = [d.AAs for d in data]

In [None]:
vec(locations[1])

In [73]:
rff_dim = 32
embedding_dim = 64

rff = RandomFourierFeatures(1 => rff_dim, 0.1f0)
pre_transformer_proj = Dense(rff_dim => embedding_dim, bias=false)
transformer_blocks = [DART(TransformerBlock(64, 1, 1)) for i in 1:10]



10-element Vector{DART{TransformerBlock{Attention{Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}}, StarGLU{Dense{typeof(identity), Matrix{Float32}, Bool}, typeof(swish)}, RMSNorm{Float32, Vector{Float32}}, RMSNorm{Float32, Vector{Float32}}}}}:
 DART{TransformerBlock{Attention{Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}}, StarGLU{Dense{typeof(identity), Matrix{Float32}, Bool}, typeof(swish)}, RMSNorm{Float32, Vector{Float32}}, RMSNorm{Float32, Vector{Float32}}}}(TransformerBlock{Attention{Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}, Dense{typeof(identity), Matrix{Float32}, Bool}}, StarGLU{Dense{typeof(id

In [74]:
input_coordinates = rearrange(locations[1], (:K, 1, :L) --> (:K, :L))
rotated_coordinates = transform_molecule(input_coordinates)
coordinate_tokens = rearrange(rotated_coordinates, (:K, :L) --> (1, :K, :L))
clock_tokens = rff(coordinate_tokens)
embeddings = pre_transformer_proj(clock_tokens)
for block in transformer_blocks
    embeddings = transformer_block(embeddings)
end
embeddings

64×3×284 Array{Float32, 3}:
[:, :, 1] =
 -2407.14   -2562.78     -937.067
  1286.41    1602.41     1029.42
 -1107.37    -540.147   -1204.58
  -736.283   -315.754    -683.237
 -2334.51   -1678.19    -1276.43
  3111.12    3049.13     1381.32
  -273.528   -723.555    -542.492
  -748.07    -811.119    -315.506
   637.612    -45.3017   -680.749
   628.865    433.497    -279.523
     ⋮                  
  -441.058    626.187   -1134.05
   702.076    842.322    1670.47
  2093.78    1814.72     1226.65
 -2099.05   -2215.23    -2761.29
  -899.831   -763.283    -875.541
  -339.231   -159.238    -372.918
 -2202.51   -2224.38     -972.282
  -278.958    320.072     268.213
 -1407.58   -1092.7      -549.213

[:, :, 2] =
 -1920.8    -2014.27     -925.375
  1322.87    1584.71      970.839
  -600.665   -331.768   -1212.94
  -211.425    -80.7043   -705.439
 -1695.49   -1455.84    -1326.92
  2645.61    2690.94     1319.66
  -398.563   -688.619    -630.396
  -737.094   -720.233    -265.68
   486.908     4

In [None]:
Onion.causal_mask(rand(5,5))

In [None]:
###foot

# Parameters for the MOGMOG model
const rff_dim = 32               # Random Fourier feature dimension
const embed_dim = 64             # Final transformer input dim
const vocab_size = 10            # Number of atom types (incl. STOP)

# -- Modules for the foot --
const RFF = RandomFourierFeatures(3 => rff_dim, 0.1f0)  # 3D input
const CoordProj = Dense(rff_dim => embed_dim, bias=false)
const AtomEmbed = Embedding(vocab_size, embed_dim)
###################################################################################3

struct CoordEmbed 
    mlp::Chain #this is a multi-layer perceptron (MLP)
end

function CoordEmbed(embed_dim::Int)
    return CoordEmbed(Chain(
        Dense(3, embed_dim, relu),
        Dense(embed_dim, embed_dim)
    ))
end

function (ce::CoordEmbed)(coords::Matrix{Float64})  # (3, L)
    return ce.mlp(coords')'  # Transpose for batch → return (embed_dim, L)
end 
###################################################################################3

function encode_prefix(positions::Matrix{Float64}, atom_types::Vector{Int}, ce::CoordEmbed, ae::Embedding)
    # positions: (3, L)
    # atom_types: (L,)

    coord_embed = ce(positions)         # (embed_dim, L)
    type_embed = ae(atom_types)         # (embed_dim, L)

    return coord_embed .+ type_embed    # Combine the two
end



In [9]:
using Pkg
Pkg.activate("..")

using Flux, JLD2, Random, Statistics, MOGMOG, Onion, RandomFeatureMaps, Onion.Einops

include("loss.jl")
include("MOGhead.jl")
include("MOGfoot.jl")
include("model.jl")
include("utilities.jl")

export logpdf_MOG, MoGAxisHead, MOGMOGModel, transform_molecule, loss

# Load molecules
@load "processed_molecules.jld2" result

# Atom dictionary
atom_dict = Dict(name => i for (i, name) in enumerate(["C", "F", "H", "N", "O", "STOP"]))
PAD = atom_dict["STOP"]

# Hyperparameters
embed_dim = 64
n_components = 5
vocab_size = length(atom_dict)
depth = 4
max_len = 32
batchsize = 4
nbatches = 1000
nepochs = 100

# Initialize model and optimizer
model = MOGMOGModel(embed_dim, n_components, vocab_size, depth=depth)
opt_state = Flux.setup(AdamW(0.001f0), model)

# Padding function
function pad_batch(mols::Vector{Molecule})
    B = length(mols)
    atom_ids = fill(PAD, max_len, B)
    positions = zeros(Float32, 3, max_len, B)
    atom_mask = zeros(Float32, max_len - 1, B)
    coord_mask = zeros(Float32, max_len - 1, B)

    for (i, mol) in enumerate(mols)
        L = min(length(mol.atoms), max_len - 1)
        for j in 1:L
            atom_ids[j, i] = atom_dict[mol.atoms[j]]
            positions[:, j, i] = Float32.(mol.positions[j, :])
            atom_mask[j, i] = 1.0
            coord_mask[j, i] = 1.0
        end
        atom_ids[L+1, i] = PAD
    end
    return atom_ids, positions, atom_mask, coord_mask
end

# Training loop
all_losses = Float32[]
for epoch in 1:nepochs
    println("Epoch $epoch")
    shuffled = shuffle(result)

    for i in 1:nbatches
        mols = shuffled[i:i+batchsize-1]
        atom_ids, pos, atom_mask, coord_mask = pad_batch(mols)

        loss_val, (grad,) = Flux.withgradient(model) do m
            loss(m, atom_ids, pos, atom_mask, coord_mask)
        end

        Flux.update!(opt_state, model, grad)
        push!(all_losses, loss_val)

        println("Batch $i, loss = $(round(loss_val, digits=4))")
    end
end

# Save loss plot
using Plots
plot(all_losses, title="Training Loss", xlabel="Batch", ylabel="Loss")
savefig("training_loss.pdf")

# Save checkpoint
@save "mogmodel_checkpoint.jld2" model all_losses


[32m[1m  Activating[22m[39m project at `~/MOGMOG.jl-1`


SystemError: SystemError: opening file "/home/star/MOGMOG.jl-1/Notebook/loss.jl": No such file or directory