In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from google.colab import drive
import os
import cv2
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

print("GPU Available:", torch.cuda.is_available())

# Print GPU details
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

GPU Available: True
GPU Name: Tesla T4


In [7]:
drive.mount('/content/drive')
train_input_images_folder = '/content/drive/My Drive/train/input/'
train_target_images_folder = '/content/drive/My Drive/train/target/'

train_input_patches_folder = "train_input_patches"
os.makedirs(train_input_patches_folder, exist_ok=True)

train_target_patches_folder = "train_target_patches"
os.makedirs(train_target_patches_folder, exist_ok=True)

test_input_lmdb_folder = '/content/drive/My Drive/test/input.lmdb/'
test_target_lmdb_folder = '/content/drive/My Drive/test/target.lmdb/'

test_input_images_folder = "test_input_images"
os.makedirs(test_input_images_folder, exist_ok=True)

test_target_images_folder = "test_target_images"
os.makedirs(test_target_images_folder, exist_ok=True)

test_input_patches_folder = "test_input_patches"
os.makedirs(test_input_patches_folder, exist_ok=True)

test_target_patches_folder = "test_target_patches"
os.makedirs(test_target_patches_folder, exist_ok=True)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [8]:
# patcher function

def compute_stride(image_size, patch_size):
    """
    Compute the optimal stride to ensure full coverage while maximizing uniform patches.
    """
    num_patches = (image_size + patch_size - 1) // patch_size  # Equivalent to math.ceil
    return (image_size - patch_size) // (num_patches - 1) if num_patches > 1 else 1

def extract_patches(image, patch_size=256):
    """
    Extracts patches ensuring full coverage.
    """
    _, h, w = image.shape
    stride_h = compute_stride(h, patch_size)
    stride_w = compute_stride(w, patch_size)

    # Generate patch indices efficiently
    i_vals = torch.arange(0, h - patch_size + 1, stride_h)
    j_vals = torch.arange(0, w - patch_size + 1, stride_w)

    # Compute number of patches
    num_patches_h = i_vals.shape[0]
    num_patches_w = j_vals.shape[0]

    # Preallocate tensor for patches
    patches = torch.empty((num_patches_h * num_patches_w, 3, patch_size, patch_size), dtype=image.dtype, device=image.device)

    indices = []
    patch_idx = 0
    for i in i_vals:
        for j in j_vals:
            patches[patch_idx] = image[:, i:i+patch_size, j:j+patch_size]
            indices.append((i.item(), j.item()))
            patch_idx += 1

    return patches, indices, stride_h, stride_w

In [15]:
import os
from PIL import Image
import torch
from torchvision import transforms

def save_patches(image_folder, patch_folder, prefix):
    """
    Extracts patches from images in image_folder and saves them in patch_folder.

    Args:
        image_folder (str): Path to the folder containing images.
        patch_folder (str): Path to save patches.
        prefix (str): Prefix for naming output patches.
    """
    os.makedirs(patch_folder, exist_ok=True)  # Ensure output folder exists
    transform = transforms.ToTensor()
    image_counter = 0

    # Get sorted list of PNG files
    img_files = sorted([f for f in os.listdir(image_folder) if f.endswith('.png')])

    for img_file in img_files:
        img_path = os.path.join(image_folder, img_file)

        # Load and preprocess image
        with Image.open(img_path) as img:
            img = img.convert("RGB")
            image_tensor = transform(img)  # Convert to tensor (C, H, W)

        # Extract patches
        patches, _, _, _ = extract_patches(image_tensor, patch_size=256)

        # Save patches efficiently
        for idx, patch in enumerate(patches):
            patch_name = f"{prefix}_{image_counter:05d}_{idx:03d}.png"
            patch_path = os.path.join(patch_folder, patch_name)
            transforms.ToPILImage()(patch).save(patch_path)

        image_counter += 1

        # Print progress every 100 images
        if image_counter % 100 == 0:
            print(f"Processed {image_counter} {prefix} images...")

    print(f"Finished processing {image_counter} {prefix} images.")


In [None]:
save_patches(train_input_images_folder, train_input_patches_folder, 'input')
save_patches(train_target_images_folder, train_target_patches_folder, 'target')

Processed 100 input images...
Processed 200 input images...
Processed 300 input images...
Processed 400 input images...
Processed 500 input images...
Processed 600 input images...
Processed 700 input images...
Processed 800 input images...
Processed 900 input images...
Processed 1000 input images...
Processed 1100 input images...
Processed 1200 input images...
Processed 1300 input images...
Processed 1400 input images...
Processed 1500 input images...
Processed 1600 input images...
Processed 1700 input images...
Processed 1800 input images...
Processed 1900 input images...
Processed 2000 input images...
Processed 2100 input images...
Finished processing 2103 input images.


In [None]:
# Dataloader

import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.multiprocessing as mp

class PairedPatchDataset(Dataset):
    def __init__(self, input_folder, target_folder):
        self.input_folder = input_folder
        self.target_folder = target_folder

        # Use os.scandir() for faster directory listing
        input_files = {f.name.replace("input_", "").split('.')[0] for f in os.scandir(input_folder) if f.is_file()}
        target_files = {f.name.replace("target_", "").split('.')[0] for f in os.scandir(target_folder) if f.is_file()}

        # Find common filenames
        self.image_filenames = sorted(input_files & target_files)

        # Precompute full paths for efficiency
        self.input_paths = [os.path.join(input_folder, f"input_{name}.png") for name in self.image_filenames]
        self.target_paths = [os.path.join(target_folder, f"target_{name}.png") for name in self.image_filenames]

        # Define transformations
        self.transform = transforms.Compose([
            transforms.ToTensor(),  # Convert images to tensors
        ])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        # Load images
        input_image = Image.open(self.input_paths[idx]).convert("RGB")
        target_image = Image.open(self.target_paths[idx]).convert("RGB")

        # Apply transformations
        return self.transform(input_image), self.transform(target_image)

In [3]:
# baseline block

class LayerNorm2d(nn.Module):
    def __init__(self, num_channels, eps=1e-6):
        super().__init__()
        self.norm = nn.LayerNorm(num_channels, eps=eps)

    def forward(self, x):
        # Convert to channels_last format for efficiency
        x = x.to(memory_format=torch.channels_last)
        return self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=2):
        super(ChannelAttention, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False)  # No bias needed
        self.relu = nn.ReLU(inplace=True)  # In-place ReLU saves memory
        self.conv2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attn = self.pool(x)
        attn = self.conv1(attn)
        attn = self.relu(attn)
        attn = self.conv2(attn)
        attn = self.sigmoid(attn)
        return x * attn

class BaselineBlock(nn.Module):
    def __init__(self, in_channels):
        super(BaselineBlock, self).__init__()
        self.norm1 = LayerNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)
        self.dconv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels, bias=False)
        self.gelu = nn.GELU()
        self.ca = ChannelAttention(in_channels)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)

        self.norm2 = LayerNorm2d(in_channels)
        self.conv3 = nn.Conv2d(in_channels, 2 * in_channels, kernel_size=1, bias=False)
        self.conv4 = nn.Conv2d(2 * in_channels, in_channels, kernel_size=1, bias=False)

    def forward(self, x):
        residual = x

        out = self.norm1(x)
        out = self.conv1(out)
        out = self.dconv(out)
        out = self.gelu(out)
        out = self.ca(out)
        out = self.conv2(out)
        out += residual

        residual = out

        out = self.norm2(out)
        out = self.conv3(out)
        out = self.gelu(out)
        out = self.conv4(out)
        out += residual

        return out

In [4]:
class BaselineModel(nn.Module):
    def __init__(self, n_channels=3, width=32):
        super(BaselineModel, self).__init__()
        self.init_conv = nn.Conv2d(n_channels, width, kernel_size=3, padding=1)

        # Encoder (Downsampling path)
        self.enc1 = self._make_stage(width, 4)
        self.down1 = nn.Conv2d(width, width, kernel_size=3, stride=2, padding=1)

        self.enc2 = self._make_stage(width, 4)
        self.down2 = nn.Conv2d(width, width, kernel_size=3, stride=2, padding=1)

        self.enc3 = self._make_stage(width, 4)
        self.down3 = nn.Conv2d(width, width, kernel_size=3, stride=2, padding=1)

        self.enc4 = self._make_stage(width, 4)
        self.down4 = nn.Conv2d(width, width, kernel_size=3, stride=2, padding=1)

        # Bottleneck
        self.bottleneck = self._make_stage(width, 4)

        # Decoder (Upsampling path)
        self.up4 = self._upsample_layer(width)
        self.dec4 = self._make_stage(width, 4)

        self.up3 = self._upsample_layer(width)
        self.dec3 = self._make_stage(width, 4)

        self.up2 = self._upsample_layer(width)
        self.dec2 = self._make_stage(width, 4)

        self.up1 = self._upsample_layer(width)
        self.dec1 = self._make_stage(width, 4)

        # Final output layer
        self.final_conv = nn.Conv2d(width, n_channels, kernel_size=3, padding=1)

    def _make_stage(self, channels, num_blocks):
        """Helper function to create multiple BaselineBlocks."""
        return nn.Sequential(*[BaselineBlock(channels) for _ in range(num_blocks)])

    def _upsample_layer(self, channels):
        """Upsample using pointwise convolution followed by pixel shuffle."""
        return nn.Sequential(
            nn.Conv2d(channels, channels * 4, kernel_size=1),
            nn.PixelShuffle(2)
        )

    def forward(self, x):
        x = self.init_conv(x)

        # Encoder
        e1 = self.enc1(x)
        x = self.down1(e1)

        e2 = self.enc2(x)
        x = self.down2(e2)

        e3 = self.enc3(x)
        x = self.down3(e3)

        e4 = self.enc4(x)
        x = self.down4(e4)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder
        x = self.up4(x) + e4
        x = self.dec4(x)

        x = self.up3(x) + e3
        x = self.dec3(x)

        x = self.up2(x) + e2
        x = self.dec2(x)

        x = self.up1(x) + e1
        x = self.dec1(x)

        # Final output
        x = self.final_conv(x)
        return x


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
# from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import CosineAnnealingLR

# Define PSNR Loss
def psnr_loss(pred, target, max_val=1.0):
    mse = nn.functional.mse_loss(pred, target)
    psnr = 20 * torch.log10(max_val / torch.sqrt(mse))
    return -psnr  # Negative because we minimize loss

# Dataset (Dummy Dataset Example)
class RandomDataset(data.Dataset):
    def __init__(self, size=1000, img_size=(3, 256, 256)):
        self.size = size
        self.img_size = img_size
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip()
        ])

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        img = torch.rand(self.img_size)
        target = img.clone()  # Identity mapping as placeholder
        return img, target

# Hyperparameters
batch_size = 8
num_iterations = 10000
learning_rate = 1e-3
min_lr = 1e-6
width = 32  # Model width parameter

# Prepare DataLoader
dataset = PairedPatchDataset(train_input_patches_folder,train_target_patches_folder)
dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model Initialization
model = BaselineModel(3, width).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.9), weight_decay=0)
scheduler = CosineAnnealingLR(optimizer, T_max=num_iterations, eta_min=min_lr)

# TensorBoard Writer
# writer = SummaryWriter()

# Training Loop
model.train()
iteration = 0
while iteration < num_iterations:
    for imgs, targets in dataloader:
        imgs, targets = imgs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = psnr_loss(outputs, targets)
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log Metrics
        if iteration % 100 == 0:
            psnr_value = -loss.item()
            #writer.add_scalar('Loss/PSNR', psnr_value, iteration)
            print(f"Iteration {iteration}: PSNR = {psnr_value:.2f} dB")

        iteration += 1
        if iteration >= num_iterations:
            break

#writer.close()
print("Training complete.")


Iteration 0: PSNR = 2.17 dB
Iteration 5: PSNR = 12.27 dB


KeyboardInterrupt: 

In [None]:
import lmdb
import cv2
import numpy as np
import os

def extract_images_from_lmdb(lmdb_folder, output_folder):
    """
    Extracts PNG images from an LMDB database and saves them to the specified folder.

    Args:
        lmdb_folder (str): Path to the LMDB folder containing data.mdb and lock.mdb.
        output_folder (str): Path to the folder where extracted images will be saved.
    """
    # Open LMDB environment
    env = lmdb.open(lmdb_folder, readonly=True, lock=False)

    with env.begin() as txn:
        cursor = txn.cursor()
        for key, value in cursor:
            # Decode the key (filename)
            filename = key.decode("utf-8")

            # Ensure filename has .png extension
            if not filename.endswith(".png"):
                filename += ".png"

            # Convert value (image data) to numpy array
            img_array = np.frombuffer(value, dtype=np.uint8)
            img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)  # Read as an image

            if img is None:
                print(f"Warning: Could not decode image for key {filename}, skipping...")
                continue  # Skip if decoding fails

            # Ensure output folder exists
            os.makedirs(output_folder, exist_ok=True)

            # Save the image
            output_path = os.path.join(output_folder, filename)
            cv2.imwrite(output_path, img)
            print(f"Saved {output_path}")

    print("Extraction complete!")


In [None]:
extract_images_from_lmdb(test_input_lmdb_folder, test_input_images_folder)
extract_images_from_lmdb(test_target_lmdb_folder, test_target_images_folder)

In [None]:
save_patches(test_input_images_folder, test_input_patches_folder, 'input')
save_patches(test_target_images_folder, test_target_patches_folder, 'target')