<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/GANs_for_Image_to_Image_Translation_(Pix2Pix).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

# Instantiate the generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Define loss functions and optimizers
criterion_gan = nn.BCEWithLogitsLoss()
criterion_l1 = nn.L1Loss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

# Define data transformations and load dataset
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = datasets.ImageFolder('path_to_data', transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        real_images, _ = data

        # Train Discriminator
        optimizer_d.zero_grad()

        # Real images
        real_labels = torch.ones(real_images.size(0), 1)
        output = discriminator(real_images).view(-1, 1)
        loss_real = criterion_gan(output, real_labels)

        # Fake images
        z = torch.randn(real_images.size(0), 100, 1, 1)
        fake_images = generator(z)
        fake_labels = torch.zeros(real_images.size(0), 1)
        output = discriminator(fake_images.detach()).view(-1, 1)
        loss_fake = criterion_gan(output, fake_labels)

        # Backward propagation and optimization
        loss_d = loss_real + loss_fake
        loss_d.backward()
        optimizer_d.step()

        # Train Generator
        optimizer_g.zero_grad()

        output = discriminator(fake_images).view(-1, 1)
        loss_g = criterion_gan(output, real_labels) + criterion_l1(fake_images, real_images) * 10

        # Backward propagation and optimization
        loss_g.backward()
        optimizer_g.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(dataloader)}], Loss D: {loss_d.item()}, Loss G: {loss_g.item()}")

# Save the models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')