## Imports


In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets
from torch import nn, optim
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

## Dataset

In [None]:
transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)
])

In [None]:
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transformer)

In [None]:
loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

In [None]:
fixed_noise = torch.randn((32, 100))

## Model

In [None]:
class Generator(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(in_features=100, out_features=256)
        self.bn1 = nn.BatchNorm1d(256, 0.2)
        self.fc2 = nn.Linear(in_features=256, out_features=512)
        self.bn2 = nn.BatchNorm1d(512, 0.2)
        self.fc3 = nn.Linear(in_features=512, out_features=784)
        

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), negative_slope=0.2)
        x = self.bn1(x)

        x = F.leaky_relu(self.fc2(x), negative_slope=0.2)
        x = self.bn2(x) 

        x = torch.tanh(self.fc3(x))

        return x


In [None]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(in_features=784, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=256)
        self.fc3 = nn.Linear(in_features=256, out_features=1)

        
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), negative_slope=0.2)
        
        x = F.leaky_relu(self.fc2(x), negative_slope=0.2)

        x = torch.sigmoid(self.fc3(x))

        return x

In [None]:
generator = Generator()
discriminator = Discriminator()

## Train

In [None]:
EPOCHS = 100

optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0003)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0003)

criterion_D = nn.BCELoss()
criterion_G = nn.BCELoss()

writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

In [None]:
for epoch in range(EPOCHS):

    for i , (real_images, _) in enumerate(loader):

        #Training discriminator

        optimizer_D.zero_grad()

        real_labels = torch.ones(real_images.size(0), 1)
        fake_labels = torch.zeros(real_images.size(0), 1)

        real_images = real_images.view(real_images.size(0), -1)

        noise = torch.randn(real_images.size(0), 100)
        fake_images = generator(noise)

        real_outputs = discriminator(real_images)
        fake_outputs = discriminator(fake_images.detach())

        real_loss = criterion_D(real_outputs, real_labels)
        fake_loss = criterion_D(fake_outputs, fake_labels)

        loss_d = (real_loss + fake_loss) / 2

        loss_d.backward()
        optimizer_D.step()

        #Training Genertor
        optimizer_G.zero_grad()

        fake_outputs = discriminator(fake_images)

        loss_g = criterion_G(fake_outputs, real_labels)

        loss_g.backward()
        optimizer_G.step()
        
        if i== 0:
            print(
                f"Epoch [{epoch}/{EPOCHS}] Batch {i}/{len(loader)} \
                      Loss D: {loss_d:.4f}, loss G: {loss_g:.4f}"
            )

            with torch.no_grad():
                fake_images = generator(fixed_noise).reshape(-1, 1, 28, 28)
                data = real_images.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake_images, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

In [None]:
FILE_GEN = 'generator_model.pth'
FILE_DISC = 'discriminator.pth'
torch.save(generator, FILE_GEN)
torch.save(discriminator, FILE_DISC)