In [1]:
from utils.parse_pdb import align_pdb, open_pdb, PDBError, get_pdb_file
import os
import boto3
import pickle
from tqdm import tqdm
from p_tqdm import p_map
import sidechainnet as scn
import numpy as np
from rcsbsearch import TextQuery, Attr
import subprocess

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def visualize(id):
    with open(f"./data/pdb/{id}.pickle", "rb") as f:
        data = pickle.load(f)
    crds = []
    seq = ""
    for chain in data:  
        crd = np.concatenate([data[chain]["crd_bb"], data[chain]["crd_sc"]], axis=1).reshape((-1, 3))
        crds.append(crd)
        seq += data[chain]["seq"]
    crd = np.concatenate(crds, 0)
    sb2 = scn.StructureBuilder(seq, crd)
    return sb2.to_3Dmol()

In [3]:
from collections import defaultdict

def get_log_stats(log_file):
    stats = defaultdict(lambda: 0)
    with open(log_file, "r") as f:
        for line in f.readlines():
            if line.startswith("<<<"):
                stats[line.split(':')[0]] += 1
    keys = sorted(stats.keys(), key=lambda x: stats[x], reverse=True)
    for key in keys:
        value = stats[key]
        print(f'{key}: {value}')

In [4]:
def get_unknown_stats(log_file):
    stats = defaultdict(lambda: [])
    with open(log_file, "r") as f:
        error = None
        id = None
        for line in f.readlines():
            if line.startswith("<<< Unknown"):
                error = ""
                id = line.split(":")[-1].strip()
            elif line.startswith("<<<") and error is not None:
                if error.startswith("Could not download"):
                    error = "Could not download PDB"
                stats[error].append(id)
                error = None
            elif error is not None:
                error += line
    keys = sorted(stats.keys(), key=lambda x: len(stats[x]), reverse=True)
    for key in keys:
        value = stats[key]
        print(f'{key}: {value}')

In [5]:
from utils.parse_pdb import get_pdb_file
bucket = boto3.resource('s3').Bucket("pdbsnapshots")

In [62]:
from torch.utils.data import Dataset
import torch
import random

class ProteinDataset(Dataset):
    """
    Dataset to load BestProt data

    Saves the model input tensors as pickle files in `features_folder`. When `clustering_dict_path` is provided,
    at each iteration a random bionit from a cluster is sampled.

    Returns dictionaries with the following keys and values (all values are `torch` tensors):
    - `'X'`: 3D coordinates of N, C, Ca, O (shape `(total_L, 4, 3)`),
    - `'S'`: sequence indices (shape `(total_L)`),
    - `'mask'`: residue mask (0 where coordinates are missing, 1 otherwise) (shape `(total_L)`),
    - `'residue_idx'`: residue indices (from 0 to length of sequence, +100 where chains change) (shape `(total_L)`),
    - `'chain_encoding_all'`: chain indices (shape `(total_L)`),
    - `'chain_id`': the chain id to mask.
    """

    def __init__(
            self, 
            dataset_folder, 
            features_folder, 
            clustering_dict_path=None, 
            max_length=100, 
            rewrite=False,
        ):
        """
        Parameters
        ----------
        dataset_folder : str
            the path to the folder with BestProt format input files (assumes that files are named {biounit_id}.pickle)
        features_folder : str
            the path to the folder where the ProteinMPNN features will be saved
        clustering_dict_path : str, optional
            path to the pickled clustering dictionary (keys are cluster ids, values are (biounit id, chain id) tuples)
        max_length : int, default 100
            entries with total length of chains larger than `max_length` will be disregarded
        """

        alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
        self.alphabet_dict = {letter: i for i, letter in enumerate(alphabet)}
        self.num_chains = {} # number of chains by biounit id
        self.files = {} # file path by biounit id
        self.dataset_folder = dataset_folder
        self.features_folder = features_folder
        output_tuples = p_map(lambda x: self._process(x, rewrite=rewrite), os.listdir(dataset_folder))
        # output_tuples = tqdm([self._process(x, rewrite=rewrite) for x in os.listdir(dataset_folder)])
        for id, filename, num_chains in output_tuples:
            self.files[id] = filename
            self.num_chains[id] = num_chains
        for filename in os.listdir(dataset_folder):
            self._process(filename)
        if clustering_dict_path is not None:
            with open(clustering_dict_path, "rb") as f:
                self.clusters = pickle.load(f) # list of biounit ids by cluster id
            self.data = list(self.clusters.keys())
        else:
            self.clusters = None
            self.data = list(self.files.keys())
        
    
    def _process(self, filename, rewrite=False):
        """
        Process a BestProt file and save it as ProteinMPNN features
        """

        input_file = os.path.join(self.dataset_folder, filename)
        output_file = os.path.join(self.features_folder, filename)
        with open(input_file, "rb") as f:
            data = pickle.load(f)
        chains = sorted(data.keys())
        if not rewrite and os.path.exists(output_file):
            pass
        else:
            X = []
            S = []
            mask = []
            chain_encoding_all = []
            residue_idx = []
            last_idx = 0
            for chain_i, chain in enumerate(chains):
                X.append(data[chain]["crd_bb"])
                S += [self.alphabet_dict[x] for x in data[chain]["seq"]]
                mask.append(data[chain]["msk"])
                residue_idx.append(torch.arange(len(data[chain]["seq"])) + last_idx)
                last_idx = residue_idx[-1][-1] + 100
                chain_encoding_all.append(torch.ones(len(data[chain]["seq"])) * chain_i)
            out = {}
            out["X"] = torch.from_numpy(np.concatenate(X, 0))
            out["S"] = torch.tensor(S)
            out["mask"] = torch.from_numpy(np.concatenate(mask))
            out["chain_encoding_all"] = torch.cat(chain_encoding_all)
            out["residue_idx"] = torch.cat(residue_idx)
            with open(output_file, "wb") as f:
                pickle.dump(out, f)
        return os.path.basename(filename).split('.')[0], output_file, len(chains), 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        chain_id = None
        if self.clusters is None:
            id = self.data[idx]
            chain_id = random.randint(0, self.num_chains[id] - 1)
        else:
            cluster = self.data[idx]
            id, chain_id = random.sample(self.clusters[cluster])
        file = self.files[id]
        with open(file, "rb") as f:
            data = pickle.load(f)
            data["chain_id"] = chain_id
        return data

In [63]:
folder = "data/subset"
cnt = 0
total = 0
for file in tqdm(os.listdir(folder)):
    with open(os.path.join(folder, file), "rb") as f:
        data = pickle.load(f)
    for chain in data:
        total += 1
        if "-" in data[chain]["seq"]:
            os.remove(os.path.join(folder, file))
            break
            cnt += 1

100%|██████████| 2131/2131 [00:00<00:00, 23187.14it/s]


In [64]:
dataset = ProteinDataset(dataset_folder='data/subset', features_folder="data/tmp_features")

100%|██████████| 2131/2131 [00:09<00:00, 217.91it/s]


In [65]:
dataset[0]

{'X': tensor([[[18.1920, 25.9740, 29.7160],
          [16.1630, 26.8260, 28.5070],
          [16.7100, 26.1890, 29.7970],
          [16.3770, 28.0150, 28.2540]],
 
         [[15.4450, 26.0300, 27.7090],
          [13.7730, 27.5300, 26.5110],
          [14.8740, 26.4880, 26.4230],
          [12.8350, 27.4050, 27.3130]],
 
         [[13.8790, 28.5420, 25.6590],
          [12.0540, 29.2960, 24.3770],
          [12.8780, 29.5940, 25.6150],
          [12.5720, 29.3250, 23.2550]],
 
         ...,
 
         [[21.4890, 17.9400, 29.5560],
          [19.5650, 19.3690, 29.9680],
          [20.4220, 18.6940, 28.8980],
          [19.5650, 18.9460, 31.1240]],
 
         [[18.8370, 20.4370, 29.6040],
          [16.9570, 20.2990, 31.2810],
          [17.9880, 21.1590, 30.5600],
          [16.4870, 19.2980, 30.7420]],
 
         [[16.6160, 20.7010, 32.5020],
          [14.2280, 20.0150, 32.8040],
          [15.6530, 19.9810, 33.3330],
          [14.0100, 20.5590, 31.7000]]], dtype=torch.float64),
 'S'