<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Generative_Adversarial_Networks_(GANs)_with_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

# Transformer-based generator
class TransformerGenerator(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers, num_heads, hidden_dim, sequence_length):
        super(TransformerGenerator, self).__init__()
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.transformer = nn.Transformer(
            d_model=input_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            dim_feedforward=hidden_dim,
            batch_first=True  # Enables batch as the first dimension
        )
        self.fc = nn.Linear(input_dim * sequence_length, output_dim)

    def forward(self, x):
        # Reshape input for Transformer: [batch_size, sequence_length, input_dim]
        x = x.view(-1, self.sequence_length, self.input_dim)
        x = self.transformer(x, x)  # Pass through transformer
        x = x.reshape(x.size(0), -1)  # Flatten output
        return self.fc(x)  # Map to output dimension

# Discriminator network
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.fc(x)

# Training function
def train_gan(generator, discriminator, data_loader, num_epochs=100, sequence_length=16):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)

    optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
    optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)
    loss_fn = nn.BCELoss()

    for epoch in range(num_epochs):
        for real_images, _ in data_loader:
            real_images = real_images.view(real_images.size(0), -1).to(device)  # Flatten images
            batch_size = real_images.size(0)

            # Train Discriminator
            optimizer_d.zero_grad()
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Loss on real data
            outputs = discriminator(real_images)
            loss_real = loss_fn(outputs, real_labels)

            # Loss on fake data
            noise = torch.randn(batch_size, sequence_length * generator.input_dim).to(device)
            fake_data = generator(noise)
            outputs = discriminator(fake_data.detach())
            loss_fake = loss_fn(outputs, fake_labels)

            # Backprop for discriminator
            loss_d = loss_real + loss_fake
            loss_d.backward()
            optimizer_d.step()

            # Train Generator
            optimizer_g.zero_grad()
            outputs = discriminator(fake_data)
            loss_g = loss_fn(outputs, real_labels)  # Generator aims to fool discriminator
            loss_g.backward()
            optimizer_g.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}")

# DataLoader for MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = DataLoader(dataset, batch_size=64, shuffle=True)

# Initialize models and start training
input_dim = 32
output_dim = 28 * 28  # For MNIST (28x28 images)
sequence_length = 16
generator = TransformerGenerator(input_dim, output_dim, num_layers=2, num_heads=4, hidden_dim=256, sequence_length=sequence_length)
discriminator = Discriminator(output_dim)

train_gan(generator, discriminator, data_loader, num_epochs=10, sequence_length=sequence_length)