In [5]:
import pandas as pd
import networkx as nx
#get known complexes
datadir = '/home/dmoi/datasets/foldtree2/complexes/'
complexdf = pd.read_csv(datadir+'contactDefinition.txt' , sep='\t')
print(complexdf.head())


     code ch1 ch2  nres  len1  len2 dom_s1      dom_p1 dom_s2      dom_p2  \
0  4rk1_1   A   B  31.0   267   268  53822   PF13377.1  53822   PF13377.1   
1  2gqn_1   D   B  53.0   392   392  53383  PF01053.15  53383  PF01053.15   
2  2gqn_1   D   C  17.0   392   391  53383  PF01053.15  53383  PF01053.15   
3  2gqn_1   A   C  52.0   391   391  53383  PF01053.15  53383  PF01053.15   
4  2gqn_1   A   D  28.5   391   392  53383  PF01053.15  53383  PF01053.15   

   ident  homo  
0      1   1.0  
1      1   1.0  
2      1   1.0  
3      1   1.0  
4      1   1.0  


In [None]:
class StructureDataset(Dataset):
    def __init__(self, h5dataset ):
        super().__init__()
        #keys should be the structures
        self.h5data = h5dataset
        self.structlist = list(self.structures.keys())

    def __len__(self):
        return len(self.structlist)

    def __getitem__(self, idx , chain=0):
        if type(idx) == str:
            f = self.h5data[idx][chain]
        elif type(idx) == int:
            f = self.h5data[self.structlist[idx]][chain]
        else:
            raise 'use a structure filename or integer'
        data = {}
        hetero_data = HeteroData()
        if 'node' in f.keys():
            for node_type in f['node'].keys():
                node_group = f['node'][node_type]
                # Assuming 'x' exists
                if 'x' in node_group.keys():
                    hetero_data[node_type].x = torch.tensor(node_group['x'][:])
        # Edge data
        if 'edge' in f.keys():
            for edge_name in f['edge'].keys():
                edge_group = f['edge'][edge_name]
                src_type, link_type, dst_type = edge_name.split('_')
                edge_type = (src_type, link_type, dst_type)
                # Assuming 'edge_index' exists
                if 'edge_index' in edge_group.keys():
                    hetero_data[edge_type].edge_index = torch.tensor(edge_group['edge_index'][:])
                
                # If there are edge attributes, load them too
                if 'edge_attr' in edge_group.keys():
                    hetero_data[edge_type].edge_attr = torch.tensor(edge_group['edge_attr'][:])
        #return pytorch geometric heterograph
        return hetero_data




class ComplexDataset(Dataset):
    def __init__(self, h5dataset ):
        super().__init__()
        #keys should be structures and pairs
        self.h5data = h5dataset
        self.structlist = list(self.structures.keys())
        self.npairs = { struct:len(h5dataset[struct]['contacts']) for struct in h5dataset if  }
        
    def __len__(self):
        return sum(self.npairs.values())

    def __getitem__(self, idx , pair=None):
        if type(idx) == str:
            f = self.h5data[idx]['contacts'][pair]
        elif type(idx) == int:
            f = self.h5data[self.structlist[idx]]['contacts'][pair]
        else:
            raise 'use a structure filename or integer'
        data = {}
        hetero_data = HeteroData()
        if 'node' in f.keys():
            for node_type in f['node'].keys():
                node_group = f['node'][node_type]
                # Assuming 'x' exists
                if 'x' in node_group.keys():
                    hetero_data[node_type].x = torch.tensor(node_group['x'][:])
        # Edge data
        if 'edge' in f.keys():
            for edge_name in f['edge'].keys():
                edge_group = f['edge'][edge_name]
                src_type, link_type, dst_type = edge_name.split('_')
                edge_type = (src_type, link_type, dst_type)
                # Assuming 'edge_index' exists
                if 'edge_index' in edge_group.keys():
                    hetero_data[edge_type].edge_index = torch.tensor(edge_group['edge_index'][:])
                
                # If there are edge attributes, load them too
                if 'edge_attr' in edge_group.keys():
                    hetero_data[edge_type].edge_attr = torch.tensor(edge_group['edge_attr'][:])
        #return pytorch geometric heterograph
        return hetero_data






#create the structural graph dataset
def store_pyg(pdbfiles, aaproperties, filename, verbose = True , complex_distmat = True):
    with h5py.File(filename , mode = 'w') as f:
        #create structs list
        for pdbfile in tqdm.tqdm(pdbfiles):
            if verbose:
                print(pdbfile)
            accession = pdbfile.split('/')[-1].split('.')[0]
            
            f.create_group(accession)
            chains = read_pdb(pdbfile)
            for i,chain in enumerate(chains):
                hetero_data = struct2pyg(chain, aaproperties)
                i = str(i)
                if hetero_data:
                    for node_type in hetero_data.node_types:
                        if hetero_data[node_type].x is not None:
                            node_group = f.create_group(f'{accession}/{i}/node/{node_type}')
                            node_group.create_dataset('x', data=hetero_data[node_type].x.numpy())
                    # Iterate over edge types and their connections
                    for edge_type in hetero_data.edge_types:
                        # edge_type is a tuple: (src_node_type, relation_type, dst_node_type)
                        edge_group = f.create_group(f'{accession}/{i}/edge/{edge_type[0]}_{edge_type[1]}_{edge_type[2]}')
                        if hetero_data[edge_type].edge_index is not None:
                            edge_group.create_dataset('edge_index', data=hetero_data[edge_type].edge_index.numpy())
                        # If there are edge features, save them too
                        if hasattr(hetero_data[edge_type], 'edge_attr') and hetero_data[edge_type].edge_attr is not None:
                            edge_group.create_dataset('edge_attr', data=hetero_data[edge_type].edge_attr.numpy())
                if complex_distmat == True and len(chains)>1:
                    #all vs all between chains
                    for j,c1 in enumerate(chains):
                        for k,c2 in enumerate(chains):
                            if j<k:
                                k = str(k)
                                j = str(j)
                                #generate the contact graph...
                                hetero_data = complex2pyg(chain, aaproperties)
                                if hetero_data:
                                    for node_type in hetero_data.node_types:
                                        if hetero_data[node_type].x is not None:
                                            node_group = f.create_group(f'{accession}/contacts/{j}_{k}/node/{node_type}')
                                            node_group.create_dataset('x', data=hetero_data[node_type].x.numpy())
                                    # Iterate over edge types and their connections
                                    for edge_type in hetero_data.edge_types:
                                        # edge_type is a tuple: (src_node_type, relation_type, dst_node_type)
                                        edge_group = f.create_group(f'{accession}/{j}_{k}/edge/{edge_type[0]}_{edge_type[1]}_{edge_type[2]}')
                                        if hetero_data[edge_type].edge_index is not None:
                                            edge_group.create_dataset('edge_index', data=hetero_data[edge_type].edge_index.numpy())
                                        # If there are edge features, save them too
                                        if hasattr(hetero_data[edge_type], 'edge_attr') and hetero_data[edge_type].edge_attr is not None:
                                            edge_group.create_dataset('edge_attr', data=hetero_data[edge_type].edge_attr.numpy())
                    
                #todo. store some other data. sequence. uniprot info etc.
                    
            else:
                print('err' , pdbfile )


In [6]:
import tqdm
import pickle

chunksize = 10000
graphs = {}
for code in tqdm.tqdm(complexdf.code.unique()):
    sub = complexdf[complexdf.code == code]
    chains = set(sub.ch1.unique()).union( set( sub.ch2.unique() ))
    #add chains to graph
    G = nx.MultiGraph()
    G.add_nodes_from(chains)
    for i, row in sub.iterrows():
        #homology link
        if row.homo == 1:
            #add homology type edge
            ekey = G.add_edge(row.ch1, row.ch2, key='homology' )
        #add contact type edge
        ekey = G.add_edge(row.ch1, row.ch2, key='contact' )
    graphs[code] = G
    
    

with open('complexgraphs.pkl' , 'wb') as graphsout:
    graphsout.write( pickle.dumps( graphs ) )

    
#embed all structures
#use decoder and generate z vecs
#use decoder sigmoid to get contact proba

 40%|█████████████████████████████▊                                            | 36180/89764 [10:45<15:56, 56.02it/s]


KeyboardInterrupt: 

In [None]:
from torch_geometric.data import HeteroData
#import sageconv
from torch_geometric.nn import SAGEConv , Linear , FiLMConv , TransformerConv , FeaStConv , GATConv , GINConv , GatedGraphConv
#import module dict and module list
from torch.nn import ModuleDict, ModuleList , L1Loss
from torch_geometric.nn import global_mean_pool
#import negative sampling
from torch_geometric.utils import negative_sampling


EPS = 1e-10
def recon_loss( z1: Tensor,z2: Tensor, pos_edge_index: Tensor , backbone:Tensor = None , decoder = None , poslossmod = 1 , neglossmod= 1) -> Tensor:
    r"""Given latent variables :obj:`z`, computes the binary cross
    entropy loss for positive edges :obj:`pos_edge_index` and negative
    sampled edges.

    Args:
        z (torch.Tensor): The latent space :math:`\mathbf{Z}`.
        pos_edge_index (torch.Tensor): The positive edges to train against.
        neg_edge_index (torch.Tensor, optional): The negative edges to
            train against. If not given, uses negative sampling to
            calculate negative edges. (default: :obj:`None`)
    """
    
    pos =decoder(z1,z2, pos_edge_index, { ( 'res','backbone','res'): backbone } )[1]
    #turn pos edge index into a binary matrix
    pos_loss = -torch.log( pos + EPS).mean()
    neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
    neg = decoder(z1 , z2, neg_edge_index, { ( 'res','backbone','res'): backbone } )[1]
    neg_loss = -torch.log( ( 1 - neg) + EPS ).mean()
    return poslossmod*pos_loss + neglossmod*neg_loss

#define loss for x reconstruction   
def x_reconstruction_loss(x, recon_x):
    """
    compute the loss over the node feature reconstruction.
    """
    return F.l1_loss(recon_x, x)


#amino acid onehot loss for x reconstruction
def aa_reconstruction_loss(x, recon_x):
    """
    compute the loss over the node feature reconstruction.
    using categorical cross entropy
    """
    x = torch.argmax(x, dim=1)
    #recon_x = torch.argmax(recon_x, dim=1)
    return F.cross_entropy(recon_x, x)

def gaussian_loss(mu , logvar , beta= 1.5):
    '''
    
    add beta to disentangle the features
    
    '''
    kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return beta*kl_loss

In [None]:

class HeteroGAE_Pairwise_Decoder(torch.nn.Module):
    def __init__(self, encoder_out_channels, xdim=20, hidden_channels={'res_backbone_res': [20, 20, 20]}, out_channels_hidden=20, nheads = 1 , Xdecoder_hidden=30, metadata={}, amino_mapper= None):
        super(HeteroGAE_Decoder, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.metadata = metadata
        self.hidden_channels = hidden_channels
        self.out_channels_hidden = out_channels_hidden
        self.in_channels = encoder_out_channels
        self.amino_acid_indices = amino_mapper
        for i in range(len(self.hidden_channels[('res', 'backbone', 'res')])):
            self.convs.append(
                torch.nn.ModuleDict({
                    '_'.join(edge_type): SAGEConv(self.in_channels if i == 0 else self.hidden_channels[edge_type][i-1], self.hidden_channels[edge_type][i]  )
                    for edge_type in [('res', 'backbone', 'res')]
                })
            )

        self.lin = Linear(hidden_channels[('res', 'backbone', 'res')][-1], self.out_channels_hidden)
        
        self.sigmoid = nn.Sigmoid()

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear( self.in_channels + self.out_channels_hidden , Xdecoder_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(Xdecoder_hidden, Xdecoder_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(Xdecoder_hidden, Xdecoder_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(Xdecoder_hidden, Xdecoder_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(Xdecoder_hidden, xdim),
            torch.nn.LogSoftmax(dim=1)         )
        

    def forward(self, z1, z2, , edge_index, backbones, **kwargs):
        #copy z for later concatenation
        xrs = []
        zs = []
        for z in [z1,z2]:
            inz = z
            for layer in self.convs:
                for edge_type, conv in layer.items():
                    z = conv(z, backbones[tuple(edge_type.split('_'))])
                    z = F.relu(z)
            z = self.lin(z)
            x_r = self.decoder(  torch.cat( [inz,  z] , axis = 1) )
            xrs.append(x_r)
            zs.append(z)

        sim_matrix = (zs[0][edge_index[0]] * zs[1][edge_index[1]]).sum(dim=1)
        edge_probs = self.sigmoid(sim_matrix)
        #plddt_r = self.plddt_decoder(z)
        #return x_r, plddt_r,  edge_probs
        return x_r[0],x_r[1]  edge_probs

    def x_to_amino_acid_sequence(self, x_r):
        """
        Converts the reconstructed 20-dimensional matrix to a sequence of amino acids.

        Args:
            x_r (Tensor): Reconstructed 20-dimensional tensor.

        Returns:
            str: A string representing the sequence of amino acids.
        """
        # Find the index of the maximum value in each row to get the predicted amino acid
        indices = torch.argmax(x_r, dim=1)
        
        # Convert indices to amino acids
        amino_acid_sequence = ''.join(self.amino_mapper[idx.item()] for idx in indices)
        
        return amino_acid_sequence

In [None]:
#load pretrained encoder
encoder = torch.load("encoder.pt")
encoder.eval()

#create new decoder
decoder = HeteroGAE_Pairwise_Decoder(encoder_out_channels = encoder.out_channels , 
                            hidden_channels={ ( 'res','backbone','res'):[ 300 ] * 5  } , 
                            out_channels_hidden= 150 , metadata=metadata , amino_mapper = aaindex , Xdecoder_hidden=100 )


encoder_save = 'encoder_mk2_aa_50_AAq_transformermk2'
decoder_save = 'decoder_mk2_aa_50_AAq_transformermk2'

betafactor = 2
#put encoder and decoder on the device
encoder = encoder.to(device)
decoder = decoder.to(device)
# Create a DataLoader for training
total_loss_x = 0
total_loss_edge = 0
total_vq=0
total_kl = 0
total_plddt=0
# Training loop
if os.path.exists(encoder_save):
    encoder.load_state_dict(torch.load(encoder_save))
    decoder.load_state_dict(torch.load(decoder_save))
    print("loaded encoder and decoder")


In [None]:
    train_loader = DataLoader(struct_dat, batch_size=40, shuffle=True)

    optimizer = torch.optim.Adam( list(decoder.parameters()), lr=0.001)
    encoder.train()
    decoder.train()
    
    xlosses = []
    edgelosses = []
    vqlosses = []
    plddtlosses = []

    edgeweight = 1
    xweight = 2
    vqweight = 2
    plddtweight = 1

    for epoch in range(1000):
        for data in tqdm.tqdm(train_loader):
            data = data.to(device)
            optimizer.zero_grad()
            
            z1 = encoder(data['res'].x, data['AA'].x , data.edge_index_dict)
            z2 = encoder(data['res'].x, data['AA'].x , data.edge_index_dict)
            
            #add positional encoding to give to the decoder
            edgeloss = recon_loss(z1,z2, , data.edge_index_dict[( 'res','contactPoints','res')]
                                  , data.edge_index_dict[( 'res','backbone','res')], decoder)
            
            recon_x , edge_probs = decoder(z, data.edge_index_dict[( 'res','contactPoints','res')] , data.edge_index_dict )        
    
            xloss = aa_reconstruction_loss(data['AA'].x, recon_x)
            #plddtloss = x_reconstruction_loss(data['plddt'].x, recon_plddt)
            loss = xweight*xloss + edgeweight*edgeloss + vqweight*vqloss #+ plddtloss
            loss.backward()
            optimizer.step()
            total_loss_edge += edgeloss.item()
            total_loss_x += xloss.item()
            total_vq += vqloss.item()
            #total_plddt += plddtloss.item()

        if epoch % 100 == 0 :
            #save model
            torch.save(encoder.state_dict(), encoder_save)
            torch.save(decoder.state_dict(), decoder_save)
        """
        for loss in [( total_loss_x , xlosses , xweight ), (total_loss_edge, edgelosses, edgeweight), ( total_vq, vqlosses, vqweight ) ]:
            loss[1].append(loss[0])
            #calculate the mean delta loss for past 10 epochs
            if len(loss[1]) > 10:
                loss[1].pop(0)
                mean_loss = np.mean(loss[0:5])
                #calculate the delta loss for the last 5 epochs
                delta_loss = np.mean(loss[1][-5:])
                delta_loss = delta_loss- mean_loss
                if delta_loss > 0:
                    loss[2] = loss[2]*2
                else:
                    loss[2] = loss[2]/1.5
                loss[2] = np.clip(loss[2], 1e-5, 1e5)
        """    
        print(f'Epoch {epoch}, AALoss: {total_loss_x:.4f}, Edge Loss: {total_loss_edge:.4f}, vq Loss: {total_vq:.4f}') #, plddt Loss: {total_plddt:.4f}')
        total_loss_x = 0
        total_loss_edge = 0
        total_vq = 0
        #total_plddt = 0
    torch.save(encoder.state_dict(), encoder_save)
    torch.save(decoder.state_dict(), decoder_save)

In [None]:
#brainstorming to get complex graphs...

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.nn import TransformerDecoder, TransformerDecoderLayer

class GraphGenerationModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers, num_classes):
        super(GraphGenerationModel, self).__init__()
        self.encoder_layer = TransformerEncoderLayer(d_model=input_dim, nhead=nhead)
        self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=num_encoder_layers)
        
        self.decoder_layer = TransformerDecoderLayer(d_model=output_dim, nhead=nhead)
        self.transformer_decoder = TransformerDecoder(self.decoder_layer, num_layers=num_decoder_layers)
        
        self.node_mlp = nn.Linear(output_dim, num_classes + 1)  # +1 for EOS token
        self.edge_mlp = nn.Linear(2 * output_dim, 1)
        self.embedding = nn.Embedding(num_classes + 1, output_dim)
        
    def encode(self, x):
        return self.transformer_encoder(x)
    
    def decode(self, z, memory):
        return self.transformer_decoder(z, memory)
    
    def forward(self, x, max_length=20):
        # Encode the input sequence using the transformer encoder
        memory = self.encode(x)  # Shape: [batch_size, seq_len, input_dim]
        
        # Start with an initial input (e.g., a start token)
        start_token = torch.zeros((memory.size(0), 1, memory.size(-1)), device=x.device)
        z = start_token
        
        node_logits = []
        for _ in range(max_length):
            z = self.decode(z, memory)
            node_logit = self.node_mlp(z[:, -1, :])  # Take the last token's output
            node_logits.append(node_logit)
            
            # Get the next input token (embedding of the predicted class)
            next_token = node_logit.argmax(dim=-1).unsqueeze(1)
            next_token_embedding = self.embedding(next_token)
            z = torch.cat([z, next_token_embedding], dim=1)
            
            if (next_token == num_classes).all():  # EOS token
                break
        
        node_logits = torch.stack(node_logits, dim=1)
        
        # Decode edges (fully connected example, customize as needed)
        num_nodes = node_logits.size(1)
        edge_index = torch.combinations(torch.arange(num_nodes), r=2).t().contiguous()
        edge_features = torch.cat([node_logits[:, edge_index[0]], node_logits[:, edge_index[1]]], dim=-1)
        edge_logits = self.edge_mlp(edge_features).squeeze(-1)
        edge_probs = torch.sigmoid(edge_logits)
        
        return node_logits, edge_index, edge_probs

# Example usage
input_dim = 10   # Dimension of input embeddings
hidden_dim = 32  # Dimension of hidden layer
output_dim = 16  # Dimension of node features
nhead = 2        # Number of heads in the multi-head attention
num_encoder_layers = 2  # Number of transformer encoder layers
num_decoder_layers = 2  # Number of transformer decoder layers
num_classes = 5  # Number of node categories (excluding EOS token)

model = GraphGenerationModel(input_dim, hidden_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers, num_classes)

# Example input: batch of sequences (batch_size, seq_len, input_dim)
x = torch.randn(5, 10, input_dim)  # Batch of 5 sequences, each of length 10

node_logits, edge_index, edge_probs = model(x)
print("Node logits:\n", node_logits)
print("Edge index:\n", edge_index)
print("Edge probabilities:\n", edge_probs)


In [None]:
import torch.optim as optim

def train(model, dataloader, epochs=10):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    edge_loss_fn = nn.BCELoss()  # Loss function for edge probabilities
    node_loss_fn = nn.NLLLoss()  # Loss function for node classification

    for epoch in range(epochs):
        model.train()
        total_edge_loss = 0
        total_node_loss = 0

        for sequences, graphs, node_labels in dataloader:
            optimizer.zero_grad()
            
            node_logits, edge_index, edge_probs = model(sequences)

            # Flatten the edge probabilities and create a target tensor
            target_edge_probs = torch.cat([graph.edge_index for graph in graphs], dim=1)
            target = torch.ones(target_edge_probs.size(1))  # Example target, customize as needed

            # Compute edge loss
            edge_loss = edge_loss_fn(edge_probs, target)

            # Compute node classification loss
            node_labels = node_labels.view(-1)
            node_classes = node_logits.view(-1, num_classes + 1)  # +1 for EOS token
            node_loss = node_loss_fn(node_classes, node_labels)

            # Total loss
            loss = edge_loss + node_loss
            loss.backward()
            optimizer.step()

            total_edge_loss += edge_loss.item()
            total_node_loss += node_loss.item()

        print(f"Epoch {epoch+1}, Edge Loss: {total_edge_loss / len(dataloader)}, Node Loss: {total_node_loss / len(dataloader)}")

# Train the model
train(model, dataloader)
