<a href="https://colab.research.google.com/github/WilliamAshbee/gan/blob/master/catalyst_gan_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



```
# This is formatted as c

```
# This is formatted as code
```

ode
```

# 20.11 version

In [1]:
! pip install catalyst==20.11

Collecting catalyst==20.11
[?25l  Downloading https://files.pythonhosted.org/packages/39/45/24a485b76527a2601f11f12c16b5b11f853b42a3ba029d21d5c80c6c30d1/catalyst-20.11-py2.py3-none-any.whl (489kB)
[K     |████████████████████████████████| 491kB 13.7MB/s 
[?25hCollecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/af/0c/4f41bcd45db376e6fe5c619c01100e9b7531c55791b7244815bac6eac32c/tensorboardX-2.1-py2.py3-none-any.whl (308kB)
[K     |████████████████████████████████| 317kB 55.1MB/s 
[?25hCollecting deprecation
  Downloading https://files.pythonhosted.org/packages/02/c3/253a89ee03fc9b9682f1541728eb66db7db22148cd94f89ab22528cd1e1b/deprecation-2.1.0-py2.py3-none-any.whl
Collecting GitPython>=3.1.1
[?25l  Downloading https://files.pythonhosted.org/packages/24/d1/a7f8fe3df258549b303415157328bfcc63e9b11d06a7ad7a3327f3d32606/GitPython-3.1.11-py3-none-any.whl (159kB)
[K     |████████████████████████████████| 163kB 57.7MB/s 
Collecting gitdb<5,>=4.0.1
[?25l  D

In [2]:
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.data.cv import ToTensor
from catalyst.contrib.datasets import MNIST
from catalyst.contrib.nn.modules import Flatten, GlobalMaxPool2d, Lambda

latent_dim = 128
generator = nn.Sequential(
    # We want to generate 128 coefficients to reshape into a 7x7x128 map
    nn.Linear(128, 128 * 7 * 7),
    nn.LeakyReLU(0.2, inplace=True),
    Lambda(lambda x: x.view(x.size(0), 128, 7, 7)),
    nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(128, 1, (7, 7), padding=3),
    nn.Sigmoid(),
)
discriminator = nn.Sequential(
    nn.Conv2d(1, 64, (3, 3), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    GlobalMaxPool2d(),
    Flatten(),
    nn.Linear(128, 1)
)

model = {"generator": generator, "discriminator": discriminator}
optimizer = {
    "generator": torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.5, 0.999)),
    "discriminator": torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999)),
}
loaders = {
    "train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32),
}

class CustomRunner(dl.Runner):

    def _handle_batch(self, batch):
        real_images, _ = batch
        batch_metrics = {}
        
        # Sample random points in the latent space
        batch_size = real_images.shape[0]
        random_latent_vectors = torch.randn(batch_size, latent_dim).to(self.device)
        
        # Decode them to fake images
        generated_images = self.model["generator"](random_latent_vectors).detach()
        # Combine them with real images
        combined_images = torch.cat([generated_images, real_images])
        
        # Assemble labels discriminating real from fake images
        labels = torch.cat([
            torch.ones((batch_size, 1)), torch.zeros((batch_size, 1))
        ]).to(self.device)
        # Add random noise to the labels - important trick!
        labels += 0.05 * torch.rand(labels.shape).to(self.device)
        
        # Train the discriminator
        predictions = self.model["discriminator"](combined_images)
        batch_metrics["loss_discriminator"] = \
          F.binary_cross_entropy_with_logits(predictions, labels)
        
        # Sample random points in the latent space
        random_latent_vectors = torch.randn(batch_size, latent_dim).to(self.device)
        # Assemble labels that say "all real images"
        misleading_labels = torch.zeros((batch_size, 1)).to(self.device)
        
        # Train the generator
        generated_images = self.model["generator"](random_latent_vectors)
        predictions = self.model["discriminator"](generated_images)
        batch_metrics["loss_generator"] = \
          F.binary_cross_entropy_with_logits(predictions, misleading_labels)
        
        self.batch_metrics.update(**batch_metrics)

runner = CustomRunner()
runner.train(
    model=model, 
    optimizer=optimizer,
    loaders=loaders,
    callbacks=[
        dl.OptimizerCallback(
            optimizer_key="generator", 
            metric_key="loss_generator"
        ),
        dl.OptimizerCallback(
            optimizer_key="discriminator", 
            metric_key="loss_discriminator"
        ),
    ],
    main_metric="loss_generator",
    num_epochs=3,
    verbose=True,
    logdir="./logs_gan",
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/MNIST/raw/train-images-idx3-ubyte.gz to /content/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/MNIST/raw/train-labels-idx1-ubyte.gz to /content/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/MNIST/raw
Processing...
Done!





Attention, there is only one dataloader - train







1/3 * Epoch (train): 100% 1875/1875 [00:37<00:00, 50.50it/s, loss_discriminator=0.805, loss_generator=0.333]
[2020-12-18 15:53:30,545] 
1/3 * Epoch 1 (_base): lr/discriminator=0.0003 | lr/generator=0.0003 | momentum/discriminator=0.5000 | momentum/generator=0.5000
1/3 * Epoch 1 (train): loss_discriminator=0.7683 | loss_generator=0.3420
2/3 * Epoch (train): 100% 1875/1875 [00:36<00:00, 50.77it/s, loss_discriminator=0.834, loss_generator=0.311]
[2020-12-18 15:54:07,582] 
2/3 * Epoch 2 (_base): lr/discriminator=0.0003 | lr/generator=0.0003 | momentum/discriminator=0.5000 | momentum/generator=0.5000
2/3 * Epoch 2 (train): loss_discriminator=0.8314 | loss_generator=0.3116
3/3 * Epoch (train): 100% 1875/1875 [00:37<00:00, 50.09it/s, loss_discriminator=0.789, loss_generator=0.344]
[2020-12-18 15:54:45,119] 
3/3 * Epoch 3 (_base): lr/discriminator=0.0003 | lr/generator=0.0003 | momentum/discriminator=0.5000 | momentum/generator=0.5000
3/3 * Epoch 3 (train): loss_discriminator=0.8260 | loss

# Master version

In [None]:
! pip install git+https://github.com/catalyst-team/catalyst@master --upgrade

Collecting git+https://github.com/catalyst-team/catalyst@master
  Cloning https://github.com/catalyst-team/catalyst (to revision master) to /tmp/pip-req-build-r04jpj9d
  Running command git clone -q https://github.com/catalyst-team/catalyst /tmp/pip-req-build-r04jpj9d
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: catalyst
  Building wheel for catalyst (PEP 517) ... [?25l[?25hdone
  Created wheel for catalyst: filename=catalyst-20.12-cp36-none-any.whl size=512358 sha256=6491408aa9ab16567730b6a05fe64908154e241e85793551b9f52db5fbe0e67b
  Stored in directory: /tmp/pip-ephem-wheel-cache-00zer4nz/wheels/c5/6b/8c/16132d56af8955e9826cad50d651cbd422fe59e3175aee0efc
Successfully built catalyst
Installing collected packages: catalyst
  Found existing installation: catalyst 20.12
    Uninstalling catalyst-20.12:
      Successfully uninstalled

In [None]:
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.contrib.data.cv import ToTensor
from catalyst.contrib.datasets import MNIST
from catalyst.contrib.nn.modules import Flatten, GlobalMaxPool2d, Lambda

latent_dim = 128
generator = nn.Sequential(
    # We want to generate 128 coefficients to reshape into a 7x7x128 map
    nn.Linear(128, 128 * 7 * 7),
    nn.LeakyReLU(0.2, inplace=True),
    Lambda(lambda x: x.view(x.size(0), 128, 7, 7)),
    nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(128, 1, (7, 7), padding=3),
    nn.Sigmoid(),
)
discriminator = nn.Sequential(
    nn.Conv2d(1, 64, (3, 3), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True),
    GlobalMaxPool2d(),
    Flatten(),
    nn.Linear(128, 1)
)

model = {"generator": generator, "discriminator": discriminator}
optimizer = {
    "generator": torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.5, 0.999)),
    "discriminator": torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999)),
}
loaders = {
    "train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32),
}

class CustomRunner(dl.Runner):

    def _handle_batch(self, batch):
        real_images, _ = batch
        batch_metrics = {}
        
        # Sample random points in the latent space
        batch_size = real_images.shape[0]
        random_latent_vectors = torch.randn(batch_size, latent_dim).to(self.device)
        
        # Decode them to fake images
        generated_images = self.model["generator"](random_latent_vectors).detach()
        # Combine them with real images
        combined_images = torch.cat([generated_images, real_images])
        
        # Assemble labels discriminating real from fake images
        labels = torch.cat([
            torch.ones((batch_size, 1)), torch.zeros((batch_size, 1))
        ]).to(self.device)
        # Add random noise to the labels - important trick!
        labels += 0.05 * torch.rand(labels.shape).to(self.device)
        
        # Train the discriminator
        predictions = self.model["discriminator"](combined_images)
        batch_metrics["loss_discriminator"] = \
          F.binary_cross_entropy_with_logits(predictions, labels)
        
        # Sample random points in the latent space
        random_latent_vectors = torch.randn(batch_size, latent_dim).to(self.device)
        # Assemble labels that say "all real images"
        misleading_labels = torch.zeros((batch_size, 1)).to(self.device)
        
        # Train the generator
        generated_images = self.model["generator"](random_latent_vectors)
        predictions = self.model["discriminator"](generated_images)
        batch_metrics["loss_generator"] = \
          F.binary_cross_entropy_with_logits(predictions, misleading_labels)
        
        self.batch_metrics.update(**batch_metrics)

runner = CustomRunner()
runner.train(
    model=model, 
    optimizer=optimizer,
    loaders=loaders,
    callbacks=[
        dl.OptimizerCallback(
            optimizer_key="generator", 
            metric_key="loss_generator"
        ),
        dl.OptimizerCallback(
            optimizer_key="discriminator", 
            metric_key="loss_discriminator"
        ),
    ],
    main_metric="loss_generator",
    num_epochs=3,
    verbose=True,
    logdir="./logs_gan2",
)


Attention, there is only one dataloader - train



1/3 * Epoch (train): 100% 1875/1875 [00:52<00:00, 35.73it/s, loss_discriminator=0.804, loss_generator=0.332]
[2020-12-17 20:43:53,820] 
1/3 * Epoch 1 (_base): lr/discriminator=0.0003 | lr/generator=0.0003 | momentum/discriminator=0.5000 | momentum/generator=0.5000
1/3 * Epoch 1 (train): loss_discriminator=0.8221 | loss_generator=0.3164
2/3 * Epoch (train): 100% 1875/1875 [00:52<00:00, 35.57it/s, loss_discriminator=0.762, loss_generator=0.335]
[2020-12-17 20:44:46,638] 
2/3 * Epoch 2 (_base): lr/discriminator=0.0003 | lr/generator=0.0003 | momentum/discriminator=0.5000 | momentum/generator=0.5000
2/3 * Epoch 2 (train): loss_discriminator=0.8335 | loss_generator=0.3106
3/3 * Epoch (train): 100% 1875/1875 [00:52<00:00, 35.68it/s, loss_discriminator=0.777, loss_generator=0.339]
[2020-12-17 20:45:39,303] 
3/3 * Epoch 3 (_base): lr/discriminator=0.0003 | lr/generator=0.0003 | momentum/discriminator=0.5000 | momentum/generator=0.5000
3/3 * Epoch 3 (train): loss_discriminator=0.8257 | loss_gen