# Make ProteinSolver embeddings for multiple proteins

#### Import packages

In [None]:
# import stuff
import os
import sys
import pandas as pd
import proteinsolver
import torch
import torch_geometric
import matplotlib.pyplot as plt
import kmbio
from kmbio import PDB
from Bio.PDB import *
from Bio import SeqIO
from PS_model import *

#### Prepare paths

In [None]:
#Assign paths
pdb_path = "./pdb-files/"
embs_path = "./PS_embeddings"
statefile = "./e53-s1952148-d93703104.state"
device = "cpu"

#### Define merging funciton

In [None]:
#Define merging function
def merge_chains(structure, merged_chain_name="A"):
    """merges a structure with multiple chains into a single chain"""
    # generate empty structure
    new_structure = kmbio.PDB.Structure(structure.id)
    new_model = kmbio.PDB.Model(0)
    new_structure.add(new_model)
    new_chain = kmbio.PDB.Chain(merged_chain_name)
    new_model.add(new_chain)
    
    
    # sort chains according to index of first residue
    chains = list(structure.chains)
    start_positions = [list(chain.residues)[0].id[1] for chain in chains] # idx 1 is residue position
    sorted_chains = [chain for _, chain in sorted(zip(start_positions, chains))]
    
    chain_len = 1  # constant to offset positions of residues in other chains
    for i, chain in enumerate(sorted_chains):
        res_list = list(chain.residues)
        if i > 0:  # skip first chain
            for j, res in list(enumerate(res_list))[::-1]:  # iterate in reverse to prevent duplicate idxs
                res.id = (res.id[0], j + chain_len + 1, res.id[2])
        chain_len += res_list[-1].id[1]
        new_chain.add(chain.residues)
    return new_structure

## Make PS embeddings

In [None]:
#Set parameters
num_features = 20
adj_input_size = 2
hidden_size = 128

In [None]:
#Define model
gnn = Net(
    x_input_size=num_features + 1,
    adj_input_size=adj_input_size,
    hidden_size=hidden_size,
    output_size=num_features
)
gnn.load_state_dict(torch.load(statefile, map_location=device))
gnn.eval()
gnn = gnn.to(device)

In [None]:
#Get all filenames for new structures
pdbs = os.listdir(pdb_path)
print("There are {} pdbs".format(len(pdbs)), end = "\n\n")

In [None]:
count = 0

#Iterate through all the pdb files
for pdb in pdbs:
    error_flag = False
    count += 1
    
    #Screen output
    screen = f"Working with file {count} of {len(pdbs)}. Name: {pdb}"
    screen = screen.ljust(60 , " ")
    print(screen, end = "\r")
    
    #Get structure 
    structure = merge_chains(PDB.load("{}/{}".format(pdb_path, pdb)))

    #Get name id
    name = pdb.split(".")[0]
    
    #Extract sequences and adjacency matrix
    protein_data = proteinsolver.utils.extract_seq_and_adj(structure, "A")

    #Preprocess data
    data1 = proteinsolver.datasets.protein.row_to_data(protein_data)
    data2 = proteinsolver.datasets.protein.transform_edge_attr(data1)
        
    #Make embeddings
    PS_embed = gnn.forward_without_last_layer(data2.x, data2.edge_index, data2.edge_attr)
        
    #Save embeddings
    emb_path = f'{embs_path}/PS_{name}.pt'
    torch.save(PS_embed, emb_path)