In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
batch_size = 100
lr = 0.0002
epochs = 50
latent_dim = 100
img_size = 28
img_shape = (1, img_size, img_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img



class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [None]:
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):

        valid = torch.ones(imgs.size(0), 1, device=device, requires_grad=False)
        fake = torch.zeros(imgs.size(0), 1, device=device, requires_grad=False)


        real_imgs = imgs.to(device)


        optimizer_G.zero_grad()
        z = torch.randn(imgs.size(0), latent_dim, device=device)
        generated_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(generated_imgs), valid)
        g_loss.backward()
        optimizer_G.step()


        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        if i % 400 == 0:
            print(f"Epoch {epoch}/{epochs} Batch {i}/{len(dataloader)} \
                  Loss D: {d_loss.item():.6f}, Loss G: {g_loss.item():.6f}")

Epoch 0/50 Batch 0/600                   Loss D: 0.682669, Loss G: 0.724249
Epoch 0/50 Batch 400/600                   Loss D: 0.340233, Loss G: 1.346480
Epoch 1/50 Batch 0/600                   Loss D: 0.351534, Loss G: 1.308863
Epoch 1/50 Batch 400/600                   Loss D: 0.496394, Loss G: 3.027916
Epoch 2/50 Batch 0/600                   Loss D: 0.437646, Loss G: 0.679266
Epoch 2/50 Batch 400/600                   Loss D: 0.296584, Loss G: 2.031601
Epoch 3/50 Batch 0/600                   Loss D: 0.467400, Loss G: 0.629870
Epoch 3/50 Batch 400/600                   Loss D: 0.259872, Loss G: 1.536625
Epoch 4/50 Batch 0/600                   Loss D: 0.248083, Loss G: 3.128299
Epoch 4/50 Batch 400/600                   Loss D: 0.477719, Loss G: 3.817528
Epoch 5/50 Batch 0/600                   Loss D: 0.312823, Loss G: 1.067905
Epoch 5/50 Batch 400/600                   Loss D: 0.380523, Loss G: 2.907200
Epoch 6/50 Batch 0/600                   Loss D: 0.203087, Loss G: 1.694281


In [12]:
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

In [None]:
import os
import torchvision.utils as vutils

os.makedirs('images', exist_ok=True)
with torch.no_grad():
    z = torch.randn(64, latent_dim, device=device)
    generated_imgs = generator(z)
    vutils.save_image(generated_imgs.data, 'images/generated_samples.png', nrow=8, normalize=True)