In [8]:
import matplotlib.pyplot as plt
import numpy as np
import pyfftw
from pyfftw.interfaces.numpy_fft import fftshift
from pyfftw.interfaces.numpy_fft import ifftshift
from pyfftw.interfaces.numpy_fft import rfftn
from pyfftw.interfaces.numpy_fft import irfftn
from pyfftw.interfaces.numpy_fft import rfft2
from pyfftw.interfaces.numpy_fft import irfft2
import scipy as sp
import seaborn as sns

from pyem.ctf import ctf_freqs
from pyem.ctf import eval_ctf
from pyem import mrc
from pyem.star import parse_star
from pyem.star import calculate_apix
from pyem.util import *
from pyem.vop import *

%matplotlib inline
sns.set()
pyfftw.interfaces.cache.enable()

In [None]:
def radial_sum(data, center=(0,0), r=None):
    if r is None:
        y, x = np.indices((data.shape))
        r = np.sqrt((x - center[0])**2 + (y - center[1])**2)
        r = r.astype(np.int)
#     radialsum = sp.stats.binned_statistic(r.ravel(), data.ravel(), "sum", 192)
    tbin = np.bincount(r.ravel(), data.ravel())
    return tbin

In [None]:
def fourier_ring_correlation(p1, p2, r=None):
    center = (0, p1.shape[0] // 2)
    fc = p1 * np.conj(p2)
    fcr = radial_sum(np.real(fc), center, r)
    fcc = radial_sum(np.imag(fc), center, r)
    mag = np.sqrt(radial_sum(np.abs(p1)**2, center, r) * 
                  radial_sum(np.abs(p2)**2, center, r))
    frc = (fcr + fcc*1j) / mag
    return frc

In [2]:
def bincorr(p1, p2, bins):
    bflat = bins.reshape(-1)
    p1flat = p1.reshape(-1)
    p2flat = p2.reshape(-1)
    fc = p1flat * np.conj(p2flat)
    fcr = np.bincount(bflat, np.real(fc))
    fcc = np.bincount(bflat, np.imag(fc))
    mag = np.sqrt(np.bincount(bflat, np.abs(p1flat)**2) * 
                  np.bincount(bflat, np.abs(p2flat)**2))
    frc = (fcr + fcc*1j) / mag
    return frc

In [3]:
def vol_ft(vol, pfac=2, threads=8):
    """ Returns a centered, Nyquist-limited, zero-padded, interpolation-ready 3D Fourier transform.
    :param vol: Volume to be Fourier transformed.
    :param pfac: Size factor for zero-padding.
    """
    vol = grid_correct(np.double(vol), pfac=pfac, order=1)
    padvol = np.pad(vol, (vol.shape[0] * pfac - vol.shape[0]) // 2, "constant")
    ft = rfftn(ifftshift(padvol), padvol.shape, threads=threads)
    ftc = np.zeros((ft.shape[0] + 3, ft.shape[1] + 3, ft.shape[2]), dtype=np.complex128)
    fill_ft(ft, ftc, vol.shape[0])
    return ftc

In [4]:
df = parse_star("meta/tpc1_data_5stack.star", keep_index=False)

In [None]:
df.columns

In [None]:
tpc1 = mrc.read("maps/tpc1_wholemap.mrc", inc_header=False, compat="relion")
vsd2 = mrc.read("maps/tpc1_submap.mrc", inc_header=False, compat="relion")

In [None]:
tpc1_ft = vol_ft(tpc1)
vsd2_ft = vol_ft(vsd2)

In [None]:
rot1 = euler2rot(0, 0, 0)
rot2 = euler2rot(0, 0, np.deg2rad(1))

In [5]:
sz = 256
apix = 1.2156
center = (0, sz // 2)
y, x = np.indices((sz, sz // 2 + 1))
r = np.sqrt((x - center[0])**2 + (y - center[1])**2)
r = r.astype(np.int)
# rbin = (20* r/r.max()).astype(np.int)
r = fftshift(r, axes=0)  # Pre-shift r, leave FFTs shifted.

In [31]:
# s, a = ctf_freqs((sz, sz), apix, full=False)
# s = fftshift(s)
# a = fftshift(a)
sx, sy = np.meshgrid(np.fft.rfftfreq(sz), np.fft.fftfreq(sz))
s = np.sqrt(sx**2 + sy**2)
a = np.arctan2(sy, sx)

In [32]:
ptcl = df.iloc[0]
apix = calculate_apix(ptcl)
c = eval_ctf(s / apix, a,
             ptcl["rlnDefocusU"], ptcl["rlnDefocusV"],
             angast=ptcl["rlnDefocusAngle"], phase=ptcl["rlnPhaseShift"], kv=ptcl["rlnVoltage"],
             ac=ptcl["rlnAmplitudeContrast"], cs=ptcl["rlnSphericalAberration"],
             bf=0, lp=2)

In [None]:
xshift, yshift = ptcl["rlnOriginX"], ptcl["rlnOriginY"]
idx, stack = ptcl["rlnImageName"].split("@")
# p1 = rfft2()
p1r = mrc.read_imgs(stack, int(idx), num=1).squeeze()
p1 = rfft2(p1r, threads=4)

In [None]:
pshift = np.exp(-2 * np.pi * 1j * (xshift * sx + yshift * sy))
p2 = interpolate_slice_numba(vsd2_ft, rot1) * pshift
p3 = interpolate_slice_numba(tpc1_ft, rot2) * pshift
frc = np.abs(bincorr(p1, p3 * c, r))
p1s = p1 - p2 * c * frc[r]
p1s[np.isnan(p1s)] = 0
new_image = fftshift(irfft2(p1s))

In [None]:
f, ax = plt.subplots(2, 2, figsize=(5,5))
f.tight_layout()
ax[0,0].imshow(p1r)
ax[0,1].imshow(fftshift(irfft2(p1)) - fftshift(irfft2(p2 * c)))
ax[1,0].imshow((irfft2(p1)))
ax[1,1].imshow(new_image)