In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 64
lr = 0.0002
z_dim = 100
epochs = 30
image_size = 28
input_channels = 1

# Create output directory
os.makedirs('pix2pix_images', exist_ok=True)

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True)

# Prepare MNIST paired dataset (image → label as binary mask)
def create_paired_data(images, labels):
    paired = []
    for img, label in zip(images, labels):
        label_img = torch.zeros_like(img)  # Binary mask for the label
        label_img[:, 7:21, 7:21] = (label + 1) / 10.0  # Center the label intensity
        paired.append(torch.cat([img, label_img], dim=0))  # Concatenate image and label
    return torch.stack(paired)

# Generator
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, kernel_size=4, padding=1)
        )

    def forward(self, x):
        return self.model(x)

# Initialize models
generator = Generator(in_channels=2, out_channels=1).to(device)  # Input: Image + Label, Output: Label
discriminator = Discriminator(in_channels=3).to(device)          # Input: Real or Fake Pair

# Loss function and optimizers
criterion = nn.BCEWithLogitsLoss()
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Training loop
for epoch in range(epochs):
    for i, (images, labels) in enumerate(data_loader):
        images, labels = images.to(device), labels.to(device)
        paired_data = create_paired_data(images, labels).to(device)

        # Split paired data into inputs and real targets
        inputs = paired_data[:, :1, :, :]  # Input images
        real_targets = paired_data[:, 1:, :, :]  # Real labels

        # Train Discriminator
        optimizer_d.zero_grad()

        # Real pair
        real_pair = torch.cat([inputs, real_targets], dim=1)
        real_output = discriminator(real_pair)
        real_loss = criterion(real_output, torch.ones_like(real_output))

        # Fake pair
        fake_targets = generator(inputs)
        fake_pair = torch.cat([inputs, fake_targets], dim=1)
        fake_output = discriminator(fake_pair.detach())
        fake_loss = criterion(fake_output, torch.zeros_like(fake_output))

        # Total Discriminator loss
        loss_d = real_loss + fake_loss
        loss_d.backward()
        optimizer_d.step()

        # Train Generator
        optimizer_g.zero_grad()
        fake_output = discriminator(fake_pair)
        loss_g = criterion(fake_output, torch.ones_like(fake_output))  # Fool discriminator
        loss_g.backward()
        optimizer_g.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(data_loader)}], "
                  f"Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}")

    # Save generated labels
    save_image(fake_targets.data[:25], f'pix2pix_images/epoch_{epoch}.png', nrow=5, normalize=True)

# Display some results
import matplotlib.pyplot as plt
grid = save_image(fake_targets.data[:25], nrow=5, normalize=True)
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.title("Generated Labels")
plt.axis('off')
plt.show()
