In [None]:
import torch
import torchvision
import torch.nn as nn

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(100, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(256, 512, bias=False),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.fc3 = nn.Sequential(
            nn.Linear(512, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.fc4 = nn.Sequential(
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )
        self.residual = nn.Linear(256, 512, bias=False)

    def forward(self, x):
        out = self.fc1(x)
        residual = self.residual(out)
        out = self.fc2(out)
        out = self.fc3(out + residual)
        out = self.fc4(out)
        return out.view(-1, 1, 28, 28)


In [None]:
import torch.nn.utils.spectral_norm as spectral_norm

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            spectral_norm(nn.Linear(28 * 28, 1024)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Linear(1024, 512)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Linear(512, 256)),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
device="cuda" if torch.cuda.is_available else  "cpu"

In [None]:
loss_fn = nn.BCELoss()
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [None]:
import torch.optim as optim
from torchvision import transforms, datasets

generator_optimizer = optim.Adam(generator.parameters(), lr=1e-4)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)

In [None]:
def train_step(real_images):
    batch_size = real_images.size(0)
    real_images = real_images.to(device)

    noise = torch.randn(batch_size, 100).to(device)
    fake_images = generator(noise)


    real_labels = torch.ones(batch_size, 1).to(device)
    fake_labels = torch.zeros(batch_size, 1).to(device)

    #Train Discriminator
    discriminator_optimizer.zero_grad()

    real_output = discriminator(real_images)
    fake_output = discriminator(fake_images.detach())

    disc_loss_real = loss_fn(real_output, real_labels)
    disc_loss_fake = loss_fn(fake_output, fake_labels)

    disc_loss = disc_loss_real + disc_loss_fake
    disc_loss.backward()
    discriminator_optimizer.step()

    #Train Generator
    generator_optimizer.zero_grad()

    fake_output = discriminator(fake_images)
    gen_loss = loss_fn(fake_output, real_labels)

    gen_loss.backward()
    generator_optimizer.step()

    return disc_loss.item(), gen_loss.item()

In [None]:
import os
from torchvision.utils import save_image

# Directory to save images
output_dir = "generated_images"
os.makedirs(output_dir, exist_ok=True)

# Function to save generated images
def save_generated_images(epoch, generator, device, num_images=64):
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_images, 100).to(device)  # Generate noise
        fake_images = generator(noise)  # Generate images
        fake_images = (fake_images + 1) / 2  # Rescale images to [0, 1]
        save_image(fake_images, f"{output_dir}/epoch_{epoch+1}.png", nrow=8)
    generator.train()  # Set generator back to training mode

In [None]:
def train(dataset, epochs):
    for epoch in range(epochs):
        for i, (real_images, _) in enumerate(dataset):
            d_loss, g_loss = train_step(real_images)

        print(f"Epoch [{epoch+1}/{epochs}], D Loss: {d_loss:.4f}, G Loss: {g_loss:.4f}")
        save_generated_images(epoch, generator, device)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.MNIST(root='mnist_data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.9MB/s]


Extracting mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 441kB/s]


Extracting mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.05MB/s]


Extracting mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 5.82MB/s]

Extracting mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist_data/MNIST/raw






In [None]:
train(train_loader, epochs=50)

Epoch [1/50], D Loss: 0.7877, G Loss: 1.2987
Epoch [2/50], D Loss: 0.7934, G Loss: 1.3090
Epoch [3/50], D Loss: 0.9433, G Loss: 1.3198
Epoch [4/50], D Loss: 0.8341, G Loss: 1.2501
Epoch [5/50], D Loss: 0.7493, G Loss: 1.3444
Epoch [6/50], D Loss: 0.6360, G Loss: 1.5833
Epoch [7/50], D Loss: 0.7157, G Loss: 1.4484
Epoch [8/50], D Loss: 0.7697, G Loss: 1.5498
Epoch [9/50], D Loss: 0.6966, G Loss: 1.5702
Epoch [10/50], D Loss: 0.6982, G Loss: 1.7610
Epoch [11/50], D Loss: 0.6962, G Loss: 1.8149
Epoch [12/50], D Loss: 0.6015, G Loss: 1.7260
Epoch [13/50], D Loss: 0.6345, G Loss: 1.8528
Epoch [14/50], D Loss: 0.6651, G Loss: 1.9939
Epoch [15/50], D Loss: 0.6649, G Loss: 1.7126
Epoch [16/50], D Loss: 0.5935, G Loss: 1.9379
Epoch [17/50], D Loss: 0.5883, G Loss: 2.0224
Epoch [18/50], D Loss: 0.5665, G Loss: 2.0752
Epoch [19/50], D Loss: 0.7645, G Loss: 2.0782
Epoch [20/50], D Loss: 0.6635, G Loss: 1.8237
Epoch [21/50], D Loss: 0.7023, G Loss: 2.0954
Epoch [22/50], D Loss: 0.8293, G Loss: 1.75