<a href="https://colab.research.google.com/github/OneFineStarstuff/TheOneEverAfter/blob/main/Training_the_GAN_with_Discriminator_Feedback.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the Generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Linear(100, 256*8*8)
        self.upsample_layers = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = torch.relu(self.fc(x)).view(-1, 256, 8, 8)
        x = self.upsample_layers(x)
        return x

# Define the Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.downsample_layers = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
        )
        self.fc = nn.Linear(256*8*8, 1)

    def forward(self, x):
        x = self.downsample_layers(x).view(-1, 256*8*8)
        x = torch.sigmoid(self.fc(x))
        return x

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models, optimizers, and loss function
generator = Generator().to(device)
discriminator = Discriminator().to(device)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
criterion = nn.BCELoss()

# Hyperparameters
batch_size = 32
epochs = 50
real_label = 1.
fake_label = 0.

# Data preparation
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

# Use your face dataset here
dataset = datasets.ImageFolder('/path/to/faces/dataset', transform=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
for epoch in range(epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.to(device)

        # Train Discriminator
        d_optimizer.zero_grad()
        real_output = discriminator(images)
        d_real_loss = criterion(real_output, torch.full_like(real_output, real_label))

        z = torch.randn(images.size(0), 100, device=device)  # Random noise for generator
        fake_images = generator(z)
        fake_output = discriminator(fake_images.detach())
        d_fake_loss = criterion(fake_output, torch.full_like(fake_output, fake_label))

        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        g_optimizer.zero_grad()
        fake_output = discriminator(fake_images)
        g_loss = criterion(fake_output, torch.full_like(fake_output, real_label))
        g_loss.backward()
        g_optimizer.step()

    print(f"Epoch [{epoch+1}/{epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

# Save the generator model
torch.save(generator.state_dict(), 'gan_generator.pth')
# Save the discriminator model
torch.save(discriminator.state_dict(), 'gan_discriminator.pth')