In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torchvision.utils as vutils
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import os

In [None]:
!pip install opendatasets --quiet

In [None]:
import opendatasets as od

In [None]:
dataset_url = 'https://www.kaggle.com/datasets/greatgamedota/ffhq-face-data-set'
od.download(dataset_url)

Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username: rajpriyesh
Your Kaggle Key: ··········
Dataset URL: https://www.kaggle.com/datasets/greatgamedota/ffhq-face-data-set
Downloading ffhq-face-data-set.zip to ./ffhq-face-data-set


100%|██████████| 1.97G/1.97G [01:38<00:00, 21.5MB/s]





In [None]:
class FaceDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Initialize the FaceDataset.

        Args:
            root_dir (str): The root directory containing the image files.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith('.jpg') or f.endswith('.png')]

    def __len__(self):
        """
        Get the length of the dataset.

        Returns:
            int: The number of images in the dataset.
        """
        return len(self.image_files)

    def __getitem__(self, idx):
        """
        Get an item from the dataset at the given index.

        Args:
            idx (int): Index of the image to return.

        Returns:
            PIL.Image.Image: The image at the given index.
        """
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        """
        Initialize the Generator network.

        Args:
            latent_dim (int): The dimension of the latent vector.
        """
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        """
        Initialize the Discriminator network.
        """
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False)
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

In [None]:
def train_gan(dataloader, output_dir, num_epochs, latent_dim, lr, device):
    generator = Generator(latent_dim).to(device)
    discriminator = Discriminator().to(device)

    g_optimizer = optim.Adam(generator.parameters(), lr=lr * 0.5, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr * 0.2, betas=(0.5, 0.999))

    g_scheduler = StepLR(g_optimizer, step_size=100, gamma=0.8)
    d_scheduler = StepLR(d_optimizer, step_size=100, gamma=0.8)

    criterion = nn.BCEWithLogitsLoss()

    fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)
    real_label = 1
    fake_label = 0

    for epoch in range(num_epochs):
        for i, real_images in enumerate(dataloader):
            batch_size = real_images.size(0)
            real_images = real_images.to(device)

            # Train Discriminator
            d_optimizer.zero_grad()
            label = torch.full((batch_size,), real_label, device=device)
            output = discriminator(real_images)
            d_loss_real = criterion(output, label.float())
            d_loss_real.backward()

            noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
            fake_images = generator(noise)
            label.fill_(fake_label)
            output = discriminator(fake_images.detach())
            d_loss_fake = criterion(output, label.float())
            d_loss_fake.backward()

            d_loss = d_loss_real + d_loss_fake
            d_optimizer.step()

            # Train Generator
            g_optimizer.zero_grad()
            label.fill_(real_label)
            output = discriminator(fake_images)
            g_loss = criterion(output, label.float())
            g_loss.backward()
            g_optimizer.step()

            if i % 50 == 0:
                print(f'Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] '
                      f'D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}')

        g_scheduler.step()
        d_scheduler.step()

        if epoch % 10 == 0 or epoch == num_epochs - 1:
            with torch.no_grad():
                fake_images = generator(fixed_noise).detach().cpu()
                vutils.save_image(fake_images, f'{output_dir}/fake_images_epoch_{epoch}.png', normalize=True)

        if epoch % 10 == 0:
            torch.save({
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'g_optimizer_state_dict': g_optimizer.state_dict(),
                'd_optimizer_state_dict': d_optimizer.state_dict(),
                'epoch': epoch,
            }, f'{output_dir}/checkpoint_epoch_{epoch}.pt')

    return generator, discriminator


In [None]:
# Set parameters
data_path = '/content/ffhq-face-data-set/thumbnails128x128'
output_dir = 'generated_images'
image_size = 64
batch_size = 64
num_epochs = 1000
latent_dim = 100
lr = 0.0002

In [None]:
# Define data transformations
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [None]:
# Create dataset and dataloader
dataset = FaceDataset(data_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# Start training
generator, discriminator = train_gan(
    dataloader=dataloader,
    output_dir=output_dir,
    num_epochs=num_epochs,
    latent_dim=latent_dim,
    lr=lr,
    device=device
)

print("Training completed.")

Epoch [0/1000] Batch [0/1093] D_loss: 1.3044 G_loss: 1.0830
Epoch [0/1000] Batch [50/1093] D_loss: 0.4579 G_loss: 3.7086
Epoch [0/1000] Batch [100/1093] D_loss: 0.4196 G_loss: 4.3230
Epoch [0/1000] Batch [150/1093] D_loss: 0.2960 G_loss: 4.4512
Epoch [0/1000] Batch [200/1093] D_loss: 0.2229 G_loss: 4.9200
Epoch [0/1000] Batch [250/1093] D_loss: 0.2307 G_loss: 5.2754
Epoch [0/1000] Batch [300/1093] D_loss: 0.1623 G_loss: 5.5245
Epoch [0/1000] Batch [350/1093] D_loss: 0.5067 G_loss: 2.4484
Epoch [0/1000] Batch [400/1093] D_loss: 0.5744 G_loss: 4.3721
Epoch [0/1000] Batch [450/1093] D_loss: 0.5262 G_loss: 4.5842
Epoch [0/1000] Batch [500/1093] D_loss: 0.2454 G_loss: 3.8514
Epoch [0/1000] Batch [550/1093] D_loss: 0.6492 G_loss: 2.5806
Epoch [0/1000] Batch [600/1093] D_loss: 0.1841 G_loss: 4.4133
Epoch [0/1000] Batch [650/1093] D_loss: 0.4006 G_loss: 3.6624
Epoch [0/1000] Batch [700/1093] D_loss: 0.4514 G_loss: 3.2191
Epoch [0/1000] Batch [750/1093] D_loss: 0.3636 G_loss: 2.6085
Epoch [0/10

In [None]:
def train_gan(dataloader, output_dir, num_epochs, latent_dim, lr, device, checkpoint_file=None):
    generator = Generator(latent_dim).to(device)
    discriminator = Discriminator().to(device)

    g_optimizer = optim.Adam(generator.parameters(), lr=lr * 0.5, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr * 0.2, betas=(0.5, 0.999))

    g_scheduler = StepLR(g_optimizer, step_size=100, gamma=0.8)
    d_scheduler = StepLR(d_optimizer, step_size=100, gamma=0.8)

    criterion = nn.BCEWithLogitsLoss()

    fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)
    real_label = 1
    fake_label = 0

    start_epoch = 0
    if checkpoint_file:
        checkpoint = torch.load(checkpoint_file, map_location=device)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
        d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}")

    for epoch in range(start_epoch, num_epochs):
        for i, real_images in enumerate(dataloader):
            batch_size = real_images.size(0)
            real_images = real_images.to(device)

            # Train Discriminator
            d_optimizer.zero_grad()
            label = torch.full((batch_size,), real_label, device=device)
            output = discriminator(real_images)
            d_loss_real = criterion(output, label.float())
            d_loss_real.backward()

            noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
            fake_images = generator(noise)
            label.fill_(fake_label)
            output = discriminator(fake_images.detach())
            d_loss_fake = criterion(output, label.float())
            d_loss_fake.backward()

            d_loss = d_loss_real + d_loss_fake
            d_optimizer.step()

            # Train Generator
            g_optimizer.zero_grad()
            label.fill_(real_label)
            output = discriminator(fake_images)
            g_loss = criterion(output, label.float())
            g_loss.backward()
            g_optimizer.step()

            if i % 50 == 0:
                print(f'Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] '
                      f'D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}')

        g_scheduler.step()
        d_scheduler.step()

        if epoch % 10 == 0 or epoch == num_epochs - 1:
            with torch.no_grad():
                fake_images = generator(fixed_noise).detach().cpu()
                vutils.save_image(fake_images, f'{output_dir}/fake_images_epoch_{epoch}.png', normalize=True)

        if epoch % 10 == 0:
            torch.save({
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'g_optimizer_state_dict': g_optimizer.state_dict(),
                'd_optimizer_state_dict': d_optimizer.state_dict(),
                'epoch': epoch,
            }, f'{output_dir}/checkpoint_epoch_{epoch}.pt')

    return generator, discriminator

In [None]:
# Start training or resume from checkpoint
checkpoint_file = '/content/checkpoint_epoch_830.pt'  # Set to None for fresh training
generator, discriminator = train_gan(
    dataloader=dataloader,
    output_dir=output_dir,
    num_epochs=num_epochs,
    latent_dim=latent_dim,
    lr=lr,
    device=device,
    checkpoint_file=checkpoint_file
)

print("Training completed.")

Resuming training from epoch 831


  checkpoint = torch.load(checkpoint_file, map_location=device)


Epoch [831/1000] Batch [0/1093] D_loss: 0.0025 G_loss: 10.1432
Epoch [831/1000] Batch [50/1093] D_loss: 0.0095 G_loss: 12.2694
Epoch [831/1000] Batch [100/1093] D_loss: 0.0112 G_loss: 6.7801
Epoch [831/1000] Batch [150/1093] D_loss: 0.0242 G_loss: 9.7713
Epoch [831/1000] Batch [200/1093] D_loss: 1.8219 G_loss: 5.0339
Epoch [831/1000] Batch [250/1093] D_loss: 0.0158 G_loss: 8.3451
Epoch [831/1000] Batch [300/1093] D_loss: 0.0398 G_loss: 8.7841
Epoch [831/1000] Batch [350/1093] D_loss: 0.0199 G_loss: 7.6603
Epoch [831/1000] Batch [400/1093] D_loss: 0.0165 G_loss: 7.4633
Epoch [831/1000] Batch [450/1093] D_loss: 0.0009 G_loss: 10.6555
Epoch [831/1000] Batch [500/1093] D_loss: 0.0141 G_loss: 7.8914
Epoch [831/1000] Batch [550/1093] D_loss: 0.0691 G_loss: 7.5628
Epoch [831/1000] Batch [600/1093] D_loss: 0.0051 G_loss: 9.9967
Epoch [831/1000] Batch [650/1093] D_loss: 0.0120 G_loss: 5.5973
Epoch [831/1000] Batch [700/1093] D_loss: 0.0179 G_loss: 8.0801
Epoch [831/1000] Batch [750/1093] D_loss