In [None]:
def align_images(images, upsample_factor=100):
    """
    Aligns 2D images, corrects for relative shifts, and returns corrected images along with the phase factors.

    Based on Mansi's code
    
    Input:
    images : list  of 2D images containing the same object.
    upsample_factor : The upsampling factor for subpixel accuracy (default: 100).

    Returns:
    corrected_images : The aligned (shift-corrected) images.
    shifts : The relative shifts (dy, dx) for each image relative to the first image.
    phase_factors : The phase factors applied in Fourier space for each shift.
    """
    # Use the first image as the reference
    ref_image = images[0]
    
    # Prepare outputs
    corrected_images = []
    shifts = []
    phase_factors = []
    
    for i, img in enumerate(images):
        # Compute relative shift between the reference and the current image
        shift, error, diffphase = phase_cross_correlation(
            ref_image, img, upsample_factor=upsample_factor
        )
        shifts.append(shift)
        
        # Apply Fourier-domain shift correction
        shifted_image_fft = fourier_shift(np.fft.fftn(img), shift)
        corrected_image = np.fft.ifftn(shifted_image_fft).real
        corrected_images.append(corrected_image)
        
        # Calculate phase factor for the shift
        ny, nx = img.shape
        y = np.fft.fftfreq(ny)[:, np.newaxis]
        x = np.fft.fftfreq(nx)
        phase_factor = np.exp(-2j * np.pi * (shift[0] * y + shift[1] * x))
        phase_factors.append(phase_factor)
    
    return corrected_images, shifts, phase_factors