In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import configargparse
from aspire.volume import Volume
from aspire.utils.rotation import Rotation as aspire_Rotation
from scipy.spatial.transform import Rotation as scipy_Rotation
import numpy as np
import os
import logging
import time

from cryomap_align.utils import init_config, try_mkdir, center_vol
from cryomap_align.gauss_opt_utils import run_gaussian_opt
from cryomap_align.opt_refinement import run_nelder_mead_refinement

  from .autonotebook import tqdm as notebook_tqdm
2023-12-08 11:14:43,788	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


I developed this module in a way that it would be easy for me to use in the future, so some of the functionalities it has are not useful for the project (like reading all my inputs from config files). For this reason I will have to do things in a non-optimal way below. This does not impact the method as it's just unimportant software stuff.

In [3]:
import sys

sys.argv = ["--config", "config.ini"]
parser = configargparse.ArgumentParser()

parser.add_argument(
    "-c",
    "--config",
    is_config_file=True,
    help="Path to config file.",
    required=True,
)

init_config(parser)
config = parser.parse_args(sys.argv)

# set up logging
config.full_save_path = os.path.join(config.save_path, config.experiment_name)

if not try_mkdir(config.full_save_path):
    raise SystemError("Could not create output directory")

# create experiment directory
logging.captureWarnings(False)

logger = logging.getLogger()
fhandler = logging.FileHandler(
    filename="log.log", mode="a"
)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
fhandler.setFormatter(formatter)

logger.addHandler(fhandler)
logger.setLevel(logging.INFO)

In [4]:
def run_align_test(vol_fnames, n_iter, config, param_setups):

    init_cand = np.eye(3).reshape(1, 3, 3) # start from identity matrix

    results = {
        "true_quats": np.zeros((len(vol_fnames), len(param_setups), n_iter, 4)),
        "optim_quats_wemd": np.zeros((len(vol_fnames), len(param_setups), n_iter, 4)),
        "refin_quats_wemd": np.zeros((len(vol_fnames), len(param_setups), n_iter, 4)),
        "optim_quats_l2": np.zeros((len(vol_fnames), len(param_setups), n_iter, 4)),
        "refin_quats_l2": np.zeros((len(vol_fnames), len(param_setups), n_iter, 4)),
        "run_time_wemd": np.zeros((len(vol_fnames), len(param_setups), n_iter, 2)),
        "run_time_l2": np.zeros((len(vol_fnames), len(param_setups), n_iter, 2)),
    }

    for i, vol_fname in enumerate(vol_fnames):

        vol_ref = Volume.load(vol_fname)
        vol_ref = center_vol(vol_ref, config)
        
        for j, param in enumerate(param_setups):
            
            config.downsample_res = param[0]
            config.max_iter = param[1]
            vol_ref_ds = vol_ref.downsample(config.downsample_res)

            for k in range(n_iter):
                
                # generate random rotation
                true_rot = aspire_Rotation.generate_random_rotations(1)
                results["true_quats"][i, j, k] = scipy_Rotation.from_matrix(true_rot._matrices[0]).as_quat()
                vol = vol_ref.rotate(true_rot).downsample(config.downsample_res)

                # run for wemd
                config.loss_type = "wemd"
                config.corr_length = 0.75
                
                start_time = time.time()
                opt_rot = run_gaussian_opt(vol, vol_ref_ds, init_cand, config)
                opt_time = time.time()
                ref_rot = run_nelder_mead_refinement(vol, vol_ref_ds, opt_rot, config)
                end_time = time.time()

                results["run_time_wemd"][i, j, k] = [opt_time - start_time, end_time - start_time]
                results["optim_quats_wemd"][i, j, k] = scipy_Rotation.from_matrix(opt_rot).as_quat()
                results["refin_quats_wemd"][i, j, k] = scipy_Rotation.from_matrix(ref_rot).as_quat()

                # run for l2
                config.loss_type = "l2"
                config.corr_length = 1.0

                start_time = time.time()
                opt_rot = run_gaussian_opt(vol, vol_ref_ds, init_cand, config)
                opt_time = time.time()
                ref_rot = run_nelder_mead_refinement(vol, vol_ref_ds, opt_rot, config)
                end_time = time.time()

                results["optim_quats_l2"][i, j, k] = scipy_Rotation.from_matrix(opt_rot).as_quat()
                results["refin_quats_l2"][i, j, k] = scipy_Rotation.from_matrix(ref_rot).as_quat()
                results["run_time_l2"][i, j, k] = [opt_time - start_time, end_time - start_time]


                # output results to log
                err_wemd_opt = np.linalg.norm(results["optim_quats_wemd"][i, j, k] - results["true_quats"][i, j, k]) / np.linalg.norm(results["true_quats"][i, j, k])

                err_wemd_ref = np.linalg.norm(results["refin_quats_wemd"][i, j, k] - results["true_quats"][i, j, k]) / np.linalg.norm(results["true_quats"][i, j, k])

                err_l2_opt = np.linalg.norm(results["optim_quats_l2"][i, j, k] - results["true_quats"][i, j, k]) / np.linalg.norm(results["true_quats"][i, j, k])

                err_l2_ref = np.linalg.norm(results["refin_quats_l2"][i, j, k] - results["true_quats"][i, j, k]) / np.linalg.norm(results["true_quats"][i, j, k])

                logging.info(f"Results for volume {vol_fname}, parameters (downsample_res, max_iter) = {param}, iteration {k}:")
                logging.info(f"Error for wemd opt: {err_wemd_opt}")
                logging.info(f"Error for wemd ref: {err_wemd_ref}")
                logging.info(f"Error for l2 opt: {err_l2_opt}")
                logging.info(f"Error for l2 ref: {err_l2_ref}")
                logging.info(f"Time for wemd opt: {results['run_time_wemd'][i, j, k, 0]}")
                logging.info(f"Time for wemd ref: {results['run_time_wemd'][i, j, k, 1]}")
                logging.info(f"Time for l2 opt: {results['run_time_l2'][i, j, k, 0]}")
                logging.info(f"Time for l2 ref: {results['run_time_l2'][i, j, k, 1]}")

    return results

In [5]:
def run_noise_test(vol_fname, n_iter, config, signal_noise_ratios, param_setups):

    init_cand = np.eye(3).reshape(1, 3, 3) # start from identity matrix

    results = {
        "true_quats": np.zeros((len(signal_noise_ratios), len(param_setups), n_iter, 4)),
        "optim_quats_wemd": np.zeros((len(signal_noise_ratios), len(param_setups), n_iter, 4)),
        "refin_quats_wemd": np.zeros((len(signal_noise_ratios), len(param_setups), n_iter, 4)),
        "optim_quats_l2": np.zeros((len(signal_noise_ratios), len(param_setups), n_iter, 4)),
        "refin_quats_l2": np.zeros((len(signal_noise_ratios), len(param_setups), n_iter, 4)),
        "run_time_wemd": np.zeros((len(signal_noise_ratios), len(param_setups), n_iter, 2)),
        "run_time_l2": np.zeros((len(signal_noise_ratios), len(param_setups), n_iter, 2)),
    }

    vol_ref = Volume.load(vol_fname)
    vol_ref = center_vol(vol_ref, config)
    
    for i, snr in enumerate(signal_noise_ratios):
        noise_std = np.sqrt(np.linalg.norm(vol_ref._data[0])**2 / (vol_ref._data.shape[1] ** 3 * snr))

        for j, param in enumerate(param_setups):
            
            config.downsample_res = param[0]
            config.max_iter = param[1]

            for k in range(n_iter):
                
                # generate random rotation
                true_rot = aspire_Rotation.generate_random_rotations(1)
                results["true_quats"][i, j, k] = scipy_Rotation.from_matrix(true_rot._matrices[0]).as_quat()
                vol = vol_ref.rotate(true_rot)

                # add noise
                vol = vol + np.random.normal(0, noise_std, vol._data[0].shape)
                vol_ref_noisy = vol_ref + np.random.normal(0, noise_std, vol_ref._data[0].shape)

                # downsample
                vol = vol.downsample(config.downsample_res)
                vol_ref_noisy = vol_ref_noisy.downsample(config.downsample_res)

                # run for wemd
                config.loss_type = "wemd"
                config.corr_length = 0.75
                
                start_time = time.time()
                opt_rot = run_gaussian_opt(vol, vol_ref_noisy, init_cand, config)
                opt_time = time.time()
                ref_rot = run_nelder_mead_refinement(vol, vol_ref_noisy, opt_rot, config)
                end_time = time.time()

                results["run_time_wemd"][i, j, k] = [opt_time - start_time, end_time - start_time]
                results["optim_quats_wemd"][i, j, k] = scipy_Rotation.from_matrix(opt_rot).as_quat()
                results["refin_quats_wemd"][i, j, k] = scipy_Rotation.from_matrix(ref_rot).as_quat()

                # run for l2
                config.loss_type = "l2"
                config.corr_length = 1.0

                start_time = time.time()
                opt_rot = run_gaussian_opt(vol, vol_ref_noisy, init_cand, config)
                opt_time = time.time()
                ref_rot = run_nelder_mead_refinement(vol, vol_ref_noisy, opt_rot, config)
                end_time = time.time()

                results["optim_quats_l2"][i, j, k] = scipy_Rotation.from_matrix(opt_rot).as_quat()
                results["refin_quats_l2"][i, j, k] = scipy_Rotation.from_matrix(ref_rot).as_quat()
                results["run_time_l2"][i, j, k] = [opt_time - start_time, end_time - start_time]

    return results

In [6]:
%%capture cap

vol_fnames = [
    #"volumes/emd_3683.map.gz",
    #"volumes/emd_25892.map.gz",
    "volumes/emd_9515.map.gz",
    #"volumes/emd_23006.map.gz",
]
param_setups = [
    [32, 150],
    #[32, 200],
    #[64, 150],
    #[64, 200],
]
n_iter = 3

results = run_align_test(vol_fnames, n_iter, config, param_setups)

2023-12-08 11:14:48,086 INFO [aspire.volume.volume] volumes/emd_9515.map.gz with dtype float32 loaded as None
2023-12-08 11:14:48,093 INFO [root] Centering volume
2023-12-08 11:14:48,177 INFO [root] Center of mass: [ 39.8367519  -36.75667459  78.89296519]
2023-12-08 11:14:52,203 INFO [root] Volume centered
2023-12-08 11:16:11,476 INFO [root] Results for volume volumes/emd_9515.map.gz, parameters (downsample_res, max_iter) = [32, 150], iteration 0:
2023-12-08 11:16:11,478 INFO [root] Error for wemd opt: 0.435671570974375
2023-12-08 11:16:11,481 INFO [root] Error for wemd ref: 0.004104328035205957
2023-12-08 11:16:11,485 INFO [root] Error for l2 opt: 0.4427501529981646
2023-12-08 11:16:11,487 INFO [root] Error for l2 ref: 0.038128035295933276
2023-12-08 11:16:11,490 INFO [root] Time for wemd opt: 16.83385944366455
2023-12-08 11:16:11,491 INFO [root] Time for wemd ref: 25.824294328689575
2023-12-08 11:16:11,496 INFO [root] Time for l2 opt: 16.442373752593994
2023-12-08 11:16:11,497 INFO [