# Cloning into CGENNs repo

In [None]:
! git clone https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks.git

In [None]:
cd /content/clifford-group-equivariant-neural-networks

In [None]:
import torch
from torch import nn
torch.set_default_dtype(torch.float64)

import matplotlib.pyplot as plt

from models.modules.fcgp import FullyConnectedSteerableGeometricProductLayer
from models.modules.gp import SteerableGeometricProductLayer
from models.modules.linear import MVLinear
from models.modules.mvsilu import MVSiLU
from models.modules.mvlayernorm import MVLayerNorm
from models.modules.normalization import NormalizationLayer

from algebra.cliffordalgebra import CliffordAlgebra

In [None]:
algebra = CliffordAlgebra((1., 1., 1.))

In [None]:
cd ..

# Fetching data

In [None]:
! wget http://deep.cs.umsl.edu/pdnet/train-data.tar.gz
! tar zxf train-data.tar.gz

In [None]:
!pip install graph-transformer-pytorch

In [None]:
!pip install pytorch-msssim

In [None]:
!pip install pycuda

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import random
from pytorch_msssim import SSIM
import gc
from graph_transformer_pytorch import GraphTransformer
import time
from google.colab import files

In [None]:
version = 1

In [None]:
def get_nodes_and_edges(pdb, all_feat_paths, node_n, edge_n):
    features = None
    for path in all_feat_paths:
        if os.path.exists(path + pdb + '.pkl'):
            features = pickle.load(open(path + pdb + '.pkl', 'rb'))
    l = len(features['seq'])
    seq = features['seq']

    nodes = torch.zeros(1, l, node_n)
    edges = torch.zeros(1, l, l, edge_n)
    mask = torch.ones(1, l).bool()

    ######NODES FEATURES########

    # Add secondary structure
    ss = features['ss']
    assert ss.shape == (3, l)
    fi = 0
    gi = 0

    for j in range(3):
        a = np.repeat(ss[j].reshape(1, l), l, axis = 0)
        a = a[0,0:l]
        np.reshape(a, [1, l])
        nodes[:, :, fi] = torch.from_numpy(a).to(nodes)
        fi += 1
    # Add PSSM
    pssm = features['pssm']
    assert pssm.shape == (l, 22)
    for j in range(22):
        a = np.repeat(pssm[:, j].reshape(1, l), l, axis = 0)
        a = a[0,0:l]
        np.reshape(a, [1, l])
        nodes[:, :, fi] = torch.from_numpy(a).to(nodes)
        fi += 1

    # Add SA
    sa = features['sa']
    assert sa.shape == (l, )
    a = np.repeat(sa.reshape(1, l), l, axis = 0)
    a = a[0,0:l]
    np.reshape(a, [1, l])
    nodes[:, :, fi] = torch.from_numpy(a).to(nodes)
    fi += 1

    # Add entropy
    entropy = features['entropy']
    assert entropy.shape == (l, )
    a = np.repeat(entropy.reshape(1, l), l, axis = 0)
    a = a[0,0:l]
    np.reshape(a, [1, l])
    nodes[:, :, fi] = torch.from_numpy(a).to(nodes)
    fi += 1

    ######EDGES FEATURES########

    # Add CCMpred
    ccmpred = features['ccmpred']
    assert ccmpred.shape == ((l, l))
    edges[:, :, :, gi] = torch.from_numpy(ccmpred).to(edges)
    gi += 1
    # Add  FreeContact
    freecon = features['freecon']
    assert freecon.shape == ((l, l))
    edges[:, :, :, gi] = torch.from_numpy(freecon).to(edges)
    gi += 1
    # Add potential
    potential = features['potential']
    assert potential.shape == ((l, l))
    edges[:, :, :, gi] = torch.from_numpy(potential).to(edges)
    gi += 1

    if version == 1:
        cost = np.load('drive/MyDrive/Output-ALL/'+ pdb + '-cb.npy', allow_pickle = True)

        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

    if version == 2:
        cost = np.load('drive/MyDrive/OrientedPoints/'+ pdb + '-cb.npy', allow_pickle = True)

        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

    if version == 3:

        cost = np.load('drive/MyDrive/OrientedPoints/'+ pdb + '-cb.npy', allow_pickle = True)

        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/Output-ALL/'+ pdb + '-cb.npy', allow_pickle = True)

        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

    if version == 4:

        cost = np.load('drive/MyDrive/Phi/'+ pdb + '-cb.npy', allow_pickle = True)
        cost[2] = np.nan_to_num(cost[2])
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/Psi/'+ pdb + '-cb.npy', allow_pickle = True)
        cost[2] = np.nan_to_num(cost[2])
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/Omega/'+ pdb + '-cb.npy', allow_pickle = True)
        cost[2] = np.nan_to_num(cost[2])
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

    if version == 5:

        cost = np.load('drive/MyDrive/Output-ALL/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/Cost_CB/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

    if version == 6:

        cost = np.load('drive/MyDrive/OrientedPoints/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/OrientedPoints_CB/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

    if version == 7:

        cost = np.load('drive/MyDrive/OrientedPoints/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/OrientedPoints_CB/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/OrientedPoints_C/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

    if version == 8:

        cost = np.load('drive/MyDrive/OrientedPoints/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/OrientedPoints_CB/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/OrientedPoints_N/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

    if version == 9:

        cost = np.load('drive/MyDrive/Output-ALL/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/OrientedPoints/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/OrientedPoints_CB/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

    if version == 10:

        cost = np.load('drive/MyDrive/Output-ALL/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/OrientedPoints_CB/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/OrientedPoints_N/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

    if version == 11:

        cost = np.load('drive/MyDrive/Cost_phi/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/Cost_psi/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

        cost = np.load('drive/MyDrive/Cost_omega/'+ pdb + '-cb.npy', allow_pickle = True)
        edges[:, :, :, gi] = torch.from_numpy(cost[2]).to(edges)
        gi += 1

    distance = np.load('drive/MyDrive/DISTANCES/'+ pdb + '-cb.npy', allow_pickle = True)
    edges[:, :, :, gi] = torch.from_numpy(distance[2]).to(edges)

    return nodes, edges, mask, l


#Defining the GrT + 3D Projector architecture

In [None]:
class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.fc1 = nn.Linear(27,3, bias = False)
      #self.fc1 = nn.Linear(27, 3*128, bias = False)

      #self.cgenn1 = FullyConnectedSteerableGeometricProductLayer(algebra=algebra, in_features=9, out_features=3)
      #self.cgenn2 = FullyConnectedSteerableGeometricProductLayer(algebra=algebra, in_features=3, out_features=1)

      self.cgenn1 = MVLinear(algebra=algebra, in_features=9, out_features=3)
      self.cgenn2 = MVLinear(algebra=algebra, in_features=3, out_features=1)

      self.norm1 = MVLayerNorm(algebra, 1)
      self.act1 = MVSiLU(algebra, 1)

    def forward(self, x, old):
      #output = self.fc1(x)
      output = torch.reshape(x, (-1, 9, 3))
      output = algebra.embed_grade(output, 1)

      output = self.cgenn1(output)

      output = algebra.get_grade(output, 1)
      output = algebra.embed_grade(output, 1)

      #output = self.norm1(output)
      #output = self.cgenn2(output)

      output = algebra.get_grade(output, 1)
      output = torch.reshape(output, (1, -1, 3))
      return output

In [None]:
node_n = 27

if version == 0:
    edge_n = 4
if version == 1 or version == 2:
    edge_n = 5
if version == 3:
    edge_n = 6
if version == 4:
    edge_n = 7
if version == 5:
    edge_n = 6
if version == 6:
    edge_n = 6
if version == 7 or version == 8 or version == 9 or version == 10 or version == 11:
    edge_n = 7

GT = GraphTransformer(
    dim = node_n,
    depth = 3,
    heads = 4,
    edge_dim = edge_n,             # optional - if left out, edge dimensions is assumed to be the same as the node dimensions above
    with_feedforwards = True,   # whether to add a feedforward after each attention layer, suggested by literature to be needed
    gated_residual = True,      # to use the gated residual to prevent over-smoothing
    rel_pos_emb = True          # set to True if the nodes are ordered, default to False
)

projector3D = Net()

#GT.cuda()
#projector3D.cuda()
model = nn.Sequential(GT, projector3D)
#model.cuda()

#Training

In [None]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [None]:
dir_dataset = './data'
path = dir_dataset + '/deepcov/features/'

dir = '/content/drive/MyDrive/DISTANCES/'
lst = os.listdir(dir)
lst.sort()

lst_train = []
for filename in lst:
    pdb = os.path.splitext(filename)[0]
    pdb = pdb[:-3]
    if os.path.exists(path + pdb + '.pkl'):
        lst_train = np.append(lst_train, filename)


path = dir_dataset + '/psicov/features/'

lst_test = []
for filename in lst:
    pdb = os.path.splitext(filename)[0]
    pdb = pdb[:-3]
    if os.path.exists(path + pdb + '.pkl'):
        lst_test = np.append(lst_test, filename)

In [None]:
PATH = "drive/MyDrive/TripleFCGP/"

In [None]:
dir = '/content/drive/MyDrive/DISTANCES/'
lst = os.listdir(dir)
lst.sort()

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#GT.to(device)
#projector3D.to(device)
model.to(device)
torch.cuda.empty_cache()

dir_dataset = './data/'
train_feat_paths = [dir_dataset + '/deepcov/features/']
test_dist_paths = [dir_dataset + '/psicov/distance/']






optimizer = optim.Adam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, verbose = True)

early_stopper = EarlyStopper(patience=4, min_delta=0.1)
stopflag = 0

torch.backends.cudnn.benchmark = True


from sklearn.model_selection import train_test_split

lsttrain, lstval =  train_test_split(lst_train, test_size=0.20, random_state=42)

#lsttrain = lst_train[:500]
#lstval = lst_train[500:505]

epochs = 100
batch_size = 8
#n_batches = len(lsttrain) // batch_size
n_batches =  len(lsttrain) // batch_size


if os.path.isfile(PATH + "BEST_GTmodel.pt"):
    checkpoint = torch.load(PATH + "BEST_GTmodel.pt")
    minimum = checkpoint['minimum']
else:
    minimum= 100000

print("The minimum validation loss is: ", minimum)

L = nn.L1Loss()
ssim_module = SSIM(data_range=255, size_average=True, channel=1)
counter = 0
loss = 0
i_in = 0
j_in = 0
tot_val_loss = 0
alpha = 20

total_loss = 0

final_loss = []
final_mae = []
final_ssim = []
val_loss_arr = []


if version == 0:
    edge_len = 3
elif version == 1 or version == 2:
    edge_len = 4
elif version == 3:
    edge_len = 5
elif version == 4:
    edge_len = 6
elif version == 5:
    edge_len = 5
elif version == 6:
    edge_len = 5
elif version == 7 or version == 8 or version == 9 or version == 10 or version == 11:
    edge_len = 6

if os.path.isfile(PATH + "GTmodel.pt"):
    print('loading model...')
    checkpoint = torch.load(PATH + "GTmodel.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    i_in = checkpoint['epoch']
    total_loss = checkpoint['loss']
    j_in = checkpoint['batch']

    if os.path.isfile(PATH + 'train_loss.npy'):
        final_loss = np.load(PATH + 'train_loss.npy')
        final_loss = final_loss.tolist()
    if os.path.isfile(PATH + 'val_loss.npy'):
        val_loss_arr = np.load(PATH + 'val_loss.npy')
        val_loss_arr = val_loss_arr.tolist()


print(final_loss)
print(val_loss_arr)

for i in range(i_in, epochs):

        lstbatch = []
        print('****')

        np.random.shuffle(lsttrain)
        np.random.shuffle(lstval)

        for j in range(j_in, n_batches):

            print("batch n.", j+1, "/", n_batches)
            if (j+1)*batch_size < len(lsttrain):
                lstbatch = lsttrain[j*batch_size:(j+1)*batch_size]
            else:
                lstbatch = lsttrain[j*batch_size:]

            counter = 0


            for filename in lstbatch:

                #print(filename)

                filename = os.path.splitext(filename)[0]
                filename = filename[:-3]
                #print(filename)
                loss1 = 0
                loss2 = 0


                nodes, edges, mask, l = get_nodes_and_edges(filename, train_feat_paths, node_n, edge_n)

                #print(len(nodes))
                #print("TIME TO EXTRACT: ", time.time() - start)


                nodes = nodes.to(device)
                edges = edges.to(device)
                mask = mask.to(device)


                nodes_new, edges_new = GT(nodes, edges, mask = mask)


                nodes_new = nodes_new.to(device)
                edges_new = edges_new.to(device)


                coord = projector3D(nodes_new, nodes)

                #print(coord.shape)
                #noise = np.random.normal(0,1,[coord.shape[1], 3])
                #coord = coord + torch.from_numpy(noise).to(device)


                coord = coord.to(device)

                Y = edges[:,:,:,edge_len]

                X = torch.zeros(coord.shape[1], coord.shape[1])
                pred_dist = torch.zeros(1, coord.shape[1], coord.shape[1])

                Y = Y.to(device)
                X = X.to(device)
                #print(int(coord.shape[1]))


                for p in range (0,int(coord.shape[1])):
                    for q in range (0, p):
                        X[p,q] = torch.linalg.norm(coord[0,p,:] -  coord[0,q,:])


                pred_dist[0] = X + X.T - torch.diag(X)
                #print("ELAPSED TIME to EVALUATE PRED MAP: ", time.time() - start)

                #print(X)


                pred_dist = pred_dist.to(device)


                loss1 = L(pred_dist, Y)
                loss2 = 1 - ssim_module(pred_dist.unsqueeze(0), Y.unsqueeze(0))
                loss += loss1 + alpha*loss2


                total_loss = total_loss +  (loss1 + alpha*loss2).item()
                counter = counter + 1

                del nodes, edges, nodes_new, edges_new, mask, coord, X, Y, pred_dist
                torch.cuda.empty_cache()
                gc.collect()
                torch.no_grad()

            if i % 1 == 0 and j % 3 == 0:
                 print('epoch: %d,  batch: %d,  avg loss: %.3f' % (i, j, loss.item()/batch_size))
                 print('MAE: %.3f,  SSIM: %.3f' % ( loss1, 1-loss2))
                 #torch.save(model.state_dict(), "/content/model.pt")

                 torch.save({
                      'epoch': i,
                      'batch': j,
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'scheduler_state_dict': scheduler.state_dict(),
                      'loss': total_loss,
                      },  PATH + "GTmodel.pt")




                 print('model saved!')

            loss = loss / batch_size
            start = time.time()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss = 0


        ####################################Validation
        print("Validation")
        tot_val_loss = 0
        CNT = 0

        for filename in lstval:
            #print(filename)
            CNT += 1
            print(CNT, "/", len(lstval))


            filename = os.path.splitext(filename)[0]
            filename = filename[:-3]

            nodes, edges, mask, l = get_nodes_and_edges(filename, train_feat_paths, node_n, edge_n)
            nodes = nodes.to(device)
            edges = edges.to(device)
            mask = mask.to(device)

            nodes_new, edges_new = GT(nodes, edges, mask = mask)

            coord = projector3D(nodes_new, nodes)

            Y = edges[:,:,:,edge_len]

            #print(coord.shape[1])
            #print(filename)
            #print("+++")

            X = torch.zeros(coord.shape[1], coord.shape[1])
            pred_dist = torch.zeros(1, coord.shape[1], coord.shape[1])

            for p in range (0,int(coord.shape[1])):
                    for q in range (0, p):
                        X[p,q] = torch.linalg.norm(coord[0,p,:] -  coord[0,q,:])


            pred_dist[0] = X + X.T - torch.diag(X)
            #print("ELAPSED TIME to EVALUATE PRED MAP: ", time.time() - start)

            pred_dist = pred_dist.to(device)

            loss1 = L(pred_dist, Y)
            loss2 = 1 - ssim_module(pred_dist.unsqueeze(0), Y.unsqueeze(0))

            tot_val_loss = tot_val_loss +  (loss1 + alpha*loss2).item()
            counter = counter + 1

            del nodes, edges, nodes_new, edges_new, mask, l, coord, X, Y, pred_dist
            torch.cuda.empty_cache()
            gc.collect()
            torch.no_grad()
            torch.cuda.empty_cache()



        #print('validation loss: %.3f' % (tot_val_loss /len(lstval)))
        val_loss_arr = np.append(val_loss_arr, (tot_val_loss /len(lstval)))
        final_loss = np.append(final_loss, total_loss/len(lsttrain))

        print("....")
        print("validation loss:", val_loss_arr)
        print("training loss:", final_loss)
        print("....")


        np.save(PATH + 'train_loss.npy', final_loss)
        np.save(PATH +'val_loss.npy', val_loss_arr)

        if tot_val_loss /len(lstval) < minimum:
            print("FOUND A NEW BEST!!")
            torch.save({
                      'epoch': i,
                      'batch': j,
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'scheduler_state_dict': scheduler.state_dict(),
                      'loss': loss,
                      'minimum':tot_val_loss /len(lstval),
                      },  PATH + "BEST_GTmodel.pt")
            minimum = tot_val_loss /len(lstval)

        if early_stopper.early_stop(tot_val_loss/len(lstval)):
                stopflag = 1
                break

        tot_val_loss = 0
        total_loss = 0

        if j_in != 0:
            j_in = 0

        if (i + 1) % 2 == 0:
            scheduler.step()



        if stopflag == 1:
            break

            #print("ELAPSED TIME to UPDATE WEIGHTS: ", time.time() - start)
            #if i % 1 == 0 and j %  == 0:


        counter = 0

In [None]:
val_loss_arr = np.append(val_loss_arr, (tot_val_loss /len(lstval)))
final_loss = np.append(final_loss, total_loss/len(lsttrain))

print("....")
print("validation loss:", val_loss_arr)
print("training loss:", final_loss)
print("....")


In [None]:
torch.save(model.state_dict(), "/content/model.pt")
print('model saved!')

In [None]:
np.save(PATH +'train_loss.npy', final_loss)
np.save(PATH +'val_loss.npy', val_loss_arr)

#Testing

In [None]:
from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import structural_similarity as ssim

dir = '/content/drive/MyDrive/DISTANCES/'
lst = os.listdir(dir)
lst.sort()

test_feat_paths = [dir_dataset + '/psicov/features/']

if version == 0:
    edge_len = 3
elif version == 1 or version == 2:
    edge_len = 4
elif version == 3:
    edge_len = 5
elif version == 4:
    edge_len = 6
elif version == 5:
    edge_len = 5
elif version == 6:
    edge_len = 5
elif version == 7 or version == 8 or version == 9 or version == 10 or version == 11:
    edge_len = 6


dir_dataset = './data/'

projector3D = Net()
#GT.cuda()
#projector3D.cuda()
model = nn.Sequential(GT, projector3D)
#model.cuda()


L = nn.L1Loss()

if os.path.isfile(PATH + "BEST_GTmodel.pt"):
    print('loading model...')
    checkpoint = torch.load(PATH + "BEST_GTmodel.pt")
    model.load_state_dict(checkpoint['model_state_dict'])


total_loss = []
coordinates_array = []
protein_length = []
similarity = []

counter = 0

for filename in lst_test:
    filename = os.path.splitext(filename)[0]
    filename = filename[:-3]

    nodes = []
    edges = []
    mask = []
    coord = []
    nodes_out = []

    nodes, edges, mask, l = get_nodes_and_edges(filename, test_feat_paths, node_n, edge_n)

    #nodes = nodes.cuda()
    #edges = edges.cuda()
    #mask = mask.cuda()

    '''
    if counter < 5:
      print(nodes)
    '''

    nodes_new, edges = GT(nodes, edges, mask = mask)
    #nodes_new = nodes_new.cuda()
    coord = projector3D(nodes_new, nodes)


    '''
    if counter < 5:
      print('Output Nodes:')
      print(nodes_new)
      print('---')
      counter += 1
    '''
    #coord = coord.detach_()

    or_dist = edges[:,:,:,edge_len]

    pred_dist = torch.zeros(1, coord.shape[1], coord.shape[1])
    #pred_dist = np.zeros([coord.shape[1], coord.shape[1]])


    for p in range (0,int(coord.shape[1])):
      for q in range (0,p):
        pred_dist[0, p,q] = torch.linalg.norm(coord[0,p,:] -  coord[0,q,:])


    pred_dist[0] = pred_dist[0] + pred_dist[0].T - torch.diag(torch.diag(pred_dist[0]))


    #or_dist = or_dist.cuda()
    #pred_dist = pred_dist.cuda()

    #print(coord)

    loss = L(pred_dist, or_dist)

    img = or_dist.cpu().detach().numpy()[0][0]
    ssimerror = ssim(pred_dist.cpu().detach().numpy()[0][0], img, data_range=img.max() - img.min())

    a = loss.cpu().detach().numpy()


    total_loss = np.append(total_loss, a)


    print(np.asscalar(a))
    #print(counter)
    #counter += 1


    if np.asscalar(a) < 3.7  and counter < 8:
        print("MAE:  ", np.asscalar(a))
        print("SSIM: ", ssimerror)
        print(filename)
        print('****')
        #print(pred_dist[0])
        plt.figure()
        b = pred_dist.cpu().detach().numpy()
        imshow(np.asarray(b[0]), cmap = "plasma")


        plt.figure()
        b = or_dist.cpu().detach().numpy()
        imshow(np.asarray(b[0]), cmap = "plasma")

        counter = counter + 1

        coordinates_array = np.append(coordinates_array, np.asarray(coord.cpu().detach().numpy()))
        protein_length = np.append(protein_length, int(coord.shape[1]))

    similarity = np.append(similarity, ssimerror)
    #print(similarity)


    del nodes, edges, mask, coord, or_dist, pred_dist
    gc.collect()

print(similarity)


In [None]:
#np.save('coordinates.npy', coordinates_array)
# np.save('length.npy', protein_length)
np.save(PATH + 'loss.npy', total_loss)
np.save(PATH + 'similarity.npy', similarity)

In [None]:
print(np.max(total_loss))
print(np.median(total_loss))
print(np.min(total_loss))

In [None]:
print(np.max(similarity))
print(np.median(similarity))
print(np.min(similarity))

#Verifying the alignment on selected proteins

In [None]:
#lsttest = lst
#index = np.argmin(total_loss)
#lsttest = lsttest[index]
counter = 0

types = ["1mk0A-cb.npy", "1z0jB-cb.npy", "1yqhA-cb.npy", "1zv1A-cb.npy", "2d0oB-cb.npy", "2dgbA-cb.npy", "2dm9A-cb.npy", "2ehwA-cb.npy", "2fyuK-cb.npy", "2fztA-cb.npy", "2gomA-cb.npy"]

all_feat_paths = [dir_dataset + '/deepcov/features/', dir_dataset + '/psicov/features/', dir_dataset + '/cameo/features/']


while counter < 10:

    lsttest = types[counter]
    counter += 1

    filename = os.path.splitext(lsttest)[0]
    filename = filename[:-3]

    print(filename)

    nodes = []
    edges = []
    mask = []
    coord = []
    nodes_out = []

    nodes, edges, mask, l = get_nodes_and_edges(filename, all_feat_paths, node_n, edge_n)

    #nodes = nodes.cuda()
    #edges = edges.cuda()
    #mask = mask.cuda()


    nodes_new, edges = GT(nodes, edges, mask = mask)
    #nodes_new = nodes_new.cuda()
    coord = projector3D(nodes_new, nodes)


    #coord = coord.detach_()

    or_dist = edges[:,:,:,edge_len]


    pred_dist = torch.zeros(1, coord.shape[1], coord.shape[1])
    #pred_dist = np.zeros([coord.shape[1], coord.shape[1]])


    for p in range (0,int(coord.shape[1])):
      for q in range (0,p):
        pred_dist[0, p,q] = torch.linalg.norm(coord[0,p,:] -  coord[0,q,:])


    pred_dist[0] = pred_dist[0] + pred_dist[0].T - torch.diag(torch.diag(pred_dist[0]))


    #or_dist = or_dist.cuda()
    #pred_dist = pred_dist.cuda()

    #print(coord)

    loss = L(pred_dist, or_dist)
    a = loss.cpu().detach().numpy()

    img = or_dist.cpu().detach().numpy()[0][0]
    ssimerror = ssim(pred_dist.cpu().detach().numpy()[0][0], img, data_range=img.max() - img.min())

    print(np.asscalar(a))
    print(ssimerror)
    print()
    plt.figure()
    b = pred_dist.cpu().detach().numpy()
    imshow(np.asarray(b[0]))


    plt.figure()
    b = or_dist.cpu().detach().numpy()
    imshow(np.asarray(b[0]))

    np.save(PATH + "version" + str(version) +'/coordinates' + str(filename) + '_version' + str(version) + '.npy', coord.cpu().detach().numpy())

In [None]:
'''
dir = '/content/drive/MyDrive/DISTANCES/'
lst = os.listdir(dir)
lst.sort()
lsttest = lst[200:950]

namelist = []
for filename in lsttest:
    print(filename)
    filename = os.path.splitext(filename)[0]
    filename = filename[:-4]

    namelist = np.append(namelist, filename)
'''

In [None]:
#!zip -r /content/version1.zip /content/version1

# Evaluating GDT scores over the Test set

In [None]:
def rigid_transform_3D(A, B):
    assert A.shape == B.shape

    num_rows, num_cols = A.shape
    if num_rows != 3:
        raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")

    num_rows, num_cols = B.shape
    if num_rows != 3:
        raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")

    # find mean column wise
    centroid_A = np.mean(A, axis=1)
    centroid_B = np.mean(B, axis=1)

    # ensure centroids are 3x1
    centroid_A = centroid_A.reshape(-1, 1)
    centroid_B = centroid_B.reshape(-1, 1)

    # subtract mean
    Am = A - centroid_A
    Bm = B - centroid_B

    H = Am @ np.transpose(Bm)

    # sanity check
    #if linalg.matrix_rank(H) < 3:
    #    raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H)))

    # find rotation
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T

    # special reflection case
    if np.linalg.det(R) < 0:
        #print("det(R) < R, reflection detected!, correcting for it ...")
        Vt[2,:] *= -1
        R = Vt.T @ U.T

    t = -R @ centroid_A + centroid_B

    return R, t

In [None]:
!pip install biopython
!pip install git+https://github.com/pygae/clifford.git@master

In [None]:
from Bio.PDB.PDBParser import PDBParser
from clifford.tools.g3c import *
import warnings
warnings.filterwarnings("ignore")

In [None]:
dir = '/content/drive/MyDrive/DISTANCES/'
lst = os.listdir(dir)
lst.sort()

dir = '/content/drive/MyDrive/coordinates/'
lst = os.listdir(dir)
lst.sort()

dir_dataset = './data/'
all_feat_paths = [dir_dataset + '/psicov/features/']

model_NN = nn.Sequential(GT, projector3D)
#model_NN.cuda()

if os.path.isfile('/content/model.pt'):
    print('loading model...')
    model_NN.load_state_dict(torch.load("/content/model.pt"))


parser = PDBParser(PERMISSIVE=1)
structure_id = "chain"

GDT_TS_TOTAL = []
GDT_HA_TOTAL = []

for filename in lst_test:
    filename = os.path.splitext(filename)[0]
    filename = filename[:-3]
    #print(filename)

    nodes = []
    edges = []
    mask = []
    coord = []
    nodes_out = []

    nodes, edges, mask, l = get_nodes_and_edges(filename, all_feat_paths, node_n, edge_n)

    #nodes = nodes.cuda()
    #edges = edges.cuda()
    #mask = mask.cuda()


    nodes_new, edges = GT(nodes, edges, mask = mask)
    #nodes_new = nodes_new.cuda()
    pred_coord = projector3D(nodes_new, nodes)


    pred_coord = pred_coord.detach().cpu().numpy()

    target_chain = filename[-1]

    #print(target_chain)
    print(filename)
    filename = filename[:-1]


    if 'pdb'+ filename + '.ent' in lst:
        #print('pdb'+ filename + '.ent')

        structure = parser.get_structure(structure_id, dir +'pdb' + filename + '.ent')

        N = 0
        m = 0
        I = e1*e2*e3


        idx = 0
        TOL =  15 #MEASURE IN ANGSTROM DEFINING THE RADIUS OF INTERACTION BETWEEN TWO RESIDUES
        chain_n = 0

        #counting the total number of atoms N in the chain
        for model in structure:
            for chain in model:
                chain_n = chain_n + 1
                #print(chain.id)
                #if chain.id ==  'A':
                for residue in chain:
                    for atom in residue:
                        #print(atom)
                        if atom.altloc == "B":
                            #print(atom)
                            del atom


        for model in structure:
            for chain in model:
                #print(chain.id)
                #if chain.id ==  'A':
                for residue in chain:
                    for atom in residue:
                        if chain.id ==  target_chain:
                            N = N + 1
                            if atom.name == "CA":
                                idx = idx + 1


        #storing 3D coordinates in an array P for each of the atoms
        P = np.zeros([N, 3])
        cnt = np.zeros([idx, 1])


        i = 0
        m = 0
        for model in structure:
            for chain in model:
                if chain.id == target_chain:
                  for residue in chain:
                    for atom in residue:
                        if chain.id == target_chain:
                            P[m] = atom.get_coord()
                            m = m + 1;
                            if atom.name == "CA":
                                cnt[i,0] = m
                                i = i + 1


        coord = []


        for m in range(0,idx):
            i = int(cnt[m,0])-1
            coord = np.append(coord, P[i,:])

        coord = np.reshape(coord, [idx, 3])


        sum_x = 0
        sum_y = 0
        sum_z = 0

        for i in range(0, idx):
            sum_x += coord[i, 0]
            sum_y += coord[i, 1]
            sum_z += coord[i, 2]

        M = [sum_x/idx, sum_y/idx, sum_z/idx]

        coord = coord - M


        sum_x = 0
        sum_y = 0
        sum_z = 0

        if idx > pred_coord.shape[1]:
          idx = pred_coord.shape[1]
          coord = coord[:idx]

        for i in range(0, idx):
            sum_x += pred_coord[0, i, 0]
            sum_y += pred_coord[0, i, 1]
            sum_z += pred_coord[0, i, 2]

        M = [sum_x/idx, sum_y/idx, sum_z/idx]

        pred_coord[0] = pred_coord[0] - M


        GDT_TS_max = 0
        GDT_HA_max = 0
        index = 0
        best_X = coord

        for k in range(0, 2000):

            if k == 0:
                R, t = rigid_transform_3D(coord.reshape((3, idx)), pred_coord[0].reshape((3, idx)))
                X = np.matmul(coord, R)
                #X = coord
                #Y = np.matmul(pred_coord[0], R)
                Y = pred_coord[0]

            else :
                R, t = rigid_transform_3D(X.reshape((3, idx)), Y.reshape((3, idx)))
                X = np.matmul(X, R)
                #X = X
                Y = Y
                #Y = np.matmul(Y, R)

            total_1 = 0
            total_2 = 0
            total_4 = 0
            total_8 = 0

            for i in range(0, idx):
                distance = np.linalg.norm(X[i,:] - Y[i, :])
                if distance < 1:
                    total_1 += 1
                if distance < 2:
                    total_2 += 1
                if distance < 4:
                    total_4 += 1
                if distance < 8:
                    total_8 += 1

            GDT_P1 = total_1 / idx
            GDT_P2 = total_2 / idx
            GDT_P4 = total_4 / idx
            GDT_P8 = total_8 / idx

            GDT_TS = (GDT_P1 + GDT_P2 + GDT_P4 +GDT_P8)*100/4

            total_1 = 0
            total_2 = 0
            total_4 = 0
            total_8 = 0

            for i in range(0, idx):
                distance = np.linalg.norm(X[i,:] - Y[i, :])
                if distance < 1/2:
                    total_1 += 1
                if distance < 2/2:
                    total_2 += 1
                if distance < 4/2:
                    total_4 += 1
                if distance < 8/2:
                    total_8 += 1

            GDT_P1 = total_1 / idx
            GDT_P2 = total_2 / idx
            GDT_P4 = total_4 / idx
            GDT_P8 = total_8 / idx

            GDT_HA = (GDT_P1 + GDT_P2 + GDT_P4 +GDT_P8)*100/4

            if GDT_TS > GDT_TS_max:
              GDT_TS_max = GDT_TS
              index = k
              GDT_HA_max = GDT_HA
              best_X = X

        #print('maximum GDT_TS: ', GDT_TS_max)
        #print('GDT_HA: ', GDT_HA_max)
        #print('achieved after n. of iterations equal to:', index)
        #print("****")

        if GDT_TS_max > 30:
          print(filename)
          print(GDT_TS_max)
          print(idx)
          print("******")

        GDT_TS_TOTAL = np.append(GDT_TS_TOTAL, GDT_TS_max)
        GDT_HA_TOTAL = np.append(GDT_HA_TOTAL, GDT_HA_max)


In [None]:
np.save(PATH + 'GDT_TS' + '.npy', GDT_TS_TOTAL)
np.save(PATH + 'GDT_HA' + '.npy', GDT_HA_TOTAL)

In [None]:
print(np.max(GDT_TS_TOTAL))
print(np.median(GDT_TS_TOTAL))
print(np.min(GDT_TS_TOTAL))

In [None]:
print(np.max(GDT_HA_TOTAL))
print(np.median(GDT_HA_TOTAL))
print(np.min(GDT_HA_TOTAL))

#Counting the number of trainable parameters

In [None]:
pytorch_total_params = sum(p.numel() for p in GT.parameters())
pytorch_total_params

In [None]:
pytorch_total_params = sum(p.numel() for p in projector3D.parameters())
pytorch_total_params