In [1]:
import os, glob, sys, math
import concurrent.futures
import argparse

import shutil

#from meetdock import *
# from scipy.spatial.transform import Rotation as R
from sklearn.neighbors import NearestNeighbors
import numpy as np
import pandas as pd
from Bio.PDB import *
from Bio.PDB.ResidueDepth import get_surface
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.ResidueDepth import min_dist
from pyquaternion import Quaternion

from utils import pdbtools
from utils import pdb_resdepth
from utils import matrice_distances
from utils import Lennard_Jones
from utils import electrostatic
from utils import combine_methods as cm
from utils import tm_score as tm

# from surface import *
p = PDBParser()

recognized_residues = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET',
                           'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'NH', 'OC']
atom_types = [['N'], ['CA'], ['C'], ['O'], ['GLYCA'],
                  ['ALACB', 'ARGCB', 'ASNCB', 'ASPCB', 'CYSCB', 'GLNCB', 'GLUCB', 'HISCB', 'ILECB', 'LEUCB', 'LYSCB',
                   'METCB', 'PHECB', 'PROCB', 'PROCG', 'PROCD', 'THRCB', 'TRPCB', 'TYRCB', 'VALCB'],
                  ['LYSCE', 'LYSNZ'], ['LYSCD'], ['ASPCG', 'ASPOD1', 'ASPOD2', 'GLUCD', 'GLUOE1', 'GLUOE2'],
                  ['ARGCZ', 'ARGNH1', 'ARGNH2'],
                  ['ASNCG', 'ASNOD1', 'ASNND2', 'GLNCD', 'GLNOE1', 'GLNNE2'], ['ARGCD', 'ARGNE'],
                  ['SERCB', 'SEROG', 'THROG1', 'TYROH'],
                  ['HISCG', 'HISND1', 'HISCD2', 'HISCE1', 'HISNE2', 'TRPNE1'], ['TYRCE1', 'TYRCE2', 'TYRCZ'],
                  ['ARGCG', 'GLNCG', 'GLUCG', 'ILECG1', 'LEUCG', 'LYSCG', 'METCG', 'METSD', 'PHECG', 'PHECD1', 'PHECD2',
                   'PHECE1', 'PHECE2', 'PHECZ', 'THRCG2', 'TRPCG', 'TRPCD1', 'TRPCD2', 'TRPCE2', 'TRPCE3', 'TRPCZ2',
                   'TRPCZ3', 'TRPCH2', 'TYRCG', 'TYRCD1', 'TYRCD2'],
                  ['ILECG2', 'ILECD1', 'ILECD', 'LEUCD1', 'LEUCD2', 'METCE', 'VALCG1', 'VALCG2'], ['CYSSG']]

rng = np.random.default_rng(0)

In [2]:
def chaindef(file, rec_chain):
    
    structure=p.get_structure('1bth',file)
    coordinatesr = np.empty((0,3))
    tobi_residuesr = []
    residue_id=[]
    boundary_residue_coord=np.empty((0,3))
    atom_coord=np.empty((0,3))
    boundary_residue_id=[]
    boundary_residue_name=[]
    #rcc=0
    for model in structure:
        surface = get_surface(model)
        for chain in model:
            if chain.id in rec_chain:
                for residue in chain:
                    #print('hi')
                    cx = 0.0
                    cy = 0.0
                    cz = 0.0
                    count = 0
                    residue_index=recognized_residues.index(residue.get_resname())
                    atom_set=np.empty((0,3))
                    for atom in residue:
                        if  not atom.name=='H':
                            ax=atom.get_coord()[0]
                            ay=atom.get_coord()[1]
                            az=atom.get_coord()[2]
                            atom_set=np.append(atom_set,[atom.get_coord()], axis=0)
                            atom_coord=np.append(atom_coord,[atom.get_coord()], axis=0)
                            cur_atom=residue.get_resname()+atom.name
                            for typ in atom_types:
                                if  cur_atom in typ or atom.name in ['N','CA','C','O']:	#typ:#atom.name now added
                                    cx += ax
                                    cy += ay
                                    cz += az
                                    count += 1
                                else:
                                    pass
                    cx/= float(count)
                    cy/= float(count)
                    cz/= float(count)
                    coordinatesr=np.append(coordinatesr,[[cx, cy, cz]], axis=0)
                    #rcc+=1
                    tobi_residuesr.append(residue_index)
                    residue_id.append(str(residue.get_id()[1])+residue.get_id()[2])
                    fji=0     #check whether any of of the atoms in the resdue are at a distance 3 A from surface
                    for ji in range(len(atom_set)):
                        if min_dist(atom_set[ji], surface) < 2:
                            fji=1
                            break
                    if fji==1:
                        boundary_residue_coord=np.append(boundary_residue_coord,[[cx, cy, cz]],axis=0)
                        #boundary_atom_name.append(atom.name)
                        boundary_residue_id.append(str(residue.get_id()[1])+residue.get_id()[2])
                        boundary_residue_name.append(residue.get_resname())
    #print(rcc)
    return boundary_residue_coord,boundary_residue_name, boundary_residue_id, atom_coord

In [3]:
def findPointNormals(points, numNeighbours, viewPoint, residue_id, residue_name,f):
    """
    xu=[]
    for i in points:
        k=[]
        for j in i:
             k.append(float(j))
        xu.append(k)
    viewPoint =[float(x) for x in viewPoint]
    X=xu
    """
    nbrs = NearestNeighbors(n_neighbors=numNeighbours+1, algorithm='kd_tree').fit(points)
    distances, indices = nbrs.kneighbors(points)
    n = [] #indices[:,2:]
    [n.append(indices[i][1:].tolist()) for i in range(0,len(indices))]

    
      #%find difference in position from neighbouring points
    n=np.asarray(n).flatten('F')    
    p = np.tile(points,(numNeighbours,1)) - points[n]
    x=np.zeros((3,len(points),numNeighbours))
    for i in range(0,3):
        for j in range(0,len(points)):
            for k in range(0,numNeighbours):
                x[i,j,k]=p[k*len(points)+j,i]
    p=x
    C = np.zeros((len(points),6))
    C[:,0]= np.sum(np.multiply(p[0],p[0]),axis=1)
    C[:,1]= np.sum(np.multiply(p[0],p[1]),axis=1)
    C[:,2]= np.sum(np.multiply(p[0],p[2]),axis=1)
    C[:,3]= np.sum(np.multiply(p[1],p[1]),axis=1)
    C[:,4]= np.sum(np.multiply(p[1],p[2]),axis=1)
    C[:,5]= np.sum(np.multiply(p[2],p[2]),axis=1)
    C = np.divide(C, numNeighbours)
    normals = np.zeros((len(points),3))
    curvature = np.zeros((len(points),1))
    for i in range(0,len(points)):
        Cmat = [[C[i,0], C[i,1] ,C[i,2]], [C[i,1], C[i,3], C[i,4]], [C[i,2], C[i,4], C[i,5]]]
        [value,vector] = np.linalg.eigh(Cmat)
        [lam,k] = min(value), value.tolist().index(min(value))
        normals[i,:] = vector[:,k] #np.transpose(vector[:,k])
        curvature[i]= lam / sum(value)
        
    return normals, curvature

In [4]:
def do_something(args):
    output_file='out'+str(args[1])+'.pdb'
    out=open(os.path.join(mypath, output_file),'w')
    sc=open('score.txt','a')
    in1=open(inp2,'r')
    in2=open(inp1,'r')
    for line in in1:
        if 'ATOM' in line:
              out.write(line) 
    indexing=0
    new_co=args[0]
    for line in in2:
        if 'ATOM' in line:
            #print(line)
            l=line.split()
            l[0]=l[0].ljust(5)
            l[1]=l[1].rjust(5)
            l[2]=l[2].ljust(3)
            l[3]=l[3].ljust(3)
            l[4]=line[21]
            l[5]=('%4d' % (int(line[22:26]))).rjust(4)
            l[6]=('%8.3f' % (float(new_co[indexing][0]))).rjust(8)
            l[7]=('%8.3f' % (float(new_co[indexing][1]))).rjust(8)
            l[8]=('%8.3f' % (float(new_co[indexing][2]))).rjust(8)
            out.write('{0} {1}  {2} {3} {4}{5}    {6}{7}{8}' .format(l[0],l[1],l[2],l[3],l[4],l[5],l[6],l[7],l[8]))
            out.write('\n')
            indexing+=1
    out.close()
    #print("depth ok")
    pdbfile=os.path.join(mypath, output_file)
    my_struct = pdbtools.read_pdb(pdbfile)
    try:
        depth_dict = pdb_resdepth.calculate_resdepth(structure=my_struct, pdb_filename=pdbfile, method="msms")
    except:
        os.remove(os.path.join(mypath, output_file))
        return
    
    distmat = matrice_distances.calc_distance_matrix(structure=my_struct, depth= depth_dict, chain_R=rec_chain, chain_L=lig_chain, dist_max=dist, method = depth)

    vdw = Lennard_Jones.lennard_jones(dist_matrix=distmat)
    electro = electrostatic.electrostatic(inter_resid_dict=distmat, pH =pH)
    score=vdw+electro
    #if score>=0:
    #      os.remove(os.path.join(mypath, output_file))#eliminate the bad solutions

    #      return
    #else:

    #score=np.random.randint(-30,20)
    #sc.write(str(args[1])+'   '+ str(score)+'\n')
    #sc.close()
    return score, args[1], args[2], args[3]

In [5]:
def find_score(args):
    output_file='out'+str(args[1])+'.pdb'
    shape, electro, jones, proba = True, True, True, False
    pH = 7
    dist = 8.6
    with open(os.path.join(mypath, output_file),'w') as out:
        in1 = open(inp2, "r")
        in2 = open(inp1, "r")
        for line in in1:
            if "ATOM" in line:
                out.write(line)
        indexing = 0
        new_co = args[0]
        for line in in2:
            if "ATOM" in line:
                # print(line)
                l = line.split()
                l[0] = l[0].ljust(5)
                l[1] = l[1].rjust(5)
                l[2] = l[2].ljust(3)
                l[3] = l[3].ljust(3)
                l[4] = line[21]
                l[5] = ("%4d" % (int(line[22:26]))).rjust(4)
                l[6] = ("%8.3f" % (float(new_co[indexing][0]))).rjust(8)
                l[7] = ("%8.3f" % (float(new_co[indexing][1]))).rjust(8)
                l[8] = ("%8.3f" % (float(new_co[indexing][2]))).rjust(8)
                out.write(
                    "{0} {1}  {2} {3} {4}{5}    {6}{7}{8}".format(
                        l[0], l[1], l[2], l[3], l[4], l[5], l[6], l[7], l[8]
                    )
                )
                out.write("\n")
                indexing += 1

    pdbfile = os.path.join(mypath, output_file)
    res = cm.combine_score(pdbfile, recepChain=rec_chain, ligChain=lig_chain, statpotrun=proba, vdwrun=jones, electrorun=electro, shaperun=shape, pH=pH, depth=depth, dist=dist)
    mydf = pd.DataFrame(res,  index=[0])
    mydf = mydf.set_index('pdb')
    score = tm.tm_score(mydf, execdir='.')
    return float(score['tm_score_prediction']), args[1], args[2], args[3]

In [6]:
def pdbpre(file1):
    with open(os.path.join(args.pdb, file1), "r") as pdb_in: 
        with open(file1 + "1.pdb", "w") as out: 
            atmno = 1
            resno = 0
            res = ""
            fr = ""
            l = [""] * 11
            for line in pdb_in:
                if "ATOM" in line[0:4]:
                    li = line.split()
                    l[0] = li[0].ljust(6)
                    l[1] = str(atmno).rjust(4)
                    l[2] = li[2].ljust(3)
                    l[3] = li[3].ljust(3)
                    l[4] = line[21]
                    if fr != line[21]:
                        atmno = 1
                        resno = 0
                        res = ""
                        fr = line[21]
                    if line[22:26] == res:
                        l[5] = ("%4d" % (int(resno))).rjust(4)
                    else:
                        resno += 1
                        res = line[22:26]
                        l[5] = ("%4d" % (int(resno))).rjust(4)
                    # if len(l[6])>10:
                    l[6] = ("%8.3f" % (float(line[29:38]))).rjust(8)
                    l[7] = ("%8.3f" % (float(line[38:46]))).rjust(8)
                    l[8] = ("%8.3f" % (float(line[46:54]))).rjust(8)
                    l[9] = ("%6.2f" % (float(line[55:60]))).rjust(6)
                    l[10] = ("%6.2f" % (float(line[60:66]))).ljust(6)
                    out.write(
                        "{0} {1}  {2} {3} {4}{5}    {6}{7}{8}{9}{10}".format(
                            l[0], l[1], l[2], l[3], l[4], l[5], l[6], l[7], l[8], l[9], l[10]
                        )
                    )
                    out.write("\n")
                    atmno += 1
    return file1 + "1.pdb"

In [53]:
def do_something1(args):
    output_file='out_'+str(args[1])+'.pdb'
    with open(os.path.join(mypath, output_file),'w') as out:
        in2=open(inp1,'r')
        indexing=0
        new_co=args[0]
        for line in in2:
            if 'ATOM' in line:
                #print(line)
                l=line.split()
                l[0]=l[0].ljust(5)
                l[1]=l[1].rjust(5)
                l[2]=l[2].ljust(3)
                l[3]=l[3].ljust(3)
                l[4]=line[21]
                l[5]=('%4d' % (int(line[22:26]))).rjust(4)
                l[6]=('%8.3f' % (float(new_co[indexing][0]))).rjust(8)
                l[7]=('%8.3f' % (float(new_co[indexing][1]))).rjust(8)
                l[8]=('%8.3f' % (float(new_co[indexing][2]))).rjust(8)
                out.write('{0} {1}  {2} {3} {4}{5}    {6}{7}{8}' .format(l[0],l[1],l[2],l[3],l[4],l[5],l[6],l[7],l[8]))
                out.write('\n')
                indexing+=1
    alignment = tmscoring.TMscoring(inp2, os.path.join('poses/', output_file))
    score = alignment.tmscore(**alignment.get_current_values())
    score1 = alignment.rmsd(**alignment.get_current_values())
    
    #tms()
    #score = tms.get_tm_score()
    print(score, score1)
    return score, args[1], args[2], args[3]

In [56]:
res = do_something1(Doargs[16])

0.0008423485683393902 52.30314890225165


In [11]:
pdb0 = '4dn4' 
pdb1 = ['LH', 'M']

In [12]:
class ArgsVal:
    def __init__(self, pdb):
        self.pdb = pdb

In [13]:
args = ArgsVal("Data/4dn4_LH:M")

In [14]:
depth = "msms"
dist = 8.6
pH = 7

In [15]:
rpdb=pdb1[0]+'_model_st.pdb'
lpdb=pdb1[1]+'_model_st.pdb'
lig_chain=[]
rec_chain=[]
for i in pdb1[0]:
    rec_chain.append(i)
for i in pdb1[1]:
    lig_chain.append(i)

In [16]:
inp1 = pdbpre(lpdb)
inp2 = pdbpre(rpdb)

In [17]:
lig_coord, lig_res,lig_res_id, lig_atom=chaindef(inp1, lig_chain)
rec_coord, rec_res,rec_res_id, rec_atom=chaindef(inp2, rec_chain)
print(len(rec_chain))
print(rec_chain)
#print(rec_coord)

2
['L', 'H']


In [18]:
rec_normal, rec_curve = findPointNormals(rec_coord, 20,[0,0,0], rec_res_id, rec_res, 'r')
lig_normal, lig_curve = findPointNormals(lig_coord, 20,[0,0,0], lig_res_id, lig_res, 'r')

In [46]:
frogs = 50 ## No of frogs (population)

StructInfo = {}
init = 0
mypath='poses/'

### Multi - with Function

In [133]:
frogs = 200 ## No of frogs (population)

StructInfo = {}
init = 0
mypath='poses/'

In [135]:
def generate_init_population():
    global init
    with concurrent.futures.ProcessPoolExecutor() as executor:
        Doargs = []
        for _ in range(frogs):
            Doargs += generate_one_frog(init)
            init += 1
        results = executor.map(find_score, Doargs)
        for r in results:
            if r:
                StructInfo[r[1]] = [r[0], r[2]]    

In [None]:
generate_init_population()

In [223]:
len(Doargs)

50

In [None]:
with concurrent.futures.ProcessPoolExecutor() as executor:
    Doargs = []
    for _ in range(frogs):
        Doargs += generate_one_frog(init)
        init += 1
    results = executor.map(find_score, Doargs)
    for r in results:
        if r:
            StructInfo[r[1]] = [r[0], r[2]]

In [None]:
StructInfo

In [149]:
rng.integers(0, 12)

7

In [32]:
def sort_frog(mplx_no):
    sorted_fitness = np.array(sorted(StructInfo, key = lambda x: StructInfo[x][0], reverse=True))

    memeplexes = np.zeros((mplx_no, int(frogs/mplx_no)))

    for j in range(memeplexes.shape[1]):
        for i in range(mplx_no):
            memeplexes[i, j] = sorted_fitness[i + (mplx_no*j)] 
    return memeplexes

In [54]:
def shuffle_memeplexes(memeplexes):
    """Shuffles the memeplexes and sorting them.
    
    Arguments:
        frogs {numpy.ndarray} -- All the frogs
        memeplexes {numpy.ndarray} -- The memeplexes
    
    Returns:
        None
    """
    mplx_no = memeplexes.shape[0]
    temp = memeplexes.flatten()
    temp = np.array(sorted(temp, key = lambda x: StructInfo[x][0], reverse=True))
    for j in range(memeplexes.shape[1]):
        for i in range(mplx_no):
            memeplexes[i, j] = temp[i + (mplx_no*j)]

In [240]:
shuffle_memeplexes(mem)

In [20]:
StructInfo

{}

## Local Search

#### Local Search Main Code -- Multi

In [138]:
memeplexes = sort_frog(20)

In [191]:
def local_search_one_memeplex(inp):
    """
        q: The number of frogs in submemeplex
        N: No of mutations
    """
    im, N, q, Frog_gb = inp
    FrogsEach = int(frogs/len(memeplexes))
    weights = [2*(FrogsEach+1-j)/(FrogsEach*(FrogsEach+1)) for j in range(1, FrogsEach+1)] 
    
    Omega = [np.amin(rec_normal), np.amax(rec_normal)]
    max_step = (Omega[1]-Omega[0])/2 # maximum step size
    
    for iN in range(N):
        uId = init + im + 1
        rValue = rng.random(FrogsEach) * weights # random value with probability weights
        subindex = np.sort(np.argsort(rValue)[::-1][0:q]) # index of selected frogs in memeplex
        submemeplex = memeplexes[im][subindex] 

        #--- Improve the worst frog's position ---#
        # Learn from local best Pb #
        Pb = StructInfo[int(submemeplex[0])] # mark the best frog in submemeplex
        Pw = StructInfo[int(submemeplex[q-1])] # mark the worst frog in memeplex

        S = rng.random() * (Pb[1] - Pw[1]) 
        Uq = Pw[1] + S

        globStep = False
        censorship = False
        # Check feasible space and the performance #
        if Omega[0] <= min(Uq) and max(Uq) <= Omega[1]: # check feasible space
            final = np.array([Uq.rotate(i) for i in lig_atom])  
            results = find_score([final, uId, Uq, im])

            if results[0] < Pw[0]:
                globStep = True

        if globStep:
            S = rng.random() * (Frog_gb[1] - Pw[1])
            for i in range(4):
                if S[i] > 0:
                    S[i] = min(S[i],max_step)
                else:
                    S[i] = max(S[i],-max_step)
            Uq = Pw[1] + S

            if Omega[0] <= min(Uq) and max(Uq) <= Omega[1]: # check feasible space
                final = np.array([Uq.rotate(i) for i in lig_atom])  
                results = find_score([final, uId, Uq, im])
                if results[0] < Pw[0]:
                    censorship = True
            else:
                censorship = True

        if censorship:
            params = generate_one_frog(uId)
            results = find_score(params)            


        #StructInfo[im] = [results[0], results[2]]
        shutil.move(os.path.join('poses/', 'out'+str(uId)+'.pdb'), os.path.join('poses/', 'out'+ str(submemeplex[q-1]) + '.pdb'))
        StructInfo[int(submemeplex[q-1])] = [results[0], results[2]]
        memeplexes[im] = memeplexes[im][np.argsort(memeplexes[im])]

In [198]:
def local_search():
    Frog_gb = StructInfo[int(memeplexes[0][0])]
    N = 10
    q = 6
    with concurrent.futures.ProcessPoolExecutor() as executor:
        doargs = [[im, N, q, Frog_gb] for im in range(len(memeplexes))]
        results = executor.map(local_search_one_memeplex, doargs)

In [None]:
%%time
local_search()

In [194]:
memeplexes

array([[  1., 192., 133.,   8.,  11.,  25.,  36., 160.,   9.,  48.],
       [ 83.,  43.,  69.,  59.,  68.,  76.,  86.,  15., 157., 174.],
       [161., 103., 145.,  53., 119., 134., 142., 183.,   6.,  51.],
       [ 18.,  44.,  71.,  93., 171., 180., 191.,  16., 116., 175.],
       [ 88., 105., 149., 155.,  13.,  29.,  38.,  19., 118.,  82.],
       [162.,  45.,  75.,  54.,  70.,  78.,  87.,  32., 123., 178.],
       [ 27., 109., 159.,  97., 120., 135., 144.,  37.,   2.,  84.],
       [ 89.,  46.,  56., 156., 172., 184., 193.,  47., 126., 182.],
       [163., 112.,   4.,   0.,  14.,  30.,  39.,  50.,  10.,  94.],
       [ 28.,  52.,   7.,  55.,  72.,  80.,  90.,  60., 127., 189.],
       [ 96., 113.,  42., 104., 121., 138., 147.,  61.,  20.,  95.],
       [167.,  57.,  77., 164., 173., 187., 196.,  65., 143., 199.],
       [ 31., 115., 110.,   3.,  17.,  33.,  41.,  99.,  22., 102.],
       [ 98.,  63., 130.,  58.,  73.,  81.,  91., 107., 154., 108.],
       [181., 124., 136., 106., 12

## MAIN RUN CODE

In [None]:
def run_sfla(mplx_no, n_iter):
    generate_init_population()
    memeplexes = sort_frog(mplx_no=mplx_no)
    for _ in range(n_iter):
        local_search()
        shuffle_memeplexes(memeplexes)

In [210]:
scores = []
for i in memeplexes:
    for j in i:
        scores.append(StructInfo[j][0])

In [213]:
npScores = np.array(scores)

In [214]:
npScores.sort()

In [217]:
npScores

array([0.52040086, 0.52538213, 0.52560013, 0.53092888, 0.541294  ,
       0.541294  , 0.541294  , 0.541294  , 0.541294  , 0.541294  ,
       0.541294  , 0.541294  , 0.541294  , 0.541294  , 0.541294  ,
       0.541294  , 0.541294  , 0.541294  , 0.541294  , 0.541294  ,
       0.541294  , 0.541294  , 0.541294  , 0.541294  , 0.541294  ,
       0.541294  , 0.541294  , 0.541294  , 0.541294  , 0.541294  ,
       0.541294  , 0.541294  , 0.541294  , 0.541294  , 0.541294  ,
       0.541294  , 0.541294  , 0.55506038, 0.55523138, 0.55523138,
       0.55523138, 0.55523138, 0.55523138, 0.55523138, 0.55523138,
       0.55523138, 0.55523138, 0.55523138, 0.55523138, 0.55523138,
       0.55523138, 0.55523138, 0.55523138, 0.55523138, 0.55523138,
       0.55523138, 0.55523138, 0.55523138, 0.55523138, 0.55523138,
       0.55523138, 0.55576539, 0.55576539, 0.55576539, 0.55576539,
       0.55576539, 0.55576539, 0.55576539, 0.55576539, 0.55576539,
       0.55576539, 0.55576539, 0.55576539, 0.55576539, 0.55576

In [234]:
memeplexes

array([ 8., 21., 11., 27.,  3.])

In [204]:
StructInfo[memeplexes[-1][-1]]

[0.5204008611111113,
 Quaternion(0.8253056370515306, -0.5565910091238468, 0.05185394732923476, -0.0799263546012341)]

In [205]:
StructInfo[memeplexes[0][0]]

[0.657472333333334,
 Quaternion(0.8990546147122486, -0.4105507045149465, 0.05919979793793172, -0.14015813465867544)]

In [76]:
memeplexes = sort_frog(10)

In [None]:
%%time
# im to count the number of memeplexes
# iN to count the number of evolution

Frog_gb = StructInfo[int(memeplexes[0][0])]
FrogsEach = int(frogs/len(memeplexes)) #the number of frogs in each memeplex
weights = [2*(FrogsEach+1-j)/(FrogsEach*(FrogsEach+1)) for j in range(1, FrogsEach+1)] 

Omega = [np.amin(rec_normal), np.amax(rec_normal)]
max_step = (Omega[1]-Omega[0])/2 # maximum step size
q = 3 # int, the number of frogs in submemeplex -- CHANGE
N = 2
for im in range(len(memeplexes)):
    for iN in range(N):
        rValue = rng.random(FrogsEach) * weights # random value with probability weights
        subindex = np.sort(np.argsort(rValue)[::-1][0:q]) # index of selected frogs in memeplex
        submemeplex = memeplexes[im][subindex] 
        
        #--- Improve the worst frog's position ---#
        # Learn from local best Pb #
        Pb = StructInfo[int(submemeplex[0])] # mark the best frog in submemeplex
        Pw = StructInfo[int(submemeplex[q-1])] # mark the worst frog in memeplex
        
        S = rng.random() * (Pb[1] - Pw[1]) 
        Uq = Pw[1] + S
        
        globStep = False
        censorship = False
        
        # Check feasible space and the performance #
        if Omega[0] <= min(Uq) and max(Uq) <= Omega[1]: # check feasible space
            final = np.array([Uq.rotate(i) for i in lig_atom])  
            results = find_score([final, init+1, Uq, im])
            
            if results[0] < Pw[0]:
                globStep = True
        
        if globStep:
            S = rng.random() * (Frog_gb[1] - Pw[1])
            for i in range(4):
                if S[i] > 0:
                    S[i] = min(S[i],max_step)
                else:
                    S[i] = max(S[i],-max_step)
            Uq = Pw[1] + S
            
            if Omega[0] <= min(Uq) and max(Uq) <= Omega[1]: # check feasible space
                final = np.array([Uq.rotate(i) for i in lig_atom])  
                results = find_score([final, init+1, Uq, im])
                if results[0] < Pw[0]:
                    censorship = True
            else:
                censorship = True
        
        if censorship:
            recRandIdx = rng.integers(0, rec_coord.shape[0] - 1)
            ligRandIdx = rng.integers(0, lig_coord.shape[0] - 1)

            axis = rec_coord[recRandIdx]
            a = rec_normal[recRandIdx]
            b = lig_normal[ligRandIdx]
            
            dotProduct = np.dot(a, b)
            theta = np.arccos(dotProduct) * 2 - np.pi
            Quater = Quaternion(axis=a, angle=theta)
            final = np.array([Quater.rotate(i) for i in lig_atom])
            results = find_score([final, init+1, Quater, im])            
        
        
        #StructInfo[im] = [results[0], results[2]]
        shutil.move(os.path.join('poses/', 'out'+str(init+1)+'.pdb'), os.path.join('poses/', 'out'+ str(submemeplex[q-1]) + '.pdb'))
        StructInfo[int(submemeplex[q-1])] = [results[0], results[2]]
        memeplexes[im] = memeplexes[im][np.argsort(memeplexes[im])]

* 5.35
* 3.25

In [83]:
memeplexes

array([[ 7., 28.,  0., 43., 12.],
       [30., 40.,  1., 23., 16.],
       [49., 42.,  2., 44., 21.],
       [13., 39., 34., 25., 24.],
       [32., 46.,  4., 45., 38.],
       [20.,  3., 35., 29.,  5.],
       [26., 10., 14.,  6., 41.],
       [33., 15., 36., 11.,  8.],
       [27., 31., 18., 17., 47.],
       [37., 48., 22., 19.,  9.]])