In [1]:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

# Hyperparameters
latent_dim = 100
lr = 0.0002
batch_size = 128
image_size = 32 * 32  # CIFAR-10 images are 32x32
num_epochs = 500

## Download Dataset

In [2]:
device = torch.device('cuda')

os.makedirs('./data/cifar10', exist_ok=True)
# CIFAR-10 dataset
transform_cifar10 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 training dataset
train_dataset_cifar10 = datasets.CIFAR10(root='./data', train=True, transform=transform_cifar10, download=True)
train_loader_cifar10 = DataLoader(dataset=train_dataset_cifar10, batch_size=batch_size, shuffle=True, drop_last=True)

# Load CIFAR-10 test dataset
test_dataset_cifar10 = datasets.CIFAR10(root='./data', train=False, transform=transform_cifar10, download=True)
test_loader_cifar10 = DataLoader(dataset=test_dataset_cifar10, batch_size=batch_size, shuffle=False, drop_last=True)

# Move data to the specified device (optional)
for images, labels in train_loader_cifar10:
    images, labels = images.to(device), labels.to(device)
    # Your training code here

for images, labels in test_loader_cifar10:
    images, labels = images.to(device), labels.to(device)
    
print("Verifying data loading...")
for images, labels in train_loader_cifar10:
    print(f'Batch shape: {images.shape}')
    break


Files already downloaded and verified
Files already downloaded and verified
Verifying data loading...
Batch shape: torch.Size([128, 3, 32, 32])


## Models

In [3]:
class Generator(nn.Module):
    def __init__(self, img_channels=3):
        super(Generator, self).__init__()
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 512 * 4 * 4))  # Projecting to a smaller space

        self.model = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, img_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 512, 4, 4)
        img = self.model(out)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_channels=3):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_channels, 128, kernel_size=3, stride=2, padding=1),f
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
        
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
        
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
        
            nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
        
            nn.Conv2d(1024, 2048, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(2048),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4)
        )
        
        self.adv_layer = nn.Sequential(
            nn.Linear(2048 * 1 * 1, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.size(0), -1)
        validity = self.adv_layer(out)
        return validity

## Training Stage

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

# Loss function
criterion = nn.BCELoss()

print("Starting training...")

# Training
for epoch in range(num_epochs):
    print(f'Starting epoch {epoch+1}/{num_epochs}')
    for i, (images, _) in enumerate(train_loader_cifar10):
        real_images = images.to(device)
        batch_size = real_images.size(0)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # --------- Train the Discriminator --------- #
        d_optimizer.zero_grad()
        outputs = discriminator(real_images)
        d_real_loss = criterion(outputs, real_labels)
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(z)
        outputs = discriminator(fake_images.detach())
        d_fake_loss = criterion(outputs, fake_labels)
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()

        # --------- Train the Generator --------- #
        g_optimizer.zero_grad()
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        g_optimizer.step()

        if (i+1) % 100 == 0:  # More frequent logging
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader_cifar10)}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

    # Save generated images every epoch
    save_image(fake_images.data[:25], f'./data/cifar10/fake_image_{epoch+1:03d}.png', nrow=5, normalize=True)

print("Training complete.")

Starting training...
Starting epoch 1/500
Epoch [1/500], Step [100/390], D Loss: 1.3668229579925537, G Loss: 0.742272138595581
Epoch [1/500], Step [200/390], D Loss: 1.4243297576904297, G Loss: 0.6729264259338379
Epoch [1/500], Step [300/390], D Loss: 1.324495553970337, G Loss: 0.8380115032196045
Starting epoch 2/500
Epoch [2/500], Step [100/390], D Loss: 1.432347297668457, G Loss: 0.7294477224349976
Epoch [2/500], Step [200/390], D Loss: 1.339842677116394, G Loss: 0.7587568163871765
Epoch [2/500], Step [300/390], D Loss: 1.3487144708633423, G Loss: 0.898546040058136
Starting epoch 3/500
Epoch [3/500], Step [100/390], D Loss: 1.503427267074585, G Loss: 0.8152139782905579
Epoch [3/500], Step [200/390], D Loss: 1.3597757816314697, G Loss: 0.8042879104614258
Epoch [3/500], Step [300/390], D Loss: 1.1412049531936646, G Loss: 0.879277765750885
Starting epoch 4/500
Epoch [4/500], Step [100/390], D Loss: 1.1273716688156128, G Loss: 1.0551891326904297
Epoch [4/500], Step [200/390], D Loss: 0.6