# Categorical GAN with pytorch

This ipynb file contains codes for Categorical GAN, with different possible networks for both **Generator** and **discriminator**.

- Codes are made to work in *Google CoLab* with relevant files located at *Google Drive*.

In [None]:
from __future__ import print_function
%matplotlib inline
import os
import random
import torch.nn as nn
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision.models as models
import matplotlib.pyplot as plt
import matplotlib.animation as animation

import numpy as np
import random

from functools import partial
from torch.autograd import grad as torch_grad
from PIL import Image

import math
from math import floor, log2
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dset

from google.colab import drive

# Set random seed for reproducibility
manualSeed = 1
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Random Seed:  1


<torch._C.Generator at 0x7fc7411c71c8>

Mount my Google drive folder.
Authorization code is required when;
- Mount for the first time
- After prolonged inactivity

In [None]:
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
def data_sel(option):
    if option == 'KDEF':
        # Root directory for KDEF dataset
        dataroot = "/content/drive/My Drive/Colab Notebooks/KDEF"
    elif option == 'MMI':
        # Root directory for MMI_sel dataset
        dataroot = "/content/drive/My Drive/Colab Notebooks/MMI_selected"
    else:
      print('No dataset selected')
    return dataroot

In [None]:
dataroot = data_sel('KDEF')

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 32

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Size of z latent vector (i.e. size of generator input)
nz = 128

# Number of training epochs
num_epochs = 50

# Learning rate for optimizers
lr = 1e-4

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

K = 6

Generate train and test set from selected folders of the dataset.
- Train set - 70% of the dataset
- Test set - 30% of the dataset

The image is resized to run within the CoLab environment.

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y
        
    def __len__(self):
        return len(self.subset)

In [None]:
# We can use an image folder dataset the way we have it setup.
# Create the dataset

def datagen(dataroot):
  dataset = dset.ImageFolder(root=dataroot)

  transform1 = transform=transforms.Compose([
                #  transforms.Resize((224, 224)),
                #  transforms.CenterCrop((224,224)),
                  # transforms.RandomCrop((400,400)),
                  transforms.Resize((60,48)),
                  transforms.ToTensor(),
                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
              ])
  transform2 = transform=transforms.Compose([
                  # transforms.CenterCrop((576,576)),
                  # transforms.Resize((224, 224)),
                  transforms.Resize((60,48)),
                  transforms.ToTensor(),
                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
              ])
  # Create the dataloader
  test_len = int(len(dataset)*0.3)
  train_len = len(dataset) - test_len
  train_set, test_set = torch.utils.data.random_split(dataset, [train_len, test_len])

  train_set = MyDataset(train_set, transform1)
  test_set = MyDataset(test_set, transform2)

  train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True, num_workers=workers)
  test_loader = torch.utils.data.DataLoader(test_set, batch_size=test_len,
                                          shuffle=True, num_workers=workers)


  return train_loader, test_loader
# # Plot some training images
# real_batch = next(iter(train_loader))
# plt.figure(figsize=(8,8))
# plt.axis("off")
# plt.title("Training Images")
# plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
# def weights_init(m):
#     classname = m.__class__.__name__
#     if classname.find('Conv') != -1:
#         nn.init.normal_(m.weight.data, 0.0, 0.02)
#     elif classname.find('BatchNorm') != -1:
#         nn.init.normal_(m.weight.data, 1.0, 0.02)
#         nn.init.constant_(m.bias.data, 0.0)
#     elif classname.find('Linear') != -1:
#       nn.init.normal_(m.weight.data,0.0,0.01)
#       nn.init.constant_(m.bias.data,0.0) 

### Generator and Discriminator

The initial networks are from the original CatGAN paper, except that the dimension of the input data is different.

Two other generator networks have been set up
- **Variational AudoEncoder**
- **StyleGAN 2**

Four other discriminator netowrks have been set up
 - **Alexnet**
 - **VGG 16**
 - **ResNet 50**
 - **GoogLE Net**

In [None]:
class ClippedReLU(nn.Module):
    def __init__(self):
        super(ClippedReLU, self).__init__()
        self.clipped = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.clipped(x)
        return x.clamp(max=10)

class inception(nn.Module):
    def __init__(self, in_channel, acti, filter1, filter3r, filter3, filter5r, filter5, filterpool):
        super(inception, self).__init__()
        if acti == 'relu':
          self.acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          self.acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
          self.acti = ClippedReLU()

        self.conv1 = nn.Conv2d(in_channel, filter1, 1, 1)
        self.norm1 = nn.BatchNorm2d(filter1)
        self.conv3r = nn.Conv2d(in_channel, filter3r, 1, 1)
        self.norm3r = nn.BatchNorm2d(filter3r)
        self.conv3 = nn.Conv2d(filter3r, filter3, 3, 1, 1)
        self.norm3 = nn.BatchNorm2d(filter3)
        self.conv5r = nn.Conv2d(in_channel, filter5r, 1, 1)
        self.norm5r = nn.BatchNorm2d(filter5r)
        self.conv5 = nn.Conv2d(filter5r, filter5, 5, 1, 2)
        self.norm5 = nn.BatchNorm2d(filter5)
        self.pool = nn.MaxPool2d(3, 1, 1)
        self.convpool = nn.Conv2d(in_channel, filterpool, 1, 1)
        self.normpool = nn.BatchNorm2d(filterpool)

    def forward(self, x):
        x1 = self.acti(self.norm1(self.conv1(x)))
        x2 = self.acti(self.norm3r(self.conv3r(x)))
        x2 = self.acti(self.norm3(self.conv3(x2)))
        x3 = self.acti(self.norm5r(self.conv5r(x)))
        x3 = self.acti(self.norm5(self.conv5(x3)))
        x4 = self.acti(self.normpool(self.convpool(self.pool(x))))
        x = torch.cat((x1, x2, x3, x4), 1)
        return x


class identity_block(nn.Module):
    def __init__(self, in_channel, acti, filter1, filter2, filter3):
        super(identity_block, self).__init__()
        if acti == 'relu':
          self.acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          self.acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
          self.acti = ClippedReLU()

        self.conv1 = nn.Conv2d(in_channel, filter1, 1, 1)
        self.norm1 = nn.BatchNorm2d(filter1)
        self.conv2 = nn.Conv2d(filter1, filter2, 3, 1, 1)
        self.norm2 = nn.BatchNorm2d(filter2)
        self.conv3 = nn.Conv2d(filter2, filter3, 1, 1)
        self.norm3 = nn.BatchNorm2d(filter3)

    def forward(self, x):
        x1 = self.acti(self.norm1(self.conv1(x)))
        x1 = self.acti(self.norm2(self.conv2(x1)))
        x1 = self.acti(self.norm3(self.conv3(x1)))
        x = self.acti(torch.add(x1,x))
        return x


class conv_block(nn.Module):
    def __init__(self, in_channel, acti, filter1, filter2, filter3, stride = (2, 2)):
        super(conv_block, self).__init__()
        if acti == 'relu':
          self.acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          self.acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
          self.acti = ClippedReLU()
        
        self.convshort = nn.Conv2d(in_channel, filter3, 1, stride)
        self.normshort = nn.BatchNorm2d(filter3)

        self.conv1 = nn.Conv2d(in_channel, filter1, 1, stride)
        self.norm1 = nn.BatchNorm2d(filter1)
        self.conv2 = nn.Conv2d(filter1, filter2, 3, 1, 1)
        self.norm2 = nn.BatchNorm2d(filter2)
        self.conv3 = nn.Conv2d(filter2, filter3, 1, 1)
        self.norm3 = nn.BatchNorm2d(filter3)

    def forward(self, x):
        x1 = self.normshort(self.convshort(x))
        x2 = self.acti(self.norm1(self.conv1(x)))
        x2 = self.acti(self.norm2(self.conv2(x2)))
        x2 = self.acti(self.norm3(self.conv3(x2)))
        x = self.acti(torch.add(x1, x2))
        return x

In [None]:
class EqualLinear(nn.Module):
    def __init__(self, in_dim, out_dim, lr_mul = 1, bias = True):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim))

        self.lr_mul = lr_mul

    def forward(self, input):
        return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)


class Mapping(nn.Module):
    def __init__(self, emb, depth, lr_mul = 0.1, acti = 'leaky'):
        super().__init__()
        if acti == 'relu':
          acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
         acti = ClippedReLU()
       
        layers = []
        for i in range(depth):
            layers.extend([EqualLinear(emb, emb, lr_mul), acti])

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = F.normalize(x, dim=1)
        
        return self.net(x)


class RGBBlock(nn.Module):
    def __init__(self, latent_dim, input_channel, upsample, rgba = False):
        super().__init__()
        self.input_channel = input_channel
        self.to_style = nn.Linear(latent_dim, input_channel)

        out_filters = 3 if not rgba else 4
        self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)

        self.upsample = nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False) if upsample else None

    def forward(self, x, prev_rgb, istyle):
        b, c, h, w = x.shape
        style = self.to_style(istyle)
        x = self.conv(x, style)

        if prev_rgb is not None:
            x = x + prev_rgb

        if self.upsample is not None:
            x = self.upsample(x)

        return x

class Conv2DMod(nn.Module):
    def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, basic =False, acti= 'leaky', **kwargs):
        super().__init__()

        if acti == 'relu':
          acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
         acti = ClippedReLU()


        self.filters = out_chan
        self.demod = demod
        self.kernel = kernel
        self.stride = stride
        self.dilation = dilation
        self.basic = basic
        self.chan = in_chan

        self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
        nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in')


    def _get_same_padding(self, size, kernel, dilation, stride):
        return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2

    def forward(self, x, y = None):
        b, c, h, w = x.shape

        if self.basic:
            padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
            weights = nn.Parameter(torch.randn((b, self.filters, self.chan, self.kernel, self.kernel)).to(device=device))
            nn.init.kaiming_normal_(weights, a=0, mode='fan_in')
            _, _, *ws = weights.shape
            weights = weights.reshape(b*self.filters, *ws)
            x = x.reshape(1, -1, h, w)
            x = F.conv2d(x, weights, padding=padding, groups=b)
            x = x.reshape(-1, self.filters, h, w) 
            return x

        w1 = y[:, None, :, None, None]
        w2 = self.weight[None, :, :, :, :]
        weights = w2 * (w1 + 1)

        if self.demod:
            d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + 1e-6)
            weights = weights * d

        x = x.reshape(1, -1, h, w)

        _, _, *ws = weights.shape
        weights = weights.reshape(b * self.filters, *ws)

        padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
        x = F.conv2d(x, weights, padding=padding, groups=b)

        x = x.reshape(-1, self.filters, h, w)

        return x

class GeneratorBlock(nn.Module):
    def __init__(self, latent_dim, input_channels, filters, architecture, acti= 'leaky', upsample = True, upsample_rgb = True, rgba = False):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
        self.architecture = architecture
        self.to_style1 = nn.Linear(latent_dim, input_channels)
        self.to_noise1 = nn.Linear(1, filters)
        self.conv1 = Conv2DMod(input_channels, filters, 3)

        self.to_style2 = nn.Linear(latent_dim, filters)
        self.to_noise2 = nn.Linear(1, filters)
        self.conv2 = Conv2DMod(filters, filters, 3)
        self.resconv = Conv2DMod(input_channels, filters, 1, basic = True)

        if acti == 'relu':
          self.activation = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          self.activation = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
          self.activation = ClippedReLU()

        self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
    def forward(self, x, prev_rgb, istyle, inoise):
        t = x
        
        if self.upsample is not None:
            x = self.upsample(x)
            
        inoise = inoise[:, :x.shape[3], :x.shape[2], :]
        noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
        noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))
        style1 = self.to_style1(istyle)
        x = self.conv1(x, style1)
        x = self.activation(x + noise1)

        style2 = self.to_style2(istyle)
        x = self.conv2(x, style2)
        x = self.activation(x + noise2)
        
        if self.architecture == 'resnet':
            if self.upsample is not None:
                t = self.upsample(t)
            t = self.resconv(t)
            x = (x + t) * (1 / np.sqrt(2))
        
        rgb = self.to_rgb(x, prev_rgb, istyle)
        return x, rgb

In [None]:
class StyleGAN(nn.Module):
    def __init__(self, image_size, style_depth = 8, latent_dim = 512, network_capacity = 16, transparent = False, attn_layers = [], no_const = False, fmap_max = 512, acti = 'leaky'):
        super(StyleGAN, self).__init__()
        self.image_size = image_size
        self.latent_dim = latent_dim
        self.num_layers = int(log2(16) - 1)

        filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1]

        set_fmap_max = partial(min, fmap_max)
        filters = list(map(set_fmap_max, filters))
        init_channels = filters[0]
        filters = [init_channels, *filters]

        in_out_pairs = zip(filters[:-1], filters[1:])
        self.no_const = no_const

        if no_const:
            self.to_initial_block = nn.ConvTranspose2d(nz, init_channels, 4, 1, 0, bias=False)
        else:
            self.initial_block = nn.Parameter(torch.randn((1, init_channels, 15, 12)))
            

        self.initial_conv = nn.Conv2d(filters[0], filters[0], 3, padding=1)
        self.blocks = nn.ModuleList([])

        self.mapp = Mapping(latent_dim, style_depth, lr_mul = 0.1, acti = acti)

        for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
            not_first = ind != 0
            not_last = ind != (self.num_layers - 1)
            num_layer = self.num_layers - ind

            block = GeneratorBlock(
                latent_dim,
                in_chan,
                out_chan,
                'resnet',
                acti,
                upsample = not_first,
                upsample_rgb = not_last,
                rgba = transparent                
            )
            self.blocks.append(block)

    def latent_to_w(self, Mapping, latent_descr):
        return [(Mapping(z), num_layers) for z, num_layers in latent_descr]

    def styles_def_to_tensor(self, styles_def):
        return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1)


    def forward(self, styles, input_noise):
        stylez = self.latent_to_w(self.mapp, styles)

        style = self.styles_def_to_tensor(stylez)

        batch_size = style.shape[0]
        image_size = self.image_size


        if self.no_const:
            avg_style = style.mean(dim=1)[:, :, None, None]
            x = self.to_initial_block(avg_style)
        else:
            x = self.initial_block.expand(batch_size, -1, -1, -1)

        rgb = None
        style_ = style.transpose(0, 1)
        x = self.initial_conv(x)
        for style_, block in zip(style_, self.blocks):

            x, rgb = block(x, rgb, style_, input_noise)

        return style, rgb


In [None]:
class VAE(nn.Module):
    def __init__(self, acti = 'leaky'):
        super(VAE, self).__init__()
        if acti == 'relu':
          self.acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          self.acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
          self.acti = ClippedReLU()

        # input is Z, going into a convolution
        self.conv1 = nn.Conv2d(3, 96, 5, 1, 2, bias=False)
        self.bn1 = nn.BatchNorm2d(96)
        # self.pool1 = nn.MaxPool2d(3, 2, 1)

        self.conv2 = nn.Conv2d(96, 96, 5, 1, 2, bias=False)
        self.bn2 = nn.BatchNorm2d(96)
        self.pool2 = nn.MaxPool2d(3, 2, 1)

        self.conv3 = nn.Conv2d(96, 192, 5, 1, 2, bias=False)
        self.bn3 = nn.BatchNorm2d(192)
        # self.pool3 = nn.MaxPool2d(3, 2, 1)

        self.conv4 = nn.Conv2d(192, 192, 5, 1, 2, bias=False)
        self.bn4 = nn.BatchNorm2d(192)
        self.pool4 = nn.MaxPool2d(3, 2, 1)

        self.mu = nn.Linear(15*12*192 , nz)
        self.sigma = nn.Linear(15*12*192 , nz)

        self.decoder = Generator(acti)

    def reparamatize(self, mu, logvar, batch):
        std = torch.exp(0.5*logvar)
        eps = torch.randn(batch, nz, device=device)
        return eps.mul(std).add_(mu)

    def forward(self, x, batch):
        x = self.bn1(self.conv1(x))

        # x = self.pool1(self.acti(x))
        x = self.acti(x)
        x = self.bn2(self.conv2(x))
        x = self.pool2(self.acti(x))
        x = self.bn3(self.conv3(x))
        # x = self.pool3(self.acti(x))
        x = self.acti(x)
        x = self.bn4(self.conv4(x))
        x = self.pool4(self.acti(x))
        x = x.view(-1, 15*12*192)
        mu = self.mu(x)
        logvar = self.sigma(x)
        out = self.reparamatize(mu, logvar, batch)
        out = self.decoder(out)
        return mu, logvar, out


In [None]:
class Generator(nn.Module):
    def __init__(self, acti = 'leaky'):
        super(Generator, self).__init__()
        if acti == 'relu':
          self.acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          self.acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
          self.acti = ClippedReLU()

        # input is Z, going into a convolution
        self.lin = nn.Linear(nz, 192*15*12)
        self.bn = nn.BatchNorm1d(192*15*12)
        self.up = nn.Upsample(scale_factor=2)
        self.conv1 = nn.Conv2d(192, 192, 5, 1, 2, bias=False)
        self.bn1 = nn.BatchNorm2d(192)
        self.conv2 = nn.Conv2d(192, 96, 5, 1, 2, bias=False)
        self.bn2 = nn.BatchNorm2d(96)
        self.conv3 = nn.Conv2d(96, 96, 5, 1, 2, bias=False)
        self.bn3 = nn.BatchNorm2d(96)
        self.conv4 = nn.Conv2d(96, 3, 5, 1, 2, bias=False)

        self.tanh = nn.Tanh()
        # self.lin = nn.ConvTranspose2d( nz, 1024, (4, 3), 1, 0, bias=False)
        # self.bn = nn.BatchNorm2d(1024)

        # self.conv1 = nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False)
        # self.bn1 = nn.BatchNorm2d(512)

        # self.conv2 = nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False)
        # self.bn2 = nn.BatchNorm2d(256)

        # self.conv3 = nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False)
        # self.bn3 = nn.BatchNorm2d(128)

        # self.conv4 = nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False)
        # self.tanh = nn.Tanh()
    def forward(self, x):
        x = self.acti(self.bn(self.lin(x)))
        x = self.up(x.view(-1, 192, 15, 12))
        x = self.acti(self.bn1(self.conv1(x)))
        x = self.acti(self.bn2(self.conv2(x)))
        x = self.up(x)
        x = self.acti(self.bn3(self.conv3(x)))
        x = self.tanh(self.conv4(x))
        return x


In [None]:
class Discriminator(nn.Module):
    def __init__(self, acti = 'leaky'):
        super(Discriminator, self).__init__()
        if acti == 'relu':
          self.acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          self.acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
          self.acti = ClippedReLU()

        self.conv1 = nn.Conv2d(3, 96, 5, 1, 2)
        self.norm1 = nn.BatchNorm2d(96)
        self.conv2 = nn.Conv2d(96, 96, 3, 1, 1)
        self.norm2 = nn.BatchNorm2d(96)
        self.conv3 = nn.Conv2d(96, 96, 3, 1, 1)
        self.norm3 = nn.BatchNorm2d(96)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv4 = nn.Conv2d(96, 192, 3, 1, 1)
        self.norm4 = nn.BatchNorm2d(192)
        self.conv5 = nn.Conv2d(192, 192, 3, 1, 1)
        self.norm5 = nn.BatchNorm2d(192)
        self.conv6 = nn.Conv2d(192, 192, 3, 1, 1)
        self.norm6 = nn.BatchNorm2d(192)
        self.pool2 = nn.MaxPool2d(3, 2, 1)
        self.conv7 = nn.Conv2d(192, 192, 3, 1, 1)
        self.norm7 = nn.BatchNorm2d(192)
        self.conv8 = nn.Conv2d(192, 192, 1)
        self.norm8 = nn.BatchNorm2d(192)
        self.conv9 = nn.Conv2d(192, 10, 1)
        self.norm9 = nn.BatchNorm2d(10)
        self.softmax = nn.Softmax(dim=-1)
        self.fc = nn.Linear(10*15*12, K)

    def forward(self, x):
        x = self.acti(self.norm1(self.conv1(x)))
        x = self.acti(self.norm2(self.conv2(x)))
        x = self.norm3(self.conv3(x))
        x = self.pool1(self.acti(x))
        x = self.acti(self.norm4(self.conv4(x)))
        x = self.acti(self.norm5(self.conv5(x)))
        x = self.norm6(self.conv6(x))
        x = self.pool2(self.acti(x))
        x = self.acti(self.norm7(self.conv7(x)))
        x = self.acti(self.norm8(self.conv8(x)))
        x = self.acti(self.norm9(self.conv9(x)))
        x = x.view(-1, 10*15*12)
        x = self.softmax(self.fc(x))

        return x




In [None]:
class Alexnet(nn.Module):

    """
    # Reference:
    - [ImageNet classification with deep convolutional neural networks]
        (https://doi.org/10.1145/3065386)
    """
  
    def __init__(self, acti = 'leaky'):
        super(Alexnet, self).__init__()

        if acti == 'relu':
          self.acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          self.acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
          self.acti = ClippedReLU()

        self.conv1 = nn.Conv2d(3, 96, 11, 4)
        self.pool1 = nn.MaxPool2d(3, 2, 1)
        self.norm1 = nn.BatchNorm2d(96)

        self.conv2 = nn.Conv2d(96, 256, 5, 1, 2)
        self.pool2 = nn.MaxPool2d(3, 2, 1)
        self.norm2 = nn.BatchNorm2d(256)
        
        self.conv3 = nn.Conv2d(256, 384, 3, 1, 1)
        self.norm3 = nn.BatchNorm2d(384)
        self.conv4 = nn.Conv2d(384, 384, 3, 1, 1)
        self.norm4 = nn.BatchNorm2d(384)
        self.conv5 = nn.Conv2d(384, 384, 3, 1, 1)
        self.norm5 = nn.BatchNorm2d(384)
        self.pool3 = nn.MaxPool2d(3, 2, 1)

        self.fc1 = nn.Linear(384*2*2,4096)
        self.drop1 = nn.Dropout(0.5)

        self.fc2 = nn.Linear(4096, 4096)
        self.drop2 = nn.Dropout(0.5)

        self.fc3 = nn.Linear(4096, K)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.norm1(self.conv1(x))
        x = self.pool1(self.acti(x))
        x = self.norm2(self.conv2(x))
        x = self.pool2(self.acti(x))
        x = self.acti(self.norm3(self.conv3(x)))
        x = self.acti(self.norm4(self.conv4(x)))
        x = self.norm5(self.conv5(x))
        x = self.pool3(self.acti(x))
        x = x.view(-1,384*2*2)
        x = self.drop1(self.acti(self.fc1(x)))
        x = self.drop2(self.acti(self.fc2(x)))
        x = self.softmax((self.fc3(x)))
        return x


In [None]:
class VGG_16(nn.Module):

    """
    # Reference:
    - [Very Deep Convolutional Networks for Large-Scale Image Recognition]
        (https://arxiv.org/pdf/1409.1556)
    """

    def __init__(self, K, acti = 'leaky'):
        super(VGG_16, self).__init__()

        if acti == 'relu':
          self.acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          self.acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
          self.acti = ClippedReLU()

        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
        self.norm1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, 3, 1, 1)
        self.norm2 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(2, 2)

        self.conv3 = nn.Conv2d(64, 128, 3, 1, 1)
        self.norm3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, 3, 1, 1)
        self.norm4 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.conv5 = nn.Conv2d(128, 256, 3, 1, 1)
        self.norm5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 256, 3, 1, 1)
        self.norm6 = nn.BatchNorm2d(256)
        self.conv7 = nn.Conv2d(256, 256, 3, 1, 1)
        self.norm7 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.conv8 = nn.Conv2d(256, 512, 3, 1, 1)
        self.norm8 = nn.BatchNorm2d(512)
        self.conv9 = nn.Conv2d(512, 512, 3, 1, 1)
        self.norm9 = nn.BatchNorm2d(512)
        self.conv10 = nn.Conv2d(512, 512, 3, 1, 1)
        self.norm10 = nn.BatchNorm2d(512)
        self.pool4 = nn.MaxPool2d(2, 2)


        self.conv11 = nn.Conv2d(512, 512, 3, 1, 1)
        self.norm11 = nn.BatchNorm2d(512)
        self.conv12 = nn.Conv2d(512, 512, 3, 1, 1)
        self.norm12 = nn.BatchNorm2d(512)
        self.conv13 = nn.Conv2d(512, 512, 3, 1, 1)
        self.norm13 = nn.BatchNorm2d(512)
        self.pool5 = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(512*1*1,4096)
        self.drop1 = nn.Dropout(0.5)

        self.fc2 = nn.Linear(4096, 4096)
        self.drop2 = nn.Dropout(0.5)

        self.fc3 = nn.Linear(4096, K)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):

        x = self.acti(self.norm1(self.conv1(x)))
        x = self.acti(self.norm2(self.conv2(x)))
        x = self.pool1(x)
        x = self.acti(self.norm3(self.conv3(x)))
        x = self.acti(self.norm4(self.conv4(x)))
        x = self.pool2(x)
        x = self.acti(self.norm5(self.conv5(x)))
        x = self.acti(self.norm6(self.conv6(x)))
        x = self.acti(self.norm7(self.conv7(x)))
        x = self.pool3(x)
        x = self.acti(self.norm8(self.conv8(x)))
        x = self.acti(self.norm9(self.conv9(x)))
        x = self.acti(self.norm10(self.conv10(x)))
        x = self.pool4(x)
        x = self.acti(self.norm11(self.conv11(x)))
        x = self.acti(self.norm12(self.conv12(x)))
        x = self.acti(self.norm13(self.conv13(x)))
        x = self.pool5(x)
        x = x.view(-1,512*1*1)
        x1 = self.fc1(x)
        x2 = self.drop1(self.acti(x1))
        x2 = self.fc2(x2)
        x2 = self.drop2(self.acti(x2))
        x2 = self.softmax(self.fc3(x2))      
        return x, x1, x2

In [None]:
class VGG_16_nobn(nn.Module):

    """
    # Reference:
    - [Very Deep Convolutional Networks for Large-Scale Image Recognition]
        (https://arxiv.org/pdf/1409.1556)
    """

    def __init__(self, K, acti = 'leaky'):
        super(VGG_16_nobn, self).__init__()

        if acti == 'relu':
          self.acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          self.acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
          self.acti = ClippedReLU()

        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
        self.conv2 = nn.Conv2d(64, 64, 3, 1, 1)
        self.pool1 = nn.MaxPool2d(2, 2)

        self.conv3 = nn.Conv2d(64, 128, 3, 1, 1)
        self.conv4 = nn.Conv2d(128, 128, 3, 1, 1)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.conv5 = nn.Conv2d(128, 256, 3, 1, 1)
        self.conv6 = nn.Conv2d(256, 256, 3, 1, 1)
        self.conv7 = nn.Conv2d(256, 256, 3, 1, 1)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.conv8 = nn.Conv2d(256, 512, 3, 1, 1)
        self.conv9 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv10 = nn.Conv2d(512, 512, 3, 1, 1)
        self.pool4 = nn.MaxPool2d(2, 2)


        self.conv11 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv12 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv13 = nn.Conv2d(512, 512, 3, 1, 1)
        self.pool5 = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(512*1*1,4096)
        self.drop1 = nn.Dropout(0.5)

        self.fc2 = nn.Linear(4096, 4096)
        self.drop2 = nn.Dropout(0.5)

        self.fc3 = nn.Linear(4096, K)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):

        x = self.acti((self.conv1(x)))
        x = self.acti((self.conv2(x)))
        x = self.pool1(x)
        x = self.acti((self.conv3(x)))
        x = self.acti((self.conv4(x)))
        x = self.pool2(x)
        x = self.acti((self.conv5(x)))
        x = self.acti((self.conv6(x)))
        x = self.acti((self.conv7(x)))
        x = self.pool3(x)
        x = self.acti((self.conv8(x)))
        x = self.acti((self.conv9(x)))
        x = self.acti((self.conv10(x)))
        x = self.pool4(x)
        x = self.acti((self.conv11(x)))
        x = self.acti((self.conv12(x)))
        x = self.acti((self.conv13(x)))
        x = self.pool5(x)
        x = x.view(-1,512*1*1)
        x1 = self.fc1(x)
        x2 = self.drop1(self.acti(x1))
        x2 = self.fc2(x2)
        x2 = self.drop2(self.acti(x2))
        x2 = self.softmax(self.fc3(x2))      
        return x, x1, x2

In [None]:
class ResNet_50(nn.Module):

    """
    # Reference:
    - [Deep Residual Learning for Image Recognition]
        (https://arxiv.org/abs/1512.03385)
    """

    def __init__(self, acti = 'leaky'):
        super(ResNet_50, self).__init__()

        if acti == 'relu':
          self.acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          self.acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
          self.acti = ClippedReLU()
          
        self.conv1 = nn.Conv2d(3, 64, 7, 2, 3)
        self.norm1 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(3, 2, 1)

        self.conv_a1 = conv_block(64, acti, 64, 64, 256, (1,1))
        self.iden_a1 = identity_block(256, acti, 64, 64, 256)
        self.iden_a2 = identity_block(256, acti, 64, 64, 256)

        self.conv_b1 = conv_block(256, acti, 128, 128, 512)
        self.iden_b1 = identity_block(512, acti, 128, 128, 512)
        self.iden_b2 = identity_block(512, acti, 128, 128, 512)
        self.iden_b3 = identity_block(512, acti, 128, 128, 512)

        self.conv_c1 = conv_block(512, acti, 256, 256, 1024)
        self.iden_c1 = identity_block(1024, acti, 256, 256, 1024)
        self.iden_c2 = identity_block(1024, acti, 256, 256, 1024)
        self.iden_c3 = identity_block(1024, acti, 256, 256, 1024)
        self.iden_c4 = identity_block(1024, acti, 256, 256, 1024)
        self.iden_c5 = identity_block(1024, acti, 256, 256, 1024)

        self.conv_d1 = conv_block(1024, acti, 512, 512, 2048)
        self.iden_d1 = identity_block(2048, acti, 512, 512, 2048)
        self.iden_d2 = identity_block(2048, acti, 512, 512, 2048)

        #Pooling kernel reduced due to the dimension of our dataset - [*, *, 2, 2]
        self.avgpool = nn.AvgPool2d(2, 1)
        self.fc = nn.Linear(2048,K)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
      x = self.norm1(self.conv1(x))
      x = self.pool(self.acti(x))
      x = self.iden_a2(self.iden_a1(self.conv_a1(x)))
      x = self.iden_b3(self.iden_b2(self.iden_b1(self.conv_b1(x))))
      x = self.iden_c2(self.iden_c1(self.conv_c1(x)))
      x = self.iden_c5(self.iden_c4(self.iden_c3(x)))
      x = self.iden_d2(self.iden_d1(self.conv_d1(x)))
      x = self.avgpool(x)
      x = x.view(-1, 2048)
      x = self.softmax(self.fc(x))
      
      return x

In [None]:
class GoogleNet(nn.Module):

    """
    # Reference:
    - [Going Deeper with Convolutions]
        (https://arxiv.org/pdf/1409.4842)
    """

    def __init__(self, acti = 'leaky'):
        super(GoogleNet, self).__init__()

        if acti == 'relu':
          self.acti = nn.ReLU(inplace=True)
        elif acti == 'leaky':
          self.acti = nn.LeakyReLU(negative_slope=0.3, inplace=True)
        elif acti == 'clipped':
          self.acti = ClippedReLU()

        self.conv1 = nn.Conv2d(3, 64, 7, 2, 3)
        self.pool1 = nn.MaxPool2d(3, 2, 1)
        self.norm1 = nn.BatchNorm2d(64)

        self.conv2 = nn.Conv2d(64, 64, 1, 1)
        self.norm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 192, 3, 1, 1)
        self.norm3 = nn.BatchNorm2d(192)
        self.pool2 = nn.MaxPool2d(3, 2, 1)

        self.incep3a = inception(192, acti, 64, 96, 128, 16, 32, 32)
        self.incep3b = inception(256, acti, 128, 128, 192, 32, 96, 64)
        self.pool3 = nn.MaxPool2d(3, 2, 1)

        self.incep4a = inception(480, acti, 192, 96, 208, 16, 48 ,64)
        self.incep4b = inception(512, acti, 160, 112, 224, 24, 64, 64)
        self.incep4c = inception(512, acti, 128, 128, 256, 24, 64, 64)
        self.incep4d = inception(512, acti, 112, 144, 288, 32, 64, 64)
        self.incep4e = inception(528, acti, 256, 160, 320, 32, 128, 128)
        self.pool4 = nn.MaxPool2d(3, 2, 1)

        self.incep5a = inception(832, acti, 256, 160, 320, 32, 128, 128)
        self.incep5b = inception(832, acti, 384, 192, 384, 48, 128, 128)

        #Pooling kernel reduced due to the dimension of our dataset - [*, *, 4, 3]
        self.avgpool4a = nn.AvgPool2d(3, 3)
        self.conv4a = nn.Conv2d(512, 128, 1, 1)
        self.fc4a = nn.Linear(128, 1024)
        self.drop4a = nn.Dropout(0.7)
        self.fc4a_1 = nn.Linear(1024, K)

        #Pooling kernel reduced due to the dimension of our dataset - [*, *, 4, 3]
        self.avgpool4d = nn.AvgPool2d(3, 3)
        self.conv4d = nn.Conv2d(528, 128, 1, 1)
        self.fc4d = nn.Linear(128, 1024)
        self.drop4d = nn.Dropout(0.7)
        self.fc4d_1 = nn.Linear(1024, K)

        #Pooling kernel reduced due to the dimension of our dataset - [*, *, 2, 2]
        self.avgpool5b = nn.AvgPool2d(2, 1)
        self.drop5b = nn.Dropout(0.4)
        self.fc5b = nn.Linear(1024,K)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.norm1(self.conv1(x))
        x = self.pool1(self.acti(x))
        x = self.acti(self.norm2(self.conv2(x)))
        x = self.acti(self.norm3(self.conv3(x)))
        x = self.pool2(x)

        x3 = self.incep3b(self.incep3a(x))
        x3 = self.pool3(x3)

        x4a = self.incep4a(x3)

        x4 = self.incep4c(self.incep4b(x4a))
        x4p = self.incep4d(x4)

        x4 = self.pool4(self.incep4e(x4p))
        x5 = self.incep5b(self.incep5a(x4))

        out1 = self.avgpool4a(x4a)
        out1 = self.acti(self.conv4a(out1))
        out1 = out1.view(-1, 128)
        out1 = self.acti(self.fc4a(out1))
        out1 = self.fc4a_1(self.drop4a(out1))
        out1 = self.softmax(out1)

        out2 = self.avgpool4d(x4p)
        out2 = self.acti(self.conv4d(out2))
        out2 = out2.view(-1, 128)
        out2 = self.acti(self.fc4d(out2))
        out2 = self.fc4d_1(self.drop4d(out2))
        out2 = self.softmax(out2)

        out3 = self.avgpool5b(x5)
        out3 = out3.view(-1, 1024)
        out3 = self.fc5b(self.drop5b(out3))
        out3 = self.softmax(out3)

        return out3

In [None]:
def disc_sel(option, K, acti = 'leaky'):
    if option == 'catgan':
        return Discriminator(acti).to(device)#.half()
    elif option == 'alex':
        return Alexnet(acti).to(device)
    elif option == 'vgg':
        return VGG_16(K, acti).to(device)
    elif option == 'vggbn':
        return VGG_16_nobn(K, acti).to(device)
    elif option == 'resnet':
        return ResNet_50(acti).to(device)
    elif option == 'google':
        return GoogleNet(acti).to(device)
    elif option == 'prune':
        return prune_disc(acti).to(device)
    else: 
      print('No option')

In [None]:
def gen_sel(option, acti = 'leaky'):
    if option == 'catgan':
        return Generator(acti).to(device)#.half()
    elif option == 'vae':
        return VAE(acti).to(device)
    elif option == 'style':
        return StyleGAN(image_size = 60, style_depth = 8, latent_dim = 512, network_capacity = 16, transparent = False, attn_layers = [], no_const = False, fmap_max = 512, acti = 'leaky').to(device)
    else: 
      print('No option')

As explained in the paper, three different entropy functions are required for the semi-supervised case of CatGAN
- Marginal Entropy
- Conditional Entropy
- Cross Entropy

In [None]:
class MarginalHLoss(nn.Module):
  def __init__(self):
    super(MarginalHLoss, self).__init__()
  def forward(self, x):# NxK
    x = x.mean(axis=0)
    x = -torch.sum(x*torch.log(x+1e-6))
    return x

class JointHLoss(nn.Module):
  def __init__(self):
    super(JointHLoss, self).__init__()
  def forward(self, x):
    x = -x*torch.log(x+1e-6)
    x = (1.0/batch_size) * torch.sum(x)
    return x
    #marginalized entropy

class CrossHLoss(nn.Module):
  def __init__(self):
    super(CrossHLoss, self).__init__()
  def forward(self, x, y):
    x = -torch.sum(x*torch.log(y+1e-6))
    # x = (1.0/batch_size) * x
    return x
    #marginalized entropy

class KLLoss(nn.Module):
  def __init__(self):
    super(KLLoss, self).__init__()
  def forward(self, x, y):
    x = 1 + y - x**2 - y.exp()
    x = -0.5 * torch.sum(x)
    x = (1.0/batch_size) * x
    return x
    #marginalized entropy


class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new       

class PL_length(nn.Module):
  def __init__(self):
    super(PL_length, self).__init__()

  def forward(self, styles, images):
    num_pixels = images.shape[2] * images.shape[3]
    pl_noise = torch.randn(images.shape).to(device=device) / math.sqrt(num_pixels)
    outputs = torch.sum(images * pl_noise)
    pl_grads = torch_grad(outputs=outputs, inputs=styles,
                          grad_outputs=torch.ones(outputs.shape).to(device=device),
                          create_graph=True, retain_graph=True, only_inputs=True)[0]
    return (pl_grads ** 2).sum().mean().sqrt()  

class PathPenalty(nn.Module):
  def __init__(self):
    super(PathPenalty, self).__init__()
  def forward(self, pl_lengths, plmean): 
    if not is_empty(plmean):
        pl_loss = ((pl_lengths - plmean) ** 2).mean()
    else:
      pl_loss = None
    return pl_loss

class GenEMA(nn.Module):
    def __init__(self):
        super().__init__()
        self.ema_updater = EMA(0.995)
    def EMA(self, GE, G_net):
        def update_moving_average(ma_model, current_model):
            for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
                old_weight, up_weight = ma_params.data, current_params.data
                ma_params.data = self.ema_updater.update_average(old_weight, up_weight)

        update_moving_average(GE, G_net)

### Optimisers

Three different optimisers are used for this project
- Adam
- RMSProp
- SGD Momentum

This section includes codes for the checkpoint saving

In [None]:
def opti_sel(model, option, lr):
  if option == 'adam':
    return optim.Adam(model.parameters(), lr = lr, eps = 1e-7)
  elif option == 'rmsprop':
    return optim.RMSprop(model.parameters(), lr = lr, eps = 1e-7, alpha = 0.9)
  elif option == 'sgdm':
    return optim.SGD(model.parameters(), lr = lr, momentum = 0.9)

In [None]:
def generate_and_save_images(epoch, model):
    # noise = torch.rand(4, nz, 1, 1).to(device=device)
    # noise = torch.rand(4, 3, 60, 48).to(device=device)
    noise = torch.randn(4, nz).to(device=device)
    # get_latents_fn = mixed_list if random() < 0.9 else noise_list
    # style = get_latents_fn(4, int(log2(60) - 1), 512)
    # noise = torch.rand(4, nz, nz, 1).to(device=device)
    # predictions = model.decoder(noise).float().detach().cpu()
    # _, predictions = model(style, noise)
    # predictions = predictions.float().detach().cpu()
    predictions = model(noise).float().detach().cpu()

    img_list.append(vutils.make_grid(predictions, padding=2, normalize=True))

    fig = plt.figure(figsize=(8,8))

    plt.axis("off")
    ims = [[plt.imshow(np.transpose(i,(1,2,0)))] for i in img_list]
    # plt.savefig(os.path.join(checkpoints_dir, 'images/generator_epoch_{:04d}.png'.format(epoch)))
    plt.show()

## Training

Trains the network with preset shown above.

During the training, it shows losses of the network, and the outputs from the generator for every 20 step (or for new epoch)

After the training, it shows the graph of generator and discriminator loss

In [None]:
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

# D_net = torchvision.models.vgg16(num_classes=6).to(device)
# D_optimizer = opti_sel(D_net, option = 'adam')

jointH = JointHLoss()
marginalH = MarginalHLoss()
crossH = CrossHLoss()
KL = KLLoss()
crossPY = nn.CrossEntropyLoss()
bcnloss = nn.BCELoss()
sigmoid = nn.Sigmoid()
pathlength = PathPenalty()
pll = PL_length()
GenEMA = GenEMA()
def noise(n, latent_dim):
    return torch.randn(n, latent_dim).to(device=device)

def noise_list(n, layers, latent_dim):
    return [(noise(n, latent_dim), layers)]

def mixed_list(n, layers, latent_dim):
    tt = int(torch.rand(()).numpy() * layers)
    return noise_list(n, tt, latent_dim) + noise_list(n, layers - tt, latent_dim)
def is_empty(t):
    if isinstance(t, torch.Tensor):
        return t.nelement() == 0
    return t is None



In [None]:
def GAN(G_net, D_net, D_optimizer, G_optimizer, img_list, G_losses, D_losses, fixed_noise, train_loader, test_loader, data):
    print("Starting Training Loop...")
    dat_name = data
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, (data, labels) in enumerate(train_loader) :

            b_size=data.shape[0] 

            data = data.to(device=device)
            if dat_name == 'KDEF':
              label = F.one_hot(labels, num_classes=6).to(device=device)
            elif dat_name == 'MMI':
              label = F.one_hot(labels, num_classes=5).to(device=device)
            #Train D

            D_net.zero_grad()
            y_real = D_net(data)
            
            cross_entropy = crossH(label, y_real[2])
            
            joint_entropy_real = jointH(y_real[2])#minimize uncertainty

            marginal_entropy_real = marginalH(y_real[2])#maximize uncertainty

            z = torch.rand(b_size, nz).to(device=device) #uniform distribution sampling
            # z = torch.rand(b_size,nz,1,1).to(device=device)#uniform distribution sampling
            fake_images = G_net(z)

            y_fake = D_net(fake_images.detach())
            
            joint_entropy_fake = jointH(y_fake[2])#maximize uncertainty

            loss_D = - (- cross_entropy  + marginal_entropy_real+ joint_entropy_fake -joint_entropy_real)
           
            loss_D.backward(retain_graph=True)
            D_optimizer.step()

            #Train G
            del y_fake, data
            G_net.zero_grad()
            y_fake = D_net(fake_images)
            marginal_entropy_fake = marginalH(y_fake[2])#maximize uncertainty
            joint_entropy_fake = jointH(y_fake[2])#maximize uncertainty
  
            loss_G = joint_entropy_fake - marginal_entropy_fake

            loss_G.backward(retain_graph=True)
            G_optimizer.step()


            if (i+1)%20 == 0 or i ==0:

              print('Epoch [{}/{}], Step [{}/{}] \n D_loss: {:.4f}, H_x[p(y|D)] : {:.4f}, E[H[p(y|x,D)]] : {:.4f}, \
                      E[H[p(y|G(z),D)]]:{:.4f}\n G_loss: {:.4f}, H_G[p(y|D)] : {:.4f}, E[H[p(y|G(z),D)]]: {:.4f}, E[CE[y,p(y|x,D)]: {:.4f}]' 
                    .format(epoch, num_epochs, i+1, len(train_loader), loss_D.item(), marginal_entropy_real.item(), joint_entropy_real.item()\
                            ,joint_entropy_fake.item(),loss_G.item(),marginal_entropy_fake.item(),joint_entropy_fake.item(), cross_entropy.item()))
              G_losses.append(loss_G.item())
              D_losses.append(loss_D.item())
              # print(y_real[0], y_fake[0], label[0])#, loss_D.item(), loss_G.item())
              # with torch.no_grad():

              #     fake = G_net(fixed_noise).float().detach().cpu()
              #     img_list.append(vutils.make_grid(torch.reshape(fake,(b_size,3,224,224))[:64], padding=2, normalize=True))
              #     transform_PIL=transforms.ToPILImage()
              #     transform_PIL(img_list[-1]).save(str(epoch)+"CATGAN_MNIST_Last.png")
              #     del fake

            del y_real, y_fake, cross_entropy, joint_entropy_fake, joint_entropy_real, marginal_entropy_real
            del label, marginal_entropy_fake, fake_images, loss_G, loss_D
        # generate samples every 2 epochs for surveillance


        # if epoch % 1 == 0:
        #     generate_and_save_images(epoch, G_net)


        # do checkpointing every 20 epochs
        # if epoch == (num_epochs - 1):
    torch.save(G_net.state_dict(), os.path.join(checkpoints_dir, 'G_net_{}_GAN.pth'.format(dat_name)))
    torch.save(D_net.state_dict(), os.path.join(checkpoints_dir, 'D_net_{}_GAN.pth'.format(dat_name)))

In [None]:
def VAEGAN(G_net, D_net, D_optimizer, G_optimizer, img_list, G_losses, D_losses, fixed_noise, train_loader, test_loader):
    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(num_epochs):
        # print(torch.cuda.memory_summary())
        # For each batch in the dataloader
        for i, (data, labels) in enumerate(train_loader) :

            b_size=data.shape[0] 

            data = data.to(device=device)#.half()
            label = F.one_hot(labels, num_classes=6).to(device=device)
            #Train D

            D_net.zero_grad()
            y_real = D_net(data)
            
            cross_entropy = crossH(label, y_real)
            joint_entropy_real = jointH(y_real)#minimize uncertainty
            marginal_entropy_real = marginalH(y_real)#maximize uncertainty

            # z = torch.rand(b_size, 3, 60, 48).to(device=device)#.half() #uniform distribution sampling
            z = torch.randn(b_size,nz).to(device=device)#uniform distribution sampling
            fake_images = G_net.decoder(z)

            y_fake = D_net(fake_images.detach())
            joint_entropy_fake = jointH(y_fake)#maximize uncertainty

            loss_D = - (- cross_entropy  + marginal_entropy_real+ joint_entropy_fake -joint_entropy_real)
           
            loss_D.backward(retain_graph=True)
            D_optimizer.step()

            #Train G
            del y_fake, data
            G_net.zero_grad()
            y_fake = D_net(fake_images)

            marginal_entropy_fake = marginalH(y_fake)#maximize uncertainty

            joint_entropy_fake = jointH(y_fake)#maximize uncertainty

            # reparamatize_loss = KL(m, logvar)
            loss_G = joint_entropy_fake - marginal_entropy_fake #+ reparamatize_loss

            loss_G.backward(retain_graph=True)
            G_optimizer.step()


            if (i+1)%20 == 0 or i ==0:
              # with open(os.path.join(checkpoints_dir, 'results.txt'), 'a+') as f:
              #   print('Epoch [{}/{}], Step [{}/{}] \n D_loss: {:.4f}, H_x[p(y|D)] : {:.4f}, E[H[p(y|x,D)]] : {:.4f}, \n \
              #           E[H[p(y|G(z),D)]]:{:.4f}\n G_loss: {:.4f}, H_G[p(y|D)] : {:.4f}, E[H[p(y|G(z),D)]]: {:.4f}, \n \
              #           E[CE[y,p(y|x,D)]: {:.4f}'#, KL: {:.4f}' 
              #         .format(epoch, num_epochs, i+1, len(train_loader), loss_D.item(), marginal_entropy_real.item(), joint_entropy_real.item()\
              #                 ,joint_entropy_fake.item(),loss_G.item(),marginal_entropy_fake.item(), joint_entropy_fake.item(), cross_entropy.item()), file = f)#\
              #                 # ,reparamatize_loss.item()))
              #   print(y_real[0], y_fake[0], label[0], file = f)#, loss_D.item(), loss_G.item())
              print('Epoch [{}/{}], Step [{}/{}] \n D_loss: {:.4f}, H_x[p(y|D)] : {:.4f}, E[H[p(y|x,D)]] : {:.4f}, \n \
                      E[H[p(y|G(z),D)]]:{:.4f}\n G_loss: {:.4f}, H_G[p(y|D)] : {:.4f}, E[H[p(y|G(z),D)]]: {:.4f}, \n \
                      E[CE[y,p(y|x,D)]: {:.4f}'#, KL: {:.4f}' 
                    .format(epoch, num_epochs, i+1, len(train_loader), loss_D.item(), marginal_entropy_real.item(), joint_entropy_real.item()\
                            ,joint_entropy_fake.item(),loss_G.item(),marginal_entropy_fake.item(), joint_entropy_fake.item(), cross_entropy.item()))#\
              G_losses.append(loss_G.item())
              D_losses.append(loss_D.item())

              with torch.no_grad():
                  fake = G_net.decoder(fixed_noise).detach().cpu()
                  img_list.append(vutils.make_grid(torch.reshape(fake,(b_size,3,60,48))[:64], padding=2, normalize=True))
                  transform_PIL=transforms.ToPILImage()
                  transform_PIL(img_list[-1]).save(str(epoch)+"CATGAN_MNIST_Last.png")
                  del fake

            # del y_real, y_fake, cross_entropy, joint_entropy_fake, joint_entropy_real, marginal_entropy_real
            # del label, marginal_entropy_fake, fake_images, loss_G, loss_D
        # generate samples every 2 epochs for surveillance

        if epoch % 1 == 0:
            test(D_net, test_loader)
            
        # if epoch % 10  == 0:
        #     generate_and_save_images(epoch, G_net)


        # do checkpointing every 20 epochs
        # if epoch == (num_epochs - 1):
        #     torch.save(G_net.state_dict(), '%s/netG_epoch_%d.pth' % (os.path.join(checkpoints_dir), 40))
        #     torch.save(D_net.state_dict(), '%s/netD_epoch_%d.pth' % (os.path.join(checkpoints_dir), 40))

In [None]:
from random import random
def SGAN(G_net, D_net, GE, D_optimizer, G_optimizer, img_list, G_losses, D_losses, fixed_noise, train_loader, test_loader):
    print("Starting Training Loop...")
    plmean = None
    pl_length_ma = EMA(0.99)
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, (data, labels) in enumerate(train_loader) :

            b_size=data.shape[0]      
            data = data.to(device=device)
            label = F.one_hot(labels, num_classes=6).to(device=device)
            #Train D

            D_net.zero_grad()
            y_real = D_net(data)
            
            cross_entropy = crossH(label, y_real)
            joint_entropy_real = jointH(y_real)#minimize uncertainty
            marginal_entropy_real = marginalH(y_real)#maximize uncertainty


            get_latents_fn = mixed_list if random() < 0.9 else noise_list
            style = get_latents_fn(b_size, int(log2(60) - 1), 512)

            z = torch.rand(b_size, nz, nz, 1).to(device=device)
            styles, fake_images = G_net(style, z)

            y_fake = D_net(fake_images.detach())
            joint_entropy_fake = jointH(y_fake)#maximize uncertainty

            loss_D = - (- cross_entropy  + marginal_entropy_real+ joint_entropy_fake -joint_entropy_real)
           
            loss_D.backward(retain_graph=True)
            D_optimizer.step()

            #Train G


            del y_fake, data
            G_net.zero_grad()
            y_fake = D_net(fake_images)
            marginal_entropy_fake = marginalH(y_fake)#maximize uncertainty
            joint_entropy_fake = jointH(y_fake)#maximize uncertainty
            pl_lengths = pll(styles, fake_images)
            avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
            path_length = pathlength(pl_lengths, plmean)
            
            if path_length == None:
              loss_G = joint_entropy_fake - marginal_entropy_fake
            else:
              loss_G = joint_entropy_fake - marginal_entropy_fake + path_length

            loss_G.backward(retain_graph=True)
            G_optimizer.step()

            if not np.isnan(avg_pl_length):
                plmean = pl_length_ma.update_average(plmean, avg_pl_length)

            if epoch %10 == 0 and i > 5:
              
              GenEMA.EMA(GE, G_net)


            if (i+1)%20 == 0 or i ==0:

              print('Epoch [{}/{}], Step [{}/{}] \n D_loss: {:.4f}, H_x[p(y|D)] : {:.4f}, E[H[p(y|x,D)]] : {:.4f}, \
                      E[H[p(y|G(z),D)]]:{:.4f}\n G_loss: {:.4f}, H_G[p(y|D)] : {:.4f}, E[H[p(y|G(z),D)]]: {:.4f}, E[CE[y,p(y|x,D)]: {:.4f}' 
                    .format(epoch, num_epochs, i+1, len(train_loader), loss_D.item(), marginal_entropy_real.item(), joint_entropy_real.item()\
                            ,joint_entropy_fake.item(),loss_G.item(),marginal_entropy_fake.item(),joint_entropy_fake.item(), cross_entropy.item()))
              G_losses.append(loss_G.item())
              D_losses.append(loss_D.item())
              if path_length == None:
                pass
              else:
                print('Path Length: {:.4f}'.format(path_length.item()))
              print(y_real[0], y_fake[0], label[0])#, loss_D.item(), loss_G.item())
              with torch.no_grad():
                  _, fake = G_net(style, fixed_noise)
                  fake = fake.detach().cpu()
                  img_list.append(vutils.make_grid(torch.reshape(fake,(b_size,3,60,48))[:64], padding=2, normalize=True))
                  transform_PIL=transforms.ToPILImage()
                  transform_PIL(img_list[-1]).save(str(epoch)+"CATGAN_MNIST_Last.png")
                  del fake

            del y_real, y_fake, cross_entropy, joint_entropy_fake, joint_entropy_real, marginal_entropy_real
            del label, marginal_entropy_fake, fake_images, loss_G, loss_D
        # generate samples every 2 epochs for surveillance


        if epoch % 1 == 0:
            test(D_net, test_loader)

        # if epoch % 10 == 0:
        #     generate_and_save_images(epoch, G_net)
        # do checkpointing every 20 epochs
        # if epoch == (num_epochs - 1):
        #     torch.save(G_net.state_dict(), '%s/netG_epoch_%d.pth' % (os.path.join(checkpoints_dir), 40))
        #     torch.save(D_net.state_dict(), '%s/netD_epoch_%d.pth' % (os.path.join(checkpoints_dir), 40))

In [None]:
def genn(G_net, G_optimizer, img_list, G_losses, D_losses, train_loader):
    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, (data, labels) in enumerate(train_loader) :

            b_size = data.shape[0] 

            data = data.to(device=device)
            #label = labels.to(device=device)
            label = F.one_hot(labels, num_classes=6).to(device=device)
            #Train D

            G_net.zero_grad()

            mu, logvar, output = G_net(data, b_size)

            reparam = KL(mu, logvar)

            target = sigmoid(data)
            BCE = bcnloss(output, target)

            loss_G = BCE + reparam 

            loss_G.backward(retain_graph=True)
            
            G_optimizer.step()      

            if (i+1)%20 == 0 or i ==0:
              # with open(os.path.join(checkpoints_dir, 'results.txt'), 'a+') as f:
              #   print('Epoch [{}/{}], Step [{}/{}] \n D_loss: {:.4f}' 
              #         .format(epoch, num_epochs, i+1, len(train_loader), loss_G.item()), file = f)
              #   print(BCE.item(), reparam.item(), file = f)
              # G_losses.append(loss_G.item())
              # G_losses.append(loss_G.item())
              # D_losses.append(0)
              
              print('Epoch [{}/{}], Step [{}/{}] \n D_loss: {:.4f}' 
                .format(epoch, num_epochs, i+1, len(train_loader), loss_G.item()))
              G_losses.append(loss_G.item())
              D_losses.append(0)


        # if (epoch)% 10 == 0 and epoch != 0:
        #   if (epoch) % 30  == 0:
        #     err = pruning(D_net, remove_layer= True)
        #     if err == -1:
        #       break
        #   else:          
        #     pruning(D_net)
        # do checkpointing every 20 epochs
        # if epoch == (num_epochs - 1):
        #     torch.save(G_net.state_dict(), '%s/netG_epoch_%d.pth' % (os.path.join(checkpoints_dir), 40))
        #     torch.save(D_net.state_dict(), '%s/netD_epoch_%d.pth' % (os.path.join(checkpoints_dir), 40))

In [None]:
def orig(D_net, D_optimizer, img_list, D_losses, train_loader, test_loader, data):
    dat_name = data
    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, (data, labels) in enumerate(train_loader) :

            b_size = data.shape[0] 

            data = data.to(device=device)
            #label = labels.to(device=device)
            if dat_name == 'KDEF':
              label = F.one_hot(labels, num_classes=6).to(device=device)
            elif dat_name == 'MMI':
              label = F.one_hot(labels, num_classes=5).to(device=device)
            #Train D

            # G_net.zero_grad()

            # mu, logvar, output = G_net(data, b_size)

            # reparam = KL(mu, logvar)

            # target = sigmoid(data)
            # BCE = bcnloss(output, target)

            # loss_G = BCE + reparam 

            # loss_G.backward(retain_graph=True)
            
            # G_optimizer.step()      
            D_net.zero_grad()
            _, _, y_real = D_net(data)
            

            
            cross_entropy = crossH(label, y_real)

            loss_D = cross_entropy 
            loss_D.backward()
            D_optimizer.step()
            

            if (i+1)%20 == 0 or i ==0:
              # with open(os.path.join(checkpoints_dir, 'results.txt'), 'a+') as f:
              #   print('Epoch [{}/{}], Step [{}/{}] \n D_loss: {:.4f}' 
              #         .format(epoch, num_epochs, i+1, len(train_loader), loss_G.item()), file = f)
              #   print(BCE.item(), reparam.item(), file = f)
              # G_losses.append(loss_G.item())
              # G_losses.append(loss_G.item())
              # D_losses.append(0)
              
              print('Epoch [{}/{}], Step [{}/{}] \n D_loss: {:.4f}' 
                .format(epoch, num_epochs, i+1, len(train_loader), loss_D.item()))
              D_losses.append(loss_D.item())
              # test(D_net)
              print(y_real[0],label[0])

        # if (epoch)% 10 == 0 and epoch != 0:
        #   if (epoch) % 30  == 0:
        #     err = pruning(D_net, remove_layer= True)
        #     if err == -1:
        #       break
        #   else:          
        #     pruning(D_net)
        # do checkpointing every 20 epochs
        # if epoch == (num_epochs - 1):
        #     torch.save(G_net.state_dict(), '%s/netG_epoch_%d.pth' % (os.path.join(checkpoints_dir), 40))
    if D_net.__class__.__name__ == 'VGG_16':
      torch.save(D_net.state_dict(), os.path.join(checkpoints_dir, 'D_net_{}_VGG_NOBN.pth'.format(dat_name)))
    else:
      torch.save(D_net.state_dict(), os.path.join(checkpoints_dir, 'D_net_{}_VGG_BN.pth'.format(dat_name)))

In [None]:
def pltfig():
  plt.figure(figsize=(10,5))
  plt.title("Generator and Discriminator Loss During Training")
  plt.plot(G_losses,label="G")
  plt.plot(D_losses,label="D")
  plt.xlabel("iterations")
  plt.ylabel("Loss")
  plt.legend()
  plt.savefig(os.path.join(checkpoints_dir, 'train_graph.png'))

In [None]:
from IPython.display import HTML
def giffig():
  writergif = animation.PillowWriter(fps=3) 

  fig = plt.figure(figsize=(8,8))
  plt.axis("off")
  ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
  ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
  ani.save(os.path.join(checkpoints_dir, 'gen_result.gif'), writer = writergif)
  fig.show()

## Testing

Test the trained network.

It first shows the array of images from the generator

It also visualises the result of the discriminator
- Shows image of a data in the test set
- Generates a bar chart for the particular data, with output of the discriminator (probability of the input data belong to a class)
- Checks whether the highest probability matches with the actual labeled block
 - If True - blue, and if False - red





In [None]:
def plot_image(i, predictions_array, true_label, img, class_names):
  predictions_array, true_label, img = predictions_array[i].cpu().detach(), true_label[i], img[i]
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])

  predicted_label = np.argmax(predictions_array).numpy()

  if predicted_label == true_label.numpy():
    color = 'blue'
  else:
    color = 'red'

  plt.imshow(np.transpose(img,(1,2,0)))

  plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                                100*predictions_array.max().numpy(),
                                class_names[true_label]),
                                color=color)

def plot_value_array(i, predictions_array, true_label):
  predictions_array, true_label = predictions_array[i].cpu().detach(), true_label[i]
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])
  thisplot = plt.bar(range(6), predictions_array, color="#777777")
  plt.ylim([0, 1])
  predicted_label = np.argmax(predictions_array)

  thisplot[predicted_label].set_color('red')
  thisplot[true_label].set_color('blue')

In [None]:
def result(predictions, test_la, A, class_names):
  num_rows = 25
  num_cols = 10
  num_images = num_rows*num_cols
  plt.figure(figsize=(2*2*num_cols, 2*num_rows))
  for i in range(num_images):
    plt.subplot(num_rows, 2*num_cols, 2*i+1)
    plot_image(i, predictions, test_la, A, class_names)
    plt.subplot(num_rows, 2*num_cols, 2*i+2)
    plot_value_array(i, predictions, test_la)
  plt.savefig(os.path.join(checkpoints_dir, 'test_result.png'))

In [None]:
def accuracy(predictions, test_la):
  correctness = 0
  for i in range(predictions.shape[0]):
    if np.argmax(predictions[i].cpu().detach()) == test_la[i]:
      correctness += 1
    else:
      correctness += 0


  # with open(os.path.join(checkpoints_dir, 'accuracy.txt'), 'w') as f:
  #     print(correctness/predictions.shape[0], file=f)  # Python 3.x
  return correctness/predictions.shape[0]

def feature_ext(predictions, features1, features2, class_names, data, name):
  n = None
  for i in range(predictions.shape[0]):
    n = class_names[np.argmax(predictions[i].cpu().detach())]
    feature = features1.cpu().detach().numpy()
    feature2 = features2.cpu().detach().numpy()
    if os.path.isdir(os.path.join(checkpoints_dir, '{}_{}'.format(data,name))) == True:
      pass
    else:
      os.makedirs(os.path.join(checkpoints_dir, '{}_{}'.format(data,name)))
    with open(os.path.join(checkpoints_dir, '{}_{}'.format(data,name), '{}_{}_{}_{}_feature0.npy'.format(data, name, n, i)), 'wb') as f:
      np.save(f,feature[i,:])  # Python 3.x
    with open(os.path.join(checkpoints_dir, '{}_{}'.format(data,name), '{}_{}_{}_{}_feature1.npy'.format(data, name, n, i)), 'wb') as f:
      np.save(f,feature2[i,:])  # Python 3.x


In [None]:
def test(D_net, test_loader, data, name = None, pruning = None):
  test = test_loader
  if data == 'KDEF':
    class_names = ['Angry',
    'Frown', 'Neutral',
    'Sad',
    'Smile',
    'Surprise']
  elif data == 'MMI':
    class_names = ['Angry',
    'Frown',
    'Sad',
    'Smile',
    'Surprise']

  for test_im, test_la in test:
    with torch.no_grad():
        # predictions = D_net(test_im.to(device).half()).float()
        features, features1, predictions = D_net(test_im.to(device))

    # A = test_im.cpu()
    # A -= A.min()
    # A /= A.max()
    acc = accuracy(predictions, test_la)
    if pruning == True:
      print(acc)
      return acc
    else:
      print(acc)
    if name != None:
      feature_ext(predictions, features, features1, class_names, data, name)
      with open(os.path.join(checkpoints_dir, '{}_{}_accuracy.txt'.format(name, data)), 'w') as f:
        print(acc, file=f)  # Python 3.x
    # result(predictions, test_la, A, class_names)


In [None]:
def runn(gen, n, opt, af, lr, train_loader, test_loader):
  img_list = []
  G_losses = []
  D_losses = []
  G_net = gen_sel(gen, af)
  D_net = disc_sel(n, af)
  G_optimizer = opti_sel(G_net, option = opt, lr=lr)
  D_optimizer = opti_sel(D_net, option = opt, lr=lr)

  if gen == 'vae':
    genn(G_net, G_optimizer, img_list, G_losses, D_losses, train_loader)
    VAEGAN(G_net, D_net, D_optimizer, G_optimizer, img_list, G_losses, D_losses, fixed_noise, train_loader, test_loader)
  elif gen == 'style':
    GE = gen_sel(gen, af)
    for p in GE.parameters():
        p.requires_grad = False
    SGAN(G_net, D_net, GE, D_optimizer, G_optimizer, img_list, G_losses, D_losses, fixed_noise, train_loader, test_loader)
  plt.close('all')
  del G_net, D_net, G_optimizer, D_optimizer


In [None]:
def runall(n, opt, af, lr, data, train_loader, test_loader):
  img_list = []
  G_losses = []
  D_losses = []
  G_net = gen_sel('catgan', af)
  if data == 'MMI':
    D_net = disc_sel(n, 5, af)
    D_net2 = disc_sel('vggbn', 5, af)
  elif data == 'KDEF':
    D_net = disc_sel(n, 6, af)
    D_net2 = disc_sel('vggbn', 6, af)
  G_optimizer = opti_sel(G_net, option = opt, lr=lr)
  D_optimizer = opti_sel(D_net, option = opt, lr=lr)
  D_optimizer2 = opti_sel(D_net2, option = opt, lr = lr)

  GAN(G_net, D_net, D_optimizer, G_optimizer, img_list, G_losses, D_losses, fixed_noise, train_loader, test_loader, data)
  test(D_net, test_loader, data, 'VGG_GAN')

  orig(D_net2, D_optimizer2, img_list, D_losses, train_loader, test_loader, data)
  test(D_net2, test_loader, data, 'VGG_NOBN')
  del D_net, D_optimizer
  if data == 'MMI':
    D_net = disc_sel(n, 5, af)
  elif data == 'KDEF':
    D_net = disc_sel(n, 6, af)
  D_optimizer = opti_sel(D_net, option = opt, lr=lr)

  orig(D_net, D_optimizer, img_list, D_losses, train_loader, test_loader, data)
  test(D_net, test_loader, data, 'VGG_BN')
  plt.close('all')
  del G_net, D_net, D_net2, D_optimizer2, G_optimizer, D_optimizer

In [None]:
# root = "/content/drive/My Drive/Colab Notebooks/pytorch_check_mod/"

# fixed_noise = torch.rand(batch_size, nz, device=device)
# arr = os.listdir(root)
# activations = ['LeakyReLU', 'ClippedReLU', 'ReLU', ]
# actis = ['leaky', 'clipped', 'relu']
# optimizers = ['SGDM', 'RMSProp', 'Adam']
# optis = ['sgdm', 'rmsprop', 'adam']
# gens = ['vae', 'style']
# gen_n = ['VAE','StyleGAN']

# n = 'vgg'
# net = 'VGG 16'
# lr 
# for m in range(2):
#   if m == 0:
#     dataroot = data_sel('MMI')
#     data = 'MMI'
#   elif m == 1:
#     dataroot = data_sel('KDEF')
#     data = 'KDEF'
#   train_loader, test_loader = datagen(dataroot)
#   for i in range(len(activations)):
#       af = actis[i]
#       checkpoints_dir = os.path.join(root,net,activations[i],optimizers[2])
#       # print(type(G_net).__name__)
#       print('learning with','VGG', '-', af, '-', optis[2])
#       lr = 1e-4
#       runall(n, optis[2], af, lr, data, train_loader, test_loader)

In [None]:
# root = "/content/drive/My Drive/Colab Notebooks/pytorch_check_mod/"

# fixed_noise = torch.rand(batch_size, nz, device=device)
# arr = os.listdir(root)
# activations = ['LeakyReLU', 'ClippedReLU', 'ReLU', ]
# actis = ['leaky', 'clipped', 'relu']
# optimizers = ['SGDM', 'RMSProp', 'Adam']
# optis = ['sgdm', 'rmsprop', 'adam']
# gens = ['vae', 'style']
# gen_n = ['VAE','StyleGAN']

# n = 'vgg'
# net = 'VGG 16'

# for m in range(2):
#   if m == 0:
#     dataroot = data_sel('MMI')
#     data = 'MMI'
#   elif m == 1:
#     dataroot = data_sel('KDEF')
#     data = 'KDEF'
#   train_loader, test_loader = datagen(dataroot)
#   for i in range(len(activations)):
#       af = actis[i]
#       checkpoints_dir = os.path.join(root,net,activations[i],optimizers[1])
#       # print(type(G_net).__name__)
#       print('learning with','VGG', '-', af, '-', optis[1])
#       lr = 1e-4
#       runall(n, optis[1], af, lr, data, train_loader, test_loader)

In [None]:
root = "/content/drive/My Drive/Colab Notebooks/pytorch_check_mod/"

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
# fixed_noise = torch.rand(batch_size, nz, device=device)#.half()
fixed_noise = torch.rand(batch_size, nz, device=device)#.half()
# fixed_noise = torch.rand(batch_size, nz, 1, 1, device=device)
arr = os.listdir(root)
# networks = ['ResNet 50', 'AlexNet', 'VGG 16', 'Original', 'GoogleNet']
# nets = ['resnet', 'alex', 'vgg', 'catgan', 'google']
activations = ['LeakyReLU', 'ClippedReLU', 'ReLU', ]
actis = ['leaky', 'clipped', 'relu']
optimizers = ['SGDM', 'RMSProp', 'Adam']
optis = ['sgdm', 'rmsprop', 'adam']
gens = ['vae', 'style']
gen_n = ['VAE','StyleGAN']

n = 'vgg'
net = 'VGG 16'

for m in range(2):
  if m == 0:
    dataroot = data_sel('MMI')
    data = 'MMI'
  elif m == 1:
    dataroot = data_sel('KDEF')
    data = 'KDEF'
  train_loader, test_loader = datagen(dataroot)
  for i in range(len(activations)):
      af = actis[i]
      checkpoints_dir = os.path.join(root,net,activations[i],optimizers[0])
      # print(type(G_net).__name__)
      print('learning with','VGG', '-', af, '-', optis[0])
      lr = 1e-4
      runall(n, optis[0], af, lr, data, train_loader, test_loader)


learning with VGG - leaky - sgdm
Starting Training Loop...
Epoch [0/50], Step [1/6] 
 D_loss: 49.5948, H_x[p(y|D)] : 1.6073, E[H[p(y|x,D)]] : 1.5821,                       E[H[p(y|G(z),D)]]:1.5632
 G_loss: -0.0227, H_G[p(y|D)] : 1.5859, E[H[p(y|G(z),D)]]: 1.5632, E[CE[y,p(y|x,D)]: 51.2033]
Epoch [1/50], Step [1/6] 
 D_loss: 49.3729, H_x[p(y|D)] : 1.5335, E[H[p(y|x,D)]] : 1.4973,                       E[H[p(y|G(z),D)]]:1.5254
 G_loss: -0.0177, H_G[p(y|D)] : 1.5431, E[H[p(y|G(z),D)]]: 1.5254, E[CE[y,p(y|x,D)]: 50.9432]
Epoch [2/50], Step [1/6] 
 D_loss: 42.1049, H_x[p(y|D)] : 1.5008, E[H[p(y|x,D)]] : 1.4241,                       E[H[p(y|G(z),D)]]:1.5180
 G_loss: -0.0319, H_G[p(y|D)] : 1.5499, E[H[p(y|G(z),D)]]: 1.5180, E[CE[y,p(y|x,D)]: 43.6640]
Epoch [3/50], Step [1/6] 
 D_loss: 40.9008, H_x[p(y|D)] : 1.5206, E[H[p(y|x,D)]] : 1.3416,                       E[H[p(y|G(z),D)]]:1.5424
 G_loss: -0.0566, H_G[p(y|D)] : 1.5990, E[H[p(y|G(z),D)]]: 1.5424, E[CE[y,p(y|x,D)]: 42.6000]
Epoch [4/50],