# Super Resolution Auto Encoder
Super Resoution model based on tng simulated images

## Imports and setup

In [1]:
# Check if gpu is free
!nvidia-smi

Thu May 13 17:49:10 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.80.02    Driver Version: 450.80.02    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:3B:00.0 Off |                  N/A |
| 30%   27C    P8    24W / 250W |  10898MiB / 11019MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:5E:00.0 Off |                  N/A |
| 30%   26C    P8    22W / 250W |  10898MiB / 11019MiB |      0%      Default |
|       

In [2]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn.functional as F
import torch.nn as nn
from astropy.io import fits
from tqdm import tqdm
import time
import wandb

wandb.login()

data_dir = "/home/ssweere/data/sim" #faster local storage, do not use too much
dataset_dir = "tng300_2048"
root_dir = os.path.join(data_dir, dataset_dir, "fits")
   
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# TODO: uncomment if you want to train on gpu, check if the gpu is free before you enable it
device = "cpu"

print("PyTorch uses device:",device)

[34m[1mwandb[0m: Currently logged in as: [33msamsweere[0m (use `wandb login --relogin` to force relogin)


PyTorch uses device: cpu


## Helper functions

In [3]:
import matplotlib.pyplot as plt
from astropy.visualization import (imshow_norm, MinMaxInterval,
                                   SqrtStretch)
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np

def plot_img(img, title=None, figsize =(10,10)):
    # Generate and display a test image
    # image = np.arange(65536).reshape((256, 256))
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(1, 1, 1)
    if title is not None:
        ax.set_title(title)
    ax.autoscale(enable=True)

    pos = ax.imshow(img, cmap='hot', interpolation=None)
    fig.colorbar(pos, ax=ax)
    
    plt.show()
        

def plot_multiple_img(images, titles = None, rows=None, cols=None, figsize=(10,10), merge=False, colorbar = True, save_fig=None, show_plot=True, wandb_log=False):
    """Plot multiple images

    Args:
        images (list): List of images as numpy arrays.
        titles (list) (optional): The titles of the subplots.
        rows (int) (optional): The amount of rows in the subplot, if this is set and cols is None it will expand in the cols direction
        cols (int) (optional): The amount of cols in the subplot, if this is set and rows is None it will expand in the rows direction
        figsize (tuple) (optional): The figsize of the subplots
        merge (boolean) (optional): If this parameter is set all the images are compressed into one big image and displayed
        colorbar (boolean) (optional): If this parameter is set the colorbar is shown
        save_fig (string) (optional): If this parameter is set save the image to a file with the given name
        show_plot (boolean) (optional): If this parameter is set the plot is displayed
        wandb_log (boolean) (optional): If this parameter is set the plot is logged to wandb
    """
        
    # Determine the rows and cols if they are set to none
    if not rows and not cols:
        # Go for a square layout if possible, if not possible expand in the x direction
        s_dim = np.sqrt(len(images))
        if s_dim.is_integer():
            # The number of images is perfect for a square layout
            rows = int(s_dim)
            cols = int(s_dim)
        else:
            # The number of images is not perfect for a square layout
            # Expand in the x direction
            rows = int(np.floor(s_dim))
            cols = int(np.ceil(len(images)/float(rows)))
    elif not rows and cols:
        # rows is None and cols is not, expand in the rows direction
        rows = int(np.ceil(len(images)/float(cols)))
    elif rows and not cols:
        # cols is None and rows is not, expand in the cols direction
        cols = int(np.ceil(len(images)/float(rows)))
    else:
        if len(images) > rows*cols:
            raise ValueError(f"The specified rows ({rows}) and cols ({cols}) do not provide enough room for all the images ({rows*cols}/{len(images)})")
            
    if merge:
        counter = 0
        img_merged = None
        for y in range(rows):
            # First merge the images in x direction
            img_m_x = None
            for x in range(cols):
                img = images[counter]
                counter += 1
                
                if img_m_x is None:
                    img_m_x = img
                else:
                    img_m_x = np.hstack((img_m_x, img))
            
            # Then stack them vertically on the big image
            if img_merged is None:
                img_merged = img_m_x
            else:
                img_merged = np.vstack((img_merged, img_m_x))
         
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(1, 1, 1)
        ax.autoscale(enable=True)
        if titles is not None:
            ax.set_title(' - '.join(titles))

        pos = ax.imshow(img_merged, cmap='hot', interpolation=None)
        if colorbar:
            fig.colorbar(pos, ax=ax)
                
                
    else: 
        fig, axs = plt.subplots(rows, cols, figsize=figsize)

        axs = np.array(axs)
        if len(axs.shape) == 1:
            # We only have one row, expand the dimensions in order to have the same structure 
            # as with multiple rows
            axs = np.expand_dims(axs, axis = 1)

        counter = 0
        for x in range(axs.shape[0]):
            for y in range(axs.shape[1]):
    #             if rows == 1:
    #                 # This is a fix for when we only have one row
    #                 ax = axs[x][y]
    #             else:
    #                 ax = axs[y][x]
                ax = axs[x, y]

                idx = counter
                counter += 1

                # Check if there are enough images to show, if not continue
                if idx >= len(images):
                    continue

                img = images[idx]
                if titles is not None:
                    ax.set_title(titles[idx])

                pos = ax.imshow(img, cmap='hot', interpolation=None)
                
                if colorbar:
                    # create an axes on the right side of ax. The width of cax will be 5%
                    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
                    divider = make_axes_locatable(ax)
                    cax = divider.append_axes("right", size="5%", pad=0.05)

                    plt.colorbar(pos, cax=cax)

        # Correct for a nicer layout
        fig.tight_layout()
#     fig.subplots_adjust(hspace=-0.7)
    
    if save_fig is not None:
        plt.savefig(save_fig, bbox_inches='tight')
#         print("Saved figure to",save_fig)
    
    if wandb_log:
        wandb.log({"images": plt})
    
    if show_plot:
        plt.show()
    else:
        plt.close()

## Dataset Creation
Based on: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

In [4]:
class DownsampleSum(object):
    """Downsample the image summing the values using conv2d
    
    Args:
        output_size (int): Desired output size.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, int)
        self.output_size = output_size
        
        # Create an input_size and weigths parameter to cache previous weights
        self.input_size = 0
        self.kernel_size = None
        self.weights = None
    
    def __call__(self, img):
        input_size = img.shape[-1]
        
        if input_size == self.output_size:
            # The sizes are the same, return itself
            return img
        
        if input_size != self.input_size or self.input_size == 0 or self.kernel_size is None or self.weights is None:          
            # New input_size we cannot use the cache
            self.input_size = input_size
            
            if input_size == self.output_size:
                raise ValueError(f"The desired output size {output_size} is the same as the input size {input_size}.")

            kernel_size = input_size/self.output_size

            if kernel_size%2 != 0:
                raise ValueError(f"The desired output size {output_size} is not a multiple of 2 of the input size {input_size}")

            # The kernel size seems to be valid, convert it to int
            self.kernel_size = int(kernel_size)
            self.input_size = input_size

            # Generate the weights for the conv2d
            weights = torch.ones((self.kernel_size, self.kernel_size))
            self.weights = weights.view(1, 1, self.kernel_size, self.kernel_size).repeat(1, 1, 1, 1)

        # The conv2d needs the data to be in minibatches and have dimensions [1, x, x]
        x = torch.unsqueeze(img, axis=0)
        x = torch.unsqueeze(x, axis=0)
        
        output = torch.nn.functional.conv2d(x, self.weights, stride=self.kernel_size)
        
        # Return the result from the minibatch
        return output[0][0]

    
class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        crop_p (float, list): Desired crop percentage, if list it will randomly sample the crop percentage from the list
    """

    def __init__(self, crop_p):
        assert isinstance(crop_p, (float, list))
        self.crop_p = crop_p

    def __call__(self, image):
        h, w = image.shape[:2]
        
        if type(self.crop_p) is list:
            crop_p = np.random.choice(self.crop_p)
        else:
            crop_p = self.crop_p
        
        new_h = int(h*crop_p)
        new_w = int(w*crop_p)

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)       
        
        image = image[top: top + new_h,
                      left: left + new_w]

        return image
    
class Normalize(object):
    """Normalize and image optionally based on a strectching function. First apply the strectching function then normalize to max = 1. 
        The minimum possible value is 0 independent of the image, the image is thus only normalized based on the max value.
        Returns the normalized image and the max value

    Args:
         stretch_f (string) (optional) : The stretching function options: linear, sqrt. 
    """

    def __init__(self, stretch_f="linear"):
        assert isinstance(stretch_f, str)
        if stretch_f == 'linear':
            self.stretch_f = torch.nn.Identity
        elif stretch_f == 'sqrt':
            self.stretch_f = torch.sqrt
        else:
            raise ValueError(f"Stretching function {stretch_f} is not implemented")

    def __call__(self, image):
        # Calculate the max value
        max_val = torch.max(image)
        
        # Normalize the image
        image = image/max_val
        
        # Apply the stretching function
        image = self.stretch_f(image)

        return image, max_val.item()

class DeNormalize(object):
    """DeNormalize and image optionally based on the originala strectching function and the original max value. 
        Returns the denormalized image

    Args:
         stretch_f (string) (optional) : The stretching function options: linear, sqrt. 
    """

    def __init__(self, stretch_f="linear"):
        assert isinstance(stretch_f, str)
        if stretch_f == 'linear':
            self.stretch_f = torch.nn.Identity
        elif stretch_f == 'sqrt':
            self.stretch_f = lambda x: torch.pow(x, 2)
        else:
            raise ValueError(f"Stretching function {stretch_f} is not implemented")

    def __call__(self, image, max_val):
        # Aply the opposite stretching function
        image = self.stretch_f(image)
       
        # multiply the image with the max val
        image = image*max_val
        
        return image
    
class TngDataset(Dataset):
    """Illustrius TNG X ray simulated images dataset"""
    
    def __init__(self, root_dir, lr_res, hr_res, normalize=None, transform=None, preload=False):
        """
        Args:
            root_dir (string): Directory with all the fits images.
            lr_res (int): The low resolution
            hr_res (int): The high resolution
            normalize (string) (optional): The normalization method, options: linear, sqrt. 
                                            When None no normalization will be done
            transform (callable) (optional): Optional transform to be applied
                                                on a sample.
            preload (boolean) (optinal): Preload the data into ram
                
        """
        self.root_dir = root_dir
        self.lr_res = lr_res
        self.hr_res = hr_res
        self.transform = transform
        self.stretch_f = normalize
        self.preload = preload
        
        self.fits_files = []
        # Only save the files that end with .fits
        for file in os.listdir(root_dir):
            if file.endswith(".fits"):
                self.fits_files.append(file)
                
        if preload:
            # Preload all the image
            print("Preloading the fits images")
            self.fits_images = []
            
            for i in tqdm(range(len(self.fits_files))):
                fits_file = self.fits_files[i]
                self.fits_images.append(self.load_fits(fits_file))
        
        # Create the downsample sum classes
        self.downsample_lr = DownsampleSum(output_size=lr_res)
        self.downsample_hr = DownsampleSum(output_size=hr_res)
        
        if normalize:
            # Create the normalization class
            self.normalize = Normalize(normalize)
        else:
            self.normalize = None
        
    def load_fits(self, fits_file):
        hdu = fits.open(os.path.join(self.root_dir, fits_file))
        # Extract the image data from the fits file and convert to float 
        # (these images will be in int but since we will work with floats in pytorch we convert them to float)
        img = hdu['PRIMARY'].data.astype(np.float32)
        hdu.close()
        
        return img
        
    def __len__(self):
        return len(self.fits_files)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        if self.preload:
            fits_file = self.fits_files[idx]
            # Load the preloaded image
            img = self.fits_images[idx]
        else:
            fits_file = self.fits_files[idx]
            img = self.load_fits(fits_file)
        
        # Convert the img to tensor
        img = torch.tensor(img)
        
        # Apply the transformations
        if self.transform:
            img = self.transform(img)
            
        # Downsample to the desired image resolutions
        lr_img = self.downsample_lr(img)
        hr_img = self.downsample_hr(img)
        
        # Apply the normalization
        if self.normalize:
            lr_img, lr_max = self.normalize(lr_img)
            hr_img, hr_max = self.normalize(hr_img)
        else:
            lr_max = torch.max(lr_img).item()
            hr_max = torch.max(hr_img).item()
        
        # Torch needs the data to have dimensions [1, x, x]
        lr_img = torch.unsqueeze(lr_img, axis=0)
        hr_img = torch.unsqueeze(hr_img, axis=0)
        
        sample = {'lr': lr_img, 'hr': hr_img, 'source': fits_file, 'lr_max': lr_max, 'hr_max': hr_max, 'stretch_f': self.stretch_f}
        
        return sample

In [5]:
def plot_sample(sample):
    # Plot one sample
    images = [sample['lr'].numpy(), sample['hr'].numpy()]
    if len(images[0].shape) == 3:
        # The dimensions are probably [1, x, x]. Make them [x, x]
        images = [x[0] for x in images]
        
    titles = [f"lr {images[0].shape}",f"hr {images[1].shape}"]
    plot_multiple_img(images = images, titles = titles)
    
def preview_dataset(dataset, n=10):
    # Plot n samples randomly chosen from the dataset
    for i in range(n):
        index = np.random.randint(len(dataset))
        sample = dataset[index]
        plot_sample(sample)

## Create the dataset and load the data
For all the possible prebuild transforms see: https://pytorch.org/vision/stable/transforms.html

# TODO: the current image strechting and normalization is probably not suitable for the reconversion to fits, this is a temporary solution. This is also per image normalization, not ideal

In [6]:
# # Dataset settings
# lr_res = 32
# hr_res = 128
# batch_size = 64
# crop_p = [0.25, 0.5]
# # crop_p = [0.125, 0.25, 0.5] #This will random sample from multiple crop percentages
# normalize = 'sqrt' #Other options None and 'linear'
# # normalize = None
# preload = False #Warning this increases the ram usage dramatically

# # Prepare the data transforms
# data_transform = transform=transforms.Compose([
#                                         RandomCrop(crop_p = crop_p),
#                                         transforms.RandomHorizontalFlip(),
#                                         transforms.RandomVerticalFlip(),
#                                     ])

# tng_dataset = TngDataset(root_dir = root_dir, lr_res = lr_res, hr_res = hr_res, normalize=normalize, transform=data_transform, preload=preload)

# train_val_test_split = [0.7, 0.15, 0.15]
# train_len = int(len(tng_dataset)*train_val_test_split[0])
# val_len = int(len(tng_dataset)*train_val_test_split[1])
# test_len = len(tng_dataset) - train_len - val_len

# # Note that the test set has the same transformations as the train an validation set
# tng_datasets = torch.utils.data.random_split(tng_dataset, [train_len, val_len, test_len], generator=torch.Generator().manual_seed(42))
# tng_datasets = {'train':tng_datasets[0], 'val':tng_datasets[1], 'test':tng_datasets[2]}

# print("Train set size:", len(tng_datasets['train']))
# print("Validation set size:", len(tng_datasets['val']))
# print("Test set size:", len(tng_datasets['test']))

# dataloaders = {x: torch.utils.data.DataLoader(tng_datasets[x], batch_size=batch_size,
#                                              shuffle=True, num_workers=0)
#               for x in ['train', 'val', 'test']}

## Visualize a few images from the dataset

In [7]:
# plot_sample(tng_datasets['train'][0])

In [8]:
# len(tng_datasets['train'][0]['lr'].shape)

In [9]:
# preview_dataset(tng_datasets['train'], 4)

In [10]:
# sample_lr = tng_datasets['train'][0]['lr']

In [11]:
# plot_img(sample['lr'][0])

In [12]:
# sample_lr.shape

In [13]:
# m = nn.Upsample(scale_factor=2, mode='nearest')

In [14]:
# model = SRCNN()

In [15]:
# m(torch.unsqueeze(sample_lr, axis=0)).shape

In [16]:
# model(torch.unsqueeze(sample_lr, axis=0)).shape

In [17]:
# # Demonstrate the transformations
# for i in range(6):
#     sample = tng_datasets['train'][i]
#     plot_sample(sample)

### TODO: implement upsamplig comparison: https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html

# Super Resolution Model

Basic idea from: https://debuggercafe.com/image-super-resolution-using-deep-learning-and-pytorch/

In [18]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision.models import vgg19
import math


class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])

    def forward(self, img):
        return self.vgg19_54(img)


class DenseResidualBlock(nn.Module):
    """
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    """

    def __init__(self, filters, res_scale=0.2):
        super(DenseResidualBlock, self).__init__()
        self.res_scale = res_scale

        def block(in_features, non_linearity=True):
            layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
            if non_linearity:
                layers += [nn.LeakyReLU()]
            return nn.Sequential(*layers)

        self.b1 = block(in_features=1 * filters)
        self.b2 = block(in_features=2 * filters)
        self.b3 = block(in_features=3 * filters)
        self.b4 = block(in_features=4 * filters)
        self.b5 = block(in_features=5 * filters, non_linearity=False)
        self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]

    def forward(self, x):
        inputs = x
        for block in self.blocks:
            out = block(inputs)
            inputs = torch.cat([inputs, out], 1)
        return out.mul(self.res_scale) + x


class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2):
        super(ResidualInResidualDenseBlock, self).__init__()
        self.res_scale = res_scale
        self.dense_blocks = nn.Sequential(
            DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
        )

    def forward(self, x):
        return self.dense_blocks(x).mul(self.res_scale) + x


class GeneratorRRDB(nn.Module):
    def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2):
        super(GeneratorRRDB, self).__init__()

        # First layer
        self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
        # Second conv layer post residual blocks
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
        # Upsampling layers
        upsample_layers = []
        for _ in range(num_upsample):
            upsample_layers += [
                nn.Conv2d(filters, filters * 4, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.PixelShuffle(upscale_factor=2),
            ]
        self.upsampling = nn.Sequential(*upsample_layers)
        # Final output block
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)

        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

In [19]:
# parser = argparse.ArgumentParser()
# parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
# parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
# parser.add_argument("--dataset_name", type=str, default="img_align_celeba", help="name of the dataset")
# parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
# parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
# parser.add_argument("--b1", type=float, default=0.9, help="adam: decay of first order momentum of gradient")
# parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
# parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
# parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
# parser.add_argument("--hr_height", type=int, default=256, help="high res. image height")
# parser.add_argument("--hr_width", type=int, default=256, help="high res. image width")
# parser.add_argument("--channels", type=int, default=3, help="number of image channels")
# parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving image samples")
# parser.add_argument("--checkpoint_interval", type=int, default=5000, help="batch interval between model checkpoints")
# parser.add_argument("--residual_blocks", type=int, default=23, help="number of residual blocks in the generator")
# parser.add_argument("--warmup_batches", type=int, default=500, help="number of batches with pixel-wise loss only")
# parser.add_argument("--lambda_adv", type=float, default=5e-3, help="adversarial loss weight")
# parser.add_argument("--lambda_pixel", type=float, default=1e-2, help="pixel-wise loss weight")
# opt = parser.parse_args()
# print(opt)


config = dict(
    epoch = 0,
    n_epochs = 5000,
    channels = 1,
    residual_blocks = 23,
    project = "esr_gan",
    root_dir = "/home/ssweere/data/sim/tng300_2048/fits",
    runs_dir = "/home/ssweere/remote_home/data/runs", # The run specific folder will be created automatically
    lr_res = 32,
    hr_res = 128,
    batch_size = 64,
    crop_p = [0.5],
    # crop_p = [0.125, 0.25, 0.5] #This will random sample from multiple crop percentages
    normalize = 'sqrt', #Other options None and 'linear'
    # normalize = None
    preload = True, #Warning this increases the ram usage dramatically
    sample_interval = 10,
    lr = 0.0002,
    b1= 0.9,
    b2 = 0.999,
    warmup_batches = 500,
    lambda_adv = 5e-3,
    lambda_pixel = 1e-2
)

In [20]:
def model_pipeline(hyperparameters):

    # tell wandb to get started
    with wandb.init(project=hyperparameters['project'], config=hyperparameters) as run:
        # access all HPs through wandb.config, so logging matches execution!
        config = wandb.config
        
        # make the model, data, and optimization problem
        generator, discriminator, feature_extractor, dataloaders, criterion_GAN, criterion_content, criterion_pixel , optimizer_G, optimizer_D, Tensor, run_path = make(config)
        
#         print("Generator:")
#         print(generator)
        
#         print("Discriminator:")
#         print(discriminator)

        # and use them to train the model
        train(generator, discriminator, feature_extractor, dataloaders, criterion_GAN, criterion_content, criterion_pixel , optimizer_G, optimizer_D, Tensor, run_path, config)

        #       # and test its final performance
        #       test(model, test_loader)

    return model



In [21]:
"""
Super-resolution of CelebA using Generative Adversarial Networks.
The dataset can be downloaded from: https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=0
(if not available there see if options are listed at http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
Instrustion on running the script:
1. Download the dataset from the provided link
2. Save the folder 'img_align_celeba' to '../../data/'
4. Run the sript using command 'python3 esrgan.py'
"""

import argparse
import os
import numpy as np
import math
import itertools
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

def make(config):
    # Make the data
    # Prepare the data transforms
    data_transform = transform=transforms.Compose([
                                            RandomCrop(crop_p = config.crop_p),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.RandomVerticalFlip(),
                                        ])

    tng_dataset = TngDataset(root_dir = config.root_dir, lr_res = config.lr_res, hr_res = config.hr_res, normalize=config.normalize, transform=data_transform, preload=config.preload)

    train_val_test_split = [0.7, 0.15, 0.15]
    train_len = int(len(tng_dataset)*train_val_test_split[0])
    val_len = int(len(tng_dataset)*train_val_test_split[1])
    test_len = len(tng_dataset) - train_len - val_len

    # Note that the test set has the same transformations as the train an validation set
    tng_datasets = torch.utils.data.random_split(tng_dataset, [train_len, val_len, test_len], generator=torch.Generator().manual_seed(42))
    tng_datasets = {'train':tng_datasets[0], 'val':tng_datasets[1], 'test':tng_datasets[2]}

    print("Train set size:", len(tng_datasets['train']))
    print("Validation set size:", len(tng_datasets['val']))
    print("Test set size:", len(tng_datasets['test']))

    dataloaders = {x: torch.utils.data.DataLoader(tng_datasets[x], batch_size=config.batch_size,
                                                 shuffle=True, num_workers=0)
                  for x in ['train', 'val', 'test']}

    # Create the checkpoint and output folder
    run_name = wandb.run.name
    # Create the run folder by combining the runs dir, project name and the run name
    run_path = os.path.join(config.runs_dir, config.project, run_name)
    if not os.path.exists(run_path):
        os.makedirs(run_path)
    print("Run path:", run_path)

    # Also create the checkpoints and figures folder
    checkpoint_path = os.path.join(run_path, "checkpoints")
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
        
    figs_path = os.path.join(run_path, "figures")
    if not os.path.exists(figs_path):
        os.makedirs(figs_path)

   
    # Make the model
    hr_shape = (config.hr_res, config.hr_res)
    
    # Initialize generator and discriminator
    generator = GeneratorRRDB(config.channels, filters=64, num_res_blocks=config.residual_blocks).to(device)
    discriminator = Discriminator(input_shape=(config.channels, *hr_shape)).to(device)
    feature_extractor = FeatureExtractor().to(device)
        
    # Set feature extractor to inference mode
    feature_extractor.eval()

    # Losses
    criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
    criterion_content = torch.nn.L1Loss().to(device)
    criterion_pixel = torch.nn.L1Loss().to(device)

    if config.epoch != 0:
        # Load pretrained models
        generator.load_state_dict(torch.load("saved_models/generator_%d.pth" % config.epoch))
        discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth" % config.epoch))

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=config.lr, betas=(config.b1, config.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=config.lr, betas=(config.b1, config.b2))

    Tensor = torch.Tensor #TODO: torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
        
        
#     # Make the model
#     model = SRCNN().to(device)

#     # Make the loss and optimizer
#     # optimizer
#     optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
#     # loss function 
#     criterion = nn.MSELoss()

    return generator, discriminator, feature_extractor, dataloaders, criterion_GAN, criterion_content, criterion_pixel , optimizer_G, optimizer_D, Tensor, run_path

In [22]:
def train(generator, discriminator, feature_extractor, dataloaders, criterion_GAN, criterion_content, criterion_pixel, 
          optimizer_G, optimizer_D, Tensor, run_path, config):
    # tell wandb to watch what the model gets up to: gradients, weights, and more!
    # This is inspired on: https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pytorch/Simple_PyTorch_Integration.ipynb#scrollTo=bZiTlrNkRKzm
    wandb.watch(generator, log="all", log_freq=10)
    
    wandb.watch(discriminator, log="all", log_freq=10)
    
    
    # ----------
    #  Training
    # ----------
    dataloader = dataloaders['train']

    for epoch in range(config.epoch, config.n_epochs):
        for i, imgs in enumerate(dataloader):

            batches_done = epoch * len(dataloader) + i

            # Configure model input
            imgs_lr = Variable(imgs["lr"].type(Tensor))
            imgs_hr = Variable(imgs["hr"].type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
            fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            optimizer_G.zero_grad()

            # Generate a high resolution image from low resolution input
            gen_hr = generator(imgs_lr)

            # Measure pixel-wise loss against ground truth
            loss_pixel = criterion_pixel(gen_hr, imgs_hr)

            if batches_done < config.warmup_batches:
                # Warm-up (pixel-wise loss only)
                loss_pixel.backward()
                optimizer_G.step()
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]"
                    % (epoch, config.n_epochs, i, len(dataloader), loss_pixel.item())
                )
                continue

            # Extract validity predictions from discriminator
            pred_real = discriminator(imgs_hr).detach()
            pred_fake = discriminator(gen_hr)

            # Adversarial loss (relativistic average GAN)
            loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

            # Content loss
            gen_features = feature_extractor(gen_hr)
            real_features = feature_extractor(imgs_hr).detach()
            loss_content = criterion_content(gen_features, real_features)

            # Total generator loss
            loss_G = loss_content + config.lambda_adv * loss_GAN +configopt.lambda_pixel * loss_pixel

            loss_G.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            pred_real = discriminator(imgs_hr)
            pred_fake = discriminator(gen_hr.detach())

            # Adversarial loss for real and fake images (relativistic average GAN)
            loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
            loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

            # Total loss
            loss_D = (loss_real + loss_fake) / 2

            loss_D.backward()
            optimizer_D.step()

            # --------------
            #  Log Progress
            # --------------

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, content: %f, adv: %f, pixel: %f]"
                % (
                    epoch,
                    config.n_epochs,
                    i,
                    len(dataloader),
                    loss_D.item(),
                    loss_G.item(),
                    loss_content.item(),
                    loss_GAN.item(),
                    loss_pixel.item(),
                )
            )

            # Log to wandb
            wandb.log({"epoch": epoch, "batch":batches_done, "loss_D": loss_D.item(), "loss_G": loss_G.item(), 
                       "loss_content": loss_content.item(), "loss_GAN": loss_GAN.item(), "loss_pixel": loss_pixel.item()})



            if batches_done % config.sample_interval == 0:
                # Save image grid with upsampled inputs and ESRGAN outputs
                imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
                img_grid = denormalize(torch.cat((imgs_lr, gen_hr), -1))

                images = wandb.Image(img_grid)
                wandb.log({"img_grid": images})
                          
                images = []
                for i in range(len(imgs_lr)):
                    images += [imgs_lr[i][0], gen_hr[i][0], imgs_hr[i][0]]

                # TODO: do this in the plot img function
                #                     lr = np.repeat(image_lr[i][0], int(image_hr.shape[-1]/image_lr.shape[-1]), axis=0)
                #                     lr = np.repeat(lr, int(image_hr.shape[-1]/image_lr.shape[-1]), axis=1)


                #                 titles += ['Input','Generated', "Target"]
                #                 plot_multiple_img([image_lr[i][0], outputs[i][0], image_hr[i][0]], titles = ['Input','Generated', "Target"])

                plot_multiple_img(images, titles = ['Input','Generated', "Label"], cols=3, merge=True, colorbar=False, figsize=(3*5, len(outputs)*5), save_fig=save_results, show_plot=show_results, wandb_log=True)
            

    #             save_image(img_grid, os.path.join(run_path, "figures",f"training_{batches_done}.png"), nrow=1, normalize=False)

            if batches_done % config.checkpoint_interval == 0:
                # Save model checkpoints
                torch.save(generator.state_dict(), os.path.join(run_path, "checkpoints" , f"generator_{epoch}.pt"))
                torch.save(discriminator.state_dict(), os.path.join(run_path, "checkpoints" , "discriminator_{epoch}.pt"))

In [23]:
# Build, train and analyze the model with the pipeline
model = model_pipeline(config)

  0%|          | 3/1200 [00:00<00:43, 27.28it/s]

Preloading the fits images


100%|██████████| 1200/1200 [00:33<00:00, 35.77it/s]


Train set size: 840
Validation set size: 180
Test set size: 180
Run path: /home/ssweere/remote_home/data/runs/esr_gan/sage-flower-20
[Epoch 0/5000] [Batch 0/14] [G pixel: 0.236737]
[Epoch 0/5000] [Batch 1/14] [G pixel: 0.211381]
[Epoch 0/5000] [Batch 4/14] [G pixel: 0.121034]
[Epoch 0/5000] [Batch 5/14] [G pixel: 0.118620]
[Epoch 0/5000] [Batch 6/14] [G pixel: 0.084931]
[Epoch 0/5000] [Batch 7/14] [G pixel: 0.073117]
[Epoch 0/5000] [Batch 8/14] [G pixel: 0.075345]
[Epoch 0/5000] [Batch 9/14] [G pixel: 0.067249]
[Epoch 0/5000] [Batch 10/14] [G pixel: 0.062188]
[Epoch 0/5000] [Batch 11/14] [G pixel: 0.063626]
[Epoch 0/5000] [Batch 12/14] [G pixel: 0.064401]
[Epoch 0/5000] [Batch 13/14] [G pixel: 0.057113]
[Epoch 1/5000] [Batch 0/14] [G pixel: 0.057759]
[Epoch 1/5000] [Batch 1/14] [G pixel: 0.065911]
[Epoch 1/5000] [Batch 2/14] [G pixel: 0.058643]
[Epoch 1/5000] [Batch 3/14] [G pixel: 0.055523]
[Epoch 1/5000] [Batch 4/14] [G pixel: 0.054242]
[Epoch 1/5000] [Batch 5/14] [G pixel: 0.054791]

VBox(children=(Label(value=' 0.10MB of 0.10MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
_runtime,11782
_timestamp,1620932735
_step,0


0,1
_runtime,▁
_timestamp,▁
_step,▁


RuntimeError: Given groups=1, weight of size [64, 3, 3, 3], expected input[64, 1, 128, 128] to have 3 channels, but got 1 channels instead