In [None]:
import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriter
device = torch.device('mps')

In [None]:
def show_tensor_images(image_tensor, num_images=32, size=(1, 64, 64), title='Fake Images'):

    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = torchvision.utils.make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.title(title)
    plt.show()

In [None]:
class Discriminator(nn.Module):
    
    def __init__(self, image_channels, features_d):
        super(Discriminator, self).__init__()
        
        self.disc = nn.Sequential(
        nn.Conv2d(image_channels, features_d, 4, 2, 1),
        nn.LeakyReLU(0.2),   
        self._block(features_d, features_d*2, 4, 2, 1),
        self._block(features_d*2, features_d*4, 4, 2, 1),
        self._block(features_d*4, features_d*8, 4, 2, 1),
        nn.Conv2d(features_d*8, 1, 4, 2, 0),
        nn.Sigmoid()
        )
        
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2))
    
    def forward(self, x):
        return self.disc(x)

In [None]:
disc = Discriminator(1, 64).to(device)
x = torch.randn(32, 1, 64, 64).to(device)

out = disc(x)
out.shape

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, image_channels, features_g):
        super(Generator, self).__init__()
        
        self.gen = nn.Sequential(
        self._block(z_dim, features_g*16, 4, 1, 0),
        self._block(features_g*16, features_g*8, 4, 2, 1),
        self._block(features_g*8, features_g*4, 4, 2, 1),
        self._block(features_g*4, features_g*2, 4, 2, 1),
        nn.ConvTranspose2d(features_g*2, image_channels, 4, 2, 1),
        nn.Tanh()
        )
        
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU())
    
    def forward(self, x):
        return self.gen(x)

In [None]:
gen = Generator(100, 1, 64).to(device)
x = torch.randn(32, 100, 1, 1).to(device)

out = gen(x)
out.shape

In [None]:
def initialize_weights(model):

    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
import matplotlib.pyplot as plt

plt.imshow(out[0].cpu().detach().permute(1, 2, 0), cmap='gray')

In [None]:
lr = 2e-4
batch_size = 128
image_size = 64
image_channels = 1
features = 64
z_dim = 100
n_epochs = 10

In [None]:
import torchvision

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(image_size),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5 for _ in range(image_channels)], [0.5 for _ in range(image_channels)])
])

In [None]:
dataset = torchvision.datasets.MNIST(root="Datasets/", transform=transforms, download=True)

In [None]:
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
plt.imshow(next(iter(loader))[0][0].permute(1, 2, 0), cmap='gray')

In [None]:
disc = Discriminator(image_channels, features).to(device)
gen = Generator(z_dim, image_channels, features).to(device)
initialize_weights(disc)
initialize_weights(gen)

opt_disc = torch.optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
opt_gen = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()
fixed_noise = torch.randn((32, z_dim, 1, 1)).to(device)

In [None]:
gen.train()
disc.train()

In [None]:
import tqdm

step = 0
 
for epoch in range(n_epochs):
    for batch_idx, (real, _) in enumerate(tqdm.tqdm(loader)):
        
        real = real.to(device)
        noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
        fake = gen(noise)
        
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        
        disc_fake = disc(fake).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        
        lossD = (loss_disc_real + loss_disc_fake) / 2.0
        
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()
        
        
        output = disc(fake).reshape(-1).to(device)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        if batch_idx % 150 == 0:
            
                
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 64, 64).to(device)
                data = real.reshape(-1, 1, 64, 64).to(device)
                
                show_tensor_images(fake, num_images=25)
                show_tensor_images(data, num_images=25, title='Real Images')

                #writer_real.add_image('LinearGAN MNIST Discriminator', img_real_grid, global_step=step)
                
                #writer_fake.add_image('LinearGAN MNIST Generator', img_fake_grid, global_step=step)
    
            step+=1  
        
        if batch_idx % 300 == 0:
            
            print(
                 f"Epoch [{epoch+1}/{n_epochs}], Step {step} \nDiscriminator Loss: {lossD:.4f}, Generator Loss: {lossG:.4f}"
            )

In [None]:
z = torch.randn(25, 100, 1, 1).to(device)
out = gen(z)
show_tensor_images(out, num_images=25)
show_tensor_images(real, num_images=25, title='Real Images')