In [7]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

## Generator


In [3]:
def gen_block(in_dim, out_dim, kernel_size, stride):
    return nn.Sequential(
        nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride=stride),
        nn.BatchNorm2d(out_dim),
        nn.ReLU()
    )
    
class Generator(nn.Module):
    def __init__(self, in_dim, kernel_size=2, stride=2):
        super(Generator, self).__init__()
        self.in_dim = in_dim
        self.gen = nn.Sequential(
            gen_block(in_dim, 1024, kernel_size, stride),
            gen_block(1024, 512, kernel_size, stride),
            gen_block(512, 256, kernel_size, stride),
            gen_block(256, 128, kernel_size, stride),
            nn.ConvTranspose2d(128, 1, kernel_size, stride=stride),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = x.view(len(x), self.in_dim, 1, 1)
        return self.gen(x)

## Discriminator


In [4]:
def disc_block(in_dim, out_dim, kernel_size, stride):
    return nn.Sequential(
        nn.Conv2d(in_dim, out_dim, kernel_size, stride=stride),
        nn.BatchNorm2d(out_dim),
        nn.LeakyReLU(0.2)
    )
    
class Discriminator(nn.Module):
    def __init__(self, kernel_size=2, stride=2):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            disc_block(1, 128, kernel_size, stride),
            disc_block(128, 256, kernel_size, stride),
            disc_block(256, 512, kernel_size, stride),
            disc_block(512, 1024, kernel_size, stride),
            nn.Conv2d(1024, 1, kernel_size, stride=stride),
        )

    def forward(self, x):
        x = self.disc(x)
        return x.view(len(x), -1)

## Losses


In [5]:
def gen_loss(gen, disc, num_images, z_dim):
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    disc_pred = disc(fake)
    criterion = nn.BCEWithLogitsLoss()
    gen_loss = criterion(disc_pred, torch.ones_like(disc_pred))
    return gen_loss

def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    disc_pred_fake = disc(fake)
    fake_loss = criterion(disc_pred_fake, torch.zeros_like(disc_pred_fake))
    disc_pred_real = disc(real)
    real_loss = criterion(disc_pred_real, torch.ones_like(disc_pred_real))
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss

In [6]:
batch_size = 16
noise_dim = 100

noise = torch.randn(batch_size, noise_dim)

gen = Generator(in_dim=noise_dim)
generated_images = gen(noise)
print("Generated Images Shape:", generated_images.shape)

disc = Discriminator()
disc_output = disc(generated_images)
print("Discriminator Output Shape:", disc_output.shape)

Generated Images Shape: torch.Size([16, 1, 32, 32])
Discriminator Output Shape: torch.Size([16, 1])
Discriminator Output: tensor([[-0.1464],
        [ 0.6151],
        [ 0.3880],
        [ 0.4352],
        [ 0.4317],
        [-0.3247],
        [ 0.4826],
        [-0.6484],
        [ 0.5375],
        [-0.1112],
        [-0.0263],
        [-0.4229],
        [ 0.2592],
        [-0.2507],
        [-0.1231],
        [-0.2539]], grad_fn=<ViewBackward0>)
