In [None]:
# Environment Set-Up
%%capture
!gdown 1YS6NdHvEQb19rTZL6RI-NbAyWy0kX08m
!unzip final_dataset.zip

In [None]:
# Imports
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from IPython.display import clear_output
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 2e-4
FEATURES = 64
BATCH_SIZE = 16
NUM_WORKERS = 2
IMG_SIZE = 256
CHANNELS = 3
NUM_EPOCHS = 35

# Paths
TRAIN_PATH = '/content/final_dataset/train'
VAL_PATH = '/content/final_dataset/val'
CHECKPOINT_PATH = '/content/drive/MyDrive/Weights'

# Loss Functions
BCE_LOSS = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()

In [None]:
# Progress Bars
bars = {
    'epoch': None,
    'loading': None,
    'training': None,
    'validation': None
}

def update_bar(name, total=None, progress=0, desc='', unit='', position=None, postfix=None):
    global bars
    clear_output()
    if bars[name] is None:
        bars[name] = tqdm(total=total, desc=desc, unit=unit, leave=False, position=position, postfix=postfix)
    else:
        bars[name].update(progress - bars[name].n)
        if desc!='': bars[name].set_description(desc)
        if postfix: bars[name].set_postfix(postfix)

    # Print all bars
    for bar in bars.values():
        if bar is not None:
            print(bar)

def close_bar(name):
    global bars
    if bars[name] is not None:
        bars[name].close()
        bars[name] = None

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.gt_dir = os.path.join(self.root_dir, "GT")
        self.hazy_dir = os.path.join(self.root_dir, "hazy")
        self.gt_filenames = sorted(os.listdir(self.gt_dir))
        self.hazy_filenames = sorted(os.listdir(self.hazy_dir))
        self.transform = transform

    def __len__(self):
        return len(self.gt_filenames)

    def __getitem__(self, idx):
        gt_name = os.path.join(self.gt_dir, self.gt_filenames[idx])
        gt_image = Image.open(gt_name).convert("RGB")

        hazy_name = os.path.join(self.hazy_dir, self.hazy_filenames[idx])
        hazy_image = Image.open(hazy_name).convert("RGB")

        if self.transform:
            gt_image = self.transform(gt_image)
            hazy_image = self.transform(hazy_image)

        return gt_image, hazy_image

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=CHANNELS*2, features=[64, 128, 256, 512]):
        super(Discriminator, self).__init__()
        self.layers = []
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2)
        )
        in_channels = features[0]
        for feature in features[1:]:
            self.layers.append(self._block(in_channels, feature, stride=1 if feature == features[-1] else 2))
            in_channels = feature

        self.layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1))

        # Initialize weights using He initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')

        self.model = nn.Sequential(*self.layers)
        self.optimizer = optim.Adam(self.parameters(), lr=0.0002, betas=(0.5, 0.999))

    def _block(self, in_channels, out_channels, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x, y):
        # Concatenate the hazy image (x) and ground truth image (y) along the channel dimension
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

    def train_step(self, x, y):
        self.train()
        # Concatenate the hazy image (x) and ground truth image (y) along the channel dimension
        disc_input = torch.cat([x, y], dim=1)

        disc_output = self.forward(x, y)
        disc_loss = BCE_LOSS(disc_output, torch.ones_like(disc_output))

        self.optimizer.zero_grad()
        disc_loss.backward()
        self.optimizer.step()

        return disc_loss.item()

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, leaky=True, use_dropout=False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
              if down else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2) if leaky else nn.ReLU(),
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=CHANNELS, features=FEATURES):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2))

        self.down1 = Block(features*1, features*2, down=True, leaky=True, use_dropout=False)
        self.down2 = Block(features*2, features*4, down=True, leaky=True, use_dropout=False)
        self.down3 = Block(features*4, features*8, down=True, leaky=True, use_dropout=False)
        self.down4 = Block(features*8, features*8, down=True, leaky=True, use_dropout=False)
        self.down5 = Block(features*8, features*8, down=True, leaky=True, use_dropout=False)
        self.down6 = Block(features*8, features*8, down=True, leaky=True, use_dropout=False)

        self.bottleneck = nn.Sequential(nn.Conv2d(features*8, features*8, 4, 2, 1), nn.ReLU())

        self.up1 = Block(features*8*1, features*8, down=False, leaky=False, use_dropout=True)
        self.up2 = Block(features*8*2, features*8, down=False, leaky=False, use_dropout=True)
        self.up3 = Block(features*8*2, features*8, down=False, leaky=False, use_dropout=True)
        self.up4 = Block(features*8*2, features*8, down=False, leaky=False, use_dropout=False)
        self.up5 = Block(features*8*2, features*4, down=False, leaky=False, use_dropout=False)
        self.up6 = Block(features*4*2, features*2, down=False, leaky=False, use_dropout=False)
        self.up7 = Block(features*2*2, features*1, down=False, leaky=False, use_dropout=False)

        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh())

        # Initialize weights using He initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')

        self.optimizer = optim.Adam(self.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
        self.scaler = torch.cuda.amp.GradScaler()

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)

        bottleneck = self.bottleneck(d7)

        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))

        return self.final_up(torch.cat([up7, d1], 1))

    def train_step(self, x, y):
        self.train()
        y_fake = self.forward(x)
        gen_loss = L1_LOSS(y_fake, y)

        self.optimizer.zero_grad()
        gen_loss.backward()
        self.optimizer.step()

        return gen_loss.item()

In [None]:
def save_checkpoint(gen, disc, epoch):
    checkpoint_path = os.path.join(CHECKPOINT_PATH, f"checkpoint_epoch_{epoch}.pth.tar")
    torch.save({'epoch': epoch,
                'gen_state_dict': gen.state_dict(),
                'disc_state_dict': disc.state_dict()},
               checkpoint_path)
    print(f"Saved checkpoint at epoch {epoch}")


def load_checkpoint(model):
    checkpoint_files = [f for f in os.listdir(CHECKPOINT_PATH) if f.endswith('.tar')]
    if not checkpoint_files:
        start_epoch = 0

    # Find the latest checkpoint file
    latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
    checkpoint_path = os.path.join(CHECKPOINT_PATH, latest_checkpoint)

    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1

    print(f"Loaded checkpoint '{latest_checkpoint}' (epoch {checkpoint['epoch']})")
    return model, start_epoch

In [None]:
def load_checkpoint(gen, disc):
    checkpoint_files = [f for f in os.listdir(CHECKPOINT_PATH) if f.endswith('.tar')]
    if not checkpoint_files:
        start_epoch = 0
        print("No checkpoint files found in the directory.")
        return gen, disc, start_epoch

    # Find the latest checkpoint file
    latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
    checkpoint_path = os.path.join(CHECKPOINT_PATH, latest_checkpoint)

    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)
    gen.load_state_dict(checkpoint['gen_state_dict'])
    disc.load_state_dict(checkpoint['disc_state_dict'])
    start_epoch = checkpoint['epoch'] + 1

    print(f"Loaded checkpoint '{latest_checkpoint}' (epoch {checkpoint['epoch']})")
    return gen, disc, start_epoch

In [None]:
def validate_epochs(generator, val_loader):
    generator.eval()
    psnr_values = []
    ssim_values = []

    with tqdm(total=len(val_loader), desc='Validation') as pbar:
        for batch_idx, (input_data, target_data) in enumerate(val_loader):
            input_data = input_data.to(DEVICE)
            target_data = target_data.to(DEVICE)

            with torch.no_grad():
                output_data = generator(input_data)

                # Compute PSNR
                output_data_np = output_data.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy()
                target_data_np = target_data.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy()
                psnr_batch = np.mean([psnr(target_data_np[i], output_data_np[i]) for i in range(output_data_np.shape[0])])
                psnr_values.append(psnr_batch)

                # Compute SSIM
                output_data_np = np.transpose(output_data_np, (0, 3, 1, 2))
                target_data_np = np.transpose(target_data_np, (0, 3, 1, 2))
                for i in range(output_data_np.shape[0]):
                    ssim_value = ssim(output_data_np[i], target_data_np[i], multichannel=True)
                    ssim_values.append(ssim_value)

                # Update progress bar
                pbar.update(1)
                pbar.set_postfix({'PSNR': np.mean(psnr_values), 'SSIM': np.mean(ssim_values)})

    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)
    print(f'Validation - Avg PSNR: {avg_psnr:.2f}, Avg SSIM: {avg_ssim:.4f}')

    return avg_psnr, avg_ssim

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = ImageDataset(root_dir=TRAIN_PATH, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
val_dataset = ImageDataset(root_dir=VAL_PATH, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)

generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

generator, discriminator, start_epoch = load_checkpoint(generator, discriminator)

No checkpoint files found in the directory.


In [None]:
update_bar('epoch', desc=f'Epoch', total=NUM_EPOCHS)

# Training loop
for epoch in range(start_epoch, NUM_EPOCHS):
    update_bar('epoch', progress=epoch+1, desc=f'Epoch {epoch+1}/{NUM_EPOCHS}')

    # Training
    generator.train()
    discriminator.train()
    for batch_idx, (gt_images, hazy_images) in enumerate(train_loader):
        gt_images = gt_images.to(DEVICE)
        hazy_images = hazy_images.to(DEVICE)

        disc_loss = discriminator.train_step(hazy_images, gt_images)
        gen_loss = generator.train_step(hazy_images, gt_images)

        update_bar('training', progress=batch_idx+1, desc=f'Training Epoch {epoch+1}',
                   postfix={'Generator Loss': gen_loss, 'Discriminator Loss': disc_loss})

    close_bar('training')
    # Validation
    avg_psnr, avg_ssim = validate_epochs(generator, val_loader)
    # Save checkpoint after each epoch
    save_checkpoint(generator, discriminator, epoch)

print("Training complete.")


Training Epoch 1: : 257 [1:36:43, 20.87s/, Generator Loss=0.28, Discriminator Loss=0.000505][A
Training Epoch 1: : 257 [1:36:43, 20.87s/, Generator Loss=0.28, Discriminator Loss=0.000505][A
Training Epoch 1: : 257 [1:36:43, 20.87s/, Generator Loss=0.317, Discriminator Loss=0.000525][A

Epoch 1/35: : 1 [1:37:11, 5831.06s/]
Training Epoch 1: : 257 [1:36:43, 20.87s/, Generator Loss=0.317, Discriminator Loss=0.000525]
