In [14]:
import torch
import numpy as np
from PIL import Image
from diffusers import LDMSuperResolutionPipeline
from hyperscope import config
from tqdm import tqdm

def minmax_norm(img):
    return (img - img.min())/(img.max() - img.min())

def split_image_into_patches(image, patch_size=128):
    """Split image into patches of size patch_size x patch_size."""
    if isinstance(image, Image.Image):
        width, height = image.size
        image = np.array(image).astype(np.uint8)
    else:
        height, width = image.shape[:2]
    
    patches = []
    positions = []
    
    for y in range(0, height, patch_size):
        for x in range(0, width, patch_size):
            # Handle edge cases where patch might be smaller than patch_size
            h = min(patch_size, height - y)
            w = min(patch_size, width - x)
            
            # If patch is smaller than patch_size, pad it
            patch = image[y:y+h, x:x+w]
            if h < patch_size or w < patch_size:
                if len(image.shape) == 3:  # RGB image
                    padded = np.zeros((patch_size, patch_size, image.shape[2]), dtype=np.uint8)
                else:  # Grayscale image
                    padded = np.zeros((patch_size, patch_size), dtype=np.uint8)
                padded[:h, :w] = patch
                patch = padded
            
            if isinstance(patch, np.ndarray):
                patch = Image.fromarray(patch)
            
            patches.append(patch)
            positions.append((x, y, w, h))  # Store original width and height for reconstruction
    
    return patches, positions

def reassemble_patches(patches, positions, original_size):
    """Reassemble patches back into a complete image."""
    width, height = original_size
    upscale_factor = 4  # LDM super-resolution upscales by 4x
    final_image = Image.new('RGB', (width * upscale_factor, height * upscale_factor))
    
    for patch, (x, y, orig_w, orig_h) in zip(patches, positions):
        # Calculate the size of the upscaled patch
        upscaled_w = orig_w * upscale_factor
        upscaled_h = orig_h * upscale_factor
        
        # If the patch was padded, we need to crop it back to the original aspect ratio
        if orig_w < 128 or orig_h < 128:
            patch = patch.crop((0, 0, upscaled_w, upscaled_h))
        
        # Paste the patch into the correct position
        final_image.paste(patch, (x * upscale_factor, y * upscale_factor))
    
    return final_image

def process_image(image_path, pipeline, patch_size=128):
    # Load and preprocess image
    img = np.load(image_path).astype(np.uint16)
    img = img[200:860, 117:1440]
    img = minmax_norm(img)
    img = Image.fromarray(img * 255).convert("RGB")
    
    # Get original size
    original_size = img.size
    
    # Split into patches
    patches, positions = split_image_into_patches(img, patch_size)
    
    # Process each patch
    upscaled_patches = []
    for patch in tqdm(list(patches)):
        upscaled_patch = pipeline(patch, num_inference_steps=2, eta=1).images[0]
        upscaled_patches.append(upscaled_patch)
    
    # Reassemble patches
    final_image = reassemble_patches(upscaled_patches, positions, original_size)
    
    return final_image

# Setup pipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "CompVis/ldm-super-resolution-4x-openimages"
pipeline = LDMSuperResolutionPipeline.from_pretrained(model_id)
pipeline = pipeline.to(device)

# Process image
image_path = config.INTERIM_DATA_DIR / "worms" / "extracted" / "mkate_live_2025-01-13T10-31-47.631_3_page_25.npy"
final_image = process_image(image_path, pipeline)

# Save the final image
final_image.save("ldm_generated_image_full.png")


Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]

An error occurred while trying to fetch /home/rich/.cache/huggingface/hub/models--CompVis--ldm-super-resolution-4x-openimages/snapshots/0b55ddf931a8e3a1b426b3a50ddcf325ff84f668/unet: Error no file named diffusion_pytorch_model.safetensors found in directory /home/rich/.cache/huggingface/hub/models--CompVis--ldm-super-resolution-4x-openimages/snapshots/0b55ddf931a8e3a1b426b3a50ddcf325ff84f668/unet.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
An error occurred while trying to fetch /home/rich/.cache/huggingface/hub/models--CompVis--ldm-super-resolution-4x-openimages/snapshots/0b55ddf931a8e3a1b426b3a50ddcf325ff84f668/vqvae: Error no file named diffusion_pytorch_model.safetensors found in directory /home/rich/.cache/huggingface/hub/models--CompVis--ldm-super-resolution-4x-openimages/snapshots/0b55ddf931a8e3a1b426b3a50ddcf325ff84f668/vqvae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
  0%|          | 0

  0%|          | 0/2 [00:00<?, ?it/s]

  2%|▏         | 1/66 [00:14<16:03, 14.82s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

  3%|▎         | 2/66 [00:29<15:47, 14.81s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

  5%|▍         | 3/66 [00:44<15:40, 14.93s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

  6%|▌         | 4/66 [00:59<15:32, 15.05s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

  8%|▊         | 5/66 [01:14<15:17, 15.04s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

  9%|▉         | 6/66 [01:29<14:56, 14.94s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 11%|█         | 7/66 [01:44<14:42, 14.95s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 12%|█▏        | 8/66 [01:59<14:23, 14.88s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 14%|█▎        | 9/66 [02:13<13:50, 14.57s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 15%|█▌        | 10/66 [02:26<13:13, 14.16s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 17%|█▋        | 11/66 [02:40<12:51, 14.03s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 18%|█▊        | 12/66 [02:53<12:31, 13.91s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 20%|█▉        | 13/66 [03:07<12:09, 13.76s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 21%|██        | 14/66 [03:21<11:54, 13.75s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 23%|██▎       | 15/66 [03:34<11:41, 13.75s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 24%|██▍       | 16/66 [03:48<11:20, 13.62s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 26%|██▌       | 17/66 [04:01<11:06, 13.60s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 27%|██▋       | 18/66 [04:15<10:51, 13.57s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 29%|██▉       | 19/66 [04:28<10:34, 13.51s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 30%|███       | 20/66 [04:42<10:24, 13.58s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 32%|███▏      | 21/66 [04:55<10:08, 13.52s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 33%|███▎      | 22/66 [05:09<09:53, 13.49s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 35%|███▍      | 23/66 [05:22<09:37, 13.43s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 36%|███▋      | 24/66 [05:36<09:27, 13.51s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 38%|███▊      | 25/66 [05:49<09:10, 13.42s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 39%|███▉      | 26/66 [06:02<08:59, 13.48s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 41%|████      | 27/66 [06:16<08:43, 13.43s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 42%|████▏     | 28/66 [06:29<08:27, 13.37s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 44%|████▍     | 29/66 [06:43<08:18, 13.46s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 45%|████▌     | 30/66 [06:56<08:02, 13.41s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 47%|████▋     | 31/66 [07:10<07:52, 13.49s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 48%|████▊     | 32/66 [07:23<07:37, 13.46s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 50%|█████     | 33/66 [07:37<07:26, 13.53s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 52%|█████▏    | 34/66 [07:50<07:10, 13.44s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 53%|█████▎    | 35/66 [08:03<06:58, 13.49s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 55%|█████▍    | 36/66 [08:17<06:44, 13.49s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 56%|█████▌    | 37/66 [08:30<06:30, 13.47s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 58%|█████▊    | 38/66 [08:44<06:17, 13.49s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 59%|█████▉    | 39/66 [08:57<06:03, 13.46s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 61%|██████    | 40/66 [09:11<05:50, 13.50s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 62%|██████▏   | 41/66 [09:24<05:35, 13.43s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 64%|██████▎   | 42/66 [09:38<05:24, 13.50s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 65%|██████▌   | 43/66 [09:51<05:09, 13.47s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 67%|██████▋   | 44/66 [10:05<04:57, 13.53s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 68%|██████▊   | 45/66 [10:18<04:43, 13.49s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 70%|██████▉   | 46/66 [10:32<04:29, 13.49s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 71%|███████   | 47/66 [10:45<04:16, 13.52s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 73%|███████▎  | 48/66 [10:59<04:02, 13.45s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 74%|███████▍  | 49/66 [11:12<03:50, 13.54s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 76%|███████▌  | 50/66 [11:26<03:35, 13.46s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 77%|███████▋  | 51/66 [11:39<03:22, 13.50s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 79%|███████▉  | 52/66 [11:53<03:08, 13.43s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 80%|████████  | 53/66 [12:06<02:55, 13.48s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 82%|████████▏ | 54/66 [12:20<02:41, 13.48s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 83%|████████▎ | 55/66 [12:33<02:28, 13.49s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 85%|████████▍ | 56/66 [12:47<02:14, 13.49s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 86%|████████▋ | 57/66 [13:00<02:00, 13.39s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 88%|████████▊ | 58/66 [13:13<01:47, 13.47s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 89%|████████▉ | 59/66 [13:27<01:34, 13.44s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 91%|█████████ | 60/66 [13:40<01:20, 13.47s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 92%|█████████▏| 61/66 [13:54<01:07, 13.40s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 94%|█████████▍| 62/66 [14:07<00:53, 13.42s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 95%|█████████▌| 63/66 [14:20<00:40, 13.41s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 97%|█████████▋| 64/66 [14:34<00:26, 13.45s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

 98%|█████████▊| 65/66 [14:47<00:13, 13.46s/it]

  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 66/66 [15:01<00:00, 13.66s/it]


In [None]:
predictions