For this project we will be using Pokemon dataset. This dataset contains ~40 images per 1,000 Pokémon species, structured in subdirectories for each class. Each image is resized to 128x128 pixels and stored as a PNG file.

This dataset is availible here:

https://www.kaggle.com/datasets/noodulz/pokemon-dataset-1000/data

Main objective of  our project

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [2]:
IMG_SIZE = 128
BATCH_SIZE = 64
LATENT_DIM = 100
EPOCHS = 50
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Let's begin with the first task - creatig a Convolutional NN for pokemon image classification

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),                  #делим на 255 получаем диапазон -> [0, 1]
    transforms.Normalize([0.5]*3, [0.5]*3)  # -> [-1, 1]; функция активации в генераторе будет tanh, которая работает с таким диапазоном
])

In [None]:
data_dir = "data/pokemon-dataset-1000/New folder/dataset"
dataset = datasets.ImageFolder(root=data_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

We will instantiate image generators to process our pokemons. The are three generators for training, testing and validation respectively.

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # input: z (batch, latent_dim, 1, 1)
            # для каждого изображения из батча формируется 128 карт признаков размером 8х8
            nn.ConvTranspose2d(in_channels=latent_dim,
                               out_channels=128,
                               kernel_size=8,
                               stride=1,
                               padding=0,
                               bias=False), # (128, 8, 8)
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=128,
                               out_channels=64,
                               kernel_size=4,
                               stride=2,
                               padding=1,
                               bias=False), # (64, 16, 16)
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False), # (32, 32, 32)
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False), # (16, 64, 64)
            nn.BatchNorm2d(16),
            nn.ReLU(True),

            nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False),  # (3, 128, 128)
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3,
                      out_channels=64,
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      bias=False),  # (64, 64, 64)
            # не делается batchnorm потому что напермо слое важно сохранить распределение настоящих изображений
            nn.LeakyReLU(0.2, inplace=True),


            nn.Conv2d(64, 128, 4, 2, 1, bias=False), # (128, 32, 32)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False), # (256, 16, 16)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1, bias=False), # (512, 8, 8)
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 8, 1, 0, bias=False),  # (1, 1, 1)
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.net(img).view(-1, 1)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
generator = Generator(LATENT_DIM).to(DEVICE)
discriminator = Discriminator().to(DEVICE)

In [None]:
generator.apply(weights_init)
discriminator.apply(weights_init)

In [None]:
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

fixed_noise = torch.randn(16, LATENT_DIM, 1, 1, device=DEVICE)

In [None]:
for epoch in range(EPOCHS):
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = imgs.to(DEVICE)
        batch_size = real_imgs.size(0)

        # ------------------
        # Обучаем дискриминатор
        # ------------------
        valid = torch.ones(batch_size, 1, device=DEVICE)
        fake = torch.zeros(batch_size, 1, device=DEVICE)

        optimizer_D.zero_grad()

        real_loss = criterion(discriminator(real_imgs), valid)

        z = torch.randn(batch_size, LATENT_DIM, 1, 1, device=DEVICE)
        gen_imgs = generator(z)
        fake_loss = criterion(discriminator(gen_imgs.detach()), fake)

        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # ------------------
        # Обучаем генератор
        # ------------------
        optimizer_G.zero_grad()

        g_loss = criterion(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        if i % 100 == 0:
            print(f"[Epoch {epoch}/{EPOCHS}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

In [None]:
with torch.no_grad():
        gen_imgs = generator(fixed_noise).cpu()
        gen_imgs = (gen_imgs + 1) / 2  # денормализация
        plt.figure(figsize=(8, 8))
        for k in range(16):
            plt.subplot(4, 4, k+1)
            plt.imshow(gen_imgs[k].permute(1, 2, 0).numpy())
            plt.axis("off")
        plt.savefig(f"generated_epoch_{epoch+1}.png")
        plt.close()