In [1]:
import torch
import numpy as np
from basicsr.archs.spynet_arch import SpyNet
from basicsr.utils.registry import ARCH_REGISTRY
import matplotlib.pyplot as plt
from PIL import Image
import os

def load_spynet(weight_path):
    print("Loading SPyNet model...")
    spynet = ARCH_REGISTRY.get('SpyNet')()
    checkpoint = torch.load(weight_path, map_location='cpu')
    spynet.load_state_dict(checkpoint['params'])
    spynet.eval()
    print("SPyNet model loaded successfully.")
    return spynet

def visualize_optical_flow(flow):
    mag = np.sqrt(flow[..., 0]**2 + flow[..., 1]**2)
    ang = np.arctan2(flow[..., 1], flow[..., 0])
    hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.float32)
    hsv[..., 0] = ang * 180 / np.pi / 2
    hsv[..., 1] = 255
    hsv[..., 2] = np.minimum(mag * 4, 255)
    rgb = plt.cm.hsv(hsv[..., 0] / 180.0)[..., :3]
    return rgb

def patch_alignment(image_burst, spynet, device='cuda'):
    aligned_images = []
    reference_image = image_burst[len(image_burst) // 2]
    print(f"Using image {len(image_burst) // 2} as reference.")
    
    for i, image in enumerate(image_burst):
        print(f"\nProcessing image {i+1}/{len(image_burst)}")
        
        ref_tensor = torch.from_numpy(reference_image).permute(2, 0, 1).float().unsqueeze(0).to(device)
        img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().unsqueeze(0).to(device)
        
        print("Calculating optical flow...")
        with torch.no_grad():
            flow = spynet(ref_tensor, img_tensor)
        
        flow_np = flow.squeeze().permute(1, 2, 0).cpu().numpy()
        
        print("Warping image based on optical flow...")
        h, w = flow_np.shape[:2]
        flow_map = np.column_stack((flow_np[..., 1].ravel(), flow_np[..., 0].ravel()))
        destination = np.array(list(np.ndindex(h, w))).reshape(h, w, 2)
        source = (destination + flow_map.reshape(h, w, 2)).reshape(-1, 2)
        
        warped = np.zeros_like(image)
        for c in range(image.shape[2]):
            warped[:,:,c] = np.array(Image.fromarray(image[:,:,c]).transform(
                (w, h), Image.AFFINE, source.flatten(), resample=Image.BILINEAR))
        
        aligned_images.append(warped)
        
        # Visualize intermediate results
        plt.figure(figsize=(20, 5))
        plt.subplot(141), plt.imshow(image), plt.title('Original Image')
        plt.subplot(142), plt.imshow(visualize_optical_flow(flow_np)), plt.title('Optical Flow')
        plt.subplot(143), plt.imshow(warped), plt.title('Warped Image')
        plt.subplot(144), plt.imshow(np.abs(reference_image - warped)), plt.title('Difference')
        plt.tight_layout()
        plt.show()
    
    return aligned_images

if __name__ == "__main__":
    # Load SPyNet model
    spynet = load_spynet(r'C:\Users\Arnav\Desktop\Image SuperResolution\patchAlignment\spynetWeights.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    spynet = spynet.to(device)
    print(f"Using device: {device}")

    # Load your burst of images
    burst_directory = 'path/to/your/burst/images/'
    image_burst = []
    
    print("Loading burst images...")
    for filename in sorted(os.listdir(burst_directory)):
        if filename.endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(burst_directory, filename)
            image = np.array(Image.open(image_path).convert('RGB')) / 255.0
            image_burst.append(image)
    
    print(f"Loaded {len(image_burst)} images.")

    # Perform patch alignment
    aligned_images = patch_alignment(image_burst, spynet, device)

    # Save the final aligned image
    final_aligned_image = aligned_images[-1]
    Image.fromarray((final_aligned_image * 255).astype(np.uint8)).save('final_aligned_image.jpg')
    print("Final aligned image saved as 'final_aligned_image.jpg'.")

    # Display results
    plt.figure(figsize=(20, 5))
    plt.subplot(131), plt.imshow(image_burst[0]), plt.title('First Burst Image')
    plt.subplot(132), plt.imshow(image_burst[len(image_burst)//2]), plt.title('Reference Image')
    plt.subplot(133), plt.imshow(final_aligned_image), plt.title('Final Aligned Image')
    plt.tight_layout()
    plt.show()

print("Patch alignment completed.")

ImportError: DLL load failed while importing cv2: The specified module could not be found.

In [7]:
import cv2
print(cv2.__version__)

ImportError: DLL load failed while importing cv2: The specified module could not be found.