In [1]:
import torch
from torch import nn
import numpy as np
from torchvision.utils import save_image

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# 1. Dataset

In [2]:
import torchvision

img_size = 32
num_classes = 10

In [3]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((img_size, img_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
])

mnist_images = torchvision.datasets.MNIST(root='mnist_data', train=True, 
                                    download=True, transform=transform)

In [4]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32
dataloader = DataLoader(mnist_images, batch_size=BATCH_SIZE, shuffle=True)

# 2. Model

In [5]:
channels = 1
img_shape = (channels, img_size, img_size)
latent_dim = 100

In [6]:
emb_dim=16
label_emb = nn.Embedding(num_classes, emb_dim, _freeze=True) 
label_emb.to(device)

Embedding(10, 16)

In [9]:
print(label_emb.weight)

Parameter containing:
tensor([[ 0.9258,  2.3938, -0.1443, -0.4354,  1.6637,  1.3191, -1.3365, -0.6047,
         -1.3645, -1.0489, -0.4607, -0.1710, -1.6421,  1.1338,  1.0476,  1.0769],
        [-0.1661,  0.5422, -0.9975, -1.6783,  1.1237, -1.3186, -1.6384, -0.4724,
         -0.5700, -0.0405,  1.0312,  0.9640,  0.6198,  1.5520, -1.2713, -0.3896],
        [-0.9217, -1.4727,  0.1746, -1.7818, -1.0208, -1.1770,  0.1896, -1.1848,
         -1.1443, -0.0965,  0.1967,  0.7601,  0.7213,  0.9259, -2.3169, -0.6858],
        [-0.6363, -0.0180,  0.0915,  1.8618,  0.2904,  0.4646,  1.3824, -0.3942,
          0.8689, -0.1972,  1.4784, -0.2482,  0.8844, -0.6958, -0.3277, -0.5520],
        [ 0.3253,  0.5345,  2.2241,  2.2806,  0.8769,  1.9120, -0.4606, -1.5523,
         -0.4705,  0.3438, -0.4633, -1.6691, -0.4672,  0.3611, -1.3241, -0.7577],
        [ 0.1736,  0.0759,  0.0550,  1.0535,  2.5562,  0.0740, -0.0260, -0.3254,
          2.6763, -0.0517, -0.1028, -1.1810, -1.1921, -0.5814,  0.0085,  0.3317],


In [10]:
class Generator(nn.Module):
    def __init__(self, num_classes, emb_dim):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim + emb_dim, 256),
            nn.BatchNorm1d(256,),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
        
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
        
    def forward(self, z, label):
        cond = label_emb(label)
        x = torch.cat([z, cond], 1)
        img = self.model(x)
        img = img.view(img.size(0), *img_shape)
        return img

In [11]:
class Descriminator(nn.Module):
    def __init__(self, num_classes, emb_dim):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape))+emb_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, img, label):
        img_flat = img.view(img.size(0), -1)
        cond = label_emb(label)
        x = torch.cat([img_flat, cond], 1)
        validity = self.model(x)
        return validity

In [12]:
generator = Generator(num_classes=num_classes, emb_dim=16)
discriminator = Descriminator(num_classes=num_classes, emb_dim=16)

In [13]:
generator.to(device)

Generator(
  (model): Sequential(
    (0): Linear(in_features=116, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Linear(in_features=256, out_features=512, bias=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Linear(in_features=512, out_features=1024, bias=True)
    (7): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Linear(in_features=1024, out_features=1024, bias=True)
    (10): Tanh()
  )
)

In [14]:
discriminator.to(device)

Descriminator(
  (model): Sequential(
    (0): Linear(in_features=1040, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Linear(in_features=256, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

# 3. Training

In [15]:
EPOCHS = 1

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

criterion = nn.BCELoss()
hist = {
        "train_G_loss": [],
        "train_D_loss": [],
}

In [16]:
for epoch in range(EPOCHS):
    running_G_loss = 0.0
    running_D_loss = 0.0

    for i, (imgs, labels) in enumerate(dataloader):
        real_imgs = imgs.to(device)
        labels = labels.to(device)
        
        real_labels = torch.ones((imgs.shape[0], 1)).to(device)
        fake_labels = torch.zeros((imgs.shape[0], 1)).to(device)


        # -------------------------- Train Generator --- 
        optimizer_G.zero_grad()
        
        # Noise input for Generator
        z = torch.randn((imgs.shape[0], latent_dim)).to(device)

        gen_imgs = generator(z, labels)
        validity = discriminator(gen_imgs, labels)
        G_loss = criterion(validity, real_labels)
        running_G_loss += G_loss.item()

        G_loss.backward()
        optimizer_G.step()


        # -------------- Train Discriminator --- 
        optimizer_D.zero_grad()

        real_validity = discriminator(real_imgs, labels)
        real_loss = criterion(real_validity, real_labels)

        fake_validity = discriminator(gen_imgs.detach(), labels)
        fake_loss = criterion(fake_validity, fake_labels)
        
        D_loss = (real_loss + fake_loss) / 2
        running_D_loss += D_loss.item()

        D_loss.backward()
        optimizer_D.step()
    
    epoch_G_loss = running_G_loss / len(dataloader)
    epoch_D_loss = running_D_loss / len(dataloader)
    
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Train G Loss: {epoch_G_loss:.4f}, Train D Loss: {epoch_D_loss:.4f}")

    hist["train_G_loss"].append(epoch_G_loss)
    hist["train_D_loss"].append(epoch_D_loss)

Epoch [1/1], Train G Loss: 26.3742, Train D Loss: 0.0461


In [17]:
print(label_emb.weight)

Parameter containing:
tensor([[ 0.9258,  2.3938, -0.1443, -0.4354,  1.6637,  1.3191, -1.3365, -0.6047,
         -1.3645, -1.0489, -0.4607, -0.1710, -1.6421,  1.1338,  1.0476,  1.0769],
        [-0.1661,  0.5422, -0.9975, -1.6783,  1.1237, -1.3186, -1.6384, -0.4724,
         -0.5700, -0.0405,  1.0312,  0.9640,  0.6198,  1.5520, -1.2713, -0.3896],
        [-0.9217, -1.4727,  0.1746, -1.7818, -1.0208, -1.1770,  0.1896, -1.1848,
         -1.1443, -0.0965,  0.1967,  0.7601,  0.7213,  0.9259, -2.3169, -0.6858],
        [-0.6363, -0.0180,  0.0915,  1.8618,  0.2904,  0.4646,  1.3824, -0.3942,
          0.8689, -0.1972,  1.4784, -0.2482,  0.8844, -0.6958, -0.3277, -0.5520],
        [ 0.3253,  0.5345,  2.2241,  2.2806,  0.8769,  1.9120, -0.4606, -1.5523,
         -0.4705,  0.3438, -0.4633, -1.6691, -0.4672,  0.3611, -1.3241, -0.7577],
        [ 0.1736,  0.0759,  0.0550,  1.0535,  2.5562,  0.0740, -0.0260, -0.3254,
          2.6763, -0.0517, -0.1028, -1.1810, -1.1921, -0.5814,  0.0085,  0.3317],
