In [1]:
import os
from tqdm import tqdm
from multiprocessing import Pool

import timm
import faiss
import torch
from hat_arch import HAT
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset

import cv2
import numpy as np
import pandas as pd
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Function to create the patches directory
def create_patches_dir(output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

# Function to split the image into patches
def split_image_into_patches(image, patch_size, overlap):
    h, w, _ = image.shape
    stride = int(patch_size * (1 - overlap))
    
    patches = []
    for y in range(0, h - patch_size + 1, stride):
        for x in range(0, w - patch_size + 1, stride):
            patch = image[y:y + patch_size, x:x + patch_size]
            patches.append(patch)
    return patches

# Function to process a single image and save patches
def process_image(img_info):
    img_path, output_folder, patch_size, overlap = img_info
    image = cv2.imread(img_path)
    
    # Resize the image to 1/4th of its original size
    image = cv2.resize(image, (image.shape[1] // 32, image.shape[0] // 32), interpolation=cv2.INTER_AREA)
    
    patches = split_image_into_patches(image, patch_size, overlap)
    img_file = os.path.basename(img_path)

    for i, patch in enumerate(patches):
        patch_name = f"{os.path.splitext(img_file)[0]}_patch_{i}.png"
        patch_path = os.path.join(output_folder, patch_name)
        cv2.imwrite(patch_path, patch)

# Main function to process the folders using multiprocessing
def process_folders(input_folder1, input_folder2, output_folder, patch_size=8, overlap=0.5, num_workers=8):
    create_patches_dir(output_folder)
    
    image_info = []
    
    # Gather image paths from both folders
    for folder in [input_folder1, input_folder2]:
        image_files = [f for f in os.listdir(folder) if f.endswith(('.png', '.jpg', '.jpeg', '.tif'))]
        
        for img_file in image_files:
            img_path = os.path.join(folder, img_file)
            image_info.append((img_path, output_folder, patch_size, overlap))
    
    # Use multiprocessing to process the images
    with Pool(num_workers) as pool:
        list(tqdm(pool.imap_unordered(process_image, image_info), total=len(image_info), desc="Processing images"))

# Input folders (replace with actual folder paths)
input_folder1 = "Flickr2K"
input_folder2 = "DIV2K"
output_folder = "patches"

# Call the function
process_folders(input_folder1, input_folder2, output_folder)

Processing images: 100%|████████████████████| 3450/3450 [02:34<00:00, 22.35it/s]


In [3]:
# Custom Dataset to load images
class ImageDataset(Dataset):
    def __init__(self, image_folder, transform):
        self.image_folder = image_folder
        self.image_files = [f for f in os.listdir(image_folder) if f.endswith(('.png', '.jpg', '.jpeg', '.tif'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, self.image_files[idx])
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            image = self.transform(image)
        return image, self.image_files[idx]

# Initialize the VGG model from torchvision
def get_vgg_model():
    model = models.vgg16(pretrained=True)
    model = torch.nn.Sequential(*list(model.children())[:-2])  # Remove the classifier layers, keep convolutional layers
    model = model[0][:-9]
    model = model.cuda()  # Move model to GPU
    model.eval()  # Set model to evaluation mode
    return model

# Function to process a batch of images and extract embeddings
def extract_embeddings_batch(model, batch_images):
    # Move the batch of images to the GPU
    batch_tensor = batch_images.cuda()  # No need to stack, DataLoader already returns tensors
    
    # Extract embeddings (forward pass)
    with torch.no_grad():
        embeddings = model(batch_tensor).flatten(1).cpu().numpy().astype(np.float32)  # Flatten and convert to numpy
    
    return embeddings

# Main function to process images in batches and save embeddings to CSV
def process_images_and_save_embeddings(image_folder, output_csv, batch_size=64, num_workers=8):
    # Get the VGG model
    model = get_vgg_model()

    # Use transforms suitable for VGG16 models
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((8, 8)),  # Resize to 8
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create a custom dataset
    dataset = ImageDataset(image_folder, transform)

    # Use DataLoader with multiple workers to parallelize data loading
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    # Initialize the CSV writer and append after the first batch
    with open(output_csv, 'a') as f_output:
        for batch_images, image_files in tqdm(dataloader, desc="Processing images"):
            # Extract embeddings for the batch
            embeddings = extract_embeddings_batch(model, batch_images)

            # Write embeddings to the CSV file
            for img_file, embedding in zip(image_files, embeddings):
                f_output.write(f"{img_file}," + ",".join(map(str, embedding)) + "\n")
                
            # Clear memory after each batch
            del batch_images, image_files, embeddings  # Delete large variables
            torch.cuda.empty_cache()  # Clear unused VRAM

    # Final memory cleanup
    del model  # Remove model from memory
    torch.cuda.empty_cache()  # Clear any remaining VRAM

# Define input and output paths
image_folder = "patches"
output_csv = "perceptual_embeddings.csv"

# Call the function to process images and save embeddings
process_images_and_save_embeddings(image_folder, output_csv, batch_size=64, num_workers=8)

# Complete cleanup after processing
torch.cuda.empty_cache()  # Ensure all VRAM is cleared

Processing images: 100%|████████████████████| 6812/6812 [10:15<00:00, 11.07it/s]


In [4]:
# Initialize the HAT model
def get_hat_model():
    model = HAT(
        upscale=4,
        in_chans=3,
        img_size=64,
        window_size=16,
        compress_ratio=3,
        squeeze_factor=30,
        conv_scale=0.01,
        overlap_ratio=0.5,
        img_range=1.,
        depths=[6]*6,
        embed_dim=180,
        num_heads=[6]*6,
        mlp_ratio=2,
        upsampler='pixelshuffle',
        resi_connection='1conv'
    )
    # Load weights
    checkpoint = torch.load('Real_HAT_GAN_sharper.pth')
    model.load_state_dict(checkpoint['params_ema'], strict=True)
    model.eval()
    model = model.cuda()
    return model

# Function to perform tiled inference
def super_resolve_image(model, input_image_path, tile_size=256, tile_overlap=32, window_size=16):
    # Open the input image
    img = Image.open(input_image_path).convert('RGB')
    img_width, img_height = img.size

    # Prepare the image transform
    img_transform = transforms.Compose([
        transforms.ToTensor()
    ])
    img_range = model.img_range if hasattr(model, 'img_range') else 1.0
    upscale = model.upscale if hasattr(model, 'upscale') else 1

    # Convert image to tensor
    img_tensor = img_transform(img).unsqueeze(0).cuda()  # Add batch dimension

    # Pad image so that dimensions are multiples of window_size
    mod_pad_w = (window_size - img_width % window_size) % window_size
    mod_pad_h = (window_size - img_height % window_size) % window_size
    padding = (0, 0, mod_pad_w, mod_pad_h)  # Pad right and bottom
    img_padded = F.pad(img_tensor, padding, mode='reflect')

    _, _, padded_height, padded_width = img_padded.shape

    # Prepare output tensor
    output_height = padded_height * upscale
    output_width = padded_width * upscale
    output = torch.zeros(1, 3, output_height, output_width).cuda()

    # Calculate the number of tiles
    stride = tile_size - tile_overlap
    tiles_x = (padded_width + stride - 1) // stride
    tiles_y = (padded_height + stride - 1) // stride

    # Loop over tiles
    for y in tqdm(range(tiles_y), desc='Processing tiles'):
        for x in range(tiles_x):
            start_x = x * stride
            start_y = y * stride
            end_x = min(start_x + tile_size, padded_width)
            end_y = min(start_y + tile_size, padded_height)

            # Extract tile
            input_tile = img_padded[:, :, start_y:end_y, start_x:end_x]

            # Calculate the required padding for the tile to be divisible by window_size
            tile_height = end_y - start_y
            tile_width = end_x - start_x
            pad_h = (window_size - tile_height % window_size) % window_size
            pad_w = (window_size - tile_width % window_size) % window_size

            # Pad tile if necessary
            if pad_h > 0 or pad_w > 0:
                input_tile = F.pad(input_tile, (0, pad_w, 0, pad_h), mode='reflect')

            # Super-resolve the tile
            with torch.no_grad():
                output_tile = model(input_tile)

            # Remove padding from output_tile if input_tile was padded
            if pad_h > 0 or pad_w > 0:
                output_tile = output_tile[:, :, :tile_height * upscale, :tile_width * upscale]

            # Determine placement in output tensor
            out_start_x = start_x * upscale
            out_start_y = start_y * upscale
            out_end_x = out_start_x + output_tile.shape[-1]
            out_end_y = out_start_y + output_tile.shape[-2]

            # Place the tile into the output image
            output[:, :, out_start_y:out_end_y, out_start_x:out_end_x] = output_tile

            # Clear cache to save VRAM
            del input_tile, output_tile
            torch.cuda.empty_cache()

    # Crop to original image size multiplied by upscale factor
    output = output[:, :, :img_height * upscale, :img_width * upscale]

    # Convert to image and save
    output_img = output.squeeze(0).cpu().clamp(0, img_range)
    if img_range != 1.0:
        output_img = output_img / img_range
    output_img = transforms.ToPILImage()(output_img)

    # Save the image with 'sr' appended to the filename
    base_name, ext = os.path.splitext(os.path.basename(input_image_path))
    output_image_path = os.path.join(
        os.path.dirname(input_image_path), f"{base_name}_sr{ext}"
    )
    output_img.save(output_image_path)

    # Clean up
    del output, img_tensor, img_padded
    torch.cuda.empty_cache()

# Main code
if __name__ == '__main__':
    # Initialize the model
    model = get_hat_model()

    # Input image path (replace with your image path)
    input_image_path = 'photo_2024-09-30_16-47-00.jpg'  # Replace with your image

    # Perform super-resolution
    super_resolve_image(model, input_image_path, tile_size=512, tile_overlap=32, window_size=16)

    # Clean up
    del model
    torch.cuda.empty_cache()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  checkpoint = torch.load('Real_HAT_GAN_sharper.pth')
Processing tiles: 100%|███████████████████████████| 2/2 [00:14<00:00,  7.30s/it]


In [5]:
def load_embeddings(embedding_csv):
    # Load embeddings and filenames
    df = pd.read_csv(embedding_csv, header=None)
    filenames = df[0].values
    embeddings = df.drop(0, axis=1).values.astype('float32')
    return embeddings, filenames

def initialize_vgg_model():
    # Initialize the VGG16 model
    model = models.vgg16(pretrained=True)
    model = torch.nn.Sequential(*list(model.children())[:-2])  # Remove the classifier layers, keep convolutional layers
    model = model[0][:-9]
    model = model.cuda()
    model.eval()
    return model

def compute_patch_embeddings(model, patches, transform):
    # Compute embeddings for a list of image patches
    embeddings = []
    with torch.no_grad():
        for patch in patches:
            # Apply transformations
            input_tensor = transform(patch).unsqueeze(0).cuda()
            # Extract embedding
            embedding = model(input_tensor).flatten(1).cpu().numpy().astype(np.float32)
            embeddings.append(embedding)
    embeddings = np.vstack(embeddings)
    return embeddings

def find_nearest_patches(patch_embeddings, database_embeddings):
    # Use FAISS for efficient nearest neighbor search
    index = faiss.IndexFlatL2(database_embeddings.shape[1])
    index.add(database_embeddings)
    distances, indices = index.search(patch_embeddings, 1)  # k=1 for nearest neighbor
    return indices.flatten()

def create_mosaic_image(sr_image_path, patches_folder, embeddings_csv, output_image_path):
    # Load the SR image
    sr_image = Image.open(sr_image_path).convert('RGB')
    sr_width, sr_height = sr_image.size

    # Crop the SR image to make dimensions divisible by 16
    new_width = (sr_width // 8) * 8
    new_height = (sr_height // 8) * 8
    sr_image = sr_image.crop((0, 0, new_width, new_height))

    # Divide the SR image into 16 patches
    patches = []
    positions = []  # To keep track of where each patch belongs
    for y in range(0, new_height, 8):
        for x in range(0, new_width, 8):
            patch = sr_image.crop((x, y, x + 8, y + 8))
            patches.append(patch)
            positions.append((x, y))

    # Initialize the VGG model and transform
    vgg_model = initialize_vgg_model()
    transform = transforms.Compose([
        transforms.Resize((8, 8)),  # Ensure size is 8
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  std=[0.229, 0.224, 0.225])
    ])

    # Compute embeddings for the SR patches
    print("Computing embeddings for SR image patches...")
    sr_patch_embeddings = compute_patch_embeddings(vgg_model, patches, transform)

    # Load the precomputed embeddings from the CSV file
    print("Loading precomputed embeddings...")
    database_embeddings, database_filenames = load_embeddings(embeddings_csv)

    # Find the nearest patches from the dataset
    print("Finding nearest patches...")
    nearest_indices = find_nearest_patches(sr_patch_embeddings, database_embeddings)

    # Create a mapping from index to filename
    index_to_filename = {i: fname for i, fname in enumerate(database_filenames)}

    # Create the mosaic image
    print("Creating the mosaic image...")
    mosaic_image = Image.new('RGB', (new_width, new_height))
    for idx, (x, y) in tqdm(enumerate(positions), total=len(positions)):
        nearest_index = nearest_indices[idx]
        nearest_filename = index_to_filename[nearest_index]
        # Load the patch image
        patch_image_path = os.path.join(patches_folder, nearest_filename)
        patch_image = Image.open(patch_image_path).convert('RGB')
        patch_image = patch_image.resize((8, 8))
        # Paste the patch into the mosaic image
        mosaic_image.paste(patch_image, (x, y))

    # Save the mosaic image
    mosaic_image.save(output_image_path)
    print(f"Mosaic image saved to {output_image_path}")

    # Clean up
    del vgg_model
    torch.cuda.empty_cache()

# Example usage
if __name__ == '__main__':
    # Paths
    sr_image_path = 'photo_2024-09-30_16-47-00_sr.jpg'  # Replace with your SR image path
    patches_folder = 'patches'  # Folder containing the patches
    embeddings_csv = 'perceptual_embeddings.csv'  # CSV file with embeddings
    output_image_path = 'mosaic_image.jpg'  # Output mosaic image path

    # Create the mosaic image
    create_mosaic_image(sr_image_path, patches_folder, embeddings_csv, output_image_path)

Computing embeddings for SR image patches...
Loading precomputed embeddings...
Finding nearest patches...
Creating the mosaic image...


100%|████████████████████████████████| 307200/307200 [00:24<00:00, 12326.73it/s]


Mosaic image saved to mosaic_image.jpg
