In [1]:
import numpy as np
from numpy.fft import fft2, ifft2, fftshift
import scipy.signal
import matplotlib.pyplot as plt
from skimage import data, util, transform
from skimage.transform import AffineTransform
from matplotlib import animation, rc
rc('animation', html='html5')
rc('image', cmap='gray', interpolation='nearest')
np.random.seed(10)


In [2]:
size = 128
N = 2*size+1
pad = np.zeros((N, N), np.float64)
q = int((N-size)/2)
center = np.s_[q:q+size, q:q+size]
mask = np.ones((N, N), dtype=bool)
mask[center] = False

def SyntheticBlobs():
    mask = data.binary_blobs(length=pad.shape[0], blob_size_fraction=0.07, volume_fraction=.25, seed=1)
    image = np.random.normal(loc = .7, scale=.003, size=mask.shape)
    image[np.invert(mask)] = .3
    return image

In [3]:
def HanningWindow():
    hann = scipy.signal.hanning(size)
    window = np.zeros_like(pad)
    window[center] = np.sqrt(np.outer(hann, hann))
    return window

In [4]:
class WhiteNoise:
    def __init__(self, variance = 0.0):
        self.variance = variance 
        
    def __call__(self, image):
        return util.random_noise(image, var=self.variance)

In [5]:
class MotionModel:
    def __init__(self):
        self.track = [AffineTransform()]
        
    def __iter__(self):
        return self
    
    def __next__(self):
        xform = self.track[-1] + self.transform()
        self.track.append(xform)
        return xform

class NoDrift(MotionModel):
    def transform(self):
        return AffineTransform()
    
class HorizontalDrift(MotionModel):
    def transform(self):
        return AffineTransform(translation = [1,0])

class VerticalDrift(MotionModel):
    def transform(self):
        return AffineTransform(translation = [0,1]) 

class RandomWalk(MotionModel):
    def transform(self):
        T = np.random.choice(np.arange(-3,3,.1), size=2)
        return AffineTransform(translation=T)


In [6]:
def simulate_capture(image, drift, noise):

    walklen = 16
    transforms = [next(drift) for i in range(walklen)]
    imagelist = [transform.warp(image, T) for T in transforms]    
    imagelist = [noise(image) for image in imagelist]    
    return imagelist

def motion_correct(imagelist):
    mu_image = np.zeros_like(pad)
    frame = np.zeros_like(pad)
    window = HanningWindow()
    
    for epoch in range(3):
        mu_image = sum(imagelist)/len(imagelist)
        mu_Fourier = np.conjugate(fft2(mu_image))
        for i, frame in enumerate(imagelist):      
            Sxm = fft2(window*frame)*mu_Fourier
            Rxm = fftshift(np.absolute(ifft2(Sxm/np.absolute(Sxm))))
            argmax = np.unravel_index(Rxm.argmax(), Rxm.shape)
            argmax = (argmax[1]-N/2+.5, argmax[0]-N/2+.5)
            stabilized = transform.warp(frame, AffineTransform(translation=argmax))
            imagelist[i] = stabilized
            yield (frame, (epoch, i), mu_image, Sxm, Rxm, stabilized)

In [8]:
fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(16,12))
ax = axes.flatten()
  
image_artist = [a.imshow(pad, vmin=0, vmax=1) for a in ax]
image_artist[2] = ax[2].imshow(pad, vmin=0, vmax=30)
image_artist[3] = ax[3].imshow(pad, vmin=-10, vmax=0, zorder=0)
P=32
for a in ax:
    a.axis([P,N-P,P,N-P])
ax[3].axis([N/2-N/10, N/2+N/10, N/2-N/10, N/2+N/10])

for a, title in zip(ax, ['Captured image $I_x$', 'Mean image $I_\mu$', 
                         'Cross Spectrum Magnitude $|S_{x\mu}|$', 
                         'Cross Correlation $R_{x\mu}$',
                         'Stabilized Image', '']):
    a.set_title(title)

text1 = ax[1].text(0, N, "Epoch 0", color='cyan', fontsize=32)
track, = ax[3].plot([], [], color='cyan', lw=1, alpha=.5, zorder=2)
truth, = ax[3].plot(
    [],
    # [pad.shape[0]/2 - x[0] -.5 for x in model.walk], [pad.shape[1]/2 - x[1] -.5 for x in model.walk], 
    color='red', zorder=1, lw=10, alpha=.5)

vmark = ax[3].axvline(N/2, color='yellow', linestyle=':', alpha=1)
hmark = ax[3].axhline(N/2, color='yellow', linestyle=':', alpha=1)

walk_x = []
walk_y = []

#def animate(data):
#    capture, epoch, mu_image, Sxm, Rxm, stabilized = data
def animate((capture, epoch, mu_image, Sxm, Rxm, stabilized)):
    for artist, img in zip(image_artist, 
            [capture, mu_image, fftshift(np.log(np.abs(Sxm))), np.log(Rxm), stabilized, pad]):
        artist.set_array(img)
    argmax = np.unravel_index(Rxm.argmax(), Rxm.shape)
    hmark.set_ydata([argmax[0]])
    vmark.set_xdata([argmax[1]])
    walk_x.append(argmax[0])
    walk_y.append(argmax[1])
    track.set_data(walk_y, walk_x)
    text1.set_text('Epoch {}.{}'.format(*epoch))
    return image_artist
    
imagelist = simulate_capture(SyntheticBlobs(), RandomWalk(), WhiteNoise(0.0001))
stabilized = motion_correct(imagelist)
anim = animation.FuncAnimation(fig, animate, frames=stabilized, blit=True)
anim

SyntaxError: invalid syntax (<ipython-input-8-f460d8c101b6>, line 33)