In [1]:
import numpy as np
import scipy as sp
from skimage import io
from matplotlib import pyplot as plt
from skimage import transform
from tqdm import tqdm
from time import time
import pickle
import os
import sys
import pandas as pd
from glob import glob
import mrcfile
from skimage import registration as im_reg_met
from pathlib import Path
import warnings
import json

warnings.filterwarnings('ignore')
%matplotlib inline

sys.path.append('./compute results/')
from registration_functions import (
    precompute_w_params, run_fbm,
    run_fbm_laguerre, run_fast_fbm_laguerre,
    fixed_image_precompute 
)
from util_functions import normalize, apply_transform, mean_std_normalize, save_arr_mrc

In [2]:
from pathlib import Path

global_path = '/home/ubuntu/Data/cryo-em samples/synthetic_data_2023'
folders = ['small_trans']#, 'no_trans']#['no_trans', 'small_trans', 'big_trans']
models = ['model2096']#, 'model6185']
names = [('_stack.mrc', '_info.csv')]#, ('_iter_stack.mrc', '_iter_info.csv')]
nums = {models[0]:2}#, models[1]:1}

In [3]:
results_path = Path(global_path) / 'results_compare'
os.makedirs(str(results_path), exist_ok=True)

method = 'fbm'#'fast_fbm_laguerre'
method_path = results_path / method
os.makedirs(str(method_path), exist_ok=True)


functions = {'fbm': run_fbm, 'fbm_laguerre': run_fbm_laguerre, 'fast_fbm_laguerre': run_fast_fbm_laguerre}

func = functions[method]

In [4]:
method = 'fbm_laguerre'#'fbm'#'fast_fbm_laguerre'
method_path = results_path / method
os.makedirs(str(method_path), exist_ok=True)

func = functions[method]

In [5]:
image_radius = 120
pixel_sampling = 0.5
com_offset_initial = 10

lag_func_num=50
lag_scale=5
lag_num_dots = 2000
center = 128, 128
normalization = 'standard'


input_values = {'image_radius': image_radius,
                'pixel_sampling': pixel_sampling,
                'com_offset_initial': com_offset_initial,
                'normalization':normalization}

if 'laguerre' in method:
    input_values['lag_func_num'] = lag_func_num
    input_values['lag_scale'] = lag_scale
    input_values['lag_num_dots'] = lag_num_dots
    

In [6]:
params = precompute_w_params(image_radius=image_radius,
                            pixel_sampling=pixel_sampling, com_offset_initial=com_offset_initial,
                             lag_func_num=lag_func_num, lag_scale=lag_scale, 
                             lag_num_dots=lag_num_dots, compute_zeros=False)
print(params.keys())

for folder in folders:
    for model in models:
        for path in glob(str(Path(global_path) / folder / model / f'{nums[model]}_stack.mrc')):
            print(path)
            with mrcfile.open(path) as mrc:
                seq = mrc.data
            fixed_image = normalize(sp.ndimage.gaussian_filter(seq[0].copy(), 1.3))
            params = fixed_image_precompute(fixed_image,
                                            params, method='fbm', image_radius=input_values['image_radius']) 
#                                             lag_func_num=input_values['lag_func_num'], 
#                                             lag_scale = input_values['lag_scale'],
#                                             lag_num_dots=input_values['lag_num_dots'])

            method_path = results_path / method
            os.makedirs(str(method_path), exist_ok=True)

            opath = method_path / folder / model
            os.makedirs(opath, exist_ok=True)
            opath = method_path / folder / model / str(path).split('/')[-1]
#             if not os.path.exists(opath):
#                 df, fbm_seq, fbm_seq_shift = func(seq, params)
#                 save_arr_mrc(opath, fbm_seq)
#                 df.to_csv(str(opath).replace('stack.mrc', 'info.csv'))

        for snr_value in [0.1, 0.5, 1., 2.]:
            for data_name in glob(str(Path(global_path) / folder / model / 
                                      f'{snr_value}/{nums[model]}_stack.mrc')):
                with mrcfile.open(data_name) as mrc:
                    seq = mrc.data
                df, fbm_seq, fbm_seq_shift = func(seq=seq, func_parameters=params, **input_values)
                opath = method_path / folder / model / str(snr_value)
                os.makedirs(str(opath), exist_ok=True)
                opath = method_path / folder / model / str(snr_value) / data_name.split('/')[-1]
                save_arr_mrc(opath, fbm_seq)
                df.to_csv(str(opath).replace('stack.mrc', 'info.csv'))
                with open( method_path / folder /'params.json', 'w') as file:
                    json.dump(input_values, file)
#                 print(df.head())


120
s_ang = 753.9822368615503
s_rad = 239.99999999999997
bandwidth = 376.99111843077515
len(Im1) 63 len(Ih1) 21 len(Imm) 378
dict_keys(['integration_intervals', 'alphas', 'laguerre_functions', 'precomputed_c1_coefs', 'precomputed_c2_coefs', 'precomputed_coef_exp'])
/home/ubuntu/Data/cryo-em samples/synthetic_data_2023/small_trans/model2096/2_stack.mrc


100%|███████████████████████████████████████████| 99/99 [06:53<00:00,  4.18s/it]
100%|███████████████████████████████████████████| 99/99 [06:56<00:00,  4.21s/it]
100%|███████████████████████████████████████████| 99/99 [17:31<00:00, 10.62s/it]
100%|███████████████████████████████████████████| 99/99 [15:27<00:00,  9.37s/it]


In [9]:
for snr_value in [2.]:
    print('SNR', snr_value)
    gt_df2 = pd.read_csv(f'/home/ubuntu/Data/cryo-em samples/synthetic_data_2023/results_model2096_small/{method}/' \
                         + f'small_trans/model2096/{snr_value}/2_info.csv')
    print(gt_df2.head(4))
    
    

SNR 2.0
   Unnamed: 0         x         y         ang       ksi  eta_prime  \
0           0  1.653695  2.653109   63.704244  1.013417   4.712389   
1           1  0.807621  1.295709   36.966844  1.013417   3.769911   
2           2 -0.795667 -3.073046  159.194960  4.459035   1.884956   
3           3  2.088625 -0.896299  197.391247  5.877819   2.199115   

   omega_prime  dft_x  dft_y  
0     1.033309   0.93   5.49  
1     0.566653   2.34   0.15  
2     2.699936   0.37  -4.76  
3     3.366587   5.24   1.26  
