Skip to content

Commit

Permalink
Merge pull request #420 from ExcitedStates/maxpool
Browse files Browse the repository at this point in the history
Implement maxpool sampling
  • Loading branch information
Stephanie (Mullane) Wankowicz committed Apr 29, 2024
2 parents f19f804 + 188b344 commit 37da920
Showing 1 changed file with 137 additions and 24 deletions.
161 changes: 137 additions & 24 deletions src/qfit/qfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import subprocess
import numpy as np
import tqdm
import timeit

from .backbone import NullSpaceOptimizer, adp_ellipsoid_axes
from .clash import ClashDetector
Expand Down Expand Up @@ -245,7 +246,7 @@ def _subtract_transformer(self, residue, structure):
# Subtract the density:
self.xmap.array -= self._subtransformer.xmap.array

def _convert(self):
def _convert(self, stride=1, pool_size=1): #default is to manipulate the maps
"""Convert structures to densities and extract relevant values for (MI)QP."""
logger.info("Converting conformers to density")
logger.debug("Masking")
Expand All @@ -258,15 +259,36 @@ def _convert(self):

nvalues = mask.sum()
self._target = self.xmap.array[mask]

# For a 1D array, we adjust our pooling approach
pooled_values = []
for i in range(0, len(self._target), stride):
# Extract the current window for pooling
current_window = self._target[i:i+pool_size]
# Perform max pooling on the current window and append the max value to pooled_values
if len(current_window) > 0: # Ensure the window is not empty
pooled_values.append(np.max(current_window))

# Convert pooled_values back to a numpy array
self._target = np.array(pooled_values)

logger.debug("Density")
nmodels = len(self._coor_set)
self._models = np.zeros((nmodels, nvalues), float)
maxpool_size = len(range(0, nvalues, stride))
self._models = np.zeros((nmodels, maxpool_size), float)
for n, coor in enumerate(self._coor_set):
self.conformer.coor = coor
self.conformer.b = self._bs[n]
self._transformer.density()
model = self._models[n]
model[:] = self._transformer.xmap.array[mask]
# Apply maxpooling to the map similar to self._target
map_values = self._transformer.xmap.array[mask]
pooled_map_values = []
for i in range(0, len(map_values), stride):
current_window = map_values[i:i+pool_size]
if len(current_window) > 0:
pooled_map_values.append(np.max(current_window))
model[:] = np.array(pooled_map_values)
np.maximum(model, self.options.bulk_solvent_level, out=model)
self._transformer.reset(full=True)

Expand Down Expand Up @@ -641,14 +663,16 @@ def _setup_clash_detector(self):
)

def run(self):
start_time = timeit.default_timer()
if self.options.sample_backbone:
self._sample_backbone()

if self.options.sample_angle:
self._sample_angle()

if self.residue.nchi >= 1 and self.options.sample_rotamers:
self._sample_sidechain()
self._sample_sidechain(version = 0)
self._sample_sidechain(version = 1)

# Check that there are no self-clashes within a conformer
self.residue.active = True
Expand All @@ -669,21 +693,21 @@ def run(self):
self._bs = new_bs

# QP score conformer occupancy
self._convert()
self._convert(1,1)
self._solve_qp()
self._update_conformers()
if self.options.write_intermediate_conformers:
self._write_intermediate_conformers(prefix="qp_solution")
self._write_intermediate_conformers(prefix="qp_solution_residue")

# MIQP score conformer occupancy
self.sample_b()
self._convert()
self._convert(1,1)
self._solve_miqp(
threshold=self.options.threshold, cardinality=self.options.cardinality
)
self._update_conformers()
if self.options.write_intermediate_conformers:
self._write_intermediate_conformers(prefix="miqp_solution")
self._write_intermediate_conformers(prefix="miqp_solution_residue")

# Now that the conformers have been generated, the resulting
# conformations should be examined via GoodnessOfFit:
Expand All @@ -699,6 +723,11 @@ def run(self):
self.validation_metrics = validator.GoodnessOfFit(
self.conformer, self._coor_set, self._occupancies, cutoff
)
# End of processing
end_time = timeit.default_timer()
print(f"Processing time: {end_time - start_time} seconds")



def _sample_backbone(self):
# Check if residue has enough neighboring residues
Expand Down Expand Up @@ -894,17 +923,30 @@ def _sample_angle(self):
if self.options.write_intermediate_conformers:
self._write_intermediate_conformers(prefix=f"sample_angle")

def _sample_sidechain(self):
def _sample_sidechain(self, version=0):
opt = self.options
start_chi_index = 1
if self.residue.resn[0] != "PRO":
sampling_window = np.arange(
-opt.rotamer_neighborhood,
opt.rotamer_neighborhood + opt.dihedral_stepsize,
opt.dihedral_stepsize,
)
if version == 0:
stride_ = 2
pool_size_ = 2
sampling_window = np.arange(
-opt.rotamer_neighborhood,
opt.rotamer_neighborhood,
24,
)
else:
stride_ = 1
pool_size_ = 1
sampling_window = np.arange(
-opt.rotamer_neighborhood,
opt.rotamer_neighborhood,
opt.dihedral_stepsize,
)
else:
sampling_window = [0]
stride_ = 1
pool_size_ = 1

rotamers = self.residue.rotamers
rotamers.append(
Expand Down Expand Up @@ -943,6 +985,21 @@ def _sample_sidechain(self):
n = 0
ex = 0
# For each backbone conformation so far:
if version == 1:
sampled_rotamers = []
for coor in self._coor_set:
self.residue.coor = coor
if chi_index in [1, 2]:
rotamer = [self.residue.get_chi(i) for i in range(1, self.residue.nchi + 1)]
else:
rotamer = [rotamer for rotamer in rotamers]
rotamer = [self.residue.get_chi(i) for i in range(1, self.residue.nchi + 1)]
sampled_rotamers.append(rotamer)
if self.residue.nchi > 1:
new_rotamers = [[sampled_rotamer[0], rotamer[1]] for sampled_rotamer in sampled_rotamers for rotamer in rotamers]
else:
new_rotamers = sampled_rotamers

for coor, b in zip(self._coor_set, self._bs):
self.residue.coor = coor
self.residue.b = b
Expand Down Expand Up @@ -1042,19 +1099,73 @@ def _sample_sidechain(self):
)
if self.options.write_intermediate_conformers:
self._write_intermediate_conformers(
prefix=f"sample_sidechain_iter{iteration}"
prefix=f"sample_sidechain_iter{version}_{iteration}"
)

self._convert()
self._solve_qp()
self._update_conformers()
if self.options.write_intermediate_conformers:
self._write_intermediate_conformers(
prefix=f"sample_sidechain_iter{iteration}_qp"
)


if len(self._coor_set) <= 15000:
# If <15000 conformers are generated, QP score conformer occupancy normally
self._convert(stride_, pool_size_)
self._solve_qp()
self._update_conformers() #should this be more lenient?
if self.options.write_intermediate_conformers:
self._write_intermediate_conformers(
prefix=f"sample_sidechain_iter{version}_{iteration}_qp"
)
if len(self._coor_set) > 15000:
# If >15000 conformers are generated, split the QP conformer scoring into two
temp_coor_set = self._coor_set
temp_bs = self._bs

# Splitting the arrays into two sections
half_index = len(temp_coor_set) // 2 # Integer division for splitting
section_1_coor = temp_coor_set[:half_index]
section_1_bs = temp_bs[:half_index]
section_2_coor = temp_coor_set[half_index:]
section_2_bs = temp_bs[half_index:]

# Process the first section
self._coor_set = section_1_coor
self._bs = section_1_bs

# QP score the first section
self._convert(stride_, pool_size_)
self._solve_qp()
self._update_conformers()
if self.options.write_intermediate_conformers:
self._write_intermediate_conformers(
prefix=f"sample_sidechain_iter_{version}_{iteration}_qp"
)

# Save results from the first section
qp_temp_coor = self._coor_set
qp_temp_bs = self._bs

# Process the second section
self._coor_set = section_2_coor
self._bs = section_2_bs

# QP score the second section
self._convert(stride_, pool_size_)
self._solve_qp()
self._update_conformers()
if self.options.write_intermediate_conformers:
self._write_intermediate_conformers(
prefix=f"sample_sidechain_iter_{version}_{iteration}_qp"
)

# Save results from the second section
qp_2_temp_coor = self._coor_set
qp_2_temp_bs = self._bs

# Concatenate the results from both sections
self._coor_set = np.concatenate((qp_temp_coor, qp_2_temp_coor), axis=0)
self._bs = np.concatenate((qp_temp_bs, qp_2_temp_bs), axis=0)

# MIQP score conformer occupancy
self.sample_b()
self._convert()
self._convert(stride_, pool_size_)
self._solve_miqp(
threshold=self.options.threshold,
cardinality=None, # don't enforce strict cardinality constraint, just less-than 1/threshold
Expand All @@ -1063,7 +1174,7 @@ def _sample_sidechain(self):
self._update_conformers()
if self.options.write_intermediate_conformers:
self._write_intermediate_conformers(
prefix=f"sample_sidechain_iter{iteration}_miqp"
prefix=f"sample_sidechain_iter{version}_{iteration}_miqp"
)

# Check if we are done
Expand Down Expand Up @@ -1339,6 +1450,7 @@ def find_paths(self, segment_original):
threshold=self.options.threshold,
cardinality=self.options.cardinality,
segment=True,
do_BIC_selection=False,
)
except SolverError:
# MIQP failed and we need to remove conformers that are close to each other
Expand Down Expand Up @@ -2408,3 +2520,4 @@ def get_conformers_covalent(self):
conformer.b = b
conformers.append(conformer)
return conformers

0 comments on commit 37da920

Please sign in to comment.