In [15]:
import torchvision
from pathlib import Path
import matplotlib.pyplot as plt

import argparse
import os
import numpy as np
import math
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

from torch.utils.data import Dataset

from natsort import natsorted

from PIL import Image

# cuda is available
if torch.cuda.is_available():
    print('cuda is available')
    # get the number of GPUs
    print('number of GPUs:', torch.cuda.device_count())
    # get the name of the GPU
    print('name of the GPU:', torch.cuda.get_device_name(0))
    # get the current device
    print('current device:', torch.cuda.current_device())

    cuda = True

torch.set_float32_matmul_precision("medium")

cuda is available
number of GPUs: 1
name of the GPU: NVIDIA GeForce RTX 3080
current device: 0


In [16]:
# https://www.kaggle.com/datasets/potatohd404/ffhq-128-70k/data

import os
import glob

files = glob.glob('/images/*.png')
for f in files:
    os.remove(f)

In [17]:
n_epochs = 200
batch_size = 64
lr = 0.0001
n_cpu = 8
latent_dim = 128
img_size = 128
channels = 3
n_critic = 5
clip_value = 0.01
sample_interval = 400

img_shape = (channels, img_size, img_size)




In [18]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity




In [19]:
class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.images = [f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image

In [20]:
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = CustomImageDataset(img_dir="archive", transform=transform)
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu,
)

In [21]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()

# Optimizers
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=lr)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

batches_done = 0
for epoch in range(n_epochs):

    for i, imgs in enumerate(dataloader):

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

        # Generate a batch of images
        fake_imgs = generator(z).detach()
        # Adversarial loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))

        loss_D.backward()
        optimizer_D.step()

        # Clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-clip_value, clip_value)

        # Train the generator every n_critic iterations
        if i % n_critic == 0:

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Generate a batch of images
            gen_imgs = generator(z)
            # Adversarial loss
            loss_G = -torch.mean(discriminator(gen_imgs))

            loss_G.backward()
            optimizer_G.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [Total loss: %f]"
                % (epoch, n_epochs, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item(), loss_D.item() + loss_G.item())
            )

        if batches_done % sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
        batches_done += 1

[Epoch 0/200] [Batch 0/1094] [D loss: -0.028886] [G loss: 0.019589] [Total loss: -0.009297]
[Epoch 0/200] [Batch 5/1094] [D loss: -24.168478] [G loss: -1.444064] [Total loss: -25.612542]
[Epoch 0/200] [Batch 10/1094] [D loss: -68.242561] [G loss: -11.810064] [Total loss: -80.052626]
[Epoch 0/200] [Batch 15/1094] [D loss: -121.629654] [G loss: -34.796505] [Total loss: -156.426159]
[Epoch 0/200] [Batch 20/1094] [D loss: -135.212357] [G loss: -68.888519] [Total loss: -204.100876]
[Epoch 0/200] [Batch 25/1094] [D loss: -154.154877] [G loss: -109.747307] [Total loss: -263.902184]
[Epoch 0/200] [Batch 30/1094] [D loss: -143.663239] [G loss: -157.814957] [Total loss: -301.478195]
[Epoch 0/200] [Batch 35/1094] [D loss: -144.752289] [G loss: -208.203415] [Total loss: -352.955704]
[Epoch 0/200] [Batch 40/1094] [D loss: -81.528412] [G loss: -240.032806] [Total loss: -321.561218]
[Epoch 0/200] [Batch 45/1094] [D loss: -20.135925] [G loss: -258.191803] [Total loss: -278.327728]
[Epoch 0/200] [Batch