In [None]:
import torch
import torchvision
import os
import matplotlib.pyplot as plt
import random
%matplotlib inline

In [None]:
random.seed(0)
torch.cuda.get_device_capability(device=None)
torch.cuda.get_device_name(device=None)

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(128),
    torchvision.transforms.CenterCrop(128),
    torchvision.transforms.ToTensor(),
    #torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = torchvision.datasets.ImageFolder("/kaggle/input/celeba-dataset/img_align_celeba", transform = transform)
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=100, shuffle=True, num_workers=8, drop_last=True)
print(len(dataset))

In [None]:
figure, axes = plt.subplots(2, 2)
axes[(0, 0)].imshow(dataset[random.randint(0, 100)][0].permute(1, 2, 0))
axes[0, 1].imshow(dataset[random.randint(0, 100)][0].permute(1, 2, 0))
axes[1, 0].imshow(dataset[random.randint(0, 100)][0].permute(1, 2, 0))
axes[1, 1].imshow(dataset[random.randint(0, 100)][0].permute(1, 2, 0))

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0)

In [None]:
num_epochs = 10
ngpu = torch.cuda.device_count()
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
class Generator(torch.nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(in_channels = 100, out_channels = 64, kernel_size = 4, stride = 1, padding = 0, output_padding = 0, bias=True),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 4, stride = 2, padding = 1, output_padding = 0, bias=True),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 4, stride = 2, padding = 1, output_padding = 0, bias=True),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 4, stride = 2, padding = 1, output_padding = 0, bias=True),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 4, stride = 2, padding = 1, output_padding = 0, bias=True),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(in_channels = 64, out_channels = 3, kernel_size = 4, stride = 2, padding = 1, output_padding = 0, bias=True),
            torch.nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
generator = Generator(ngpu).to(device)
if (device.type == 'cuda') and (ngpu > 1):
    generator = torch.nn.DataParallel(generator, list(range(ngpu)))
generator.apply(weights_init)
print(generator)

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self, ngpu):
        super().__init__()
        self.ngpu = ngpu
        self.main = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 3, stride = 2, bias = True),
            torch.nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 2, bias = True),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 2, bias = True),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 2, bias = True),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_channels = 512, out_channels = 1024, kernel_size = 3, stride = 2, bias = True),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_channels = 1024, out_channels = 2048, kernel_size = 3, stride = 2, bias = True),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_channels = 2048, out_channels = 256, kernel_size = 1, bias = True),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_channels = 256, out_channels = 1, kernel_size = 1),
            torch.nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
discriminator = Discriminator(ngpu).to(device)
if (device.type == 'cuda') and (ngpu > 1):
    discriminator = torch.nn.DataParallel(discriminator, list(range(ngpu)))
discriminator.apply(weights_init)
print(discriminator)

In [None]:
criterion = torch.nn.BCELoss()
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizer_generator = torch.optim.Adam(generator.parameters(), lr = 0.0002, betas = (0.5, 0.999))

In [None]:
for epoch in range(num_epochs):
    for i, data in enumerate(data_loader, 0):
        discriminator.zero_grad()
        label = torch.full((100, ), 1, device=device)
        output = discriminator(data[0].to(device)).view(-1)
        error_discriminator_real = criterion(output, label)
        error_discriminator_real.backward()
        noise = torch.randn(100, 100, 1, 1, device=device)
        fake = generator(noise)
        label.fill_(0)
        output = discriminator(fake.detach()).view(-1)
        error_discriminator_fake = criterion(output, label)
        error_discriminator_fake.backward()
        error_discriminator = error_discriminator_real + error_discriminator_fake
        optimizer_discriminator.step()
        generator.zero_grad()
        label.fill_(1)
        output = discriminator(fake).view(-1)
        error_generator = criterion(output, label)
        error_generator.backward()
        optimizer_generator.step()
        if i % 50 == 0:
            print('[%3d/%3d][%4d/%4d]\tLoss_D: %.4f\tLoss_G: %.4f' % (epoch, num_epochs, i, len(data_loader), error_discriminator.item(), error_generator.item()))

In [None]:
torch.save(discriminator.state_dict(), "/kaggle/working/discriminator")

In [None]:
torch.save(generator.state_dict(), "/kaggle/working/generator")

In [None]:
noise = torch.randn(1, 100, 1, 1, device=device)
output = generator(noise)
output.cpu().detach().shape
plt.imshow(output.cpu().detach()[0].permute(1, 2, 0))

In [None]:
noise = torch.randn(1, 100, 1, 1, device=device)
output = generator(noise)
output.cpu().detach().shape
plt.imshow(output.cpu().detach()[0].permute(1, 2, 0))

In [None]:
noise = torch.randn(1, 100, 1, 1, device=device)
output = generator(noise)
output.cpu().detach().shape
plt.imshow(output.cpu().detach()[0].permute(1, 2, 0))

In [None]:
noise = torch.randn(1, 100, 1, 1, device=device)
output = generator(noise)
output.cpu().detach().shape
plt.imshow(output.cpu().detach()[0].permute(1, 2, 0))

In [None]:
noise = torch.randn(1, 100, 1, 1, device=device)
output = generator(noise)
output.cpu().detach().shape
plt.imshow(output.cpu().detach()[0].permute(1, 2, 0))

In [None]:
noise = torch.randn(1, 100, 1, 1, device=device)
output = generator(noise)
output.cpu().detach().shape
plt.imshow(output.cpu().detach()[0].permute(1, 2, 0))

In [None]:
noise = torch.randn(1, 100, 1, 1, device=device)
output = generator(noise)
output.cpu().detach().shape
plt.imshow(output.cpu().detach()[0].permute(1, 2, 0))

In [None]:
noise = torch.randn(1, 100, 1, 1, device=device)
output = generator(noise)
output.cpu().detach().shape
plt.imshow(output.cpu().detach()[0].permute(1, 2, 0))

In [None]:
noise = torch.randn(1, 100, 1, 1, device=device)
output = generator(noise)
output.cpu().detach().shape
plt.imshow(output.cpu().detach()[0].permute(1, 2, 0))

In [None]:
noise = torch.randn(1, 100, 1, 1, device=device)
output = generator(noise)
output.cpu().detach().shape
plt.imshow(output.cpu().detach()[0].permute(1, 2, 0))

In [None]:
noise = torch.randn(1, 100, 1, 1, device=device)
output = generator(noise)
output.cpu().detach().shape
plt.imshow(output.cpu().detach()[0].permute(1, 2, 0))