In [42]:
from tqdm import tqdm 
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import time
import os 
import torch
from torchvision.models import resnet50, ResNet50_Weights
from torch import nn
torch.cuda.empty_cache()
from torchvision import transforms
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [43]:
# HIGH_RES_TRAIN_FOLDER = "/kaggle/input/div2k-processed-dataset/processed_data/high_res_train"
# LOW_RES_TRAIN_FOLDER = "/kaggle/input/div2k-processed-dataset/processed_data/low_res_train"
# HIGH_RES_VALID_FOLDER = "/kaggle/input/div2k-processed-dataset/processed_data/high_res_valid"
# LOW_RES_VALID_FOLDER = "/kaggle/input/div2k-processed-dataset/processed_data/low_res_valid"

HIGH_RES_TRAIN_FOLDER = "/Users/archismanchakraborti/Desktop/Final_model/archive/processed_data/high_res_train"
LOW_RES_TRAIN_FOLDER = "/Users/archismanchakraborti/Desktop/Final_model/archive/processed_data/low_res_train"
HIGH_RES_VALID_FOLDER = "/Users/archismanchakraborti/Desktop/Final_model/archive/processed_data/high_res_valid"
LOW_RES_VALID_FOLDER = "/Users/archismanchakraborti/Desktop/Final_model/archive/processed_data/low_res_valid"

In [44]:
# Count number of files in the folders
ntrain_images = len(os.listdir(HIGH_RES_TRAIN_FOLDER))
nvalid_images = len(os.listdir(HIGH_RES_VALID_FOLDER))

print(f"Number of training images: {ntrain_images}")
print(f"Number of validation images: {nvalid_images}")

Number of training images: 33032
Number of validation images: 4200


In [45]:
class PairedImageDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, downscale_factor=4, transform=None, use_half_precision=False):
        """
        Initialize the dataset class
        Args:
            hr_dir (str): Path to the high resolution directory
            lr_dir (str): Path to the low resolution directory
            downscale_factor (int): Factor by which the low resolution image should be downscaled
            transform: Transformations to apply. By default applies the ToTensor operation
            use_half_precision (bool): Whether to use half precision for the images
        """
        self.hr_dir = hr_dir
        self.lr_dir = lr_dir
        self.hr_images = sorted(os.listdir(hr_dir))
        self.lr_images = sorted(os.listdir(lr_dir))
        if transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),])
        else:
            self.transform = transform 
        self.use_half_precision = use_half_precision
        self.downscale_factor = downscale_factor

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_image_path = os.path.join(self.hr_dir, self.hr_images[idx])
        lr_image_path = os.path.join(self.lr_dir, self.lr_images[idx])

        hr_image = Image.open(hr_image_path).convert("RGB")
        lr_image = Image.open(lr_image_path).convert("RGB")
        
        # Decrease the resolution of the lr image
        lr_image = lr_image.resize(
            (lr_image.size[0] // self.downscale_factor, 
             lr_image.size[1] // self.downscale_factor), 
            Image.BICUBIC
        )

        if self.transform:
            hr_image = self.transform(hr_image)
            lr_image = self.transform(lr_image)

        if self.use_half_precision:
            hr_image = hr_image.half()
            lr_image = lr_image.half()

        return hr_image, lr_image

In [46]:
train_dataset = PairedImageDataset(HIGH_RES_TRAIN_FOLDER, LOW_RES_TRAIN_FOLDER, use_half_precision=True)
valid_dataset = PairedImageDataset(HIGH_RES_VALID_FOLDER, LOW_RES_VALID_FOLDER, use_half_precision=True)

BATCH_SIZE: int = 16
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [47]:
class GeneratorResidualBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, 
                 padding=1, use_activation: bool = True, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels, kernel_size=kernel_size,
                              stride=stride, padding=padding, bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.activation = nn.LeakyReLU(0.2, inplace=True) if use_activation else nn.Identity()
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, 
                                kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        residual = x
        x = self.bn2(self.conv2(self.activation(self.bn1(self.conv1(x)))))
        return x + residual
    
class UpsampleBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=0, 
                 scate_factor=4, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 
                              stride=stride, padding=padding, bias=True)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor=scate_factor)
        self.activation = nn.PReLU()
    
    def forward(self, x):
        return self.pixel_shuffle(self.activation(self.conv(x)))
    

class Generator(nn.Module):
    def __init__(self, in_channels: int = 3, num_residual_blocks: int = 16, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=9, stride=1, padding=4, bias=True)
        self.activation = nn.PReLU()
        self.residual_blocks = nn.Sequential(*[GeneratorResidualBlock() for _ in range(num_residual_blocks)])
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
        self.bn = nn.BatchNorm2d(64)
        self.upsample = nn.Sequential(*[UpsampleBlock(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, scate_factor=2), 
                                        UpsampleBlock(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, scate_factor=2)])
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=9, stride=1, padding=4, bias=True)

    def forward(self, x):
        x = self.activation(self.conv1(x))
        x += self.bn(self.conv2(self.residual_blocks(x))) + x
        return self.conv3(self.upsample(x))    



class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=2, 
                 padding=1, use_activation: bool = True, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 
                              stride=stride, padding=padding, bias=True)
        self.bn = nn.BatchNorm2d(out_channels)
        self.activation = nn.LeakyReLU(0.2, inplace=True) if use_activation else nn.Identity()

    def forward(self, x):
        return self.activation(self.bn(self.conv(x)))
    
class Discriminator(nn.Module):
    def __init__(self, in_channels: int = 3, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
        self.activation = nn.LeakyReLU(0.2, inplace=True)
        self.disc_blocks = nn.Sequential(*[DiscriminatorBlock(in_channels=64, out_channels=64, stride=2), 
                                           DiscriminatorBlock(in_channels=64, out_channels=128, stride=1),
                                           DiscriminatorBlock(in_channels=128, out_channels=128, stride=2),
                                           DiscriminatorBlock(in_channels=128, out_channels=256, stride=1),
                                           DiscriminatorBlock(in_channels=256, out_channels=256, stride=2),
                                           DiscriminatorBlock(in_channels=256, out_channels=512, stride=1),
                                           DiscriminatorBlock(in_channels=512, out_channels=512, stride=2)])
        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d((6, 6)), nn.Flatten(), nn.Linear(512*6*6, 1024), 
                                        nn.LeakyReLU(0.2, inplace=True), nn.Linear(1024, 1), nn.Sigmoid())
        
    def forward(self, x):
        return self.classifier(self.disc_blocks(self.activation(self.conv1(x))))

In [51]:
class ResNetLoss(nn.Module):
    def __init__(self, use_half_precision: bool = False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.resnet = resnet50(weights = ResNet50_Weights.DEFAULT)
        if use_half_precision:
            self.resnet = self.resnet.half()
        self.loss = nn.MSELoss()

    def forward(self, sr_image, hr_image):
        return self.loss(self.resnet(sr_image), self.resnet(hr_image))
    

def train_srgan(train_loader, valid_loader, generator, discriminator, 
                generator_optim, discriminator_optim, bce_loss, resnet_loss, num_epochs: int = 100, 
                device: str = DEVICE, log_folder: str = "logs/SRGAN", patience: int = 15, 
                use_half_precision: bool = False):
    """ Train the SRGAN model with the given hyperparameters
    train_loader (DataLoader): Training data loader
    valid_loader (DataLoader): Validation data loader
    generator (nn.Module): Generator model
    discriminator (nn.Module): Discriminator model
    generator_optim (Optimizer): Generator optimizer
    discriminator_optim (Optimizer): Discriminator optimizer
    bce_loss (nn.Module): Binary cross entropy loss
    resnet_loss (nn.Module): ResNet loss
    num_epochs (int): Number of epochs
    log_folder (str): Folder to save the logs
    device (torch.device): Device to run the model
    patience (int): Patience for early stopping. Default: 15
    use_half_precision (bool): Use half precision training. Default: False
    """
    # Move model to device
    generator = generator.to(device)
    discriminator = discriminator.to(device)

    # Check for half precision training
    if use_half_precision:
        generator = generator.half()
        discriminator = discriminator.half()
    
    # Create the log folder
    os.makedirs(log_folder, exist_ok=True)
    log_file = os.path.join(log_folder, "log.log")
    generator_wts_path = os.path.join(log_folder, "generator_weights.pth")
    discriminator_wts_path = os.path.join(log_folder, "discriminator_weights.pth")

    # Initialize the best loss for Early stopping
    best_loss = float("inf")
    patience_counter = 0

    # Start the training loop
    with open(log_file, "w") as log:
        log.write("Epoch,GenTrainLoss,DiscTrainLoss,GenValidLoss,DiscValidLoss,Time\n")

        for epoch in tqdm(range(num_epochs), desc="Epoch"):
            start_time = time.time()    

            # training phase
            generator.train()
            discriminator.train()
            total_gen_loss, total_disc_loss = 0.0, 0.0

            for hr_images, lr_images in train_loader:
                print("hi")
                lr_images, hr_images = lr_images.to(device), hr_images.to(device)

                # Train the discriminator
                sr_images = generator(lr_images)
                discriminator_real = discriminator(hr_images)
                discriminator_fake = discriminator(sr_images.detach())

                discriminator_real_loss = bce_loss(discriminator_real, 
                                                   torch.ones_like(discriminator_real, dtype=torch.float16 if use_half_precision else torch.float32))
                discriminator_fake_loss = bce_loss(discriminator_fake, 
                                                   torch.zeros_like(discriminator_fake, dtype=torch.float16 if use_half_precision else torch.float32))
                discriminator_loss = discriminator_real_loss + discriminator_fake_loss
                total_disc_loss += discriminator_loss.item()

                discriminator_optim.zero_grad()
                discriminator_loss.backward()
                discriminator_optim.step()

                # Train the generator
                discriminator_fake = discriminator(sr_images)
                generator_loss = 1e-3 * bce_loss(discriminator_fake, 
                                                 torch.ones_like(discriminator_fake, dtype=torch.float16 if use_half_precision else torch.float32))
                resnet_loss_val = 0.006 * resnet_loss(sr_images, hr_images)
                generator_loss += resnet_loss_val
                total_gen_loss += generator_loss.item()

                generator_optim.zero_grad()
                generator_loss.backward()
                generator_optim.step()

            total_disc_loss /= len(train_loader)
            total_gen_loss /= len(train_loader)

            # Validation phase
            generator.eval()
            discriminator.eval()
            total_gen_valid_loss, total_disc_valid_loss = 0.0, 0.0

            with torch.no_grad():
                for hr_images, lr_images in valid_loader:
                    lr_images, hr_images = lr_images.to(device), hr_images.to(device)

                    # Calculate the generator loss
                    sr_images = generator(lr_images)
                    discriminator_fake = discriminator(sr_images)
                    generator_loss = 1e-3 * bce_loss(discriminator_fake, 
                                                     torch.ones_like(discriminator_fake, dtype=torch.float16 if use_half_precision else torch.float32))
                    resnet_loss_val = 0.006 * resnet_loss(sr_images, hr_images)
                    generator_loss += resnet_loss_val
                    total_gen_valid_loss += generator_loss.item()

                    # Calculate the discriminator loss
                    discriminator_real = discriminator(hr_images)
                    discriminator_fake = discriminator(sr_images.detach())
                    discriminator_real_loss = bce_loss(discriminator_real, 
                                                       torch.ones_like(discriminator_real, dtype=torch.float16 if use_half_precision else torch.float32))
                    discriminator_fake_loss = bce_loss(discriminator_fake, 
                                                       torch.zeros_like(discriminator_fake, dtype=torch.float16 if use_half_precision else torch.float32))
                    discriminator_loss = discriminator_real_loss + discriminator_fake_loss
                    total_disc_valid_loss += discriminator_loss.item()

            total_gen_valid_loss /= len(valid_loader)
            total_disc_valid_loss /= len(valid_loader)

            # Log the losses
            end_time = time.time()
            log.write(f"{epoch + 1},{total_gen_loss},{total_disc_loss},{total_gen_valid_loss},{total_disc_valid_loss},{end_time - start_time}\n")
            print(f"Epoch: {epoch + 1}, GenTrainLoss: {total_gen_loss}, DiscTrainLoss: {total_disc_loss}, GenValidLoss: {total_gen_valid_loss}, DiscValidLoss: {total_disc_valid_loss}, Time: {end_time - start_time}")

            # check for early stopping
            if total_gen_valid_loss < best_loss:
                best_loss = total_gen_valid_loss
                patience_counter = 0
                torch.save(generator.state_dict(), generator_wts_path)
                torch.save(discriminator.state_dict(), discriminator_wts_path)
            else:
                patience_counter += 1
                print(f"Patience Counter: {patience_counter} / {patience}: Loss did not improve from {best_loss}")
                if patience_counter >= patience:
                    print("Early stopping after {epoch + 1} epochs")
                    break
        
    generator.load_state_dict(torch.load(generator_wts_path))
    discriminator.load_state_dict(torch.load(discriminator_wts_path))

In [52]:
vanilla_log_folder = "logs/Vaanilla_SRGAN"
relu_log_folder = "logs/RELU"
prelu_log_folder = "logs/PRELU"
mish_log_folder = "logs/MISH"
dsrelu_log_folder = "logs/DSRELU"

In [53]:
############################ Vanilla SRGAN ################################
NUM_EPOCHS = 100
LEARNING_RATE = 3e-4
GENERATOR = Generator()
DISCRIMINATOR = Discriminator()
GENERATOR_OPTIM = torch.optim.Adam(GENERATOR.parameters(), lr=3e-4)
DISCRIMINATOR_OPTIM = torch.optim.Adam(DISCRIMINATOR.parameters(), lr=3e-4)
BCE_LOSS = nn.BCELoss()
RESNET_LOSS = ResNetLoss(use_half_precision=True)
PATIENCE = 15

train_srgan(train_loader=train_loader, valid_loader=valid_loader, generator=GENERATOR, discriminator=DISCRIMINATOR, 
            generator_optim=GENERATOR_OPTIM, discriminator_optim=DISCRIMINATOR_OPTIM,
            bce_loss=BCE_LOSS, resnet_loss=RESNET_LOSS, num_epochs=NUM_EPOCHS, device=DEVICE,
            log_folder=vanilla_log_folder, patience=PATIENCE, use_half_precision=True)

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

hi


Epoch:   0%|          | 0/100 [00:44<?, ?it/s]


RuntimeError: "binary_cross_entropy" not implemented for 'Half'