In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import os
from PIL import Image
from torchvision.transforms import ToPILImage


In [2]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(3 * 64 * 64, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        return self.disc(img)


In [3]:
class Generator(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 3 * 64 * 64),  # For 64x64 RGB images
            nn.Tanh(),
        )

    def forward(self, noise):
        return self.gen(noise)


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Hyperparameters
lr = 1e-4
z_dim = 100  # Latent vector dimension
image_dim = 3 * 64 * 64  # 3 (RGB) x 64 x 64
batch_size = 32
epochs = 200


In [5]:
disc = Discriminator().to(device)
gen = Generator(z_dim).to(device)
fixed_noise = torch.randn((1, z_dim)).to(device)  # Single noise vector for generating images

# Image transformations
transforms = transforms.Compose([
    transforms.Resize((64, 64)),            # Resize images to 64x64
    transforms.ToTensor(),                 # Convert to tensor
    transforms.Normalize([0.5], [0.5]),    # Normalize to [-1, 1]
])


In [6]:
dataset_path = "anime"  # Path to your cat image folder


# Load dataset
dataset = ImageFolder(root=dataset_path, transform=transforms)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

output_dir = "generatedimages"
os.makedirs(output_dir, exist_ok=True)


In [7]:
to_pil = ToPILImage()
for epoch in range(epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 3 * 64 * 64).to(device)  # Flatten images
        batch_size = real.shape[0]

        # Train Discriminator
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        # Train Generator
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        # Logging and saving images
        if batch_idx % 50 == 0:  # Save image every 50 batches
            print(
                f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(loader)} \
                Loss D: {lossD:.4f}, Loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                # Generate a single image using fixed noise
                fake_image = gen(fixed_noise).reshape(3, 64, 64)  # Reshape to (C, H, W)
                fake_image = (fake_image * 0.5 + 0.5).clamp(0, 1)  # Denormalize to [0, 1]

                # Convert to PIL image and save
                pil_image = to_pil(fake_image.cpu())
                image_path = os.path.join(output_dir, f"epoch_{epoch}_batch_{batch_idx}.png")
                pil_image.save(image_path)
                print(f"Saved image: {image_path}")

Epoch [0/200] Batch 0/1987                 Loss D: 0.6810, Loss G: 0.7293
Saved image: generatedimages\epoch_0_batch_0.png
Epoch [0/200] Batch 50/1987                 Loss D: 0.2665, Loss G: 1.4079
Saved image: generatedimages\epoch_0_batch_50.png
Epoch [0/200] Batch 100/1987                 Loss D: 0.0872, Loss G: 2.1639
Saved image: generatedimages\epoch_0_batch_100.png
Epoch [0/200] Batch 150/1987                 Loss D: 0.0625, Loss G: 2.7898
Saved image: generatedimages\epoch_0_batch_150.png
Epoch [0/200] Batch 200/1987                 Loss D: 0.0428, Loss G: 2.9132
Saved image: generatedimages\epoch_0_batch_200.png
Epoch [0/200] Batch 250/1987                 Loss D: 0.0531, Loss G: 2.5435
Saved image: generatedimages\epoch_0_batch_250.png
Epoch [0/200] Batch 300/1987                 Loss D: 0.1270, Loss G: 1.9026
Saved image: generatedimages\epoch_0_batch_300.png
Epoch [0/200] Batch 350/1987                 Loss D: 0.0897, Loss G: 2.7505
Saved image: generatedimages\epoch_0_batc

KeyboardInterrupt: 