In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import os

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 64
z_dim = 100
num_classes = 10
image_size = 28 * 28  # Fashion MNIST image dimensions (28x28)
lr = 0.0002
epochs = 50

# Create output directory for generated images
os.makedirs('cgan_images', exist_ok=True)

# Transformations for Fashion MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])

# Load Fashion MNIST dataset
train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# Conditional Generator
class Generator(nn.Module):
    def __init__(self, z_dim, num_classes, img_dim):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)  # Embedding for labels
        self.model = nn.Sequential(
            nn.Linear(z_dim + num_classes, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, img_dim),
            nn.Tanh()
        )

    def forward(self, z, labels):
        # Concatenate latent vector and label embeddings
        c = self.label_emb(labels)
        x = torch.cat([z, c], dim=1)
        return self.model(x).view(x.size(0), 1, 28, 28)

# Conditional Discriminator
class Discriminator(nn.Module):
    def __init__(self, num_classes, img_dim):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(img_dim + num_classes, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        # Concatenate image and label embeddings
        c = self.label_emb(labels)
        x = torch.cat([img.view(img.size(0), -1), c], dim=1)
        return self.model(x)

# Initialize models
generator = Generator(z_dim, num_classes, image_size).to(device)
discriminator = Discriminator(num_classes, image_size).to(device)

# Loss function and optimizers
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Training Loop
for epoch in range(epochs):
    for i, (real_images, labels) in enumerate(train_loader):
        real_images, labels = real_images.to(device), labels.to(device)

        # Labels for real and fake data
        real_labels = torch.ones(real_images.size(0), 1).to(device)
        fake_labels = torch.zeros(real_images.size(0), 1).to(device)

        # Train Discriminator
        optimizer_d.zero_grad()

        # Real images loss
        real_outputs = discriminator(real_images, labels)
        loss_real = criterion(real_outputs, real_labels)

        # Generate fake images
        z = torch.randn(real_images.size(0), z_dim).to(device)
        fake_images = generator(z, labels)
        fake_outputs = discriminator(fake_images.detach(), labels)
        loss_fake = criterion(fake_outputs, fake_labels)

        # Total loss and optimization
        loss_d = loss_real + loss_fake
        loss_d.backward()
        optimizer_d.step()

        # Train Generator
        optimizer_g.zero_grad()

        # Generator loss (fool the discriminator)
        fake_outputs = discriminator(fake_images, labels)
        loss_g = criterion(fake_outputs, real_labels)

        loss_g.backward()
        optimizer_g.step()

        # Log progress
        if i % 100 == 0:
            print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(train_loader)}], "
                  f"Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}")

    # Save generated images
    z = torch.randn(100, z_dim).to(device)
    sample_labels = torch.arange(0, num_classes).repeat(10).to(device)
    sample_images = generator(z, sample_labels)
    save_image(sample_images.data, f'cgan_images/epoch_{epoch}.png', nrow=10, normalize=True)

# Display generated images after training
sample_images = sample_images.data[:25].cpu()
grid = save_image(sample_images, normalize=True, nrow=5)
plt.imshow(grid.permute(1, 2, 0))
plt.title("Generated Images")
plt.axis('off')
plt.show()


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:01<00:00, 16.7MB/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 276kB/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:00<00:00, 5.00MB/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 14.0MB/s]


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Epoch [0/50], Step [0/938], Loss D: 1.3956, Loss G: 0.6740
Epoch [0/50], Step [100/938], Loss D: 0.6090, Loss G: 1.3286
Epoch [0/50], Step [200/938], Loss D: 0.4601, Loss G: 2.5100
Epoch [0/50], Step [300/938], Loss D: 0.5475, Loss G: 2.9061
Epoch [0/50], Step [400/938], Loss D: 0.5370, Loss G: 2.2823
Epoch [0/50], Step [500/938], Loss D: 0.7484, Loss G: 1.7488
Epoch [0/50], Step [600/938], Loss D: 0.9578, Loss G: 2.5706
Epoch [0/50], Step [700/938], Loss D: 0.9184, Loss G: 2.0503
Epoch [0/50], Step [800/938], Loss D: 0.5534, Loss G: 2.4223
Epoch [0/50], Step [900/938], Loss D: 1.0371, Loss G: 2.6712
Epoch [1/50], Step [0/938], Loss D: 0.8086, Loss G: 2.4999
Epoch [1/50], Step [100/938], Loss D: 0.7907, Loss G: 3.2434
Epoch [1/50], Step [200/938], Loss D: 0.9964, Loss G: 2.4417
Epoch [1/50], Step [300/938], Loss D: 0.8588, Loss G: 1.2442
Epoch [1/50], Step [400/938], Loss D: 1.1314, Loss G: 2.2527
