### Imports

In [None]:
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import re
import reciprocalspaceship as rs
import os
import gemmi
import math
import shutil
from sklearn import metrics
import matplotlib.pyplot as plt

In [None]:
import valdo

##### Define Paths

In [None]:
original_data_path = '../../pipeline/data/original_data/'

In [None]:
basepath = '../../pipeline/'
reindexed_path = basepath + 'data/reindexed/'
scaled_path = basepath + 'data/scaled/'

In [None]:
intersection_path = scaled_path + 'intersection.pkl'
union_path = scaled_path + 'union.pkl'

In [None]:
vae_folder = basepath + 'vae/'
vae_reconstructed_folder = vae_folder + 'reconstructed/'
vae_reconstructed_with_phases_folder = vae_folder + 'reconstructed-phases/'

In [None]:
blob_folder = vae_folder + 'blobs/'

In [None]:
amplitude_col = 'F-obs'
amplitude_scaled_col = 'F-obs-scaled'

error_col = 'SIGF-obs'

In [None]:
recons_col = 'recons'
diff_col = 'diff'
phase_2FOFC_col = 'refine_PH2FOFCWT'
phase_FOFC_col = 'refine_PHFOFCWT'

### Step 1: Diffraction Data

The first step involves acquiring diffraction datasets in the `mtz` format. These datasets should follow a specific naming convention, where each file is named with a number followed by the `.mtz` extension (e.g., `01.mtz`, `02.mtz`, etc.).

#### Usage

1. Ensure that you have collected diffraction datasets in the `mtz` format.

2. Organize the datasets with sequential numerical names (e.g., `01.mtz`, `02.mtz`, etc.).

Following this naming convention will allow datasets to be ready for further processing.

#### Template Code

The following cell is a template for renaming files to the correct naming convention. Change `source_folder`, `destination_folder`, and extensions as necessary.

In [None]:
# Define the source and destination folders
source_folder = "../../PTP1B_DK/all_bound_models_reindexed/"
destination_folder = "../../pipeline/data/bound_models/"

# Get a list of all files in the source folder
file_list = os.listdir(source_folder)

# Define a regular expression pattern to match the filenames
pattern = r".*(\d{4}).*.pdb"

# Iterate over each file in the source folder
for filename in file_list:

    # Check if the file matches the pattern
    match = re.match(pattern, filename)
    if match:
        # Extract the ID from the filename
        id = match.group(1)
        
        # Define the new filename
        new_filename = id + ".pdb"
        
        # Construct the full source and destination paths
        source_path = os.path.join(source_folder, filename)
        destination_path = os.path.join(destination_folder, new_filename)
        
        # Copy the file to the destination folder with the new name
        shutil.copy(source_path, destination_path)

### Step 2: Reindexing & Scaling

This step focuses on reindexing and scaling a list of input MTZ files to a reference MTZ file using gemmi. 

**Reindexing:** The datasets provided may include samples from different space groups that describe the same physical crystal structure. To ensure comparability, we reindex each sample to a common indexing scheme by applying reindexing operators. 

**Scaling:** The samples are scaled to a reference dataset using a global anisotropic scaling factor by an analytical scaling method that determines the Debye-Waller Factor. The scaling process ensures that structure factor amplitudes are comparable across different datasets, accounting for variabilities such as differences in lattice orientations.

#### Usage

1. Import the required library, `valdo`.

2. Call the `reindex_files()` function from `valdo.reindex`. The `reindex_files()` function will enumerate possible reindexing operations for any space group and apply them to each input MTZ file. It will select the operation with the highest correlation with the reference dataset. The reindexed files will be saved in the specified output folder, following the same `##.mtz` naming convention.

    This function can be called with the following parameters:
    - `input_files`: List of paths to input MTZ files to be reindexed.
    - `reference_file`: Path to the reference MTZ file.
    - `output_folder`: Path to the folder where the reindexed MTZ files will be saved.
    - `columns`: A list containing the names of the columns in the dataset that represent the amplitude and the error column.

3. Create a `Scaler` object by providing the path to the reference MTZ file.

4. Call the `batch_scaling()` method of the `Scaler` object. The `batch_scaling()` method will apply the scaling process to each input MTZ file and save the scaled MTZ files in the specified output folder. Scaling metrics, such as least squares values and correlations, will be saved in the report file.

    This function can be called with the following parameters:
    - `mtz_path_list`: List of paths to input MTZ files to be scaled.
    - `outputmtz_path`: Path to the folder where the scaled MTZ files will be saved (optional, default is `./scaled_mtzs/`).
    - `reportfile`: Path to the file where scaling metrics will be saved (optional, default is `./scaling_data.json`).
    - `verbose`: Whether to display verbose information during scaling (optional, default is `True`).
    - `n_iter`: Number of iterations for the analytical scaling method (optional, default is `5`).


#### Reindexing Code

In [None]:
# List of files to be reindexed

file_list = glob.glob(original_data_path + "*mtz")
file_list.sort()

In [None]:
# Reindexes a list of input MTZ files to a reference MTZ file using gemmi

valdo.reindex.reindex_files(input_files=file_list, 
              reference_file=file_list[0], 
              output_folder=reindexed_path,
              columns=[amplitude_col, error_col])

#### Scaling Code

In [None]:
file_list = glob.glob(reindexed_path + "*mtz")
file_list.sort()

In [None]:
# Initiate the Scaler, file_list[0] serves as the reference

scaler = valdo.Scaler(reference_mtz=file_list[0])

In [None]:
# Scales all datasets to the previously provided reference, writes a `metrics.pkl`

metrics = scaler.batch_scaling(mtz_path_list=file_list, 
                               outputmtz_path=scaled_path, 
                               verbose=False)

### Step 3: Normalization

This step involves normalizing the scaled structure factor amplitudes obtained in the previous step. The input is restricted to only those Miller indices present in the intersection of all datasets, and the VAE predicts structure factor amplitudes for all Miller indices in the union of all datasets.

Additionally, we standardize all the input data, such that the structure factor amplitudes for each Miller index in the union of all datasets have a mean of zero and a unit variance across datasets. 

#### Usage

1. Import the required library, `valdo.preprocessing`.

2. Find the intersection and union of the scaled datasets using the following functions:

   - `find_intersection()`: Finds the intersection of `amplitude_col` from multiple input MTZ files and saves the result to the specified output pickle file. Arguments include the following:

      - `input_files`: List of input MTZ file paths.
      - `output_path`: Path to save the output pickle file containing the intersection data.
      - `amplitude_col`: Name of the column in the dataset that represents the scaled amplitude (default is 'F-obs-scaled').

   - `find_union()`: Finds the union of `amplitude_col` from multiple input MTZ files and saves the result to the specified output pickle file. Arguments are the same as `find_intersection()`.

3. Generate the VAE input and output data using the `generate_vae_io()` function. This standardizes the intersection dataset using mean and standard deviation calculated from the union dataset. The standardized intersection becomes the VAE input, while the standardized union becomes the VAE output. Both the VAE input and output are saved to the specified folder. 

    This function can be called with the following parameters:

    - `intersection_path`: Path to the intersection dataset pickle file.
    - `union_path`: Path to the union dataset pickle file.
    - `io_folder`: Path to the output folder where the VAE input and output will be saved. Mean and standard deviation data calculated from the union dataset will also be saved in this folder as `union_mean.pkl` and `union_sd.pkl`.

#### Code

In this example, we remove samples with low `end_corr`. This ensures that our VAE is trained with high quality samples.

In [None]:
# Identify all scaled files to use as input and output for the VAE

file_list = glob.glob(scaled_path + "*mtz")
file_list.sort()

This following cell removes samples with `end_corr < 0.6` or if `end_corr = NA`.

In [None]:
metrics_df = pd.read_pickle(scaled_path + "metrics.pkl")
metrics_df.columns=['file', 'start_LS', 'start_corr', 'end_LS', 'end_corr']
metrics_df[metrics_df.isna().any(axis=1)]
low_corr_files = list(metrics_df[(metrics_df['end_corr'] < 0.6) | (metrics_df['end_corr'].isnull())]['file'])
low_corr_files = [scaled_path + x + '.mtz' for x in low_corr_files]
file_list = [file for file in file_list if file not in low_corr_files]

The following cells generate the VAE input and output. 

In [None]:
# Creates an `intersection.mtz` file at the specified path
# This is the intersection of all the scaled files provided

valdo.preprocessing.find_intersection(input_files=file_list, 
                  output_path=intersection_path,
                  amplitude_col=amplitude_scaled_col)

In [None]:
# Creates an `union.mtz` file at the specified path
# This is the union of all the scaled files provided

valdo.preprocessing.find_union(input_files=file_list, 
           output_path=union_path,
           amplitude_col=amplitude_scaled_col)

In [None]:
# Generates VAE input and output data from the intersection and union datasets

valdo.preprocessing.generate_vae_io(intersection_path=intersection_path, 
                union_path=union_path, 
                io_folder=vae_folder)

### Step 4: VAE Training

In this step, we train the VAE model using the provided VAE class.

#### Usage

1. Load the VAE input and output data that was generated in the previous step.

2. Initialize the VAE model with the desired hyperparameters. Tune-able hyperparameters include the following:
    - `n_dim_latent`: Number of dimensionality in latent space (optional, default `1`)

    - `n_hidden_layers`: Number of hidden layers in the encoder and decoder. If an int is given, it will applied to both encoder and decoder; If a length 2 list is given, first int will be used for encoder, the second will be used for decoder

    - `n_hidden_size`: Number of units in hidden layers. If an int is given, it will be applied to all hidden layers in both encoder and decoder; otherwise, an array with length equal to the number of hidden layers can be given, the number of units will be assigned accordingly.

    - `activation` : Activation function for the hidden layers (optional, default `tanh`)

3. Split the data into training and validation sets. Randomly select a subset of indices for training and use the rest for validation.

4. Convert the data into PyTorch tensors.

5. Set up the optimizer for training.

6. Train the VAE model using the `train()` method. The training process involves minimizing the ELBO (Evidence Lower Bound) loss function, which consists of a Negative Log-Likelihood (NLL) term and a Kullback-Leibler (KL) divergence term. Arguments used in this function include:

    - `x_train`: Input data for training the VAE, a PyTorch tensor representing the VAE input data. 

    - `y_train`: Output data for training the VAE, a PyTorch tensor representing the VAE output data. 

    - `optim`: The optimizer used for training the VAE, a PyTorch optimizer object, such as `torch.optim.Adam`, that specifies the optimization algorithm and its hyperparameters, including the learning rate (`lr`).

    - `x_val`: Input data for validation during training. (optional, default is `None`).

    - `y_val`: Output data for validation during training. (optional, default is `None`).

    - `epochs`: The number of training epochs (epoch: a single pass through the data).

    - `batch_size`: The batch size used during training. If an integer is provided, the same batch size will be used for all epochs. If a list of integers is provided, it should have the same length as the number of epochs, and each value in the list will be used as the batch size for the corresponding epoch. Default is `256`.

    - `w_kl`: The weight of the Kullback-Leibler (KL) divergence term in the ELBO loss function. The KL divergence term encourages the latent distribution to be close to a prior distribution (usually a standard normal distribution). A higher value of `w_kl` will increase the regularization strength on the latent space. Default is `1.0`.

    **Note:** The VAE class internally keeps track of the training loss (`loss_train`) and its components (NLL and KL divergence) during each batch of training. These values can be accessed after training to monitor the training progress and performance. The `loss_train` attribute of the VAE object will be a list containing the training loss values for each batch during training. The `loss_names` attribute contains the names of the loss components: "Loss", "NLL", and "KL_div". These attributes are updated during training and can be used for analysis or visualization.

7. Save the trained VAE model for future use (optional).

#### Code

In [None]:
# Load the VAE I/O Files Generated

vae_input = np.load(vae_folder + 'vae_input.npy')
vae_output = np.load(vae_folder + 'vae_output.npy')

In [None]:
# Specify VAE Parameters
latent_dimension = 7

vae = valdo.VAE(n_dim_i = vae_input.shape[1], 
      n_dim_o = vae_output.shape[1], 
      n_dim_latent = latent_dimension, 
      n_hidden_layers = [3, 6], 
      n_hidden_size = 100, 
      activation = torch.relu)

# Randomly select 1300 indices for training
choice = np.random.choice(vae_input.shape[0], 1300, replace=False)    
train_ind = np.zeros(vae_input.shape[0], dtype=bool)
train_ind[choice] = True
test_ind = ~train_ind

# Split the input and output data into training and validation sets
x_train, x_val = vae_input[train_ind], vae_input[test_ind]
y_train, y_val = vae_output[train_ind], vae_output[test_ind]

# Convert the data to torch tensors
x_train, x_val, y_train, y_val = torch.tensor(x_train), torch.tensor(x_val), torch.tensor(y_train), torch.tensor(y_val)

# Set up the optimizer and train the VAE
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)
vae.train(x_train, y_train, optimizer, x_val, y_val, epochs=300, batch_size=100, w_kl=1.0)

# Save the trained VAE model
vae.save(vae_folder + 'trained_vae.pkl')

The following cells allow us to visualize the loss over epochs.

In [None]:
# Plot the loss over time

vae = valdo.VAE.load(vae_folder + 'trained_vae.pkl')
loss_array = np.array(vae.loss_train)

In [None]:
fig, axs = plt.subplots(nrows=3, ncols=1, figsize=[10,12])
ax = axs.reshape(-1)

ax[0].plot(loss_array[:,0], label='Total Loss, Training')
ax[0].plot(loss_array[:,3], label='Total Loss, Validation')
ax[0].set_xlabel("Steps")
ax[0].legend()

ax[1].plot(loss_array[:,1], label='Negative Log Likelihood, Training')
ax[1].plot(loss_array[:,4], label='Negative Log Likelihood, Validation')
ax[1].set_xlabel("Steps")
ax[1].legend()

ax[2].plot(loss_array[:,2], label='KL Divergence, Training')
ax[2].plot(loss_array[:,5], label='KL Divergence, Validation')
ax[2].set_xlabel("Steps")
ax[2].legend()


plt.show()

### Steps 5 & 6: Reconstruction of "Apo" Data & Calculating Difference Maps

In this step, VAE outputs are re-scaled accordingly to recover the original scale, and differences in amplitudes between the original and reconstructed data are calculated. A `recons` and a `diff` column will be created for all datasets.

#### Usage

To perform the reconstruction, or re-scaling, the `rescale()` function can be called, providing the necessary arguments:

- `recons_path`: Path to the reconstructed output of the VAE in NumPy format.
- `intersection_path`: Path to the pickle file containing the intersection of all scaled datasets.
- `union_path`: Path to the pickle file containing the union data of all scaled datasets.
- `input_files`: List of input file paths. This list should be in the same order as is in the `vae_input.npy` or `intersection.mtz`.
- `info_folder`: Path to the folder containing files with the mean and SD used for standardization previously.
- `output_folder`: Path to the folder where the reconstructed data will be saved.
- `amplitude_col`: Column in the MTZ file that contains structure factor amplitudes to calculate the difference column.

#### Code

In [None]:
# Load trained VAE

vae = valdo.VAE.load(vae_folder + 'trained_vae.pkl')

In [None]:
# Load input file and create a tensor

vae_input = np.load(vae_folder + 'vae_input.npy')
vae_input_tensor = torch.tensor(vae_input)
vae_input_tensor = vae_input_tensor.to("cuda:0")

In [None]:
# Reconstruct the input file via VAE, convert to numpy, and save

recons = vae.reconstruct(torch.tensor(vae_input_tensor))
recons = recons.detach().cpu().numpy()
np.save(vae_reconstructed_folder + 'recons', recons)

In [None]:
# Re-scale the reconstructed files accordingly and creates the `diff` column
# Function is valdo.preprocessing.rescale

valdo.preprocessing.rescale(recons_path=vae_reconstructed_folder + 'recons.npy', 
            intersection_path=intersection_path, 
            union_path=union_path, 
            input_files=file_list, 
            info_folder=vae_folder, 
            output_folder=vae_reconstructed_folder,
            amplitude_col="F-obs-scaled")

### Steps 7 & 8: Gaussian Blurring & Searching for Blobs

**Note Regarding Phases:** In this section, phases are required for each dataset. You can obtain phases by completing refinement via PHENIX for each dataset, and utilizing those phases.

**Note Regarding Models:** In this section, models are also required for each dataset. These can also be obtained by refinement via PHENIX for each dataset, and they should be stored in a single folder, with the same naming convention (i.e. ##.mtz).

We offer a command-line tool for automatic refinement using PHENIX. Based on our tests, starting with a single apo model yields satisfactory phases and models for the following real-space maps. You can find an example refine_drug.eff file in the notebook/ directory.

*Code Example:* `valdo.refine --pdbpath "xxx/xxx_apo.pdb" --mtzpath "xxx/*.mtz" --output "yyy/" --eff "xxx/refine_drug.eff"`

In this step, we aim to identify significant changes in electron density caused by ligand binding to a protein. By taking the absolute value of the electron density difference maps and applying Gaussian blurring, a new map is created with merged positive electron density blobs. The blurring process attempts to reduce noise. Blobs are then identified and characterized above a specified contour level and volume threshold.

#### Usage

To generate blobs from electron density maps, call the `generate_blobs()` function, which takes electron density map files and corresponding refined protein models as inputs. The function preprocesses the maps and identifies blobs above a specified contour level and volume threshold (the volume threshold is the default set by `gemmi`). The output is a DataFrame containing statistics for each identified blob, including peak value, score, centroid coordinates, volume, and radius. 

This function can be called with the following arguments:

- `input_files`: List of input file paths.
- `model_folder`: Path to the folder containing the refined models for each dataset (pdb format).
- `diff_col`: Name of the column representing diffraction values in the input MTZ files.
- `phase_col`: Name of the column representing phase values in the input MTZ files.
- `output_folder`: Path to the output folder where the blob statistics DataFrame will be saved.
- `cutoff`: Blob cutoff value. Blobs with values below this cutoff will be ignored (optional, default is `5`).
- `negate`: Whether to negate the blob statistics (optional, default is `False`). Use True if there is interest in both positive and negative peaks, which is not typically of interest here due to the absolute value function applied to the map.
- `sample_rate`: Sample rate for generating the grid in the FFT process (optional, default is `3`).

#### Code

The following cell adds phases to our newly reconstructed datasets. These phases are copied from `../../pipeline/data/refined_mtzs/` which were generated via PHENIX.

In [None]:
# List of reconstructed mtz files without phases to add phases to

file_list = glob.glob(vae_reconstructed_folder + "*mtz")

no_phases_files = []

# Phases here are copied from refinement 

for file in tqdm(file_list):
    
    current = rs.read_mtz(file)
    
    try:
        phases_df = rs.read_mtz('../../pipeline/data/refined_mtzs/'+os.path.basename(file))

    except:
        no_phases_files.append(file)
        continue
    
    current[phase_2FOFC_col] = phases_df['PH2FOFCWT']
    current[phase_FOFC_col] = phases_df['PHFOFCWT']
    
    current.write_mtz(vae_reconstructed_with_phases_folder + os.path.basename(file))

The following two cells complete gaussian blurring and blob searching. For the blurring, the radius is set to `5A` with `sigma = 5/3`.

In [None]:
# List of reconstructed mtz files (with phases) to identify blobs in

file_list = glob.glob(vae_reconstructed_with_phases_folder + "*mtz")

In [None]:
# Function in valdo.blobs that generates a list of blobs

valdo.blobs.generate_blobs(input_files=file_list, 
               model_folder='../../pipeline/data/refined_models/', 
               diff_col='diff', 
               phase_col='refine_PH2FOFCWT', 
               output_folder=blob_folder, 
               cutoff=3.5)

### Step 9: Identifying Events

In this final step, the highest scoring blobs returned in the previous step can be analyzed individually. If the blob is plausibly a ligand, refinement with a ligand may be completed to determine whether or not the blob can be considered a "hit."

Blobs that are returned can be related to various other events, not just ligand binding. Examples may include ligand-induced conformational change (which would still indicate the presence of a ligand) or various other unrelated conformational changes, such as radiation damage or cysteine oxidation (as is seen in `pipeline.ipynb`).

In the following example, we have also included the evaluation of our method, via AUC, in this section.

#### Helper Functions

These functions will help us do various tasks. 

For example, we may want to remove blobs that are associated with `cys215` oxidation, as in PTP1B, the oxidation of `cys215` modulates the protein's activity.

Additionally, to evaluate our method, we have a function that tags each blob to identify whether or not it is a blob associated with a ligand. We can then calculate the AUC. 

Various other helpful functions are also included.

##### Finds Nearby Atoms

In [None]:
def find_nearby_atoms(centroid_dict, structure_path, sample_no=None, radius=3):
    
    """
    Finds nearby atoms within a specified radius around a given centroid position in a structure file and returns the atom details as a DataFrame.

    The function reads the structure file in PDB format using gemmi, performs a neighbor search within the specified radius around the centroid position, and retrieves information about the nearby atoms. The atom details include the sample number, chain name, residue sequence ID, residue name, atom name, element name, coordinates (x, y, z), and distance from the centroid.

    Args:
        centroid_dict (dict): Dictionary containing the centroid position with keys 'x', 'y', and 'z'.
        structure_path (str): Path to the structure file in PDB format.
        sample_no (str): Sample number or identifier. Optional.
        radius (float, optional): Radius in angstroms to search for nearby atoms. Default is 3.

    Returns:
        pandas.DataFrame: DataFrame containing the details of the nearby atoms.

    Example:
        centroid = {'x': 10.0, 'y': 20.0, 'z': 30.0}
        structure_file = './data/structure.pdb'
        sample_number = 'S001'
        nearby_atoms = find_nearby_atoms(centroid, structure_file, sample_number, radius=5)
    """
    
    peaks = []
    
    structure = gemmi.read_pdb(structure_path)
    ns = gemmi.NeighborSearch(structure[0], structure.cell, radius).populate()
    centroid = gemmi.Position(centroid_dict["x"], centroid_dict["y"], centroid_dict["z"])
    marks = ns.find_atoms(centroid)
    
    for mark in marks:
        image_idx = mark.image_idx
        cra = mark.to_cra(structure[0])
        dist = structure.cell.find_nearest_pbc_image(centroid, cra.atom.pos, mark.image_idx).dist()

        record = {
            "sample"  :    sample_no,
            "chain"   :    cra.chain.name,
            "seqid"   :    cra.residue.seqid.num,
            "residue" :    cra.residue.name,
            "atom"    :    cra.atom.name,
            "element" :    cra.atom.element.name,
            "coordx"  :    cra.atom.pos.x,
            "coordy"  :    cra.atom.pos.y,
            "coordz"  :    cra.atom.pos.z,
            "dist"    :    dist
        }

        peaks.append(record)
        
    return pd.DataFrame(peaks)

##### Tag Blobs near Cys215

In [None]:
def tag_cys_215_blobs(df, structure_path, radius=3):
    
    """
    Tags the blobs in the DataFrame 'df' that contain the CYS 215 residue based on the nearby atoms found in PDB files.

    Args:
        df (pandas.DataFrame): The input DataFrame containing the blobs information.
        structure_path (str): The path to the folder containing the PDB files used for identifying nearby atoms.
        radius (int, optional): The radius in Angstroms for finding nearby atoms. Default is 3.

    Returns:
        pandas.DataFrame: The modified DataFrame with an additional 'cys215' column indicating the presence (1) or absence (0) of CYS 215 in the blobs.
    """
    
    def check_blob_for_cys(row):
        
        """
        Args:
            row (pandas.Series): A row of the DataFrame representing a blob.

        Returns:
            int: Returns 1 if the blob contains CYS 215, otherwise returns 0.
        """
        
        sample = row["sample"]
        
        cenx, ceny, cenz = row['cenx'], row['ceny'], row['cenz']
        atoms_df = find_nearby_atoms({"x": cenx, "y": ceny, "z": cenz}, structure_path + sample + '.pdb', sample, radius)
        
        if len(atoms_df) < 1:
            return 0
        
        if 215 in set(atoms_df['seqid']):
            return 1
        return 0
    
    tqdm.pandas()
    df['cys215'] = df.progress_apply(check_blob_for_cys, axis=1)
    
    return df

##### Tag Blobs near LIG Atoms

In [None]:
def tag_lig_blobs(df, structure_path):
    
    """
    Tags the blobs in the DataFrame 'df' that contain ligands based on the nearby atoms found in PDB files.

    Args:
        df (pandas.DataFrame): The input DataFrame containing the blobs information.
        structure_path (str): The path to the folder containing the PDB files used for identifying nearby atoms.

    Returns:
        pandas.DataFrame: The modified DataFrame with an additional 'ligand' column indicating the presence (1) or absence (0) of ligands in the blobs.

    """
    def check_blob_for_lig(row):
        
        """
        Args:
            row (pandas.Series): A row of the DataFrame representing a blob.

        Returns:
            int: Returns 1 if the blob contains ligands, otherwise returns 0.

        """
        
        if row["bound"] == 0:
            return 0
        
        sample = row["sample"]
        
        cenx, ceny, cenz = row['cenx'], row['ceny'], row['cenz']
        atoms_df = find_nearby_atoms({"x": cenx, "y": ceny, "z": cenz}, structure_path + sample + '.pdb', sample, row['radius'])

        if len(atoms_df) < 1:
            return 0
        
        if 'LIG' in set(atoms_df['residue']):
            return 1
        return 0
    
    tqdm.pandas()
    df['ligand'] = df.progress_apply(check_blob_for_lig, axis=1)
    
    return df

##### Validate Fractional Coordinates

In [None]:
def valid_fractional_coords(coords):
    
    """
    Converts fractional coordinates to valid fractional coordinates within the range [0, 1).

    Args:
        coords (list or numpy.ndarray): The input fractional coordinates.

    Returns:
        numpy.ndarray: The converted valid fractional coordinates within the range [0, 1).

    """
    
    valid_coords = np.array(coords)
    for i in range(3):
        while valid_coords[i] > 1:
            valid_coords[i] -= 1
        while valid_coords[i] < 0:
            valid_coords[i] += 1
    return valid_coords

##### Determine Locations

In [None]:
# Function to fractionalize coordinates and find all symmetry-related cartesian points 

def determine_locations(row, folder):
    
    """
    Converts coordinates to fractional form and determines all symmetry-related cartesian points for the given row.

    Args:
        row (pandas.Series): A row of the DataFrame representing a blob.
        folder (str): The path to the folder containing the mtz files.

    Returns:
        pandas.Series: A pandas Series containing the fractional coordinates, all possible fractional coordinates, and all possible cartesian coordinates.

    """
    
    # find mtz file for sample number
    mtz_file = folder + row['sample'] + '.mtz'
    if mtz_file is None:
        return pd.Series({'fractional': np.nan, 'all_possible_frac': np.nan, 'all_possible_cart': np.nan})
    
    # read in mtz file
    sample_file = rs.read_mtz(mtz_file)
    
    # fractionalize coordinates using move2cell
    frac_coords = move2cell([row['cenx'], row['ceny'], row['cenz']], sample_file.cell)
    
    # identify all symmetry operations
    all_ops = list(sample_file.spacegroup.operations().sym_ops)

    all_possible_frac = []
    for op in all_ops:
        result = op.apply_to_xyz(frac_coords)
        result = valid_fractional_coords(result)
        all_possible_frac.append(result)
        
    all_possible_frac = sorted(all_possible_frac, key=lambda x: x[0])
                
    # orthogonalize fractional coordinates
    all_possible_cart = [sample_file.cell.orthogonalize(gemmi.Fractional(*elt)) for elt in all_possible_frac]
    
    all_possible_cart = [np.array([elt.x, elt.y, elt.z]) for elt in all_possible_cart]
    
    return pd.Series({'fractional': frac_coords, 'all_possible_frac': all_possible_frac, 'all_possible_cart': all_possible_cart})

##### Move Points to the Unit Cell

In [None]:
def move2cell(cartesian_coordinates, unit_cell, fractionalize=True):
    
    '''
    Move your points into a unitcell with translational vectors
    
    Parameters
    ----------
    cartesian_coordinates: array-like
        [N_points, 3], cartesian positions of points you want to move
        
    unit_cell, gemmi.UnitCell
        A gemmi unitcell instance
    
    fractionalize: boolean, default True
        If True, output coordinates will be fractional; Or will be cartesians
    
    Returns
    -------
    array-like, coordinates inside the unitcell
    '''
    o2f_matrix = np.array(unit_cell.fractionalization_matrix)
    frac_pos = np.dot(cartesian_coordinates, o2f_matrix.T) 
    frac_pos_incell = frac_pos % 1
    for i in range(len(frac_pos_incell)):
        if frac_pos_incell[i] < 0:
            frac_pos_incell[i] += 1
    if fractionalize:
        return frac_pos_incell
    else:
        f2o_matrix = np.array(unit_cell.orthogonalization_matrix)
        return np.dot(frac_pos_incell, f2o_matrix.T)

##### Remove Duplicate Blobs

In [None]:
def mark_duplicates(blobs_df):
    
    """
    Marks duplicate blobs in the DataFrame based on proximity in cartesian coordinates. 
    Checks on a per-sample basis, checking blobs with adjacent peak values.

    Args:
        blobs_df (pandas.DataFrame): The input DataFrame containing the blob information.

    Returns:
        pandas.DataFrame: The modified DataFrame with an additional 'duplicate' column indicating duplicate blobs.

    """
    
    blobs_df = blobs_df.sort_values(by='peak', ascending=False)
    blobs_df['duplicate'] = 0  # Initialize 'duplicate' column with 0
    
    def check_euclidean_distance(list1, list2):
        
        """
        Checks if the distance between any pairwise points in two lists is less than 1.

        Args:
            list1 (list): The first list of cartesian coordinates.
            list2 (list): The second list of cartesian coordinates.

        Returns:
            bool: Returns True if any distance is less than 1, otherwise returns False.

        """
        
        for point1 in list1:
            for point2 in list2:
                distance = math.sqrt((point2[0] - point1[0])**2 + (point2[1] - point1[1])**2 + (point2[2] - point1[2])**2)
                if distance < 1:
                    return True
        return False
    
    grouped = blobs_df.groupby('sample')
    
    for _, group in grouped:
        if len(group) > 1:
            all_possible_cart_lists = group['all_possible_cart'].tolist()
            for i in range(1, len(all_possible_cart_lists)):
                if check_euclidean_distance(all_possible_cart_lists[i-1], all_possible_cart_lists[i]):
                    blobs_df.at[group.index[i], 'duplicate'] = 1
    
    return blobs_df

#### Tagging Blobs

In this section, we tag and filter the blobs. We remove...

- blobs that are duplicates (we occassionally have duplicate blobs due to an issue with gemmi's ASU mask)
- blobs associated with the oxidation of `cys215`
- blobs that belong to low quality samples (high r factors in refinement)
- blobs that belong to samples with inconsistent data (in particular, Helen Ginn lists a few samples as hits without including a ligand in their bound state).

In [None]:
blob_df = pd.read_pickle(blob_folder + 'blob_stats.pkl')

Tag Samples that are Bound (1 if bound, 0 otherwise)

In [None]:
with open("../../bound_sample_ids.txt") as f:
    bound_samples = set([line.strip() for line in f])

# Set the "bound" column based on whether or not each sample is in the bound samples list
blob_df["bound"] = blob_df["sample"].apply(lambda x: 1 if x in bound_samples else 0)

Tag Samples within 3A of a Cys215 Atom (1 if within, 0 otherwise)

In [None]:
blob_df = tag_cys_215_blobs(blob_df, '../../pipeline/data/refined_models/')

Tag Samples within `r` of a known LIG atom (1 if yes, 0 otherwise)

In [None]:
blob_df = tag_lig_blobs(blob_df, '../../pipeline/data/bound_models/')

Tag Blobs that are Duplicates of Other Blobs (Patch for Gemmi's ASU Mask Issue)

In [None]:
# Identifies all possible cartesian coordinates after symmetry operations
blob_df[['fractional', 'all_possible_frac', 'all_possible_cart']] = blob_df.apply(determine_locations, args=('../../pipeline/vae/reconstructed-phases/',), axis=1)

In [None]:
# Marks blobs as duplicates if they are within 1A of another blob in the same sample
blob_df = mark_duplicates(blob_df)

Tag Blobs Belonging to Samples with High R Factors

In [None]:
# In this case, no blobs are removed

r_factors = pd.read_csv('../../pipeline/data/refine_stats.csv')[['data_id', 'Rfree_final']]
high_r_factors = r_factors.loc[r_factors['Rfree_final'] > 0.4, 'data_id'].astype(str).str.zfill(4)
blob_df['high_r_factor'] = blob_df['sample'].isin(list(high_r_factors))

In [None]:
blob_df.to_pickle('../../pipeline/vae/blobs/blob_stats_tagged.pkl')

Filter Blobs (Remove Cys215, High R Factor, Duplicates, and more)

In [None]:
blob_df = pd.read_pickle('../../pipeline/vae/blobs/blob_stats_tagged.pkl')

In [None]:
# Remove all samples where Helen Ginn does not include a bound state model
# In this case, there are no blobs.

hg_no_lig = ['0060', '1429', '1733', '1791', '0225', '0432', '0710']
blob_df = blob_df[~blob_df['sample'].isin(hg_no_lig)] 

In [None]:
# Remove Cys215 related blobs, blobs in samples with high R factors, and duplicates

blob_df = blob_df[(blob_df['cys215']==0) & (blob_df['high_r_factor']==0) & (blob_df['duplicate']==0)]    

In [None]:
blob_df.to_pickle('../../pipeline/vae/blobs/filtered_blob_stats_tagged.pkl')

#### Generate AUC Curve

In this section, we take the list of filtered blobs and 1) determine the AUC and 2) plot the ROC curve. We use `score` as the metric by which we classify blobs â€“ a higher blob score means a higher likelihood that the blob represents ligand-binding.

In [None]:
def plot_roc_blob_stats(path, name=''):
    
    blob_df = pd.read_pickle(path)
    
    # create ROC curve
    fpr, tpr, thresholds = metrics.roc_curve(blob_df["ligand"], blob_df["score"], pos_label=1)
    roc_auc = metrics.auc(fpr, tpr)
    display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name)
    display.plot()
    
    print("Total Number of Blobs:", len(blob_df))
    print("Total Number of Unique Samples:", len(blob_df.drop_duplicates(subset='sample')))
    
    plt.savefig(os.path.dirname(path) + '/roc_curve.pdf')

In [None]:
plot_roc_blob_stats('../../pipeline/vae/blobs/filtered_blob_stats_tagged.pkl')

In [None]:
blob_df = pd.read_pickle('../../pipeline/vae/blobs/filtered_blob_stats_tagged.pkl')