In [1]:
import torch
from torch import nn
import numpy as np
from torchvision.utils import save_image

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# 1. Dataset

In [2]:
import torchvision

img_size = 32
num_classes = 10

In [3]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((img_size, img_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
])

mnist_images = torchvision.datasets.MNIST(root='mnist_data', train=True, 
                                    download=True, transform=transform)

In [4]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32
dataloader = DataLoader(mnist_images, batch_size=BATCH_SIZE, shuffle=True)

# 2. Model

In [5]:
channels = 1
img_shape = (channels, img_size, img_size)
latent_dim = 100

In [6]:
class Generator(nn.Module):
    def __init__(self, num_classes, emb_dim):
        super().__init__()

        self.label_emb = nn.Embedding(num_classes, emb_dim) 
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim + emb_dim, 256),
            nn.BatchNorm1d(256,),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
        
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
        
    def forward(self, z, label):
        cond = self.label_emb(label)
        x = torch.cat([z, cond], 1)
        img = self.model(x)
        img = img.view(img.size(0), *img_shape)
        return img

In [7]:
class Descriminator(nn.Module):
    def __init__(self, num_classes, emb_dim):
        super().__init__()

        self.label_emb = nn.Embedding(num_classes, emb_dim) 

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape))+emb_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, img, label):
        img_flat = img.view(img.size(0), -1)
        cond = self.label_emb(label)
        x = torch.cat([img_flat, cond], 1)
        validity = self.model(x)
        return validity

In [8]:
generator = Generator(num_classes=num_classes, emb_dim=16)
discriminator = Descriminator(num_classes=num_classes, emb_dim=16)

In [9]:
generator.to(device)


Generator(
  (label_emb): Embedding(10, 16)
  (model): Sequential(
    (0): Linear(in_features=116, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Linear(in_features=256, out_features=512, bias=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Linear(in_features=512, out_features=1024, bias=True)
    (7): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Linear(in_features=1024, out_features=1024, bias=True)
    (10): Tanh()
  )
)

In [10]:
discriminator.to(device)


Descriminator(
  (label_emb): Embedding(10, 16)
  (model): Sequential(
    (0): Linear(in_features=1040, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Linear(in_features=256, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [11]:
generator.label_emb.weight

Parameter containing:
tensor([[ 1.3921e+00,  1.5551e+00, -2.0304e+00,  8.5707e-01, -4.8496e-01,
         -1.7082e+00,  1.4773e+00, -7.3380e-01, -1.9674e+00,  2.0983e+00,
          3.4575e-01,  3.3494e-01, -8.1928e-01,  1.7746e-01, -8.3734e-01,
          1.9602e+00],
        [ 2.0074e-01,  1.2127e+00, -3.9997e-02,  1.3045e+00,  5.1198e-01,
         -5.0860e-01, -6.5350e-01,  1.5844e+00,  5.5103e-02, -1.4875e+00,
         -2.5723e-01, -4.1075e-01, -2.5581e-01,  1.4165e+00, -1.0728e+00,
          1.7852e-01],
        [ 8.2097e-01,  5.6045e-01,  1.1845e-01,  1.1224e+00,  1.5691e+00,
         -5.7020e-01,  1.8601e-01, -9.4479e-01,  3.8005e-01,  2.1652e-01,
         -6.0831e-01, -1.6018e+00,  6.7142e-01,  1.4244e-01,  9.4347e-01,
         -7.7309e-01],
        [-1.3265e+00,  1.2587e-01, -1.6731e+00, -3.8101e-01,  8.7552e-01,
         -8.3858e-01,  1.3657e-01, -1.9664e-02, -2.2608e-01, -4.7246e-01,
         -3.6512e-01, -8.3210e-02, -1.6211e-01,  1.9598e+00,  1.2385e+00,
          6.8994e-01]

# 3. Training

In [12]:
EPOCHS = 1

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

criterion = nn.BCELoss()
hist = {
        "train_G_loss": [],
        "train_D_loss": [],
}

In [13]:
for epoch in range(EPOCHS):
    running_G_loss = 0.0
    running_D_loss = 0.0

    for i, (imgs, labels) in enumerate(dataloader):
        real_imgs = imgs.to(device)
        labels = labels.to(device)
        
        real_labels = torch.ones((imgs.shape[0], 1)).to(device)
        fake_labels = torch.zeros((imgs.shape[0], 1)).to(device)


        # -------------------------- Train Generator --- 
        optimizer_G.zero_grad()
        
        # Noise input for Generator
        z = torch.randn((imgs.shape[0], latent_dim)).to(device)

        gen_imgs = generator(z, labels)
        validity = discriminator(gen_imgs, labels)
        G_loss = criterion(validity, real_labels)
        running_G_loss += G_loss.item()

        G_loss.backward()
        optimizer_G.step()


        # -------------- Train Discriminator --- 
        optimizer_D.zero_grad()

        real_validity = discriminator(real_imgs, labels)
        real_loss = criterion(real_validity, real_labels)

        fake_validity = discriminator(gen_imgs.detach(), labels)
        fake_loss = criterion(fake_validity, fake_labels)
        
        D_loss = (real_loss + fake_loss) / 2
        running_D_loss += D_loss.item()

        D_loss.backward()
        optimizer_D.step()
    
    epoch_G_loss = running_G_loss / len(dataloader)
    epoch_D_loss = running_D_loss / len(dataloader)
    
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Train G Loss: {epoch_G_loss:.4f}, Train D Loss: {epoch_D_loss:.4f}")

    hist["train_G_loss"].append(epoch_G_loss)
    hist["train_D_loss"].append(epoch_D_loss)

Epoch [1/1], Train G Loss: 15.6032, Train D Loss: 0.0677


In [14]:
generator.label_emb.weight

Parameter containing:
tensor([[ 1.4145,  1.5255, -2.0129,  0.8499, -0.5002, -1.6947,  1.4860, -0.7572,
         -1.9573,  2.0881,  0.3743,  0.3425, -0.7990,  0.1912, -0.8259,  1.9661],
        [ 0.2331,  1.2357, -0.0441,  1.3109,  0.5102, -0.5107, -0.6521,  1.5919,
          0.0505, -1.4527, -0.2328, -0.3740, -0.2695,  1.3991, -1.0757,  0.1832],
        [ 0.8225,  0.5861,  0.0866,  1.1562,  1.6056, -0.6152,  0.1480, -0.8960,
          0.3426,  0.2271, -0.6191, -1.5825,  0.6270,  0.1047,  0.9060, -0.8080],
        [-1.3245,  0.1510, -1.6815, -0.3453,  0.9090, -0.8721,  0.1065,  0.0164,
         -0.2515, -0.4504, -0.3913, -0.0860, -0.1969,  1.9252,  1.2106,  0.6615],
        [ 0.8851, -1.5616, -1.3409,  0.0329, -1.4201,  0.0367,  0.5719, -0.8645,
          0.0535, -2.2126,  1.8443, -1.6660,  0.6919,  1.0952,  0.5706,  0.2790],
        [ 0.7413,  0.6840,  0.3124,  2.4086,  1.1710, -0.4486, -0.0706, -0.3337,
         -0.3261, -1.0436,  0.9440,  0.1324, -0.8684, -2.7197,  0.5176,  0.0585],
