# Import Libraries

In [26]:
import warnings
warnings.filterwarnings("ignore")
import os
import numpy as np
from tqdm import tqdm
import pandas as pd
import os 
from difflib import SequenceMatcher

# Define paths

Google drive = https://drive.google.com/drive/folders/1NcerEtJUn6eULDLdu2l-WPdzvTTw6mFE?usp=sharing

# Imort data

In [27]:
data_path = os.path.dirname(os.getcwd()) + '/data/'
expression_file=pd.read_csv(data_path + 'expression/expression.csv')
expression=dict(zip(list(expression_file.PDB_ID),list(expression_file.Expression)))
species=dict(zip(list(expression_file.PDB_ID),list(expression_file.Species)))
patches = os.listdir(data_path + 'patches/raw')

AA = ['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y']
HYDR = ['A','C','F','I','L','M','V','W','Y']
SPEC = dict(zip(list(expression_file.Species.value_counts().index[:10]),range(0,10)))
EXP = np.percentile(expression_file[~expression_file.Expression.isnull()].Expression, np.arange(0,100,10)).tolist()

In [28]:
def calculate_match_percentage(list1, list2): 
    if len(list1) != len(list2):
        return 0   
    match_count = 0
    for aa1, aa2 in zip(list1, list2):
        if aa1 == aa2:
            match_count += 1
    
    match_percentage = (match_count / len(list1))
    return match_percentage

In [29]:
def get_ASA(data, list_data, agg_data=None):
    new_data=np.zeros((len(data),len(data[0]),95))
    count = 0
    print(new_data.shape, 'initial')
    for i in tqdm(range(len(list_data))):
        #Column len(AA)
        column=np.zeros((len(data[i]),27))
        mask = np.count_nonzero(data[i,:,50])
        
        #Column 0 equals TASA = sum(RSA*ASAmax)
        tasa=sum(data[i,:,53]*data[i,:,50])
        hydr=[bool(AA[k] in HYDR) for k in np.argmax(data[i,:,:20],axis=-1)]
        #Column 1 equals THSA = sum(RSA*ASAmax*hydr_mask)
        thsa=sum(data[i,:,53]*data[i,:,50]*hydr)
        #Column 2 equals RHSA = sum(RSA*ASAmax*hydr_mask)/sum(RSA*ASAmax)
        rhsa=thsa/tasa
        column[0,0]=round(tasa,1)
        column[0,1]=round(thsa,1)
        column[0,2]=round(rhsa,5)

        #Column 3 equals Patch Size
        id_patch=list_data[i].replace('-','').upper()
        if id_patch+'.csv' in patches: 
            patch_info = pd.read_csv(os.path.join(data_path + 'patches','raw',f'{id_patch}.csv'))
                        
            data_fasta = "".join([AA[k] for k in np.argmax(data[i,:mask,:20],axis=-1)])
            
            min_index = min(patch_info['pdb_index'])
            patch_info['pdb_index'] = patch_info['pdb_index'] - min_index
            max_index = max(patch_info['pdb_index'])
            pdb_id_AA_dict = dict(zip(patch_info['pdb_index'], patch_info['amino_acid']))
            
            patch_fasta = "".join([pdb_id_AA_dict.get(i, "_") for i in range(max_index)])
            
            s = SequenceMatcher(None, data_fasta, patch_fasta)
            matches = s.get_matching_blocks()
            if matches[0].a == 0:
                patch_fasta = patch_fasta[matches[0].b:]
                patch_info['pdb_index'] = patch_info['pdb_index'] + matches[0].b
            if matches[0].b == 0:
                patch_fasta = "_"*matches[0].a + patch_fasta
                patch_info['pdb_index'] = patch_info['pdb_index'] + matches[0].a

            pdb_id_AA_dict = dict(zip(patch_info['pdb_index'], patch_info['amino_acid']))
            patch_fasta = "".join([pdb_id_AA_dict.get(i, "_") for i in range(len(data_fasta))])

            skip_files = list_data[i] in ["1dts-A", "1rlr-A", "4pfk-A", "1u2z-B"]
            
            if calculate_match_percentage(data_fasta, patch_fasta) > 0.95 and not skip_files:                
                patch_place = np.zeros((len(data[i])))
                patch_place[patch_info['pdb_index']] = patch_info['patch_size']
                patch_place[np.isnan(patch_place)] = 0
                column[:,4]=np.where(patch_place>0,1,0)
                if not((patch_place==0).all()):
                    column[0,3]=round(max(patch_place),1)
                    column[:,5]=np.where(patch_place==max(patch_place),1,0)
                count += 1  
            # elif calculate_match_percentage(data_fasta, patch_fasta) > 0.9:      
            #     print(data_fasta) 
            #     print(patch_fasta)
            #     print(id_patch)
            #     print('*******')

        #Column 6 equals Species
        if list_data[i] in species.keys() and species[list_data[i]] in SPEC.keys():
            column[0,6+SPEC[species[list_data[i]]]]=1

        # Columns 17 equals expression
        if list_data[i] in expression.keys() and not(np.isnan(expression[list_data[i]])):
            expression_value=round(expression[list_data[i]],1)  
            column[0,16]=expression_value
            rk=0
            for k in EXP[1:]:
                if expression_value>=k:
                    rk+=1
            column[0,17+rk]=1    
        new_data[i]=np.c_[data[i],column]        
    
    print(new_data.shape, 'final')
    print(count)
    return new_data

In [30]:
if __name__=="__main__":
    name_train = np.load(f"{os.path.dirname(os.getcwd())}/data/source_dataset/Train_HHblits.npz")['pdbids']
    name_casp = np.load(f"{os.path.dirname(os.getcwd())}/data/source_dataset/CASP12_HHblits.npz")['pdbids']
    name_cb = np.load(f"{os.path.dirname(os.getcwd())}/data/source_dataset/CB513_HHblits.npz")['pdbids']
    name_ts = np.load(f"{os.path.dirname(os.getcwd())}/data/source_dataset/TS115_HHblits.npz")['pdbids']
    data_train = np.load(f"{os.path.dirname(os.getcwd())}/data/source_dataset/Train_HHblits.npz")['data']
    data_casp = np.load(f"{os.path.dirname(os.getcwd())}/data/source_dataset/CASP12_HHblits.npz")['data']
    data_cb = np.load(f"{os.path.dirname(os.getcwd())}/data/source_dataset/CB513_HHblits.npz")['data']
    data_ts = np.load(f"{os.path.dirname(os.getcwd())}/data/source_dataset/TS115_HHblits.npz")['data']
    list_train = get_ASA(data_train,name_train)
    list_casp = get_ASA(data_casp,name_casp)
    list_cb = get_ASA(data_cb,name_cb)
    list_ts = get_ASA(data_ts,name_ts)
    np.savez_compressed(f"{os.path.dirname(os.getcwd())}/data/extended/Train_HHblits_extended.npz",pdbids=name_train,data=list_train)
    np.savez_compressed(f"{os.path.dirname(os.getcwd())}/data/extended/CASP12_HHblits_extended.npz",pdbids=name_casp,data=list_casp)
    np.savez_compressed(f"{os.path.dirname(os.getcwd())}/data/extended/CB513_HHblits_extended.npz",pdbids=name_cb,data=list_cb)
    np.savez_compressed(f"{os.path.dirname(os.getcwd())}/data/extended/TS115_HHblits_extended.npz",pdbids=name_ts,data=list_ts)

(10848, 1632, 95) initial


100%|██████████| 10848/10848 [00:16<00:00, 643.71it/s]


(10848, 1632, 95) final
9990
(21, 1494, 95) initial


100%|██████████| 21/21 [00:00<00:00, 686.55it/s]


(21, 1494, 95) final
20
(513, 874, 95) initial


100%|██████████| 513/513 [00:00<00:00, 923.26it/s]


(513, 874, 95) final
470
(115, 1111, 95) initial


100%|██████████| 115/115 [00:00<00:00, 841.82it/s]


(115, 1111, 95) final
113


In [31]:
print(len(name_train), len(set(name_train)))
print(len(name_casp), len(set(name_casp)))
print(len(name_cb), len(set(name_cb)))
print(len(name_ts), len(set(name_ts)))

10848 10848
21 21
513 434
115 115


# Only keep LHP global and local features and remove the proteins that are not annotated

In [32]:
import numpy as np

def LHP_only(dataset_path, save_path):
    # Load dataset
    data = np.load(dataset_path, allow_pickle=True)
    pdbids = data['pdbids']
    dataset = data['data']
    
    # Define indices of columns with LHP annotations
    lhp_indices = [71, 72, 73]
    
    # Initialize a list to store filtered data
    filtered_data = []
    indices_with_lhp = []
    
    # Iterate through each entry in the dataset
    for i in range(len(dataset)):
        # Check if any of the LHP columns have annotations
        if any(dataset[i, :, idx].any() for idx in lhp_indices):
            # If LHP annotation is present, keep the row
            filtered_data.append(dataset[i])
            indices_with_lhp.append(i)
    
    # Stack filtered data
    filtered_data = np.stack(filtered_data)
    
    # Save filtered dataset
    np.savez_compressed(save_path, pdbids=pdbids[indices_with_lhp], data=filtered_data)

In [33]:
LHP_only(f"{os.path.dirname(os.getcwd())}/data/extended/Train_HHblits_extended.npz", 
                              f"{os.path.dirname(os.getcwd())}/data/extended/Train_LHP.npz")
LHP_only(f"{os.path.dirname(os.getcwd())}/data/extended/CASP12_HHblits_extended.npz", 
                              f"{os.path.dirname(os.getcwd())}/data/extended/CASP12_LHP.npz")
LHP_only(f"{os.path.dirname(os.getcwd())}/data/extended/CB513_HHblits_extended.npz", 
                              f"{os.path.dirname(os.getcwd())}/data/extended/CB513_LHP.npz")
LHP_only(f"{os.path.dirname(os.getcwd())}/data/extended/TS115_HHblits_extended.npz", 
                              f"{os.path.dirname(os.getcwd())}/data/extended/TS115_LHP.npz")

In [34]:
CASP12_HHblits_extended = np.load(f"{os.path.dirname(os.getcwd())}/data/extended/CASP12_HHblits_extended.npz")
Train_LHP = np.load(f"{os.path.dirname(os.getcwd())}/data/extended/Train_LHP.npz")
CASP12_LHP = np.load(f"{os.path.dirname(os.getcwd())}/data/extended/CASP12_LHP.npz")
CB513_LHP = np.load(f"{os.path.dirname(os.getcwd())}/data/extended/CB513_LHP.npz")
TS115_LHP = np.load(f"{os.path.dirname(os.getcwd())}/data/extended/TS115_LHP.npz")

In [35]:
print(Train_LHP['data'].shape)
print(CASP12_LHP['data'].shape)
print(CB513_LHP['data'].shape)
print(TS115_LHP['data'].shape)

(9990, 1632, 95)
(20, 1494, 95)
(470, 874, 95)
(113, 1111, 95)


In [36]:
from collections import Counter

data_test_hydr = list_cb

all_residues = []
for i in range(len(data_test_hydr)):
    mask = np.count_nonzero(data_test_hydr[i,:,50])
    data_fasta = [AA[k] for k in np.argmax(data_test_hydr[i,:mask,:20],axis=-1)]
    list_aa_in_hydr = pd.DataFrame({'AA': data_fasta, 'patch':list(data_test_hydr[i, :mask, 72])})
    all_residues += list(list_aa_in_hydr[list_aa_in_hydr["patch"] == 1]['AA'])
    if 'P' in list(list_aa_in_hydr[list_aa_in_hydr["patch"] == 1]['AA']):
        print(list_aa_in_hydr)
        print(name_train[i])
        
HYDR = ['A','C','F','I','L','M','V','W','Y']
for item, count in Counter(all_residues).items():
    print(f"{item}: {count}")

C: 1242
Y: 3898
I: 4745
A: 8527
L: 7372
V: 6105
W: 1513
F: 3549
M: 1947
