# WGAN (2D)
---

<font size = 4>**Wasserstein Generative Adversarial Network** (WGAN) is an alternative to traditional GAN training, it was published by [Arjovsky, Martin and Chintala, Soumith and Bottou, Leon](http://proceedings.mlr.press/v70/arjovsky17a/arjovsky17a.pdf). This network aims to recover a high-resolution (HR) image from a low-resolution (LR) image and in order to achieve it, a new loss function is proposed: Wasserstein distance or Earth Mover's Distance. They claim a better stability of learning wich is one of the main problems in GAN training.

---

<font size = 4>*Disclaimer*:

<font size = 4>This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by [Iván Hidalgo Cenalmor](https://github.com/IvanHCenalmor), [Pablo Alonso Pérez](https://www.linkedin.com/in/palonso998/?originalSubdomain=es), [Estibaliz Gómez de Mariscal](https://github.com/esgomezm) and [Ignacio Argnada-Carreras](https://sites.google.com/site/iargandacarreras/).

<font size = 4>This notebook is largely based on the paper:

<font size = 4>**Wasserstein Generative Adversarial Networks** from M. Arjovsky, S. Chintala, and L. Bottou: https://proceedings.mlr.press/v70/arjovsky17a.html

<font size = 4>**The Original code** is freely available in GitHub:
https://github.com/martinarjovsky/WassersteinGAN

<font size = 4>**Please also cite this original paper when using or developing this notebook.**

# **How to use this notebook?**

---

<font size = 4>Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:
  - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook
  - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook


---
### **Structure of a notebook**

<font size = 4>The notebook contains two types of cell:  

<font size = 4>**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.

<font size = 4>**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.

---
### **Table of contents, Code snippets** and **Files**

<font size = 4>On the top left side of the notebook you find three tabs which contain from top to bottom:

<font size = 4>*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.

<font size = 4>*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.

<font size = 4>*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. 

<font size = 4>**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.

<font size = 4>**Note:** The "sample data" in "Files" contains default files. Do not upload anything in here!

---
### **Making changes to the notebook**

<font size = 4>**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.

<font size = 4>To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).
You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment.

# **0. Before getting started**
---

<font size = 4>**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.

<font size = 4> **Additionally, the corresponding input and output files need to have the same name**.

<font size = 4> Please note that you currently can **only use .tif files!**


<font size = 4>Here's a common data structure that can work:
*   Experiment A
    - **Training dataset**
      - Low resolution (LR) images (Training_source)
        - img_1.tif, img_2.tif, ...
      - High resolution (HR) images (Training_target)
        - img_1.tif, img_2.tif, ...
    - **Quality control dataset**
     - Low resolution (LR) images
        - img_1.tif, img_2.tif
      - High resolution (HR) images
        - img_1.tif, img_2.tif
    - **Data to be predicted**
    - **Results**

---
<font size = 4>**Important note**

<font size = 4>- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.

<font size = 4>- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.

<font size = 4>- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.
---

# **1. Install WGAN and dependencies**
---

In [None]:
#@markdown ##Install Network and dependencies
Notebook_version = '1.15.1'
Network = 'WGAN 2D'

!pip install pytorch-lightning==1.6.4 --quiet

from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import torchvision
import random
import shutil
import csv
import os

#Create a variable to get and store relative base path
base_path = os.getcwd()

from collections import defaultdict, OrderedDict
from tqdm import tqdm
from datetime import datetime

import skimage
from skimage import transform
from skimage import filters
from skimage import metrics
from skimage import io
from skimage.util import random_noise
from skimage.util import img_as_ubyte
from skimage.util import img_as_uint

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.loggers import CSVLogger

import tensorflow as tf
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from tensorflow.image import ssim_multiscale as mssim

from pip._internal.operations.freeze import freeze
import subprocess

from ipywidgets import interact, interactive, fixed, interact_manual
from scipy.ndimage.interpolation import zoom as npzoom
from torchvision.utils import save_image
from builtins import any as b_any

###
def get_requirements_path():
    # Store requirements file in 'base_path' directory 
    current_dir = os.getcwd()
    dir_count = current_dir.count('/') - 1
    path = '../' * (dir_count) + 'requirements.txt'
    return path
def filter_files(file_list, filter_list):
    filtered_list = []
    for fname in file_list:
        if b_any(fname.split('==')[0] in s for s in filter_list):
            filtered_list.append(fname)
    return filtered_list

def build_requirements_file(before, after):
    path = get_requirements_path()

    # Exporting requirements.txt for local run
    !pip freeze > $path

    # Get minimum requirements file
    df = pd.read_csv(path)
    mod_list = [m.split('.')[0] for m in after if not m in before]
    req_list_temp = df.values.tolist()
    req_list = [x[0] for x in req_list_temp]

    # Replace with package name and handle cases where import name is different to module name
    mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]
    mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] 
    filtered_list = filter_files(req_list, mod_replace_list)

    file=open(path,'w')
    for item in filtered_list:
        file.writelines(item + '\n')

    file.close()
# Initialize requirements storage
import sys
before = [str(m) for m in sys.modules]

def set_seed(seed=56):
    ''' Sets the seed for the random operation in all 
        the packagees that use it.
    Args:
        seed (int): seed that will be established.
    '''
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed()

###

def em_crappify(img, scale):
    ''' Degradates and downscales the given image, simuating
        the degradation that a electron microscopy would suffer.
        Source https://www.biorxiv.org/content/10.1101/740548v3.
    Args:
        img (array): image that whose low resolution counterpart will be calcualted.
        scale (int): scale factor for the image downscaling. Example: 2.
    Returns:
        Low resolution image obtained from degradatin the received high resolution image.
    '''
    img = filters.gaussian(img, sigma=3) + 1e-6
    return npzoom(img, 1/scale, order=1)

def fluo_crappify(img,scale):
    ''' Degradates and downscales the given image, simuating
        the degradation that a fluorescence microscopy would suffer.
        Source https://www.biorxiv.org/content/10.1101/740548v3.
    Args:
        img (array): image that whose low resolution counterpart will be calcualted.
        scale (int): scale factor for the image downscaling. Example: 2.
    Returns:
        Low resolution image obtained from degradatin the received high resolution image.
    '''
    img = random_noise(img, mode='salt', amount=0.005)
    img = random_noise(img, mode='pepper', amount=0.005)
    img = filters.gaussian(img, sigma=5) + 1e-10
    return npzoom(img, 1/scale, order=1)
  
###

def calculate_down_factor (lr_imgs_basedir, hr_imgs_basedir):
    _, hr_extension = os.path.splitext(os.listdir(hr_imgs_basedir)[0])
    hr_filenames = [hr_imgs_basedir + '/' + x for x in os.listdir(hr_imgs_basedir) if x.endswith(hr_extension)]

    _, lr_extension = os.path.splitext(os.listdir(lr_imgs_basedir)[0])
    lr_filenames = [lr_imgs_basedir + '/' + x for x in os.listdir(lr_imgs_basedir) if x.endswith(lr_extension)]

    hr_img = io.imread(hr_filenames[0])
    lr_img = io.imread(lr_filenames[0])

    down_scale_x = hr_img.shape[0] / lr_img.shape[0]
    down_scale_y = hr_img.shape[1] / lr_img.shape[1]

    return [down_scale_x, down_scale_y]

###

class bcolors:
    '''Definition of a color for the displays.'''
    WARNING = '\033[31m'

class ToTensor(object):
    '''Convert ndarrays in sample to Tensors.'''
    def __call__(self, sample):
        hr, lr = sample['hr'], sample['lr']

        # Pytorch is (batch, channels, width, height)
        hr = hr.transpose((2, 0, 1))
        lr = lr.transpose((2, 0, 1))
        return {'hr': torch.from_numpy(hr),
                'lr': torch.from_numpy(lr)}

class RandomHorizontalFlip(object):
    '''Random horizontal flip.'''
    def __init__(self):
        self.rng = np.random.default_rng()

    def __call__(self, sample):
        hr, lr = sample['hr'], sample['lr']

        if self.rng.random() < 0.5:
            hr = np.flip(hr, 1)
            lr = np.flip(lr, 1)

        return {'hr': hr.copy(),
                'lr': lr.copy()}

class RandomVerticalFlip(object):
    '''Random vertical flip.'''
    def __init__(self):
        self.rng = np.random.default_rng()

    def __call__(self, sample):
        hr, lr = sample['hr'], sample['lr']

        if self.rng.random() < 0.5:
            hr = np.flip(hr, 0)
            lr = np.flip(lr, 0)

        return {'hr': hr.copy(),
                'lr': lr.copy()}

class RandomRotate(object):
    '''Random rotation.'''
    def __init__(self):
        self.rng = np.random.default_rng()

    def __call__(self, sample):
        hr, lr = sample['hr'], sample['lr']

        k = self.rng.integers(4)

        hr = np.rot90(hr, k=k)
        lr = np.rot90(lr, k=k)

        return {'hr': hr.copy(),
                'lr': lr.copy()}

class EMDataset(Dataset):
    ''' Pytorch's Dataset type object used to obtain the train and 
        validation information during the training process. Saves the 
        filenames as an attribute and only loads the ones rquired for
        the training batch, reducing the required RAM memory during 
        and after the training.
    '''
    def __init__(self, 
                 patch_size_x, 
                 patch_size_y,
                 down_factor,
                 transform=None, 
                 validation=False, 
                 validation_split=None,
                 hr_imgs_basedir="", 
                 lr_imgs_basedir="",
                 only_high_resolution_data=False,
                 only_hr_imgs_basedir="",
                 type_of_data="Electron microscopy"):

        if only_high_resolution_data:
            used_hr_imgs_basedir = only_hr_imgs_basedir 
        else: 
            used_hr_imgs_basedir = hr_imgs_basedir

        _, hr_extension = os.path.splitext(os.listdir(used_hr_imgs_basedir)[0])

        hr_filenames = [used_hr_imgs_basedir + '/' + x for x in os.listdir(used_hr_imgs_basedir) if x.endswith(hr_extension)]
        hr_filenames.sort()

        if validation_split is not None:
            val_files = int(len(hr_filenames) * validation_split)
            if validation:
                self.hr_img_names = hr_filenames[:val_files]
            else:
                self.hr_img_names = hr_filenames[val_files:]
        else:
            self.hr_img_names = hr_filenames

        if not only_high_resolution_data:
            _, lr_extension = os.path.splitext(os.listdir(lr_imgs_basedir)[0])

            lr_filenames = [lr_imgs_basedir + '/' + x for x in os.listdir(lr_imgs_basedir) if x.endswith(lr_extension)]
            lr_filenames.sort()

            if validation_split is not None:
                val_lr_files = int(len(lr_filenames) * validation_split)
                if validation:
                    self.lr_img_names = lr_filenames[:val_lr_files]
                else:
                    self.lr_img_names = lr_filenames[val_lr_files:]
            else:
                self.lr_img_names = lr_filenames

        self.only_high_resolution_data = only_high_resolution_data

        self.transform = transform
        self.down_factor = down_factor
        self.type_of_data = type_of_data

        self.lr_patch_size_x = patch_size_x
        self.lr_patch_size_y = patch_size_y
        self.hr_patch_size_x = patch_size_x * down_factor
        self.hr_patch_size_y = patch_size_y * down_factor

        self.lr_shape = io.imread(self.hr_img_names[0]).shape[0]//down_factor

    def hr_to_lr(self, hr_img, down_factor, type_of_data):

        if type_of_data == "Electron microscopy":
            lr_img = em_crappify(hr_img, down_factor)
        else:
            lr_img = fluo_crappify(hr_img, down_factor)
        return lr_img

    def __len__(self):
        return len(self.hr_img_names) * (self.lr_shape//self.lr_patch_size_x)**2

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_idx = idx // (self.lr_shape//self.lr_patch_size_x)**2
        
        hr_img = img_as_ubyte(io.imread(self.hr_img_names[img_idx]))
        
        if self.only_high_resolution_data:
            lr_img = self.hr_to_lr(hr_img, self.down_factor, self.type_of_data)
        else:
            lr_img = img_as_ubyte(io.imread(self.lr_img_names[img_idx])) / 255.0

        hr_img = hr_img / 255.0 

        hr_img = hr_img.astype(np.float32)
        lr_img = lr_img.astype(np.float32)

        lr_idx_x = np.random.randint(0, lr_img.shape[0] - self.lr_patch_size_x)
        lr_idx_y = np.random.randint(0, lr_img.shape[1] - self.lr_patch_size_y)

        lr_patch = lr_img[lr_idx_x : lr_idx_x + self.lr_patch_size_x, 
                          lr_idx_y : lr_idx_y + self.lr_patch_size_y]
        lr_patch = lr_patch[:,:,np.newaxis]

        hr_idx_x = lr_idx_x * self.down_factor
        hr_idx_y = lr_idx_y * self.down_factor

        hr_patch = hr_img[hr_idx_x : hr_idx_x + self.hr_patch_size_x, 
                          hr_idx_y : hr_idx_y + self.hr_patch_size_x]

        hr_patch = hr_patch[:,:,np.newaxis]

        sample = {'hr': hr_patch, 'lr': lr_patch}

        if self.transform:
            sample = self.transform(sample)

        return sample

###

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, k=3, p=1):
        super(ResidualBlock, self).__init__()
        self.net = nn.Sequential(
          nn.Conv2d(in_channels, out_channels, kernel_size=k, padding=p),
          nn.PReLU(),
          nn.Conv2d(out_channels, out_channels, kernel_size=k, padding=p),
        )

    def forward(self, x):
        return x + self.net(x)

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, scaleFactor, k=3, p=1):
        super(UpsampleBlock, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * (scaleFactor ** 2), kernel_size=k, padding=p),
            nn.PixelShuffle(scaleFactor),
            nn.PReLU()
        )

    def forward(self, x):
        return self.net(x)
        
class GeneratorUpsample(nn.Module):
    def __init__(self, n_residual=8, down_factor=4):
        super(GeneratorUpsample, self).__init__()
        self.n_residual = n_residual
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=5, padding=2),
            nn.PReLU()
        )
        
        for i in range(n_residual):
            self.add_module('residual' + str(i+1), ResidualBlock(64, 64))
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.PReLU()
        )
        
        upsamples = [UpsampleBlock(64, 2) for x in range(int(np.log2(down_factor)))]
        
        self.upsample = nn.Sequential(
            *upsamples,
            nn.Conv2d(64, 1, kernel_size=5, padding=2)
        )

    def forward(self, lr):
        y = self.conv1(lr)
        cache = y.clone()
        
        for i in range(self.n_residual):
            y = self.__getattr__('residual' + str(i+1))(y)
            
        y = self.conv2(y)
        y = self.upsample(y + cache)
        return torch.tanh(y)

class GeneratorModule(LightningModule):
    def __init__(self, n_residual=8, down_factor=4, lr=0.001):
        super(GeneratorModule, self).__init__()
        
        self.save_hyperparameters()
        
        self.generator = GeneratorUpsample(n_residual=n_residual, down_factor=down_factor)
        
        self.l1loss = nn.L1Loss()
        
    def forward(self, x):
        y = self.generator(x)
        return y
    
    def training_step(self, batch, batch_idx):
        lr, hr = batch['lr'], batch['hr']
        
        fake = self(lr)
        
        loss = self.l1loss(fake, hr)
        
        return loss

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        
        sched = {
            'scheduler': torch.optim.lr_scheduler.OneCycleLR(
                opt, 
                0.0001, 
                epochs=5, 
                steps_per_epoch=712
            ),
            'interval': 'step'
        }
        
        return [opt], [sched]

###

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(

            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
                
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=3, stride=1, padding=1),
            #nn.AdaptiveAvgPool2d(1)
            )
    
    def forward(self, img):
        score = self.model(img)
        return torch.mean(score, dim=(-1,-2,-3))

###

from skimage.metrics import structural_similarity, peak_signal_noise_ratio

class WGANGP(LightningModule):
    def __init__(self,
               g_layers: int = 5,
               d_layers: int = 5,
               recloss: float = 10.0,
               lambda_gp: float = 10.0,
               batchsize: int = 8,
               patch_size: int = 256,
               down_factor: int = 2,
               learning_rate_d: float = 0.0001,
               learning_rate_g: float = 0.0001,
               n_critic_steps: int = 5,
               validation_split: float = 0.1,
               epochs: int = 151,
               rotation: bool = True,
               horizontal_flip: bool = True,
               vertical_flip: bool = True,
               hr_imgs_basedir: str = "", 
               lr_imgs_basedir: str = "",
               only_high_resolution_data: bool = False,
               only_hr_images_basedir: str = "",
               type_of_data: str = "Electron microscopy",
               gen_checkpoint: str = None, 
               save_basedir: str = None,
               only_predict: bool = False
               ):
        super(WGANGP, self).__init__()
        
        self.save_hyperparameters()

        if gen_checkpoint is not None:
            checkpoint = torch.load(gen_checkpoint)
            self.generator = GeneratorModule(n_residual=checkpoint['n_residuals'], down_factor=checkpoint['down_factor'])
            self.generator.load_state_dict(checkpoint['model_state_dict'])
            self.best_valid_loss = checkpoint['best_valid_loss']
        else:
            self.generator = GeneratorModule(n_residual=g_layers, down_factor=down_factor)
            self.best_valid_loss = float('inf')

        self.discriminator = Discriminator()

        self.mae = nn.L1Loss()

        self.opt_g = None
        self.opt_d = None

        if not only_predict:
            self.len_data = self.train_dataloader().__len__()

    def save_model(self, filename):
        if self.hparams.save_basedir is not None:
            torch.save({
                        'model_state_dict': self.generator.state_dict(),
                        'optimizer_state_dict': self.opt_g.state_dict(),
                        'n_residuals': self.hparams.g_layers,
                        'down_factor': self.hparams.down_factor,
                        'best_valid_loss': self.best_valid_loss
                        }, self.hparams.save_basedir + '/' + filename)
        else:
            raise Exception('No save_basedir was specified in the construction of the WGAN object.')

    def forward(self, x):
        return self.generator(x)
  
    def compute_gradient_penalty(self, real_samples, fake_samples):
        ''' Calculates the gradient penalty loss for WGAN GP.
            Source: https://github.com/nocotan/pytorch-lightning-gans
        '''
        # Random weight term for interpolation between real and fake samples
        alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(self.device)
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        interpolates = interpolates.to(self.device)
        d_interpolates = self.discriminator(interpolates)
        fake = torch.Tensor(d_interpolates.shape).fill_(1.0).to(self.device)
        # Get gradient w.r.t. interpolates
        gradients = torch.autograd.grad(
          outputs=d_interpolates,
          inputs=interpolates,
          grad_outputs=fake,
          create_graph=True,
          retain_graph=True,
          only_inputs=True,
        )
        gradients = gradients[0]
        
        gradients = gradients.view(gradients.size(0), -1).to(self.device)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty
  
    def training_step(self, batch, batch_idx, optimizer_idx):
        lr, hr = batch['lr'], batch['hr']

        # Optimize generator
        if optimizer_idx == 0:
            generated = self(lr)

            adv_loss = -1*self.discriminator(generated).mean()
            error = self.mae(generated, hr)

            g_loss = adv_loss + error * self.hparams.recloss

            self.log('g_loss', g_loss, prog_bar=True, on_epoch=True)
            self.log('g_l1', error, prog_bar=True, on_epoch=True)

            return g_loss

        # Optimize discriminator
        elif optimizer_idx == 1:
            generated = self(lr)

            real_logits = self.discriminator(hr).mean()
            fake_logits = self.discriminator(generated).mean()

            gradient_penalty = self.compute_gradient_penalty(hr.data, generated.data)
            
            wasserstein = real_logits - fake_logits
            
            d_loss = -wasserstein + self.hparams.lambda_gp * gradient_penalty

            self.log('d_loss', d_loss, prog_bar=True, on_epoch=True)
            self.log('d_wasserstein', wasserstein, prog_bar=False, on_epoch=True)
            self.log('d_gp', gradient_penalty, prog_bar=False, on_epoch=True)

            return d_loss

    def configure_optimizers(self):
        n_critic = self.hparams.n_critic_steps

        if self.hparams.gen_checkpoint is not None:
            self.opt_g = torch.optim.Adam(self.generator.parameters())
    
            checkpoint = torch.load(self.hparams.gen_checkpoint)
            self.opt_g.load_state_dict(checkpoint['optimizer_state_dict'])
        else:
            self.opt_g = torch.optim.Adam(self.generator.parameters(), lr=self.hparams.learning_rate_g, betas=(0.5,0.9))
        self.opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.hparams.learning_rate_d, betas=(0.5,0.9))
        
        sched_g = {
            'scheduler': torch.optim.lr_scheduler.OneCycleLR(
                self.opt_g, 
                self.hparams.learning_rate_g, 
                epochs=self.hparams.epochs, 
                steps_per_epoch=self.len_data
            ),
            'interval': 'step',
            'name': 'g_lr'
        }
        sched_d = {
            'scheduler': torch.optim.lr_scheduler.OneCycleLR(
                self.opt_d,
                self.hparams.learning_rate_g,
                epochs=self.hparams.epochs,
                steps_per_epoch=self.len_data
            ),
            'interval': 'step',
            'name': 'd_lr'
        }

        return (
            {'optimizer': self.opt_g, 'frequency': 1, 'lr_scheduler': sched_g},
            {'optimizer': self.opt_d, 'frequency': n_critic, 'lr_scheduler': sched_d}
        )

    def validation_step(self, batch, batch_idx):
        # Right now used for just plotting, might want to change it later
        lr, hr = batch['lr'], batch['hr']

        generated = self(lr)

        true = hr.cpu().numpy()
        fake = generated.cpu().numpy()
        
        for i in range(lr.size(0)):
            ssim = structural_similarity(true[i,0,...], fake[i,0,...], data_range=1.0)
            self.log('val_ssim', ssim)
            
            psnr = peak_signal_noise_ratio(true[i,0,...], fake[i,0,...], data_range=1.0)
            self.log('val_psnr', psnr)
            
        return lr, hr, generated
       
    def validation_step_end(self, val_step_outputs):
        # Right now used for just plotting, might want to change it later
        lr, hr, generated = val_step_outputs

        adv_loss = -1*self.discriminator(generated).mean()
        error = self.mae(generated, hr)

        g_loss = adv_loss + error * self.hparams.recloss

        self.log('val_g_loss', g_loss)
        self.log('val_g_l1', error)

        real_logits = self.discriminator(hr).mean()
        fake_logits = self.discriminator(generated).mean()

        wasserstein = real_logits - fake_logits
        
        self.log('val_d_wasserstein', wasserstein)

        if g_loss < self.best_valid_loss:
            self.best_valid_loss = g_loss
            self.save_model('best_checkpoint.pth')
        self.save_model('last_checkpoint.pth')

    def on_train_end(self):
        self.save_model('last_checkpoint.pth')

    def train_dataloader(self):
      
        transformations = []
        
        if self.hparams.horizontal_flip: 
            transformations.append(RandomHorizontalFlip())
        if self.hparams.vertical_flip:
            transformations.append(RandomVerticalFlip())
        if self.hparams.rotation:
            transformations.append(RandomRotate())

        transformations.append(ToTensor())

        transform = torchvision.transforms.Compose(transformations)

        dataset = EMDataset(self.hparams.patch_size, self.hparams.patch_size,
                            self.hparams.down_factor, transform=transform, validation=False, 
                            validation_split=self.hparams.validation_split, 
                            hr_imgs_basedir=self.hparams.hr_imgs_basedir, 
                            lr_imgs_basedir=self.hparams.lr_imgs_basedir,
                            only_high_resolution_data=self.hparams.only_high_resolution_data, 
                            only_hr_imgs_basedir=self.hparams.only_hr_images_basedir,
                            type_of_data=self.hparams.type_of_data)

        return DataLoader(dataset, batch_size=self.hparams.batchsize, 
                          shuffle=True, num_workers=4)
        
    def val_dataloader(self):
        transform = ToTensor()

        dataset = EMDataset(self.hparams.patch_size, self.hparams.patch_size,
                            self.hparams.down_factor, transform=transform, validation=True, 
                            validation_split=self.hparams.validation_split, 
                            hr_imgs_basedir=self.hparams.hr_imgs_basedir,
                            lr_imgs_basedir=self.hparams.lr_imgs_basedir,
                            only_high_resolution_data=self.hparams.only_high_resolution_data, 
                            only_hr_imgs_basedir=self.hparams.only_hr_images_basedir,
                            type_of_data=self.hparams.type_of_data)

        return DataLoader(dataset, batch_size=self.hparams.batchsize, 
                          shuffle=False, num_workers=4)
###

def extract(logger_path, desired_tags):
    summary_iterators = EventAccumulator(logger_path).Reload()
    tags = summary_iterators.Tags()['scalars']

    out = defaultdict(list)
    steps = []

    for tag in tags:
        if tag in desired_tags:
            steps = [e.step for e in summary_iterators.Scalars(tag)]
            out[tag].append([e.value for e in summary_iterators.Scalars(tag)])

    return out, steps

def to_csv(logger_path, csv_path, desired_tags):
    d, steps = extract(logger_path, desired_tags)
    tags, values = zip(*d.items())
    np_values = np.transpose(np.squeeze(np.array(values)))

    df = pd.DataFrame(np_values, index=steps, columns=tags)
    df.to_csv(csv_path)


#--------------------- Display QC Maps --------------------------------

def visualise_image_comparison_QC(image, dimension, Source_folder, Prediction_folder, Ground_truth_folder, QC_folder, QC_scores):
    
    _, image_extension = os.path.splitext(image)

    if image_extension == '.tif':
        img_SSIM_GTvsSource = io.imread(os.path.join(QC_folder, 'SSIM_GTvsSource_'+image))
        img_SSIM_GTvsPrediction = io.imread(os.path.join(QC_folder, 'SSIM_GTvsPrediction_'+image))
        img_RSE_GTvsSource = io.imread(os.path.join(QC_folder, 'RSE_GTvsSource_'+image))
        img_RSE_GTvsPrediction = io.imread(os.path.join(QC_folder, 'RSE_GTvsPrediction_'+image))
    else:
        img_SSIM_GTvsSource = img_as_uint(io.imread(os.path.join(QC_folder, 'SSIM_GTvsSource_'+image)))/65536
        img_SSIM_GTvsPrediction = img_as_uint(io.imread(os.path.join(QC_folder, 'SSIM_GTvsPrediction_'+image)))/65536
        img_RSE_GTvsSource = img_as_uint(io.imread(os.path.join(QC_folder, 'RSE_GTvsSource_'+image)))/65536
        img_RSE_GTvsPrediction = img_as_uint(io.imread(os.path.join(QC_folder, 'RSE_GTvsPrediction_'+image)))/65536

    SSIM_GTvsP_forDisplay = QC_scores.loc[df['image #'] == image, 'Prediction v. GT mSSIM'].tolist()
    SSIM_GTvsS_forDisplay = QC_scores.loc[df['image #'] == image, 'Input v. GT mSSIM'].tolist()
    NRMSE_GTvsP_forDisplay = QC_scores.loc[df['image #'] == image, 'Prediction v. GT NRMSE'].tolist()
    NRMSE_GTvsS_forDisplay = QC_scores.loc[df['image #'] == image, 'Input v. GT NRMSE'].tolist()
    PSNR_GTvsP_forDisplay = QC_scores.loc[df['image #'] == image, 'Prediction v. GT PSNR'].tolist()
    PSNR_GTvsS_forDisplay = QC_scores.loc[df['image #'] == image, 'Input v. GT PSNR'].tolist()


    plt.figure(figsize=(15,15))

    # Source
    plt.subplot(3,3,1)
    plt.axis('off')
    img_Source = io.imread(os.path.join(Source_QC_folder, image))
    plt.imshow(img_Source,'gray')
    plt.title('Source',fontsize=15)

    # Target (Ground-truth)
    plt.subplot(3,3,2)
    plt.axis('off')
    img_GT = io.imread(os.path.join(Target_QC_folder,image))
    plt.imshow(img_GT ,'gray')
    plt.title('Target',fontsize=15)


    #Prediction
    plt.subplot(3,3,3)
    plt.axis('off')
    #img_Prediction = predictions
    img_Prediction = io.imread(os.path.join(prediction_QC_folder, image))
    plt.imshow(img_Prediction,'gray')
    plt.title('Prediction',fontsize=15)

    #Setting up colours
    cmap = plt.cm.CMRmap

    #SSIM between GT and Source
    plt.subplot(3,3,5)
    #plt.axis('off')
    plt.tick_params(
            axis='both',            # changes apply to the x-axis and y-axis
            which='both',        # both major and minor ticks are affected
            bottom=False,        # ticks along the bottom edge are off
            top=False,                # ticks along the top edge are off
            left=False,         # ticks along the left edge are off
            right=False,                 # ticks along the right edge are off
            labelbottom=False,
            labelleft=False)     
    imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)
    plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)
    plt.title('Target vs. Source',fontsize=15)
    plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay[0],3)),fontsize=14)
    plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)


    #SSIM between GT and Prediction
    plt.subplot(3,3,6)
    #plt.axis('off')
    plt.tick_params(
            axis='both',            # changes apply to the x-axis and y-axis
            which='both',        # both major and minor ticks are affected
            bottom=False,        # ticks along the bottom edge are off
            top=False,                # ticks along the top edge are off
            left=False,         # ticks along the left edge are off
            right=False,                 # ticks along the right edge are off
            labelbottom=False,
            labelleft=False)    
    imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)
    plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)
    plt.title('Target vs. Prediction',fontsize=15)
    plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay[0],3)),fontsize=14)


    #Root Squared Error between GT and Source
    plt.subplot(3,3,8)
    #plt.axis('off')
    plt.tick_params(
            axis='both',            # changes apply to the x-axis and y-axis
            which='both',        # both major and minor ticks are affected
            bottom=False,        # ticks along the bottom edge are off
            top=False,                # ticks along the top edge are off
            left=False,         # ticks along the left edge are off
            right=False,                 # ticks along the right edge are off
            labelbottom=False,
            labelleft=False) 
    imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)
    plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)
    plt.title('Target vs. Source',fontsize=15)
    plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay[0],3))+', PSNR: '+str(round(PSNR_GTvsS_forDisplay[0],3)),fontsize=14)   
    plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)

    #Root Squared Error between GT and Prediction
    plt.subplot(3,3,9)
    #plt.axis('off')
    plt.tick_params(
            axis='both',            # changes apply to the x-axis and y-axis
            which='both',        # both major and minor ticks are affected
            bottom=False,        # ticks along the bottom edge are off
            top=False,                # ticks along the top edge are off
            left=False,         # ticks along the left edge are off
            right=False,                 # ticks along the right edge are off
            labelbottom=False,
            labelleft=False) 
    imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)
    plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)
    plt.title('Target vs. Prediction',fontsize=15)
    plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay[0],3))+', PSNR: '+str(round(PSNR_GTvsP_forDisplay[0],3)),fontsize=14)
    plt.savefig(full_QC_model_path+'/QC_example_data.png', bbox_inches='tight',pad_inches=0)


All_notebook_versions = pd.read_csv("https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv", dtype=str)
print('Notebook version: '+Notebook_version)
Latest_Notebook_version = All_notebook_versions[All_notebook_versions["Notebook"] == Network]['Version'].iloc[0]
print('Latest notebook version: '+Latest_Notebook_version)
if Notebook_version == Latest_Notebook_version:
  print("This notebook is up-to-date.")
else:
  print(bcolors.WARNING +"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki")

!pip install fpdf2
from fpdf import FPDF, HTMLMixin

def pdf_export(trained = False, pretrained_model = False):
  class MyFPDF(FPDF, HTMLMixin):
    pass

  pdf = MyFPDF()
  pdf.add_page()
  pdf.set_right_margin(-1)
  pdf.set_font("Arial", size = 11, style='B') 

  #model_name = 'little_CARE_test'
  day = datetime.now()
  datetime_str = str(day)[0:10]

  Header = 'Training report for {} model ({})\nDate: {}'.format(Network, Model_name, datetime_str)
  pdf.multi_cell(180, 5, txt = Header, align = 'L') 
  pdf.ln(1)
    
  # add another cell 
  if trained:
    training_time = "Training time:  {} hour(s) {} min(s) {} sec(s)".format(
                      hours, mins, round(secs))
    pdf.cell(190, 5, txt = training_time, ln = 1, align='L')
  pdf.ln(1)

  Header_2 = 'Information for your materials and method:'
  pdf.cell(190, 5, txt=Header_2, ln=1, align='L')

  all_packages = ''
  for requirement in freeze(local_only=True):
    all_packages = all_packages + requirement + ', '
  #print(all_packages)

  #Main Packages
  main_packages = ''
  version_numbers = []
  for name in ['tensorflow','numpy','Keras']:
    find_name=all_packages.find(name)
    main_packages = main_packages + all_packages[find_name:all_packages.find(',',find_name)]+', '
    #Version numbers only here:
    version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])

  try:
    cuda_version = subprocess.run(["nvcc","--version"],stdout=subprocess.PIPE)
    cuda_version = cuda_version.stdout.decode('utf-8')
    cuda_version = cuda_version[cuda_version.find(', V')+3:-1]
  except:
    cuda_version = ' - No cuda found - '
  try:
    gpu_name = subprocess.run(["nvidia-smi"],stdout=subprocess.PIPE)
    gpu_name = gpu_name.stdout.decode('utf-8')
    gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]
  except:
    gpu_name = ' - No GPU found - '
  #print(cuda_version[cuda_version.find(', V')+3:-1])
  #print(gpu_name)
  #dataset_size = len(os.listdir(Training_source))

  #text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'
  text = ('The '+Network+' model was trained from scratch for '+str(Number_of_epochs)+''
          ' epochs and applying a scaling factor of '+str(Down_factor)+'. Using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2021), '
          'the training was done with '+str(n_patches)+' paired patches in batches '
          'of '+str(Batch_size)+ 'with dimensions: ('+str(input_patch_shape[0])+', '+str(input_patch_shape[1])+')'
          ' the input (low resolution) and ('+str(output_patch_shape[0])+', '+str(output_patch_shape[1])+') the ouput (high resolution). The original paired images'
          'had the dimensions: ('+str(original_input_shape[0])+', '+str(original_input_shape[1])+') the input (low resolution) and'
          '('+str(original_output_shape[0])+', '+str(original_output_shape[1])+') the ouput (high resolution).\n\n'
          'Following the original paper (Gulrajani, F. Ahmed, M. Arjovsky, et al., 2017) '
          'the training loss used in the discriminator has been the Wasserstein distance with gradient '
          'penalty. On the other hand, the training loss used in the generator has been a '
          'combination of the L1 loss (mean absolute error) and the prediction of the discriminator (the adversarial loss).\n\n'
          'Key python packages used include: TensorFlow '
          '(v '+version_numbers[0]+'), NumPy (v '+version_numbers[1]+'), Keras '
          '(v '+version_numbers[2]+'), CUDA (v '+cuda_version.replace('\n',' ')+').\n'
          'The training was accelerated using a '+gpu_name+' GPU.')
  if pretrained_model:
    text = ('The '+Network+' model was trained after loading a pretrained model (taken from '+Pretrained_model_path+') and then trained for other '+str(Number_of_epochs)+''
            ' epochs and applying a scaling factor of '+str(Down_factor)+'. Using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2021), '
            'the training was done with '+str(n_patches)+' paired patches in batches '
            'of '+str(Batch_size)+ 'with dimensions: ('+str(input_patch_shape[0])+', '+str(input_patch_shape[1])+')'
            ' the input (low resolution) and ('+str(output_patch_shape[0])+', '+str(output_patch_shape[1])+') the ouput (high resolution). The original paired images'
            'had the dimensions: ('+str(original_input_shape[0])+', '+str(original_input_shape[1])+') the input (low resolution) and'
            '('+str(original_output_shape[0])+', '+str(original_output_shape[1])+') the ouput (high resolution).\n\n'
            'Following the original paper (Gulrajani, F. Ahmed, M. Arjovsky, et al., 2017) '
            'the training loss used in the discriminator has been the Wasserstein distance with gradient '
            'penalty. On the other hand, the training loss used in the generator has been a '
            'combination of the L1 loss (mean absolute error) and the prediction of the discriminator (the adversarial loss).\n\n'
            'Key python packages used include: TensorFlow '
            '(v '+version_numbers[0]+'), NumPy (v '+version_numbers[1]+'), Keras '
            '(v '+version_numbers[2]+'), CUDA (v '+cuda_version.replace('\n',' ')+').\n'
            'The training was accelerated using a '+gpu_name+' GPU.')
    

  pdf.set_font('')
  pdf.set_font_size(10.)
  pdf.multi_cell(180, 5, txt = text, align='L')
  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font("Arial", size = 11, style='B')
  pdf.ln(1)

  if Only_high_resolution_data:
    only_hr_text = ('Only high resolution images have been received for '
                    'the training, therefore, their low resolution images '
                    'have been synthetically generated. The type of the '
                    'images es '+Type_of_data+',consequently a degradation '
                    'function related with that data has been used, that '
                    'function was obtained from Fang, L., Monroe, et al.,2021.')
    
    pdf.ln(3)
    pdf.set_font('Arial', size=10)
    pdf.multi_cell(180, 5, txt = only_hr_text, align='L')
    pdf.ln(1)

  pdf.ln(3)
  pdf.set_font('Arial', size=10)
  pdf.cell(200, 5, txt='The following parameters were used for training:')
  pdf.ln(1)
  html = """ 
  <table width=70% style="margin-left:0px;">
    <tr>
      <th width = 50% align="left">Training Parameter</th>
      <th width = 50% align="left">Value</th>
    </tr>
    <tr>
      <td width = 50%>Number_of_epochs</td>
      <td width = 50%>{0}</td>
    </tr>
    <tr>
      <td width = 50%>Batch_size</td>
      <td width = 50%>{1}</td>
    </tr>
    <tr>
      <td width = 50%>Percentage_validation</td>
      <td width = 50%>{2}</td>
    </tr>
    <tr>
      <td width = 50%>Generator_initial_learning_rate</td>
      <td width = 50%>{3}</td>
    </tr>
    <tr>
      <td width = 50%>Discriminator_initial_learning_rate</td>
      <td width = 50%>{4}</td>
    </tr>
  </table>
  """.format(Number_of_epochs, Batch_size,
             Percentage_validation, Generator_initial_learning_rate,
             Discriminator_initial_learning_rate)
  pdf.write_html(html)

  pdf.ln(1)
  # pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)
  pdf.set_font('')
  pdf.multi_cell(170, 5, txt = Model_path+'/'+Model_name, align = 'L')
  pdf.ln(1)

  pdf.ln(1)
  pdf.cell(60, 5, txt = 'Example Training Images', ln=1)
  pdf.ln(1)
  exp_size = io.imread(full_model_path+'/TrainingDataExample_WGAN2D.png').shape
  pdf.image(full_model_path+'/TrainingDataExample_WGAN2D.png', 
            x = 50, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))
  pdf.ln(1)
  

  pdf.ln(2)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.ln(3)
  pdf.cell(80, 5, txt = 'References:', ln=1)
  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font_size(10.)
  ref_1 = '- ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. "Democratising deep learning for microscopy with ZeroCostDL4Mic." Nature Communications (2021).'
  pdf.multi_cell(190, 5, txt = ref_1, align='L')
  pdf.ln(1)
  ref_2 = '- WGAN: I. Gulrajani, F. Ahmed, M. Arjovsky, et al., "Improved training ofwasserstein gans." Advances in neural information processing systems, vol. 30, 2017.'
  pdf.multi_cell(190, 5, txt = ref_2, align='L')
  pdf.ln(1)
  #'Fang, L., Monroe, F., Novak, S.W. et al. Deep learning-based point-scanning super-resolution imaging. Nat Methods 18, 406–416 (2021).''
  pdf.ln(3)
  reminder = 'Important:\nRemember to perform the quality control step on all newly trained models\nPlease consider depositing your training dataset on Zenodo'
  pdf.set_font('Arial', size = 11, style='B')
  pdf.multi_cell(190, 5, txt=reminder, align='C')
  pdf.ln(1)

  pdf.output(full_model_path+'/'+Model_name+'_training_report.pdf')
  print('------------------------------')
  print('PDF report exported in '+full_model_path+'/'+Model_name+'_training_report.pdf')



def qc_pdf_export():
  class MyFPDF(FPDF, HTMLMixin):
    pass

  pdf = MyFPDF()
  pdf.add_page()
  pdf.set_right_margin(-1)
  pdf.set_font("Arial", size = 11, style='B') 

  day = datetime.now()
  datetime_str = str(day)[0:10]

  Header = 'Quality Control report for '+Network+' model ('+os.path.basename(QC_model_path)+')\nDate: '+datetime_str
  pdf.multi_cell(180, 5, txt = Header, align = 'L') 
  pdf.ln(1)

  all_packages = ''
  for requirement in freeze(local_only=True):
    all_packages = all_packages+requirement+', '



  pdf.set_font('')
  pdf.set_font('Arial', size = 11, style = 'B')
  pdf.ln(2)
  pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')
  pdf.ln(1)

  pdf.ln(2)
  pdf.cell(190, 3, txt = 'Training and validation losses in each epoch:', ln=1, align='L')
  pdf.ln(1)

  if os.path.exists(full_QC_model_path+'/TrainVal_epoch_losses_plots.png'):
    exp_size = io.imread(full_QC_model_path+'/TrainVal_epoch_losses_plots.png').shape
    pdf.image(full_QC_model_path+'/TrainVal_epoch_losses_plots.png', x = 40, y = None, 
              w = round(exp_size[1]/6), h = round(exp_size[0]/6))
  else:
    pdf.set_font('')
    pdf.set_font('Arial', size=10)
    pdf.cell(190, 5, txt='If you would like to see the evolution of the train and validation values in each epoch during training please play the first cell of the QC section in the notebook.')

  pdf.ln(70)
  pdf.cell(190, 3, txt = 'Training losses in each train step:', ln=1, align='L')
  pdf.ln(1)

  if os.path.exists(full_QC_model_path+'/Train_steps_losses_plots.png'):
    exp_size = io.imread(full_QC_model_path+'/Train_steps_losses_plots.png').shape
    pdf.image(full_QC_model_path+'/Train_steps_losses_plots.png', x = 40, y = None, 
              w = round(exp_size[1]/7), h = round(exp_size[0]/7))
  else:
    pdf.set_font('')
    pdf.set_font('Arial', size=10)
    pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function in each train step during training please play the second cell of the QC section in the notebook.')

  pdf.ln(2)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.ln(3)
  pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)
  exp_size = io.imread(full_QC_model_path+'/QC_example_data.png').shape
  pdf.image(full_QC_model_path+'/QC_example_data.png', x = 30, y = None, 
            w = round(exp_size[1]/6), h = round(exp_size[0]/6))
  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font('Arial', size = 11, style = 'B')
  pdf.ln(1)
  pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)
  pdf.set_font('')
  pdf.set_font_size(10.)
  
  pdf.ln(1)
  html = """
  <body>
  <font size="7" face="Courier" >
  <table width=94% style="margin-left:0px;">"""
  with open(full_QC_model_path+"/Quality Control/QC_metrics_"+QC_model_name+".csv", 'r') as csvfile:
    metrics = csv.reader(csvfile, delimiter=',')
    header = next(metrics)
    image = header[0]
    mSSIM_PvsGT = header[1]
    mSSIM_SvsGT = header[2]
    NRMSE_PvsGT = header[3]
    NRMSE_SvsGT = header[4]
    PSNR_PvsGT = header[5]
    PSNR_SvsGT = header[6]
    header = """
    <tr>
    <th width = 10% align="left">{0}</th>
    <th width = 15% align="left">{1}</th>
    <th width = 15% align="center">{2}</th>
    <th width = 15% align="left">{3}</th>
    <th width = 15% align="center">{4}</th>
    <th width = 15% align="left">{5}</th>
    <th width = 15% align="center">{6}</th>
    </tr>""".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)
    html = html+header
    for row in metrics:
      image = row[0]
      mSSIM_PvsGT = row[1]
      mSSIM_SvsGT = row[2]
      NRMSE_PvsGT = row[3]
      NRMSE_SvsGT = row[4]
      PSNR_PvsGT = row[5]
      PSNR_SvsGT = row[6]
      cells = """
        <tr>
          <td width = 10% align="left">{0}</td>
          <td width = 15% align="center">{1}</td>
          <td width = 15% align="center">{2}</td>
          <td width = 15% align="center">{3}</td>
          <td width = 15% align="center">{4}</td>
          <td width = 15% align="center">{5}</td>
          <td width = 15% align="center">{6}</td>
        </tr>""".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))
      html = html+cells
    html = html+"""</body></table>"""
    
  pdf.write_html(html)
  
  pdf.ln(2)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.ln(3)
  pdf.cell(80, 5, txt = 'References:', ln=1)
  pdf.set_font('')
  pdf.set_font_size(10.)
  ref_1 = '- ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. "Democratising deep learning for microscopy with ZeroCostDL4Mic." Nature Communications (2021).'
  pdf.multi_cell(190, 5, txt = ref_1, align='L')
  pdf.ln(1)
  ref_2 = '- Imrpoved WGAN: I. Gulrajani, F. Ahmed, M. Arjovsky, V. Dumoulin, and A. C. Courville, "Improved training ofwasserstein gans." Advances in neural information processing systems, vol. 30, 2017.'
  pdf.multi_cell(190, 5, txt = ref_2, align='L')
  pdf.ln(1)

  pdf.ln(3)
  reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'

  pdf.set_font('Arial', size = 11, style='B')
  pdf.multi_cell(190, 5, txt=reminder, align='C')
  pdf.ln(1)

  pdf.output(full_QC_model_path+'/'+os.path.basename(QC_model_name)+'_QC_report.pdf')


  print('------------------------------')
  print('QC PDF report exported as '+full_QC_model_path+'/'+os.path.basename(QC_model_name)+'_QC_report.pdf')


# Build requirements file for local run
after = [str(m) for m in sys.modules]
build_requirements_file(before, after)

# **2. Initialise the Colab session**
---







## **2.1. Check for GPU access**
---

By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelerator: GPU** *(Graphics processing unit)*


In [None]:
#@markdown ##Run this cell to check if you have GPU access

import tensorflow as tf
if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi

## **2.2. Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. 

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [None]:
#@markdown ##Run this cell to connect your Google Drive to Colab

#@markdown * Click on "Connect to Google Drive". 

#@markdown * Sign in your Google Account. 

#@markdown * Click on "Allow". 

#@markdown * Click on `Files` site on the right. Refresh the site. Your Google Drive folder should now be available here as `gdrive`. 

#mounts user's Google Drive to Google Colab.

from google.colab import drive
drive.mount(base_path + '/gdrive')

# **3. Select your paths and parameters**

---


## **3.1. Setting the main training parameters**
---

<font size = 4>The code below allows the user to enter the paths to where the training data is and to define the training parameters. Note that the execution of the cell will take some time as the images from the folders will be read into memory.

<font size = 5> **Paths for training, predictions and results**

<font size = 4>**`Training_source`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.
  

<font size = 4>**`Only_high_resolution_data`:** In case your dataset only contains high resolution data, he low resolution data will have to be generated artificially by a crappification function. Choose this if your dataset only contains high resolution data. **Default value: False**

<font size = 4>**`Type_of_data`:** In order to generate the low resolution data, it has to be known if the data is from *Electron microscopy* or *Fluorescence*. Based on [Deep Learning-Based Point-Scanning Super-Resolution Imaging](https://www.biorxiv.org/content/10.1101/740548v8) paper, this has to be known to be able of using the suitable crappification function.  **Default value: Electron microscopy**

<font size = 4>**`Down_factor`:** Scaling factor by which every dimension of the HR images is reduced. For example, if an HR image dimension is 256x256, and  its LR conterpart is 128x128, the down_factor is 2. Typical values are 2, 4, etc. This is a **critical parameter** that depends on the acquisition of your LR and HR images. **Default value: 2**

<font size = 4>**`Training_high_resolution_folder`:** If *Only_high_resolution_data* has been selected, previous folder paths will not be operative and the path to the high resolution folder will have to be inserted here.

<font size = 4>**`Model_name`:** Use only my_model -style, not my-model (Use "_" not "-"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.

<font size = 4>**`Model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).

<font size = 5>**Training parameters**

<font size = 4>**`Number_of_epochs`:** Input how many epochs (rounds) the network will be trained on. Since we use a fast-convergence algorithm, preliminary results can already be observed after 5-10 epochs, and full training could be achieved with as few as 15-20 epochs. Evaluate the  performance after training (see 5.). **Default value: 10**


<font size = 5>**Advanced parameters - experienced users only**

<font size =4>**`Batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 2**

<font size = 4>**`Percentage_validation`:**  Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** 

<font size = 4>**`Generator_initial_learning_rate`:**  Input the initial value to be used as learning rate for the generator network. **Default value: 0.0001**

<font size = 4>**`Discriminator_initial_learning_rate`:**  Input the initial value to be used as learning rate for the discriminator network. **Default value: 0.0001**

<font size = 4>**`Source_patch_size`:** Size in both X and Y axis for the training patches that will be taken from each training image. A lower value would end up in a smaller patch and consequently in a faster training.  **Default value: 128**



In [None]:
#@markdown ###Path to training images:

Training_source = "" #@param {type:"string"}
Training_target = "" #@param {type:"string"}


#@markdown #####In case you do not have paired LR and HR images, instead of filling the two parameters above completer the ones below: 

Only_high_resolution_data = False #@param {type:"boolean"}
Type_of_data = "Flourescence" #@param ["Electron microscopy", "Flourescence"]
Training_high_resolution_folder = "" #@param {type:"string"}

Down_factor = 2 #@param {type:"number"}
if not Only_high_resolution_data:
    down_factor_x, down_factor_y = calculate_down_factor(Training_source, Training_target)
    if down_factor_x != down_factor_y:
        raise ValueError("Down factors on the width and height of the received images are not the same")
    else:
        Down_factor = down_factor_x
    if Down_factor % 1 !=0:
        raise ValueError("Down factor between hR and LR images is not a whole number.")
    else:
        Down_factor = int(Down_factor)
        

# Model name and path
#@markdown ###Names of both the model and path to the folder:
Model_name = "" #@param {type:"string"}
Model_path = "" #@param {type:"string"}

# Other parameters for training.
#@markdown ###Training Parameters
#@markdown Number of epochs:

Number_of_epochs = 100#@param {type:"number"}

#@markdown ###Advanced Parameters

Use_Default_Advanced_Parameters = False #@param {type:"boolean"}
#@markdown ###If not, please input:

Batch_size =   8#@param {type:"number"}
Percentage_validation =  10 #@param {type:"number"}
Generator_initial_learning_rate = 0.0001 #@param {type:"number"}
Discriminator_initial_learning_rate = 0.0001 #@param {type:"number"}

Source_patch_size =  128#@param {type:"number"}

if (Use_Default_Advanced_Parameters): 
  print("Default advanced parameters enabled")
  Batch_size = 2
  Percentage_validation = 10
  Generator_initial_learning_rate = 0.0001
  Discriminator_initial_learning_rate = 0.0001
  Source_patch_size =  128

# target_patch_size = source_patch_size * down_factor

#Here we define the percentage to use for validation
val_split = Percentage_validation/100


#here we check that no model with the same name already exist, if so delete
#if os.path.exists(model_path+'/'+model_name):
#  shutil.rmtree(model_path+'/'+model_name)

full_model_path = os.path.join(Model_path, Model_name)
if os.path.exists(full_model_path):
  print(bcolors.WARNING+'!! WARNING: Model folder already exists and will be overwritten if a model is trained !!')
  shutil.rmtree(full_model_path)
  
os.makedirs(full_model_path)
os.makedirs(os.path.join(full_model_path,'Quality Control'))

## **3.2. Data augmentation**
---
<font size = 4>

<font size = 4>Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.

<font size = 4>Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis and Y-Axis. This only works if the images are square in XY.


In [None]:
#@markdown ###Data augmentation
Use_Data_augmentation = True #@param{type:"boolean"}

#@markdown Select this option if you want to use augmentation to increase the size of your dataset

#@markdown **Rotate each image randomly by 90 degrees.**
Rotation = True #@param{type:"boolean"}

#@markdown **Flip each image once around the x and y axis of the stack.**
Horizontal_flip = True #@param{type:"boolean"}
Vertical_flip = True #@param{type:"boolean"}


if Use_Data_augmentation:
  print("Data augmentation enabled")
else:
  print(bcolors.WARNING+"Data augmentation disabled")

## **3.3. Using weights from a pre-trained model as initial weights**
---
<font size = 4>  Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a WGAN model**. 

<font size = 4> This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.

<font size = 4> In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. 

In [None]:
# @markdown ##Loading weights from a pre-trained network

Use_pretrained_model = False #@param {type:"boolean"}

Weights_choice = "last" #@param ["last", "best"]

#@markdown ###If you do, please provide the path to the model folder:
Pretrained_model_path = "" #@param {type:"string"}


# Check if we load a previously trained model
if not Use_pretrained_model:
  print(bcolors.WARNING+'No pretrained network will be used.')

else:
  checkpoint_file_path = os.path.join(Pretrained_model_path, Weights_choice+"_checkpoint.pth")

  # Check the model exist
  if not os.path.exists(checkpoint_file_path):
    print(bcolors.WARNING+'WARNING: Pretrained model does not exist.')
    Use_pretrained_model = False
    print(bcolors.WARNING+'No pretrained network will be used.')
  else:
    print("Pretrained model "+os.path.basename(Pretrained_model_path)+" was found and will be loaded prior to training.")


# **4. Train the network**
---

## **4.1. Prepare the data and model for training**
---
<font size = 4>Here, we use the information from section 3 to build the model and convert the training data into a suitable format for training. A pair or LR-HR training images will be displayed at the end of the process.

In [None]:
#@markdown ##Play this cell to prepare the model for training
if Use_pretrained_model:
  gen_checkpoint = checkpoint_file_path
else:
  gen_checkpoint = None

logger = CSVLogger(full_model_path + '/Quality Control', name='Logger')

lr_monitor = LearningRateMonitor(logging_interval='epoch')
#checkpoints = ModelCheckpoint(monitor='val_ssim', mode='max', save_top_k=3, every_n_train_steps=5, save_last=True, filename="{epoch:02d}-{val_ssim:.3f}")

model = WGANGP(
    g_layers=15, 
    d_layers=5, 
    batchsize=Batch_size,
    patch_size=Source_patch_size,
    down_factor=Down_factor,
    recloss=100.0,
    learning_rate_g=Generator_initial_learning_rate,
    learning_rate_d=Discriminator_initial_learning_rate,
    validation_split = val_split,
    epochs = Number_of_epochs,
    rotation = Rotation,
    horizontal_flip = Horizontal_flip,
    vertical_flip = Vertical_flip,
    hr_imgs_basedir = Training_target, 
    lr_imgs_basedir = Training_source,
    only_high_resolution_data = Only_high_resolution_data,
    only_hr_images_basedir = Training_high_resolution_folder,
    type_of_data = Type_of_data,
    save_basedir = full_model_path,
    gen_checkpoint = gen_checkpoint
)


trainer = Trainer(
    gpus=1, 
    max_epochs=Number_of_epochs, 
    logger=logger, 
    callbacks= [lr_monitor]#[checkpoints, lr_monitor]
)

# For the pdf generation
n_patches = model.len_data * Batch_size

if Only_high_resolution_data:
    hr_basedir = Training_high_resolution_folder 
else: 
    hr_basedir = Training_target

_, hr_extension = os.path.splitext(os.listdir(hr_basedir)[0])
hr_filenames = [hr_basedir + '/' + x for x in os.listdir(hr_basedir) if x.endswith(hr_extension)]

original_output_shape = np.array(io.imread(hr_filenames[0]).shape)
original_input_shape = original_output_shape // Down_factor

data = iter(model.train_dataloader()).next()
lr = data['lr'][0][0]
hr = data['hr'][0][0]

input_patch_shape = np.array(lr.shape)
output_patch_shape = np.array(lr.shape)

print('Example of the patches that will be used for training:')
f = plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(lr, 'gray')
plt.title('Input patch (low resolution)')
plt.axis('off');

plt.subplot(1,2,2)
plt.imshow(hr, 'gray')
plt.title('Ground truth patch (high resolution)')
plt.axis('off');

plt.savefig(full_model_path+'/TrainingDataExample_WGAN2D.png',
            bbox_inches='tight',pad_inches=0)

pdf_export(pretrained_model = Use_pretrained_model)

## **4.2. Start Training**
---
<font size = 4>When playing the cell below you should see updates after each epoch (round). Network training can take some time.

<font size = 4>* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.

<font size = 4>Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder.

In [None]:
#@markdown ##Start training
import time

start = time.time()

trainer.fit(model)

dt = time.time() - start
mins, secs = divmod(dt, 60) 
hours, mins = divmod(mins, 60) 
print("Time elapsed:",hours, "hour(s)",mins,"min(s)",round(secs),"sec(s)")
        
pdf_export(trained = True, pretrained_model = Use_pretrained_model)


# **5. Evaluate your model**
---

<font size = 4>This section allows the user to perform important quality checks on the validity and generalisability of the trained model. 

<font size = 4>**We highly recommend to perform quality control on all newly trained models.**



In [None]:
# model name and path
#@markdown ###Do you want to assess the model you just trained ?
Use_the_current_trained_model = True #@param {type:"boolean"}

#@markdown ###If not, please provide the path to model folder and the scale between the low and high resolution images:

QC_model_folder = "" #@param {type:"string"}
QC_down_factor = 2 #@param {type:"number"}

#Here we define the loaded model name and path
QC_model_name = os.path.basename(QC_model_folder)
QC_model_path = os.path.dirname(QC_model_folder)

if Use_the_current_trained_model: 
  QC_model_name = Model_name
  QC_model_path = Model_path
else:
  Down_factor = QC_down_factor

full_QC_model_path = os.path.join(QC_model_path, QC_model_name)

#print(full_QC_model_path)

if os.path.exists(full_QC_model_path):
  print("The "+QC_model_name+" network will be evaluated")
else:
  W  = '\033[0m'  # white (normal)
  R  = '\033[31m' # red
  print(R+'!! WARNING: The chosen model does not exist !!'+W)
  print('Please make sure you provide a valid model path and model name before proceeding further.')


## **5.1. Inspection of the loss function**
---

<font size = 4>First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*

<font size = 4>**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.

<font size = 4>**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.

<font size = 4>During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.

<font size = 4>Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.

In [None]:
#@markdown ##Play the cell to show the plots of training errors vs. epoch number

logger_path = os.path.join(full_QC_model_path + '/Quality Control/Logger')
all_logger_versions = [os.path.join(logger_path, dname) for dname in os.listdir(logger_path)]
last_logger = all_logger_versions[-1]

train_csv_path = last_logger + '/metrics.csv'

if not os.path.exists(train_csv_path):
  print('The path does not contain a csv file containing the loss and validation evolution of the model')
else:
  with open(train_csv_path,'r') as csvfile:
    csvRead = csv.reader(csvfile, delimiter=',')
    keys = next(csvRead)
    keys.remove('step')
    train_metrics = {'g_lr':[], 'd_lr':[],
                    'g_loss_step':[], 'g_l1_step':[],
                    'd_loss_step':[], 'd_wasserstein_step':[], 'd_gp_step':[],
                    'epoch':[],
                    'val_ssim':[], 'val_psnr':[],
                    'val_g_loss':[], 'val_g_l1':[],
                    'val_d_wasserstein':[],
                    'g_loss_epoch':[], 'g_l1_epoch':[],
                    'd_loss_epoch':[], 'd_wasserstein_epoch':[], 'd_gp_epoch':[]
                    }

    for row in csvRead:
      step = int(row[2])
      row.pop(2)
      for i, row_value in enumerate(row):
        if row_value:
          train_metrics[keys[i]].append([step, float(row_value)])


  epochNumber = range(len(train_metrics['g_loss_epoch']))
  plt.figure(figsize=(12,15))

  plt.subplot(3,1,1)
  plt.tight_layout(w_pad=2.0)
  plt.plot(epochNumber,[e[1] for e in train_metrics['g_loss_epoch']], label='Training loss')
  plt.plot(epochNumber,[e[1] for e in train_metrics['val_g_loss']], label='Validation loss')
  plt.title('Generator\'s total training and validation loss VS epochs', fontsize=16)
  plt.ylabel('Generator loss', fontsize=14)
  plt.xlabel('Epoch number', fontsize=14)
  plt.legend()

  plt.subplot(3,1,2)
  plt.tight_layout(w_pad=2.0)
  plt.plot(epochNumber,[e[1] for e in train_metrics['g_l1_epoch']], label='Training loss')
  plt.plot(epochNumber,[e[1] for e in train_metrics['val_g_l1']], label='Validation loss')
  plt.title('Generator\'s L1 training and validation loss VS epochs', fontsize=16)
  plt.ylabel('L1 loss', fontsize=14)
  plt.xlabel('Epoch number', fontsize=14)
  plt.legend()

  plt.subplot(3,1,3)
  plt.tight_layout(w_pad=2.0)
  plt.plot(epochNumber,[e[1] for e in train_metrics['d_wasserstein_epoch']], label='Training loss')
  plt.plot(epochNumber,[e[1] for e in train_metrics['val_d_wasserstein']], label='Validation loss')
  plt.title('Discriminator\'s Wasserstein training and validation loss VS epochs', fontsize=16)
  plt.ylabel('Discriminator\'s loss', fontsize=14)
  plt.xlabel('Epoch number', fontsize=14)
  plt.legend()

  plt.savefig(full_QC_model_path+'/TrainVal_epoch_losses_plots.png')
  plt.show()


## **5.2. Image predictions**
---


In [None]:
#@markdown ##Choose the folders that contain your Quality Control dataset
Source_QC_folder = "" #@param{type:"string"}
Target_QC_folder = "" #@param{type:"string"}

#@markdown #####In case you do not have paired LR and HR images, instead of filling the two parameters above completer the ones below: 

Only_high_resolution_data = False #@param {type:"boolean"}
Type_of_data = "Flourescence" #@param ["Electron microscopy", "Flourescence"]
QC_high_resolution_folder = "" #@param {type:"string"}

# Create a list of sources

# Insert code to perform predictions on all datasets in the Source_QC folder
if Only_high_resolution_data:
  Source_QC_folder = base_path + "/LR_images"
  Target_QC_folder = QC_high_resolution_folder

# If only HR data, generate a folder to save generated LR data
if Only_high_resolution_data:
    if os.path.exists(Source_QC_folder):
        shutil.rmtree(Source_QC_folder)
    os.makedirs(Source_QC_folder)


_, test_extension = os.path.splitext(os.listdir(Target_QC_folder)[0])
test_filenames = [x for x in os.listdir( Target_QC_folder ) if x.endswith(test_extension)]
test_filenames.sort()

print('Available images : ' + str(len(test_filenames)))

model = WGANGP(
    g_layers=15, 
    d_layers=5,
    batchsize=1,
    patch_size=256,
    down_factor=Down_factor,
    recloss=100.0,
    hr_imgs_basedir=Target_QC_folder,
    lr_imgs_basedir=Source_QC_folder,
    gen_checkpoint = full_QC_model_path + '/best_checkpoint.pth',
    only_predict = True
)

gen = model.generator

del model

gen.eval()

# Save the predictions
prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')
if os.path.exists(prediction_QC_folder):
  shutil.rmtree(prediction_QC_folder)
os.makedirs(prediction_QC_folder)

_MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)

psnr_array = []
ssim_array = []
mssim_array = []

for i in tqdm(range(len(test_filenames))):

    hr_img = img_as_ubyte(io.imread( Target_QC_folder + '/' + test_filenames[i]))

    # If only HR data, simulate the LR counterparts
    if Only_high_resolution_data:
        if Type_of_data == "Electron microscopy":
            lr_img = em_crappify(hr_img, Down_factor)
        else:
            lr_img = fluo_crappify(hr_img, Down_factor)
        tf.keras.preprocessing.image.save_img(Source_QC_folder+'/'+test_filenames[i], 
                                              lr_img, data_format=None, file_format=None)
    else:
        lr_img = img_as_ubyte(io.imread(Source_QC_folder + '/' + test_filenames[i])) / 255

    hr_img = hr_img / 255

    # Transform data to the model
    hr_img = np.expand_dims(hr_img, axis=-1)
    hr_img = hr_img.astype(np.float32)
    hr_img = hr_img.transpose((2, 0, 1))
    hr_img = torch.from_numpy(hr_img)

    lr_img = np.expand_dims(lr_img, axis=-1)
    lr_img = lr_img.astype(np.float32)
    lr_img = lr_img.transpose((2, 0, 1))
    lr_img = torch.from_numpy(lr_img)

    pred = gen(lr_img)
    pred = np.expand_dims(np.squeeze(pred.detach().numpy()),axis=-1)
    pred = np.clip( pred, a_min=0, a_max=1 )

    tf.keras.preprocessing.image.save_img(prediction_QC_folder+'/'+test_filenames[i], 
                                          pred, data_format=None, file_format=None)

    

    hr_image = np.expand_dims(np.squeeze(hr_img.detach().numpy()),axis=-1)
    psnr_array.append(metrics.peak_signal_noise_ratio(pred[:,:,0], hr_image[:,:,0]))
    ssim_array.append(metrics.structural_similarity(pred[:,:,0], hr_image[:,:,0]))
    mssim_array.append(mssim(pred, hr_image.astype('float32'), max_val=1, 
                              power_factors=_MSSSIM_WEIGHTS, filter_size=11,
                              filter_sigma=1.5, k1=0.01, k2=0.03))
    del pred

psnr_mean = np.mean(psnr_array)
ssim_mean = np.mean(ssim_array)
mssim_mean = np.mean(mssim_array)
        
# ------------- For display ------------
print('\n------------------------------------------------------------------------------------------------------------------------------------------------------------------------------')
print("Choose the image file (Once selected, the new set of images will take a few seconds to appear):")
print(" ")
@interact
def show_prediction_results(file = sorted(os.listdir(prediction_QC_folder))):
  imageLR = io.imread(os.path.join(Source_QC_folder, file))
  imageHR = io.imread(os.path.join(Target_QC_folder, file))
  imageP = io.imread(os.path.join(prediction_QC_folder, file))

  plt.figure(figsize=(25,25))
  plt.subplot(3, 3, 1)
  plt.imshow( imageLR, 'gray' )
  plt.title( 'Low resolution' )
  # Side by side with its "ground truth"
  plt.subplot(3, 3, 2)
  plt.imshow( imageHR, 'gray' )
  plt.title( 'High resolution' )
  # ant its prediction
  plt.subplot(3, 3, 3)
  plt.imshow( imageP, 'gray' )
  plt.title( 'Prediction' ) 

print('\n------------------------------------------------------------------------------------------------------------------------------------------------------------------------------')

## **5.3. Error mapping and quality metrics estimation**
---


<font size = 4>This section will calculate the SSIM, PNSR and MSSIM metrics between the predicted and target images to evaluate the quality of the results.

<font size = 4>**1. The SSIM (structural similarity) map** 

<font size = 4>The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better.
<font size=4>**mSSIM** is the SSIM value calculated across the entire window of both images.

<font size = 4>**2. PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.





In [None]:
#@markdown ##Run to calculate the PSNR, SSIM and MSSIM metrics

print("PSNR: ", psnr_mean)
print("SSIM: ", ssim_mean)
print("MSSIM:", mssim_mean)

<font size = 4>The metrics shown in the previous cell are an average of all the test dataset. In the following cell the metrics refer only to the displayed image. You can see all the metrics for each image in a csv file saved in the Quality Control folder.

In [None]:
#@markdown ##SSIM and RSE map
#@markdown Now we will show the SSIM and RSE maps between the original images (upsampled using simple interpolation) and the target images, together with the maps between the predicted and target images.

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity

def ssim(img1, img2):
  return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)


def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):
    """This function is adapted from Martin Weigert"""
    """Percentile-based image normalization."""

    mi = np.percentile(x,pmin,axis=axis,keepdims=True)
    ma = np.percentile(x,pmax,axis=axis,keepdims=True)
    return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)


def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32
    """This function is adapted from Martin Weigert"""
    if dtype is not None:
        x   = x.astype(dtype,copy=False)
        mi  = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)
        ma  = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)
        eps = dtype(eps)

    try:
        import numexpr
        x = numexpr.evaluate("(x - mi) / ( ma - mi + eps )")
    except ImportError:
        x =(x - mi) / ( ma - mi + eps )

    if clip:
        x = np.clip(x,0,1)

    return x

def norm_minmse(gt, x, normalize_gt=True):
    """This function is adapted from Martin Weigert"""

    """
    normalizes and affinely scales an image pair such that the MSE is minimized  
     
    Parameters
    ----------
    gt: ndarray
        the ground truth image      
    x: ndarray
        the image that will be affinely scaled 
    normalize_gt: bool
        set to True of gt image should be normalized (default)
    Returns
    -------
    gt_scaled, x_scaled 
    """
    if normalize_gt:
        gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)
    x = x.astype(np.float32, copy=False) - np.mean(x)
    #x = x - np.mean(x)
    gt = gt.astype(np.float32, copy=False) - np.mean(gt)
    #gt = gt - np.mean(gt)
    x_scaled = transform.resize( x, (gt.shape[0], gt.shape[1]), order=1 )
    
    scale = np.cov(x_scaled.flatten(), gt.flatten())[0, 1] / np.var(x_scaled.flatten())
    return gt, scale * x_scaled
    #return gt, x_scaled


maps_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Maps')
if os.path.exists(maps_QC_folder):
  shutil.rmtree(maps_QC_folder)
os.makedirs(maps_QC_folder)

# Open and create the csv file that will contain all the QC metrics
with open(full_QC_model_path+"/Quality Control/QC_metrics_"+QC_model_name+".csv", "w", newline='') as file:
    writer = csv.writer(file)

    # Write the header in the csv file
    writer.writerow(["image #","Prediction v. GT mSSIM","Input v. GT mSSIM", "Prediction v. GT NRMSE", "Input v. GT NRMSE", "Prediction v. GT PSNR", "Input v. GT PSNR"])  

    # Let's loop through the provided dataset in the QC folders
    print('Computing maps...')

    _, image_extension = os.path.splitext(os.listdir(Target_QC_folder)[0])

    for i in os.listdir(Target_QC_folder):
      if not os.path.isdir(os.path.join(Source_QC_folder,i)) and i.endswith(image_extension):

      # -------------------------------- Target test data (Ground truth) --------------------------------
        test_GT = io.imread(os.path.join(Target_QC_folder, i))
        #test_GT = test_patches_gt

      # -------------------------------- Source test data --------------------------------
        test_source = io.imread(os.path.join(Source_QC_folder,i))
        #test_source = test_patches_wf

      # Normalize the images wrt each other by minimizing the MSE between GT and Source image
        test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)

      # -------------------------------- Prediction --------------------------------
        test_prediction = io.imread(os.path.join(prediction_QC_folder,i))

      # Normalize the images wrt each other by minimizing the MSE between GT and prediction
        test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True)        

      # -------------------------------- Calculate the metric maps and save them --------------------------------

      # Calculate the SSIM maps
        index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)
        index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)

      #Save ssim_maps
        img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)
        img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)
        if image_extension == '.tif':
          io.imsave(full_QC_model_path+'/Quality Control/Maps/SSIM_GTvsPrediction_'+i, img_SSIM_GTvsPrediction_32bit)
          io.imsave(full_QC_model_path+'/Quality Control/Maps/SSIM_GTvsSource_'+i, img_SSIM_GTvsSource_32bit)
        else:
          io.imsave(full_QC_model_path+'/Quality Control/Maps/SSIM_GTvsPrediction_'+i, img_as_uint(img_SSIM_GTvsPrediction_32bit))
          io.imsave(full_QC_model_path+'/Quality Control/Maps/SSIM_GTvsSource_'+i, img_as_uint(img_SSIM_GTvsSource_32bit))
            
      # Calculate the Root Squared Error (RSE) maps
        img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))
        img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))

      # Save SE maps
        img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)
        img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)
        if image_extension == '.tif':
          io.imsave(full_QC_model_path+'/Quality Control/Maps/RSE_GTvsPrediction_'+i, np.clip(img_RSE_GTvsPrediction_32bit, -1, 1))
          io.imsave(full_QC_model_path+'/Quality Control/Maps/RSE_GTvsSource_'+i, np.clip(img_RSE_GTvsSource_32bit, -1, 1))
        else:
          io.imsave(full_QC_model_path+'/Quality Control/Maps/RSE_GTvsPrediction_'+i, img_as_uint(np.clip(img_RSE_GTvsPrediction_32bit, -1, 1)))
          io.imsave(full_QC_model_path+'/Quality Control/Maps/RSE_GTvsSource_'+i, img_as_uint(np.clip(img_RSE_GTvsSource_32bit, -1, 1)))


      # -------------------------------- Calculate the RSE metrics and save them --------------------------------

      # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)
        NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))
        NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))
        
      # We can also measure the peak signal to noise ratio between the images
        PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)
        PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)

        writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),
                         str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])

# All data is now processed saved
Test_FileList = os.listdir(prediction_QC_folder) # this assumes, as it should, that both source and target are named the same


plt.figure(figsize=(15,15))
# Currently only displays the last computed set, from memory

df = pd.read_csv(full_QC_model_path+"/Quality Control/QC_metrics_"+QC_model_name+".csv")

# ------------- For display ------------
print('\n------------------------------------------------------------------------------------------------------------------------')
print("Choose the image file (once selected, the new set of images will take a few seconds to appear):")
print(" ")
@interact
def show_QC_results(file = sorted(os.listdir(Source_QC_folder))):
  visualise_image_comparison_QC(image = file, dimension='2D', Source_folder=Source_QC_folder , Prediction_folder= prediction_QC_folder, Ground_truth_folder=Target_QC_folder, QC_folder=full_QC_model_path+"/Quality Control/Maps", QC_scores= df )  

print('\n------------------------------------------------------------------------------------------------------------------------')


# Export pdf wth summary of QC results
qc_pdf_export()

# **6. Using the trained model**

---

<font size = 4>In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.

## **6.1. Generate prediction(s) from unseen dataset**
---
<font size = 4>The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).

<font size = 4>**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.

<font size = 4>**`Result_folder`:** This folder will contain the predicted output images.

In [None]:
#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.

Data_folder = "" #@param {type:"string"}
Result_folder = "" #@param {type:"string"}


#@markdown ###Does **Data_folder** contain *High resolution* images?
#@markdown ##### If your data only contains *High resolution* images, check the option below. The *Low resolution* images will be generated from the *High resolution* data given in the **Data_foder**.    

Are_high_resolution_images = False #@param {type:"boolean"}

Type_of_data = "Flourescence" #@param ["Electron microscopy", "Flourescence"]

#@markdown ##### Aside from the predicted data, if it is desired, the generated *Low resolution* images can be saved checking the option below and providing the desired folder where the *Low resolution* data will be saved.

Save_generated_low_resolution_images = False #@param {type:"boolean"}
Low_resolution_folder = "" #@param{type:"string"}


# model name and path
#@markdown ###Do you want to use the current trained model?
Use_the_current_trained_model = True #@param {type:"boolean"}

#@markdown ###If not, provide the name of the model and path to model folder:
#@markdown #####During training, the model files are automatically saved inside a folder named after model_name in section 3. Provide the path to this folder below.
Prediction_model_folder = "" #@param {type:"string"}
Prediction_down_factor = 2 #@param {type:"number"}

prediction_QC_folder = os.path.join(Result_folder, 'Quality Control', 'Prediction')

## Remove directory. Better avoid this.
# if os.path.exists(Result_folder):
#   shutil.rmtree(Result_folder)
# os.makedirs(Result_folder)
if not os.path.exists(Result_folder):
  os.makedirs(Result_folder)
else:
  print(bcolors.WARNING+'!! WARNING: The results folder already exists and files may be overwritten !!')



#Here we find the loaded model name and parent path
Prediction_model_name = os.path.basename(Prediction_model_folder)
Prediction_model_path = os.path.dirname(Prediction_model_folder)

if (Use_the_current_trained_model): 
    print("Using current trained network")
    Prediction_model_name = Model_name
    Prediction_model_path = Model_path
else:
    Down_factor = Prediction_down_factor

full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)
if os.path.exists(full_Prediction_model_path):
    print("The "+Prediction_model_name+" network will be used.")
else:
    W  = '\033[0m'  # white (normal)
    R  = '\033[31m' # red
    print(R+'!! WARNING: The chosen model does not exist !!'+W)
    print('Please make sure you provide a valid model path and model name before proceeding further.')

# Read the list of file names


_, extension = os.path.splitext(os.listdir(Data_folder)[0])

filenames = [x for x in os.listdir( Data_folder ) if x.endswith(extension)]
filenames.sort()
print( 'Available images: ' + str( len(filenames)) )

model = WGANGP(
    g_layers=15, 
    d_layers=5,
    batchsize=1,
    patch_size=256,
    down_factor=Down_factor,
    recloss=100.0,
    hr_imgs_basedir=Data_folder,
    lr_imgs_basedir=Data_folder,
    gen_checkpoint = full_Prediction_model_path + '/best_checkpoint.pth',
    only_predict = True
)

gen = model.generator

del model

#model.cuda()

if Save_generated_low_resolution_images:
    if not os.path.exists(Low_resolution_folder):
        print("The "+Low_resolution_folder+ " path does not exist, it will be generated.")
        os.makedirs(Low_resolution_folder)

print("Predicting...")

for i in tqdm(filenames):
    image = img_as_ubyte( io.imread( Data_folder + '/' + i ) )

    if Are_high_resolution_images:
        if Type_of_data == "Electron microscopy":
            image = em_crappify(image, Down_factor)
        else:
            image = fluo_crappify(image, Down_factor)
    else:
        image = image/255 # normalize between 0 and 1

    if Save_generated_low_resolution_images:
        tf.keras.preprocessing.image.save_img(Low_resolution_folder+'/'+filenames[i], image, data_format=None, file_format=None)

    image = np.expand_dims(image, axis=-1)
    image = image.astype(np.float32)
    image = image.transpose((2, 0, 1))
    image = torch.from_numpy(image)

    prediction = gen(image)
    prediction = np.expand_dims(np.squeeze(prediction.detach().numpy()),axis=-1)

    tf.keras.preprocessing.image.save_img(Result_folder+'/'+i, prediction)
    del image
    del prediction

print("Images saved into folder:", Result_folder)


## **6.2. Inspect the predicted output**
---



In [None]:
# @markdown  ##Run this cell to see some inputs with their corresponding outputs
##Run this cell to display a randomly chosen input and its corresponding predicted output.

# ------------- For display ------------
print('---------------------------------------------------------------------------------------------------------------------------------------------------')
print("Choose the image file (Once selected, the new set of images will take a few seconds to appear):")
print(" ")
@interact
def show_prediction_results(file = sorted(os.listdir(Data_folder))):
  imageLR = io.imread(os.path.join(Data_folder, file))
  imageP = io.imread(os.path.join(Result_folder, file))

  plt.figure(figsize=(25,25))
  plt.subplot(3, 3, 1)
  plt.imshow( imageLR, 'gray' )
  plt.title( 'Low resolution' )
  # ant its prediction
  plt.subplot(3, 3, 2)
  plt.imshow( imageP, 'gray' )
  plt.title( 'Prediction' ) 

print('---------------------------------------------------------------------------------------------------------------------------------------------------')


## **6.3. Download your predictions**
---

<font size = 4>**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name.

# **7. Version log**
---



# **7. Version log**

---

<font size = 4>**v1.15.1**:  
*  Fixed the PDF generation (fpdf2, delimiter, Cournier)
*  Removed the paths to /content

<font size = 4>**v1.15.1**:  
*   First version of this notebook


# **Thank you for using WGAN!**