The following is an implementation of Clear-GAN as described in the proposal titled: ***Clear-GAN: Cloud Removal in Satellite Imagery via Cyclic Multispectral Generative Adversarial Networks***.

In what follows, PyTorch is used for data preprocessing, model construction, training, and evaluation.

# **Imports and Dependencies**

In [None]:
import os

import random
import itertools
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.transforms as t

from torchvision.io import read_image
from torch.utils.data import Dataset, random_split
import torch.utils.data

import matplotlib.pyplot as plt
from tqdm import tqdm
!pip install jupyterplot
from jupyterplot import ProgressPlot

# **Dataset**

The training dataset is stored in the `dataset.zip` file. When unzipped the `dataset` directory contains two subdirectories: `A`, which corresponds to **cloudy** images, and `B`, which corresponds to **cloudless** images. Each class contains `1735` images for each channel (It is essential that the number of images in both classes are equal in this implementation).

```
dataset
    ├── A
        ├── image_1_R.jpg
        ├── image_1_G.jpg
        ├── image_1_B.jpg
        ├── image_1_NIR.jpg
        ├── image_2_R.jpg
        ├── ...
    └── B
        ├── image_1_R.jpg
        ├── image_1_G.jpg
        ├── image_1_B.jpg
        ├── image_1_NIR.jpg
        ├── image_2_R.jpg
        ├── ...
```

The evaluation dataset (for PSNR calculation) is stored in the `eval.zip` file. When unzipped, the `eval` directory contains paired cloudy and cloudless images (of the same location). Each class contains `608` images for each channel.

``` 
eval
    image_1_R_clear.jpg
    image_1_R_cloudy.jpg
    image_1_G_clear.jpg
    image_1_G_cloudy.jpg
    image_1_B_clear.jpg
    image_1_B_cloudy.jpg
    image_2_R_clear.jpg
    image_2_R_cloudy.jpg
    ...

```


Each satellite image collected from Sentinel-2 is stored as 4 separate ```.jpg``` files, each consisting of a single channel (R, G, B, NIR).

Now, we download the datasets from Google Drive and unzip them.

In [None]:
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1qESOHWNCzOpljDj1r34QK5LMiHkE1f0t' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1qESOHWNCzOpljDj1r34QK5LMiHkE1f0t" -O dataset.zip && rm -rf /tmp/cookies.txt
!unzip dataset.zip
!rm -rf __MACOSX

!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1hZ6qzpQYxTOG9BVQfSpi4LlnhhFo76Wj' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1hZ6qzpQYxTOG9BVQfSpi4LlnhhFo76Wj" -O eval.zip && rm -rf /tmp/cookies.txt
!unzip eval.zip
!mv ./content/eval ./
!rm -rf ./content

Now, we will define custom datasets compatible with PyTorch. This makes it easier to handle the data for preprocessing, training, and evaluation using a PyTorch `DataLoader` object.

The `__getitem__(i)` function returns two 4-channel images (R, G, B, NIR) stored at the given index (ranging from 1 -> 1735), the first being a cloudy image and the second being a cloudless image (unpaired).

In [5]:
class Sentinel2(Dataset):
    def __init__(self, dir, transform = None, target_transform = None):
        self.dir = dir
        self.transform = transform
        self.target_transform = target_transform
        self.channel_map = {1: "R",
                            2: "G",
                            3: "B",
                            4: "NIR"}

    def __len__(self):
        _, _, files = next(os.walk(self.dir + '/A'))
        return int(len(files) / 4)

    def __getitem__(self, i):
        cloudy_image_path = os.path.join(self.dir, 'A', 'image_' + str(i + 1) + '_')
        cloudy_shape = read_image(cloudy_image_path + 'R.jpg').shape
        cloudy_image = torch.zeros(4, cloudy_shape[1], cloudy_shape[2])

        clear_image_path = os.path.join(self.dir, 'B', 'image_' + str(i + 1) + '_')
        clear_shape = read_image(clear_image_path + 'R.jpg').shape
        clear_image = torch.zeros(4, clear_shape[1], clear_shape[2])

        for j in range(1, 5):
            cloudy_channel = read_image(cloudy_image_path + self.channel_map[j] + '.jpg')
            cloudy_image[j - 1, :, :] = cloudy_channel[0, :, :]

            clear_channel = read_image(clear_image_path + self.channel_map[j] + '.jpg')
            clear_image[j - 1, :, :] = clear_channel[0, :, :]
            
        cloudy_image /= 255.0
        clear_image /= 255.0

        if self.transform:
            cloudy_image = self.transform(cloudy_image)
            clear_image = self.transform(clear_image)
        if self.target_transform:
            cloudy_image = self.target_transform(cloudy_image)
            clear_image = self.target_transform(clear_image)

        return cloudy_image, clear_image
    
    
    
class Eval(Dataset):
    def __init__(self, dir, transform = None, target_transform = None):
        self.dir = dir
        self.transform = transform
        self.target_transform = target_transform
        self.channel_map = {1: "R",
                            2: "G",
                            3: "B",
                            4: "NIR"}

    def __len__(self):
        _, _, files = next(os.walk(self.dir))
        return int(len(files) / 8)

    def __getitem__(self, i):
        cloudy_image_path = os.path.join(self.dir, 'image_' + str(i + 1) + '_')
        cloudy_shape = read_image(cloudy_image_path + 'R_cloudy.jpg').shape
        cloudy_image = torch.zeros(4, cloudy_shape[1], cloudy_shape[2])

        clear_image_path = os.path.join(self.dir, 'image_' + str(i + 1) + '_')
        clear_shape = read_image(clear_image_path + 'R_clear.jpg').shape
        clear_image = torch.zeros(4, clear_shape[1], clear_shape[2])

        for j in range(1, 5):
            cloudy_channel = read_image(cloudy_image_path + self.channel_map[j] + '_cloudy.jpg')
            cloudy_image[j - 1, :, :] = cloudy_channel[0, :, :]

            clear_channel = read_image(clear_image_path + self.channel_map[j] + '_clear.jpg')
            clear_image[j - 1, :, :] = clear_channel[0, :, :]
            
        cloudy_image /= 255.0
        clear_image /= 255.0

        if self.transform:
            cloudy_image = self.transform(cloudy_image)
            clear_image = self.transform(clear_image)
        if self.target_transform:
            cloudy_image = self.target_transform(cloudy_image)
            clear_image = self.target_transform(clear_image)

        return cloudy_image, clear_image

Now, we apply a `Resize` transformation to resize each image into the dimension `(4, 256, 256)`.

In [6]:
transform = t.Resize(size = (256, 256))

dataset = Sentinel2(dir = 'dataset', transform = transform)
eval_dataset = Eval(dir = 'eval', transform = transform)

We visualize a few random samples from the training dataset, displaying (unpaired) cloudy images in the top row and cloudless images in the bottom row (RGB channels only).

In [8]:
n = 4
f = plt.figure(figsize = (4 * n, 2 * n))

for i in range(n):
    img = dataset[np.random.randint(low = 0, high = len(dataset))][0][:3, :, :].permute(1, 2, 0)
    f.add_subplot(2, n, i + 1)
    plt.axis('off')
    plt.imshow(img)

for i in range(n):
    img = dataset[np.random.randint(low = 0, high = len(dataset))][1][:3, :, :].permute(1, 2, 0)
    f.add_subplot(2, n, i + n + 1)
    plt.axis('off')
    plt.imshow(img)

plt.show()

We also visualize a few random paired samples from the evaluation dataset in the same order mentioned above.

In [9]:
n = 4
f = plt.figure(figsize = (4 * n, 2 * n))

rand_inds = []

for i in range(n):
    rand_inds.append(np.random.randint(low = 0, high = len(eval_dataset)))
    img = eval_dataset[rand_inds[i]][0][:3, :, :].permute(1, 2, 0)
    f.add_subplot(2, n, i + 1)
    plt.axis('off')
    plt.imshow(img)

for i in range(n):
    img = eval_dataset[rand_inds[i]][1][:3, :, :].permute(1, 2, 0)
    f.add_subplot(2, n, i + n + 1)
    plt.axis('off')
    plt.imshow(img)

plt.show()

# **Model Construction**

CycleGAN, as described in the original paper, consists of two discriminators and two generators.

<br />

**Discriminators**

Discriminators in this context are deep convolutional neural nets that perform image classification (real/fake).

The first discriminator `(D_A)` classifies images from Domain A (cloudy), while the second `(D_B)` classifies images from Domain B (cloudless).

Each discriminator model consists of a set of (1) Convolutional layers, (2) Instance Normalization layers, and (3) Leaky ReLU activation functions.

In [10]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.main = nn.Sequential(
            nn.Conv2d(4, 64, 4, stride = 2, padding = 1),
            nn.LeakyReLU(0.2, inplace = True),

            nn.Conv2d(64, 128, 4, stride = 2, padding = 1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True),

            nn.Conv2d(128, 256, 4, stride = 2, padding = 1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace = True),

            nn.Conv2d(256, 512, 4, padding = 1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace = True),

            nn.Conv2d(512, 1, 4, padding = 1),
            )

    def forward(self, x):
        x = self.main(x)
        x = F.avg_pool2d(x, x.size()[2:])
        x = torch.flatten(x, 1)
        #x = torch.sigmoid(x)
        return x

**Generators**

In this context, a generator is an encoder-decoder architecture which downsamples (encodes) the input image, interprets the encoding through a series of `ResNet` blocks, then upsamples (decodes) the result.

Each `ResNet` block consists of two `Conv2d` layers with `3x3` filters.

The first generator `(G_A2B)` maps an image from Domain A (cloudy) to Domain B (cloudless), while the second generator `(G_B2A)` maps an image from Domain B (cloudless) to Domain A (cloudy).

In [11]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            #Initial Convolution
            nn.ReflectionPad2d(3),
            nn.Conv2d(4, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace = True),

            #Encoder
            nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace = True),
            nn.Conv2d(128, 256, 3, stride = 2, padding = 1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace = True),

            #Residual Block Sequence
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),

            #Decoder
            nn.ConvTranspose2d(256, 128, 3, stride = 2, padding = 1, output_padding = 1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace = True),
            nn.ConvTranspose2d(128, 64, 3, stride = 2, padding = 1, output_padding = 1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace = True),

            #Output
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 4, 7),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()

        self.res = nn.Sequential(nn.ReflectionPad2d(1),
                                 nn.Conv2d(in_channels, in_channels, 3),
                                 nn.InstanceNorm2d(in_channels),
                                 nn.ReLU(inplace = True),
                                 nn.ReflectionPad2d(1),
                                 nn.Conv2d(in_channels, in_channels, 3),
                                 nn.InstanceNorm2d(in_channels))

    def forward(self, x):
        return x + self.res(x)

# **Model Training**

First, we define the class `DecayLR` for learning rate decay. This will be used during weight optimization.

In [12]:
class DecayLR:
    def __init__(self, epochs, offset, decay_epochs):
        epoch_flag = epochs - decay_epochs
        assert (epoch_flag > 0), "Decay must start before the training session ends!"
        self.epochs = epochs
        self.offset = offset
        self.decay_epochs = decay_epochs

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_epochs) / (
                self.epochs - self.decay_epochs)

We also define a few helper functions.

In [15]:
#Weight Initialization
def weights_init(m):
    classname = m.__class__.__name__
    
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)

    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)
        
#Buffer for efficiency in backpropagation
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert (max_size > 0), "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)

#Peak Signal-to-Noise Ratio (used for evaluation of the generator models)
def PSNR(real_image, gen_image):
    mse = torch.mean((real_image - gen_image) ** 2)
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

Now, we train the four models `D_A, D_B, G_A2B, and G_B2A`.

In [None]:
def train(trainset, hp, first_time = True, num_saved_epochs = 0):
    #Create a 'weights' folder for checkpointing
    if first_time:
        try:
            os.makedirs("weights")
        except OSError:
            pass

    #DataLoader Instantiation
    dataloader = torch.utils.data.DataLoader(trainset, batch_size=hp['batch_size'], shuffle=True, pin_memory=True)

    #Switch to GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    #Model Instantiation
    G_A2B = Generator().to(device)
    G_B2A = Generator().to(device)
    D_A = Discriminator().to(device)
    D_B = Discriminator().to(device)

    #Weight Initialization
    if first_time:
        G_A2B.apply(weights_init)
        G_B2A.apply(weights_init)
        D_A.apply(weights_init)
        D_B.apply(weights_init)
    else:
        G_A2B.load_state_dict(torch.load(f"weights/G_A2B_epoch_{num_saved_epochs - 1}.pth"))
        G_B2A.load_state_dict(torch.load(f"weights/G_B2A_epoch_{num_saved_epochs - 1}.pth"))
        D_A.load_state_dict(torch.load(f"weights/D_A_epoch_{num_saved_epochs - 1}.pth"))
        D_B.load_state_dict(torch.load(f"weights/D_B_epoch_{num_saved_epochs - 1}.pth"))

    #Loss Functions
    cyc_loss = torch.nn.L1Loss().to(device)
    idt_loss = torch.nn.L1Loss().to(device)
    adv_loss = torch.nn.MSELoss().to(device)

    #Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(G_A2B.parameters(), G_B2A.parameters()),
                                   lr=hp['lr'],
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=hp['lr'], betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=hp['lr'], betas=(0.5, 0.999))

    #Learning Rate Schedulers
    lr_lambda = DecayLR(hp['epochs'], 0, hp['decay_epochs']).step
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lr_lambda)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lr_lambda)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lr_lambda)

    g_losses, d_losses, idt_losses, gan_losses, cyc_losses = [], [], [], [], []

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    #liveloss = PlotLosses()
    curr_iter = 0 if first_time else num_saved_epochs * len(dataloader)
    pp = ProgressPlot(line_names = ["Loss_D", "Loss_G"],
                      x_iterator=False,
                      x_label="Epochs")
    
    for epoch in range(hp['epochs']):
        progress_bar = tqdm(enumerate(dataloader), total = len(dataloader))
        for i, data in progress_bar:
            plot_log = {}

            #Load data batch
            real_A = data[0].to(device)
            real_B = data[1].to(device)
            batch_size = real_A.size(0)

            real_label = torch.full((batch_size, 1), 1, device=device, dtype=torch.float32)
            fake_label = torch.full((batch_size, 1), 0, device=device, dtype=torch.float32)

            #########################################
            #Update G network: Generators A2B and B2A
            #########################################
            optimizer_G.zero_grad()

            #Identity loss (G_B2A(A) should equal A if real A is fed)
            idt_A = G_B2A(real_A)
            loss_idt_A = idt_loss(idt_A, real_A) * 5.0

            #Identity Loss (G_A2B(B) should equal B if real B is fed)
            idt_B = G_A2B(real_B)
            loss_idt_B = idt_loss(idt_B, real_B) * 5.0


            #GAN loss D_A(G_A(A))
            fake_A = G_B2A(real_B)
            fake_output_A = D_A(fake_A)
            loss_GAN_B2A = adv_loss(fake_output_A, real_label)

            #GAN loss D_B(G_B(B))
            fake_B = G_A2B(real_A)
            fake_output_B = D_B(fake_B)
            loss_GAN_A2B = adv_loss(fake_output_B, real_label)


            #Cycle loss
            recovered_image_A = G_B2A(fake_B)
            loss_cycle_ABA = cyc_loss(recovered_image_A, real_A) * 10.0

            recovered_image_B = G_A2B(fake_A)
            loss_cycle_BAB = cyc_loss(recovered_image_B, real_B) * 10.0

            #Combine losses
            errG = loss_idt_A + loss_idt_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB

            #Calculate gradients for G_A and G_B
            errG.backward()

            #Update G_A and G_B's weights
            optimizer_G.step()

            ##############################################
            # (2) Update D network: Discriminator A
            ##############################################
            optimizer_D_A.zero_grad()

            #Real A image loss
            real_output_A = D_A(real_A)
            errD_real_A = adv_loss(real_output_A, real_label)

            #Fake A image loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            fake_output_A = D_A(fake_A.detach())
            errD_fake_A = adv_loss(fake_output_A, fake_label)

            #Combine losses
            errD_A = (errD_real_A + errD_fake_A) / 2

            #Calculate gradients for D_A
            errD_A.backward()

            #Update D_A weights
            optimizer_D_A.step()

            ##############################################
            # (3) Update D network: Discriminator B
            ##############################################
            optimizer_D_B.zero_grad()

            #Real B image loss
            real_output_B = D_B(real_B)
            errD_real_B = adv_loss(real_output_B, real_label)

            #Fake B image loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            fake_output_B = D_B(fake_B.detach())
            errD_fake_B = adv_loss(fake_output_B, fake_label)

            #Combined losses
            errD_B = (errD_real_B + errD_fake_B) / 2

            #Calculate gradients for D_B
            errD_B.backward()

            #Update D_B weights
            optimizer_D_B.step()

            #Display progress and losses
            progress_bar.set_description(
                f"[{epoch + num_saved_epochs + 1}/{hp['epochs'] + num_saved_epochs}][{i}/{len(dataloader) - 1}] "
                f"Loss_D: {(errD_A + errD_B).item():.4f} "
                f"Loss_G: {errG.item():.4f} "
                f"Loss_G_identity: {(loss_idt_A + loss_idt_B).item():.4f} "
                f"Loss_G_GAN: {(loss_GAN_A2B + loss_GAN_B2A).item():.4f} "
                f"Loss_G_cycle: {(loss_cycle_ABA + loss_cycle_BAB).item():.4f}")
            
            if curr_iter % 2 == 0:   
                pp.update(curr_iter / len(dataloader), [[(errD_A + errD_B).item(), errG.item()]])
                
            curr_iter += 1

        #Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        #Create model weight checkpoints
        ep = epoch + num_saved_epochs if not first_time else epoch
        torch.save(G_A2B.state_dict(), f"weights/G_A2B_epoch_{ep}.pth")
        torch.save(G_B2A.state_dict(), f"weights/G_B2A_epoch_{ep}.pth")
        torch.save(D_A.state_dict(), f"weights/D_A_epoch_{ep}.pth")
        torch.save(D_B.state_dict(), f"weights/D_B_epoch_{ep}.pth")
    
    pp.finalize()
    
    return G_A2B, G_B2A, D_A, D_B

#Define hyperparameters
hp = {'batch_size': 4,
      'lr': 2e-4,
      'epochs': 20,
      'decay_epochs': 10}

#Train the models
G_A2B, G_B2A, D_A, D_B = train(trainset = dataset, hp = hp, first_time = False, num_saved_epochs = 50)

In [42]:
def load_models(num_saved_epochs):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    G_A2B = Generator().to(device)
    G_B2A = Generator().to(device)
    D_A = Discriminator().to(device)
    D_B = Discriminator().to(device)

    G_A2B.load_state_dict(torch.load(f"weights/G_A2B_epoch_{num_saved_epochs - 1}.pth"))
    G_B2A.load_state_dict(torch.load(f"weights/G_B2A_epoch_{num_saved_epochs - 1}.pth"))
    D_A.load_state_dict(torch.load(f"weights/D_A_epoch_{num_saved_epochs - 1}.pth"))
    D_B.load_state_dict(torch.load(f"weights/D_B_epoch_{num_saved_epochs - 1}.pth"))
    
    return G_A2B, G_B2A, D_A, D_B

def evaluate(G_A2B, G_B2A, eval_dataset):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    G_A2B.train(False)
    G_B2A.train(False)
    
    dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size = 1, shuffle = False)
    total_removal_PSNR = 0
    total_addition_PSNR = 0
    for i, data in enumerate(dataloader):
        real_cloudy, real_clear = data
        real_cloudy = real_cloudy.to(device)
        real_clear = real_clear.to(device)
        
        with torch.no_grad():
            gen_clear = G_A2B(real_cloudy)
            gen_cloudy = G_B2A(real_clear)
            
        removal_PSNR = PSNR(real_clear, gen_clear)
        addition_PSNR = PSNR(real_cloudy, gen_cloudy)
        
        total_removal_PSNR += removal_PSNR
        total_addition_PSNR += addition_PSNR
    
    return (total_removal_PSNR / len(dataloader)).item(), (total_addition_PSNR / len(dataloader)).item()


G_A2B, G_B2A, D_A, D_B = load_models(70)
av_removal_PSNR, av_addition_PSNR = evaluate(G_A2B, G_B2A, eval_dataset)

print(f'Average PSNR for Cloud Removal: {av_removal_PSNR} dB')
print(f'Average PSNR for Cloud Addition: {av_addition_PSNR} dB')

Now, we visualize generated cloud removal exmaples using images from the evaluation dataset. In the top row, we display the real cloudy images, followed by their paired real cloudless images on the second row, and finally generated cloudless images on the third row.

In [41]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

n = 3
f = plt.figure(figsize = (4 * n, 4 * n))

rand_inds = []

for i in range(n):
    rand_inds.append(np.random.randint(low = 0, high = len(eval_dataset)))
    img = eval_dataset[rand_inds[i]][0][:3, :, :].permute(1, 2, 0)
    f.add_subplot(3, n, i + 1)
    plt.axis('off')
    plt.imshow(img)

for i in range(n):
    img = eval_dataset[rand_inds[i]][1][:3, :, :].permute(1, 2, 0)
    f.add_subplot(3, n, i + n + 1)
    plt.axis('off')
    plt.imshow(img)

for i in range(n):
    img = eval_dataset[rand_inds[i]][0]
    gen_img = G_A2B(img.unsqueeze(0).to(device)).squeeze()[:3, :, :].permute(1, 2, 0).cpu().detach().numpy()
    f.add_subplot(3, n, i + 2*n + 1)
    plt.axis('off')
    plt.imshow(gen_img)


plt.show()

# Credits

This project's CycleGAN implementation relied heavily on that contributed by GitHub user Lornatang (https://github.com/Lornatang/CycleGAN-PyTorch).