In [57]:
%cd /home/yinj0b/repository/cfp-gen/

/home/yinj0b/repository/cfp-gen


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


# Create Functional Protein Dataset


In [58]:
import os
import glob
import re
import pickle
from collections import Counter, defaultdict
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
from Bio.Seq import Seq
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from subprocess import run
import shutil
import json
import gzip
import random
from Bio.PDB import PDBParser
from Bio.SeqUtils import seq1
from difflib import SequenceMatcher
import numpy as np
from multiprocessing import Pool, cpu_count
from src.byprot.utils.ontology import Ontology

In [59]:
def load_pkl_file(file_path):
    with open(file_path, 'rb') as f:
        return pickle.load(f)

In [60]:
def save_pkl_file(data, file_path):
    with open(file_path, 'wb') as f:
        pickle.dump(data, f)
    print(f"Updated data saved to {file_path}")

In [62]:
def count_go_ipr(new_dataset):
    final_ipr_counter = Counter()
    final_go_f_counter = Counter()

    for entry in new_dataset:
        ipr_numbers = entry.get('ipr_numbers', [])
        go_numbers_f = entry.get('go_numbers', {}).get('F', [])

        final_ipr_counter.update(ipr_numbers)
        final_go_f_counter.update(go_numbers_f)

    return final_ipr_counter, final_go_f_counter

In [63]:
def filter_by_min_max_count(counter, min_count, max_count):
    return set([item for item, count in counter.items() if min_count <= count <= max_count])

In [64]:
def create_filtered_datasets(data, selected_ipr, selected_go_f):

    new_dataset = []

    for entry in data:
        ipr_numbers = set(entry.get('ipr_numbers', []))
        go_numbers_f = set(entry.get('go_numbers', {}).get('F', []))

        if (not len(ipr_numbers)) or (not len(go_numbers_f)):
            continue

        if ipr_numbers.issubset(selected_ipr) and go_numbers_f.issubset(selected_go_f):
            new_dataset.append(entry)

    return new_dataset

In [65]:
def iterative_filtering(data, min_count, max_count, max_iterations=10):

    iteration = 0
    while iteration < max_iterations:
        print(f"Iteration {iteration + 1}")

        # Analyze the current dataset: count the frequency of IPR and GO(F) numbers
        ipr_counter, go_f_counter = count_go_ipr(data)

        # Select IPR and GO(F) numbers that appear at least min_count times but no more than max_count times
        selected_ipr = filter_by_min_max_count(ipr_counter, min_count, max_count)
        selected_go_f = filter_by_min_max_count(go_f_counter, min_count, max_count)

        # Generate a filtered dataset containing only the selected IPR and GO(F) entries
        new_dataset = create_filtered_datasets(data, selected_ipr, selected_go_f)

        # Count the attributes again in the new dataset
        final_ipr_counter, final_go_f_counter = count_go_ipr(new_dataset)

        if all(min_count <= count <= max_count for count in final_ipr_counter.values()) and \
                all(min_count <= count <= max_count for count in final_go_f_counter.values()):
            print(f"Converged at iteration {iteration + 1}")
            break
        data = new_dataset
        iteration += 1

    new_dataset = deduplicate_by_uniprot_id(new_dataset)
    final_ipr_counter, final_go_f_counter = count_go_ipr(new_dataset)

    return new_dataset, final_ipr_counter, final_go_f_counter

In [67]:
def update_motif_info(entries, ontology):
    for _, sample in tqdm(enumerate(entries), total=len(entries)):
        domain_sites = []
        seen = set()
        for item in sample['domain_sites']:
            # Use ('ipr_number', 'domain_id') as a unique identifier
            identifier = (item['ipr_number'], item['domain_id'])
            if identifier not in seen:
                seen.add(identifier)
                domain_sites.append(item)

        go_terms = sample['go_numbers']['F']
        term_desc = []
        for go_term in go_terms:
            ontology_terms = ontology.get_term(go_term)
            if ontology_terms is not None:
                term_desc.append(ontology_terms['name'])

        motif_segments = []
        for term in term_desc:
            # Find the closest matching ipr_description for each term_desc
            closest_matches = sorted(domain_sites, key=lambda item: SequenceMatcher(None, term, item['ipr_description']).ratio(), reverse=True)[:2]

            if not closest_matches:
                continue

            # Extract start_position and end_position of the matches
            start_positions = [item['start_position'] for item in closest_matches]
            end_positions = [item['end_position'] for item in closest_matches]

            # Calculate the overlapping region
            overlap_start = max(start_positions)
            overlap_end = min(end_positions)

            # If there is an overlap, retrieve the corresponding sequence segment
            if overlap_start < overlap_end:
                motif_segment = sample['sequence'][overlap_start:overlap_end]
                motif_segments.append({
                    'go_term': term,
                    'motif_segment': motif_segment,
                    'start': int(overlap_start),
                    'end': int(overlap_end)
                })

        # Add motif_segments into the sample as a new key-value pair
        sample['motif'] = motif_segments

    return entries

In [76]:
def update_domain_info(entries, ipr_map):
    for entry in entries:
        uniprot_id = entry.get('uniprot_id')
        ipr_numbers = entry.get('ipr_numbers', [])

        if uniprot_id in ipr_map:
            if 'domain_sites' not in entry:
                entry['domain_sites'] = []

            for ipr_number in ipr_numbers:
                domains = [domain for domain in ipr_map[uniprot_id] if domain['ipr_number'] == ipr_number]
                if domains:
                    entry['domain_sites'].extend(domains)
        else:
            if 'domain_sites' not in entry:
                entry['domain_sites'] = []
    return entries

In [69]:
def select_sequences_with_go_count(train_data, final_go_f_counter, min_count=20, target_per_label=20):
    """
    Select sequences from train_data such that each GO category appears at least target_per_label times.
    """
    # Only consider GO labels with counts greater than min_count
    valid_go_labels = {label for label, count in final_go_f_counter.items() if count > min_count}

    # Track the number of selected sequences for each GO label
    go_label_count = defaultdict(int)

    # List to store the final selected sequences
    selected_sequences = []

    # Traverse the training dataset
    for entry in tqdm(train_data):
        go_labels = entry['go_numbers']['F']  # GO annotations of the current sequence

        # Select valid labels (appear more than min_count and have not yet reached target_per_label)
        valid_labels_in_entry = [label for label in go_labels if label in valid_go_labels and go_label_count[label] < target_per_label]

        # If the current sequence contains any valid labels
        if valid_labels_in_entry:
            selected_sequences.append(entry)

            # Update the count for each GO label
            for label in valid_labels_in_entry:
                go_label_count[label] += 1

            # Remove labels that have already reached the target count
            valid_go_labels = {label for label in valid_go_labels if go_label_count[label] < target_per_label}

        # Early exit if all labels have reached the target
        if not valid_go_labels:
            break

    # Double check: remove sequences associated with labels that still do not meet the target count
    _, fina_go_count = count_go_ipr(selected_sequences)
    key_to_rm = [k for k, v in fina_go_count.items() if v < target_per_label]
    filtered_sequences = [
        entry for entry in selected_sequences
        if not any(go_label in key_to_rm for go_label in entry['go_numbers']['F'])
    ]

    return filtered_sequences

In [71]:
def count_ec(new_dataset):
    final_ec_counter = Counter()

    for entry in new_dataset:
        ec_numbers = entry.get('EC_number', [])

        final_ec_counter.update(ec_numbers)

    return final_ec_counter

In [72]:
def deduplicate_by_uniprot_id(data):
    seen = set()
    deduplicated_list = []

    for ele in data:
        uid = ele['uniprot_id']
        if uid not in seen:
            deduplicated_list.append(ele)
            seen.add(uid)

    return deduplicated_list

#### Create `cfpgen_general_dataset` based on `uniprot_swissprot_raw.pkl`

In [74]:
# Get the absolute path of the current script
base_dir = os.getcwd()
assert base_dir.endswith("cfp-gen"), "Need to run in cfp-gen root directory"

# Construct the work directory path
work_dir = os.path.join(base_dir, "data-bin/uniprotKB")

# Create the work directory if it does not exist
general_dataset_path = os.path.join(work_dir, 'cfpgen_general_dataset')
os.makedirs(general_dataset_path, exist_ok=True)

raw_pkl_path = os.path.join(work_dir, 'uniprot_swissprot_raw.pkl')
raw_data = load_pkl_file(raw_pkl_path)

# set minimum number for each class
min_count = 100
max_count = 20000
max_iterations = 10

final_dataset, final_ipr_counter, final_go_f_counter = iterative_filtering(raw_data, min_count, max_count, max_iterations)

print(f"Number of unique IPRs: {len(final_ipr_counter)}")
print(f"Number of unique GO(F): {len(final_go_f_counter)}")
print(f"Length of filtered Dataset: {len(final_dataset)}")

save_name = os.path.join(general_dataset_path, f'uniprot_swissprot_go-{len(final_go_f_counter)}_ipr-{len(final_ipr_counter)}.pkl')
save_pkl_file(final_dataset, save_name)

Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Converged at iteration 5


Analyzing entries: 100%|██████████| 103936/103936 [00:00<00:00, 694173.99it/s]

Number of unique IPRs: 1154
Number of unique GO(F): 375
Length of filtered Dataset: 103936





Updated data saved to /home/yinj0b/repository/cfp-gen/data-bin/uniprotKB/cfpgen_general_dataset/uniprot_swissprot_go-375_ipr-1154.pkl


Add functional domain info add motif info, based on `protein2ipr`.

In [77]:
data = load_pkl_file(save_name)

protein2ipr_path = os.path.join(work_dir, 'protein2ipr_pkls')

# iterate pkls in protein2ipr
for pkl_file in tqdm(os.listdir(protein2ipr_path)):
    pkl_file_path = os.path.join(protein2ipr_path, pkl_file)

    with open(pkl_file_path, 'rb') as f:
        ipr_map = pickle.load(f)

        data = update_domain_info(data, ipr_map)

go_path = os.path.join(work_dir, 'go.obo')
ontology = Ontology(go_path, with_rels=True)

# make motif based on go main annotations
updated_uniprot_entries = update_motif_info(data, ontology)

# save updated data
save_pkl_file(updated_uniprot_entries, save_name)

100%|██████████| 342/342 [25:38<00:00,  4.50s/it]
100%|██████████| 103936/103936 [01:28<00:00, 1170.63it/s]


Updated data saved to /home/yinj0b/repository/cfp-gen/data-bin/uniprotKB/cfpgen_general_dataset/uniprot_swissprot_go-375_ipr-1154.pkl


Split `train/valid/test` pkls.

In [81]:
train_save_path = os.path.join(general_dataset_path, 'train.pkl')
val_save_path = os.path.join(general_dataset_path, 'valid.pkl')
test_save_path = os.path.join(general_dataset_path, 'test.pkl')

# Load the full dataset
data = load_pkl_file(save_name)

# Count IPR and GO(F) numbers
final_ipr_counter, final_go_f_counter = count_go_ipr(data)

# Select sequences for the test set
selected_go_test = select_sequences_with_go_count(data, final_go_f_counter, min_count=50, target_per_label=30)
test_ids = set([ele['uniprot_id'] for ele in selected_go_test])

# Remaining sequences are for training
selected_go_train = [ele for ele in data if ele['uniprot_id'] not in test_ids]
selected_go_train = deduplicate_by_uniprot_id(selected_go_train)

# Save the datasets
save_pkl_file(selected_go_train, train_save_path)
save_pkl_file(selected_go_test, test_save_path)
save_pkl_file(selected_go_test[::10], val_save_path)  # only used for checking loss

# Print final dataset sizes
print(f"Final dataset sizes:")
print(f"  Train set: {len(selected_go_train)} sequences")
print(f"  Validation set: {len(selected_go_test[::10])} sequences")
print(f"  Test set: {len(selected_go_test)} sequences")

 99%|█████████▉| 103243/103936 [00:00<00:00, 482393.97it/s]


Updated data saved to /home/yinj0b/repository/cfp-gen/data-bin/uniprotKB/cfpgen_general_dataset/train.pkl
Updated data saved to /home/yinj0b/repository/cfp-gen/data-bin/uniprotKB/cfpgen_general_dataset/test.pkl
Updated data saved to /home/yinj0b/repository/cfp-gen/data-bin/uniprotKB/cfpgen_general_dataset/valid.pkl
Final dataset sizes:
  Train set: 95627 sequences
  Validation set: 831 sequences
  Test set: 8309 sequences


Add backbone coordinate, seq, name info, based on `pdb` and `afdb` database.
We need `seq` since it's a little bit different from `sequence` for a given uniprotID, due to the structural constraint.
`name` is given based on pdb name.

In [94]:
def extract_seqres(pdb_file):
    chain_sequences = {}
    with gzip.open(pdb_file, "rt") if pdb_file.endswith(".gz") else open(pdb_file, "r") as f:
        for record in SeqIO.parse(f, "pdb-seqres"):
            chain_sequences[record.id.split(":")[1]] = str(record.seq)
    return chain_sequences

def calculate_similarity(seq1, seq2):
    matcher = SequenceMatcher(None, seq1, seq2)
    return matcher.ratio()

def process_pdb_with_sequence(entry, pdb_name, pdb_file):
    sequence = entry.get('sequence')

    # Step 1: Extract SEQRES sequences
    seqres_sequences = extract_seqres(pdb_file)
    num_chains = len(seqres_sequences)
    # print('debug', seqres_sequences)

    # Step 2: Find the best-matching chain
    best_chain = None
    best_similarity = 0
    for chain_id, seqres_seq in seqres_sequences.items():
        similarity = calculate_similarity(sequence, seqres_seq)
        if not best_chain or similarity > best_similarity:
            best_chain = chain_id
            best_similarity = similarity
    # print('debug', best_chain)

    if best_chain is None:
        raise ValueError("No matching chain found in PDB file.")

    # print(f"Best-matching chain: {best_chain} with similarity: {best_similarity:.2f}")

    # Step 3: Extract coordinates
    seqres_seq = seqres_sequences[best_chain]
    # backbone_coords = extract_backbone_coords(pdb_file, best_chain, seqres_seq)

    # Step 4: Add Entry
    entry['name'] = f'{pdb_name}.{best_chain}'
    entry['num_chains'] = num_chains
    entry['seq'] = seqres_seq
    # entry['coords'] = backbone_coords
    return entry

def update_pdb_info(data, pdb2file, afdb2file, pdb_path, afdb_path):
    not_found = 0
    new_data = []
    for entry in tqdm(data):
        # fetch pdb_file, pdb_name
        pdb_file, pdb_name = '', ''

        pdb_ids = entry.get('pdb_ids')
        afdb_id = entry.get('afdb')
        if pdb_ids:
            for pdb_id in pdb_ids:
                pdb_id = pdb_id.lower()
                if pdb_id in pdb2file:
                    pdb_file = pdb2file[pdb_id]
                    pdb_file = os.path.join(os.path.dirname(pdb_path), os.path.join(*pdb_file.strip(os.sep).split(os.sep)[-3:]))
                    pdb_name = pdb_id
                    break

        if not pdb_name: # afdb
            if afdb_id in afdb2file:
                pdb_file = afdb2file[afdb_id]
                pdb_file = os.path.join(os.path.dirname(afdb_path), os.path.join(*pdb_file.strip(os.sep).split(os.sep)[-2:]))
                pdb_name = afdb_id
            else:
                print('Warning: PDB/AFDB file not found! Skip: ', entry.get('uniprot_id'))
                not_found += 1
                continue

        try:
            entry = process_pdb_with_sequence(entry, pdb_name, pdb_file)
        except Exception as e:
            print(f"Error processing entry ({entry.get('uniprot_id')}): {e}")
            not_found += 1
            continue

        new_data.append(entry)

    return new_data

In [None]:
# Define filenames and paths
splits = ['train', 'valid', 'test']
original_paths = [os.path.join(general_dataset_path, f"{s}.pkl") for s in splits]
new_paths = [os.path.join(general_dataset_path, f"{s}_bb.pkl") for s in splits]
data_list = [load_pkl_file(p) for p in original_paths]

# Load pdb/afdb mappings
pdb_path = '/data/junbo/datasets/pdb/pdb_table.csv'
afdb_path = '/data/junbo/datasets/afdb_swissprot/af_swissprot_v4_table.csv'
pdb2file = pd.read_csv(pdb_path).set_index('PDB_id')['Path'].to_dict()
afdb2file = pd.read_csv(afdb_path).set_index('PDB_id')['Path'].to_dict()

# Process and save
for data, save_path in zip(data_list, new_paths):
    print(f"Processing {save_path} - original size: {len(data)}")
    updated = update_pdb_info(data, pdb2file, afdb2file, pdb_path, afdb_path)
    print(f"Updated size: {len(updated)}")
    save_pkl_file(updated, save_path)
    print(f"Saved to: {save_path}\n")
