In [1]:
from dem.energies.tblite_energy import TBLiteEnergy
import numpy as np
import pickle
import torch 
from dem.energies.base_energy_function import BaseEnergyFunction
from dem.energies.lennardjones_energy import LennardJonesEnergy 
from dem.models.components.clipper import Clipper
from dem.models.components.noise_schedules import BaseNoiseSchedule, GeometricNoiseSchedule
from dem.models.components.score_estimator import get_logreward_noised_samples

In [5]:
energy_function = TBLiteEnergy(dimensionality = 3*69, n_particles = 69, is_molecule=True, data_path ='/home/mila/l/lena-nehale.ezzine/Amgen/DEM/data/TIfV_val.npy',   data_path_train='/home/mila/l/lena-nehale.ezzine/Amgen/DEM/data/TIfV_train.npy', data_path_val='/home/mila/l/lena-nehale.ezzine/Amgen/DEM/data/TIfV_val.npy', T=0.1)

In [6]:
coords = torch.tensor(np.load('/home/mila/l/lena-nehale.ezzine/Amgen/DEM/data/TIfV_val.npy'), dtype=torch.float32)
atom_ids = pickle.load(open('/home/mila/l/lena-nehale.ezzine/Amgen/DEM/data/atom_ids.pkl', 'rb'))

In [9]:
energy_function(torch.rand(69,3))

------------------------------------------------------------
  cycle        total energy    energy error   density error
------------------------------------------------------------
      1      11696.82296596   6.3654270E+03   2.1577181E+00
      2      11694.84449147  -1.9784745E+00   1.2755375E+00
      3      11693.83864785  -1.0058436E+00   9.1754637E-02
      4      11693.75435246  -8.4295385E-02   5.3377027E-02
      5      11693.85161613   9.7263673E-02   3.9444354E-02
      6      11693.68669610  -1.6492004E-01   2.1767680E-02
      7      11693.65849108  -2.8205019E-02   1.5797328E-02
      8      11693.65527664  -3.2144329E-03   1.3629944E-02
------------------------------------------------------------

 total:                                   1.623 sec
log rew -116936.5527664498


tensor(-116936.5547)

In [4]:
energy_function.interatomic_dist(coords).shape

torch.Size([283, 2346])

In [7]:
grads = energy_function(coords[:2].unsqueeze(0).unsqueeze(0))
grads.shape

------------------------------------------------------------
  cycle        total energy    energy error   density error
------------------------------------------------------------
      1     -101.2379595160  -1.0254221E+02   1.6242103E-01
      2     -101.8354361927  -5.9747668E-01   6.4933516E-02
      3     -101.8713076978  -3.5871505E-02   3.7344083E-02
      4     -101.8826051226  -1.1297425E-02   1.0278015E-02
      5     -101.8844306159  -1.8254933E-03   4.4548929E-03
------------------------------------------------------------

 total:                                   2.213 sec
log rew 101.88443061590056
------------------------------------------------------------
  cycle        total energy    energy error   density error
------------------------------------------------------------
      1     -101.2374035518  -1.0254175E+02   1.6240706E-01
      2     -101.8352837297  -5.9788018E-01   6.4946729E-02
      3     -101.8712113826  -3.5927653E-02   3.7351489E-02
      4     -10

torch.Size([1, 1, 2])

In [10]:
e0 = energy_function(coords[0])
grads = energy_function.get_gradient(coords[0])
e1 = energy_function(coords[0] + 1e0 * grads)
e0, e1

------------------------------------------------------------
  cycle        total energy    energy error   density error
------------------------------------------------------------
      1     -101.2379595160  -1.0254221E+02   1.6242103E-01
      2     -101.8354361927  -5.9747668E-01   6.4933516E-02
      3     -101.8713076978  -3.5871505E-02   3.7344083E-02
      4     -101.8826051226  -1.1297425E-02   1.0278015E-02
      5     -101.8844306159  -1.8254933E-03   4.4548929E-03
      6     -101.8848423873  -4.1177144E-04   1.6148188E-03
      7     -101.8848654878  -2.3100480E-05   6.8722718E-04
      8     -101.8848755082  -1.0020362E-05   3.6223155E-04
      9     -101.8848774218  -1.9136250E-06   8.8704758E-05
------------------------------------------------------------

 total:                                   2.065 sec
tensor([101.8849]) torch.Size([1])
------------------------------------------------------------
  cycle        total energy    energy error   density error
--------

(tensor(101.8849), tensor(101.8843))

In [5]:

energy_function = LennardJonesEnergy(dimensionality = 39,
        n_particles = 13,
        data_path = 'data/test_split_LJ13-1000.npy',
        data_path_train = "data/train_split_LJ13-1000.npy",
        data_path_val =  "data/test_split_LJ13-1000.npy")

noise_schedule = GeometricNoiseSchedule(sigma_min=0.01, sigma_max=2)
num_mc_samples = 10
clipper = Clipper(should_clip_scores=True,
                should_clip_log_rewards= False,
                max_score_norm= 20,
                min_log_reward= None)

In [6]:
bs = 5
t = torch.Tensor([bs])
x = torch.rand(bs,39)


In [72]:
samples, log_rewards = get_logreward_noised_samples(t, x, energy_function, noise_schedule, num_mc_samples, clipper)

In [73]:
samples.shape, log_rewards.shape

(torch.Size([10, 5, 39]), torch.Size([10, 1, 5]))

In [27]:
weights = torch.softmax(log_rewards, dim=-1).unsqueeze(-1)

In [32]:
log_rewards, torch.exp(log_rewards)/torch.sum(torch.exp(log_rewards))

(tensor([[-7.0318e+01],
         [-6.5862e+01],
         [-3.2641e+04],
         [-3.0963e+02],
         [-9.7251e+01],
         [-6.8172e+01],
         [-6.4121e+01],
         [-2.0791e+07],
         [-4.7405e+05],
         [-7.2687e+01]]),
 tensor([[1.7046e-03],
         [1.4672e-01],
         [0.0000e+00],
         [0.0000e+00],
         [3.4255e-15],
         [1.4573e-02],
         [8.3685e-01],
         [0.0000e+00],
         [0.0000e+00],
         [1.5948e-04]]))

In [35]:
torch.softmax(log_rewards, dim=0)

tensor([[1.7046e-03],
        [1.4672e-01],
        [0.0000e+00],
        [0.0000e+00],
        [3.4248e-15],
        [1.4573e-02],
        [8.3685e-01],
        [0.0000e+00],
        [0.0000e+00],
        [1.5948e-04]])

In [3]:
def f(x):
    return x

gradf = torch.func.grad(f)

In [6]:
gradf(torch.tensor(1.0))

tensor(1.)

###  Make datasets for cyclic peptides idem

In [1]:
import numpy as np

data = np.load('data/test_split_LJ13-1000.npy')
data.shape

(10000, 39)

In [4]:
from Bio import PDB
import numpy as np
import periodictable

def get_atomicnumber(max_Z = 50):
    '''Maps the atomic number to the atom elements in the periodic table.'''
    atom2atomicnumber = {}
    atomicnumber2atom = {}
    for atomicnumber in range(max_Z):
        element = periodictable.elements[atomicnumber + 1] 
        atom2atomicnumber[element.symbol] = atomicnumber + 1
        atomicnumber2atom[atomicnumber + 1] = element.symbol
    return atom2atomicnumber, atomicnumber2atom


def extract_column_from_pdb(pdb_path, start, end):
    column_data = []
    with open(pdb_path, 'r') as pdb_file:
        for line in pdb_file:
            if line.startswith('ENDMDL'):
                break
            if line.startswith('ATOM'):
                column_value = line[start:end].strip() 
                column_data.append(column_value)
    return column_data

def get_coord(model):
    '''Extracts the coordinates of the atoms in the model.'''
    coords = []
    for chain in model:
        for residue in chain:
            for atom in residue:
                x,y,z = atom.get_coord()
                coords.append([x,y,z])
    return np.array(coords).flatten()


def get_models_atoms(pdb_path ):
    '''Extracts the models and atom ids from the pdb file.'''
    atom2atomicnumber, _ = get_atomicnumber()
    parser = PDB.PDBParser()
    io = PDB.PDBIO()
    struct = parser.get_structure('structure_id',f'{pdb_path}')
    atom_names = []
    models = list(struct.get_models())
    atom_names = extract_column_from_pdb(f'{pdb_path}', 77, 78)
    atom_ids = np.array([atom2atomicnumber[x] for x in atom_names ])
    return models, atom_ids

def get_coords(pdb_path):
    models, atom_ids = get_models_atoms(pdb_path)
    coords = []
    for model in models:
        coords.append(get_coord(model))
    return np.array(coords)


In [5]:
pdb_path = '/home/mila/l/lena-nehale.ezzine/Amgen/DEM/data/T.I.f.V.pdb'
coords = get_coords(pdb_path)
coords.shape

(283, 207)

In [6]:
coords_train, coords_val = np.array([]), coords
np.save('/home/mila/l/lena-nehale.ezzine/Amgen/DEM/data/TIfV_train.npy', coords_train)
np.save('/home/mila/l/lena-nehale.ezzine/Amgen/DEM/data/TIfV_val.npy', coords_val)
