### 1. Fetch PDB sequences and generate blast DB <br>

Each PDB formatted file includes "SEQRES records" which list the primary sequence of the polymeric molecules present in the entry. This sequence information is also available as a FASTA download

In [2]:
import pickle
import numpy as np
from pathlib import Path
from biopandas.pdb import PandasPdb

In [3]:
import os
import subprocess
import pandas as pd
from Bio.Blast import NCBIXML
from pathlib import Path
import sys 

if not os.path.exists("PDB_blast_db"):
    os.mkdir("PDB_blast_db")

if not os.path.exists("PDB_blast_db/pdb_seqres.txt"):
    subprocess.call("wget https://files.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt.gz",shell=True)    
    subprocess.call("gzip -d pdb_seqres.txt.gz",shell=True)
    subprocess.call("mv pdb_seqres.txt ./PDB_blast_db/",shell=True)
    
if not os.path.exists("PDB_blast_db/pdb_seqres.txt.psq"):
    subprocess.call("makeblastdb -in pdb_seqres.txt -dbtype prot -title pdb",shell=True, cwd="./PDB_blast_db")

--2024-02-14 13:22:32--  https://files.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt.gz
Resolving files.wwpdb.org (files.wwpdb.org)... 128.6.159.245
Connecting to files.wwpdb.org (files.wwpdb.org)|128.6.159.245|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 49537942 (47M) [application/x-gzip]
Saving to: ‘pdb_seqres.txt.gz’

     0K .......... .......... .......... .......... ..........  0%  202K 3m59s
    50K .......... .......... .......... .......... ..........  0%  405K 2m59s
   100K .......... .......... .......... .......... ..........  0% 92.7M 1m59s
   150K .......... .......... .......... .......... ..........  0%  160M 89s
   200K .......... .......... .......... .......... ..........  0%  405K 95s
   250K .......... .......... .......... .......... ..........  0%  171M 79s
   300K .......... .......... .......... .......... ..........  0%  153M 68s
   350K .......... .......... .......... .......... ..........  0%  409K 74s
   400K .......... ......



Building a new DB, current time: 02/14/2024 13:22:39
New DB name:   /mnt/10tb/home/shevtsov/SEMAi/epitopes_prediction/dataset_generation/PDB_blast_db/pdb_seqres.txt
New DB title:  pdb
Sequence type: Protein
Keep MBits: T
Maximum file size: 1000000000B
Adding sequences from FASTA; added 865773 sequences in 15.0551 seconds.


### 2. Preliminary screen for proteins in the PDB database with homology to fragment antigen-binding region 

Fetch PDB's with Fab's

In [None]:
light = "DILLTQSPVILSVSPGERVSFSCRASQSIGTNIHWYQQRTNGSPRLLIKYASESISGIPSRFSGSGSGTDFTLSINSVESEDIADYYCQQNNNWPTTFGAGTKLELK"
with open("PDB_blast_db/fab_light.fasta",'w') as fo:
    fo.write(">input_light\n")
    fo.write(light)

heavy = "QVQLKQSGPGLVQPSQSLSITCTVSGFSLTNYGVHWVRQSPGKGLEWLGVIWSGGNTDYNTPFTSRLSINKDNSKSQVFFKMNSLQSNDTAIYYCARALTYYDYEFAYWGQGTLVTVSA"
with open("PDB_blast_db/fab_heavy.fasta",'w') as fo:
    fo.write(">input_heavy\n")
    fo.write(heavy)

if not os.path.exists("PDB_blast_db/hits_fabs_light.xml"):
    subprocess.call("blastp -db pdb_seqres.txt -num_alignments 99999 -evalue 1e-9 -query fab_light.fasta -out hits_fabs_light.xml -outfmt 5",shell=True, cwd="./PDB_blast_db")

if not os.path.exists("PDB_blast_db/hits_fabs_heavy.xml"):
    subprocess.call("blastp -db pdb_seqres.txt -num_alignments 99999 -evalue 1e-9 -query fab_heavy.fasta -out hits_fabs_heavy.xml -outfmt 5",shell=True, cwd="./PDB_blast_db")



In [None]:
def parse_blast_output(input_path): 
    result=open(input_path,"r")
    records= NCBIXML.parse(result)
    item=next(records)
    pdb_fabs  = set()
    pdb_fabs_ = set()
    for alignment in item.alignments:
        #print(alignment)
        #break
        for hsp in alignment.hsps:
            pdb_id = alignment.title.split()[1]
            pdb_id_id = pdb_id.split("_")[0]
            pdb_fabs.add(pdb_id)
            pdb_fabs_.add(pdb_id_id)
    return pdb_fabs

pdb_fab_hits_1 =  parse_blast_output("PDB_blast_db/hits_fabs_light.xml")
pdb_fab_hits_2 =  parse_blast_output("PDB_blast_db/hits_fabs_heavy.xml")
pdb_fab_hits   = pdb_fab_hits_1|pdb_fab_hits_2

print(len(pdb_fab_hits))

### 3. Screen for heavy and light fab chains using ANARCI

http://opig.stats.ox.ac.uk/webapps/newsabdab/sabpred/anarci/ <br>
https://github.com/oxpig/ANARCI

Annotate all Fabs with ANARCI

In [None]:
def load_fasta(path):
    r = []
    with open(path) as f:
        for line in f:
            if line[0]==">":
                r.append([])
            r[-1].append(line.rstrip())
    r = [[r_[0],"".join(r_[1:])] for r_ in r]
    return r

In [None]:
r = []
with open("./PDB_blast_db/pdb_seqres.txt") as f:
    for line in f:
        if line[0]==">":
            r.append([])
        r[-1].append(line)
        
pdb_seqres_fasta = load_fasta("./PDB_blast_db/pdb_seqres.txt")

rfabs = []
for r_ in r:
    title = r_[0].split(" ")[0][1:]
    if title not in pdb_fab_hits:
        continue
    rfabs.append([r_[0].split(" ")[0][1:],r_[1]])
    
with open("./PDB_blast_db/putative_fabs.fasta",'w') as fo:
    for r in rfabs:
        fo.write("".join([">"+r[0]+"\n",r[1]])+"\n")

if not os.path.exists("./PDB_blast_db/all_fabs_heavy.anarci"):
    anarci_command = "ANARCI -i putative_fabs.fasta -o all_fabs_heavy.anarci -s chothia -r ig --ncpu 8 --bit_score_threshold 100 --restrict heavy"
    subprocess.call(anarci_command,shell=True,cwd="./PDB_blast_db")
    
if not os.path.exists("./PDB_blast_db/all_fabs_light.anarci"):
    anarci_command = "ANARCI -i putative_fabs.fasta -o all_fabs_light.anarci -s chothia -r ig --ncpu 8 --bit_score_threshold 100 --restrict light"
    subprocess.call(anarci_command,shell=True,cwd="./PDB_blast_db")


### 4. Parse ANARCI output and extract heavy and light Fab sequences 


In [None]:
def parse_anarci_annotation(path = "light.anarci",n=108):
    seqs = []
    seqs.append([[] for i in range(n)])
    used = set()
    data = {}
    with open(path) as f:
        w = f.readlines()
        data = [[]]
        for u,line in enumerate(w):
            data[-1].append(line)
            if line[0]=="/":
                data.append([])          
        out = {}
        for d in data:
            if len(d)==0:
                continue
            name = d[0].rstrip().split()[-1]
            if name in out:
                continue
            out[name] = [[] for i in range(n)]            
            for d_ in d:
                if d_[0]=="#":
                    continue
                if d_[0]=="/":
                    continue
                id_ = d_.split()[1]
                id_ = int(id_)
                if d_[10]=="-":
                    continue
                out[name][id_].append(d_[10])
    out_ = {}
    for name in out:
        if len("".join(["".join(c) for c in out[name]]))==0:
            continue
        out_[name] = out[name]
    return out_

anarci_list_heavy = parse_anarci_annotation("./PDB_blast_db/all_fabs_heavy.anarci", n=120)
anarci_list_light = parse_anarci_annotation("./PDB_blast_db/all_fabs_light.anarci", n=108)


### 5. Fetch all PDB structures containing Light/Heavy chains 

In [None]:
pdb_3 = {r[:4]:{"light":[],"heavy":[]} for r in list(pdb_fab_hits)}
for h in anarci_list_light:
    h4 = h[:4]
    if h4 not in pdb_3:
        continue
    pdb_3[h4]["light"].append(h)
for h in anarci_list_heavy:
    h4 = h[:4]
    if h4 not in pdb_3:
        continue
    pdb_3[h4]["heavy"].append(h)
    
if not os.path.exists("PDB_blast_db/structs"):
    os.mkdir("PDB_blast_db/structs")

for pdb_ in pdb_3:
    if len(pdb_3[pdb_]["light"])+len(pdb_3[pdb_]["heavy"])==0:
        continue    
    pdb_name = pdb_.upper()+".pdb.gz"
    if os.path.exists("PDB_blast_db/structs/"+pdb_name):
        continue
    if os.path.exists("PDB_blast_db/structs/"+pdb_name.rstrip(".gz")):
        continue
    subprocess.call(f"wget https://files.rcsb.org/download/{pdb_name}",shell=True, cwd="./PDB_blast_db/structs")


In [4]:
### Test that PDB IDS corresponding to old ids are not missing
train = pd.read_csv("../data/sema_1.0/train_set.csv")
test = pd.read_csv("../data/sema_1.0/test_set.csv")
ref_names = [d["pdb_id_chain"][:6] for d in train.iloc()]
ref_names+= [d["pdb_id_chain"][:6] for d in test.iloc()]
ref_names = set(ref_names)
ref_pdbs  = set([r.split("_")[0] for r in ref_names])

downloaded_pdbs = set()
for p in Path("./PDB_blast_db/structs/").glob("*.pdb"):
    downloaded_pdbs.add(p.name[:4])



### Keep First model for PDBs with multiple models

In [None]:
from pathlib import Path
subprocess.call(f"gzip -d *.gz", shell=True, cwd="./PDB_blast_db/structs")
pdbs = Path("./PDB_blast_db/structs/").glob("*.pdb")
for pdb in pdbs:
    pdb_data = pdb.open('r').readlines()
    fo = open(str(pdb),'w')
    for l in pdb_data:
        fo.write(l)
        if l.startswith("ENDMDL"):
            break
    fo.close()

### 6. Prepare PDB dataframes and align full sequence (from pdb seq-res) on sequence of resloved protein (may contain some gaps) 

In [None]:
def aa_3_to_1(resn):
    #assert line[:4] in {"HETA","ATOM"}
    #resn = line[17:20]
    d = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
     'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N',
     'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W',
     'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}
    return d[resn]

def kalign(seq1,seq2):
    if not os.path.exists("./PDB_blast_db/temp"):
        os.mkdir("./PDB_blast_db/temp")
    fo = open("./PDB_blast_db/temp/input.fasta",'w')
    fo.write(f">1\n{seq1}\n>2\n{seq2}\n")
    fo.close()
    d = subprocess.check_output("cat ./PDB_blast_db/temp/input.fasta | kalign -f fasta",shell=True)
    res_  = d.decode("UTF-8").split("\n")
    res   = []
    #print(res_)
    for l in res_:
        if len(l) ==0:
            continue
        if l[0]==">":
            res.append([])
        if len(res)!=0:
            res[-1].append(l.rstrip("\n"))
    #print(res_)
    return "".join(res[0][1:]),"".join(res[1][1:])
    
    
def remove_alternative_conformations(pdb_dataframe):
    return pdb_dataframe[(pdb_dataframe["alt_loc"] == "A") | (pdb_dataframe["alt_loc"] == " ")  | (pdb_dataframe["alt_loc"] == "")]
    
def remove_unk(pdb_dataframe):
    r = pdb_dataframe["residue_name"] == "UNK"
    return pdb_dataframe[~r]#( | (pdb_dataframe["alt_loc"] == " ")  | (pdb_dataframe["alt_loc"] == "")]
    
def consider_insertions(pdb_dataframe):
    r1 = pdb_dataframe["residue_number"]
    r2 = pdb_dataframe["insertion"]
    r3 = pdb_dataframe["chain_id"]
    r4 = pdb_dataframe["residue_name"]
    ra = [(r_1,r_2,r_3,r_4) for (r_1,r_2,r_3,r_4) in zip(r1,r2,r3,r4)]
    pdb_dataframe["residue_key"] = ra
    return pdb_dataframe
    
def put_full_sequence(pdb_dataframe, full_seq):
    pdb_dataframe = remove_alternative_conformations(pdb_dataframe)
    pdb_dataframe = remove_unk(pdb_dataframe)
    pdb_dataframe = consider_insertions(pdb_dataframe)
    
    if len(pdb_dataframe) == 0:
        print("Empty...")
        return
    #print(pdb_dataframe)
    #assert 1==2
    pdb_ca                          = pdb_dataframe[pdb_dataframe["atom_name"] == "CA"]
    
    residue_numbers = []
    residue_seq     = []
    used = set()
    
    #### To avoid residue duplicates 
    #### Might be not necessary
    for r in pdb_ca.iloc():#["residue_number"]:
        residue_number = r["residue_key"]
        if residue_number in used:
            continue
        residue_numbers.append(residue_number)
        residue_seq.append(    aa_3_to_1(r["residue_name"]))
    
    pdb_seq = "".join(residue_seq)#pdb_seq)    
    if len(pdb_seq) <= 5:
        print("PDB sequence is too short")
        return
    pdb_seq_aligned, full_seq_aligned = kalign(pdb_seq,full_seq)

    assert full_seq_aligned.replace("-","") == full_seq
    
    n_pdb     = -1
    n_pdb_map = []
    
    print("pdb",pdb_seq_aligned)
    print("full",full_seq_aligned)
    
    assert len(pdb_seq_aligned) == len(full_seq_aligned)
    
    new_dataframe = []
    for i,[a_pdb,a_fullseq] in enumerate(zip(list(pdb_seq_aligned),list(full_seq_aligned))):
        #print(a_pdb,a_fullseq,residue_numbers[n_pdb])
        if a_pdb != '-':
            n_pdb+=1
        #assert a_fullseq != "-"
        #if a_fullseq!="-" and a_pdb!="-":
        #    assert a_fullseq == a_pdb
        if a_pdb == "-":
            n_pdb_map.append({"resi":None,                  "a_pdb":None, "a_full":a_fullseq})
        else:
            n_pdb_map.append({"resi":residue_numbers[n_pdb],"a_pdb":a_pdb,"a_full":a_fullseq})

    full_df  = []
    full_seq = []
    pd.options.mode.chained_assignment = None  # default='warn'
    
    for r in n_pdb_map:
        full_seq.append(r["a_full"])
        
        if r["resi"] is None:
            empty_   = pd.DataFrame(np.nan, index=[0],columns=pdb_ca.columns)
            empty_["atom_name"] = "CA"
            empty_["seqres"]    = r["a_full"]
            full_df.append(empty_)
            continue
        
        pdb_residue = pdb_dataframe[pdb_dataframe["residue_key"] == r["resi"]]
        pdb_residue["seqres"] = r["a_full"]
        pdb_residue["aa"]     = r["a_pdb"]
        full_df.append(pdb_residue)  
    full_df = pd.concat(full_df,axis=0,ignore_index=True)
    print(full_df.shape)
    return full_df

def get_tasks():
    if not os.path.exists("./PDB_blast_db/all_pdbids_and_chains.txt"):
        all_chains = set()
        p = Path("./PDB_blast_db/structs/").glob("*.pdb")    
        for p_ in p:
            for line in p_.open('r'):
                if line.startswith("ATOM") and line[13:15] =="CA":
                    all_chains.add(p_.name.rstrip(".pdb")+"_"+line[21])                
        with open("./PDB_blast_db/all_pdbids_and_chains.txt",'w') as fo:
            fo.write("\n".join(list(all_chains)))

    all_chains = [r.rstrip() for r in open("./PDB_blast_db/all_pdbids_and_chains.txt",'r').readlines()]
    used = set()
    for p in Path("./PDB_blast_db/structs_per_chain/").glob("*.pkl"):
        used.add(p.name.rstrip(".pkl"))
    j = {}
    for u in all_chains:
        if u in used:
            continue
        pdbid,chain = u.split("_")
        j.setdefault(pdbid,set())
        j[pdbid].add(chain)
        
    return j



def get_PDBDataFrame(pdb_id = "1FGV",chains = None):   
    pdb_path           = f"./PDB_blast_db/structs/{pdb_id.upper()}.pdb"
    if not os.path.exists(pdb_path):
        print(pdb_id," not found")
        return
    pdb_structure      = PandasPdb().read_pdb(pdb_path).df["ATOM"]
    from Bio import SeqIO
    
    sequences = {}
    pdb_records = {}
    
    if not os.path.exists(f"./PDB_blast_db/structs_per_chain/"):
        os.mkdir(f"./PDB_blast_db/structs_per_chain/")
        
    for record in SeqIO.parse(pdb_path, "pdb-seqres"):
        chain             = record.id[-1]            
        if os.path.exists(f"./PDB_blast_db/structs_per_chain/{pdb_id}_{chain}.pkl"):
            continue

        if chains is not None and chain not in chains:
            continue
        sequences[chain]  = record.seq
        pdb_chain         = pdb_structure[pdb_structure["chain_id"] == chain]
        if len(pdb_chain) == 0:
            continue
        
        print(pdb_id,chain)
        pdb_chain_fullseq = put_full_sequence(pdb_chain,sequences[chain])
        pickle.dump(pdb_chain_fullseq, open(f"./PDB_blast_db/structs_per_chain/{pdb_id}_{chain}.pkl",'wb'))
    
#df_jobs = {}
jobs = get_tasks()
for pdb_id in jobs:#get_tasks():#pdb_3:
    print(pdb_id)
    chains = jobs[pdb_id]
    get_PDBDataFrame(pdb_id,chains)
    




### 7. Put ANARCI annotation into antibodies dataframes prepared in the previous step

In [None]:

def mafft_align(s1,s2,strict=True):
    with open("m.fasta",'w') as fo:
        fo.write(">1\n"+s1+"\n>2\n"+s2+"\n")
    if not strict:
        d = subprocess.check_output("mafft --anysymbol --op 0.1  m.fasta ",shell=True)
    else:
        d = subprocess.check_output("mafft --anysymbol --auto m.fasta ",shell=True)

    res_  = d.decode("UTF-8").split("\n")
    res   = []
    for l in res_:
        if len(l)==0:
            continue
        if l[0]==">":
            res.append("")
            continue
        res[-1]+=l.rstrip()
    return res

def realign_sequences(pdb_seq,anarci_, firstLetterException = False):
    seq_aa = []
    seq_i  = []
    for i,s_ in enumerate(anarci_):
        if len(s_)==0:
            continue
        seq_aa+=s_
        seq_i +=[i for i_ in range(len(s_))]
    al = kalign("".join(seq_aa),"".join(pdb_seq))
    
    seq_anarci_aligned = al[0]
    pdb_seq_aligned    = al[1]
    n_anarci = 0
    n_pdb    = 0
    pdb_anarci_map = [None for i in pdb_seq]
 
    for i,[a_anarci,a_pdb] in enumerate(zip(*al)):
        if a_anarci!="-" and a_pdb!="-":#i!=0:
            pdb_anarci_map[n_pdb] = i
            if n_anarci == 0 and firstLetterException:
                n_pdb+=1
                n_anarci+=1
                continue
            if a_pdb!=a_anarci:
                return None
        if a_pdb!="-":
            n_pdb+=1
        if a_anarci!="-":            
            n_anarci+=1

    return pdb_anarci_map
    
def put_anarci_annotation(pdb_dataframe,fab_id, firstLetterException = False):
    pdb_id,chain,fab_type = fab_id
    if fab_type == "light":
        anarci_seq     = anarci_list_light[pdb_id.lower()+"_"+chain]
    else:
        anarci_seq     = anarci_list_heavy[pdb_id.lower()+"_"+chain]
    pdb_ca        = pdb_dataframe[pdb_dataframe["atom_name"] == "CA"]#["seqres"]
    pdb_seq = "".join(pdb_ca["seqres"])
    pdb_anarci_map = realign_sequences(pdb_seq, anarci_seq,firstLetterException)
    
    if pdb_anarci_map is None:
        return None
    
    pdb_anarci_map = [fab_type[0].upper()+str(i)  if i is not None else None for i in pdb_anarci_map]    
    pdb_dataframe["anarci"] = None
    
    for anarci_id,residue_number in zip(pdb_anarci_map, pdb_ca["residue_key"]):
        ids = pdb_dataframe[ "residue_key"] == residue_number
        pdb_dataframe.loc[ids,"anarci"] = anarci_id     
        
    return pdb_dataframe





In [None]:
if not os.path.exists(f"./PDB_blast_db/structs_antibodies/"):
    os.mkdir(f"./PDB_blast_db/structs_antibodies/")        
    
jobs = []
for anarci_id,anarci_map in anarci_list_heavy.items():    
    jobs.append((anarci_id[:4].upper(),anarci_id[-1],"heavy"))
for anarci_id,anarci_map in anarci_list_light.items():    
    jobs.append((anarci_id[:4].upper(),anarci_id[-1],"light"))

strange_error_list = set()
for pdb_id,chain,fab_type in jobs:
    if pdb_id.upper()+"_"+chain in strange_error_list:
        continue
    firstLetterException = True
    print(pdb_id,chain,fab_type)
    pdb_path           = f"./PDB_blast_db/structs_per_chain/{pdb_id}_{chain}.pkl"
    if not os.path.exists(pdb_path):
        continue
    out_path           = f"./PDB_blast_db/structs_antibodies/{pdb_id}_{chain}_{fab_type}.pkl"
    if os.path.exists(out_path):
        continue
    fab                = pickle.load(open(pdb_path,'rb'))
    fab_annotated      = put_anarci_annotation(fab, (pdb_id,chain,fab_type),firstLetterException)
    
    ### if something went wrong, we skip this complex
    if fab_annotated is None:
        strange_error_list.add((pdb_id,chain,fab_type))
        continue            
    pickle.dump(fab_annotated, open(out_path,'wb'))


In [None]:
### Test that nothing is missing

train = pd.read_csv("../data/sema_1.0/train_set.csv")
test = pd.read_csv("../data/sema_1.0/test_set.csv")

ref_names = [d["pdb_id_chain"][:6] for d in train.iloc()]
ref_names+= [d["pdb_id_chain"][:6] for d in test.iloc()]
ref_names = set(ref_names)    

n =0 
for p in Path("./PDB_blast_db/structs_per_chain/").glob(f"*.pkl"):
    if p.name[:6] in ref_names:
        n+=1
        
print(n, len(ref_names))


### 8. Find heavy/light chain fab pairs

In [None]:

def get_pdb_list():
    return [p.name[:4] for p in Path("./PDB_blast_db/structs_antibodies/").glob(f"*.pkl")]
    
def get_fabs_pdbid(pdb_id = "1LK3"):
    fab_path = Path("./PDB_blast_db/structs_antibodies/").glob(f"{pdb_id}*.pkl")
    fab_ids  = {"heavy":[],"light":[]}
    for struct_id in fab_path:
        pdb_id, chain, fab_type = struct_id.name.rstrip(".pkl").split("_")
        fab_ids[fab_type].append(struct_id)
    return fab_ids


In [None]:
import scipy
from scipy.spatial import distance

def getxyz(df):
    xyz = np.array([df["x_coord"],df["y_coord"],df["z_coord"]]).T
    return xyz
    
    
def get_pair_interface(path_light, path_heavy):
    
    pdb_light = pickle.load(open(path_light,'rb'))
    pdb_heavy = pickle.load(open(path_heavy,'rb'))
    
    ### interface residues of heavy and light fab chains
    heavy_interface = list(range(32,39)) + list(range(44,50)) + list(range(85,95))
    light_interface = list(range(34,39)) + list(range(45,51)) + list(range(90,108))

    heavy_ids = ["H"+str(i) for i in heavy_interface]
    light_ids = ["L"+str(i) for i in light_interface]
    
    heavy_interface = pdb_heavy[pdb_heavy["anarci"].isin(heavy_ids)]
    light_interface = pdb_light[pdb_light["anarci"].isin(light_ids)]
    
    xyz_heavy = getxyz(heavy_interface)
    xyz_light = getxyz(light_interface)
    
    cd = distance.cdist(xyz_heavy,xyz_light)
    ids = np.where(cd<4.5)
    
    return len(set(ids[0]))+len(set(ids[1]))
    
def screen_fab_pairs(pdb_id):
    fab_path = get_fabs_pdbid(pdb_id)
    contacts = {}
    for heavy_path in fab_path["heavy"]:
        for light_path in fab_path["light"]:
            n = get_pair_interface(path_light = light_path, path_heavy = heavy_path)
            ### cut off to select interface that interact with each other
            if n > 3:# 10:
                contacts[(light_path.name.rstrip(".pkl"), heavy_path.name.rstrip(".pkl"))] = n
    return contacts


fab_contacts = {}
for pdb_id in get_pdb_list():
    fab_contacts[pdb_id] = screen_fab_pairs(pdb_id)
    #print(pdb_id,fab_contacts[pdb_id])
    #break
pickle.dump(fab_contacts, open("./PDB_blast_db/fab_pairs.pkl",'wb'))

### 9. Find antigens and corresponding interacting antibodies

In [None]:

def get_all_fabs():
    return [l+".pkl" for l in fab_ids]

def get_all_antigens_list():
    fab_ids = set(anarci_list_heavy)|set(anarci_list_light)
    fab_ids = {f[:4].upper()+"_"+f[-1] for f in fab_ids}#print(fab_ids)
    pdb_ids = [p for p in Path("./PDB_blast_db/structs_per_chain/").glob("*.pkl") if p.name[:6] not in fab_ids]
    return pdb_ids

def get_antigens_PDBID(pdb_id = "1LK3"):
    all_antigens = get_all_antigens_list()
    return [a for a in all_antigens if a.name[:4] == pdb_id]


def screen_antigen_contacts(pdb_id = "1LK3"):
    if pdb_id not in fab_contacts:
        return []
    fab_pairs = fab_contacts[pdb_id]
    antigens  = get_antigens_PDBID(pdb_id)
    hits = []
    for antigen in antigens:
        for fab_pair in fab_pairs:
            print(antigen, fab_pair)
            fab_path_light = Path("./PDB_blast_db/structs_antibodies/"+fab_pair[0]+".pkl")
            fab_path_heavy = Path("./PDB_blast_db/structs_antibodies/"+fab_pair[1]+".pkl")
            n_light = test_contacts(antigen,fab_path_light)
            n_heavy = test_contacts(antigen,fab_path_heavy)
            if n_light+n_heavy == 0:
                continue
            hits.append({"antigen":antigen,
                         "fab_pair":fab_pair,
                         "n_contacts_light":n_light,
                         "n_contacts_heavy":n_heavy})
    return hits
        
def test_contacts(antigen_path, fab_path):
    ### check if there is a contact
    #print(antigen_path)
    
    antigen_df = pickle.load(antigen_path.open('rb'))
    if antigen_df is None:
        return 0
    
    fab_df     = pickle.load(fab_path.open('rb'))
    fab_type   = None
    interface  = None
    
    ### CDR1-3 residues of light chain
    ### CDR1-3 residues of heavy chain
    
    if fab_path.name.split("_")[-1] == "light.pkl":
        fab_type  = "light"
        interface = ["L"+str(i) for i in list(range(23,35))+list(range(66,72))+list(range(89,98))]

    elif fab_path.name.split("_")[-1] == "heavy.pkl":
        fab_type  = "heavy" 
        interface = ["H"+str(i) for i in list(range(23,35))+list(range(51,57))+list(range(93,102))]
    
    fab_interface = fab_df[fab_df["anarci"].isin(interface)]
    xyz_fab       = getxyz(fab_interface)
    xyz_antigen   = getxyz(antigen_df)     
    c = ~np.isnan(xyz_antigen)[:,0]
    xyz_antigen = xyz_antigen[c]
    cd  = distance.cdist(xyz_antigen, xyz_fab)
    ids = set(np.where(cd<4.5)[0])
    return len(ids)

hits = []
for pdb_id in set([n.name[:4] for n in get_all_antigens_list()]):
    hits+=screen_antigen_contacts(pdb_id)
        
df = pd.DataFrame(hits)
pickle.dump(df, open("./PDB_blast_db/antigen_fab_list.pkl",'wb'))


### 10. Calculate contact numbers

In [6]:
!pwd

/Users/ivanisenko/projects/SEMA_augmentation_lazy/PDB_epitope_dataset_generation


In [None]:
def calculate_contact_number(df,
                             antigen_chain_id,
                             light_id,
                             heavy_id,
                             R1=8.0,
                             R2=16.0):
    """
    calculate contact nubmers
    R1 - is used to calculate contact number values
    R>R2 - is masked
    """
    df_A = df[df["chain_id"]==antigen_chain_id]
    df_B = df[df["chain_id"].isin([light_id, heavy_id])]
    xyz_fab = getxyz(df_B)
    df.loc[:,"b_factor"] = -1
    resi_ids = ["residue_name", "chain_id", "residue_number", "insertion"]
    for resi, _df in df_A.groupby(resi_ids):
        xyz_resi = getxyz(_df)
        cd = distance.cdist(xyz_resi, xyz_fab)
        ids = list(set(np.where(cd<R2)[1]))
        if len(ids) == 0:
            continue
        df.loc[_df.index.values,"b_factor"] = 0
    
    for resi, _df in df_A.groupby(resi_ids):
        xyz_resi = getxyz(_df)
        cd = distance.cdist(xyz_resi, xyz_fab)
        ids = list(set(np.where(cd<R1)[1]))
        if len(ids) == 0:
            continue
        target = df_B.iloc()[ids]
        n_fab_resi = len(target.groupby(resi_ids))
        df.loc[_df.index.values,"b_factor"] = n_fab_resi
        
    return

def getxyz(df):
    xyz = np.array([df["x_coord"],df["y_coord"],df["z_coord"]]).T
    return xyz

def load_old_ref_names():
    """
    Function to get compatiblity with old train set
    """
    train = pd.read_csv("../data/sema_1.0/train_set.csv")
    test = pd.read_csv("../data/sema_1.0/test_set.csv")
    ref_names = [d["pdb_id_chain"][:6] for d in train.iloc()]
    ref_names+= [d["pdb_id_chain"][:6] for d in test.iloc()]
    return set(ref_names)

def get_clusters():
    clusters = {}
    with open("./PDB_blast_db/clusters_095/results_cluster.tsv") as f:
        for line in f:
            r = line.rstrip().split()
            clusters.setdefault(r[0],set())
            clusters[r[0]].add(r[1])

    clusters_h = {}
    for ref_id, other in clusters.items():
        for c in other:
            clusters_h[c] = ref_id
            
    return clusters_h

def extract_pdbs(data,
                 output = "./dataset/",
                 R1 = 8.0,
                 R2 = 16.0):

    clusters = get_clusters()
    
    Path(output+"/").mkdir(exist_ok=True)
    Path(output+"/clusters/").mkdir(exist_ok=True)
    Path(output+"/antigen/").mkdir(exist_ok=True)
    Path(output+"/antigen_fab/").mkdir(exist_ok=True)
    Path(output+"/antigen_fab_raw/").mkdir(exist_ok=True)
    
    for d in data.iloc():
        antigen_chain_id = d["antigen"].name.split("_")[-1][0]
        pdb_id = d["antigen"].name.split("/")[-1][:4]

        light_id = d["fab_pair"][0].split("_")[1]
        heavy_id = d["fab_pair"][1].split("_")[1]
        
        output_name = f"{pdb_id}_{antigen_chain_id}_{light_id}_{heavy_id}.pkl"
        print(output_name)

        if os.path.exists(output_name):
            continue
        
        path = pickle.load(open(d["antigen"],'rb'))
        
        if pdb_id+"_"+antigen_chain_id not in clusters:
            continue
            
        ref_name = clusters[pdb_id+"_"+antigen_chain_id]
        out = f"./{output}/clusters/{ref_name}"
        
        if os.path.exists(out+f"/{pdb_id}_{antigen_chain_id}_{light_id}_{heavy_id}.pkl"):
            continue
        
        Path(out).mkdir(exist_ok=True)
        antigen_df = pickle.load(open(d["antigen"],'rb'))

        pdb_structure      = PandasPdb().read_pdb(f"./PDB_blast_db/structs/{pdb_id}.pdb")
        pdb_structure.df.pop("HETATM")
        pdb_structure.df["ATOM"] = pdb_structure.df["ATOM"][pdb_structure.df["ATOM"]["chain_id"].isin([antigen_chain_id,
                                                                                                       light_id,
                                                                                                       heavy_id])]
        pdb_structure.df["ATOM"] = pdb_structure.df["ATOM"][pdb_structure.df["ATOM"]['alt_loc'].isin(['', 'A'])]
        
        calculate_contact_number(pdb_structure.df["ATOM"],
                                 antigen_chain_id,
                                 light_id,
                                 heavy_id,
                                 R1 = R1,
                                 R2 = R2)
    
        resi_ids = ["residue_number","insertion", "chain_id", "residue_name"]
        antigen_df["b_factor"] = -1
        
        for r,d in antigen_df.groupby(resi_ids):
            key = (int(r[0]),r[1],r[2],r[3])
            df_ = pdb_structure.df["ATOM"][(pdb_structure.df["ATOM"]["residue_number"] == int(r[0])) &
            (pdb_structure.df["ATOM"]["insertion"] == r[1]) &
            (pdb_structure.df["ATOM"]["chain_id"] == r[2]) &
            (pdb_structure.df["ATOM"]["residue_name"] == r[3])]
            if df_.shape[0]==0:
                continue
            antigen_df.loc[d.index.values, "b_factor"] = df_.iloc()[0]["b_factor"]

        pickle.dump(antigen_df, open(out+f"/{pdb_id}_{antigen_chain_id}_{light_id}_{heavy_id}.pkl",'wb'))
        
        antigen_df = antigen_df.dropna(subset=['residue_number'])        
        antigen_df['residue_number'] = antigen_df['residue_number'].astype(int)
        antigen_df['atom_number'] = antigen_df['atom_number'].astype(int)
        
        pdb_structure.df["ATOM"] = antigen_df
        pdb_structure.to_pdb(out+f"/{pdb_id}_{antigen_chain_id}_{light_id}_{heavy_id}.pdb")

        print(out+f"/{pdb_id}_{antigen_chain_id}_{light_id}_{heavy_id}.pdb")
        
data = pickle.load(open("./PDB_blast_db/antigen_fab_list.pkl",'rb'))        
extract_pdbs(data, output="./dataset_4.5/", R1=4.5, R2=16.0)

data = pickle.load(open("./PDB_blast_db/antigen_fab_list.pkl",'rb'))
extract_pdbs(data, output="./dataset/",     R1=8.0, R2=16.0)


### 11. Collect dataset of homologoues antigen clusters with calculated contact number values

In [3]:
def find_contacts(antigen_ids, fab_pair):
    """
    Function to extract multimers data
    """
    
    multimer_list = []
    coords = []
    for antigen_id in antigen_ids:
        prot = pickle.load(antigen_id.open('rb'))
        prot = prot.dropna(subset=['residue_number'])
        coords.append(prot[["x_coord","y_coord","z_coord"]].to_numpy())
    
    G = nx.Graph()
    
    for i in range(len(coords)):
        for j in range(i+1, len(coords)):
            cd = distance.cdist(coords[i], coords[j])
            n  = len(set(np.where(cd<4.5)[0]))
            if n>10:
                G.add_edge(i,j)
    if len(G.edges())==0:
        return
    
    connected_components = nx.connected_components(G)
    r = list((comp for comp in connected_components if 0 in comp))
    
    if len(r) == 0:
        return
    
    longest_subgraph_nodes = max(r, key=len)
    longest_subgraph = G.subgraph(longest_subgraph_nodes)

    antigen_multimer = [antigen_ids[i].name[:-4] for i in longest_subgraph]
    
    if len(antigen_multimer) <= 1:
        return
    
    fab_chains = [c.split("_")[1] for c in fab_pair]
    all_paths = []
    
    for a in antigen_multimer:
        all_paths += list(Path("./dataset/clusters/").glob("*/"+a+"*.pkl"))
    
    if len(all_paths) == 0:
        return
    
    prot_per_chain = {}
    pid = all_paths[0].name[:4]
    for path in all_paths:
        pdb_id = path.name[:6]        
        prot_per_chain.setdefault(pdb_id, [])
        prot_per_chain[pdb_id].append(pickle.load(path.open('rb')))
    
    multimer_prot = []
    for pdb_id in prot_per_chain:
        ref_pickle = prot_per_chain[pdb_id][0]
        bs = [p["b_factor"].to_numpy() for p in prot_per_chain[pdb_id]]
        bs = np.array(bs)
        bs = np.max(bs,axis=0)
        ref_pickle["b_factor"] = bs
        multimer_prot.append(ref_pickle)
    multimer_prot = pd.concat(multimer_prot)
    if np.max(bs)<=0:
        return
    
    name = pid

    for c in sorted(list(set(list(multimer_prot.dropna(subset=["residue_number"])["chain_id"])))):
        name+="_"+c
    
    pdb = PandasPdb()
    pdb.df["ATOM"] = multimer_prot.dropna(subset=["residue_number"])

    for c in ["atom_number", "residue_number", "line_idx"]:
        pdb.df["ATOM"][c] = pdb.df["ATOM"][c].astype(int)

    Path("./dataset/multimers/").mkdir(exist_ok=True)
    pdb.to_pdb("./dataset/multimers/"+name+".pdb")
    return
    
def find_proteins():
    """
    Function to prepare multimers dataset
    """
    fab_list = pickle.load(open("./PDB_blast_db/antigen_fab_list.pkl",'rb'))
    for name, d_ in fab_list.groupby(["fab_pair"]):
        if d_.shape[0] == 1:
            continue
            
        q = [p.name[:-4] for p in list(d_["antigen"])]

        pdb_id = q[0].split("_")[0]
        chain_ids = [q_.split("_")[1] for q_ in q]
        antigen_ids = list(d_["antigen"])
       
        nc1 = np.array(d_["n_contacts_heavy"])
        nc1+= np.array(d_["n_contacts_light"])
        
        ids = list(range(len(nc1)))
        ids.sort(key = nc1.__getitem__)
        ids = list(reversed(ids))
        
        chain_ids =  [chain_ids[i] for i in ids]
        antigen_ids = [antigen_ids[i] for i in ids]

        find_contacts(antigen_ids, name[0])
        print("-------")
        #exit(0)


def cluster_data(name = "5I8H_C", ref_cluster_name = None, folder_name = "dataset"):
    """
    Calculate MSA for each cluster to calculate consensus contact number values
    """
    sequences = []
    names = []
    dfs = {}
    dfs_full = {}
    
    fo = open(f"./{folder_name}/clusters/"+name+"/msa.fasta",'w')
    for path in Path(f"./{folder_name}/clusters/"+name+"/").glob("*.pkl"):
        data = pickle.load(path.open('rb'))
        data_ca = data[data["atom_name"]=="CA"].reset_index(drop=True)
        sequences.append("".join(data_ca["seqres"]))
        names.append(path.name)
        dfs[path.name.split(".")[0]] = data_ca
        dfs_full[path.name.split(".")[0]] = data            
        fo.write(f">{names[-1]}\n")
        fo.write(f"{sequences[-1]}\n")
    fo.close()
    
    if len(sequences)>1:
        subprocess.call(f"mafft --anysymbol --auto ./{folder_name}/clusters/"+name+f"/msa.fasta > ./{folder_name}/clusters/"+name+"/msa_al.fasta",shell=True)
    else:
        subprocess.call(f"cp ./{folder_name}/clusters/"+name+f"/msa.fasta ./{folder_name}/clusters/"+name+"/msa_al.fasta",shell=True)

    seqs = []
    Path(f"./{folder_name}/consensus/").mkdir(exist_ok=True)
    with open(f"./{folder_name}/clusters/"+name+"/msa_al.fasta") as f:
        for line in f:
            if line[0]==">":
                seqs.append([])
            seqs[-1].append(line.rstrip())
    
    seqs_h = {}
    if ref_cluster_name is None:
        ref_cluster_name = name
        
    ref_name = None
    for seq in seqs:
        n = seq[0][1:]
        seqs_h[n] = "".join(seq[1:]).replace("\n","")
        if n.startswith(ref_cluster_name):
            ref_name = n

    ref_seq = seqs_h[ref_name]
    contacts = {}
    
    if len(ref_seq.replace("-","")) != len(dfs[ref_name[:-4]]):
        return
    
    for n,s in seqs_h.items():
        n1 = 0
        n2 = 0
        for s1,s2 in zip(ref_seq,s):
            if s1!="-":
                n1+=1
            if s2!="-":
                n2+=1
            if s1!="-" and s2!="-":
                contacts.setdefault(n1-1, [])
                contacts[n1-1].append([n,n2-1])

    consensus_cn = {}
    dfs[ref_name[:-4]]["consensus_contact_number"] = -1
    
    for i in contacts:    
        ref_aa = dfs[ref_name.split(".")[0]].iloc()[i]["seqres"]
        tar_aas = []
        tar_b = []
        for c,ii in contacts[i]:
            c = c.split(".pkl")[0]
            tar_aa = dfs[c].iloc()[ii]["seqres"]
            print(tar_aa,ref_aa)
            if tar_aa != ref_aa:
                print("Error!")
                continue
            tar_aas.append(tar_aa)
            tar_b.append(dfs[c].iloc()[ii]["b_factor"])
        l = list(set(tar_aas))
        consensus_cn[i] = np.max(tar_b)
        dfs[ref_name[:-4]].loc[i, "consensus_contact_number"] = np.max(tar_b)

    ref_df = dfs[ref_name[:-4]]
    ref_df_full = dfs_full[ref_name[:-4]]
    for k,d in ref_df_full.groupby(["residue_key"]):
        cn = ref_df[ref_df["residue_key"] == k[0]].iloc()[0]["consensus_contact_number"]
        ref_df_full.loc[ref_df_full["residue_key"] == k[0], "b_factor"] = cn

    pickle.dump(ref_df, open(f"./{folder_name}/consensus/"+ref_name+"_CA.pkl",'wb'))
    pickle.dump(ref_df_full, open(f"./{folder_name}/consensus/"+ref_name+"_full.pkl",'wb'))
    prot = PandasPdb()
    
    prot.df["ATOM"] = ref_df_full.dropna(subset=['atom_number'])
    
    for c in ["atom_number", "residue_number", "line_idx"]:
        prot.df["ATOM"][c] = prot.df["ATOM"][c].astype(int)
    
    prot.to_pdb(f"./{folder_name}/consensus/"+ref_name+".pdb")
    

def load_old_ref_names():
    """
    Select compatible reference cluster names 
    """
    train = pd.read_csv("../data/sema_1.0/train_set.csv")
    test = pd.read_csv("../data/sema_1.0/test_set.csv")
    ref_names = [d["pdb_id_chain"][:6] for d in train.iloc()]
    ref_names+= [d["pdb_id_chain"][:6] for d in  test.iloc()]
    return set(ref_names)

def prepare_ds():
    """
    Function to calculate consensus contact number values aggregating precalculated antigen/fab structures above    
    """
    old_ref_names = load_old_ref_names()
    fab_list = pickle.load(open("./PDB_blast_db/antigen_fab_list.pkl",'rb'))
    n = set([n.name[:-4] for n in fab_list["antigen"]])
    
    old_ref_names = load_old_ref_names()
    
    for ds_name in ["_4.5", ""]:
        test = set()
        
        for path in Path(f"./dataset{ds_name}/clusters/").glob("*"):
            ref_name = path.name[:6]
            
            if len(path.name)!=6:
                continue
                
            for p in Path(f"./dataset{ds_name}/clusters/"+path.name+"/").glob("*.pkl"):
                pot_ref_name = p.name[:6]
                if pot_ref_name in old_ref_names:
                    ref_name = pot_ref_name
                    test.add(ref_name)
            
            print(path.name, ref_name, path.name == ref_name, len(test), len(old_ref_names))
            cluster_data(path.name, ref_name, f"dataset{ds_name}")


In [None]:
prepare_ds()


### 12. Prepare train and test csv files

In [5]:

def homology_filter_epi(df_train, df_test, p_cut): 
    """
    Function that filters out sequences from the train set that have homology to the test set
    """
    seq_ids1 = {"pdb_id_chain":[],
               "wt_seq":[]}
    
    for k,d in df_train.groupby(["pdb_id_chain"]):
        seq_ids1["pdb_id_chain"].append(k[0])
        seq = "".join(list(d["resi_aa"]))
        seq_ids1["wt_seq"].append(seq)

    seq_ids2 = {"pdb_id_chain":[],
               "wt_seq":[]}
    
    for k,d in df_test.groupby(["pdb_id_chain"]):
        seq_ids2["pdb_id_chain"].append(k[0])
        seq = "".join(list(d["resi_aa"]))
        seq_ids2["wt_seq"].append(seq)

    hf = homology_filter(pd.DataFrame(seq_ids1), pd.DataFrame(seq_ids2), p_cut=p_cut)
    df_train_filtered = df_train[df_train["pdb_id_chain"].isin(hf["pdb_id_chain"])]
    print(df_train.shape, df_train_filtered.shape)
    return df_train_filtered
          
         

def homology_filter(df_train, 
                    df_test,
                    p_cut):

    used = set()
    name = "test_"
    subprocess.call("rm -r temp",shell=True)
    Path("./temp/").mkdir(exist_ok=True)
    with open(f"./temp/all_sequences_{name}.fasta",'w') as fo:
        for d_ in df_train.iloc():
            if d_['pdb_id_chain'] in used:
                continue
            fo.write(f">{d_['pdb_id_chain']}\n{d_['wt_seq']}\n")
            used.add(d_['pdb_id_chain'])
    used = set()
    with open(f"./temp/test_sequences_{name}.fasta",'w') as fo_test:
        for d_ in df_test.iloc():
            if d_['pdb_id_chain'] in used:
                continue
            fo_test.write(f">{d_['pdb_id_chain']}\n{d_['wt_seq']}\n")
            used.add(d_['pdb_id_chain'])
            
    subprocess.call(f"makeblastdb -in all_sequences_{name}.fasta -dbtype prot", shell=True, cwd="./temp/")
    subprocess.call(f"blastp      -db all_sequences_{name}.fasta -query  test_sequences_{name}.fasta -outfmt 6 -out hits_{name}.tsv -num_threads 4", shell=True, cwd="./temp/")
    
    hit_data    = pd.read_csv(f"./temp/hits_{name}.tsv", delimiter='\t', header=None)
    hit_data    = hit_data[(hit_data.iloc[:,2]>p_cut) & (hit_data.iloc[:,-2]<0.05)]
    
    G = nx.Graph()
    for h in hit_data.iloc():
        e = [(h.iloc()[0],h.iloc()[1])]
        G.add_edges_from(e)
    subcomponents = list(nx.connected_components(G))    
    all_skip_nodes = set()
    for s in subcomponents:
        all_skip_nodes|=set(s)
    n_before = df_train.shape
    df_train = df_train[~df_train["pdb_id_chain"].isin(all_skip_nodes)]
    n_after  = df_train.shape
    
    print("Size before filtering:",  n_before)
    print("Size after filtering:",   n_after)
    
    return df_train


def collect_all_data():
    """
    Combine datasets calculated for R1=8.0/R2=16.0 and R1=4.5/R2=16.0
    """
    all_dataset = {}
    for path in Path("./dataset/consensus/").glob("*_CA.pkl"):
        name = path.name[:6]
        print(name)
        paths_2 = list(Path("./dataset_4.5/consensus/").glob(name+"*_CA.pkl"))[0]
        all_dataset[name] = merge_big_small(path, paths_2)
    pickle.dump(all_dataset, open("all_dataset.pkl",'wb'))
    
    
def merge_big_small(path_big, path_small):
    """
    Contact numbers with R1=8.0/R2=16.0 were calculated to train the model
    Contact nubmers with R1=4.5/R2=16.0 were calculated for binary classification
    this function merges them     
    """
    data_big = pickle.load(path_big.open('rb'))
    data_small = pickle.load(path_small.open('rb'))
    assert data_big.shape[0] == data_small.shape[0]
    cn_small = data_small["consensus_contact_number"]
    data_big["consensus_contact_number_4.5"] = cn_small
    binary = []
    for c in list(cn_small):
        #print(c)
        if c == -1:
            binary.append(-1)
            continue
        if c == 0:
            binary.append(0)
        if c>0:
            binary.append(1)    
    data_big["contact_number_binary"] = binary
    return data_big

    
def reformat_dataframe(df, name):
    """
    Change dataset to the old format
    """
    data = {"pdb_id_chain":[],
            "pdb_id":[],
            "resi_pos":[],
            "resi_aa":[],
            "resi_name":[],
            "contact_number":[],
            "contact_number_binary":[]}
        
    for d_ in df.iloc():
        data["pdb_id_chain"].append(name)
        data["pdb_id"].append(name[:4])
        resi_pos = d_["residue_number"]
        if not np.isnan(resi_pos):
            resi_pos = int(resi_pos)
        data["resi_pos"].append(resi_pos)
        data["resi_aa"].append(d_["seqres"])
        data["resi_name"].append(d_["residue_name"])
        cn = d_["b_factor"]
        if cn<0:
            cn = -100
        if cn<=0:
            data["contact_number"].append(cn)
        else:
            data["contact_number"].append(np.log(cn+1))
        nbin = d_["contact_number_binary"]
        if nbin<0:
            nbin = -100
        data["contact_number_binary"].append(nbin)
    df=  pd.DataFrame(data)
    df["resi_pos"] = df["resi_pos"].astype("Int64")
    return df

def generate_csv():
    """
    Generates train set
    """
    new_ds = pickle.load(open("all_dataset.pkl",'rb'))
    
    old_test_updated = []
    n = 0
    new_test = {}
    
    for pid,d_ in old_ds.groupby(["pdb_id_chain"]):
        pid = pid[0][:6]
        n+=1
        if pid not in new_ds:
            d_add = d_
            new_cn = []
            new_bin = []
            for l in d_add.iloc():
                key = (l["pdb_id_chain"][:6], int(l["resi_pos"]),l["resi_name"])
                bin_val = vals_upd[key]
                if bin_val<0:
                    bin_val = -1
                new_bin.append(vals_upd[key])
                cn = l["contact_number"]
                if vals_upd[key]<0:
                    cn = -1
                new_cn.append(cn)
            d_["contact_number"] = new_cn
            d_["contact_number_binary"] = new_bin
            new_test[pid] = d_
            continue

        no_nans = new_ds[pid][~new_ds[pid]["record_name"].isnull()]
        new_test[pid] = reformat_dataframe(no_nans, pid)
    
    old_test_set = [v for k,v in new_test.items()]
    old_test_set = pd.concat(old_test_set).reset_index(drop=True)
    old_test_set.to_csv("../data/sema_1.0/test_set.csv")

    full_train_set = []
    for pid in new_ds:
        if pid in new_test:
            continue
        full_train_set.append(reformat_dataframe(new_ds[pid], pid))

    full_train_set = pd.concat(full_train_set).reset_index(drop=True)
    full_train_set_filtered = homology_filter_epi(full_train_set, old_test_set, 30.0)
    full_train_set_filtered.to_csv("../data/sema_2.0/train_set.csv")
    
generate_csv()