In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import vgg19
import torch.nn.functional as F
import os
import gc
import time
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import vgg19
from torchvision.utils import save_image
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

torch.backends.cudnn.benchmark = True
from torchvision import transforms

2024-06-30 22:42:06.543875: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-30 22:42:06.544001: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-30 22:42:06.680118: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Move Data


In [3]:
# import os
# import shutil

# # Source directory
# source_dir = '/kaggle/input/flickr2k/Flickr2K/Flickr2K_HR/'

# # Destination directory
# destination_dir = '/kaggle/working/dataset'

# # Ensure destination directory exists
# os.makedirs(destination_dir, exist_ok=True)

# # Get list of files in source directory
# files = os.listdir(source_dir)

# # Copy each file from source to destination
# for file in files:
#     source_file = os.path.join(source_dir, file)
#     destination_file = os.path.join(destination_dir, file)
#     shutil.copy(source_file, destination_file)


## P

In [4]:
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_GEN = "/kaggle/working/gen.pth"
CHECKPOINT_DISC = "/kaggle/working/disc.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 2e-4
NUM_EPOCHS = 100
BATCH_SIZE = 16
LAMBDA_GP = 10
NUM_WORKERS = 8
HIGH_RES = 128
LOW_RES = HIGH_RES // 4
IMG_CHANNELS = 3

highres_transform = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2(),
    ]
)

lowres_transform = A.Compose(
    [
      
        A.Resize(width=LOW_RES, height=LOW_RES,interpolation=Image.BICUBIC),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2(),
    ]
)

both_transforms = A.Compose(
    [
    A.Resize(width=HIGH_RES, height=HIGH_RES,interpolation=Image.LANCZOS),

#         A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
    ]
)

test_transform = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2(),
    ]
)


## U

In [5]:
def plot_srgan(loader, generator, device='cuda', num_images=1, log_dir='/kaggle/working/logs'):
    """
    Plots low-resolution, generated, and high-resolution images from the SRGAN model.

    Parameters:
    loader (DataLoader): DataLoader for the dataset.
    generator (torch.nn.Module): Generator model.
    device (str): Device to run the model on ('cuda' or 'cpu').
    num_images (int): Number of images to display.
    log_dir (str): Directory to save log files (default: './logs').
    """
    generator.eval()
    
    # Get a batch of data
    data_iter = iter(loader)
    lr_images, hr_images = next(data_iter)

    # Select a few images to plot
    lr_images = lr_images[:num_images]
    hr_images = hr_images[:num_images]

    # Move images to the device
    lr_images = lr_images.to(device)
    hr_images = hr_images.to(device)

    # Generate super-resolution images
    with torch.no_grad():
        sr_images = generator(lr_images)

    # Move images back to CPU for plotting
    lr_images = lr_images.cpu()
    sr_images = sr_images.cpu()
    hr_images = hr_images.cpu()

#     lr_images = (lr_images+1)/2.0*255.0
#     sr_images = (sr_images+1)/2.0*255.0
#     hr_images = (hr_images+1)/2.0*255.0
    
    
    # Create directory if it doesn't exist
    os.makedirs(log_dir, exist_ok=True)

    # Plot and save images
    for i in range(num_images):
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        # Plot Low Resolution image
        axes[0].imshow(lr_images[i].permute(1, 2, 0))
        axes[0].set_title('Low Resolution')
        axes[0].axis('off')

        # Plot Super Resolution image
        axes[1].imshow(sr_images[i].permute(1, 2, 0))
        axes[1].set_title('Super Resolution')
        axes[1].axis('off')

        # Plot High Resolution image
        axes[2].imshow(hr_images[i].permute(1, 2, 0))
        axes[2].set_title('High Resolution')
        axes[2].axis('off')

        # Save figure
        fig_path = os.path.join(log_dir, f'image_{i+1}.png')
        plt.savefig(fig_path)
        plt.show()
        plt.close(fig)

        # Log the figure path in Kaggle's output
        print(f'![image_{i+1}](./{fig_path})')

    generator.train()







## DataSet

In [6]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from PIL import Image
import matplotlib.pyplot as plt

class MyImageFolder(Dataset):
    def __init__(self, root_dirs):
        super(MyImageFolder, self).__init__()
        self.root_dirs = root_dirs
        self.image_files = self.collect_image_files()

    def collect_image_files(self):
        image_files = []
        for root_dir in self.root_dirs:
            image_files.extend([os.path.join(root_dir, f) for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f))])
        return image_files

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        img_path = self.image_files[index]
        image = Image.open(img_path).convert("RGB")
        image = np.array(image)

        image = both_transforms(image=image)["image"]  # Assuming both_transforms is defined elsewhere
        high_res = highres_transform(image=image)["image"]  # Assuming highres_transform is defined elsewhere
        low_res = lowres_transform(image=image)["image"]  # Assuming lowres_transform is defined elsewhere

        return low_res, high_res

def custom_collate_fn(batch):
    return torch.stack([item[0] for item in batch]), torch.stack([item[1] for item in batch])

def plot_(low_res, high_res):
    num_images = len(low_res)
    fig, axes = plt.subplots(num_images, 2, figsize=(10, 10))
    
    for i in range(num_images):
        lr_img = np.transpose(low_res[i].numpy(), (1, 2, 0))  # Assuming low_res is a Tensor
        hr_img = np.transpose(high_res[i].numpy(), (1, 2, 0))  # Assuming high_res is a Tensor
        
        axes[i, 0].imshow(lr_img)
        axes[i, 0].set_title('Low Resolution')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(hr_img)
        axes[i, 1].set_title('High Resolution')
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.show()


In [7]:
root_dirs = [ '/kaggle/input/flickr2k/Flickr2K/Flickr2K_HR'
                 
            
                ]
dataset = MyImageFolder(root_dirs)

## Model

In [8]:


class ResidualDenseBlock(nn.Module):
    def __init__(self, nf=64, gc=32):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(nf, gc, 3, padding=1, bias=True)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, padding=1, bias=True)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, padding=1, bias=True)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, padding=1, bias=True)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, padding=1, bias=True)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x

class RRDB(nn.Module):
    def __init__(self, nf):
        super(RRDB, self).__init__()
        self.rdb1 = ResidualDenseBlock(nf)
        self.rdb2 = ResidualDenseBlock(nf)
        self.rdb3 = ResidualDenseBlock(nf)

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        return out * 0.2 + x

class Generator(nn.Module):
    def __init__(self, num_rrdb_blocks=23):
        super(Generator, self).__init__()
        self.conv_first = nn.Conv2d(3, 64, 3, padding=1, bias=True)
        self.body = nn.Sequential(*[RRDB(64) for _ in range(num_rrdb_blocks)])
        self.conv_body = nn.Conv2d(64, 64, 3, padding=1, bias=True)
        self.conv_up1 = nn.Conv2d(64, 64, 3, padding=1, bias=True)
        self.conv_up2 = nn.Conv2d(64, 64, 3, padding=1, bias=True)
        self.conv_hr = nn.Conv2d(64, 64, 3, padding=1, bias=True)
        self.conv_last = nn.Conv2d(64, 3, 3, padding=1, bias=True)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        feat = self.conv_first(x)
        body_feat = self.conv_body(self.body(feat))
        feat = feat + body_feat
        feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
        feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
        feat = self.conv_last(self.lrelu(self.conv_hr(feat)))
        return feat

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        features = self.features(x)
        output = self.classifier(features)
        return output.view(output.size(0), -1)


## Loss

In [9]:

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:35])

    def forward(self, x):
        return self.feature_extractor(x)

In [10]:
from tqdm import tqdm

class ESRGAN:
    def __init__(self, lr=1e-5, num_epochs=5, batch_size=8):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.generator = Generator().to(self.device)
        self.discriminator = Discriminator().to(self.device)
        self.feature_extractor = FeatureExtractor().to(self.device)
        
        self.generator_optimizer = optim.Adam(self.generator.parameters(), lr=lr)
        self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=lr)
        
        self.content_criterion = nn.L1Loss()
        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        
        self.num_epochs = num_epochs
        self.batch_size = batch_size

    def train(self, train_dataloader):
        for epoch in range(self.num_epochs):
            tqdm_dataloader = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.num_epochs}", leave=False)
            for i, (lr_imgs, hr_imgs) in enumerate(tqdm_dataloader):
                lr_imgs = lr_imgs.to(self.device)
                hr_imgs = hr_imgs.to(self.device)
                
                # Train Discriminator
                self.discriminator_optimizer.zero_grad()
                
                sr_imgs = self.generator(lr_imgs)
                real_preds = self.discriminator(hr_imgs)
                fake_preds = self.discriminator(sr_imgs.detach())
                
                d_loss_real = self.adversarial_criterion(real_preds - fake_preds.mean(), torch.ones_like(real_preds))
                d_loss_fake = self.adversarial_criterion(fake_preds - real_preds.mean(), torch.zeros_like(fake_preds))
                d_loss = (d_loss_real + d_loss_fake) / 2
                
                d_loss.backward()
                self.discriminator_optimizer.step()
                
                # Train Generator
                self.generator_optimizer.zero_grad()
                
                sr_imgs = self.generator(lr_imgs)
                fake_preds = self.discriminator(sr_imgs)
                real_preds = self.discriminator(hr_imgs)
                
                content_loss = self.content_criterion(sr_imgs, hr_imgs)
                adversarial_loss = self.adversarial_criterion(fake_preds - real_preds.mean(), torch.ones_like(fake_preds))
                perceptual_loss = self.content_criterion(self.feature_extractor(sr_imgs), self.feature_extractor(hr_imgs))
                
                g_loss = content_loss + 0.001 * adversarial_loss + 0.006 * perceptual_loss
                
                g_loss.backward()
                self.generator_optimizer.step()
                
                # Print or log detailed loss information for D and G
                if i % 100 == 0 or i==166:
                    d_loss_item = d_loss.item()
                    g_loss_item = g_loss.item()
                    print(f"Epoch [{epoch+1}/{self.num_epochs}], Step [{i+1}/{len(train_dataloader)}], "
                          f"D_loss: {d_loss_item:.4f}, G_loss: {g_loss_item:.4f}, "
                          f"Content_loss: {content_loss.item():.4f}, "
                          f"Adversarial_loss: {adversarial_loss.item():.4f}, "
                          f"Perceptual_loss: {perceptual_loss.item():.4f}")
                    plot_srgan(train_dataloader,self.generator, self.device)
            
            tqdm_dataloader.close()
            
            # Save models after each epoch if needed
            if SAVE_MODEL:
                self.save_models("/kaggle/working")

    def save_models(self, path):
        torch.save(self.generator.state_dict(), f"{path}/generator.pth")
        torch.save(self.discriminator.state_dict(), f"{path}/discriminator.pth")

    def load_models(self, path):
        generator_path = os.path.join(path, "generator.pth")
        discriminator_path = os.path.join(path, "discriminator.pth")

        if os.path.exists(generator_path):
            self.generator.load_state_dict(torch.load(generator_path))
            print(f"Generator model loaded successfully from {generator_path}.")
        else:
            print(f"Generator model file not found at {generator_path}. Skipping generator loading.")

        if os.path.exists(discriminator_path):
            self.discriminator.load_state_dict(torch.load(discriminator_path))
            print(f"Discriminator model loaded successfully from {discriminator_path}.")
        else:
            print(f"Discriminator model file not found at {discriminator_path}. Skipping discriminator loading.")


In [11]:
def main():
    root_dirs = [ '/kaggle/input/flickr2k/Flickr2K/Flickr2K_HR',
                 '/kaggle/input/flickrfaceshq-dataset-ffhq'
            
                ]
    dataset = MyImageFolder(root_dirs)
    print(f"total numbers of images to train {dataset.__len__()}")
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=True,
        num_workers=NUM_WORKERS,
        collate_fn=custom_collate_fn,
    )
    esrgan = ESRGAN()
    esrgan.train(loader)
    if SAVE_MODEL:
        esrgan.save_models("/kaggle/working")