In [136]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, datasets
import pytorch_lightning as pl
import torchvision.utils as vutils

In [137]:
class NumGenNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.image = nn.Sequential(
            nn.ConvTranspose2d(100, 64 * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True)
        )
        self.label = nn.Sequential(
            nn.ConvTranspose2d(10, 64 * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True)
        )
        self.main = nn.Sequential(
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(64*2, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, image, label):
        image = self.image(image)
        label = self.label(label)
        incat = torch.cat((image, label), dim=1)
        return self.main(incat)

In [138]:
class NumDiscNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.image = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.label = nn.Sequential(
            nn.Conv2d(10, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.main = nn.Sequential(
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
        )

    def forward(self, image, label):
        image = self.image(image)
        label = self.label(label)
        incat = torch.cat((image, label), dim=1)
        return self.main(incat)

In [139]:
loss_fun = nn.BCEWithLogitsLoss()

In [None]:
criterion = nn.BCEWithLogitsLoss()

real_label_num = 1.
fake_label_num = 0.

label_1hots = torch.zeros(10,10)
for i in range(10):
    label_1hots[i,i] = 1
label_1hots = label_1hots.view(10,10,1,1).to("cuda")

label_fills = torch.zeros(10, 10, 32, 32)
ones = torch.ones(32, 32)
for i in range(10):
    label_fills[i][i] = ones
label_fills = label_fills.to("cuda")

fixed_noise = torch.randn(100, 100, 1, 1).to("cuda")
fixed_label = label_1hots[torch.arange(10).repeat(10).sort().values]

In [None]:
class CGAN(pl.LightningModule):
	def __init__(self):
		super().__init__()
		self.generator = NumGenNet()
		self.discriminator = NumDiscNet()

		self.automatic_optimization = False

	def forward(self, z, y):
		return self.generator(z, y)

	def generator_step(self, x):
		z = torch.randn((x.shape[0], 100, 1, 1), dtype=torch.float).to("cuda")

		y = torch.randint(0, 10, size=(x.shape[0], 10, 1, 1), dtype=torch.float, device="cuda")

		print(z.shape)
		print(y.shape)

		generated_imgs = self(z, y)

		d_output = torch.squeeze(self.discriminator(generated_imgs, y))

		g_loss = loss_fun(d_output,
							torch.ones((x.shape[0], 10, 1, 1), dtype=torch.float).to("cuda"))

		return g_loss

	def discriminator_step(self, x, y):
		d_output = torch.squeeze(self.discriminator(x, y))
		loss_real = loss_fun(d_output,
								torch.ones((x.shape[0], 10, 1, 1), dtype=torch.float).to("cuda"))

		z = torch.randn((x.shape[0], 100, 1, 1), dtype=torch.float).to("cuda")
		y = torch.randint(0, 10, size=(x.shape[0], 10, 1, 1), dtype=torch.float, device="cuda")

		generated_imgs = self(z, y)
		d_output = torch.squeeze(self.discriminator(generated_imgs, y))
		loss_fake = loss_fun(d_output,
								torch.zeros((x.shape[0], 10, 1, 1), dtype=torch.float).to("cuda"))

		return loss_real + loss_fake

	def training_step(self, batch, batch_idx):
		X, y = batch

		g_optimizer, d_optimizer = self.optimizers()

		g_loss = self.generator_step(X)
		
		d_loss = self.discriminator_step(X, y)

		return g_loss - d_loss

	def configure_optimizers(self):
		g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0002)
		d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)
		return [g_optimizer, d_optimizer], []

In [140]:
# class CGAN(pl.LightningModule):
# 	def __init__(self):
# 		super().__init__()
# 		self.generator = NumGenNet()
# 		self.discriminator = NumDiscNet()

# 		self.automatic_optimization = False

# 	def forward(self, z, y):
# 		return self.generator(z, y)

# 	def generator_step(self, x):
# 		z = torch.randn((x.shape[0], 100, 1, 1), dtype=torch.float).to("cuda")

# 		y = torch.randint(0, 10, size=(x.shape[0], 10, 1, 1), dtype=torch.float, device="cuda")

# 		print(z.shape)
# 		print(y.shape)

# 		generated_imgs = self(z, y)

# 		d_output = torch.squeeze(self.discriminator(generated_imgs, y))

# 		g_loss = loss_fun(d_output,
# 							torch.ones((x.shape[0], 10, 1, 1), dtype=torch.float).to("cuda"))

# 		return g_loss

# 	def discriminator_step(self, x, y):
# 		d_output = torch.squeeze(self.discriminator(x, y))
# 		loss_real = loss_fun(d_output,
# 								torch.ones((x.shape[0], 10, 1, 1), dtype=torch.float).to("cuda"))

# 		z = torch.randn((x.shape[0], 100, 1, 1), dtype=torch.float).to("cuda")
# 		y = torch.randint(0, 10, size=(x.shape[0], 10, 1, 1), dtype=torch.float, device="cuda")

# 		generated_imgs = self(z, y)
# 		d_output = torch.squeeze(self.discriminator(generated_imgs, y))
# 		loss_fake = loss_fun(d_output,
# 								torch.zeros((x.shape[0], 10, 1, 1), dtype=torch.float).to("cuda"))

# 		return loss_real + loss_fake

# 	def training_step(self, batch, batch_idx):
# 		X, y = batch

# 		g_optimizer, d_optimizer = self.optimizers()

# 		g_loss = self.generator_step(X)
		
# 		d_loss = self.discriminator_step(X, y)

# 		return g_loss - d_loss

# 	def configure_optimizers(self):
# 		g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0002)
# 		d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)
# 		return [g_optimizer, d_optimizer], []

In [141]:
transforms = transforms.Compose([
                                    transforms.Resize(32),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.5], std=[0.5]),
                                    transforms.Lambda(lambda x: x.view(-1, 32*32)),
                                    transforms.Lambda(lambda x: torch.squeeze(x))
                                    ])

data = datasets.MNIST(root='/home/maxim/Documents/TestProject/maxim-lightning/mnist_diffusion', download=True, transform=transforms)

trainloader = DataLoader(data, batch_size=32, shuffle=True, num_workers=20) 

model = CGAN()

trainer = pl.Trainer(max_epochs=10, accelerator="gpu")
trainer.fit(model, trainloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type       | Params
---------------------------------------------
0 | generator     | NumGenNet  | 3.1 M 
1 | discriminator | NumDiscNet | 2.6 M 
---------------------------------------------
5.7 M     Trainable params
0         Non-trainable params
5.7 M     Total params
22.873    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

torch.Size([32, 100, 1, 1])
torch.Size([32, 10, 1, 1])


RuntimeError: Calculated padded input size per channel: (3 x 3). Kernel size: (4 x 4). Kernel size can't be greater than actual input size