Notebook for cleaning the peptide data

In [None]:
# Supress pytorch pickle load warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Logging
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle

# Library imports
import gdiffusion as gd
import util
import util.chem as chem
import util.visualization as vis
import util.stats as gdstats


import gdiffusion.bayesopt as bayesopt
from gdiffusion.classifier.logp_predictor import LogPPredictor

device = util.util.get_device()
print(f"device: {device}")

DIFFUSION_PATH = "saved_models/diffusion/molecule-diffusion-v1.pt"
SELFIES_VAE_PATH = "saved_models/selfies_vae/selfies-vae.ckpt"
PEPTIDE_VAE_PATH = "saved_models/peptide_vae/peptide-vae.ckpt"
LOGP_PREDICTOR_PATH = "saved_models/logp/model-logp"

import h5py
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


device: cuda


In [None]:
with open(file="data/raw_peptide/peptide_raw_10M.csv") as f:
    for i in range(100):
        print(f.readline())

# data is PEPTIDE, EXTINCT

text,labels

FLPQGTPSPLIPMLILIETISLFIQPMALAVRLTANITAGHLLIHL,1

VMATAFMGYVLPWGQMSFWGATVITNLLSAIPYIGPTLVEWIW,0

QDIRKMGGMMYTLPFTSSCLMIGTLALTGMPFMTGFYSKDHII,1

AFMGYVLPWGQMSFWGATVITNLLSAIPYIGTTLVEW,0

MLTMIPILMKTTNPRSTEAATKYFMTQATASMMLMMALTINLVYS,1

LLVLFIMFQLKVSNHMYPMNPELIKPKLKEQKTPWE,1

QCPKPTLQQISHIAQQLGLEKDVVRVWFCNRRQKGKRSSSDYSQREDF,0

LASATNTWEIQQL,0

IQQAFSHTQAPTLPLLGLILAATGKSAQ,0

MAIAMLSLLSLFFYLRLAYHSTIILPPNSSNH,1

DVIRESTFQGHHTTTVQKGLRYGMVLFIVSEVFFFLGFFW,1

MISHIVTYYSGKKEPFGYMGMVWAMVSIGFLGFIVWA,0

PILIAMAFLMLTERKILGYMQLRKGPNVVGPYGL,0

IPMITNSLT,1

PWASQTSKLPTMLITALL,0

PPLSGFLPKWMIIQEMTKNSLIIMPTMMAI,1

ALMVALAICSLVLYLLTLMLTEKLSS,0

ADAIKLFTKEPLKPSTS,0

LLILVLFLPDLLGDPDNYTPANPLN,1

IIMYNPTLMALNLIIYLLMT,1

IPGGPFENLEIRRFDRVKDTEWNDFEYRFIS,0

ASEPYTTKFFYYLLMFLI,1

NFTPANPLATPPHIKPEWYFLFAYAILRSIPNKLGG,0

SMLPIILLVFAAC,0

NVFGFKALRALRLEDLRIPTAYVKTFQGPPHGIQVERDKLNKYGRPLLGC,0

AKMPLYGLHLWLPKA,0

KDAIFSVSIAYFGIFIASF,0

NMSFWLLPPSFLLLLASSTVEAGAGTGWTVYPPLAGNMAH,0

KILGYMQLRKGPNIVGPLGLLQPMAD

In [None]:
with open("data/raw_peptide/peptide_raw_4p5.csv") as f:
    for _ in range(5):
        print(f.readline())

# all non-extinct (0)

sequence

QRSTPYCRQSIPKGTIV

STPYCRQSIPKGTIVPLKGP

GIISHIWALARHTLFTNTFQDDER

TGTGNALRRRATSVATSVGTD



In [9]:
# dataset lengths:
def get_num_lines(file_name):
    total_lines = 0
    with open(file_name, 'r') as f:
        for line in f:
            if line.strip():
                total_lines += 1
    return total_lines

peptide_10m_len = get_num_lines("data/raw_peptide/peptide_raw_10M.csv")
peptide_4p5_len = get_num_lines("data/raw_peptide/peptide_raw_4p5.csv")

In [None]:
print(f"10 Million Peptide Len: {peptide_10m_len}")
print(f"4.5 Million Peptide Len: {peptide_4p5_len}")

10 Million Peptide Len: 10274724
4.5 Million Peptide Len: 4500001


In [None]:
# subtract 2 for the header lines
total_len = peptide_10m_len + peptide_4p5_len - 2

In [16]:
peptide_latent_dim = 256

In [None]:
# # Create h5py file, do not run again!

# # data source is if its from the 10M dataset (0) or the 4.5M dataset (1)
# dataset_file = "data/peptide_dataset.h5"
# with h5py.File(dataset_file, 'w') as h5file:
#     peptide_dataset = h5file.create_dataset('PEPTIDES', (total_len), dtype=h5py.string_dtype())
#     extinct_dataset = h5file.create_dataset('EXTINCT', (total_len), dtype=bool)
#     data_source = h5file.create_dataset('DATA_SOURCE', (total_len), dtype=np.int8)
#     latents = h5file.create_dataset('LATENTS', (total_len, peptide_latent_dim), dtype=np.float32)
    

In [24]:
# Read from peptide latent data:
def read_peptide_dataset_raw(i: int, data_path="data/peptide_dataset.h5"):
    with h5py.File(data_path, 'r') as f:
        return f['PEPTIDES'][i], f['EXTINCT'][i], f['DATA_SOURCE'][i], f['LATENTS'][i]
    
def read_peptide_dataset(i: int, data_path="data/peptide_dataset.h5"):
    with h5py.File(data_path, 'r') as f:
        raw_peptide, raw_extinct, raw_datasource, raw_latent = f['PEPTIDES'][i], f['EXTINCT'][i], f['DATA_SOURCE'][i], f['LATENTS'][i]
        peptide = raw_peptide.decode('utf-8')
        extinct = bool(raw_extinct)
        datasource = 'peptide_10M' if raw_datasource == 0 else 'peptide_4.5M'
        latent = raw_latent

    return peptide, latent, extinct, datasource

In [28]:
RAW_DATA_PATH = "data/raw_peptide/peptide_raw_10M.csv"
PEPTIDE_DATASET_PATH = "data/peptide_dataset.h5"

In [58]:
RAW_DATA_PATH_4P5 = "data/raw_peptide/peptide_raw_4p5.csv"

In [42]:
def write_peptide_10M_dataset(start_idx: int = 0, start_line_num: int = 1):
    # line_num is 0-indexed, so line_num 0 is the first line!
    
    with open(RAW_DATA_PATH, 'r') as infile, h5py.File(PEPTIDE_DATASET_PATH, 'r+') as outfile:
        peptide_ds = outfile['PEPTIDES']
        extinct_ds = outfile['EXTINCT']
        data_source_ds = outfile['DATA_SOURCE']

        # skip first start_line_num lines in infile csv
        for _ in range(start_line_num):
            next(infile)

        idx = start_idx
        for line_num, raw_line in tqdm(enumerate(infile, start=start_line_num), total=peptide_10m_len, desc='Reading Peptide10M CSV'):
            try:
                raw_line = raw_line.strip()
                peptide, extinct = raw_line.split(',')

                if peptide is None or extinct is None:
                    raise ValueError(f"peptide, extinct is wrong: peptide={peptide} extinct={extinct}")

                if peptide_ds[idx] != b'':
                    print(f"Warning, overriding peptide data: {peptide_ds[idx]}! Aborting")
                    raise ValueError("See above.")
                
                peptide_ds[idx] = peptide
                extinct_ds[idx] = extinct
                data_source_ds[idx] = 0 # this coorosponds to the 10M dataset
                idx += 1

            except Exception as e:
                print(f"Encountered an error while processing line_num {line_num}")
                print(f"Line was {raw_line} ")
                print(f"peptide = {peptide} ")
                print(f"extinct = {extinct} ")
                print(f"Attempted to index into idx={idx}")
                print("Error MSG:")
                print(e)

                print(f"Removing idx: {idx} data")
                peptide_ds[idx] = ''
                extinct_ds[idx] = False
                data_source_ds[idx] = 0
                return

# write_peptide_10M_dataset()

Reading Peptide10M CSV: 100%|█████████▉| 10274723/10274724 [16:58<00:00, 10089.92it/s]


In [None]:
# TODO: Determine start_idx
def write_peptide_4p5M_dataset(start_idx: int = None, start_line_num: int = 1):
    # line_num is 0-indexed, so line_num 0 is the first line!
    
    with open(RAW_DATA_PATH_4P5, 'r') as infile, h5py.File(PEPTIDE_DATASET_PATH, 'r+') as outfile:
        peptide_ds = outfile['PEPTIDES']
        extinct_ds = outfile['EXTINCT']
        data_source_ds = outfile['DATA_SOURCE']

        # skip first start_line_num lines in infile csv
        for _ in range(start_line_num):
            next(infile)

        idx = start_idx
        for line_num, raw_line in tqdm(enumerate(infile, start=start_line_num), total=peptide_4p5_len, desc='Reading Peptide 4.5M CSV'):
            try:
                peptide = raw_line.strip()

                if peptide is None:
                    raise ValueError(f"peptide, extinct is wrong: peptide={peptide}")

                if peptide_ds[idx] != b'':
                    print(f"Warning, overriding peptide data: {peptide_ds[idx]}! Aborting")
                    raise ValueError("See above.")
                
                peptide_ds[idx] = peptide
                extinct_ds[idx] = False # these are all modern peptides, so they are all NOT extinct
                data_source_ds[idx] = 1 # this coorosponds to the 4.5 dataset
                idx += 1

            except Exception as e:
                print(f"Encountered an error while processing line_num {line_num}")
                print(f"Line was {raw_line} ")
                print(f"peptide = {peptide} ")
                print(f"Attempted to index into idx={idx}")
                print("Error MSG:")
                print(e)

                print(f"Removing idx: {idx} data")
                peptide_ds[idx] = ''
                extinct_ds[idx] = False
                data_source_ds[idx] = 0
                return

peptide_start_index = peptide_10m_len - 1
# write_peptide_4p5M_dataset(start_idx=peptide_start_index)

Reading Peptide 4.5M CSV: 100%|█████████▉| 4500000/4500001 [07:29<00:00, 10001.05it/s]


In [62]:
# Confirming data:
peptide_amino_acids = set('ACDEFGHIKLMNPQRSTVWY')
with h5py.File("data/peptide_dataset.h5", mode='r') as f:
    peptide_ds = f['PEPTIDES']
    extinct_ds = f['EXTINCT']
    data_source_ds = f['DATA_SOURCE']

    for idx in tqdm(range(len(peptide_ds))):
        peptide = peptide_ds[idx].decode('utf-8')
        extinct = extinct_ds[idx]
        data_source = extinct_ds[idx]

        if len(peptide) <= 4:
            print(f"Small Peptide peptide at index {idx}: {peptide}")

        if not peptide.isupper():
            print(f"Peptide at index {idx} is not uppercase: {peptide}")
        
        if not all(c in peptide_amino_acids for c in peptide):
            print(f"Peptide at index {idx} is not a valid peptide! {peptide}")

  5%|▍         | 684970/14774723 [00:32<11:14, 20879.79it/s]


KeyboardInterrupt: 

In [67]:
peptide, _, extinct, data_source = read_peptide_dataset(total_len-1)
string_print = f"""
    Peptide: {peptide}
    Extinct: {str(extinct)}
    Data Source: {data_source}
    """
print(string_print)


    Peptide: ALAPRHADVVAPRLMAITRAGVTALVLTAFLGVRGLNPGADLL
    Extinct: False
    Data Source: peptide_4.5M
    


In [None]:
# attatch latents to dataset (this will take a long long time)

vae = gd.load_vae_peptides()
gd.peptides_to_latent()

In [None]:
from difflib import SequenceMatcher

def peptides_similar(seq1, seq2, max_changes=5):
   matcher = SequenceMatcher(None, seq1, seq2)
   opcodes = matcher.get_opcodes()
   changes = sum(1 for tag, _, _, _, _ in opcodes if tag != 'equal')
   return changes <= max_changes