Source: https://github.com/KupynOrest/DeblurGAN

In [None]:
!pip install gdown
!gdown --id 1frzgOHPrw0RnOGTtnrLA8W7rjJox1E3T

Downloading...
From (original): https://drive.google.com/uc?id=1frzgOHPrw0RnOGTtnrLA8W7rjJox1E3T
From (redirected): https://drive.google.com/uc?id=1frzgOHPrw0RnOGTtnrLA8W7rjJox1E3T&confirm=t&uuid=97aa7c32-3f01-4207-9ad9-2bca6af27e30
To: /content/blurred_sharp.zip
100% 2.28G/2.28G [00:23<00:00, 98.1MB/s]


In [None]:
# Upzip the file
import zipfile

with zipfile.ZipFile("blurred_sharp.zip", "r") as zip_ref:
    zip_ref.extractall("/content")  # /content/blurred_sharp

### Split data

In [None]:
import os
import shutil
from sklearn.model_selection import train_test_split

In [None]:
# Paths
base_dir = "/content/blurred_sharp"
blurred_dir = os.path.join(base_dir, "blurred")
sharp_dir = os.path.join(base_dir, "sharp")

# New directories for splits
splits = ["train", "val", "test"]
for split in splits:
    os.makedirs(os.path.join(base_dir, split, "blurred"), exist_ok=True)
    os.makedirs(os.path.join(base_dir, split, "sharp"), exist_ok=True)

# Get all file names
blurred_files = sorted(os.listdir(blurred_dir))
sharp_files = sorted(os.listdir(sharp_dir))

# Ensure both directories contain the same number of files with matching names
assert len(blurred_files) == len(sharp_files), "Blurred and sharp directories must have the same number of files!"
for b, s in zip(blurred_files, sharp_files):
    assert b == s, f"File names do not match: {b} and {s}"

# Split into train, val, and test
train_blurred, temp_blurred, train_sharp, temp_sharp = train_test_split(
    blurred_files, sharp_files, test_size=0.3, random_state=42
)
val_blurred, test_blurred, val_sharp, test_sharp = train_test_split(
    temp_blurred, temp_sharp, test_size=0.5, random_state=42
)

# Function to copy files to new directories
def copy_files(file_list, source_dir, target_dir):
    for file in file_list:
        shutil.copy(os.path.join(source_dir, file), os.path.join(target_dir, file))

# Copy files to corresponding directories
copy_files(train_blurred, blurred_dir, os.path.join(base_dir, "train/blurred"))
copy_files(train_sharp, sharp_dir, os.path.join(base_dir, "train/sharp"))
copy_files(val_blurred, blurred_dir, os.path.join(base_dir, "val/blurred"))
copy_files(val_sharp, sharp_dir, os.path.join(base_dir, "val/sharp"))
copy_files(test_blurred, blurred_dir, os.path.join(base_dir, "test/blurred"))
copy_files(test_sharp, sharp_dir, os.path.join(base_dir, "test/sharp"))

print("Dataset successfully split into train, val, and test sets.")

Dataset successfully split into train, val, and test sets.


### Prepare data

In [None]:
from torchvision import transforms

# Define transforms
image_transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize all images to 128x128
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

In [None]:
from PIL import Image
from torch.utils.data import Dataset

class DeblurDataset(Dataset):
    def __init__(self, blurred_dir, sharp_dir, transform=None):
        """
        Args:
            blurred_dir (str): Path to the directory containing blurred images.
            sharp_dir (str): Path to the directory containing sharp images.
            transform (callable, optional): Optional transform to be applied to the images.
        """
        self.blurred_dir = blurred_dir
        self.sharp_dir = sharp_dir
        self.transform = transform
        self.image_names = sorted(os.listdir(blurred_dir))  # Ensure sorted order

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        blurred_path = os.path.join(self.blurred_dir, self.image_names[idx])
        sharp_path = os.path.join(self.sharp_dir, self.image_names[idx])

        # Load images
        blurred_image = Image.open(blurred_path).convert("RGB")
        sharp_image = Image.open(sharp_path).convert("RGB")

        # Apply transformations if provided
        if self.transform:
            blurred_image = self.transform(blurred_image)
            sharp_image = self.transform(sharp_image)

        return blurred_image, sharp_image

In [None]:
from torch.utils.data import DataLoader

# Paths to dataset
train_blurred_dir = "/content/blurred_sharp/train/blurred"
train_sharp_dir = "/content/blurred_sharp/train/sharp"
val_blurred_dir = "/content/blurred_sharp/val/blurred"
val_sharp_dir = "/content/blurred_sharp/val/sharp"
test_blurred_dir = "/content/blurred_sharp/test/blurred"
test_sharp_dir = "/content/blurred_sharp/test/sharp"

# Create datasets
train_dataset = DeblurDataset(train_blurred_dir, train_sharp_dir, transform=image_transform)
val_dataset = DeblurDataset(val_blurred_dir, val_sharp_dir, transform=image_transform)
test_dataset = DeblurDataset(test_blurred_dir, test_sharp_dir, transform=image_transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
# Check
for batch_idx, (blurred, sharp) in enumerate(train_loader):
    print(f"Batch {batch_idx + 1}:")
    print(f"Blurred Image Tensor Shape: {blurred.shape}")
    print(f"Sharp Image Tensor Shape: {sharp.shape}")
    print(f"Blurred Image (first in batch): {blurred[0]}")
    print(f"Sharp Image (first in batch): {sharp[0]}")
    break  # Stop after the first batch

Batch 1:
Blurred Image Tensor Shape: torch.Size([16, 3, 128, 128])
Sharp Image Tensor Shape: torch.Size([16, 3, 128, 128])
Blurred Image (first in batch): tensor([[[ 0.2784,  0.2784,  0.2941,  ...,  0.5608,  0.4667,  0.3882],
         [ 0.2078,  0.2157,  0.2235,  ...,  0.6157,  0.5294,  0.4431],
         [ 0.1922,  0.1922,  0.2000,  ...,  0.6235,  0.5216,  0.4118],
         ...,
         [-0.7882, -0.7882, -0.8039,  ..., -0.8196, -0.8118, -0.8118],
         [-0.7882, -0.7882, -0.7961,  ..., -0.8275, -0.8353, -0.8353],
         [-0.7961, -0.7961, -0.7961,  ..., -0.8196, -0.8275, -0.8196]],

        [[ 0.5686,  0.5765,  0.5765,  ...,  0.4902,  0.4902,  0.4902],
         [ 0.5294,  0.5373,  0.5451,  ...,  0.5451,  0.5451,  0.5373],
         [ 0.5137,  0.5216,  0.5294,  ...,  0.5608,  0.5451,  0.5216],
         ...,
         [-0.7804, -0.7804, -0.7882,  ..., -0.8118, -0.8039, -0.8039],
         [-0.7804, -0.7804, -0.7804,  ..., -0.8275, -0.8353, -0.8353],
         [-0.7882, -0.7882, -0.780

### Create model

In [None]:
import torch
import torch.nn as nn

Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_nc=3, ndf=64, n_layers=3, use_sigmoid=False):
        """
        Simplified multi-layer discriminator.
        Args:
            input_nc (int): Number of input channels (e.g., 3 for RGB images).
            ndf (int): Number of filters in the first layer.
            n_layers (int): Number of convolutional layers.
            use_sigmoid (bool): Whether to apply a sigmoid activation in the final layer.
        """
        super(Discriminator, self).__init__()

        # Define the model layers
        layers = []

        # Initial layer
        layers.append(
            nn.Sequential(
                nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            )
        )

        # Intermediate layers
        nf_mult = 1  # controls the number of feature maps (filters)
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            layers.append(
                nn.Sequential(
                    nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2, padding=1),
                    nn.BatchNorm2d(ndf * nf_mult),
                    nn.LeakyReLU(0.2, inplace=True)
                )
            )

        # Final layer before output
        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        layers.append(
            nn.Sequential(
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, padding=1),
                nn.BatchNorm2d(ndf * nf_mult),
                nn.LeakyReLU(0.2, inplace=True)
            )
        )

        # Output layer
        layers.append(nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1))
        if use_sigmoid:
            layers.append(nn.Sigmoid())

        # Combine layers into a sequential model
        self.model = nn.Sequential(*[layer for layer in layers])

    def forward(self, x):
        return self.model(x)


Generator

In [None]:
class ResnetBlock(nn.Module):
    """Defines a single residual block."""
    def __init__(self, dim, padding_type='reflect', norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=True):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        conv_block = []
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        else:
            raise NotImplementedError(f"Padding type {padding_type} is not implemented.")

        conv_block += [
            nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias),
            norm_layer(dim),
            nn.ReLU(True)
        ]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [
            nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias),
            norm_layer(dim)
        ]
        return nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)  # Add skip connection


class ResnetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=6):
        """
        Simplified ResNet generator.
        Args:
            input_nc: Number of input channels (e.g., 3 for RGB).
            output_nc: Number of output channels (e.g., 3 for RGB).
            ngf: Number of filters in the first layer.
            n_blocks: Number of ResNet blocks.
        """
        super(ResnetGenerator, self).__init__()

        # Initial convolution
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True)
        ]

        # Downsampling
        model += [
            nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True)
        ]

        # ResNet blocks
        for _ in range(n_blocks):
            model += [ResnetBlock(ngf * 4)]

        # Upsampling
        model += [
            nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True)
        ]

        # Final convolution
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# New Section

### Train GAN

Loss Functions and Optimizers

In [None]:
# --- Define Custom Loss Functions ---

class AdversarialLoss(nn.Module):
    def __init__(self):
        super(AdversarialLoss, self).__init__()
        self.criterion = nn.BCELoss()

    def forward(self, predictions, is_real):
        # Create labels based on whether the input is real or fake
        target = torch.ones_like(predictions) if is_real else torch.zeros_like(predictions)
        return self.criterion(predictions, target)

class ContentLoss(nn.Module):
    def __init__(self):
        super(ContentLoss, self).__init__()
        self.criterion = nn.L1Loss()

    def forward(self, fake_image, real_image):
        return self.criterion(fake_image, real_image)

class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        self.vgg = models.vgg19(pretrained=True).features[:14].eval()  # Up to conv3_3
        for param in self.vgg.parameters():
            param.requires_grad = False  # Freeze VGG
        self.criterion = nn.MSELoss()

    def forward(self, fake_image, real_image):
        fake_features = self.vgg(fake_image)
        real_features = self.vgg(real_image)
        return self.criterion(fake_features, real_features)

class GANLoss:
    def __init__(self, use_perceptual_loss=False):
        self.adversarial_loss = AdversarialLoss()
        self.content_loss = ContentLoss()
        self.perceptual_loss = PerceptualLoss() if use_perceptual_loss else None

    def generator_loss(self, discriminator, fake_image, real_image):
        # Adversarial loss: fool the discriminator
        g_fake_loss = self.adversarial_loss(discriminator(fake_image), is_real=True)

        # Content loss: match the ground truth sharp image
        g_content_loss = self.content_loss(fake_image, real_image)

        # Optional: Perceptual loss
        if self.perceptual_loss:
            g_perceptual_loss = self.perceptual_loss(fake_image, real_image)
        else:
            g_perceptual_loss = 0

        # Combine all generator losses
        total_loss = g_fake_loss + 10 * g_content_loss + 0.1 * g_perceptual_loss
        return total_loss

    def discriminator_loss(self, discriminator, fake_image, real_image):
        # Real images should be classified as real
        d_real_loss = self.adversarial_loss(discriminator(real_image), is_real=True)

        # Fake images should be classified as fake
        d_fake_loss = self.adversarial_loss(discriminator(fake_image.detach()), is_real=False)

        # Combine discriminator losses
        total_loss = (d_real_loss + d_fake_loss) * 0.5
        return total_loss

In [None]:
import torchvision.models as models

# Initialize models
generator = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=6)
discriminator = Discriminator(input_nc=3, ndf=64, n_layers=3, use_sigmoid=True)

# Loss functions
gan_loss = GANLoss(use_perceptual_loss=True)

# Optimizers
lr = 0.0002
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Ensure all models and components are on the same device
generator.to(device)
discriminator.to(device)
if gan_loss.perceptual_loss:
    gan_loss.perceptual_loss.vgg.to(device)  # Move VGG for perceptual loss to GPU

In [None]:
save_dir = "/content/"

# --- Training Loop ---
num_epochs = 300
for epoch in range(num_epochs):
    for i, (blurred, sharp) in enumerate(train_loader):
        blurred, sharp = blurred.to(device), sharp.to(device)

        # --- Train Discriminator ---
        discriminator_optimizer.zero_grad()
        fake_sharp = generator(blurred)
        d_loss = gan_loss.discriminator_loss(discriminator, fake_sharp, sharp)
        d_loss.backward()
        discriminator_optimizer.step()

        # --- Train Generator ---
        generator_optimizer.zero_grad()
        g_loss = gan_loss.generator_loss(discriminator, fake_sharp, sharp)
        g_loss.backward()
        generator_optimizer.step()

        if i % 100 == 0:  # Print progress every 100 batches
            print(f"[Epoch {epoch+1}/{num_epochs}] [Batch {i}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    # Save the generator and discriminator models after each epoch
    torch.save(generator.state_dict(), os.path.join(save_dir, f"generator_epoch_{epoch+1}.pth"))
    torch.save(discriminator.state_dict(), os.path.join(save_dir, f"discriminator_epoch_{epoch+1}.pth"))
    print(f"Models saved for epoch {epoch+1}.")

[Epoch 1/300] [Batch 0] | D Loss: 0.7139 | G Loss: 7.4828
Models saved for epoch 1.
[Epoch 2/300] [Batch 0] | D Loss: 0.6881 | G Loss: 2.1054
Models saved for epoch 2.
[Epoch 3/300] [Batch 0] | D Loss: 0.6344 | G Loss: 2.3902
Models saved for epoch 3.
[Epoch 4/300] [Batch 0] | D Loss: 0.6673 | G Loss: 2.1518
Models saved for epoch 4.
[Epoch 5/300] [Batch 0] | D Loss: 0.7160 | G Loss: 2.3804
Models saved for epoch 5.
[Epoch 6/300] [Batch 0] | D Loss: 0.7050 | G Loss: 3.3896
Models saved for epoch 6.
[Epoch 7/300] [Batch 0] | D Loss: 0.6833 | G Loss: 1.6047
Models saved for epoch 7.
[Epoch 8/300] [Batch 0] | D Loss: 0.6801 | G Loss: 1.8144
Models saved for epoch 8.
[Epoch 9/300] [Batch 0] | D Loss: 0.6799 | G Loss: 1.6974
Models saved for epoch 9.
[Epoch 10/300] [Batch 0] | D Loss: 0.7119 | G Loss: 1.8714
Models saved for epoch 10.
[Epoch 11/300] [Batch 0] | D Loss: 0.6888 | G Loss: 1.9843
Models saved for epoch 11.
[Epoch 12/300] [Batch 0] | D Loss: 0.7020 | G Loss: 1.8846
Models saved 

### Test the Generator

In [None]:
import os
from torchvision.utils import save_image

# Set generator to evaluation mode
generator.eval()
with torch.no_grad():
    for i, (blurred, sharp) in enumerate(test_loader):
        blurred, sharp = blurred.to(device), sharp.to(device)

        # Generate fake (deblurred) images
        fake_sharp = generator(blurred)

        # Save results
        save_image(blurred, f"blurred_{i}.png")
        save_image(fake_sharp, f"fake_sharp_{i}.png")
        save_image(sharp, f"real_sharp_{i}.png")
        if i == 3:  # Save results for the first 10 batches
            break