In [11]:
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'

import jax.numpy as jnp
from chroma import Chroma, Protein



In [108]:
import torch
t = torch.Tensor([1]*10)[:,None,None]
print(t.shape)
t=t.expand(-1,4,3)
t=t[None,:,:]
t.shape

torch.Size([10, 1, 1])


torch.Size([1, 10, 4, 3])

In [109]:
a = jnp.zeros((100, 100))

In [113]:
(a[None,:] + a[:,None]).shape

(100, 100, 100)

In [2]:
from colabdesign.af import mk_af_model, clear_mem

In [4]:
af = mk_af_model(protocol="fixbb", use_templates=True, debug=True)
# af_multi = mk_af_model(protocol="fixbb", use_templates=True, debug=True, model_type="alphafold2_multimer_v3")

In [14]:
chroma = Chroma(device="cuda")
protein_init = Protein.from_PDB("sample_1.pdb", device="cuda")
protein_ch = chroma.sample(chain_lengths=[200],
                                    initialize_noise=True,
                                    protein_init=protein_init,
                                    steps=400,
                                    # full_output=True,
                                    # trajectory_length=500,
            )
protein_ch.to("sample_2.pdb")

Using cached data from /tmp/chroma_weights/90e339502ae6b372797414167ce5a632/weights.pt
Loaded from cache
Using cached data from /tmp/chroma_weights/03a3a9af343ae74998768a2711c8b7ce/weights.pt
Loaded from cache


Integrating SDE:   0%|          | 0/400 [00:00<?, ?it/s]

Potts Sampling:   0%|          | 0/500 [00:00<?, ?it/s]

Sequential decoding:   0%|          | 0/200 [00:00<?, ?it/s]

In [18]:
af.prep_inputs("sample_2.pdb",chain="A")

In [19]:
af._inputs["batch"]["aatype"]

array([12,  2,  9, 16,  9,  9,  3,  9, 13, 19, 10, 14, 14, 10, 15,  1,  7,
        9, 16, 14,  6,  6, 10, 10,  3,  1, 10,  1,  6,  0,  7, 19, 13,  3,
       11,  9, 11,  6,  9,  9, 11,  6, 10, 18, 11, 11,  7, 19,  2,  9, 16,
        9,  9, 15,  9, 10, 19, 14,  6,  3, 14,  3,  0, 17,  1,  1, 14,  6,
        0,  1,  0,  0, 10,  6,  1,  0, 11,  6,  1, 10,  5, 15, 13, 18,  6,
        6,  9, 11, 11,  6, 19,  3, 14, 11, 11, 13, 11, 19, 11, 19,  7, 10,
        3, 14,  1,  0, 19, 14,  7, 15, 16,  6,  7,  6,  6,  0, 10, 11,  6,
        9,  6, 11,  2,  9, 11, 11,  7,  3, 11, 19,  9,  9,  9, 18, 15, 16,
       10,  3, 14,  6, 16, 10,  6,  1, 10, 11, 11, 13,  0, 11,  6, 10,  6,
       11, 11,  2, 11,  0,  9, 18,  9,  6, 15, 14, 10, 16, 19,  0, 10,  1,
        0, 10,  6,  6, 15, 14,  6, 16,  0,  5, 11,  1,  9,  3,  6,  0, 19,
        6, 11, 10, 11, 11, 11,  9,  3,  3,  9,  9,  2, 11])

In [20]:
amino_acids = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS",
                   "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP",
                   "TYR", "VAL"]
aa_mapping = {
        "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C",
        "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I",
        "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P",
        "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V"
    }

print("".join([aa_mapping[amino_acids[aa]] for aa in af._inputs["batch"]["aatype"]]))

MNITIIDIFVLPPLSRGITPEELLDRLREAGVFDKIKEIIKELYKKGVNITIISILVPEDPDAWRRPEARAALERAKERLQSFYEEIKKEVDPKKFKVKVGLDPRAVPGSTEGEEALKEIEKNIKKGDKVIIIYSTLDPETLERLKKFAKELEKKNKAIYIESPLTVALRALEESPETAQKRIDEAVEKLKKKIDDIINK


In [66]:
af.predict([20] * 200)

predict models [0] recycles 0 hard 1 soft 0 temp 1 seqid 0.20 loss 1.85 dgram_cce 1.85 plddt 0.59 ptm 0.41 rmsd 9.68


In [67]:
af.plot_pdb()

In [38]:
af.save_pdb("af_0abc.pdb")

In [58]:
af.aux.keys()

dict_keys(['aatype', 'atom_mask', 'atom_positions', 'cmap', 'debug', 'grad', 'i_cmap', 'i_ptm', 'loss', 'losses', 'num_recycles', 'pae', 'plddt', 'prev', 'ptm', 'residue_index', 'seq', 'all', 'log'])

In [95]:
af.aux["cmap"].mean(1)

array([0.09520771, 0.10413208, 0.13015707, 0.14818868, 0.13205108,
       0.15342768, 0.19593155, 0.18499982, 0.17071651, 0.20328529,
       0.2259857 , 0.18792678, 0.19169828, 0.22752655, 0.21541275,
       0.18698198, 0.20598026, 0.23275885, 0.19524941, 0.1870895 ,
       0.21546386, 0.20217678, 0.17353447, 0.18392883, 0.19619335,
       0.17023262, 0.16804837, 0.18372351, 0.17070991, 0.14846027,
       0.15982465, 0.16274098, 0.13696505, 0.1287658 , 0.1350261 ,
       0.11815249, 0.11412962, 0.11997661, 0.13816316, 0.14193565,
       0.13638265, 0.15049376, 0.16947132, 0.16050997, 0.15277334,
       0.17831312, 0.1869965 , 0.16280521, 0.16533059, 0.19248961,
       0.17994697, 0.15674967, 0.1741184 , 0.19362721, 0.16710553,
       0.16193683, 0.19774055, 0.20576867, 0.17695495, 0.19374467,
       0.22686827, 0.20626163, 0.18863364, 0.21542221, 0.23584831,
       0.19852205, 0.20078053, 0.23078503, 0.21796387, 0.18706872,
       0.19605778, 0.21815518, 0.18121628, 0.1615102 , 0.17076

In [71]:
af.aux["debug"]["inputs"]

{'aatype': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0], dtype=int32),
 'asym_id': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [13]:
af.aux["debug"]["outputs"].keys()

dict_keys(['distogram', 'experimentally_resolved', 'masked_msa', 'predicted_aligned_error', 'predicted_lddt', 'prev', 'representations', 'structure_module'])

In [78]:
af.aux["debug"]["outputs"]["distogram"]

{'bin_edges': array([ 2.3125  ,  2.625   ,  2.9375  ,  3.25    ,  3.5625  ,  3.875   ,
         4.1875  ,  4.5     ,  4.8125  ,  5.125   ,  5.4375  ,  5.75    ,
         6.0625  ,  6.375   ,  6.6875  ,  7.      ,  7.3125  ,  7.625   ,
         7.9375  ,  8.25    ,  8.5625  ,  8.875   ,  9.1875  ,  9.5     ,
         9.812499, 10.125   , 10.4375  , 10.75    , 11.0625  , 11.375   ,
        11.6875  , 12.      , 12.3125  , 12.625   , 12.9375  , 13.25    ,
        13.5625  , 13.875   , 14.1875  , 14.499999, 14.8125  , 15.125   ,
        15.4375  , 15.75    , 16.0625  , 16.375   , 16.6875  , 16.999998,
        17.312498, 17.625   , 17.9375  , 18.25    , 18.5625  , 18.875   ,
        19.1875  , 19.5     , 19.8125  , 20.125   , 20.437498, 20.75    ,
        21.0625  , 21.375   , 21.6875  ], dtype=float32),
 'logits': array([[[ 1.44074783e+02,  3.59244633e+00, -2.29589539e+01, ...,
          -6.53785553e+01, -6.75271835e+01, -1.02080292e+02],
         [ 1.98284626e-01, -1.21462083e+00, -3.4072

In [91]:
af.aux["debug"]["outputs"]["distogram"]["bin_edges"].shape

(63,)