## Set up dependencies

In [11]:
using Pkg;
Pkg.activate(".")

Pkg.add(["ProteinChains", "BSON", "Flux", "ConcreteStructs", "ChainRulesCore", "Einops", "RandomFeatureMaps", "Onion"])

using ProteinChains, BSON, Flux, ConcreteStructs, ChainRulesCore, Einops, RandomFeatureMaps
using Flux: onehotbatch, mse, AdamW
includet("../src/naive.jl")

[32m[1m  Activating[22m[39m project at `~/Desktop/SOFO/TranslationalEquivariance/SideChainTransformer/scripts`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/Desktop/SOFO/TranslationalEquivariance/SideChainTransformer/scripts/Project.toml`
[32m[1m  No Changes[22m[39m to `~/Desktop/SOFO/TranslationalEquivariance/SideChainTransformer/scripts/Manifest.toml`


## Data

In [2]:
ATOMNAMES = String["C", "CA", "CB", "CD", "CD1", "CD2", "CE", "CE1", "CE2", "CE3", "CG", "CG1", "CG2", "CH2", "CZ", "CZ2", "CZ3", "H", "H1", "H2", "H3", "HA", "HA2", "HA3", "HB", "HB1", "HB2", "HB3", "HD1", "HD11", "HD12", "HD13", "HD2", "HD21", "HD22", "HD23", "HD3", "HE", "HE1", "HE2", "HE21", "HE22", "HE3", "HG", "HG1", "HG11", "HG12", "HG13", "HG2", "HG21", "HG22", "HG23", "HG3", "HH", "HH11", "HH12", "HH2", "HH21", "HH22", "HZ", "HZ1", "HZ2", "HZ3", "N", "ND1", "ND2", "NE", "NE1", "NE2", "NH1", "NH2", "NZ", "O", "OD1", "OD2", "OE1", "OE2", "OG", "OG1", "OH", "OXT", "SD", "SG"]
AA_TO_ATOMS = Dict("Q" => ["C", "CA", "CB", "CD", "CG", "N", "NE2", "O", "OE1"], "W" => ["C", "CA", "CB", "CD1", "CD2", "CE2", "CE3", "CG", "CH2", "CZ2", "CZ3", "N", "NE1", "O"], "T" => ["C", "CA", "CB", "CG2", "N", "O", "OG1"], "P" => ["C", "CA", "CB", "CD", "CG", "N", "O"], "C" => ["C", "CA", "CB", "N", "O", "SG"], "V" => ["C", "CA", "CB", "CG1", "CG2", "N", "O"], "L" => ["C", "CA", "CB", "CD1", "CD2", "CG", "N", "O"], "M" => ["C", "CA", "CB", "CE", "CG", "N", "O", "SD"], "N" => ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"], "H" => ["C", "CA", "CB", "CD2", "CE1", "CG", "N", "ND1", "NE2", "O"], "A" => ["C", "CA", "CB", "N", "O"], "X" => ["C", "CA", "CB", "N", "O"], "D" => ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"], "G" => ["C", "CA", "N", "O"], "E" => ["C", "CA", "CB", "CD", "CG", "N", "O", "OE1", "OE2"], "Y" => ["C", "CA", "CB", "CD1", "CD2", "CE1", "CE2", "CG", "CZ", "N", "O", "OH"], "I" => ["C", "CA", "CB", "CD1", "CG1", "CG2", "N", "O"], "S" => ["C", "CA", "CB", "N", "O", "OG"], "K" => ["C", "CA", "CB", "CD", "CE", "CG", "N", "NZ", "O"], "R" => ["C", "CA", "CB", "CD", "CG", "CZ", "N", "NE", "NH1", "NH2", "O"], "F" => ["C", "CA", "CB", "CD1", "CD2", "CE1", "CE2", "CG", "CZ", "N", "O"])

BSON.@load "../data/pdb500_allatom.bson" allatom_dataset;

In [3]:
allatom_dataset[1] |> propertynames

(:chainids, :AAs, :backbone_xyz, :atom_xyz, :atom_res, :atom_name)

In [3]:
calc_dist(p1, p2) = Float32(sqrt(sum((p1 .- p2).^2)))

function prep_data(protein)
    maskcalpha = protein.atom_name .== "CA"

    atom_xyz = Float32.(protein.atom_xyz[:, maskcalpha])
    atom_res = protein.atom_res[maskcalpha]

    atom_name = Float32.(onehotbatch(protein.atom_name[maskcalpha], ATOMNAMES))
    AAs_res = Float32.(onehotbatch(protein.AAs, ProteinChains.AMINOACIDS))
    AAs_atom = Float32.(onehotbatch(protein.AAs[atom_res], ProteinChains.AMINOACIDS))
    
    backbone_xyz = Float32.(protein.backbone_xyz[:, 2, :])

    backbone_distanceincrement = map(calc_dist, eachcol(backbone_xyz[:, 2:end]), eachcol(backbone_xyz[:, 1:end-1]))
    backbone_distanceincrement = [backbone_distanceincrement; zero(Float32)]
    backbone_distanceincrement = reshape(backbone_distanceincrement, 1, size(backbone_distanceincrement)...)

    atom_calpha_distance = map(calc_dist, eachcol(atom_xyz), eachcol(backbone_xyz[:, atom_res]))
    atom_calpha_distance = reshape(atom_calpha_distance, 1, size(atom_calpha_distance)...)

    return (; AAs_res, AAs_atom, backbone_distanceincrement, atom_calpha_distance, atom_xyz, backbone_xyz, atom_name)
end

prep_data (generic function with 1 method)

In [4]:
function prep_batch(data)
    data = map(prep_data, data)
    
    AAs_res = Flux.batch(map(x -> x.AAs_res, data))
    AAs_atom = Flux.batch(map(x -> x.AAs_atom, data))
    backbone_distanceincrement = Flux.batch(map(x -> x.backbone_distanceincrement, data))
    atom_calpha_distance = Flux.batch(map(x -> x.atom_calpha_distance, data))
    atom_xyz = Flux.batch(map(x -> x.atom_xyz, data))
    backbone_xyz = Flux.batch(map(x -> x.backbone_xyz, data))
    atom_name = Flux.batch(map(x -> x.atom_name, data))
 
    return (; AAs_res, AAs_atom, backbone_distanceincrement, atom_calpha_distance, atom_xyz, backbone_xyz, atom_name) 
end

prep_batch (generic function with 1 method)

In [5]:
function chop_proteins(proteins, max_back = 100)
    chopped_proteins = []

    for protein in proteins
        back_len = length(protein.AAs)
        if back_len < max_back
            push!(chopped_proteins, protein)
        else
            for start in 1:max_back:back_len
                stop = start + max_back - 1
                stop <= back_len || break

                atom_mask = findall(x -> start <= x <= stop, protein.atom_res)

                # (:chainids, :AAs, :backbone_xyz, :atom_xyz, :atom_res, :atom_name)
                chopped_prot = (; 
                    chainids = protein.chainids[start:stop],
                    AAs = protein.AAs[start:stop],
                    backbone_xyz = protein.backbone_xyz[:, start:stop],
                    atom_xyz = protein.atom_xyz[:, atom_mask],
                    atom_res = protein.atom_res[atom_mask],
                    atom_name = protein.atom_name[atom_mask]
                )
                push!(chopped_proteins, chopped_prot)
            end
        end
    end

    return chopped_proteins
end

chop_proteins (generic function with 2 methods)

In [6]:
prep_data(allatom_dataset[1]);
prep_batch(allatom_dataset[1:1]);

UndefVarError: UndefVarError: `onehotbatch` not defined in `Main`
Suggestion: check for spelling errors or missing imports.
Hint: a global variable of this name may be made accessible by importing OneHotArrays in the current active module Main

## Model

In [7]:
struct AllAtomModel{L}
    layers::L
end

Flux.@layer AllAtomModel

function AllAtomModel(embed_dim, num_layers, num_heads)
    layers = (;
        # Layers for both
        AA_encoder = Dense(21 => embed_dim),
        cross_attention_layers = [Attention(embed_dim, num_heads) for _ in 1:num_layers],

        # Layers for atoms
        atomname_encoder = Dense(length(ATOMNAMES) => embed_dim),
        atom_calpha_distance_encoder = RandomFourierFeatures(1 => embed_dim, one(Float32)),
        atom_transformers = [NaiveTransformerBlock(embed_dim, num_heads, 3) for _ in 1:num_layers],
        output_layer = Dense(embed_dim => 3),

        # Layers for backbone
        #locdiff_rff = RandomFourierFeatures(3 => embed_dim),
        backbone_distanceincrement_encoder = RandomFourierFeatures(1 => embed_dim, one(Float32)),
        backbone_transformers = [NaiveTransformerBlock(embed_dim, num_heads, 3) for _ in 1:num_layers],
    )
    return AllAtomModel(layers)
end

function (m::AllAtomModel)(AAs_res, AAs_atom, backbone_distanceincrement, atom_calpha_distance, atom_xyz, backbone_xyz, atom_name)
    l = m.layers

    x_res = l.backbone_distanceincrement_encoder(backbone_distanceincrement) + l.AA_encoder(AAs_res)
    x_atom = l.atom_calpha_distance_encoder(atom_calpha_distance) + l.AA_encoder(AAs_atom) + l.atomname_encoder(atom_name)

    for i in 1:length(l.atom_transformers)
        x_atom = l.atom_transformers[i](x_atom, atom_xyz)
        x_res = l.backbone_transformers[i](x_res, backbone_xyz)
        x_atom = l.cross_attention_layers[i](x_atom, x_res)
    end

    y_atom = l.output_layer(x_atom)

    return y_atom
end

(m::AllAtomModel)(prepped_batch) = m(prepped_batch.AAs_res, prepped_batch.AAs_atom, prepped_batch.backbone_distanceincrement, prepped_batch.atom_calpha_distance, prepped_batch.atom_xyz, prepped_batch.backbone_xyz, prepped_batch.atom_name)

## Make model

In [8]:
embed_dim = 8 * 3
nheads = 4
model = AllAtomModel(embed_dim, 2, nheads)

AllAtomModel(
  Dense(21 => 24),                      [90m# 528 parameters[39m
  [
    Attention(
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      24,
      4,
      4,
      6,
    ),
    Attention(
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      24,
      4,
      4,
      6,
    ),
  ],
  Dense(83 => 24),                      [90m# 2_016 parameters[39m
  RandomFourierFeatures{Float32, Matrix{Float32}}(Float32[-4.0496616 10.943888 … 10.990238 -1.3580636]),
  [
    NaiveTransformerBlock(
      Attention(
        Dense(24 => 24; bias=false

In [12]:
b = prep_batch(allatom_dataset[1:1]);

## TODO


- Training loop
- Sampling code
- Calpha distance rff for embeddings


In [None]:
losses = Float32[]
 
embed_dim = 8 * 3
nheads = 4
model = AllAtomModel(embed_dim, 2, nheads)
 
opt_state = Flux.setup(AdamW(eta = 0.001), model)
 
println("Starting training...")
for epoch in 1:1
    tot_loss = 0f0
    for i in 1:10
        b = prep_batch(allatom_dataset[i:i])
        l, g = Flux.withgradient(model) do m
            y_pred = m(b)
 
            y_true = b.atom_xyz
            
            loss = mse(y_pred, y_true)
            return loss
        end
        
        if !isnothing(g)
            tot_loss += l
            Flux.update!(opt_state, model, g[1])
            push!(losses, l)
        end
        @info "Epoch: $epoch; Loss: $(round(l, digits=4))"
    end
    
    avg_loss = tot_loss / 100
    println("Epoch: $epoch; Average Loss: $(round(avg_loss, digits=4))")
end
println("Training finished.")