In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from tqdm import tqdm

# ✅ Generator Network
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1), nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


# ✅ Discriminator Network
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, 4, 1, 1)
        )

    def forward(self, x, y):
        # Concatenate input and target images along the channel dimension
        return self.model(torch.cat([x, y], dim=1))


# ✅ Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
G = Generator().to(device)
D = Discriminator().to(device)

opt_G = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

# Dummy data (simulated input and target images)
x = torch.randn(8, 3, 64, 64).to(device)  # Input images
y = torch.randn(8, 3, 64, 64).to(device)  # Target images

# ✅ Training loop (3 epochs for demo)
for epoch in range(3):
    tqdm.write(f"Epoch {epoch+1}")

    # Generate fake images
    fake_y = G(x)

    # --- Train Discriminator ---
    real_pred = D(x, y)
    fake_pred = D(x, fake_y.detach())

    loss_D_real = criterion_GAN(real_pred, torch.ones_like(real_pred))
    loss_D_fake = criterion_GAN(fake_pred, torch.zeros_like(fake_pred))
    loss_D = 0.5 * (loss_D_real + loss_D_fake)

    opt_D.zero_grad()
    loss_D.backward()
    opt_D.step()

    # --- Train Generator ---
    fake_pred = D(x, fake_y)
    loss_GAN = criterion_GAN(fake_pred, torch.ones_like(fake_pred))
    loss_L1 = criterion_L1(fake_y, y) * 100
    loss_G = loss_GAN + loss_L1

    opt_G.zero_grad()
    loss_G.backward()
    opt_G.step()

    tqdm.write(f"Loss_D: {loss_D.item():.4f} | Loss_G: {loss_G.item():.4f}")

# ✅ Save generated images
save_image((fake_y + 1) / 2, "pix2pix_result.png")
print("✅ Training done. Saved: pix2pix_result.png")


Epoch 1
Loss_D: 0.7118 | Loss_G: 93.9017
Epoch 2
Loss_D: 0.7016 | Loss_G: 89.2064
Epoch 3
Loss_D: 0.6940 | Loss_G: 85.3956
✅ Training done. Saved: pix2pix_result.png
