# Getting Amino Acids and Small Molecules to "Speak the Same Language"
![EmbeddingStrategy](images/EmbeddingStrategy.png)<br>
## Overview
The first challenge we faced was embedding amino acids and small molecules in the same latent space. The solution we came up with was to represent each atom in a small molecule as a graph node and treat amino acids as a sort of "psuedo-atom". <br>
We used an encoder-decoder architechture that took as an input an amino acid represnted as a graph and output the amino acids associated [BLOSUM62](https://en.wikipedia.org/wiki/BLOSUM) matrix row. The graph contained an additional "master" node connected to all other nodes that facilitated the flow of information and served as a convenient read-out. The encoder consisted of three rounds of message passing. At this point, the encoded information is stored in the master node. The decoder is simply a single densley connected layer.<br>
Each node in the graph represents an atom in the amino acid. The features of each atom came from [Fang et. al](https://www.nature.com/articles/s42256-023-00654-0). These are the same features we use for the sequence prediction network. Our idea was that by briding the gap between atomic and amino acid level information, we would be able to treat amino acids and atoms on equal footing. This is the weakest assumption in the construction of our model and needs furth investigation to determine whether this stradegy is sound

In [1]:
import numpy as np
import pandas as pd
import random
import torch
import math
from Bio import SeqIO
import Bio.PDB
import pickle as pickle
import torch.nn as nn
from torch_geometric.nn import GENConv
from torch_geometric.nn.models import MLP
from torch_geometric.data import Data
from torch_geometric.nn.aggr import MeanAggregation
import matplotlib.pyplot as plt
import os
from Bio import PDB
from rdkit import Chem
import blosum as bl

**Replace File Paths**

In [2]:
class CFG:
    pdbfiles: str = "/home/paul/Desktop/BioHack-Project-Walkthrough/pdbind-refined-set/"
    AA_mol2_files: str = "/home/paul/Desktop/BioHack-Project-Walkthrough/AA_mol2/"

## Graph Construction
This will be discussed in more detail in future notebooks. The graphs were constructed in the same way that the small molecule graphs will be constructed in the future.

In [36]:
# Load atomic embeddings from Fang et al.
with open('atom2emb.pkl', 'rb') as f:
    atom2emb = pickle.load(f)

# Load bond dictionary from earlier notebook 
with open('bond_type_dict.pkl', 'rb') as f:
    bond_type_dict = pickle.load(f)
    
#Self-explanatory dictionary - if only people could use the amino acid codes
one_letter_to_three_letter_dict = {'G':'gly',
                                   'A':'ala',
                                   'V':'val',
                                   'C':'cys',
                                   'P':'pro',
                                   'L':'leu',
                                   'I':'ile',
                                   'M':'met',
                                   'W':'trp',
                                   'F':'phe',
                                   'K':'lys',
                                   'R':'arg',
                                   'H':'his',
                                   'S':'ser',
                                   'T':'thr',
                                   'Y':'tyr',
                                   'N':'asn',
                                   'Q':'gln',
                                   'D':'asp',
                                   'E':'glu'
    
}

#Self-explanatory dictionary - if only people could use the amino acid codes
upper2lower = {
    "ala": "ALA",
    "arg": "ARG",
    "asn": "ASN",
    "asp": "ASP",
    "cys": "CYS",
    "gln": "GLN",
    "glu": "GLU",
    "gly": "GLY",
    "his": "HIS",
    "ile": "ILE",
    "leu": "LEU",
    "lys": "LYS",
    "met": "MET",
    "phe": "PHE",
    "pro": "PRO",
    "ser": "SER",
    "thr": "THR",
    "trp": "TRP",
    "tyr": "TYR",
    "val": "VAL",
}

def BLOSUM_encode_single(seq,AA_dict):
    allowed = set("gavcplimwfkrhstynqdeuogavcplimwfkrhstynqde")
    if not set(seq).issubset(allowed):
        invalid = set(seq) - allowed
        raise ValueError(f"Sequence has broken AA: {invalid}")
    vec = AA_dict[seq]
    return vec

matrix = bl.BLOSUM(62)
allowed_AA = "GAVCPLIMWFKRHSTYNQDE"
BLOSUM_dict_three_letter = {}
for i in allowed_AA:
    vec = []
    for j in allowed_AA:
        vec.append(matrix[i][j])
    BLOSUM_dict_three_letter.update({one_letter_to_three_letter_dict[i]:torch.Tensor(vec)})

    
def read_mol2_bonds_and_atoms(mol2_file):
    bonds = []
    bond_types = []
    atom_types = {}
    atom_coordinates = {}

    with open(mol2_file, 'r') as mol2:
        reading_bonds = False
        reading_atoms = False
        for line in mol2:
            if line.strip() == '@<TRIPOS>BOND':
                reading_bonds = True
                continue
            elif line.strip() == '@<TRIPOS>ATOM':
                reading_atoms = True
                continue
            elif line.strip().startswith('@<TRIPOS>SUBSTRUCTURE'):
                break
            elif reading_bonds and line.strip().startswith('@<TRIPOS>'):
                reading_bonds = False
            elif reading_atoms and line.strip().startswith('@<TRIPOS>'):
                reading_atoms = False


            if reading_bonds:
                parts = line.split()
                if len(parts) >= 4:
                    atom1_index = int(parts[1])
                    atom2_index = int(parts[2])
                    bond_type = parts[3]
                    bonds.append((atom1_index, atom2_index))
                    bond_types.append(bond_type)

            if reading_atoms:
                parts = line.split()
                if len(parts) >= 6:
                    atom_index = int(parts[0])
                    atom_type = parts[5]
                    x, y, z = float(parts[2]), float(parts[3]), float(parts[4])
                    atom_types[atom_index] = atom_type.split('.')[0]
                    atom_coordinates[atom_index] = (x, y, z)

    return bonds, bond_types, atom_types, atom_coordinates  

def molecule2graph_AA(filename,map_distance, norm_map_distance = 12.0):
    node_feature = []
    edge_index = []
    edge_attr = []
    mol2_file = CFG.AA_mol2_files+filename
    bonds, bond_types, atom_types, atom_coordinates = read_mol2_bonds_and_atoms(mol2_file)
    for atom in atom_types:
        node_feature.append(torch.Tensor(atom2emb[atom_types[atom]]))
    

    for atom1 in range(1, len(atom_types)+1):
        for atom2 in range(atom1 + 1, len(atom_types)+1):
            bonded_flag = 0
            for i, bond in enumerate(bonds):
                if (atom1 in bond) and (atom2 in bond):
                    edge_index.append([bond[0] - 1,bond[1] - 1])
                    coord1 = np.array(atom_coordinates[bond[0]])
                    coord2 = np.array(atom_coordinates[bond[1]])
                    dist = math.dist(coord1, coord2)
                    d = []
                    for l in range(12):
                        d.append(np.exp((-1.0*(dist - 2.0*(l + 0.5))**2.0)/norm_map_distance))
                    bond_type = bond_type_dict[bond_types[i]]
                    edge_attr.append(np.hstack((d,d,d,d,d,d,d,d,d,bond_type)))
                    bonded_flag = 1
                
            if bonded_flag == 0:
                coord1 = np.array(atom_coordinates[atom1])
                coord2 = np.array(atom_coordinates[atom2])
                dist = math.dist(coord1, coord2)
                if dist < map_distance:
                    edge_index.append([atom1 - 1,atom2 - 1])
                    d = []
                    for l in range(12):
                        d.append(np.exp((-1.0*(dist - 2.0*(l + 0.5))**2.0)/norm_map_distance))
                    bond_type = bond_type_dict['nc']
                    edge_attr.append(np.hstack((d,d,d,d,d,d,d,d,d,bond_type)))

    
    edge_index = np.array(edge_index)
    edge_index = edge_index.transpose()
    edge_index = torch.Tensor(edge_index)
    edge_index = edge_index.to(torch.int64)
    edge_attr = torch.Tensor(np.array(edge_attr))
    node_feature = torch.stack(node_feature)
    
    #Master_node
    new_edge_index = []
    new_edge_attr = []
    node_feature = torch.cat((node_feature,torch.zeros(len(atom2emb['N'])).unsqueeze(0)),dim = 0)
    
    for i in range(len(node_feature) - 1):
        new_edge_index.append([i,int(len(node_feature)-1)])
        bond_type = bond_type_dict['nc']
        new_edge_attr.append(np.hstack((np.zeros(9*len(d)),bond_type)))
    
    new_edge_index = np.array(new_edge_index)
    new_edge_index = new_edge_index.transpose()
    new_edge_index = torch.Tensor(new_edge_index)
    new_edge_index = new_edge_index.to(torch.int64)
    new_edge_attr = torch.Tensor(np.array(new_edge_attr))    
    
    edge_index = torch.cat((edge_index,new_edge_index), dim = 1)
    edge_attr = torch.cat((edge_attr,new_edge_attr), dim = 0)
    
    graph = Data(x = node_feature, edge_index = edge_index,edge_attr = edge_attr)#, pos = new_mol_coords)
    graph.label = filename.split('.')[0]
    softmax = nn.Softmax(dim = 0)
    graph.y = softmax(BLOSUM_encode_single(graph.label,BLOSUM_dict_three_letter))
    return graph

In [48]:
AA_graphs = []
for filename in os.listdir(CFG.AA_mol2_files):
    AA_graphs.append(molecule2graph_AA(filename,12.0))

Read about the message passing module used in the encoder [here](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GENConv.html#torch_geometric.nn.conv.GENConv)

In [28]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.node_feature_size = 133
        self.node_feature_hidden_size = 133
        self.node_feature_size_out = 20
        self.conv1 = GENConv(self.node_feature_size,self.node_feature_hidden_size,aggr = 'mean',edge_dim = 114, num_layer = 2,norm = 'layer')
        self.conv2 = GENConv(self.node_feature_hidden_size,self.node_feature_hidden_size,aggr = 'mean',edge_dim = 114,num_layer = 2,norm = 'layer')
        self.conv3 = GENConv(self.node_feature_hidden_size,self.node_feature_hidden_size,aggr = 'mean',edge_dim = 114,num_layer = 2,norm = 'layer')
        self.linear1 = nn.Linear(self.node_feature_hidden_size,self.node_feature_size_out)
        self.ReLu = nn.ReLU()
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self,graph):
        x, edge_index, edge_attr = graph.x, graph.edge_index, graph.edge_attr
        x1 = self.conv1(x, edge_index, edge_attr)
        x1 = self.ReLu(x1)
        x1 = self.conv2(x1, edge_index, edge_attr)
        x1 = self.ReLu(x1)
        x1 = self.conv3(x1, edge_index, edge_attr)
        x1 = x1[-1]
        x1 = torch.tanh(x1)
        x1 = self.linear1(x1)
        return x1
    
    def encode(self,graph):
        x, edge_index, edge_attr = graph.x,graph.edge_index,graph.edge_attr
        x1 = self.conv1(x, edge_index,edge_attr)
        x1 = self.ReLu(x1)
        x1 = self.conv2(x1, edge_index,edge_attr)
        x1 = self.ReLu(x1)
        x1 = self.conv3(x1, edge_index,edge_attr)
        x1 = x1[-1]
        x1 = torch.tanh(x1)
        return x1
    
    def decode(self,encoding):
        x1 = self.linear1(encoding)
        return x1

In [29]:
from torch_geometric.loader import DataLoader
train_dl = DataLoader(AA_graphs,batch_size = 1, shuffle = True)

## Training Loop
The model was trained for 1000 epochs to optimize the mean squared error loss. The optimization was done using the adam optimizer with a learning rate of 1e-4 and L2 normaliztion of 1e-4. The learning rate was dropped one order of magnitude after 500 epochs and again after 800 epochs. No hyperparameter tuning was done. The model that achieved the lowest loss is saved to generate the amino acid embeddings

In [30]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(12345)
model = Net()
model.to(DEVICE)

criterion = loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[500,800], gamma=0.1)

# Training loop
num_epochs = 1000
losses = []
lowest = 0.01

for epoch in range(num_epochs):
    total_loss = 0.0
    val_loss = 0.0
    
        
    for batch in train_dl:
        model.train()
        inputs = batch[0].to(DEVICE)

        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        
        # Compute the loss
        loss = criterion(outputs, inputs.y)
        
        # Backpropagation and optimization
        loss.backward()
        optimizer.step()
        
        inputs= inputs.to('cpu')
        
        total_loss += loss.item()
            
    for batch in train_dl:
        with torch.no_grad():
            model.eval()
            inputs = batch[0].to(DEVICE)
        
            # Forward pass
            outputs = model(inputs)
        
            # Compute the loss
            loss = criterion(outputs, inputs.y)
        
            inputs= inputs.to('cpu')
            
            val_loss += loss.item()
            
    # Print the average loss for this epoch
    avg_loss = total_loss / len(train_dl)
    val_avg_loss = val_loss / len(train_dl)
    
    if lowest > val_avg_loss:
        torch.save(model.state_dict(), 'AA_encoder.pt')
        lowest = val_avg_loss
    
    print(f'Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f} Val Loss: {val_avg_loss:.4f}')
    
print('Training complete')


Epoch [1/1000] Loss: 0.0530 Val Loss: 0.0405
Epoch [2/1000] Loss: 0.0430 Val Loss: 0.0382
Epoch [3/1000] Loss: 0.0417 Val Loss: 0.0389
Epoch [4/1000] Loss: 0.0416 Val Loss: 0.0390
Epoch [5/1000] Loss: 0.0426 Val Loss: 0.0377
Epoch [6/1000] Loss: 0.0401 Val Loss: 0.0380
Epoch [7/1000] Loss: 0.0408 Val Loss: 0.0383
Epoch [8/1000] Loss: 0.0402 Val Loss: 0.0376
Epoch [9/1000] Loss: 0.0395 Val Loss: 0.0368
Epoch [10/1000] Loss: 0.0396 Val Loss: 0.0372
Epoch [11/1000] Loss: 0.0405 Val Loss: 0.0366
Epoch [12/1000] Loss: 0.0391 Val Loss: 0.0367
Epoch [13/1000] Loss: 0.0394 Val Loss: 0.0377
Epoch [14/1000] Loss: 0.0397 Val Loss: 0.0361
Epoch [15/1000] Loss: 0.0387 Val Loss: 0.0366
Epoch [16/1000] Loss: 0.0387 Val Loss: 0.0355
Epoch [17/1000] Loss: 0.0381 Val Loss: 0.0355
Epoch [18/1000] Loss: 0.0378 Val Loss: 0.0350
Epoch [19/1000] Loss: 0.0374 Val Loss: 0.0348
Epoch [20/1000] Loss: 0.0369 Val Loss: 0.0339
Epoch [21/1000] Loss: 0.0370 Val Loss: 0.0351
Epoch [22/1000] Loss: 0.0360 Val Loss: 0.03

Epoch [180/1000] Loss: 0.0012 Val Loss: 0.0010
Epoch [181/1000] Loss: 0.0012 Val Loss: 0.0009
Epoch [182/1000] Loss: 0.0012 Val Loss: 0.0008
Epoch [183/1000] Loss: 0.0010 Val Loss: 0.0010
Epoch [184/1000] Loss: 0.0011 Val Loss: 0.0010
Epoch [185/1000] Loss: 0.0013 Val Loss: 0.0009
Epoch [186/1000] Loss: 0.0013 Val Loss: 0.0008
Epoch [187/1000] Loss: 0.0012 Val Loss: 0.0010
Epoch [188/1000] Loss: 0.0013 Val Loss: 0.0014
Epoch [189/1000] Loss: 0.0011 Val Loss: 0.0012
Epoch [190/1000] Loss: 0.0011 Val Loss: 0.0013
Epoch [191/1000] Loss: 0.0010 Val Loss: 0.0012
Epoch [192/1000] Loss: 0.0010 Val Loss: 0.0009
Epoch [193/1000] Loss: 0.0012 Val Loss: 0.0008
Epoch [194/1000] Loss: 0.0009 Val Loss: 0.0008
Epoch [195/1000] Loss: 0.0009 Val Loss: 0.0007
Epoch [196/1000] Loss: 0.0008 Val Loss: 0.0007
Epoch [197/1000] Loss: 0.0009 Val Loss: 0.0013
Epoch [198/1000] Loss: 0.0009 Val Loss: 0.0010
Epoch [199/1000] Loss: 0.0013 Val Loss: 0.0008
Epoch [200/1000] Loss: 0.0010 Val Loss: 0.0019
Epoch [201/10

Epoch [357/1000] Loss: 0.0004 Val Loss: 0.0003
Epoch [358/1000] Loss: 0.0005 Val Loss: 0.0009
Epoch [359/1000] Loss: 0.0010 Val Loss: 0.0015
Epoch [360/1000] Loss: 0.0024 Val Loss: 0.0017
Epoch [361/1000] Loss: 0.0011 Val Loss: 0.0009
Epoch [362/1000] Loss: 0.0008 Val Loss: 0.0006
Epoch [363/1000] Loss: 0.0007 Val Loss: 0.0004
Epoch [364/1000] Loss: 0.0004 Val Loss: 0.0003
Epoch [365/1000] Loss: 0.0003 Val Loss: 0.0002
Epoch [366/1000] Loss: 0.0002 Val Loss: 0.0002
Epoch [367/1000] Loss: 0.0002 Val Loss: 0.0002
Epoch [368/1000] Loss: 0.0002 Val Loss: 0.0001
Epoch [369/1000] Loss: 0.0002 Val Loss: 0.0001
Epoch [370/1000] Loss: 0.0002 Val Loss: 0.0001
Epoch [371/1000] Loss: 0.0001 Val Loss: 0.0002
Epoch [372/1000] Loss: 0.0002 Val Loss: 0.0001
Epoch [373/1000] Loss: 0.0002 Val Loss: 0.0001
Epoch [374/1000] Loss: 0.0001 Val Loss: 0.0001
Epoch [375/1000] Loss: 0.0001 Val Loss: 0.0001
Epoch [376/1000] Loss: 0.0001 Val Loss: 0.0001
Epoch [377/1000] Loss: 0.0001 Val Loss: 0.0001
Epoch [378/10

Epoch [533/1000] Loss: 0.0002 Val Loss: 0.0003
Epoch [534/1000] Loss: 0.0009 Val Loss: 0.0031
Epoch [535/1000] Loss: 0.0091 Val Loss: 0.0176
Epoch [536/1000] Loss: 0.0234 Val Loss: 0.0186
Epoch [537/1000] Loss: 0.0169 Val Loss: 0.0113
Epoch [538/1000] Loss: 0.0080 Val Loss: 0.0055
Epoch [539/1000] Loss: 0.0061 Val Loss: 0.0076
Epoch [540/1000] Loss: 0.0142 Val Loss: 0.0086
Epoch [541/1000] Loss: 0.0087 Val Loss: 0.0046
Epoch [542/1000] Loss: 0.0056 Val Loss: 0.0061
Epoch [543/1000] Loss: 0.0041 Val Loss: 0.0021
Epoch [544/1000] Loss: 0.0017 Val Loss: 0.0009
Epoch [545/1000] Loss: 0.0009 Val Loss: 0.0005
Epoch [546/1000] Loss: 0.0005 Val Loss: 0.0003
Epoch [547/1000] Loss: 0.0004 Val Loss: 0.0005
Epoch [548/1000] Loss: 0.0003 Val Loss: 0.0002
Epoch [549/1000] Loss: 0.0002 Val Loss: 0.0001
Epoch [550/1000] Loss: 0.0002 Val Loss: 0.0001
Epoch [551/1000] Loss: 0.0001 Val Loss: 0.0001
Epoch [552/1000] Loss: 0.0001 Val Loss: 0.0001
Epoch [553/1000] Loss: 0.0001 Val Loss: 0.0000
Epoch [554/10

Epoch [710/1000] Loss: 0.0146 Val Loss: 0.0054
Epoch [711/1000] Loss: 0.0055 Val Loss: 0.0050
Epoch [712/1000] Loss: 0.0048 Val Loss: 0.0066
Epoch [713/1000] Loss: 0.0044 Val Loss: 0.0044
Epoch [714/1000] Loss: 0.0035 Val Loss: 0.0015
Epoch [715/1000] Loss: 0.0015 Val Loss: 0.0007
Epoch [716/1000] Loss: 0.0010 Val Loss: 0.0004
Epoch [717/1000] Loss: 0.0004 Val Loss: 0.0005
Epoch [718/1000] Loss: 0.0003 Val Loss: 0.0002
Epoch [719/1000] Loss: 0.0002 Val Loss: 0.0001
Epoch [720/1000] Loss: 0.0001 Val Loss: 0.0001
Epoch [721/1000] Loss: 0.0001 Val Loss: 0.0001
Epoch [722/1000] Loss: 0.0001 Val Loss: 0.0001
Epoch [723/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [724/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [725/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [726/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [727/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [728/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [729/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [730/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [731/10

Epoch [887/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [888/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [889/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [890/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [891/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [892/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [893/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [894/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [895/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [896/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [897/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [898/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [899/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [900/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [901/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [902/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [903/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [904/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [905/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [906/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [907/1000] Loss: 0.0000 Val Loss: 0.0000
Epoch [908/10

In [41]:
model.load_state_dict(torch.load('AA_encoder.pt'))
with torch.no_grad():
    AA_embeddings = {}
    for graph in AA_graphs:
        pred = model.encode(graph.to(DEVICE))
        AA_embeddings.update({upper2lower[graph.label]:pred})

In [43]:
with open('AA_embeddings.pkl', 'wb') as f:
    pickle.dump(AA_embeddings, f)