In [0]:
#@title Prepare the environment and download necessary files.
# download needed libraries and files
!pip install comet_ml;
!wget https://raw.githubusercontent.com/egebeyazit/infogan/master/discriminator.py
!wget https://raw.githubusercontent.com/egebeyazit/infogan/master/generator.py
!wget https://raw.githubusercontent.com/egebeyazit/infogan/master/params.py
!wget https://raw.githubusercontent.com/egebeyazit/infogan/master/utils.py

from comet_ml import Experiment
import os
import numpy as np
import itertools
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch
import params
import utils

# reproducibility
torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# prep sample folders
os.makedirs("./images/static/", exist_ok=True)
os.makedirs("./images/generator1/", exist_ok=True) # varied c generator 1
os.makedirs("./images/generator2/", exist_ok=True) # varied c generator 2
os.makedirs("./images/static1/", exist_ok=True) # static images by generator1
os.makedirs("./images/static2/", exist_ok=True) # static images by generator1

# use GPU if available
cuda = utils.cuda
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

# get parameters
opt = params.opt

# download the dataset
dataloader = utils.get_MNIST_loader()

'''
Go to the google Colab console (ctrl+shift+i) :

function ClickConnect(){console.log("Working");document.querySelector("colab-toolbar-button#connect").click()}setInterval(ClickConnect,60000)

Dont exit the console until you get "Working" as the output in the console window. It would keep on clicking the page and prevent it from disconnecting.
'''

Collecting comet_ml
[?25l  Downloading https://files.pythonhosted.org/packages/99/c6/fac88f43f2aa61a09fee4ffb769c73fe93fe7de75764246e70967d31da09/comet_ml-3.0.2-py3-none-any.whl (170kB)
[K     |████████████████████████████████| 174kB 5.2MB/s 
[?25hCollecting comet-git-pure>=0.19.11
[?25l  Downloading https://files.pythonhosted.org/packages/fa/91/b191ae375380332f82aaa83a41c45844ee1809198085cd267fbcb95cce86/comet_git_pure-0.19.14-py3-none-any.whl (401kB)
[K     |████████████████████████████████| 409kB 74.1MB/s 
[?25hCollecting websocket-client>=0.55.0
[?25l  Downloading https://files.pythonhosted.org/packages/4c/5f/f61b420143ed1c8dc69f9eaec5ff1ac36109d52c80de49d66e0c36c3dfdf/websocket_client-0.57.0-py2.py3-none-any.whl (200kB)
[K     |████████████████████████████████| 204kB 76.4MB/s 
[?25hCollecting everett[ini]>=1.0.1; python_version >= "3.0"
  Downloading https://files.pythonhosted.org/packages/12/34/de70a3d913411e40ce84966f085b5da0c6df741e28c86721114dd290aaa0/everett-1.0.2-py

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:01, 9728837.18it/s]                            


Extracting ../../data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to ../../data/mnist/MNIST/raw


  0%|          | 0/28881 [00:00<?, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 136875.30it/s]           
  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../../data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ../../data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:00, 2331802.07it/s]                            
0it [00:00, ?it/s]

Extracting ../../data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ../../data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 51878.95it/s]            

Extracting ../../data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../../data/mnist/MNIST/raw
Processing...
Done!





In [0]:
generator1, _, _, categorical_loss, continuous_loss = utils.init_GAN()
generator2, discriminator, adversarial_loss, categorical_loss, continuous_loss = utils.init_GAN()

optimizer_G1 = torch.optim.Adam(generator1.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_G2 = torch.optim.Adam(generator2.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

optimizer_info1 = torch.optim.Adam(itertools.chain(generator1.parameters(), discriminator.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_info2 = torch.optim.Adam(itertools.chain(generator2.parameters(), discriminator.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2))

static_z, static_label, static_code = utils.get_static_gen_input()

In [0]:
experiment = Experiment(api_key="plg42bGPkFkyBcCXbg7RC8xys", project_name="bn-infogan", workspace="egebeyazit93")
experiment.log_parameters(vars(opt))
#  Training
for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = utils.to_categorical(labels.numpy(), num_columns=opt.n_classes)

        # -----------------
        #  Train Generator 1
        # -----------------
        optimizer_G1.zero_grad()
        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
        label_input = utils.to_categorical(np.random.randint(0, opt.n_classes, batch_size), num_columns=opt.n_classes)
        code_input = Variable(FloatTensor(np.random.uniform(-1, 1, (batch_size, opt.code_dim))))
        # Generate a batch of images
        gen_imgs1 = generator1(z, label_input, code_input)
        # Loss measures generator's ability to fool the discriminator
        validity, _, _ = discriminator(gen_imgs1)
        g_loss1 = adversarial_loss(validity, valid)
        experiment.log_metric("g1_loss", g_loss1.item(), step=(epoch + 1) * i)
        g_loss1.backward()
        optimizer_G1.step()

        # -----------------
        #  Train Generator 2
        # -----------------
        optimizer_G2.zero_grad()
        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
        label_input = utils.to_categorical(np.random.randint(0, opt.n_classes, batch_size), num_columns=opt.n_classes)
        code_input = Variable(FloatTensor(np.random.uniform(-1, 1, (batch_size, opt.code_dim))))
        # Generate a batch of images
        gen_imgs2 = generator2(z, label_input, code_input)
        # Loss measures generator's ability to fool the discriminator
        validity, _, _ = discriminator(gen_imgs2)
        g_loss2 = adversarial_loss(validity, valid)
        experiment.log_metric("g2_loss", g_loss2.item(), step=(epoch + 1) * i)
        g_loss2.backward()
        optimizer_G2.step()


        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        # Loss for real images
        real_pred, _, _ = discriminator(real_imgs)
        d_real_loss = adversarial_loss(real_pred, valid)

        # Loss for fake images
        fake_pred, _, _ = discriminator(gen_imgs1.detach())
        d_fake_loss1 = adversarial_loss(fake_pred, fake)
        fake_pred, _, _ = discriminator(gen_imgs2.detach())
        d_fake_loss2 = adversarial_loss(fake_pred, fake)
        d_fake_loss = (d_fake_loss1 + d_fake_loss2) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2
        experiment.log_metric("d_loss", d_loss.item(), step=(epoch + 1) * i)
        d_loss.backward()
        optimizer_D.step()

        # ------------------
        # Information Loss
        # ------------------
        # Sample labels
        sampled_labels = np.random.randint(0, opt.n_classes, batch_size)
        # Ground truth labels
        gt_labels = Variable(LongTensor(sampled_labels), requires_grad=False)
        # Sample noise, labels and code as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
        label_input = utils.to_categorical(sampled_labels, num_columns=opt.n_classes)
        code_input = Variable(FloatTensor(np.random.uniform(-1, 1, (batch_size, opt.code_dim))))

        # Information Loss 1
        optimizer_info1.zero_grad()
        gen_imgs1 = generator1(z, label_input, code_input)
        _, pred_label1, pred_code1 = discriminator(gen_imgs1)
        info_loss1 = params.lambda_cat * categorical_loss(pred_label1, gt_labels) + params.lambda_con * continuous_loss(pred_code1, code_input)
        

        # Information Loss 2
        optimizer_info2.zero_grad()
        gen_imgs2 = generator2(z, label_input, code_input)
        _, pred_label2, pred_code2 = discriminator(gen_imgs2)
        info_loss2 = params.lambda_cat * categorical_loss(pred_label2, gt_labels) + params.lambda_con * continuous_loss(pred_code2, code_input)
        
                              
        info_loss1_1 = info_loss1 #- info_loss2
        info_loss2_2 = info_loss2 #- info_loss1

        experiment.log_metric("info_loss1", info_loss1_1.item(), step=(epoch + 1) * i)
        experiment.log_metric("info_loss2", info_loss2_2.item(), step=(epoch + 1) * i)

        info_loss1_1.backward(retain_graph=True)
        info_loss2_2.backward()
        
        optimizer_info1.step()
        optimizer_info2.step()
        

        # --------------
        # Log Progress
        # --------------
        if i == len(dataloader) - 1:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G1 loss: %f] [G2 loss: %f] [info loss: %f]"
                % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss1.item(), g_loss2.item(), info_loss1.item())
            )
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            utils.sample_image(generator1, generator2, n_row=10, batches_done=batches_done)

torch.save({'generator1': generator1.state_dict(),
            'generator2': generator2.state_dict(),
            'discriminator': discriminator.state_dict(),
            'parameters': opt}, './trained_models/model_final_{}'.format(opt.n_epochs))


experiment.log_asset('./trained_models/model_final_{}'.format(opt.n_epochs))
experiment.log_asset_folder('.', step=None, log_file_name=False, recursive=False)
experiment.log_asset_folder('./images', step=None, log_file_name=True, recursive=True)

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/egebeyazit93/bn-infogan/8e9a1e2fb7cd4517ad61a3ccc7a11076

  input = module(input)


[Epoch 0/200] [Batch 937/938] [D loss: 0.227171] [G1 loss: 0.260906] [G2 loss: 0.281181] [info loss: 1.469292]
[Epoch 1/200] [Batch 937/938] [D loss: 0.255146] [G1 loss: 0.250885] [G2 loss: 0.311090] [info loss: 1.470873]
[Epoch 2/200] [Batch 937/938] [D loss: 0.215086] [G1 loss: 0.277747] [G2 loss: 0.339698] [info loss: 1.493131]
[Epoch 3/200] [Batch 937/938] [D loss: 0.236408] [G1 loss: 0.207537] [G2 loss: 0.268218] [info loss: 1.478878]
[Epoch 4/200] [Batch 937/938] [D loss: 0.236159] [G1 loss: 0.321709] [G2 loss: 0.309771] [info loss: 1.472713]
[Epoch 5/200] [Batch 937/938] [D loss: 0.195843] [G1 loss: 0.330316] [G2 loss: 0.291886] [info loss: 1.481140]
[Epoch 6/200] [Batch 937/938] [D loss: 0.297802] [G1 loss: 0.281432] [G2 loss: 0.227950] [info loss: 1.510706]
[Epoch 7/200] [Batch 937/938] [D loss: 0.208144] [G1 loss: 0.317507] [G2 loss: 0.247291] [info loss: 1.469072]
[Epoch 8/200] [Batch 937/938] [D loss: 0.262268] [G1 loss: 0.342625] [G2 loss: 0.283586] [info loss: 1.502209]
[

FileNotFoundError: ignored

In [0]:
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

NameError: ignored