# Time-Series Generation using Contrastive Learning

Consider learning a generative model for time-series data.

The sequential setting poses a unique challenge: Not only should the generator capture the conditional dynamics of (stepwise) transitions, but its open-loop rollouts should also preserve the joint distribution of (multi-step) trajectories.

On one hand, autoregressive models
trained by MLE allow learning and computing explicit transition distributions, but suffer from compounding error during rollouts.

On the other hand, adversarial models based on GAN training alleviate such exposure bias, but transitions are implicit and hard to assess.

In this work, we study a generative framework that seeks to combine the strengths of both: Motivated by a moment-matching objective to mitigate
compounding error, we optimize a local (but forward-looking) *transition policy*, where the reinforcement signal is provided by a global (but stepwise-decomposable) *energy model* trained by contrastive estimation. 

At **training**, the two components are learned cooperatively, avoiding the instabilities typical of adversarial objectives. 

At **inference**, the learned policy serves as the generator for iterative sampling, and the learned energy serves as a trajectory-level measure for evaluating sample quality.

By expressly training a policy to imitate sequential behavior of time-series features in a dataset, this approach embodies *“generation by imitation”*. Theoretically, we illustrate the correctness of this formulation and the consistency of the algorithm.

Empirically, we evaluate its ability to generate predictively useful samples from real-world datasets, verifying that it performs at the standard of existing benchmarks.

## 1 Setup

### 1.1 Install libraries

Run the cell below to **install** the necessary libraries.

In [21]:
# !pip install wandb
# !pip install pytorch-lightning
# !pip install pyyaml
# !pip install torchvision
# !pip install plotly
!pip install tensorflow

Found existing installation: tensorflow 2.15.0.post1
Uninstalling tensorflow-2.15.0.post1:
  Would remove:
    /home/dima/.local/bin/estimator_ckpt_converter
    /home/dima/.local/bin/import_pb_to_tensorboard
    /home/dima/.local/bin/saved_model_cli
    /home/dima/.local/bin/tensorboard
    /home/dima/.local/bin/tf_upgrade_v2
    /home/dima/.local/bin/tflite_convert
    /home/dima/.local/bin/toco
    /home/dima/.local/bin/toco_from_protos
    /home/dima/.local/lib/python3.10/site-packages/tensorflow-2.15.0.post1.dist-info/*
    /home/dima/.local/lib/python3.10/site-packages/tensorflow/*
Proceed (Y/n)? ^C
[31mERROR: Operation cancelled by user[0m[31m
[0m

### 1.2 Import Libraries

Run the cell below to **import** the necessary libraries

In [12]:
from typing import Sequence, List, Dict, Tuple, Optional, Any, Set, Union, Callable, Mapping
import itertools

import dataclasses
from dataclasses import dataclass
from dataclasses import asdict
from pathlib import Path
# from pprint import pprint
# from urllib.request import urlopen
# import random

# from PIL import Image
# import PIL

# import torchvision.utils
# import matplotlib.pyplot as plt
# import plotly.graph_objects as go
# import plotly.express as px

import numpy as np
import torch
import torchvision.transforms as transforms # this is not really needed
from torchvision.datasets import MNIST
# from torch.utils.data import DataLoader, Dataset
# from torch import nn, optim
# import torch.nn.functional as F

import wandb
import pytorch_lightning as pl
# from pytorch_lightning.loggers.wandb import WandbLogger
# from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

# import torchvision
# from torchvision import transforms
# from tqdm.notebook import tqdm

## Necessary packages
#import warnings
#warnings.filterwarnings("ignore")

# # 1. TimeGAN model
# from timegan import timegan
# # 2. Data loading
# from data_loading import real_data_loading, sine_data_generation
# # 3. Metrics
# from metrics.discriminative_metrics import discriminative_score_metrics
# from metrics.predictive_metrics import predictive_score_metrics
# from metrics.visualization_metrics import visualization


import random

### 1.3 Hyper-parameters

The cell below contains *all* the hyper-parameters nedded by this script, for easy tweaking.

In [None]:
c = 1.0 # . . . Domain bounds for the loss functions
M = 32 #. . . . Mini-batch size
lr = 0.0007 # . Learning Rate
k = 1.0 # . . . Regularization coefficient 

## Data loading
data_name = 'stock' # . . which dataset to use
seq_len = 24 #. . . . . . max length of the input sequence

## Newtork parameters
module = 'gru' #. . . . . Can be 'gru', 'lstm' or 'lstmLN'
hidden_dim = 24 # . . . . Hidden dimensions
num_layer = 3 # . . . . . Number of layers
iterations = 10000 #. . . Number of epochs
batch_size = 128 #. . . . Amount of samples in each batch

metric_iteration = 5 #. . Number of iteration for each metric

Parameters for humans.

In [18]:
use_wandb = False # . . will require login for Weights & Biases

### 1.4 Utils

In [None]:
# Just a function to count the number of parameters
def count_parameters(model: torch.nn.Module) -> int:
  """ Counts the number of trainable parameters of a module

  :param model: model that contains the parameters to count
  :returns: the number of parameters in the model
  """
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

### 1.5 Initialization

Initialize the modules needed by running the cells in this section.

#### 1.5.1 reproducibility.

In [13]:
np.random.seed(0)
random.seed(0)

torch.cuda.manual_seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True  # Note that this Deterministic mode can have a performance impact
torch.backends.cudnn.benchmark = False

_ = pl.seed_everything(0)

Global seed set to 0


#### 1.5.2 Utils

In [19]:
if use_wandb:
    !wandb login

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#### 1.5.3 Data

In [None]:
if data_name in ['stock', 'energy']:
  ori_data = real_data_loading(data_name, seq_len)
elif data_name == 'sine':
  # Set number of samples and its dimensions
  no, dim = 10000, 5
  ori_data = sine_data_generation(no, seq_len, dim)
else:
  assert(False)
    
print(data_name + ' dataset has been loaded.')

#### 1.5.4 Network Parameters

In [None]:
# The dataclass are fancy classes to hold data

# Working with dataclasses is particularly comfortable
# since you can specify types and get autocomplete/suggestion
# of the available hyperparameters

@dataclass
class Config:
    #dataset_name: str = "ukiyoe2photo"  # name of the dataset

    # Run we did:
    # map: 200 epochs, 100 decay
    # ukiyo: 40 epochs, 20 decay (due to time contraints)
    # They took ~7 hours each on a 2080ti
    n_epochs: int = 200  # number of epochs of training
    decay_epoch: int = 100  # epoch from which to start lr decay

    img_height: int = 128  # size of image height # default 256x256
    img_width: int = 128  # size of image width

    batch_size: int = 1  # size of the batches
    lr: float = 0.0002  # adam: learning rate
    b1: float = 0.5  # adam: decay of first order momentum of gradient
    b2: float = 0.999  # adam: decay of first order momentum of gradient

    channels: int = 3  # number of image channels
    n_residual_blocks: int = 6  # number of residual blocks in generator # original 9
    lambda_cyc: float = 10.0  # cycle loss weight
    lambda_id: float = 5.0  # identity loss weight

    n_cpu: int = 8  # number of cpu threads to use for the dataloaders

    log_images: int = min(25, 100)  # number of images to log


cfg = Config()
##pprint(asdict(cfg))

## 1.6 Model

### 1.6.1 Network Modules

##### 1.6.1.1 Embedder

##### 1.6.1.2 Recovery

##### 1.6.1.3 Supervisor

##### 1.6.1.4 Generator

##### 1.6.1.5 Discriminator

### 1.6.2 Full Model

In [None]:
class CycleGAN(pl.LightningModule):
    def __init__(
        self,
        hparams: Union[Dict, Config],
        trainA_folder: Path,
        trainB_folder: Path,
        testA_folder: Path,
        testB_folder: Path,
    ) -> None:
        """
        The CycleGAN model.

        :param hparams: dictionary that contains all the hyperparameters
        :param trainA_folder: Path to the folder that contains the trainA images
        :param trainB_folder: Path to the folder that contains the trainB images
        :param testA_folder: Path to the folder that contains the testA images
        :param testB_folder: Path to the folder that contains the testB images
        """
        super().__init__()
        self.save_hyperparameters(asdict(hparams) if not isinstance(hparams, Mapping) else hparams)

        # Dataset paths
        self.trainA_folder = trainA_folder
        self.trainB_folder = trainB_folder
        self.testA_folder = testA_folder
        self.testB_folder = testB_folder

        # Expected image shape
        self.input_shape = (self.hparams["channels"], self.hparams["img_height"], self.hparams["img_width"])

        # Generators A->B and B->A
        self.G_AB = GeneratorResNet(self.input_shape, self.hparams["n_residual_blocks"])
        self.G_BA = GeneratorResNet(self.input_shape, self.hparams["n_residual_blocks"])

        # Discriminators
        self.D_A = Discriminator(self.input_shape)
        self.D_B = Discriminator(self.input_shape)

        # Initialize weights
        # https://pytorch.org/docs/stable/nn.html?highlight=nn%20module%20apply#torch.nn.Module.apply
        self.G_AB.apply(self.weights_init_normal)
        self.G_BA.apply(self.weights_init_normal)
        self.D_A.apply(self.weights_init_normal)
        self.D_B.apply(self.weights_init_normal)

        # Image Normalizations
        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(int(self.hparams["img_height"] * 1.12), Image.BICUBIC),
                transforms.RandomCrop((self.hparams["img_height"], self.hparams["img_width"])),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        # Image Normalization for the validation: remove source of randomness
        self.val_image_transforms = transforms.Compose(
            [
                transforms.Resize(int(self.hparams["img_height"] * 1.12), Image.BICUBIC),
                transforms.CenterCrop((self.hparams["img_height"], self.hparams["img_width"])),
                # transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        # Image buffers
        self.fake_A_buffer = ReplayBuffer()
        self.fake_B_buffer = ReplayBuffer()

        # Forward pass cache to avoid re-doing some computation
        self.fake_A = None
        self.fake_B = None

        # Losses
        self.mse = torch.nn.MSELoss()
        self.l1 = torch.nn.L1Loss()

        # Ignore this.
        # It avoids wandb logging when lighting does a sanity check on the validation
        self.is_sanity = True

    def forward(self, x: torch.Tensor, a_to_b: bool) -> torch.Tensor:
        """
        Forward pass for this model.

        This is not used while training!

        :param x: input of the forward pass with shape [batch, channel, w, h]
        :param a_to_b: if True uses the mapping A->B, otherwise uses B->A

        :returns: the translated image with shape [batch, channel, w, h]
        """
        if a_to_b:
            return self.G_AB(x)
        else:
            return self.G_BA(x)

    def weights_init_normal(self, m: nn.Module) -> None:
        """
        Initialize the weights with a gaussian N(0, 0.02) as described in the paper.

        :param m: the module that contains the weights to initialise
        """
        classname = m.__class__.__name__
        if classname.find("Conv") != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
            if hasattr(m, "bias") and m.bias is not None:
                torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find("BatchNorm2d") != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)

    def train_dataloader(self) -> DataLoader:
        """ Create the train set DataLoader

        :returns: the train set DataLoader
        """
        train_loader = DataLoader(
            DatasetUnpaired(
                self.trainA_folder, self.trainB_folder, transform=self.image_transforms
            ),
            batch_size=self.hparams["batch_size"],
            shuffle=True,
            num_workers=2,
            pin_memory=True,
        )
        return train_loader

    def val_dataloader(self, custom_batch_size: Optional[int] = None) -> DataLoader:
        """ Create the validation set DataLoader.

        It is deterministic.
        It does not shuffle and does not use random transformation on each image.

        :returns: the validation set DataLoader
        """
        test_loader = DataLoader(
            DatasetUnpaired(
                self.testA_folder,
                self.testB_folder,
                transform=self.val_image_transforms,
                fixed_pairs=True,
            ),
            batch_size=custom_batch_size if custom_batch_size is not None else 32,
            shuffle=False,
            num_workers=2,
            pin_memory=True,
        )
        return test_loader

    def configure_optimizers(
        self,
    ) -> Tuple[Sequence[optim.Optimizer], Sequence[Dict[str, Any]]]:
        """ Instantiate the optimizers and schedulers.

        We have three optimizers (and relative schedulers):

        - Optimizer with index 0: optimizes the parameters of both generators
        - Optimizer with index 1: optimizes the parameters of D_A
        - Optimizer with index 2: optimizes the parameters of D_B

        Each scheduler implements a linear decay to 0 after `cfg.hparams["decay_epoch"]`

        :returns: the optimizers and relative schedulers (look at the return type!)
        """
        # Optimizers
        optimizer_G = torch.optim.Adam(
            itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()),
            lr=self.hparams["lr"],
            betas=(self.hparams["b1"], self.hparams["b2"]),
        )
        optimizer_D_A = torch.optim.Adam(
            self.D_A.parameters(), lr=self.hparams["lr"], betas=(self.hparams["b1"], self.hparams["b2"])
        )
        optimizer_D_B = torch.optim.Adam(
            self.D_B.parameters(), lr=self.hparams["lr"], betas=(self.hparams["b1"], self.hparams["b2"])
        )

        # Schedulers for each optimizers
        lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            optimizer_G,
            lr_lambda=LambdaLR(self.hparams["n_epochs"], self.hparams["decay_epoch"]).step,
        )
        lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
            optimizer_D_A,
            lr_lambda=LambdaLR(self.hparams["n_epochs"], self.hparams["decay_epoch"]).step,
        )
        lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
            optimizer_D_B,
            lr_lambda=LambdaLR(self.hparams["n_epochs"], self.hparams["decay_epoch"]).step,
        )

        return (
            [optimizer_G, optimizer_D_A, optimizer_D_B],
            [
                {"scheduler": lr_scheduler_G, "interval": "epoch", "frequency": 1},
                {"scheduler": lr_scheduler_D_A, "interval": "epoch", "frequency": 1},
                {"scheduler": lr_scheduler_D_B, "interval": "epoch", "frequency": 1},
            ],
        )

    def criterion_GAN(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """ The loss criterion for GAN losses

        :param x: tensor with any shape
        :param y: tensor with any shape

        :returns: the mse between x and y
        """
        return self.mse(x, y)

    def criterion_cycle(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """ The loss criterion for Cycle losses

        :param x: tensor with any shape
        :param y: tensor with any shape

        :returns: the l1 between x and y
        """
        return self.l1(x, y)

    def criterion_identity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """ The loss criterion for Identity losses

        :param x: tensor with any shape
        :param y: tensor with any shape

        :returns: the l1 between x and y
        """
        return self.l1(x, y)

    def identity_loss(self, image: torch.Tensor, generator: nn.Module) -> torch.Tensor:
        """ Implements the identity loss for the given generator

        :param generator: a generator module that maps X -> Y
        :param image: an image in the Y distribution with shape [batch, channel, w, h]

        :returns: the identity loss for these (generator, image)
        """
        return self.criterion_identity(generator(image), image)

    def gan_loss(
        self,
        generator: nn.Module,
        discriminator: nn.Module,
        image: torch.Tensor,
        expected_label: torch.Tensor,
    ) -> torch.Tensor:
        """ Implements the GAN loss for the given generator and discriminator

        :param image: the input image with shape [batch, channle, w, h]
        :param generator: the generator module to use to translate the image from X -> Y
        :param discriminator: the discriminator that tries to distinguish fake and real images
        :expected_label: tensor with shape compatible to the discriminator's output.
                         It is full of ones when training the generator. We feed a fake
                         image to the discriminator and we expect to get ones
                         (for the discriminator this is an error!)

        :returns: the GAN loss for these (image, generator, discriminator)
        """
        fake_image = generator(image)
        predicted_label = discriminator(fake_image)
        loss_GAN = self.criterion_GAN(predicted_label, expected_label)
        return loss_GAN, fake_image

    def cycle_loss(
        self,
        fake_image: torch.Tensor,
        reverse_generator: nn.Module,
        original_image: torch.Tensor,
    ) -> torch.Tensor:
        """ Implements the cycle consistency loss

        It takes in input a fake image, to avoid repeated computation,
        thus it only needs the reverse mapping that produced that fake image.

        :param fake_image: a image produced by a mapping X->Y with shape [batch, channel, w, h]
        :param reverse_generator: the generator module that maps Y->X
        :param original_image: the original image in X with shape [batch, channel, w, h]
                               to compare with the reconstructed fake image

        :returns: the cycle consistency loss for this (fake_image, reverse_generator, original_image)
        """
        recovered_image = reverse_generator(fake_image)
        return self.criterion_cycle(recovered_image, original_image)

    def discriminator_loss(
        self,
        discriminator: nn.Module,
        proposed_image: torch.Tensor,
        expected_label: torch.Tensor,
    ) -> torch.Tensor:
        """ Implements the loss used to train the discriminator

        :param discriminator: the discriminator model to train
        :param proposed_image: the fake or real image proposed with shape [batch, channel, w, h]
        :param expected_label: tensor with shape compatible to the discriminator's output,
                               full of zeros if the proposed image is fake
                               full of ones if the proposed image is real

        :returns: the discriminator loss for this (discriminator, proposed_image, expected_label)
        """
        predicted_label = discriminator(proposed_image)
        return self.criterion_GAN(predicted_label, expected_label)

    def training_step(
        self, batch: Dict[str, torch.Tensor], batch_nb: int, optimizer_idx: int
    ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
        """ Implements a single training step

        The parameter `optimizer_idx` identifies with optimizer "called" this training step,
        this we can change the behaviour of the training depending on which optimizers
        is currently performing the optimization

        :param batch: current training batch
        :param batch_nb: the index of the current batch
        :param optimizer_idx: the index of the optimizer in use, see the function `configure_optimizers`

        :returns: the total loss for the current training step, together with other information for the
                  logging and possibly the progress bar
        """
        # Unpack the batch
        real_A = batch["A"]
        real_B = batch["B"]

        # Adversarial ground truths
        valid = torch.ones(
            (real_A.size(0), *self.D_A.output_shape), device=real_A.device
        )
        fake = torch.zeros(
            (real_A.size(0), *self.D_A.output_shape), device=real_A.device
        )

        # The first optimizer is for the two generators!
        if optimizer_idx == 0:

            # Identity A and B loss
            loss_id_A = self.identity_loss(real_A, self.G_BA)
            loss_id_B = self.identity_loss(real_B, self.G_AB)
            loss_identity = self.hparams["lambda_id"] * ((loss_id_A + loss_id_B) / 2)

            # GAN A loss and GAN B loss
            loss_GAN_AB, self.fake_B = self.gan_loss(self.G_AB, self.D_B, real_A, valid)
            loss_GAN_BA, self.fake_A = self.gan_loss(self.G_BA, self.D_A, real_B, valid)
            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss: A -> B -> A  and  B -> A -> B
            loss_cycle_A = self.cycle_loss(self.fake_B, self.G_BA, real_A)
            loss_cycle_B = self.cycle_loss(self.fake_A, self.G_AB, real_B)
            loss_cycle = self.hparams["lambda_cyc"] * ((loss_cycle_A + loss_cycle_B) / 2)

            # Total loss
            loss_G = loss_GAN + loss_cycle + loss_identity

            self.log_dict({
                    "total_loss_generators": loss_G,
                    "loss_GAN": loss_GAN,
                    "loss_cycle": loss_cycle,
                    "loss_identity": loss_identity,
                }
            )
            return loss_G

        # The second optimizer is to train the D_A discriminator
        elif optimizer_idx == 1:

            # Real loss
            loss_real = self.discriminator_loss(self.D_A, real_A, valid)

            # Fake loss (on batch of previously generated samples)
            loss_fake = self.discriminator_loss(
                self.D_A, self.fake_A_buffer.push_and_pop(self.fake_A).detach(), fake
            )

            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2
            self.log_dict({
                    "total_D_A": loss_D_A,
                    "loss_D_A_real": loss_real,
                    "loss_D_A_fake": loss_fake,
                }
            )
            return loss_D_A


        # The second optimizer is to train the D_B discriminator
        elif optimizer_idx == 2:

            # Real loss
            loss_real = self.discriminator_loss(self.D_B, real_B, valid)

            # Fake loss (on batch of previously generated samples)
            loss_fake = self.discriminator_loss(
                self.D_B, self.fake_B_buffer.push_and_pop(self.fake_B).detach(), fake
            )

            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            self.log_dict({
                    "total_D_B": loss_D_B,
                    "loss_D_B_real": loss_real,
                    "loss_D_B_fake": loss_fake,
                }
            )
            return loss_D_B

        raise RuntimeError("There is an error in the optimizers configuration!")

    def get_image_examples(
        self, real: torch.Tensor, fake: torch.Tensor
    ) -> Sequence[wandb.Image]:
        """
        Given real and "fake" translated images, produce a nice coupled images to log

        :param real: the real images with shape [batch, channel, w, h]
        :param fake: the fake image with shape [batch, channel, w, h]

        :returns: a sequence of wandb.Image to log and visualize the performance
        """
        example_images = []
        for i in range(real.shape[0]):
            couple = torchvision.utils.make_grid(
                [real[i], fake[i]],
                nrow=2,
                normalize=True,
                scale_each=True,
                pad_value=1,
                padding=4,
            )
            example_images.append(
                wandb.Image(couple.permute(1, 2, 0).detach().cpu().numpy(), mode="RGB")
            )
        return example_images

    def validation_step(
        self, batch: Dict[str, torch.Tensor], batch_idx: int
    ) -> Dict[str, Union[torch.Tensor,Sequence[wandb.Image]]]:
        """ Implements a single validation step

        In each validation step some translation examples are produced and a
        validation loss that uses the cycle consistency is computed

        :param batch: the current validation batch
        :param batch_idx: the index of the current validation batch

        :returns: the loss and example images
        """

        real_B = batch["B"]
        fake_A = self.G_BA(real_B)
        images_BA = self.get_image_examples(real_B, fake_A)

        real_A = batch["A"]
        fake_B = self.G_AB(real_A)
        images_AB = self.get_image_examples(real_A, fake_B)

        ####

        real_A = batch["A"]
        real_B = batch["B"]

        fake_B = self.G_AB(real_A)
        fake_A = self.G_BA(real_B)

        # Cycle loss A -> B -> A
        recov_A = self.G_BA(fake_B)
        loss_cycle_A = self.criterion_cycle(recov_A, real_A)

        # Cycle loss B -> A -> B
        recov_B = self.G_AB(fake_A)
        loss_cycle_B = self.criterion_cycle(recov_B, real_B)

        # Cycle loss aggregation
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
        loss_cycle = self.hparams["lambda_cyc"] * loss_cycle

        # Total loss
        loss_G = loss_cycle

        return {"val_loss": loss_G, "images_BA": images_BA, "images_AB": images_AB}

    def validation_epoch_end(
        self, outputs: List[Dict[str, torch.Tensor]]
    ) -> Dict[str, Union[torch.Tensor, Dict[str, Union[torch.Tensor,Sequence[wandb.Image]]]]]:
        """ Implements the behaviouir at the end of a validation epoch

        Currently it gathers all the produced examples and log them to wandb,
        limiting the logged examples to `hparams["log_images"]`.

        Then computes the mean of the losses and returns it.
        Updates the progress bar label with this loss.

        :param outputs: a sequence that aggregates all the outputs of the validation steps

        :returns: the aggregated validation loss and information to update the progress bar
        """
        images_AB = []
        images_BA = []

        for x in outputs:
            images_AB.extend(x["images_AB"])
            images_BA.extend(x["images_BA"])

        images_AB = images_AB[: self.hparams["log_images"]]
        images_BA = images_BA[: self.hparams["log_images"]]

        if not self.is_sanity:  # ignore if it not a real validation epoch. The first one is not.
            print(f"Logged {len(images_AB)} images for each category.")

            self.logger.experiment.log(
                {f"images_AB": images_AB, f"images_BA": images_BA,},
                step=self.global_step,
            )
        self.is_sanity = False

        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        self.log_dict({"val_loss": avg_loss})
        return {"val_loss": avg_loss}

## 2 Train

This chapter will train the model according to the hyper-parameters defined above in section [Hyper-parameters](#13-hyper-parameters).

In [None]:
# ⏱⏱⏱ slow executing cell ⏱⏱⏱
# Suggested to use pre-trained models!

# Instantiate the model
gan_model = CycleGAN(hparams=cfg,
                     trainA_folder=trainA,
                     trainB_folder=trainB,
                     testA_folder=testA,
                     testB_folder=testB)

# Define the logger
# https://www.wandb.com/articles/pytorch-lightning-with-weights-biases.
wandb_logger = WandbLogger(project="CycleGAN Tutorial 2021", log_model=True)

## Currently it does not log the model weights, there is a bug in wandb and/or lightning.
wandb_logger.experiment.watch(gan_model, log='all', log_freq=100)

# Define the trainer
trainer = pl.Trainer(logger=wandb_logger,
                     max_epochs=cfg.n_epochs,
                     gpus=1,
                     limit_val_batches=.2,
                     val_check_interval=0.25)

# Start the training
trainer.fit(gan_model)

# Log the trained model
trainer.save_checkpoint('model.pth')
wandb.save('model.pth')

# 3 Validation

## 4 Visualize Results