<a href="https://colab.research.google.com/github/albertomariapepe/Using-a-Graph-Transformer-network-to-predict-3D-coordinates-of-proteins-via-Geometric-Algebra-model/blob/main/Graph_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! wget http://deep.cs.umsl.edu/pdnet/train-data.tar.gz
! tar zxf train-data.tar.gz

In [None]:
!pip install graph-transformer-pytorch

In [None]:
!pip install pytorch-msssim

In [None]:
!pip install pycuda

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import random
from pytorch_msssim import SSIM
import gc
from graph_transformer_pytorch import GraphTransformer
import time
from google.colab import files
import pycuda.driver as cuda

In [None]:
def get_nodes_and_edges(pdb, all_feat_paths, node_n, edge_n):
    features = None
    for path in all_feat_paths:
        if os.path.exists(path + pdb + '.pkl'):
            features = pickle.load(open(path + pdb + '.pkl', 'rb'))
    l = len(features['seq'])
    seq = features['seq']

    nodes = torch.zeros(1, l, node_n)
    edges = torch.zeros(1, l, l, edge_n) 
    mask = torch.ones(1, l).bool()

    ######NODES FEATURES########

    # Add secondary structure
    ss = features['ss']
    assert ss.shape == (3, l)
    fi = 0
    gi = 0

    for j in range(3):
        a = np.repeat(ss[j].reshape(1, l), l, axis = 0)
        a = a[0,0:l]
        np.reshape(a, [1, l])
        nodes[:, :, fi] = torch.from_numpy(a).to(nodes)
        fi += 1
    # Add PSSM
    pssm = features['pssm']
    assert pssm.shape == (l, 22)
    for j in range(22):
        a = np.repeat(pssm[:, j].reshape(1, l), l, axis = 0)
        a = a[0,0:l]
        np.reshape(a, [1, l])
        nodes[:, :, fi] = torch.from_numpy(a).to(nodes)
        fi += 1

    # Add SA
    sa = features['sa']
    assert sa.shape == (l, )
    a = np.repeat(sa.reshape(1, l), l, axis = 0)
    a = a[0,0:l]
    np.reshape(a, [1, l])
    nodes[:, :, fi] = torch.from_numpy(a).to(nodes)
    fi += 1

    # Add entropy
    entropy = features['entropy']
    assert entropy.shape == (l, )
    a = np.repeat(entropy.reshape(1, l), l, axis = 0)
    a = a[0,0:l]
    np.reshape(a, [1, l])
    nodes[:, :, fi] = torch.from_numpy(a).to(nodes)
    fi += 1

    ######EDGES FEATURES########

    # Add CCMpred
    ccmpred = features['ccmpred']
    assert ccmpred.shape == ((l, l))
    edges[:, :, :, gi] = torch.from_numpy(ccmpred).to(edges)
    gi += 1
    # Add  FreeContact
    freecon = features['freecon']
    assert freecon.shape == ((l, l))
    edges[:, :, :, gi] = torch.from_numpy(freecon).to(edges)
    gi += 1
    # Add potential
    potential = features['potential']
    assert potential.shape == ((l, l))
    edges[:, :, :, gi] = torch.from_numpy(potential).to(edges)
    gi += 1
    
    
    #cost = np.load('drive/MyDrive/Output-ALL/'+ pdb + '-cb.npy', allow_pickle = True)
    cost = np.load('drive/MyDrive/COSTS-predicted-realdistance/'+ pdb + '-cb.npy', allow_pickle = True)

    edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
    gi += 1 
    
    
    #distance = np.load('drive/MyDrive/DISTANCES/'+ pdb + '-cb.npy', allow_pickle = True)
    distance = np.load('drive/MyDrive/DISTANCES-predicted/'+ pdb + '-cb.npy', allow_pickle = True)
    edges[:, :, :, gi] = torch.from_numpy(distance[2]).to(edges)

    return nodes, edges, mask, l


In [None]:
class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.fc1 = nn.Linear(27,3, bias = False)

    def forward(self, x, old):
      output = self.fc1(x)
      return output

In [None]:
node_n = 27
edge_n = 5

GT = GraphTransformer(
    dim = node_n,
    depth = 3,
    heads = 4,
    edge_dim = edge_n,             # optional - if left out, edge dimensions is assumed to be the same as the node dimensions above
    with_feedforwards = True,   # whether to add a feedforward after each attention layer, suggested by literature to be needed
    gated_residual = True,      # to use the gated residual to prevent over-smoothing
    rel_pos_emb = True          # set to True if the nodes are ordered, default to False
)

projector3D = Net()

GT.cuda()
projector3D.cuda()
model = nn.Sequential(GT, projector3D)
model.cuda()

In [None]:
dir = '/content/drive/MyDrive/DISTANCES/'
lst = os.listdir(dir)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cuda.init()
dir_dataset = './data/'
all_feat_paths = [dir_dataset + '/deepcov/features/', dir_dataset + '/psicov/features/', dir_dataset + '/cameo/features/']
all_dist_paths = [dir_dataset + '/deepcov/distance/', dir_dataset + '/psicov/distance/', dir_dataset + '/cameo/distance/']




if os.path.isfile('/content/model.pt'):
    print('loading model...')
    model.load_state_dict(torch.load("/content/model.pt"))
    
optimizer = optim.Adam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, verbose = True)

torch.backends.cudnn.benchmark = True

lsttrain = lst[0:200]
lstval = lst[1000:1200]

epochs = 5
n_batches = int(len(lsttrain))
batch_size = len(lsttrain) // n_batches
#batch_size = 1

L = nn.L1Loss()
ssim_module = SSIM(data_range=255, size_average=True, channel=1)
counter = 0
loss = 0
j_in = 0
i_in = 0
alpha = 20

edge_len = 4 #it is 4 when we include cost maps
for i in range(i_in, epochs): 
        #torch.cuda.empty_cache()
        #gc.collect()
        #torch.no_grad()

        lstbatch = []
        #random.shuffle(lsttrain)
        filename1 = random.choice(lstval)
        print('****')
        
        if i == 0:
          j_in = 0
        
        else:
          j_in = 0
        
        for j in range(j_in, n_batches):
            #torch.cuda.empty_cache()
            #gc.collect()
            #torch.no_grad()

            lstbatch = lsttrain[j*batch_size:(j+1)*batch_size]
            
    
            loss = 0
            val_loss = 0
            counter = 0

            #print(lstbatch)
            #print(len(lstbatch))
            #print(lstbatch)

            for filename in lstbatch:
                
                filename = os.path.splitext(filename)[0]
                filename = filename[:-3]
                #print(filename)
                loss1 = 0
                loss2 = 0
                
                
                nodes, edges, mask, l = get_nodes_and_edges(filename, all_feat_paths, node_n, edge_n)
                #print("TIME TO EXTRACT: ", time.time() - start)
                
                nodes = nodes.cuda()
                edges = edges.cuda()
                mask = mask.cuda()
               
              
                nodes_new, edges_new = GT(nodes, edges, mask = mask)
                
                nodes_new = nodes_new.cuda()
                edges_new = edges_new.cuda()
                
                coord = projector3D(nodes_new, nodes)
                
                #noise = np.random.normal(0,1,[coord.shape[1], 3])
                #coord = coord + torch.from_numpy(noise).cuda()


                coord.cuda()
                

                Y = edges[:,:,:,edge_len]

                X = torch.zeros(coord.shape[1], coord.shape[1]) 
                pred_dist = torch.zeros(1, coord.shape[1], coord.shape[1]) 

                Y.cuda()
                X.cuda()

                pred_dist.cuda()

                start = time.time()
                for p in range (0,int(coord.shape[1])):
                    for q in range (0, p):
                        X[p,q] = torch.linalg.norm(coord[0,p,:] -  coord[0,q,:])

               
                pred_dist[0] = X + X.T - torch.diag(X)
                #print("ELAPSED TIME to EVALUATE PRED MAP: ", time.time() - start)

                #print(X)


                pred_dist = pred_dist.cuda()


                loss1 = L(pred_dist, Y)
                loss2 = 1 - ssim_module(pred_dist.unsqueeze(0), Y.unsqueeze(0))
                loss = loss1 + alpha*loss2
                counter = counter + 1
                
                

                del nodes, edges, nodes_new, edges_new, mask, coord, X, Y, pred_dist
                torch.cuda.empty_cache()
                gc.collect()
                torch.no_grad()

                ####################################Validation
                
                filename = filename1

                filename = os.path.splitext(filename)[0]
                filename = filename[:-3]
      

                nodes, edges, mask, l = get_nodes_and_edges(filename, all_feat_paths, node_n, edge_n)

                nodes = nodes.cuda()
                edges = edges.cuda()
                mask = mask.cuda()

                nodes_new, edges_new = GT(nodes, edges, mask = mask)
                
                nodes_new = nodes_new.cuda()
                edges_new = edges_new.cuda()

                coord = projector3D(nodes_new, nodes)
                coord.cuda()

              
                Y = edges[:,:,:,edge_len]


                #lstval.remove(filename+'-cb.npy')

                X = torch.zeros(coord.shape[1], coord.shape[1]) 
                pred_dist = torch.zeros(1, coord.shape[1], coord.shape[1]) 

                Y = Y.cuda()
                X = X.cuda()
                pred_dist = pred_dist.cuda()

                for p in range (0,int(coord.shape[1])):
                    for q in range (0, p):
                        X[p,q] = torch.linalg.norm(coord[0,p,:] -  coord[0,q,:])

                pred_dist[0] = X + X.T - torch.diag(torch.diag(X))
                pred_dist.detach_()
                val_loss = L(pred_dist, Y) + alpha*(1 - ssim_module(pred_dist.unsqueeze(0), Y.unsqueeze(0)))

                del nodes, edges, mask, nodes_new, edges_new, coord, X, Y, pred_dist
                #torch.cuda.empty_cache()
                #gc.collect()
                #torch.no_grad()
                
            
            if i % 1 == 0 and j % 10 == 0:
                 print('epoch: %d,  batch: %d,  total loss: %.3f,  val loss: %.3f' % ( i, j, loss.item()/batch_size, val_loss.item()))
                 print('MAE: %.3f,  SSIM: %.3f' % ( loss1, 1-loss2))
                 torch.save(model.state_dict(), "/content/model.pt")
                 print('model saved!')
                 #if j % 5 == 0:
                   #files.download("/content/model.pt")
                   #print("downloading model trained at epoch: %d, batch: %d " %(i, j))


            start = time.time()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        scheduler.step()

            #print("ELAPSED TIME to UPDATE WEIGHTS: ", time.time() - start)
            #if i % 1 == 0 and j %  == 0:
                
    
        counter = 0

        #del nodes, edges, mask, nodes_new, edges_new, coord, X, Y, pred_dist
        #torch.cuda.empty_cache()
        #gc.collect()
        #torch.no_grad()
                
        
        '''
        #save the model:
        # model_name = '[model_type]_[data_type]'
        save_checkpoint(
        save_dir='/content/',
        state={
            'model': model, 
            'name': "GTandprojector",
            'epoch': i + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'seed': SEED,
            'loss': loss
    
            }
        )
        '''

   

In [None]:
from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import structural_similarity as ssim

dir = '/content/drive/MyDrive/DISTANCES/'
lst = os.listdir(dir)
edge_len = 4 #it is 4 when we include cost maps

dir_dataset = './data/'

projector3D = Net()
GT.cuda()
projector3D.cuda()
model = nn.Sequential(GT, projector3D)
model.cuda()

L = nn.L1Loss()

if os.path.isfile('/content/model.pt'):
    print('loading model...')
    model.load_state_dict(torch.load("/content/model.pt"))


all_feat_paths = [dir_dataset + '/deepcov/features/', dir_dataset + '/psicov/features/', dir_dataset + '/cameo/features/']
all_dist_paths = [dir_dataset + '/deepcov/distance/', dir_dataset + '/psicov/distance/', dir_dataset + '/cameo/distance/']

total_loss = []
coordinates_array = []
protein_length = []
similarity = []

counter = 0
lsttest = lst[3000:3150]

for filename in lsttest:
    filename = os.path.splitext(filename)[0]
    filename = filename[:-3]

    nodes = []
    edges = []
    mask = []
    coord = []
    nodes_out = []
    
    nodes, edges, mask, l = get_nodes_and_edges(filename, all_feat_paths, node_n, edge_n)

    nodes = nodes.cuda()
    edges = edges.cuda()
    mask = mask.cuda()
    
    '''
    if counter < 5:
      print(nodes)
    '''

    nodes_new, edges = GT(nodes, edges, mask = mask)
    nodes_new = nodes_new.cuda()
    coord = projector3D(nodes_new, nodes)

    
    '''
    if counter < 5:
      print('Output Nodes:')
      print(nodes_new)
      print('---')
      counter += 1
    '''
    coord = coord.detach_()

    or_dist = edges[:,:,:,edge_len]

    pred_dist = torch.zeros(1, coord.shape[1], coord.shape[1])
    #pred_dist = np.zeros([coord.shape[1], coord.shape[1]]) 


    for p in range (0,int(coord.shape[1])):
      for q in range (0,p):
        pred_dist[0, p,q] = torch.linalg.norm(coord[0,p,:] -  coord[0,q,:])

  
    pred_dist[0] = pred_dist[0] + pred_dist[0].T - torch.diag(torch.diag(pred_dist[0]))

    
    or_dist = or_dist.cuda()
    pred_dist = pred_dist.cuda()

    #print(coord)

    loss = L(pred_dist, or_dist)

    img = or_dist.cpu().detach().numpy()[0][0]
    ssimerror = ssim(pred_dist.cpu().detach().numpy()[0][0], img, data_range=img.max() - img.min())
    
    a = loss.cpu().detach().numpy()
    

    total_loss = np.append(total_loss, a)
    
    
    print(np.asscalar(a))
    #print(counter)
    #counter += 1

    
    if np.asscalar(a) < 3.7  and counter < 8:
        print("MAE:  ", np.asscalar(a))
        print("SSIM: ", ssimerror)
        print(filename)
        print('****')
        #print(pred_dist[0])
        plt.figure()
        b = pred_dist.cpu().detach().numpy()
        imshow(np.asarray(b[0]), cmap = "plasma")    
  

        plt.figure()
        b = or_dist.cpu().detach().numpy()
        imshow(np.asarray(b[0]), cmap = "plasma")

        counter = counter + 1

        coordinates_array = np.append(coordinates_array, np.asarray(coord.cpu().detach().numpy()))
        protein_length = np.append(protein_length, int(coord.shape[1]))

    similarity = np.append(similarity, ssimerror)
    #print(similarity)

    
    del nodes, edges, mask, coord, or_dist, pred_dist
    gc.collect()

print(similarity)


In [None]:
np.save('coordinates.npy', coordinates_array)
np.save('length.npy', protein_length)
np.save('loss.npy', total_loss)
np.save('similarity.npy', similarity)

In [None]:
index = np.argmin(total_loss)
lsttest = lst[3000:3150]
#index = np.argmax(similarity)
#lsttest = lst[3000:3150]
lsttest = lsttest[index]
counter = 0

lsttest = '1zv1A-cb.npy'

while counter < 1:
    counter += 1
    filename = os.path.splitext(lsttest)[0]
    filename = filename[:-3]

    print(filename)

    nodes = []
    edges = []
    mask = []
    coord = []
    nodes_out = []
    
    nodes, edges, mask, l = get_nodes_and_edges(filename, all_feat_paths, node_n, edge_n)

    nodes = nodes.cuda()
    edges = edges.cuda()
    mask = mask.cuda()


    nodes_new, edges = GT(nodes, edges, mask = mask)
    nodes_new = nodes_new.cuda()
    coord = projector3D(nodes_new, nodes)


    coord = coord.detach_()

    or_dist = edges[:,:,:,edge_len]


    pred_dist = torch.zeros(1, coord.shape[1], coord.shape[1])
    #pred_dist = np.zeros([coord.shape[1], coord.shape[1]]) 


    for p in range (0,int(coord.shape[1])):
      for q in range (0,p):
        pred_dist[0, p,q] = torch.linalg.norm(coord[0,p,:] -  coord[0,q,:])

  
    pred_dist[0] = pred_dist[0] + pred_dist[0].T - torch.diag(torch.diag(pred_dist[0]))

    
    or_dist = or_dist.cuda()
    pred_dist = pred_dist.cuda()

    #print(coord)

    loss = L(pred_dist, or_dist)
    a = loss.cpu().detach().numpy()

    img = or_dist.cpu().detach().numpy()[0][0]
    ssimerror = ssim(pred_dist.cpu().detach().numpy()[0][0], img, data_range=img.max() - img.min())
    
    print(np.asscalar(a))
    print(ssimerror)
    print()
    plt.figure()
    b = pred_dist.cpu().detach().numpy()
    imshow(np.asarray(b[0]))    
  

    plt.figure()
    b = or_dist.cpu().detach().numpy()
    imshow(np.asarray(b[0]))

np.save('coordinates' + str(filename) + '_nocost.npy', coord.cpu().detach().numpy())

In [None]:
print(np.max(total_loss))
print(np.mean(total_loss))
print(np.min(total_loss))
print(np.std(total_loss))

In [None]:
print(np.max(similarity))
print(np.mean(similarity))
print(np.min(similarity))
print(np.std(similarity))

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
pytorch_total_params

In [None]:
pytorch_total_params = sum(p.numel() for p in GT.parameters())
pytorch_total_params

In [None]:
pytorch_total_params = sum(p.numel() for p in projector3D.parameters())
pytorch_total_params