In [12]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [4]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.linear1 = nn.Linear(latent_dim, 128)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(128, 784)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.sigmoid(x)
        
        return x

In [8]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(784, 128)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.sigmoid(x)
        
        return x

In [11]:
x_g = torch.randn(1, 32)
x_d = torch.randn(1, 784)
generator = Generator(32)
discriminator = Discriminator()
output_g = generator(x_g)
output_d = discriminator(x_d)
output_g.shape, output_d.shape

(torch.Size([1, 784]), torch.Size([1, 1]))

In [None]:
generator_opt = optim.Adam(generator.parameters())
discriminator_opt = optim.Adam(discriminator.parameters())

In [None]:
epochs = 100
steps = 10
batch_size = 32
latent_dim = 32

for epoch in range(epochs):
    for k in range(steps):
        noise = torch.randn(batch_size, latent_dim)
        real_data = None
        
        generator_opt.zero_grad()
        discriminator.zero_grad()
        generated_image = generator(noise)
        prediction_real = discriminator(real_data)
        prediction_generated = discriminator(generated_image)
        
        loss_discriminator = -torch.sum(torch.log(prediction_real) + torch.log(1 - prediction_generated)) / batch_size
        loss_discriminator.backward()
        discriminator_opt.step()
        
    noise = torch.randn(batch_size, latent_dim)
    generator_opt.zero_grad()
    generated_image = generator(noise)
    
    loss_generator = torch.sum(torch.log(1 - discriminator(generator(generated_image)))) / batch_size
    loss_generator.backward()
    generator_opt.step()