In [1]:
import os
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, Resize, Compose

class DIV2KDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.lr_images = sorted(os.listdir(lr_dir))
        self.hr_images = sorted(os.listdir(hr_dir))
        self.transform = transform

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, idx):
        lr_path = os.path.join(self.lr_dir, self.lr_images[idx])
        hr_path = os.path.join(self.hr_dir, self.hr_images[idx])

        lr_image = Image.open(lr_path).convert("RGB")
        hr_image = Image.open(hr_path).convert("RGB")

        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)

        return lr_image, hr_image

lr_dir = "DIV2K_train/LR"
hr_dir = "DIV2K_train/HR"

image_size = (128,128)
transform = Compose([
    Resize(image_size),
    ToTensor()
])

train_dataset = DIV2KDataset(lr_dir, hr_dir, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)

In [2]:
import torch.nn as nn

class RRDB(nn.Module):
    def __init__(self, in_channels, growth_channels=64):
        super(RRDB, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, growth_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(growth_channels, growth_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(growth_channels, in_channels, kernel_size=3, stride=1, padding=1)
        )
        self.res_scale = 0.2  # Residual scaling factor

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)
        return x + out * self.res_scale


class Generator(nn.Module):
    def __init__(self, num_rrdb=16, in_channels=3, out_channels=3, feature_channels=64):
        super(Generator, self).__init__()
        self.initial = nn.Conv2d(in_channels, feature_channels, kernel_size=3, stride=1, padding=1)
        self.rrdb_blocks = nn.Sequential(*[RRDB(feature_channels) for _ in range(num_rrdb)])
        self.conv_mid = nn.Conv2d(feature_channels, feature_channels, kernel_size=3, stride=1, padding=1)
        self.upsampling = nn.Sequential(
            nn.Conv2d(feature_channels, feature_channels * 4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_channels, feature_channels * 4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.final = nn.Conv2d(feature_channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        initial_features = self.initial(x)
        rrdb_features = self.rrdb_blocks(initial_features)
        mid_features = self.conv_mid(rrdb_features)
        upscaled_features = self.upsampling(mid_features + initial_features)
        output = self.final(upscaled_features)
        return output

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Flatten(),
            nn.Linear(512 * 4 * 4, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
        )

    def forward(self, x):
        return self.model(x)

generator = Generator().cuda()
discriminator = Discriminator().cuda()


In [3]:
from diffusers import UNet2DModel

diffusion_model = UNet2DModel(
    sample_size=128,           
    in_channels=3,              
    out_channels=3,             
    block_out_channels=(128, 256, 512, 1024),  
    layers_per_block=2,         
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"),
).cuda()


In [5]:
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor


g_optimizer = Adam(generator.parameters(), lr=1e-4)
d_optimizer = Adam(discriminator.parameters(), lr=1e-4)
diffusion_optimizer = Adam(diffusion_model.parameters(), lr=1e-5)

for epoch in range(10):
    for lr_images, _ in train_loader:
        lr_images = lr_images.cuda()
        
        sr_images = generator(lr_images)
        d_loss = discriminator(sr_images)  
        g_loss = generator(sr_images)     
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        g_loss.backward()
        d_optimizer.step()
        g_optimizer.step()
        
        refined_images = diffusion_model(sr_images)
        loss = nn.MSELoss()(refined_images, lr_images)  
        diffusion_optimizer.zero_grad()
        loss.backward()
        diffusion_optimizer.step()


RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x524288 and 8192x1024)

In [None]:
import matplotlib.pyplot as plt

test_image = Image.open("data/DIV2K_train_LR/001.png")
test_image = transform(test_image).unsqueeze(0).cuda()

with torch.no_grad():
    sr_image = generator(test_image)
    refined_image = diffusion_model(sr_image)

plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.title("Low-Resolution")
plt.imshow(test_image.cpu().squeeze().permute(1, 2, 0))

plt.subplot(1, 3, 2)
plt.title("GAN Output")
plt.imshow(sr_image.cpu().squeeze().permute(1, 2, 0))

plt.subplot(1, 3, 3)
plt.title("Diffusion Refinement")
plt.imshow(refined_image.cpu().squeeze().permute(1, 2, 0))
plt.show()
