# Add zeroes to embeddings where there is a gap in the alignment

This notebook ensures equal dimensions of the esm and ps embeddings by adding zeroes where there is an alignment gap

### Import and load all the necessary modules and files

In [None]:
#Import
import sys, os
from Bio.Seq import Seq
from Bio import SeqIO
from Bio import pairwise2
from Bio.pairwise2 import format_alignment
from Bio import BiopythonWarning
import warnings
import pandas as pd
import torch

In [None]:
#Assign pdb path
pdb_path = "../1_Structures/pdb-files"
esm_path = "PCA_reduced/ESM_embeddings/"
ps_path = "PCA_reduced/PS_embeddings/"
esm_out = "Gap_inserted/ESM_embeddings/"
ps_out = "Gap_inserted/PS_embeddings/"

In [None]:
# Catch warnings, since these are predicted PDBs
with warnings.catch_warnings():
    warnings.simplefilter('ignore', BiopythonWarning)
    
    #initialize
    total = len(os.listdir(pdb_path))
    count = 0

    #Save in a dict
    pdb_fasta = dict()

    #Iterate over pdb files
    for pdb in os.listdir(pdb_path):
        count += 1
        print(f"Working with pdb {count}/{total}", end = "\r")
        ID = pdb.split(".")[0].split("_")[0]
    
        #Collect fasta sequence
        with open(f"{pdb_path}/{pdb}","r") as pdb_file:
            for record in SeqIO.parse(pdb_file, "pdb-atom"):
                pdb_fasta[ID] = record.seq
            

In [None]:
#Load the original fasta sequences
df = pd.read_csv("../0_DataPreprocessing/CleanedData.csv",sep=",")
ID = df["ID"]
seqs = df["sequence"]

#Ensure no trailing blank space
seqs = [seq.strip() for seq in seqs]

In [None]:
#Order original fastas into dict
org_fasta = dict()
for i, s in zip(ID,seqs):
    org_fasta[i] = Seq(s)

### Work with one set of files at a time and perform all necessary insertions

In [None]:
#initialize
total_changes = 0
count = 0
total = len(pdb_fasta)
    
#Iterate through all pdb fasta
for i, s in org_fasta.items():
    if i in pdb_fasta:
        count += 1
        print(f"Working with entry {count}/{total}", end = "\r")
    
        #Perform pariwise alignment
        alignment =  pairwise2.align.globalxx(s, pdb_fasta[i])
        pdb_seq = alignment[0].seqB
        fasta_seq = alignment[0].seqA
    
        #find gap positions
        gap = "-"
        pdb_gap = [pos for pos, char in enumerate(pdb_seq) if char == gap]
        fasta_gap = [pos for pos, char in enumerate(fasta_seq) if char == gap]
    
        #initialize for sanity check
        out_embs_ps = torch.empty(0,0)
        out_embs_esm = torch.empty(0,0)
    
        #Check if there are any gaps in the pdb sequence
        if len(pdb_gap) != 0:
            total_changes += 1
            embs = torch.load(ps_path + "PS_" + i + ".pt")
            zero = torch.zeros(1,30)
            out_embs_ps = torch.empty(0,30)
        
            #insert zeroes in the gap positions
            for column in range(len(embs)):
                if column in pdb_gap:
                    pdb_gap = pdb_gap[1:]
                    insert = torch.unsqueeze(embs[column], dim=0)
                    out_embs_ps = torch.cat([out_embs_ps,zero,insert], dim=0)
                else:
                    insert = torch.unsqueeze(embs[column], dim=0)
                    out_embs_ps = torch.cat([out_embs_ps,insert], dim=0)
            
            # Append zeroes for trailing gaps        
            if len(pdb_gap) != 0:
                cat_list = [out_embs_ps]
                cat_list += [zero]*len(pdb_gap)
                out_embs_ps = torch.cat(cat_list, dim = 0)   

       
        #Check if there are any gaps in the fasta sequence
        if len(fasta_gap) != 0:
            total_changes += 1
            embs = torch.load(esm_path + "ESM_" + i + ".pt")
            zero = torch.zeros(1,30)
            out_embs_esm = torch.empty(0,30)
        
            #insert zeroes in the gap positions
            for column in range(len(embs)):
                if column in fasta_gap:
                    fasta_gap = fasta_gap[1:]
                    insert = torch.unsqueeze(embs[column], dim=0)
                    out_embs_esm = torch.cat([out_embs_esm,zero,insert], dim=0)
                else:
                    insert = torch.unsqueeze(embs[column], dim=0)
                    out_embs_esm = torch.cat([out_embs_esm, insert], dim=0)

            # Append zeroes for trailing gaps        
            if len(fasta_gap) != 0:
                cat_list = [out_embs_esm]
                cat_list += [zero]*len(fasta_gap)
                out_embs_esm = torch.cat(cat_list, dim = 0)   
       
        
        #Double check that everything has the same dimensions
        assert len(pdb_seq) == len(fasta_seq)
        org_ps = torch.load(ps_path + "PS_" + i + ".pt")
        org_esm = torch.load(esm_path + "ESM_" + i + ".pt")
        if out_embs_ps.shape != torch.empty(0,0).shape:
            if out_embs_esm.shape != torch.empty(0,0).shape:
                assert out_embs_ps.shape == out_embs_esm.shape
            else:
                assert out_embs_ps.shape == org_esm.shape
        elif out_embs_esm.shape != torch.empty(0,0).shape:
            assert out_embs_esm.shape == org_ps.shape
        else:
            assert org_ps.shape == org_esm.shape

        # Save the new embeddings with gaps
        if out_embs_ps.shape != torch.empty(0,0).shape:
            if out_embs_esm.shape != torch.empty(0,0).shape:
                assert out_embs_ps.shape == out_embs_esm.shape
                emb_path = f'{ps_out}/PS_{i}.pt'
                torch.save(out_embs_ps, emb_path)
                emb_path = f'{esm_out}/ESM_{i}.pt'
                torch.save(out_embs_esm, emb_path)
              
            else:
                assert out_embs_ps.shape == org_esm.shape
                emb_path = f'{ps_out}/PS_{i}.pt'
                torch.save(out_embs_ps, emb_path)
                emb_path = f'{esm_out}/ESM_{i}.pt'
                torch.save(org_esm, emb_path)
                
        elif out_embs_esm.shape != torch.empty(0,0).shape:
            assert out_embs_esm.shape == org_ps.shape
            emb_path = f'{ps_out}/PS_{i}.pt'
            torch.save(org_ps, emb_path)
            emb_path = f'{esm_out}/ESM_{i}.pt'
            torch.save(out_embs_esm, emb_path)
        
        else:
            assert org_ps.shape == org_esm.shape   
            emb_path = f'{ps_out}/PS_{i}.pt'
            torch.save(org_ps, emb_path)
            emb_path = f'{esm_out}/ESM_{i}.pt'
            torch.save(org_esm, emb_path)
            
               
            
    #Handle ESM embeddings with no structure (PS embedding)       
    else:
        esm_embs = torch.load(esm_path + "ESM_" + i + ".pt")
        ps_embs = torch.zeros_like(esm_embs)
        emb_path = f'{ps_out}/PS_{i}.pt'
        torch.save(ps_embs, emb_path)
        emb_path = f'{esm_out}/ESM_{i}.pt'
        torch.save(esm_embs, emb_path)

print(f"\n\ntotal changes: {total_changes}")


        