In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
!pip install torch-snippets
from torch_snippets import *

Collecting torch-snippets
  Downloading torch_snippets-0.524-py3-none-any.whl (79 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/79.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.6/79.6 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Collecting dill (from torch-snippets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
Collecting loguru (from torch-snippets)
  Downloading loguru-0.7.2-py3-none-any.whl (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.5/62.5 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
Collecting typing (from torch-snippets)
  Downloading typing-3.7.4.3.tar.gz (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.6/78.6 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25

In [4]:
class Generator(nn.Module):
  def __init__(self,classes,channels,img_size,latent_dim):
    super().__init__()

    self.classes = classes
    self.channels = channels
    self.img_size = img_size
    self.latent_dim = latent_dim

    self.img_shape = (self.channels,self.img_size,self.img_size)

    self.label_embedding = nn.Embedding(self.classes, self.classes)
    self.model = nn.Sequential(
        *self._create_layer(self.latent_dim+self.classes,128,False),
        *self._create_layer(128,256),
        *self._create_layer(256,512),
        *self._create_layer(512,1024),
        nn.Linear(1024,int(np.prod(self.img_shape))),
        nn.Tanh()
    )

  def _create_layers(self,in_ch,out_ch,normalize=True):
      layers = [nn.Linear(in_ch,out_ch)]
      if normalize:
        layers += nn.BatchNorm1d(out_ch)
      layesr += nn.LeakyReLU(0.2,inplace=True)

      return layers

  def forward(self,noise,labels):
    z = torch.cat((self.label_embedding(labels),noise),-1)
    x = self.model(z)
    x = x.view(x.size(0) , *self.img_shape)

    return x

In [5]:
class Discriminator(nn.Module):
  def __init__(self,classes,channels,img_size,latent_dim):
    super().__init__()

    self.classes = classes
    self.channels = channels
    self.img_size = img_size
    self.latent_dim = latent_dim
    self.img_shape = (self.channels,self.img_size,self.img_size)

    self.label_embedding = nn.Embedding(self.classes,self.classes)

    self.model = nn.Sequential(
        *self._create_layers(self.classes+int(np.prod(self.img_shape)),1024,False,True),
        *self._create_layers(1024,512,True,True),
        *self._create_layers(512,256,True,True),
        *self._create_layers(256,128,False,True),
        *self._create_layers(128,1,False,False),
        nn.Sigmoid()
    )

    self.adv_loss = nn.BCELoss()

  def _create_layer(self,in_ch,out_ch,drop_out=True,act_func=True):
    layers = [nn.Linear(in_ch,out_ch)]
    if drop_out:
      layers += nn.Dorpout(0.4)
    if act_func:
      layers += nn.LeakyReLU(0.2,inplace=True)

    return layers

  def forward(self,image,labels):
    x = torch.cat((image.view(image.size(0),-1),self.label_embedding(labels)),-1)
    return self.model(x)

  def loss(self,output,labels):
    return self.adv_loss(output,input)

In [None]:
class Model():
  def __init__(self,data_loader,classes,channels,img_size,latent_dim):
    self.device = "cuda" if torch.cuda.is_available() else "cpu"

    self.data_loader = data_loader
    self.classes = classes
    self.channels = channels
    self.img_size = img_size
    self.latent_dim = latent_dim
    self.image_shape = (self.channels, self.img_size, self.img_size)

    self.GenNet = Generator(self.classes,self.channels,self.img_size,self.latent_dim)
    self.DisNet = Discriminator(self.classes,self.channels,self.img_size,self.latent_dim)

    self.Gen_Optim = self.get_optim(self.GenNet)
    self.Dis_Optim = self.get_optim(self.DisNet)

  def get_optim(self,model):
    return torch.optim.Adam(model.parameters(),lr= 1e-3,betas = (0.5,0.999))

  def train(self,epochs):
    self.GenNet.train()
    self.DisNet.train()

    viz_noise = torch.randn(self.data_loader.batch_size,self.latent_dim,device = self.device)
    viz_label = torch.LongTensor(np.array([num for _ in range(nrows) for num in range(8)]))