In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, datasets
import pytorch_lightning as pl

In [2]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.image = nn.Sequential(
            nn.ConvTranspose2d(100, 64 * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True)
        )
        self.label = nn.Sequential(
            nn.ConvTranspose2d(10, 64 * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True)
        )
        self.main = nn.Sequential(
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2, 64, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, image, label):
        image = self.image(image)
        label = self.label(label)
        incat = torch.cat((image, label), dim=1)
        return self.main(incat)

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.image = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.label = nn.Sequential(
            nn.Conv2d(10, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.main = nn.Sequential(
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False)
        )

    def forward(self, image, label):
        image = self.image(image)
        label = self.label(label)
        incat = torch.cat((image, label), dim=1)
        return self.main(incat)

In [4]:
class CGAN(pl.LightningModule):
	def __init__(self):
		super().__init__()
		self.generator = Generator()
		self.discriminator = Discriminator()

		self.automatic_optimization = False

	def forward(self, z, y):
		return self.generator(z, y)

	def generator_step(self, x):
		z = torch.randn(x.shape[0], 100, 1, 1).to("cuda")

		y = torch.randint(0, 10, size=(x.shape[0], 100, 1, 1), device="cuda")

		generated_imgs = self(z, y)

		d_output = torch.squeeze(self.discriminator(generated_imgs, y))

		g_loss = nn.BCELoss()(d_output,
							torch.ones(x.shape[0], 100, 1, 1).to("cuda"))

		return g_loss

	def discriminator_step(self, x, y):
		d_output = torch.squeeze(self.discriminator(x, y))
		loss_real = nn.BCELoss()(d_output,
								torch.ones(x.shape[0], 100, 1, 1).to("cuda"))

		z = torch.randn(x.shape[0], 100, 1, 1).to("cuda")
		y = torch.randint(0, 10, size=(x.shape[0], 100, 1, 1), device="cuda")

		generated_imgs = self(z, y)
		d_output = torch.squeeze(self.discriminator(generated_imgs, y))
		loss_fake = nn.BCELoss()(d_output,
								torch.zeros(x.shape[0], 100, 1, 1).to("cuda"))

		return loss_real + loss_fake

	def training_step(self, batch, batch_idx):
		X, y = batch

		g_optimizer, d_optimizer = self.optimizers()

		g_loss = self.generator_step(X)
		
		d_loss = self.discriminator_step(X, y)

		return g_loss - d_loss

	def configure_optimizers(self):
		g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0002)
		d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)
		return [g_optimizer, d_optimizer], []

In [5]:
criterion = nn.BCEWithLogitsLoss()

real_label_num = 1.
fake_label_num = 0.

label_1hots = torch.zeros(10,10)
for i in range(10):
    label_1hots[i,i] = 1
label_1hots = label_1hots.view(10,10,1,1).to("cuda")

label_fills = torch.zeros(10, 10, 28, 28)
ones = torch.ones(28, 28)
for i in range(10):
    label_fills[i][i] = ones
label_fills = label_fills.to("cuda")

fixed_noise = torch.randn(100, 100, 1, 1).to("cuda")
fixed_label = label_1hots[torch.arange(10).repeat(10).sort().values]

In [6]:
mnist_transforms = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.5], std=[0.5]),
                                    transforms.Lambda(lambda x: x.view(-1, 784)),
                                    transforms.Lambda(lambda x: torch.squeeze(x))
                                    ])

data = datasets.MNIST(root='../data/MNIST', download=True, transform=mnist_transforms)

mnist_dataloader = DataLoader(data, batch_size=128, shuffle=True, num_workers=20) 

model = CGAN()

trainer = pl.Trainer(max_epochs=10, accelerator="gpu", devices=1)
trainer.fit(model, mnist_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 3.2 M 
1 | discriminator | Discriminator | 2.6 M 
------------------------------------------------
5.8 M     Trainable params
0         Non-trainable params
5.8 M     Total params
23.389    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

RuntimeError: Given transposed=1, weight of size [10, 256, 4, 4], expected input[128, 100, 1, 1] to have 10 channels, but got 100 channels instead

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/