In [10]:
# to start, let's try to implement a REINFORCE algorithm (policy gradient)
# we only need some sort of featurizer, 
# a policy graph network, 
# and the reinforcement learning loop

In [11]:
# # note: code borrowed from gaeun
import os 
import yaml

yamls_path = "/project/liulab/gkim/antigen_prediction/eval_boltz_on_sabdab/all_yaml_outdir"
pdbs_path = "/project/liulab/gkim/antigen_prediction/data/renumbered_sabdab_pdb_files/pdb_files"

def get_chain_info_from_pdb(pdb_path, yaml_path):
    """Get chain information from YAML file."""    
    if not os.path.exists(pdb_path):
        return None, None, None, None, None
    
    if not os.path.exists(yaml_path):
        print(f"No YAML file found at {yaml_path}")
        return None, None, None, None, None
    
    try:
        with open(yaml_path, 'r') as f:
            yaml_data = yaml.safe_load(f)
        
        # Extract chain IDs and sequences from YAML data
        # Assume the first sequence is heavy and the second sequence is light
        # UNLESS there are more than 2 sequences
        h_chain = None
        l_chain = None
        h_seq_yaml = None
        l_seq_yaml = None

        # Look for sequences in the YAML data
        if 'sequences' in yaml_data and isinstance(yaml_data['sequences'], list):
            sequences = yaml_data['sequences']
            if len(sequences) == 2:
                h_chain = sequences[0]['protein']['id']  # First sequence is heavy
                l_chain = sequences[1]['protein']['id']  # Second sequence is light
                h_seq_yaml = sequences[0]['protein']['sequence']
                l_seq_yaml = sequences[1]['protein']['sequence']
            elif len(sequences) > 2:
                # first sequence is antigen (for multimer predictions)
                h_chain = sequences[1]['protein']['id']  # Second sequence is heavy
                l_chain = sequences[2]['protein']['id']  # Third sequence is light
                h_seq_yaml = sequences[1]['protein']['sequence']
                l_seq_yaml = sequences[2]['protein']['sequence']
        
        if 'antigen' in yaml_data and isinstance(yaml_data['antigen'], list):
            antigen = yaml_data['antigen'][0]['protein']['sequence']
        else:
            antigen = None
        
        return h_chain, l_chain, h_seq_yaml, l_seq_yaml, antigen
        
    except Exception as e:
        print(f"Error reading YAML file for {yaml_path}: {e}")
        return None, None, None, None, None

In [12]:
import pandas as pd
sd_pd = pd.read_csv("sabdab_summary_all.tsv", sep="\t")

In [214]:
from Bio.PDB.Polypeptide import three_to_index

def featurizer(heavy_chain, ag_chain, residues, dist_matrix):
    # https://towardsdatascience.com/graph-convolutional-networks-introduction-to-gnns-24b3f60d6c95/
    # node feature matrix with shape (number of nodes, number of features)
    # graph connectivity (how the nodes are connected) with shape (2, number of directed edges)
    # node ground-truth labels. In this problem, every node is assigned to one class (group)

    # deduplicate while maintaining order
    heavy_residues = list(dict.fromkeys(residues[0]))
    ag_residues = list(dict.fromkeys(residues[1]))

    node_features = torch.zeros(len(heavy_residues + ag_residues), 2)
    heavy_idx_to_node_idx = {}
    ag_idx_to_node_idx = {}
    
    # heavy chains are "0" and antigen chains are "1"
    for i, res_idx in enumerate(heavy_residues + ag_residues):
        if i < len(heavy_residues):
            node_features[i][0] = 0
            node_features[i][1] = three_to_index(heavy_chain[res_idx].get_resname())
            heavy_idx_to_node_idx[res_idx] = i
        else:
            node_features[i][0] = 1
            node_features[i][1] = three_to_index(ag_chain[res_idx].get_resname())
            ag_idx_to_node_idx[res_idx] = i
            
    hc_nodes = torch.tensor([heavy_idx_to_node_idx[id] for id in residues[0]])
    ag_nodes = torch.tensor([ag_idx_to_node_idx[id] for id in residues[1]])
    edge_connections = torch.vstack((hc_nodes, ag_nodes)).T
    
    num_edges = edge_connections.shape[0]
    edge_features = torch.zeros(num_edges, 1)
    for j in range(num_edges):
        a, b = edge_connections[j]
        # this was created pre-deduplication
        edge_features[j] = dist_matrix[residues[0][a], residues[1][b]].item()

    return node_features, edge_connections, edge_features

In [215]:
i = 0
k = 12

from Bio.PDB import PDBList, PDBParser, Select, PDBIO
import numpy as np
from tqdm import tqdm 
import torch

parser = PDBParser(QUIET=True)

for yaml_file in tqdm(os.listdir(yamls_path), desc="Processing YAML files"):
    yaml_path = os.path.join(yamls_path, yaml_file)
    name = yaml_file.split('.')[0]
    pdb_file = name + '.pdb'
    pdb_path = os.path.join(pdbs_path, pdb_file)
    # we'll use Gauen's function because it already maps from the name to the pdb that is already downloaded on the server...
    h, l, _, _, _ = get_chain_info_from_pdb(pdb_path, yaml_path)
    row = sd_pd[(sd_pd["pdb"] == name) & (sd_pd["Hchain"] == h) & (sd_pd["Lchain"] == l)]
    ag = row["antigen_chain"].item() # this gives us the antigen chain alone!
    structure = parser.get_structure(name, pdb_path)

    # make a distance matrix
    heavy_Cas = []
    
    # convert this into accessing entries in a generator?
    heavy_chain = structure[0][h]
    antigen_chain = structure[0][ag]

    # construct the distance matrix
    heavy_coords = np.array([res['CA'].coord for res in heavy_chain if 'CA' in res])
    # heavy_residues = np.array([int(res.id[1]) for res in heavy_chain if 'CA' in res])
    antigen_coords = np.array([res['CA'].coord for res in antigen_chain if 'CA' in res])
    # antigen_residues = np.array([int(res.id[1]) for res in antigen_chain if 'CA' in res])
    
    dist_matrix = np.linalg.norm(
        heavy_coords[:, np.newaxis, :] - antigen_coords[np.newaxis, :, :],
        axis=-1
    )

    # https://numpy.org/devdocs/reference/generated/numpy.argpartition.html
    # only sort the bottom k
    bottom_k = np.argpartition(dist_matrix.flatten(), k)[:k]
    # flatten and unravel :)
    bottom_k_indices = np.unravel_index(bottom_k, dist_matrix.shape)

    node_features, edge_connections, edge_features = featurizer(heavy_chain, antigen_chain, bottom_k_indices, dist_matrix)


Processing YAML files:   0%|          | 0/6920 [00:00<?, ?it/s]

(217, 6)
tensor([[ 6.2841],
        [ 6.6791],
        [12.4140],
        [ 9.9115],
        [ 8.2325],
        [14.7043],
        [10.7663],
        [ 7.4657],
        [ 6.9076],
        [ 8.6688],
        [ 9.2591],
        [ 8.5089]])





Exception: 