In [None]:
import torch
import numpy as np
import cv2
from scipy import interpolate
from basicsr.archs.spynet_arch import SpyNet
from basicsr.utils.registry import ARCH_REGISTRY
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def load_spynet(weight_path):
    print("Loading SPyNet model...")
    spynet = ARCH_REGISTRY.get('SpyNet')()
    checkpoint = torch.load(weight_path, map_location='cpu')
    spynet.load_state_dict(checkpoint['params'])
    spynet.eval()
    print("SPyNet model loaded successfully.")
    return spynet

def divide_into_patches(image, patch_size=32):
    h, w, _ = image.shape
    patches = []
    for i in range(0, h, patch_size):
        for j in range(0, w, patch_size):
            patch = image[i:i+patch_size, j:j+patch_size]
            patches.append(patch)
    print(f"Divided image into {len(patches)} patches.")
    return patches

def align_patches(patches, flow, patch_size=32):
    h, w, _ = flow.shape
    aligned_patches = []
    for i, patch in enumerate(patches):
        y, x = (i // (w // patch_size)) * patch_size, (i % (w // patch_size)) * patch_size
        flow_patch = flow[y:y+patch_size, x:x+patch_size]
        aligned_patch = warp_patch(patch, flow_patch)
        aligned_patches.append(aligned_patch)
    print(f"Aligned {len(aligned_patches)} patches.")
    return aligned_patches

def reconstruct_image(patches, image_shape):
    h, w, _ = image_shape
    patch_size = patches[0].shape[0]
    reconstructed = np.zeros(image_shape, dtype=np.float32)
    for i, patch in enumerate(patches):
        y, x = (i // (w // patch_size)) * patch_size, (i % (w // patch_size)) * patch_size
        reconstructed[y:y+patch_size, x:x+patch_size] = patch
    print("Reconstructed image from aligned patches.")
    return reconstructed

def warp_image(image, flow):
    h, w = flow.shape[:2]
    flow_map = np.column_stack((flow[..., 1].ravel(), flow[..., 0].ravel()))
    destination = np.array(list(np.ndindex(h, w))).reshape(h, w, 2)
    source = (destination + flow_map.reshape(h, w, 2)).reshape(-1, 2)
    
    warped = np.zeros_like(image)
    for c in range(3):  # For each color channel
        warped[..., c] = interpolate.griddata(source, image[..., c].ravel(), destination, method='linear', fill_value=0)
    
    print("Warped image based on optical flow.")
    return warped

def warp_patch(patch, flow_patch):
    return warp_image(patch, flow_patch)

def visualize_patches(image, patch_size=32):
    fig, ax = plt.subplots(1, figsize=(10, 10))
    ax.imshow(image)
    h, w, _ = image.shape
    for i in range(0, h, patch_size):
        for j in range(0, w, patch_size):
            rect = patches.Rectangle((j, i), patch_size, patch_size, linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
    plt.title('Image with Patch Grid')
    plt.show()

def visualize_optical_flow(flow):
    # Compute magnitude and angle
    mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
    
    # Normalize magnitude for better visualization
    mag_normalized = cv2.normalize(mag, None, 0, 1, cv2.NORM_MINMAX)
    
    # Create HSV image
    hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
    hsv[..., 0] = ang * 180 / np.pi / 2  # Hue is direction
    hsv[..., 1] = 255  # Full saturation
    hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)  # Value is magnitude
    
    # Convert to BGR
    flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
    
    plt.figure(figsize=(15, 5))
    plt.subplot(131), plt.imshow(flow[..., 0]), plt.title('Flow X'), plt.colorbar()
    plt.subplot(132), plt.imshow(flow[..., 1]), plt.title('Flow Y'), plt.colorbar()
    plt.subplot(133), plt.imshow(flow_rgb), plt.title('Flow (Color-coded)')
    plt.show()

def patch_alignment(image_burst, spynet, device='cuda'):
    aligned_images = []
    reference_image = image_burst[len(image_burst) // 2]  # Use middle image as reference
    print(f"Using image {len(image_burst) // 2} as reference.")
    
    # Visualize patches on reference image
    visualize_patches(reference_image)
    
    for i, image in enumerate(image_burst):
        print(f"\nProcessing image {i+1}/{len(image_burst)}")
        
        # Convert images to PyTorch tensors and move to device
        ref_tensor = torch.from_numpy(reference_image).permute(2, 0, 1).float().unsqueeze(0).to(device)
        img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().unsqueeze(0).to(device)
        
        # Calculate optical flow using SPyNet
        print("Calculating optical flow...")
        with torch.no_grad():
            flow = spynet(ref_tensor, img_tensor)
        
        # Convert flow back to numpy for further processing
        flow_np = flow.squeeze().permute(1, 2, 0).cpu().numpy()
        
        # Visualize optical flow
        visualize_optical_flow(flow_np)
        
        # Divide image into patches
        patches = divide_into_patches(image)
        
        # Align patches based on optical flow
        aligned_patches = align_patches(patches, flow_np)
        
        # Reconstruct aligned image from patches
        aligned_image = reconstruct_image(aligned_patches, image.shape)
        
        # Warp the aligned image
        warped_image = warp_image(aligned_image, flow_np)
        
        aligned_images.append(warped_image)
        
        # Display intermediate results
        plt.figure(figsize=(20, 10))
        plt.subplot(231), plt.imshow(image), plt.title('Original Image')
        plt.subplot(232), plt.imshow(aligned_image), plt.title('Aligned Image (Before Warping)')
        plt.subplot(233), plt.imshow(warped_image), plt.title('Warped Image')
        
        # Display some sample patches
        num_patches = min(3, len(patches))
        for j in range(num_patches):
            plt.subplot(2, 3, 4 + j)
            plt.imshow(patches[j])
            plt.title(f'Sample Patch {j+1}')
        
        plt.tight_layout()
        plt.show()
    
    return aligned_images

if __name__ == "__main__":
    # Load SPyNet model
    spynet = load_spynet('C:\Users\Arnav\Desktop\Image SuperResolution\patchAlignment\spynetWeights.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    spynet = spynet.to(device)
    print(f"Using device: {device}")

    # Load your burst of images
    burst_directory = 'path/to/your/burst/images/'
    image_burst = []
    
    print("Loading burst images...")
    for i in range(14):  # Adjust the range if you have a different number of images
        image_path = os.path.join(burst_directory, f'burst_{i}.jpg')
        if os.path.exists(image_path):
            img = cv2.imread(image_path)
            if img is not None:
                image_burst.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        else:
            print(f"Warning: Image {image_path} not found.")
    
    print(f"Loaded {len(image_burst)} images.")
    
    # Perform patch alignment
    aligned_images = patch_alignment(image_burst, spynet, device)

    # Save only the final aligned image
    final_aligned_image = aligned_images[-1]
    final_aligned_image_bgr = cv2.cvtColor(final_aligned_image.astype(np.uint8), cv2.COLOR_RGB2BGR)
    cv2.imwrite('final_aligned_image.jpg', final_aligned_image_bgr)
    print("Final aligned image saved as 'final_aligned_image.jpg'.")

    # Display final result
    plt.figure(figsize=(10, 5))
    plt.subplot(121), plt.imshow(image_burst[0]), plt.title('First Image in Burst')
    plt.subplot(122), plt.imshow(final_aligned_image), plt.title('Final Aligned Image')
    plt.show()

print("Patch alignment completed.")

In [None]:
import torch
import cv2
import numpy as np


class SyntheticBurstVal(torch.utils.data.Dataset):
    """ Synthetic burst validation set. The validation burst have been generated using the same synthetic pipeline as
    employed in SyntheticBurst dataset.
    """
    def __init__(self, root):
        self.root = root
        self.burst_list = list(range(500))
        self.burst_size = 14

    def __len__(self):
        return len(self.burst_list)

    def _read_burst_image(self, index, image_id):
        im = cv2.imread('{}/{:04d}/im_raw_{:02d}.png'.format(self.root, index, image_id), cv2.IMREAD_UNCHANGED)
        im_t = torch.from_numpy(im.astype(np.float32)).permute(2, 0, 1).float() / (2**14)
        return im_t

    def __getitem__(self, index):
        """ Generates a synthetic burst
                args:
                    index: Index of the burst

                returns:
                    burst: LR RAW burst, a torch tensor of shape
                           [14, 4, 48, 48]
                           The 4 channels correspond to 'R', 'G', 'G', and 'B' values in the RGGB bayer mosaick.
                    seq_name: Name of the burst sequence
                """
        burst_name = '{:04d}'.format(index)
        burst = [self._read_burst_image(index, i) for i in range(self.burst_size)]
        burst = torch.stack(burst, 0)

        return burst, burst_name