### When connecting to a new runtime, execute the below cell first to install the required dependencies.

In [None]:
# Install all the dependencies
!pip install mrcfile numba pyfftw scipy tqdm

## Initialization Part:



In [None]:
# import pacakages
import concurrent.futures
import copy
import math
import multiprocessing
import os
import time

import mrcfile
import numba
import numpy as np
import pyfftw
import scipy.fft
from numba.typed import List
from scipy.ndimage import convolve, correlate
from scipy.spatial.transform import Rotation as R
from tqdm import tqdm

In [None]:
# set fftw params
pyfftw.config.PLANNER_EFFORT = "FFTW_MEASURE"
pyfftw.config.NUM_THREADS = multiprocessing.cpu_count()

In [None]:
class mrc_obj:
    def __init__(self, path):
        mrc = mrcfile.open(path)
        data = mrc.data
        header = mrc.header
        self.xdim = int(header.nx)
        self.ydim = int(header.ny)
        self.zdim = int(header.nz)
        self.xwidth = mrc.voxel_size.x
        self.ywidth = mrc.voxel_size.y
        self.zwidth = mrc.voxel_size.z
        self.cent = np.array([self.xdim * 0.5, self.ydim * 0.5, self.zdim * 0.5,
                              ])
        self.orig = np.array([header.origin.x, header.origin.y, header.origin.z])
        self.data = np.swapaxes(copy.deepcopy(data), 0, 2)
        self.dens = data.flatten()
        self.vec = np.zeros((self.xdim, self.ydim, self.zdim, 3), dtype="float32")
        self.dsum = None
        self.Nact = None
        self.ave = None
        self.std_norm_ave = None
        self.std = None

In [None]:
def mrc_set_vox_size(mrc, th=0.01, voxel_size=7.0):
    # set shape and size
    size = mrc.xdim * mrc.ydim * mrc.zdim
    shape = (mrc.xdim, mrc.ydim, mrc.zdim)

    # if th < 0 add th to all value
    if th < 0:
        mrc.dens = mrc.dens - th
        th = 0.0

    # zero all the values less than threshold
    mrc.dens[mrc.dens < th] = 0.0
    mrc.data[mrc.data < th] = 0.0

    # calculate maximum distance for non-zero entries
    non_zero_index_list = np.array(np.nonzero(mrc.data)).T
    cent_arr = np.array(mrc.cent)
    d2_list = np.linalg.norm(non_zero_index_list - cent_arr, axis=1)
    dmax = max(d2_list)

    print("#dmax=" + str(dmax / mrc.xwidth))
    dmax = dmax * mrc.xwidth

    # set new center
    new_cent = mrc.cent * mrc.xwidth + mrc.orig

    tmp_size = 2 * dmax / voxel_size

    # get the best size suitable for fft operation
    new_xdim = pyfftw.next_fast_len(int(tmp_size))

    # set new origins
    new_orig = new_cent - 0.5 * new_xdim * voxel_size

    # create new mrc object
    mrc_new = copy.deepcopy(mrc)
    mrc_new.orig = new_orig
    mrc_new.xdim = new_xdim
    mrc_new.ydim = new_xdim
    mrc_new.zdim = new_xdim
    mrc_new.cent = new_cent
    mrc_new.xwidth = mrc_new.ywidth = mrc_new.zwidth = voxel_size

    print("Nvox= " + str(mrc_new.xdim) + ", " + str(mrc_new.ydim) + ", " + str(mrc_new.zdim))
    print("cent= " + str(new_cent[0]) + ", " + str(new_cent[1]) + ", " + str(new_cent[2]))
    print("ori= " + str(new_orig[0]) + ", " + str(new_orig[1]) + ", " + str(new_orig[2]))

    return mrc, mrc_new

In [None]:
@numba.jit(nopython=True)
def calc(stp, endp, pos, mrc1_data, fsiv):
    dtotal = 0.0
    pos2 = np.zeros((3,))

    for xp in range(stp[0], endp[0]):
        rx = float(xp) - pos[0]
        rx = rx ** 2
        for yp in range(stp[1], endp[1]):
            ry = float(yp) - pos[1]
            ry = ry ** 2
            for zp in range(stp[2], endp[2]):
                rz = float(zp) - pos[2]
                rz = rz ** 2
                d2 = rx + ry + rz
                v = mrc1_data[xp][yp][zp] * math.exp(-1.5 * d2 * fsiv)
                dtotal += v
                pos2[0] += v * xp
                pos2[1] += v * yp
                pos2[2] += v * zp

    return dtotal, pos2

In [None]:
def fastVEC(mrc_source, mrc_dest, dreso=16.0):

    print("#Start VEC")
    gstep = mrc_source.xwidth
    fs = (dreso / gstep) * 0.5
    fs = fs ** 2
    fsiv = 1.0 / fs
    fmaxd = (dreso / gstep) * 2.0
    print("#maxd= {fmaxd}".format(fmaxd=fmaxd))
    print("#fsiv= " + str(fsiv))

    dsum = 0.0
    Nact = 0

    for x in tqdm(range(mrc_dest.xdim)):
        for y in range(mrc_dest.ydim):
            for z in range(mrc_dest.zdim):
                stp = np.zeros((3,), dtype=np.int32)
                endp = np.zeros((3,), dtype=np.int32)
                ind2 = 0
                ind = 0

                pos = np.zeros((3,), dtype=np.float32)
                pos2 = np.zeros((3,), dtype=np.float32)

                tmpcd = np.zeros((3,), dtype=np.float32)

                v, dtotal, rd = 0.0, 0.0, 0.0

                xyz_arr = np.array((x, y, z))
                pos = (xyz_arr * mrc_dest.xwidth + mrc_dest.orig - mrc_source.orig) / mrc_source.xwidth

                ind = mrc_dest.xdim * mrc_dest.ydim * z + mrc_dest.xdim * y + x

                # check density

                if (
                        pos[0] < 0
                        or pos[1] < 0
                        or pos[2] < 0
                        or pos[0] >= mrc_source.xdim
                        or pos[1] >= mrc_source.ydim
                        or pos[2] >= mrc_source.zdim
                ):
                    mrc_dest.dens[ind] = 0.0
                    mrc_dest.vec[x][y][z] = 0.0
                    continue

                if mrc_source.data[int(pos[0])][int(pos[1])][int(pos[2])] == 0:
                    mrc_dest.dens[ind] = 0.0
                    mrc_dest.vec[x][y][z] = 0.0
                    continue

                # Start Point
                stp = (pos - fmaxd).astype(np.int32)

                # set start and end point
                if stp[0] < 0:
                    stp[0] = 0
                if stp[1] < 0:
                    stp[1] = 0
                if stp[2] < 0:
                    stp[2] = 0

                # End Point
                endp = (pos + fmaxd + 1).astype(np.int32)

                if endp[0] >= mrc_source.xdim:
                    endp[0] = mrc_source.xdim
                if endp[1] >= mrc_source.ydim:
                    endp[1] = mrc_source.ydim
                if endp[2] >= mrc_source.zdim:
                    endp[2] = mrc_source.zdim

                # compute the total density
                dtotal, pos2 = calc(stp, endp, pos, mrc_source.data, fsiv)

                mrc_dest.dens[ind] = dtotal
                mrc_dest.data[x][y][z] = dtotal

                if dtotal == 0:
                    mrc_dest.vec[x][y][z] = 0.0
                    continue

                rd = 1.0 / dtotal

                pos2 *= rd

                tmpcd = pos2 - pos

                dvec = math.sqrt(tmpcd[0] ** 2 + tmpcd[1] ** 2 + tmpcd[2] ** 2)

                if dvec == 0:
                    dvec = 1.0

                rdvec = 1.0 / dvec

                mrc_dest.vec[x][y][z] = tmpcd * rdvec

                dsum += dtotal
                Nact += 1

    print("#End LDP")
    print(dsum)
    print(Nact)

    mrc_dest.dsum = dsum
    mrc_dest.Nact = Nact
    mrc_dest.ave = dsum / float(Nact)
    mrc_dest.std = np.linalg.norm(mrc_dest.dens[mrc_dest.dens > 0])
    mrc_dest.std_norm_ave = np.linalg.norm(mrc_dest.dens[mrc_dest.dens > 0] - mrc_dest.ave)

    print("#MAP AVE={ave} STD={std} STD_norm={std_norm}".format(ave=mrc_dest.ave, std=mrc_dest.std,
                                                                std_norm=mrc_dest.std_norm_ave))
    # return False
    return mrc_dest

In [None]:
@numba.jit(nopython=True)
def rot_pos_mtx(mtx, vec):
    mtx = mtx.astype(np.float32)
    vec = vec.astype(np.float32)
    return vec @ mtx

In [None]:
def rot_mrc(orig_mrc_data, orig_mrc_vec, angle):

    # dimension set to be vec array length
    dim = orig_mrc_vec.shape[0]

    # create array for new positions
    new_pos = np.array(np.meshgrid(np.arange(dim), np.arange(dim), np.arange(dim),)).T.reshape(-1, 3)

    # set the center
    cent = 0.5 * float(dim)

    # get relative positions from center
    new_pos = new_pos - cent
    # print(new_pos)

    # init the rotation by euler angle
    r = R.from_euler("ZYX", angle, degrees=True)
    mtx = r.as_matrix()
    mtx[np.isclose(mtx, 0, atol=1e-15)] = 0

    # rotate the position list to get old positions
    old_pos = rot_pos_mtx(np.flip(mtx).T, new_pos) + cent
    
    # horizontally combine two position array
    combined_arr = np.hstack((old_pos, new_pos))

    # filter values outside the boundaries
    in_bound_mask = (
        (old_pos[:, 0] >= 0)
        & (old_pos[:, 1] >= 0)
        & (old_pos[:, 2] >= 0)
        & (old_pos[:, 0] < dim)
        & (old_pos[:, 1] < dim)
        & (old_pos[:, 2] < dim)
    )

    # get the mask of all the values inside boundary
    combined_arr = combined_arr[in_bound_mask]

    # convert the index to integer
    combined_arr = combined_arr.astype(np.int32)

    # get the old index array
    index_arr = combined_arr[:, 0:3]

    # get the index that has non-zero density by masking
    dens_mask = orig_mrc_data[index_arr[:, 0], index_arr[:, 1], index_arr[:, 2]] != 0.0
    non_zero_rot_list = combined_arr[dens_mask]

    # get the non-zero vec and dens values
    non_zero_vec = orig_mrc_vec[non_zero_rot_list[:, 0], non_zero_rot_list[:, 1], non_zero_rot_list[:, 2]]
    non_zero_dens = orig_mrc_data[non_zero_rot_list[:, 0], non_zero_rot_list[:, 1], non_zero_rot_list[:, 2]]
    new_vec = rot_pos_mtx(np.flip(mtx), non_zero_vec)

    # init new vec and dens array
    new_vec_array = np.zeros_like(orig_mrc_vec)
    new_data_array = np.zeros_like(orig_mrc_data)

    # find the new indices
    new_ind_arr = (non_zero_rot_list[:, 3:6] + cent).astype(int)

    # fill in the values to new vec and dens array
    new_vec_array[new_ind_arr[:,0], new_ind_arr[:,1], new_ind_arr[:,2]] = new_vec
    new_data_array[new_ind_arr[:,0], new_ind_arr[:,1], new_ind_arr[:,2]] = non_zero_dens

    return new_vec_array, new_data_array

In [None]:
def ang_to_mtx_ZYX(angle):
    r = R.from_euler("ZYX", angle, degrees=True)
    mtx = r.as_matrix()
    mtx[np.isclose(mtx, 0, atol=1e-15)] = 0
    mtx = np.flip(mtx).T
    return mtx.astype(np.float32)

In [None]:
# find the best translation based on list of fft results
def find_best_trans_list(input_list):
    
    sum_arr = np.zeros_like(input_list[0])
    for arr in input_list:
        sum_arr = sum_arr + arr
    best = np.amax(sum_arr)
    trans = np.unravel_index(sum_arr.argmax(), sum_arr.shape)
    
    return best, trans

In [None]:
def get_score(
    target_map_data, search_map_data, target_map_vec, search_map_vec, trans, ave1, ave2, std1, std2, pstd1, pstd2
):

    px, py, pz = 0, 0, 0
    dim = target_map_data.shape[0]
    total = 0

    t = np.array(trans)
    if trans[0] > 0.5 * dim:
        t[0] -= dim
    if trans[1] > 0.5 * dim:
        t[1] -= dim
    if trans[2] > 0.5 * dim:
        t[2] -= dim

    target_pos = np.array(np.meshgrid(np.arange(dim), np.arange(dim), np.arange(dim),)).T.reshape(-1, 3)

    search_pos = target_pos + t

    total += np.count_nonzero(target_map_data[target_pos[:, 0], target_pos[:, 1], target_pos[:, 2]])

    combined_arr = np.hstack((target_pos, search_pos))

    combined_arr = combined_arr[
        (combined_arr[:, 3] >= 0)
        & (combined_arr[:, 4] >= 0)
        & (combined_arr[:, 5] >= 0)
        & (combined_arr[:, 3] < dim)
        & (combined_arr[:, 4] < dim)
        & (combined_arr[:, 5] < dim)
    ]

    target_pos = combined_arr[:, 0:3]
    search_pos = combined_arr[:, 3:6]

    d1 = target_map_data[target_pos[:, 0], target_pos[:, 1], target_pos[:, 2]]
    d2 = search_map_data[search_pos[:, 0], search_pos[:, 1], search_pos[:, 2]]
    
    d1 = np.where(d1 <= 0, 0.0, d1)
    d2 = np.where(d2 <= 0, 0.0, d1)

    print(np.sum(d1))
    print(np.sum(d2))

    pd1 = np.where(d1 <= 0, 0.0, d1 - ave1)
    pd2 = np.where(d2 <= 0, 0.0, d2 - ave2)

    cc = np.sum(np.multiply(d1, d2))
    pcc = np.sum(np.multiply(pd1, pd2))

    target_zero_mask = target_map_data[target_pos[:, 0], target_pos[:, 1], target_pos[:, 2]] == 0
    target_non_zero_mask = target_map_data[target_pos[:, 0], target_pos[:, 1], target_pos[:, 2]] > 0
    search_non_zero_mask = search_map_data[search_pos[:, 0], search_pos[:, 1], search_pos[:, 2]] > 0
    search_non_zero_count = np.count_nonzero(np.multiply(target_zero_mask, search_non_zero_mask))

    trimmed_target_vec = target_map_vec[target_pos[:, 0], target_pos[:, 1], target_pos[:, 2]]
    trimmed_search_vec = search_map_vec[search_pos[:, 0], search_pos[:, 1], search_pos[:, 2]]

    total += search_non_zero_count

    sco_arr = np.zeros_like(search_map_data)
    sco = np.einsum("ij,ij->i", trimmed_target_vec, trimmed_search_vec)
    sco_arr[search_pos[:, 0], search_pos[:, 1], search_pos[:, 2]] = sco
    sco_sum = np.sum(sco_arr)
    Nm = np.count_nonzero(np.multiply(target_non_zero_mask, search_non_zero_mask))

    print(
        "Overlap= "
        + str(float(Nm) / float(total))
        + " "
        + str(Nm)
        + "/"
        + str(total)
        + " CC= "
        + str(cc / (std1 * std2))
        + " PCC= "
        + str(pcc / (pstd1 * pstd2))
    )
    print("Score=", sco_sum)
    return sco_arr

In [None]:
def fft_search_score_trans(target_X, target_Y, target_Z, search_vec, a, b, c, fft_object, ifft_object):
    x2 = copy.deepcopy(search_vec[..., 0])
    y2 = copy.deepcopy(search_vec[..., 1])
    z2 = copy.deepcopy(search_vec[..., 2])

    X2 = np.zeros_like(target_X)
    np.copyto(a, x2)
    np.copyto(X2, fft_object(a))
    dot_X = target_X * X2
    np.copyto(b, dot_X)
    dot_x = np.zeros_like(x2)
    np.copyto(dot_x, ifft_object(b))

    Y2 = np.zeros_like(target_Y)
    np.copyto(a, y2)
    np.copyto(Y2, fft_object(a))
    dot_Y = target_Y * Y2
    np.copyto(b, dot_Y)
    dot_y = np.zeros_like(y2)
    np.copyto(dot_y, ifft_object(b))

    Z2 = np.zeros_like(target_Z)
    np.copyto(a, z2)
    np.copyto(Z2, fft_object(a))
    dot_Z = target_Z * Z2
    np.copyto(b, dot_Z)
    dot_z = np.zeros_like(z2)
    np.copyto(dot_z, ifft_object(b))

    return find_best_trans_list([dot_x, dot_y, dot_z])

In [None]:
def fft_search_best_dot(target_list, query_list, a, b, c, fft_object, ifft_object):
    dot_product_list = []
    for target_complex, query_real in zip(target_list, query_list):
        
        query_complex = np.zeros_like(target_complex)
        np.copyto(a, query_real)
        np.copyto(query_complex, fft_object(a))
        dot_complex = target_complex * query_complex
        np.copyto(b, dot_complex)
        dot_real = np.zeros_like(query_real)
        np.copyto(dot_real, ifft_object(b))
        
        dot_product_list.append(dot_real)
        
    return dot_product_list

In [None]:
def fft_search_score_trans_1d(target_X, search_data, a, b, fft_object, ifft_object, mode, ave=None):

    x2 = copy.deepcopy(search_data)

    if mode == "Overlap":
        x2 = np.where(x2 > 0, 1.0, 0.0)
    elif mode == "CC":
        x2 = np.where(x2 > 0, x2, 0.0)
    elif mode == "PCC":
        x2 = np.where(x2 > 0, x2 - ave, 0.0)
    elif mode == "Laplacian":
        weights = np.array(
            [
                [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
                [[0.0, 1.0, 0.0], [1.0, -6.0, 1.0], [0.0, 1.0, 0.0]],
                [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
            ]
        )
        x2 = convolve(search_data, weights, mode="constant")
        #x2 = correlate(x2, weights, mode="constant")

    X2 = np.zeros_like(target_X)
    np.copyto(a, x2)
    np.copyto(X2, fft_object(a))
    dot_X = target_X * X2
    np.copyto(b, dot_X)
    dot_x = np.zeros_like(x2)
    np.copyto(dot_x, ifft_object(b))

    return find_best_trans_list([dot_x])

In [None]:
def search_map_fft(mrc_target, mrc_search, TopN=10, ang=30, mode="VecProduct", is_eval_mode=False, save_path="."):

    time_start = time.time()

    if is_eval_mode:
        print("#For Evaluation Mode")
        print("#Please use the same coordinate system and map size for map1 and map2.")
        print("#Example:")
        print("#In Chimera command line: open map1 and map2 as #0 and #1, then type")
        print("#> open map1.mrc")
        print("#> open map2.mrc")
        print("#> vop #1 resample onGrid #0")
        print("#> volume #2 save new.mrc")
        print("#Chimera will generate the resampled map2.mrc as new.mrc")
        return

    #     x1 = copy.deepcopy(mrc_target.vec[:, :, :, 0])
    #     y1 = copy.deepcopy(mrc_target.vec[:, :, :, 1])
    #     z1 = copy.deepcopy(mrc_target.vec[:, :, :, 2])

    # init the target map vectors
    x1 = copy.deepcopy(mrc_target.vec[:, :, :, 0])

    if mode == "VecProduct":
        y1 = copy.deepcopy(mrc_target.vec[:, :, :, 1])
        z1 = copy.deepcopy(mrc_target.vec[:, :, :, 2])

    # Postprocessing for other modes
    if mode == "Overlap":
        x1 = np.where(mrc_target.data > 0, 1.0, 0.0)
    elif mode == "CC":
        x1 = np.where(mrc_target.data > 0, mrc_target.data, 0.0)
    elif mode == "PCC":
        x1 = np.where(mrc_target.data > 0, mrc_target.data - mrc_target.ave, 0.0)
    elif mode == "Laplacian":
        weights = np.array(
            [
                [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
                [[0.0, 1.0, 0.0], [1.0, -6.0, 1.0], [0.0, 1.0, 0.0]],
                [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
            ]
        )
        x1 = convolve(mrc_target.data, weights, mode="constant")
        #x1 = correlate(mrc_target.data, weights, mode="constant")

    d3 = mrc_target.xdim ** 3

    rd3 = 1.0 / d3

    X1 = np.fft.rfftn(x1)
    X1 = np.conj(X1)

    if mode == "VecProduct":
        Y1 = np.fft.rfftn(y1)
        Y1 = np.conj(Y1)
        Z1 = np.fft.rfftn(z1)
        Z1 = np.conj(Z1)

    x_angle = []
    y_angle = []
    z_angle = []

    i = 0
    while i < 360:
        x_angle.append(i)
        y_angle.append(i)
        i += ang

    i = 0
    while i <= 180:
        z_angle.append(i)
        i += ang

    angle_comb = np.array(np.meshgrid(x_angle, y_angle, z_angle)).T.reshape(-1, 3)
    
#     rot_vec_dict, rot_data_dict = rot_init_cuda(mrc_search.data, mrc_search.vec, angle_comb)

    rot_vec_dict = {}
    rot_data_dict = {}

    with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count() + 4) as executor:
        trans_vec = {executor.submit(rot_mrc, mrc_search.data, mrc_search.vec, angle,): angle for angle in angle_comb}
        for future in concurrent.futures.as_completed(trans_vec):
            angle = trans_vec[future]
            rot_vec_dict[tuple(angle)] = future.result()[0]
            rot_data_dict[tuple(angle)] = future.result()[1]

    time_rot = time.time()

    print("Rotation time: " + str(time_rot-time_start))

    # fftw plans
    a = pyfftw.empty_aligned((x1.shape), dtype="float32")
    b = pyfftw.empty_aligned((a.shape[0], a.shape[1], a.shape[2] // 2 + 1), dtype="complex64")
    c = pyfftw.empty_aligned((x1.shape), dtype="float32")

    fft_object = pyfftw.FFTW(a, b, axes=(0, 1, 2))
    ifft_object = pyfftw.FFTW(b, c, direction="FFTW_BACKWARD", axes=(0, 1, 2), normalise_idft=False)

    angle_score = []

    for angle in tqdm(angle_comb, desc="FFT Process"):
        rot_mrc_vec = rot_vec_dict[tuple(angle)]
        rot_mrc_data = rot_data_dict[tuple(angle)]

        if mode == "VecProduct":
            
            x2 = copy.deepcopy(rot_mrc_vec[..., 0])
            y2 = copy.deepcopy(rot_mrc_vec[..., 1])
            z2 = copy.deepcopy(rot_mrc_vec[..., 2])
                    
            target_list = [X1, Y1, Z1]
            query_list = [x2, y2, z2]
            
            fft_result_list = fft_search_best_dot(target_list, query_list, a, b, c, fft_object, ifft_object)
                    
            best, trans = find_best_trans_list(fft_result_list)
            
        else:
            best, trans = fft_search_score_trans_1d(
                X1, rot_mrc_data, a, b, fft_object, ifft_object, mode, mrc_target.ave
            )
            if mode == "CC":
                rstd2 = 1.0 / mrc_target.std ** 2
                best = best * rstd2
            if mode == "PCC":
                rstd3 = 1.0 / mrc_target.std_norm_ave ** 2
                best = best * rstd3

        angle_score.append([tuple(angle), best * rd3, trans])

    # calculate the ave and std
    score_arr = np.array([row[1] for row in angle_score])
    ave = np.mean(score_arr)
    std = np.std(score_arr)
    print("Std= " + str(std) + " Ave= " + str(ave))

    # sort the list and get topN
    sorted_topN = sorted(angle_score, key=lambda x: x[1], reverse=True)[:TopN]

    for x in sorted_topN:
        print(x)

    time_fft = time.time()

    print("FFT time: " + str(time_fft-time_rot))

    refined_score = []  
    if ang > 5.0:
        
        # setup all the angles for refinement
        # initialize the refinement list by ±5 degrees
        refine_ang_list = []
        for t_mrc in sorted_topN: 
            curr_ang_arr = np.array(
                np.meshgrid(
                    [t_mrc[0][0] - 5, t_mrc[0][0], t_mrc[0][0] + 5],
                    [t_mrc[0][1] - 5, t_mrc[0][1], t_mrc[0][1] + 5],
                    [t_mrc[0][2] - 5, t_mrc[0][2], t_mrc[0][2] + 5],
                )
            ).T.reshape(-1, 3)
            refine_ang_list.append(curr_ang_arr)
        
        refine_ang_arr = np.concatenate(refine_ang_list, axis=0)
        print(refine_ang_arr.shape)
        
        # rotate the mrc vector and data according to the list (multi-threaded)
        with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count() + 4) as executor:
            trans_vec = {executor.submit(rot_mrc, mrc_search.data, mrc_search.vec, angle,): angle for angle in refine_ang_arr}
            for future in concurrent.futures.as_completed(trans_vec):
                angle = trans_vec[future]
                rot_vec_dict[tuple(angle)] = future.result()[0]
                rot_data_dict[tuple(angle)] = future.result()[1]
                
        for angle in tqdm(refine_ang_arr, desc="Refine FFT Process"):
            
            rot_mrc_vec = rot_vec_dict[tuple(angle)]
            rot_mrc_data = rot_data_dict[tuple(angle)]
            
            if mode == "VecProduct":
                x2 = copy.deepcopy(rot_mrc_vec[..., 0])
                y2 = copy.deepcopy(rot_mrc_vec[..., 1])
                z2 = copy.deepcopy(rot_mrc_vec[..., 2])
                    
                target_list = [X1, Y1, Z1]
                query_list = [x2, y2, z2]
            
                fft_result_list = fft_search_best_dot(target_list, query_list, a, b, c, fft_object, ifft_object)
                best, trans = find_best_trans_list(fft_result_list)
            
            else:
                best, trans = fft_search_score_trans_1d(
                    X1, rot_mrc_data, a, b, fft_object, ifft_object, mode, mrc_target.ave
                )
                if mode == "CC":
                    rstd2 = 1.0 / mrc_target.std ** 2
                    best = best * rstd2
                if mode == "PCC":
                    rstd3 = 1.0 / mrc_target.std_norm_ave ** 2
                    best = best * rstd3
        
            refined_score.append([tuple(angle), best * rd3, trans, rot_mrc_vec, rot_mrc_data])
            
        refined_list = sorted(refined_score, key=lambda x: x[1], reverse=True)[:TopN]
    
    else:
        refined_list = sorted_topN

    time_refine = time.time()

    print("Refinement time: " + str(time_refine-time_fft))
    
    # Save the results to file
    for i, t_mrc in enumerate(refined_list):
        
        # calculate the scores
        print("R=" + str(t_mrc[0]) + " T=" + str(t_mrc[2]))
        sco = get_score(
            mrc_target.data,
            t_mrc[4],
            mrc_target.vec,
            t_mrc[3],
            t_mrc[2],
            mrc_target.ave,
            mrc_search.ave,
            mrc_target.std,
            mrc_search.std,
            mrc_target.std_norm_ave,
            mrc_search.std_norm_ave,
        )
        
        # Write result to PDB files
        show_vec(mrc_target.orig, t_mrc[3], t_mrc[4], sco, mrc_search.xwidth, t_mrc[2], "model_top_" + str(i + 1) + ".pdb", save_path)

    time_writefile = time.time()

    print("File Write time: " + str(time_writefile-time_refine))

    return refined_list

In [None]:
def show_vec(origin, sampled_mrc_vec, sampled_mrc_data, sampled_mrc_score, sample_width, trans, file_name, save_path):

    completeName = os.path.join(save_path, file_name)

    dim = sampled_mrc_data.shape[0]

    trans = np.array(trans)

    if 2 * trans[0] > dim:
        trans[0] -= dim
    if 2 * trans[1] > dim:
        trans[1] -= dim
    if 2 * trans[2] > dim:
        trans[2] -= dim

    add = origin - trans * sample_width

    natm = 1
    nres = 1

    with open(completeName, "w") as pdb_file:
      for x in range(dim):
          for y in range(dim):
              for z in range(dim):

                  if sampled_mrc_data[x][y][z] != 0.0:
                      tmp = np.array([x, y, z])
                      tmp = tmp * sample_width + add
                      atom_header = "ATOM{:>7d}  CA  ALA{:>6d}    ".format(natm, nres)
                      atom_content = "{:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f}".format(
                          tmp[0], tmp[1], tmp[2], 1.0, sampled_mrc_score[x][y][z]
                      )
                      pdb_file.write(atom_header + atom_content + "\n")
                      natm += 1

                      tmp = np.array([x, y, z])
                      tmp = (tmp + sampled_mrc_vec[x][y][z]) * sample_width + add
                      atom_header = "ATOM{:>7d}  CB  ALA{:>6d}    ".format(natm, nres)
                      atom_content = "{:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f}".format(
                          tmp[0], tmp[1], tmp[2], 1.0, sampled_mrc_score[x][y][z]
                      )
                      pdb_file.write(atom_header + atom_content + "\n")
                      natm += 1
                      nres += 1

## Change the parameters below before running the cell:

In [None]:
# mount Gdrive (Optional)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
vox_size = 7.0                # Sampling voxel spacing, default=7.0
sample_ang = 30.0             # Sampling angle spacing, default=30.0
dens_thres_target = 0.0       # Threshold of target density map, default=0.0
dens_thres_query = 0.0        # Threshold of query density map, default=0.0
gaussian_bandwith = 16.0      # Bandwidth of the Gaussian filter
topN = 10                     # Number of top models to refine, default=10
match_mode = "VecProduct"     # Mode for VESPER, possible values: "VecProduct", "Overlap", "CC", "PCC", "Laplacian", default="VecProduct"

mrc_target_path = "/content/drive/MyDrive/Data/emd_8097.mrc"                    # path of target map
mrc_query_path = "/content/drive/MyDrive/Data/ChainA_simulated_resample.mrc"    # path of query map
output_path = "."  # path to save result pdb files, note that if using current directory, all data are gone once disconnect from instance 

target_mrc_obj = mrc_obj(mrc_target_path)
query_mrc_obj = mrc_obj(mrc_query_path)

target_mrc_obj, mrc_N1 = mrc_set_vox_size(target_mrc_obj, th=dens_thres_target, voxel_size=vox_size)
query_mrc_obj, mrc_N2 = mrc_set_vox_size(query_mrc_obj, th=dens_thres_query, voxel_size=vox_size)

if mrc_N1.xdim > mrc_N2.xdim:
    dim = mrc_N2.xdim = mrc_N2.ydim = mrc_N2.zdim = mrc_N1.xdim
    mrc_N2.orig = mrc_N2.cent - 0.5 * vox_size * mrc_N2.xdim
else:
    dim = mrc_N1.xdim = mrc_N1.ydim = mrc_N1.zdim = mrc_N2.xdim
    mrc_N1.orig = mrc_N1.cent - 0.5 * vox_size * mrc_N1.xdim

mrc_N1.dens = np.zeros((dim ** 3, 1))
mrc_N1.vec = np.zeros((dim, dim, dim, 3), dtype="float32")
mrc_N1.data = np.zeros((dim, dim, dim))
mrc_N2.dens = np.zeros((dim ** 3, 1))
mrc_N2.vec = np.zeros((dim, dim, dim, 3), dtype="float32")
mrc_N2.data = np.zeros((dim, dim, dim))

mrc_N1 = fastVEC(target_mrc_obj, mrc_N1, dreso=gaussian_bandwith)
mrc_N2 = fastVEC(query_mrc_obj, mrc_N2, dreso=gaussian_bandwith)

score_list = search_map_fft(mrc_N1, mrc_N2, TopN=topN, ang=sample_ang, mode=match_mode, save_path=output_path)

## Visualization of the result (Optional):

In [None]:
!pip install biopython nglview ipywidgets py3dmol

In [None]:
from Bio.PDB import *
import nglview as nv
import ipywidgets

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
pdb_file_list = ["model_top_" + str(i+1) + ".pdb" for i in range(topN)]
file_list = [os.path.join(output_path, pdb_file) for pdb_file in pdb_file_list]

pdb_parser = PDBParser()
io=PDBIO()

structure_list = []
for idx, file_name in enumerate(file_list):
  structure = pdb_parser.get_structure("PDB_" + str(idx), file_name)
  io.set_structure(structure)
  io.save("PDB_" + str(idx + 1)+"_clean.txt")
  structure_list.append(structure)

In [None]:
# w = nv.show_biopython(structure_list[0][0])
# w

In [None]:
import py3Dmol
view=py3Dmol.view(1280,720)
for i in range(topN):
  view.addModel(open(os.path.join(output_path, "PDB_" + str(i+1) +"_clean.txt"), 'r').read(),'pdb')
  view.setStyle({"sphere":{'color':'grey'}})
  #view.addSurface(py3Dmol.MS, {'opacity': 0.5})

view.zoomTo()
view.show()