<a href="https://colab.research.google.com/github/billsioros/thesis/blob/master/Nanorough_surface_Super_resolution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ✔️ Prerequisites

First of all we need to take care of a few **prerequisites**, most notably:

- Install the various pip modules that we will be using.
- Install some linux specific dependencies of our [content loss](#content-loss).
- Initialize the Random Number Generator(s), so that our experiments can be replicated.
- Determine:
  - The current working directory, as it's going to be used to reference various files such as the dataset, our model checkpoints e.t.c
  - The available hardware backend. GPU utilization is preferable, as it results in higher complition time.
- `(Optionally)` Mount Google Drive, where we can load our dataset from.

## Installing [graphviz](https://graphviz.org/) & [libgraphviz-dev](https://packages.debian.org/jessie/libgraphviz-dev)

The aforementioned packages are required by [PyINSECT](https://github.com/billsioros/PyINSECT/tree/implementing-HPGs) and more specifically its graph plotting methods.

In [81]:
!sudo apt-get install graphviz libgraphviz-dev

Reading package lists... Done
Building dependency tree       
Reading state information... Done
graphviz is already the newest version (2.40.1-2).
libgraphviz-dev is already the newest version (2.40.1-2).
0 upgraded, 0 newly installed, 0 to remove and 39 not upgraded.


## Installing the required `pip` modules

- [torch](https://pytorch.org/) is our machine learning framework of choice.
- [numpy](https://numpy.org/), [sympy](https://www.sympy.org/en/index.html) and [scipy](https://www.scipy.org/) are used to in the context of nanorough surface generation.
- [plotly](https://plotly.com/) (which requires [pandas](https://pandas.pydata.org/)) as well as [matplotlib](https://matplotlib.org/) are used in order to plot various graphs.

In [82]:
!pip install torch numpy sympy scipy plotly pandas sklearn matplotlib==3.1.1 git+https://github.com/billsioros/PyINSECT.git@FEATURE_Implementing_HPGraphCollector

Collecting git+https://github.com/billsioros/PyINSECT.git@FEATURE_Implementing_HPGraphCollector
  Cloning https://github.com/billsioros/PyINSECT.git (to revision FEATURE_Implementing_HPGraphCollector) to /tmp/pip-req-build-xiui7qd3
  Running command git clone -q https://github.com/billsioros/PyINSECT.git /tmp/pip-req-build-xiui7qd3
  Running command git checkout -b FEATURE_Implementing_HPGraphCollector --track origin/FEATURE_Implementing_HPGraphCollector
  Switched to a new branch 'FEATURE_Implementing_HPGraphCollector'
  Branch 'FEATURE_Implementing_HPGraphCollector' set up to track remote branch 'FEATURE_Implementing_HPGraphCollector' from 'origin'.
Building wheels for collected packages: PyINSECT
  Building wheel for PyINSECT (setup.py) ... [?25l[?25hdone
  Created wheel for PyINSECT: filename=PyINSECT-0.0.39-cp37-none-any.whl size=22984 sha256=d5a8df3ae4e03f8b21c9966caed2a6e23e54bdb79fb317c3b26a028976e4e5f9
  Stored in directory: /tmp/pip-ephem-wheel-cache-cy7euvic/wheels/92/7c/e

## Initializing (a.k.a `Seeding`) the Random Number Generator(s)

We are required to seed various random number generation engines, so that our experiments can be replicated on a later date.

In [83]:
SEED = 1234

In [84]:
import torch
import numpy as np
import random
import os

if SEED is not None:
  np.random.seed(SEED)
  random.seed(SEED)
  torch.manual_seed(SEED)
  torch.cuda.manual_seed(SEED)
  torch.backends.cudnn.deterministic = True
  os.environ['PYTHONHASHSEED'] = str(SEED)

## Determining the Current Working Directory

In [85]:
from pathlib import Path

BASE_DIR = Path.cwd()

## Mounting Google Drive

In [86]:
GDRIVE_DIR = BASE_DIR / 'drive'

In [87]:
from google.colab import drive

drive.mount(f'{GDRIVE_DIR}')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Determining available backend

By default, we are going to be utilizing the available CPU backend, if no GPU is available.

In [88]:
device = "cpu"
if torch.cuda.is_available():
  device = "cuda:0"

In [89]:
device = torch.device(device)

## Configuring our Loggers

In [90]:
import logging

In [91]:
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

# 🚙 General Purpose Utilities

Here we will be defining any *general purpose* functions and classes, that we are going to be using in the following sections. At the moment, we have only defined two [decorator](https://realpython.com/primer-on-python-decorators/) functions:

## A debugging decorator

The following decorator serves at emitting details regarding the decorated function's calls. In more detai, the information emitted is:
- The function's name.
- Its positional and keyword arguements for the function call at hand.
- Any exception that the function `raises`.

In addition to that, the `debug` decorator passes a special boolean keyword arguement by the name `debug`, if and only if it is included in the function signature. You can then utilize this arguement inside the decorated function and emit additional information.

In [92]:
from functools import wraps
import inspect

def debug(method):
    signature = inspect.signature(method)

    defaults = {
      k: v.default
      for k, v in signature.parameters.items()
      if v.default is not inspect.Parameter.empty
    }

    @wraps(method)
    def wrapper(*args, **kwargs):
        called_with = ''
        if args:
            called_with += ', '.join(str(x) for x in args)
            called_with += ', '

        called_with += ', '.join(f"{x}={kwargs.get(x, defaults[x])}" for x in defaults.keys())

        if 'debug' in defaults and 'debug' not in kwargs:
          kwargs['debug'] = True

        try:
          rv = method(*args, **kwargs)
        except Exception as e:
          print(f"{method.__name__}({called_with}) raised {e}")
          raise

        print(f"{method.__name__}({called_with}) returned {rv}")

        return rv

    return wrapper

## A benchmarking decorator

The following decorator aims at calculating the decorated function's execution time and is used to benchmark our various approaches and assist us in coming up with a comprehensive comparison of their efficiency.

In [93]:
from functools import wraps
from time import time

def benchmark(method):
  @wraps(method)
  def wrapper(*args, **kwargs):
    beg = time()
    rv = method(*args, **kwargs)
    end = time()

    print("%s returned after %7.3f seconds" % (method.__name__, (end - beg)))

    return rv

  return wrapper

## A multi-instance decorator

The following decorator, given a multidimensional matrix, applies the decorated function on every row of the provided matrix and returns a one dimensional matrix, consisting of the accumulated return values of all the calls **or** a singular value, in case the multidimensional matrix has less than expected dimensions. 

In [506]:
from functools import wraps

def per_row(method=None, *, expected_ndim=2):
  def wrapper(method):
    @wraps(method)
    def wrapper_wrapper(self, matrix, *args, **kwargs):
      if len(matrix.shape) > expected_ndim:
        return torch.tensor([method(self, row, *args, **kwargs) for row in matrix])
      
      return method(self, matrix, *args, **kwargs)
    
    return wrapper_wrapper
  
  return wrapper if method is None else wrapper(method)

# 🔢 Metrics

Here we define various metrics that are going to be used throughout our implementation.

## Correlation

In [450]:
def correlation(z_ngs):
  N = z_ngs.shape[0]
  
  rdif, hhcf1d = np.arange(N // 2), np.zeros(N // 2)

  for ndif in range(N // 2):
    surf1 = z_ngs[:N, :(N - ndif)]
    surf2 = z_ngs[:N, ndif:N]
    difsur2 = (surf1 - surf2) ** 2
    hhcf1d[ndif] = np.sqrt(np.mean(np.mean(difsur2)))
  
  return rdif, hhcf1d

# 📈 Plotting Utilities


Here we define our plotting mechanisms. We are going to be using both [plotly](https://plotly.com/) as well as [matplotlib](https://matplotlib.org/), depending on the situation. `plotly`'s greatest benefit over `matplotlib` is that it produces interactible graphs, where you can zoom in and pan around by default.

## Plotting the correlation

In [451]:
import plotly.express as px

def plot_correlation(array):
  x, y = correlation(array)
  
  fig = px.line(
    # title="1-D height-height correlation function",
    # x="r(nm)", y="G(r) (nm)",
    x=x, y=y,
    log_x=True, log_y=True
  )
  
  fig.update_layout(
    # title=title,
    autosize=True,
    width=500, height=500,
    # margin=dict(l=65, r=50, b=65, t=90)
  )

  fig.show()

## Plotting a surface as a 3D surface

In [452]:
import plotly.graph_objects as go

def as_3d_surface(array, autosize=False):
  fig = go.Figure(data=[go.Surface(z=array)])

  fig.update_layout(
    # title=title,
    autosize=True,
    width=500, height=500,
    # margin=dict(l=65, r=50, b=65, t=90)
  )

  fig.show()

## Plotting a surface as a grayscale image

In [453]:
import plotly.express as px

def as_grayscale_image(array):
  fig = px.imshow(array, color_continuous_scale='gray')
  fig.update_layout(coloraxis_showscale=False)
  fig.update_xaxes(showticklabels=False)
  fig.update_yaxes(showticklabels=False)

  fig.update_layout(
    # title=title,
    autosize=True,
    width=500, height=500,
    # margin=dict(l=65, r=50, b=65, t=90)
  )

  fig.show()

## Plotting two distributions against each other

In [454]:
import matplotlib.pyplot as plt

def plot_against(first, second, title="", xlabel="", ylabel="", labels=("", "")):
  x = list(range(max(len(first), len(second))))

  plt.plot(x, first, label=labels[0])
  plt.plot(x, second, label=labels[1])

  plt.grid()
  plt.title(title)
  plt.xlabel(xlabel)
  plt.ylabel(ylabel)

  plt.xlim([min(x), max(x)])

  plt.legend()

  plt.show()

# ⛰️ Surface Generation

## The base class `SurfaceGenerator`

In [455]:
from abc import ABC, abstractmethod

import numpy as np
import sympy
from scipy import stats

class SurfaceGenerator(ABC):
    def __init__(
        self, n_points, rms, skewness, kurtosis, corlength_x, corlength_y, alpha
    ):
        self.n_points = n_points
        self.rms = rms
        self.skewness = skewness
        self.kurtosis = kurtosis
        self.corlength_x = corlength_x
        self.corlength_y = corlength_y
        self.alpha = alpha

        self._mean = 0
        self._length = 0

    def __str__(self):
        return f"{self.__class__.__name__}({self.n_points}, {self.rms}, {self.skewness}, {self.kurtosis}, {self.corlength_x}, {self.corlength_y}, {self.alpha})"

    def __repr__(self):
        return f"<{self}>"

    def __call__(self, length):
        self._length = length

        return self

    def __len__(self):
        return self._length

    def __iter__(self):
        for _ in range(self._length):
            yield self.generate_surface()

    def sort(self, elements):
        indices = np.argsort(elements, axis=0)

        return elements[indices], indices

    @abstractmethod
    def autocorrelation(self, tx, ty):
        raise NotImplementedError

    def generate_surface(self):
        # 1st step: Generation of a Gaussian surface

        # Determine the autocorrelation function R(tx,ty)
        R = np.zeros((self.n_points, self.n_points))

        txmin = -self.n_points // 2
        txmax = self.n_points // 2

        tymin = -self.n_points // 2
        tymax = self.n_points // 2

        dtx = (txmax - txmin) // self.n_points
        dty = (tymax - tymin) // self.n_points

        for tx in range(txmin, txmax, dtx):
            for ty in range(tymin, tymax, dty):
                R[tx + txmax, ty + tymax] = self.autocorrelation(tx, ty)

        # According to the Wiener-Khinchine theorem FR is the power spectrum of the desired profile
        FR = np.fft.fft2(R, s=(self.n_points, self.n_points))
        AMPR = np.sqrt(dtx ** 2 + dty ** 2) * abs(FR)

        # 2nd step: Generate a white noise, normalize it and take its Fourier transform
        X = np.random.rand(self.n_points, self.n_points)
        aveX = np.mean(np.mean(X))

        dif2X = (X - aveX) ** 2
        stdX = np.sqrt(np.mean(np.mean(dif2X)))
        X = X / stdX

        XF = np.fft.fft2(X, s=(self.n_points, self.n_points))

        # 3nd step: Multiply the two Fourier transforms
        YF = XF * np.sqrt(AMPR)

        # 4th step: Perform the inverse Fourier transform of YF and get the desired surface
        zaf = np.fft.ifft2(YF, s=(self.n_points, self.n_points))
        z = np.real(zaf)

        avez = np.mean(np.mean(z))
        dif2z = (z - avez) ** 2
        stdz = np.sqrt(np.mean(np.mean(dif2z)))
        z = ((z - avez) * self.rms) / stdz

        # Define the fraction of the surface to be analysed
        xmin = 0
        xmax = self.n_points
        ymin = 0
        ymax = self.n_points
        z_gs = z[xmin:xmax, ymin:ymax]

        # 2nd step: Generation of a non-Gaussian noise NxN
        z_ngn = stats.pearson3.rvs(
            self.skewness,
            loc=self._mean,
            scale=self.rms,
            size=(self.n_points, self.n_points),
        )

        # as_grayscale_image(z_ngn)
        # 3rd step: Combination of z_gs with z_ngn to output a z_ms
        v_gs = z_gs.flatten(order="F")
        v_ngn = z_ngn.flatten(order="F")

        Igs = np.argsort(v_gs)

        vs_ngn = np.sort(v_ngn)

        v_ngs = np.zeros_like(vs_ngn)
        v_ngs[Igs] = vs_ngn

        z_ngs = np.asmatrix(v_ngs.reshape(self.n_points, self.n_points, order="F")).H

        return z_ngs

## A simple non-gaussian surface generator

In [456]:
#@title
class NonGaussianSurfaceGenerator(SurfaceGenerator):
    def __init__(self, n_points=128, rms=1, skewness=0, kurtosis=3, corlength_x=4, corlength_y=4, alpha=1):
        super().__init__(n_points, rms, skewness, kurtosis, corlength_x, corlength_y, alpha)

    def autocorrelation(self, tx, ty):
        return ((self.rms ** 2) * np.exp(-(abs(np.sqrt((tx / self.corlength_x) ** 2 + (ty / self.corlength_y) ** 2))) ** (2 * self.alpha)))

### Example

In [457]:
generate = NonGaussianSurfaceGenerator()

In [458]:
for surface in generate(1):
  as_grayscale_image(surface)
#   as_3d_surface(surface)
#   plot_correlation(surface)

## A Besel function based non-gaussian surface generator

In [459]:
class BeselNonGaussianSurfaceGenerator(NonGaussianSurfaceGenerator):
    def __init__(self, n_points=128, rms=1, skewness=0, kurtosis=3, corlength_x=4, corlength_y=4, alpha=1, beta_x=1, beta_y=1):
        super().__init__(n_points, rms, skewness, kurtosis, corlength_x, corlength_y, alpha)

        self.beta_x, self.beta_y = beta_x, beta_y

    def autocorrelation(self, tx, ty):
        return super().autocorrelation(tx, ty) * sympy.besselj(0, (2 * np.pi * np.sqrt((tx / self.beta_x) ** 2 + (ty / self.beta_y) **2)))

### Example

In [460]:
besel_generate = BeselNonGaussianSurfaceGenerator(128, 1, 0, 3, 16, 16, 0.5, 4000, 4000)

In [461]:
for surface in besel_generate(1):
  as_grayscale_image(surface)
#   as_3d_surface(surface)
#   plot_correlation(surface)

# 🔄 Dataset Loading and Preprocessing

## Defining the preprocessing pipeline

In [462]:
from abc import ABC, abstractmethod

class Transform(ABC):
  def __init__(self, *args, **kwargs):
    pass
  
  @abstractmethod
  def __call__(self, *args, **kwargs):
    raise NotImplementedError

In [463]:
from torch import flatten

class Flatten(Transform):
  def __call__(self, tensor):
    return flatten(tensor)

In [464]:
class To(Transform):
  def __init__(self, device):
    self.device = device

  def __call__(self, tensor):
    return tensor.to(self.device)

In [465]:
class Normalize(Transform):
  def callback(self, dataset):
    self.min = torch.min(dataset.surfaces).item()
    self.max = torch.max(dataset.surfaces).item()

  def __call__(self, tensor):
    if self.max - self.min > 0:
      return (tensor - self.min) / (self.max - self.min)

    return torch.zeros(tensor.size())

In [466]:
class View(Transform):
  def __init__(self, *args):
    self.args = args
  
  def __call__(self, tensor):
    return tensor.view(*self.args)

In [None]:
import numpy as np
#FIXME
class Pad(Transform):
  def __call__(self, tensor):
    return np.apply_along_axis(lambda row: np.tile(row, 4), 1, y)

## Defining the `Dataset` loading procedure

## The base `NanoroughSurfaceDataset` class

In [467]:
from torch.utils.data.dataset import  Dataset
import torch
import numpy as np

class NanoroughSurfaceDataset(Dataset):
  """A dataset of pre-generated nanorough surfaces"""
  def __init__(self, surfaces, subsampling_factor=4, transforms=[]):
    self.surfaces = np.array(surfaces)
    self.surfaces = torch.from_numpy(self.surfaces)

    self.subsampling_factor = subsampling_factor
    self.subsampling_value = int(surfaces[0].shape[1] / subsampling_factor)

    self.transforms = transforms

    for transform in self.transforms:
      if hasattr(transform, 'callback'):
        transform.callback(self)

  def __len__(self):
    return len(self.surfaces)
  
  def __getitem__(self, idx):
    surface = self.surfaces[idx]

    for transform in self.transforms:
      surface = transform(surface)

    return surface

## A dataset of pre-generated nanorough surfaces in `.mat` format

In [468]:
import scipy.io as sio
import itertools

class NanoroughSurfaceMatLabDataset(NanoroughSurfaceDataset):
  """A dataset of pre-generated nanorough surfaces in `.mat` format"""
  def __init__(self, surface_dir, subsampling_factor=4, variable_name='data', transforms=[], limit=None):
    assert surface_dir.is_dir(), "%s does not exist or is not a dictionary" % (surface_dir,)

    surfaces = []
    for file in itertools.islice(surface_dir.iterdir(), limit):
      if file.is_dir() or file.suffix != '.mat':
        continue

      surfaces.append(self.from_matlab(file, variable_name))

    super().__init__(surfaces, subsampling_factor=subsampling_factor, transforms=transforms)

  @classmethod
  def from_matlab(cls, path_to_mat, variable_name):
    matlab_array = sio.loadmat(str(path_to_mat))
    numpy_array = matlab_array[variable_name]
    
    return numpy_array

# 🏋️ Training and Testing Utilities

## Performing a single training epoch

In [870]:
def per_epoch(generator, discriminator, dataloader, optimizer_generator, optimizer_discriminator, criterion, content_loss=None, loss_weights=None, log_every_n=None, debug=False):
  generator.train()

  if content_loss is None:
    content_loss_weight, criterion_weight = 0, 1
  else:
    content_loss_weight, criterion_weight = loss_weights

  generator_loss, discriminator_loss, discriminator_output_real, discriminator_output_fake = 0, 0, 0, 0
  for train_iteration, X_batch in enumerate(dataloader):
    if log_every_n is not None and not train_iteration % log_every_n:
      print(f"Training Iteration #{train_iteration:04d}")

    # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
    ## Train with all-real batch
    discriminator.zero_grad()
    # Format batch
    label = torch.full((X_batch.size(0),), 1, dtype=X_batch.dtype, device=X_batch.device)
    # Forward pass real batch through D
    output = discriminator(X_batch, debug=debug).view(-1)
    # Calculate loss on all-real batch
    discriminator_error_real = criterion(output, label)
    # Calculate gradients for D in backward pass
    discriminator_error_real.backward()
    discriminator_output_real_batch = output.mean().item()

    ## Train with all-fake batch
    # Generate batch of latent vectors
    noise = torch.randn(X_batch.size(0), *generator.feature_dims, dtype=X_batch.dtype, device=X_batch.device)
    # Generate fake image batch with G
    fake = generator(noise, debug=debug)
    label.fill_(0)
    # Classify all fake batch with D
    output = discriminator(fake.detach(), debug=debug).view(-1)
    # Calculate D's loss on the all-fake batch
    discriminator_error_fake = criterion(output, label)
    # Calculate the gradients for this batch
    discriminator_error_fake.backward()
    # Add the gradients from the all-real and all-fake batches
    discriminator_error_total = discriminator_error_real + discriminator_error_fake
    # Update D
    optimizer_discriminator.step()

    # (2) Update G network: maximize log(D(G(z)))
    generator.zero_grad()
    label.fill_(1)  # fake labels are real for generator cost
    # Since we just updated D, perform another forward pass of all-fake batch through D
    output = discriminator(fake, debug=debug).view(-1)
    # Calculate G's loss based on this output
    if content_loss_weight <= 0:
      discriminator_error_fake = criterion(output, label)
    else:
      generator_content_loss = content_loss(fake.cpu().detach().numpy().squeeze())
      generator_content_loss = torch.mean(generator_content_loss).to(fake.device)

      discriminator_error_fake = content_loss_weight * generator_content_loss + criterion_weight * criterion(output, label)
    # Calculate gradients for G, which propagate through the discriminator
    discriminator_error_fake.backward()
    discriminator_output_fake_batch = output.mean().item()
    # Update G
    optimizer_generator.step()

    generator_loss += discriminator_error_fake / len(dataloader)
    discriminator_loss += discriminator_error_total / len(dataloader)
    discriminator_output_real += discriminator_output_real_batch / len(dataloader)
    discriminator_output_fake += discriminator_output_fake_batch / len(dataloader)
    
  return generator_loss, discriminator_loss, discriminator_output_real, discriminator_output_fake

## Spliting the original dataset into Training and Testing subsets

In [817]:
from torch.utils.data import random_split

def train_test_split(dataset, train_ratio=0.8):
  train_size = int(len(dataset) * train_ratio)
  test_size = len(dataset) - train_size

  train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

  return train_dataset, test_dataset

## Configuring the Training and Testing `DataLoader`s

In [818]:
from torch.utils.data import DataLoader

def train_test_dataloaders(dataset, train_ratio=0.8, **kwargs):
  train_dataset, test_dataset = train_test_split(dataset, train_ratio=train_ratio)

  train_dataloader = DataLoader(train_dataset, **kwargs)
  test_dataloader = DataLoader(test_dataset, **kwargs)

  return train_dataloader, test_dataloader

## The Training Manager

In [819]:
class Configuration:
  def __init__(self, **kwargs):
    for key, value in kwargs.items():
      if isinstance(value, dict):
        value = Configuration(**value)

      setattr(self, key, value)
  
  def to_dict(self):
    rv = {}
    for key, value in self.__dict__.items():
      if isinstance(value, Configuration):
        value = value.to_dict()
      
      rv[key] = value
    
    return rv
  
  def __str__(self):
    return str(self.to_dict())

  def __repr__(self):
    return f"<{self.__class__.__name__} '{str(self)}'>"

In [820]:
from torch.optim import Adam
from functools import partial

class TrainingManager(Configuration):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

    if not hasattr(self, "debug"):
      self.debug = False

    if not hasattr(self, "verbose"):
      self.verbose = False

    if not hasattr(self, "benchmark"):
      self.benchmark = False

    if not hasattr(self, "log_every_n"):
      self.log_every_n = None
    
    if not hasattr(self, "checkpoint_dir"):
      self.checkpoint_dir = None
    
    if self.checkpoint_dir is not None:
      if not hasattr(self, "checkpoint_multiple"):
        self.checkpoint_multiple = False
    
    if not hasattr(self, "content_loss"):
      self.content_loss = None

    if isinstance(self.criterion, tuple):
      self.criterion, self.criterion_weight = self.criterion
    else:
      self.criterion_weight = 0.5
  
    if isinstance(self.content_loss, tuple):
      self.content_loss, self.content_loss_weight = self.content_loss
    else:
      self.content_loss_weight = 0.5

  def __call__(self, generator, discriminator, dataset):
    train_dataloader, test_dataloader = train_test_dataloaders(dataset, train_ratio=self.train_ratio, **self.dataloader.to_dict())

    optimizer_generator = Adam(generator.parameters(), **self.optimizer.to_dict())
    optimizer_discriminator = Adam(discriminator.parameters(), **self.optimizer.to_dict())

    train_epoch_f = self.train_epoch

    if self.benchmark is True:
      train_epoch_f = benchmark(train_epoch_f)

    if self.debug is True:
      tmp = train_epoch_f
      if self.verbose is True:
        train_epoch_f = debug(tmp)
      else:
        train_epoch_f = lambda *args, **kwargs: tmp(*args, **{**kwargs, "debug": True})

    generator_losses, discriminator_losses, discriminator_output_reals, discriminator_output_fakes = [], [], [], []
    for epoch in range(self.n_epochs):
      generator_loss, discriminator_loss, discriminator_output_real, discriminator_output_fake = train_epoch_f(
        generator, discriminator,
        train_dataloader,
        optimizer_generator, optimizer_discriminator,
        self.criterion,
        content_loss=self.content_loss, loss_weights=(self.content_loss_weight, self.criterion_weight),
        log_every_n=self.log_every_n
      )
      
      if self.checkpoint_dir is not None and (not generator_losses or generator_loss < min(generator_losses)):
        generator_mt, discriminator_mt = f'{generator.__class__.__name__}', f'{discriminator.__class__.__name__}'
        
        if self.checkpoint_multiple is True:
          generator_mt += f'_{epoch:03d}'
          discriminator_mt += f'_{epoch:03d}'

        generator_mt += '.mt'
        discriminator_mt += '.mt'

        torch.save(generator.state_dict(), self.checkpoint_dir / generator_mt)
        torch.save(discriminator.state_dict(), self.checkpoint_dir / discriminator_mt)

      generator_losses.append(generator_loss)
      discriminator_losses.append(discriminator_loss)
      discriminator_output_reals.append(discriminator_output_real)
      discriminator_output_fakes.append(discriminator_output_fake)
      
      if self.verbose is True:
        print("Epoch: %02d, Generator Loss: %7.3f, Discriminator Loss: %7.3f" % (epoch, generator_loss, discriminator_loss))
        print("Epoch: %02d, Discriminator Output: [Real: %7.3f, Fake: %7.3f]" % (epoch, discriminator_output_real, discriminator_output_fake))
    
    return generator_losses, discriminator_losses, discriminator_output_reals, discriminator_output_fakes

<a name="content-loss"></a>
# 💸 Designing a content loss function

## Quantizing our input data

The `Quantizer` sub-classes are responsible for quantizing our input data consisting of floating point values. These floating point values are going to serve as symbols for the **n-gram graph representation** and having an infinite amount of symbols will do us no good.

In [821]:
from abc import ABC, abstractstaticmethod

class Quantizer(Configuration):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
  
  @abstractstaticmethod
  def __call__(self, tensor):
    raise NotImplementedError

In [822]:
from sklearn.preprocessing import KBinsDiscretizer

class KBinsDiscretizerQuantizer(Configuration):
  def __init__(self, surfaces=None, **kwargs):
    if 'encode' not in kwargs:
      kwargs['encode'] = 'ordinal'

    self.underlying = KBinsDiscretizer(**kwargs)

    self.original_shape = surfaces.shape[1:]

    self.surfaces = self.underlying.fit_transform(surfaces.reshape(surfaces.shape[0], -1))
    self.surfaces = self.surfaces.reshape(*surfaces.shape)

  def __call__(self, tensor):
    return self.underlying.transform(tensor.reshape(1, -1)).reshape(*self.original_shape)
  
  def __str__(self):
    return str({
        'underlying': self.underlying,
        'shape': self.surfaces.shape
    })

### Example Usage

In [823]:
tensors = torch.rand(10, 4, 4)

In [824]:
tensors

tensor([[[0.2272, 0.3859, 0.8938, 0.2223],
         [0.8264, 0.1611, 0.4622, 0.9708],
         [0.5817, 0.9566, 0.5254, 0.5117],
         [0.0281, 0.7803, 0.2970, 0.0325]],

        [[0.7418, 0.7476, 0.7340, 0.4841],
         [0.4993, 0.6643, 0.9929, 0.4214],
         [0.6197, 0.6231, 0.0436, 0.2856],
         [0.0431, 0.3364, 0.7246, 0.9227]],

        [[0.3309, 0.4782, 0.2224, 0.3792],
         [0.0661, 0.6548, 0.6756, 0.8326],
         [0.7765, 0.3025, 0.3636, 0.1533],
         [0.0426, 0.5856, 0.4252, 0.6871]],

        [[0.5881, 0.9430, 0.2360, 0.6515],
         [0.9873, 0.7596, 0.3875, 0.1417],
         [0.2552, 0.2061, 0.6752, 0.5782],
         [0.5626, 0.4235, 0.3997, 0.1353]],

        [[0.2809, 0.4551, 0.7114, 0.9169],
         [0.8249, 0.5040, 0.7811, 0.1022],
         [0.7030, 0.9000, 0.1161, 0.3014],
         [0.3909, 0.9987, 0.8447, 0.2886]],

        [[0.5549, 0.3742, 0.2603, 0.1898],
         [0.0962, 0.2662, 0.8381, 0.8291],
         [0.3315, 0.8784, 0.0055, 0.4176],
 

In [825]:
quantizer = KBinsDiscretizerQuantizer(tensors)

In [826]:
quantizer

<KBinsDiscretizerQuantizer '{'underlying': KBinsDiscretizer(encode='ordinal', n_bins=5, strategy='quantile'), 'shape': (10, 4, 4)}'>

In [827]:
tensor = quantizer(tensors[0])

In [828]:
tensor

array([[0., 1., 3., 0.],
       [3., 1., 1., 4.],
       [2., 4., 3., 2.],
       [0., 3., 0., 0.]])

## Designing our **Content Loss** interface

In [829]:
from abc import ABC, abstractmethod

class ContentLoss(Configuration):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    
    self.quantizer = KBinsDiscretizerQuantizer(**kwargs)

    self.surfaces = self.quantizer.surfaces

  @abstractmethod
  def __call__(self, surface):
    return self.quantizer(surface)

## Implementing an **n-gram graph** based content loss

In [830]:
from pyinsect.collector.NGramGraphCollector import NGramGraphCollector

class NGramGraphContentLoss(ContentLoss):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

    self.surfaces = self.surfaces.reshape(self.surfaces.shape[0], -1)

    self._collector = NGramGraphCollector()
    
    for surface in self.surfaces:
      self._collector.add(surface)
  
  def __len__(self):
    return len(self.surfaces)
  
  @per_row(expected_ndim=1)
  def __call__(self, surface):
    return self._collector.appropriateness_of(super().__call__(surface))

  def __str__(self):
    return str({'shape': self.surfaces.shape})

### Example Usage

In [831]:
tensors = torch.rand(10, 4, 4).reshape(10, -1)

In [832]:
tensors

tensor([[0.0552, 0.3802, 0.7901, 0.2030, 0.3818, 0.1825, 0.7268, 0.2082, 0.9378,
         0.3789, 0.4485, 0.0793, 0.9778, 0.8802, 0.3579, 0.5760],
        [0.8552, 0.4507, 0.5936, 0.5000, 0.9827, 0.4072, 0.2315, 0.6974, 0.4185,
         0.1830, 0.5947, 0.4661, 0.8348, 0.7724, 0.3074, 0.0312],
        [0.7843, 0.5402, 0.5818, 0.6145, 0.4407, 0.4258, 0.8166, 0.7396, 0.4151,
         0.6283, 0.9360, 0.8064, 0.8682, 0.6643, 0.1807, 0.3297],
        [0.3893, 0.8497, 0.9841, 0.6584, 0.0937, 0.5908, 0.2150, 0.6398, 0.1869,
         0.2232, 0.0607, 0.0472, 0.2460, 0.5917, 0.3957, 0.1040],
        [0.7885, 0.0475, 0.8092, 0.8681, 0.8261, 0.3964, 0.9780, 0.7370, 0.5840,
         0.3121, 0.0569, 0.5596, 0.1319, 0.7419, 0.7878, 0.0095],
        [0.1092, 0.0336, 0.1537, 0.7776, 0.0203, 0.9537, 0.6353, 0.9993, 0.8226,
         0.6675, 0.2386, 0.6891, 0.7977, 0.5491, 0.0989, 0.3293],
        [0.9978, 0.9927, 0.2431, 0.4759, 0.8045, 0.0895, 0.5137, 0.0566, 0.3335,
         0.9775, 0.8021, 0.0031, 0.78

In [833]:
content_loss = NGramGraphContentLoss(surfaces=tensors)

In [834]:
max([content_loss(row.reshape(-1)) for row in tensors])

0.6931838093569301

In [835]:
content_loss(torch.rand(4, 4).reshape(-1))

0.06027685298755914

## Implementing a **2D array graph** based content loss

In [836]:
from pyinsect.collector.NGramGraphCollector import NGramGraphCollector
from pyinsect.structs.array_graph import ArrayGraph2D
from pyinsect.documentModel.representations import DocumentNGramGraph

class ArrayGraph2DCollector(NGramGraphCollector):
  def __init__(
    self, n=3, window_size=3, deep_copy=False, commutative=True, distributional=True, stride=1
  ):
    super().__init__(
      n=n,
      window_size=window_size,
      deep_copy=deep_copy,
      commutative=commutative,
      distributional=distributional,
    )

    self._stride = 1

  def __str__(self):
    return "{0}, stride: {1}".format(
      super().__str__(), self._stride
    )

  def _construct_graph(self, data, *args, **kwargs):
    return ArrayGraph2D(
      data, self._window_size, stride=self._stride
    ).as_graph(DocumentNGramGraph, self._n, self._window_size)

In [837]:
class ArrayGraph2DContentLoss(ContentLoss):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

    self._collector = ArrayGraph2DCollector()
    
    for surface in self.surfaces:
      self._collector.add(surface)
  
  def __len__(self):
    return len(self.surfaces)
  
  @per_row
  def __call__(self, surface):
    return self._collector.appropriateness_of(super().__call__(surface))

  def __str__(self):
    return str({'shape': self.surfaces.shape})

### Example Usage

In [838]:
tensors = torch.rand(10, 4, 4)

In [839]:
tensors

tensor([[[0.4946, 0.6615, 0.9520, 0.4164],
         [0.8176, 0.2058, 0.9560, 0.8269],
         [0.9653, 0.1678, 0.4448, 0.8542],
         [0.7149, 0.1618, 0.7260, 0.3033]],

        [[0.7914, 0.3050, 0.5384, 0.3472],
         [0.7971, 0.5702, 0.7402, 0.8172],
         [0.3259, 0.0706, 0.3510, 0.0108],
         [0.1771, 0.1687, 0.2226, 0.5673]],

        [[0.4956, 0.9977, 0.9787, 0.7012],
         [0.5776, 0.2108, 0.2094, 0.2103],
         [0.1312, 0.1744, 0.4950, 0.8296],
         [0.7259, 0.7932, 0.0715, 0.1923]],

        [[0.0720, 0.3741, 0.3558, 0.4536],
         [0.4310, 0.1235, 0.9796, 0.6803],
         [0.4593, 0.2945, 0.8288, 0.6707],
         [0.2772, 0.4915, 0.9320, 0.8724]],

        [[0.4505, 0.9287, 0.3501, 0.3619],
         [0.1747, 0.8492, 0.5471, 0.4012],
         [0.4967, 0.1953, 0.7368, 0.5051],
         [0.7077, 0.2939, 0.8068, 0.7698]],

        [[0.7545, 0.2291, 0.2098, 0.7403],
         [0.4889, 0.0316, 0.4731, 0.8066],
         [0.4210, 0.7971, 0.5055, 0.7556],
 

In [840]:
content_loss = ArrayGraph2DContentLoss(surfaces=tensors)

In [841]:
max([content_loss(tensors[i]) for i in range(tensors.shape[0])])

0.33563769563769563

In [842]:
content_loss(torch.rand(4, 4))

0.2533333333333333

## Implementing a **Hierarchical Proximity Graph (HPG)** based content loss

In [843]:
from pyinsect.collector.NGramGraphCollector import HPG2DCollector

class HPG2DContentLoss(ContentLoss):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

    self._collector = HPG2DCollector()
    
    for surface in self.surfaces:
      self._collector.add(surface)

  def __len__(self):
    return len(self.surfaces)
  
  @per_row
  def __call__(self, surface):
    return self._collector.appropriateness_of(super().__call__(surface))

  def __str__(self):
    return str({'shape': self.surfaces.shape})

### Example Usage

In [844]:
tensors = torch.rand(10, 4, 4)

In [845]:
tensors

tensor([[[0.3337, 0.5888, 0.1400, 0.0690],
         [0.5492, 0.0544, 0.9653, 0.9625],
         [0.3312, 0.7115, 0.5643, 0.0771],
         [0.8369, 0.0958, 0.9946, 0.5701]],

        [[0.0685, 0.9037, 0.7746, 0.6577],
         [0.8239, 0.4319, 0.9405, 0.5094],
         [0.7709, 0.3717, 0.0264, 0.2666],
         [0.8952, 0.0907, 0.0652, 0.6515]],

        [[0.7680, 0.6458, 0.6208, 0.1817],
         [0.5021, 0.5062, 0.7803, 0.4151],
         [0.3620, 0.4597, 0.8408, 0.2037],
         [0.3498, 0.9792, 0.0883, 0.0899]],

        [[0.6784, 0.9644, 0.9122, 0.1051],
         [0.0508, 0.1099, 0.0824, 0.6277],
         [0.3986, 0.5741, 0.9761, 0.5329],
         [0.5510, 0.6237, 0.2613, 0.4392]],

        [[0.6539, 0.5767, 0.8610, 0.1679],
         [0.6808, 0.8040, 0.7693, 0.8442],
         [0.7558, 0.5916, 0.2174, 0.7612],
         [0.1074, 0.2295, 0.8896, 0.1444]],

        [[0.4598, 0.8916, 0.9436, 0.7026],
         [0.3458, 0.7180, 0.6753, 0.1231],
         [0.0279, 0.8841, 0.2126, 0.5642],
 

In [846]:
content_loss = HPG2DContentLoss(surfaces=tensors)

In [847]:
max([content_loss(tensors[i]) for i in range(tensors.shape[0])])

0.4071593915343915

In [848]:
content_loss(torch.rand(4, 4))

0.23346560846560846

# 🙃 A naive-approach

## Our naive `Generator` and `Discriminator` networks

In [849]:
import torch
from torch import nn

class PerceptronGenerator(nn.Module):
  def __init__(self, in_features, out_features, dtype=torch.float64):
    super().__init__()

    self.in_features, self.out_features = in_features, out_features

    self.feature_dims = (in_features,)

    self.linear = nn.Linear(in_features, out_features)
    self.activation = nn.ReLU()

    self.to(dtype=dtype)
  
  def forward(self, batch, debug=False):
    if debug is True:
      print(f"[DEBUG]: {self.__class__.__name__}: Input {batch.shape}")
      
    batch = self.activation(self.linear(batch))

    if debug is True:
      print(f"[DEBUG]: {self.__class__.__name__}: Output {batch.shape}")
    
    return batch
  
  @classmethod
  def from_dataset(cls, dataset, dtype=torch.float64, device=None):
    in_features = dataset.subsampling_value ** 2
    out_features = (dataset.subsampling_factor * dataset.subsampling_value) ** 2

    model =  cls(in_features, out_features, dtype=dtype)

    if device.type == 'cuda' and torch.cuda.device_count() > 1:
      model = nn.DataParallel(model)

    model = model.to(device)

    return model

In [850]:
import torch
from torch import nn

class PerceptronDiscriminator(nn.Module):
  def __init__(self, in_features, dtype=torch.float64):
    super().__init__()

    self.feature_dims = (in_features,)

    self.linear = nn.Linear(in_features, 1)
    self.activation = nn.Sigmoid()

    self.to(dtype=dtype)
  
  def forward(self, batch, debug=False):
    if debug is True:
      print(f"[DEBUG]: {self.__class__.__name__}: Input {batch.shape}")

    batch = self.activation(self.linear(batch))

    if debug is True:
      print(f"[DEBUG]: {self.__class__.__name__}: Output {batch.shape}")
    
    return batch
  
  @classmethod
  def from_generator(cls, generator, dtype=torch.float64, device=None):
    model =  cls(generator.out_features, dtype=dtype)

    if device.type == 'cuda' and torch.cuda.device_count() > 1:
      model = nn.DataParallel(model)

    model = model.to(device)

    return model

## Loading the Dataset

In [851]:
DATASET_PATH = GDRIVE_DIR / 'MyDrive' / 'Thesis' / 'Datasets' / 'surfaces.zip'

In [852]:
DATASET_SIZE = 50 #FIXME

In [853]:
from zipfile import ZipFile

if DATASET_PATH.is_file():
  SURFACES_DIR = BASE_DIR / 'surfaces'
  SURFACES_DIR.mkdir(parents=True, exist_ok=True)
  
  with ZipFile(DATASET_PATH, 'r') as zip_file:
    zip_file.extractall(SURFACES_DIR)

In [854]:
if SURFACES_DIR.is_dir():
  dataset = NanoroughSurfaceMatLabDataset(SURFACES_DIR, transforms=[Flatten(), To(device)], limit=DATASET_SIZE)
else:
  generate = NonGaussianSurfaceGenerator()
  dataset = NanoroughSurfaceDataset(list(generate(DATASET_SIZE)), transforms=[Flatten(), To(device)])

## Instantiating the content-loss

In [855]:
content_loss = NGramGraphContentLoss(surfaces=dataset.surfaces)

KeyboardInterrupt: ignored

## Instantiating the **Generator** and the **Discriminator** Networks

In [None]:
generator = PerceptronGenerator.from_dataset(dataset, device=device)

In [None]:
print(generator)

In [None]:
discriminator = PerceptronDiscriminator.from_generator(generator, device=device)

In [None]:
print(discriminator)

## Training

In [None]:
from torch.nn import BCELoss

criterion = BCELoss().to(device)

In [None]:
from pathlib import Path

CHECKPOINT_DIR = BASE_DIR / 'checkpoint'
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
from functools import partial
from pyinsect.documentModel.comparators.NGramGraphSimilarity import SimilarityVS

training_manager = TrainingManager(
    benchmark=True,
    verbose=False,
    debug=True,
    checkpoint_dir=CHECKPOINT_DIR,
    checkpoint_multiple=False,
    train_epoch=per_epoch,
    log_every_n=10,
    criterion=criterion,
    content_loss=content_loss,
    n_epochs=10,
    train_ratio=0.8,
    optimizer={
      'lr': 0.0005,
      'weight_decay': 0
    },
    dataloader={
      'batch_size': 256,
      'shuffle': True,
      'num_workers': 0,
    }
)

In [None]:
generator_losses, discriminator_losses, discriminator_output_reals, discriminator_output_fakes = training_manager(generator, discriminator, dataset)

In [None]:
plot_against(generator_losses, discriminator_losses, title="Mean Generator vs Discriminator loss per epoch", xlabel="Epoch", ylabel="Loss", labels=("Generator", "Discriminator"))

In [None]:
plot_against(discriminator_output_reals, discriminator_output_fakes, title="Mean Discriminator Output per epoch", xlabel="Epoch", ylabel="Discriminator Output", labels=("Real Data", "Generator Data"))

# 😎 A CNN based approach

## Our CNN based `Generator` and `Discriminator` networks

In [856]:
from torch import nn

class CNNGenerator(nn.Module):
  def __init__(self, in_channels=100, out_channels=128, training_channels=1, dtype=torch.float64):
    super().__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.training_channels = training_channels

    self.feature_dims = (in_channels, 1, 1)

    self.module_list = nn.ModuleList([
      nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels * 16, 4, 1, 0, bias=False),
        nn.BatchNorm2d(out_channels * 16),
        nn.ReLU(True)
      ),
      nn.Sequential(
        nn.ConvTranspose2d(out_channels * 16, out_channels * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_channels * 8),
        nn.ReLU(True)
      ),
      nn.Sequential(
        nn.ConvTranspose2d(out_channels * 8, out_channels * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_channels * 4),
        nn.ReLU(True)
      ),
      nn.Sequential(
        nn.ConvTranspose2d(out_channels * 4, out_channels * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_channels * 2),
        nn.ReLU(True)
      ),
      nn.Sequential(
        nn.ConvTranspose2d(out_channels * 2, out_channels, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(True)
      ),
      nn.Sequential(
        nn.ConvTranspose2d(out_channels, training_channels, 4, 2, 1, bias=False),
        nn.ReLU()
      )                              
    ])

    self.to(dtype)

  @classmethod
  def from_device(cls, device):
    model =  cls()

    if device.type == 'cuda' and torch.cuda.device_count() > 1:
      model = nn.DataParallel(model)

    model = model.to(device)

    return model

  def forward(self, batch, debug=False):
    if debug is True:
        print(f"[DEBUG]: {self.__class__.__name__}: Input {batch.shape}")

    for i, module in enumerate(self.module_list):
      batch = module(batch)

      if debug is True:
        print(f"[DEBUG]: {self.__class__.__name__}:{module.__class__.__name__}: {'%02d' % (i + 1,) if not i == len(self.module_list) - 1 else 'Output'} {batch.shape}")

    return batch

In [857]:
from torch import nn

class CNNDiscriminator(nn.Module):
  def __init__(self, in_channels=1, out_channels=128, dtype=torch.float64):
    super().__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels

    self.module_list = nn.ModuleList([
      nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True)
      ),
      nn.Sequential(
        nn.Conv2d(out_channels, out_channels * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_channels * 2),
        nn.LeakyReLU(0.2, inplace=True)
      ),
      nn.Sequential(
        nn.Conv2d(out_channels * 2, out_channels * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_channels * 4),
        nn.LeakyReLU(0.2, inplace=True)
      ),
      nn.Sequential(
        nn.Conv2d(out_channels * 4, out_channels * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_channels * 8),
        nn.LeakyReLU(0.2, inplace=True)
      ),
      nn.Sequential(
        nn.Conv2d(out_channels * 8, 1, 8, 1, 0, bias=False),
        nn.Sigmoid()
      )
    ])

    self.to(dtype=dtype)

  @classmethod
  def from_device(cls, device):
    model =  cls()

    if device.type == 'cuda' and torch.cuda.device_count() > 1:
      model = nn.DataParallel(model)

    model = model.to(device)

    return model

  def forward(self, batch, debug=False):
    if debug is True:
        print(f"[DEBUG]: {self.__class__.__name__}: Input {batch.shape}")

    for i, module in enumerate(self.module_list):
      batch = module(batch)

      if debug is True:
        print(f"[DEBUG]: {self.__class__.__name__}:{module.__class__.__name__}: {'%02d' % (i + 1,) if not i == len(self.module_list) - 1 else 'Output'} {batch.shape}")

    return batch

## Loading the Dataset

In [858]:
DATASET_PATH = GDRIVE_DIR / 'MyDrive' / 'Thesis' / 'Datasets' / 'surfaces.zip'

In [859]:
DATASET_SIZE = 50 #FIXME

In [860]:
from zipfile import ZipFile

if DATASET_PATH.is_file():
  SURFACES_DIR = BASE_DIR / 'surfaces'
  SURFACES_DIR.mkdir(parents=True, exist_ok=True)
  
  with ZipFile(DATASET_PATH, 'r') as zip_file:
    zip_file.extractall(SURFACES_DIR)

In [861]:
if SURFACES_DIR.is_dir():
  dataset = NanoroughSurfaceMatLabDataset(SURFACES_DIR, transforms=[To(device), View(1, 128, 128)], limit=DATASET_SIZE)
else:
  generate = NonGaussianSurfaceGenerator()
  dataset = NanoroughSurfaceDataset(list(generate(DATASET_SIZE)), transforms=[To(device), View(1, 128, 128)])

## Instantiating the content-loss

In [None]:
content_loss = HPG2DContentLoss(surfaces=dataset.surfaces)

## Instantiating the **Generator** and the **Discriminator** Networks

In [None]:
generator = CNNGenerator.from_device(device)

In [None]:
print(generator)

In [None]:
discriminator = CNNDiscriminator.from_device(device)

In [None]:
print(discriminator)

## Training

In [None]:
from torch.nn import BCELoss

criterion = BCELoss().to(device)

In [None]:
from functools import partial
from pyinsect.documentModel.comparators.NGramGraphSimilarity import SimilarityVS

training_manager = TrainingManager(
    benchmark=True,
    verbose=False,
    debug=True,
    checkpoint_dir=None,
    train_epoch=per_epoch,
    log_every_n=10,
    criterion=criterion,
    content_loss=content_loss,
    n_epochs=10,
    train_ratio=0.8,
    optimizer={
      'lr': 0.0002,
      'betas': (0.5, 0.999)
    },
    dataloader={
      'batch_size': 256,
      'shuffle': True,
      'num_workers': 0,
    }
)

In [None]:
generator_losses, discriminator_losses, discriminator_output_reals, discriminator_output_fakes = training_manager(generator, discriminator, dataset)

In [None]:
plot_against(generator_losses, discriminator_losses, title="Mean Generator vs Discriminator loss per epoch", xlabel="Epoch", ylabel="Loss", labels=("Generator", "Discriminator"))

In [None]:
plot_against(discriminator_output_reals, discriminator_output_fakes, title="Mean Discriminator Output per epoch", xlabel="Epoch", ylabel="Discriminator Output", labels=("Real Data", "Generator Data"))