## Set up dependencies

In [25]:
using Pkg; Pkg.activate("..")

Pkg.add(["ProteinChains", "BSON", "Flux", "ConcreteStructs", "ChainRulesCore", "Einops", "RandomFeatureMaps"])

using Onion, ProteinChains, BSON, Flux, ConcreteStructs, ChainRulesCore, Einops, RandomFeatureMaps

[32m[1m  Activating[22m[39m project at `c:\Users\User\Desktop\SoFo\code\SoFo-2025-Translation-Equivariant-Transformer`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `C:\Users\User\Desktop\SoFo\code\SoFo-2025-Translation-Equivariant-Transformer\Project.toml`
[32m[1m  No Changes[22m[39m to `C:\Users\User\Desktop\SoFo\code\SoFo-2025-Translation-Equivariant-Transformer\Manifest.toml`


## Load Dataset

In [26]:
# All different types of atoms (orderd tuple per amino acid)
ATOM_TYPES = String["C", "CA", "CB", "CD", "CD1", "CD2", "CE", "CE1", "CE2", "CE3", "CG", "CG1", "CG2", "CH2", "CZ", "CZ2", "CZ3", "H", "H1", "H2", "H3", "HA", "HA2", "HA3", "HB", "HB1", "HB2", "HB3", "HD1", "HD11", "HD12", "HD13", "HD2", "HD21", "HD22", "HD23", "HD3", "HE", "HE1", "HE2", "HE21", "HE22", "HE3", "HG", "HG1", "HG11", "HG12", "HG13", "HG2", "HG21", "HG22", "HG23", "HG3", "HH", "HH11", "HH12", "HH2", "HH21", "HH22", "HZ", "HZ1", "HZ2", "HZ3", "N", "ND1", "ND2", "NE", "NE1", "NE2", "NH1", "NH2", "NZ", "O", "OD1", "OD2", "OE1", "OE2", "OG", "OG1", "OH", "OXT", "SD", "SG"]
AA_TO_ATOMS = Dict("Q" => ["C", "CA", "CB", "CD", "CG", "N", "NE2", "O", "OE1"], "W" => ["C", "CA", "CB", "CD1", "CD2", "CE2", "CE3", "CG", "CH2", "CZ2", "CZ3", "N", "NE1", "O"], "T" => ["C", "CA", "CB", "CG2", "N", "O", "OG1"], "P" => ["C", "CA", "CB", "CD", "CG", "N", "O"], "C" => ["C", "CA", "CB", "N", "O", "SG"], "V" => ["C", "CA", "CB", "CG1", "CG2", "N", "O"], "L" => ["C", "CA", "CB", "CD1", "CD2", "CG", "N", "O"], "M" => ["C", "CA", "CB", "CE", "CG", "N", "O", "SD"], "N" => ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"], "H" => ["C", "CA", "CB", "CD2", "CE1", "CG", "N", "ND1", "NE2", "O"], "A" => ["C", "CA", "CB", "N", "O"], "X" => ["C", "CA", "CB", "N", "O"], "D" => ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"], "G" => ["C", "CA", "N", "O"], "E" => ["C", "CA", "CB", "CD", "CG", "N", "O", "OE1", "OE2"], "Y" => ["C", "CA", "CB", "CD1", "CD2", "CE1", "CE2", "CG", "CZ", "N", "O", "OH"], "I" => ["C", "CA", "CB", "CD1", "CG1", "CG2", "N", "O"], "S" => ["C", "CA", "CB", "N", "O", "OG"], "K" => ["C", "CA", "CB", "CD", "CE", "CG", "N", "NZ", "O"], "R" => ["C", "CA", "CB", "CD", "CG", "CZ", "N", "NE", "NH1", "NH2", "O"], "F" => ["C", "CA", "CB", "CD1", "CD2", "CE1", "CE2", "CG", "CZ", "N", "O"])

BSON.@load "./pdb500_allatom.bson" allatom_dataset;

## Model

In [27]:
struct AllAtomModel{L}
    layers::L
end

Flux.@layer AllAtomModel

# Initialization
function AllAtomModel(embed_dim::Int, num_layers::Int, num_heads::Int)
    # allatom_dataset[1] |> propertynames === (:chainids, :AAs, :backbone_xyz, :atom_xyz, :atom_res, :atom_name)
        
    layers = (;
        # -------- Layers for both start --------
        AA_type_encoder = RandomFourierFeatures(21 => embed_dim, 1f0),
        
        cross_attention_layers = [Attention(embed_dim, num_heads) for _ in 1:num_layers],
        # -------- Layers for both end --------


        # -------- Atom stack start --------
        atom_type_encoder = RandomFourierFeatures(length(ATOM_TYPES) => embed_dim, 1f0),
        
        atom_transformers = [NaiveTransformerBlock(embed_dim, num_heads, 3) for _ in 1:num_layers],
        # -------- Atom stack end --------


        # -------- Backbone stack start --------
        
        backbone_transformers = [NaiveTransformerBlock(embed_dim, num_heads, 3) for _ in 1:num_layers],
        AA_idx_encoder = RandomFourierFeatures(n_AAs => embed_dim, 1f0),

        # -------- Backbone stack end --------


        # Seems fine atm
        output_layer = Dense(embed_dim => 3),

    )   
    return AllAtomModel(layers)
end

# Forward / backward pass
function (model::AllAtomModel)(AAs_res, AAs_atom, backbone_xyz, atom_xyz, atom_name)
    l = model.layers

    # Residue embeddings
    backbones_centroid = mean(backbone_xyz, dims=1)

    # RFF encode offset from backbones_centroid
    x_res = l.AA_type_encoder(AAs_res) + l.AA_idx_encoder(AA 

    # Sidechain atom embeddings
    x_atom = l.atomxyz_encoder(atom_xyz) + l.AA_type_encoder(AAs_atom) + l.atom_type_encoder(atom_name)

    
    for i in 1:length(l.atom_transformers)
        # Pass through backbone transformer stack
        x_res = l.backbone_transformers[i](x_res, backbone_xyz)
        
        # Pass through atom transformer stack
        x_atom = l.atom_transformers[i](x_atom, atom_xyz)
        
        # Cross attention from backbone to atom stack
        x_atom = l.cross_attention_layers[i](x_atom, x_res)
    end

    # Output (move atoms) -- dense layer
    y_atom = l.output_layer(x_atom)

    return y_atom
end

## Make model

In [28]:
using Flux: onehotbatch

In [29]:
p = allatom_dataset[1]

AAs_res = onehotbatch(p.AAs, ProteinChains.AMINOACIDS)
AAs_atom = onehotbatch(p.AAs[p.atom_res], ProteinChains.AMINOACIDS)
backbone_xyz = Float32.(p.backbone_xyz[:, 2, :])
atom_xyz = p.atom_xyz
atom_name = onehotbatch(p.atom_name, ATOM_TYPES)

AAs_res = Flux.batch([AAs_res])
AAs_atom = Flux.batch([AAs_atom])
backbone_xyz = Flux.batch([backbone_xyz])
atom_xyz = Flux.batch([atom_xyz])
atom_name = Flux.batch([atom_name])


embed_dim = 8 * 3
nheads = 4
model = AllAtomModel(embed_dim, 2, nheads)


AllAtomModel(
  RandomFourierFeatures{Float32, Matrix{Float32}}(Float32[-0.31337827 -7.3964167 … 2.9227424 0.90205204; -5.911216 10.204468 … -9.491757 -6.72175; … ; 3.2447293 8.515508 … 9.808949 5.132442; 0.49914864 8.798216 … 3.5014625 -0.21598397]),
  [
    Attention(
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      24,
      4,
      4,
      6,
    ),
    Attention(
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      Dense(24 => 24; bias=false),      [90m# 576 parameters[39m
      24,
      4,
      4,
      6,
    ),
  ],
  var"#28#32"(),
  RandomFourierFeatures{Float32, Matrix{Float32}}(Float32[2.5202048 

In [31]:
allatom_dataset[1] |> propertynames # length(allatom_dataset) = 500

(:chainids, :AAs, :backbone_xyz, :atom_xyz, :atom_res, :atom_name)

In [35]:
AAs_res = Flux.batch([AAs_res])
AAs_atom = Flux.batch([AAs_atom])
backbone_xyz = Flux.batch([backbone_xyz])
atom_xyz = Flux.batch([atom_xyz])
atom_name = Flux.batch([atom_name])

@show size(AAs_res), size(AAs_atom), size(backbone_xyz), size(atom_xyz), size(atom_name)

(size(AAs_res), size(AAs_atom), size(backbone_xyz), size(atom_xyz), size(atom_name)) = ((21, 284, 1, 1), (21, 4621, 1, 1), (3, 284, 1, 1), (3, 4621, 1, 1), (83, 4621, 1, 1))


((21, 284, 1, 1), (21, 4621, 1, 1), (3, 284, 1, 1), (3, 4621, 1, 1), (83, 4621, 1, 1))

In [36]:
@show typeof(AAs_res) typeof(AAs_atom) typeof(backbone_xyz) typeof(atom_xyz) typeof(atom_name)

typeof(AAs_res) = OneHotArrays.OneHotArray{UInt32, 3, 4, Array{UInt32, 3}}
typeof(AAs_atom) = OneHotArrays.OneHotArray{UInt32, 3, 4, Array{UInt32, 3}}
typeof(backbone_xyz) = Array{Float32, 4}
typeof(atom_xyz) = Array{Float32, 4}
typeof(atom_name) = OneHotArrays.OneHotArray{UInt32, 3, 4, Array{UInt32, 3}}


OneHotArrays.OneHotArray{UInt32, 3, 4, Array{UInt32, 3}}

## TODO


- Training loop
- Sampling code
- Calpha distance rff for embeddings
