In [None]:
import pandas as pd

#get known complexes

#embed structures


#use decoder and generate z vecs

#use decoder sigmoid to get contact proba

--2024-04-29 15:56:12--  https://shmoo.weizmann.ac.il/elevy/3dcomplexV6/dataV6/BU_all_renum.tar.gz
Resolving shmoo.weizmann.ac.il (shmoo.weizmann.ac.il)... 132.77.150.157
Connecting to shmoo.weizmann.ac.il (shmoo.weizmann.ac.il)|132.77.150.157|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18052384562 (17G) [application/x-gzip]
Saving to: ‘BU_all_renum.tar.gz’


In [None]:
import gzip
import tarfile
#iterate over all files in the tar.gz
def extract_files(tar_gz_file,n):
    with tarfile.open(tar_gz_file, "r:gz") as tar:
        members = tar.getmembers():
        member = members[n]
        f = tar.extractfile(member)
        if f is not None:
            content = f.read()
            return content

#files are pdbs. transform content into a pdb file
def write_pdb(content):
    with open("temp.pdb", "w") as f:
        f.write(content)

#extract the first file$
content = extract_files("BU_all_renum.tar.gz",0)
write_pdb(content)

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)

    def forward(self, x):
        # Flatten input
        flat_x = x.view(-1, self.embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_x**2, dim=1, keepdim=True)
                     + torch.sum(self.embeddings.weight**2, dim=1)
                     - 2 * torch.matmul(flat_x, self.embeddings.weight.t()))

        # Get the encoding that has the min distance
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=x.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize the latents
        quantized = torch.matmul(encodings, self.embeddings.weight).view_as(x)

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), x)
        q_latent_loss = F.mse_loss(quantized, x.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        # Straight-through estimator
        quantized = x + (quantized - x).detach()
        return quantized, loss

    def discretize_z(self, x):
        # Flatten input
        flat_x = x.view(-1, self.embedding_dim)
        # Compute distances between input and codebook embeddings
        distances = (torch.sum(flat_x**2, dim=1, keepdim=True)
                     + torch.sum(self.embeddings.weight**2, dim=1)
                     - 2 * torch.matmul(flat_x, self.embeddings.weight.t()))
        # Get the encoding that has the minimum distance
        closest_indices = torch.argmin(distances, dim=1)
        
        # Convert indices to characters
        char_list = [chr(idx.item()) for idx in closest_indices]
        return closest_indices, char_list

    def string_to_embedding(self, s):
        
        # Convert characters back to indices
        indices = torch.tensor([ord(c) for c in s], dtype=torch.long, device=self.embeddings.weight.device)
        
        # Retrieve embeddings from the codebook
        embeddings = self.embeddings(indices)
        
        return embeddings

class HeteroGAE_Encoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_embeddings, commitment_cost, metadata={}):
        super(HeteroGAE_Encoder, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.metadata = metadata
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.in_channels = in_channels

        for i in range(len(hidden_channels)):
            self.convs.append(
                torch.nn.ModuleDict({
                    '_'.join(edge_type): SAGEConv(in_channels if i == 0 else hidden_channels[i-1], hidden_channels[i])
                    for edge_type in metadata['edge_types']
                })
            )
        self.lin = Linear(hidden_channels[-1], out_channels)
        self.vector_quantizer = VectorQuantizer(num_embeddings, out_channels, commitment_cost)

    def forward(self, x, edge_index_dict):
        for i, convs in enumerate(self.convs):
            # Apply the graph convolutions and average over all edge types
            x = [conv(x, edge_index_dict[tuple(edge_type.split('_'))]) for edge_type, conv in convs.items()]
            x = torch.stack(x, dim=0).mean(dim=0)
            x = F.relu(x) if i < len(self.hidden_channels) - 1 else x
        x = self.lin(x)
        z_quantized, vq_loss = self.vector_quantizer(x)
        return z_quantized, vq_loss

class HeteroGAE_VariationalQuantizedEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_embeddings, commitment_cost, metadata={}):
        super(HeteroGAE_VariationalQuantizedEncoder, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.metadata = metadata
        self.hidden_channels = hidden_channels
        latent_dim = out_channels
        self.latent_dim = out_channels
        self.out_channels = out_channels
        self.in_channels = in_channels

        for i in range(len(hidden_channels)):
            self.convs.append(
                torch.nn.ModuleDict({
                    '_'.join(edge_type): SAGEConv(in_channels if i == 0 else hidden_channels[i-1], hidden_channels[i])
                    for edge_type in metadata['edge_types']
                })
            )
        self.fc_mu = Linear(hidden_channels[-1], latent_dim)
        self.fc_logvar = Linear(hidden_channels[-1], latent_dim)
        self.vector_quantizer = VectorQuantizer(num_embeddings, latent_dim, commitment_cost)

    def forward(self, x, edge_index_dict):
        for i, convs in enumerate(self.convs):
            # Apply the graph convolutions and average over all edge types
            x = [conv(x, edge_index_dict[tuple(edge_type.split('_'))]) for edge_type, conv in convs.items()]
            x = torch.stack(x, dim=0).mean(dim=0)
            x = F.relu(x) if i < len(self.hidden_channels) - 1 else x
        
        # Obtain the mean and log variance for the latent variables
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        
        # Reparameterization trick
        z = self.reparameterize(mu, logvar)
        
        # Vector quantization
        z_quantized, vq_loss = self.vector_quantizer(z)
        
        return z_quantized, vq_loss, mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
            return mu


In [None]:

class HeteroGAE_Pairwise_Decoder(torch.nn.Module):
    def __init__(self, encoder_out_channels, xdim=20, hidden_channels={'res_backbone_res': [20, 20, 20]}, out_channels_hidden=20, Xdecoder_hidden=100, metadata={}):
        super(HeteroGAE_Decoder, self).__init__()
        self.convs1 = torch.nn.ModuleList()
        self.convs2 = torch.nn.ModuleList()

        self.metadata = metadata
        self.hidden_channels = hidden_channels
        self.out_channels_hidden = out_channels_hidden
        self.in_channels = encoder_out_channels

        for i in range(len(self.hidden_channels1[('res', 'backbone', 'res')])):
            self.convs1.append(
                torch.nn.ModuleDict({
                    '_'.join(edge_type): SAGEConv(self.in_channels if i == 0 else self.hidden_channels[edge_type][i-1], self.hidden_channels[edge_type][i])
                    for edge_type in [('res', 'backbone', 'res')]
                })
            )

        self.lin = Linear(hidden_channels[('res', 'backbone', 'res')][-1], self.out_channels_hidden)
        
        self.sigmoid = nn.Sigmoid()

        #detect interaction
        self.detecter = torch.nn.Sequential(
            torch.nn.Linear(self.out_channels_hidden*2, Xdecoder_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(Xdecoder_hidden, Xdecoder_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(Xdecoder_hidden, 1),
            torch.nn.Sigmoid(),
        )
        

    def forward(self, z1, ,z2 ,  edge_index, backbones, **kwargs):
        for layer in self.convs:
            for edge_type, conv in layer.items():
                for z in [z1, z2]:
                    z = conv(z, backbones[tuple(edge_type.split('_'))])
                    z = F.relu(z)
                z = self.lin(z)
        
        #edge index is intraprotein in this case
        sim_matrix = (z1[edge_index[0]] * z2[edge_index[1]]).sum(dim=1)

        edge_probs = self.sigmoid(sim_matrix)
        x_interact = self.detecter( torch z)
        edge_probs = edge_probs*x_interact
        
        return x_interact,  edge_probs



In [None]:
encoder = HeteroGAE_Encoder(in_channels=ndim, hidden_channels=[ 100 ]*3 , out_channels= 10, metadata=metadata , num_embeddings=256, commitment_cost= 1.25 )
encoder.load_state_dict(torch.load(encoder_save))
encoder.eval()

In [None]:

decoder = HeteroGAE_Pairwise_Decoder(encoder_out_channels = encoder.out_channels , 
                            hidden_channels={ ( 'res','backbone','res'):[ 40 ] * 7  } , out_channels_hidden= 20 , metadata=metadata , amino_mapper = aaindex  )


In [None]:
from Bio import PDB

def get_structure(pdb_file):
    parser = PDB.PDBParser()
    structure = parser.get_structure("X", pdb_file)
    return structure

def get_chains(structure):
    chains = []
    for model in structure:
        for chain in model:
            chains.append(chain)
    return chains

def ret_pairwise(c1, c2 , threshold=10):
    #return all pairwise distances between beta carbons
    c1_atoms = [a for a in c1.get_atoms() if a.get_id() == "CB"]
    c2_atoms = [a for a in c2.get_atoms() if a.get_id() == "CB"]
    dists = np.zeros((len(c1_atoms), len(c2_atoms)))
    for i, a1 in enumerate(c1_atoms):
        for j, a2 in enumerate(c2_atoms):
            dists[i,j] = a1 - a2
    if threshold is not None:
        np.clip(dists, 0, threshold, out=dists)
    return dists

def get_all_pairwise(chains):
    #find all distance matrices between chains in complex
    dists = {}
    for i, c1 in enumerate(chains):
        for j, c2 in enumerate(chains):
            if i != j:
                dists[(i,j)] = ret_pairwise(c1, c2)
    return dists



In [None]:
#use pdb fixer on the pdb file
from pdbfixer import PDBFixer
def fix_pdb(pdb_file):
    fixer = PDBFixer(pdb_file)
    fixer.findMissingResidues()
    fixer.findNonstandardResidues()
    fixer.findMissingAtoms()
    fixer.addMissingAtoms()
    fixer.addMissingHydrogens()
    #output fixed file
    fixer.writePdb(pdb.split('.')[0] + "_fixed.pdb")
    return pdb.split('.')[0] + "_fixed.pdb"



In [None]:
#use embedding to get z vecs
from torch_geometric.data import HeteroData
#import sageconv
from torch_geometric.nn import SAGEConv , Linear
#import module dict and module list
from torch.nn import ModuleDict, ModuleList
from torch_geometric.nn import global_mean_pool
#import negative sampling
from torch_geometric.utils import negative_sampling
#import graph pooling 
from torch_geometric.nn import global_mean_pool , SAGPooling

class HeteroGAE_pair_Decoder(torch.nn.Module):
    def __init__(self, encoder_out_channels , xdim=20 , hidden_channels = [20,20,20] , pooling = [.5,.5,.1] , out_channels= 10 , Xdecoder_hidden = 100, metadata={}):
        super(HeteroGAE_Decoder, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.metadata = metadata
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.in_channels = encoder_out_channels
        for i in range(len(self.hidden_channels)):
            self.convs.append(
                torch.nn.ModuleDict({
                    '_'.join(edge_type): SAGEConv(self.in_channels if i == 0 else self.hidden_channels[i-1], self.hidden_channels[i])
                    for edge_type in [ ( 'res','backbone','res') ]
                })
            )
        self.lin = Linear(hidden_channels[-1], out_channels)
        #sigmoid to predict the edge probabilities after graph conv
        self.sigmoid = nn.Sigmoid()
        #global pooling layer

        self.pools = torch.nn.ModuleList()
        for i in range(len(pooling)):
            if i == 0 :
                self.pools.append( SAGPooling(hidden_channels[-1], ratio=pooling[i] , GNN=SAGEConv) ) 
            else:
                self.pools.append( SAGPooling(pooling[i-1], ratio=pooling[i] , GNN=SAGEConv) )
        self.pools.append( global_mean_pool() )
        #sigmoid to predict the probability of interaction
        self.interaction = linear(out_channels, 1)
        self.interaction_sigmoid = nn.Sigmoid()

        # add stack of dense layers to reconstruct the node features
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(out_channels , Xdecoder_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(Xdecoder_hidden, Xdecoder_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(Xdecoder_hidden , xdim),
        )
    
    def forward(self, z , edge_index , backbone=None, **kwargs):
        # Transform the latent space if necessary
        edge_probs = {}
        for layer in self.convs:
            for edge_type, conv in layer.items():
                z = conv(z, backbone)
                z = F.relu(z)
        zpool = z
        #global pooling using backbone as node features
        for i, pool in enumerate(self.pools):
            zpool, edge_index, _, batch, _, _ = pool(zpool, edge_index, batch)
        
        #sigmoid to predict the probability of interaction
        interaction = self.sigmoid(self.interaction(zpool))

        z = self.lin(z)
        sim_matrix =  (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1)
        edge_probs = self.sigmoid(sim_matrix)

        #output the probability of interaction
        #turn into connectivity with two columns for edge index
        #edge_probs = torch.stack( torch.where( edge_probs > 0.5 ) , dim=0)
        #reconstruct the node features with decoder
        x_r = self.decoder(z)
        return x_r , edge_probs , interaction
    
    def forward_retz(self, z , edge_index , backbone=None, **kwargs):
        # Transform the latent space if necessary
        edge_probs = {}
        for layer in self.convs:
            for edge_type, conv in layer.items():
                z = conv(z, backbone)
                z = F.relu(z)
        z = self.lin(z)
        sim_matrix =  (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1)
        edge_probs = self.sigmoid(sim_matrix)
        #turn into connectivity with two columns for edge index
        #edge_probs = torch.stack( torch.where( edge_probs > 0.5 ) , dim=0)
        #reconstruct the node features with decoder
        x_r = self.decoder(z)

        return x_r , edge_probs , z

EPS = 1e-10
def recon_loss( z: Tensor, pos_edge_index: Tensor , backbone:Tensor = None , decoder = None ) -> Tensor:
    r"""Given latent variables :obj:`z`, computes the binary cross
    entropy loss for positive edges :obj:`pos_edge_index` and negative
    sampled edges.

    Args:
        z (torch.Tensor): The latent space :math:`\mathbf{Z}`.
        pos_edge_index (torch.Tensor): The positive edges to train against.
        neg_edge_index (torch.Tensor, optional): The negative edges to
            train against. If not given, uses negative sampling to
            calculate negative edges. (default: :obj:`None`)
    """
    
    pos =decoder(z, pos_edge_index, backbone )[1]
    #turn pos edge index into a binary matrix
    pos_loss = -torch.log( pos + EPS).mean()
    neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
    neg = decoder(z ,  neg_edge_index, backbone )[1]
    neg_loss = -torch.log( ( 1 - neg) + EPS ).mean()
    return pos_loss + neg_loss

#define loss for x reconstruction   
def x_reconstruction_loss(x, recon_x):
    """
    compute the loss over the node feature reconstruction.
    """
    return F.mse_loss(recon_x, x)


#amino acid onehot loss for x reconstruction
def aa_reconstruction_loss(x, recon_x):
    """
    compute the loss over the node feature reconstruction.
    using categorical cross entropy
    """
    
    return F.cross_entropy(recon_x, x)

def interaction_loss(pred_x, x):
    """
    compute the binary cross entropy loss.
    """
    return F.binary_cross_entropy(pred_x, x)