In [None]:
import torch
import torch.nn  as nn
import torchvision
import torch.optim as optim
import numpy as np
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler
from torchvision import transforms 
import random
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

In [None]:
class Network(nn.Module):
    
    def __init__(self,input_dim,output_dim,n_layers,hidden_dim,flag):
        super(Network,self).__init__()
        layers = []
        in_dim = input_dim
        for i in range(n_layers):
            layers.append(nn.Linear(in_dim,hidden_dim[i]))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_dim[i]))
            in_dim = hidden_dim[i]
            
        if flag:
            self.network = nn.Sequential(*layers,nn.Linear(hidden_dim[-1],output_dim),nn.Sigmoid())
        else:
            self.network = nn.Sequential(*layers,nn.Linear(hidden_dim[-1],output_dim),nn.Tanh())
    def forward(self,x):
        x = self.network(x)
        return x
    
    

In [None]:
discriminator = Network(784,1,2,[32,32],True)
generator = Network(64,784,2,[32,32],False)

D_optim = optim.Adam(discriminator.parameters(),lr=3e-4)
G_optim = optim.Adam(generator.parameters(),lr=3e-4)

loss = nn.BCELoss()

In [None]:
print(discriminator.network)
print(generator.network)

fixed_noise = np.random.normal(size=(32,64))
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

In [None]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5),(0.5))])
dataset = datasets.MNIST(root="./",download=True,transform=transform)

In [None]:
k = 1
save_sample = []
best_g_loss = 999999999999
for e in range(100000):
    for i in range(k):
        
        sampler = RandomSampler(dataset,replacement=True,num_samples=32) 

        real_sample = [dataset[i][0] for i in sampler]
        real_sample = torch.stack(real_sample)
       
        real_labels = torch.ones((32,1))

        noise = np.random.normal(size=(32,64))
        fake_sample = generator.network(torch.tensor(noise,dtype=torch.float))
        fake_sample = torch.reshape(fake_sample,(32,1,28,28))
        
        
        
        fake_labels = torch.zeros((32,1))
        
        D_loss = loss(discriminator.network(torch.flatten(real_sample,start_dim=1)),real_labels) + loss(discriminator.network(torch.flatten(fake_sample,start_dim=1)),fake_labels)
        
        D_optim.zero_grad()
        D_loss.backward(retain_graph=True)
        D_optim.step()
        
        
    G_loss = loss(discriminator.network(torch.flatten(fake_sample,start_dim=1)),real_labels)
    
    G_optim.zero_grad()
    G_loss.backward()
    G_optim.step()
    
    if G_loss<best_g_loss:
        best_g_loss = G_loss
        save_sample.append(fake_sample)
    
    if e%100 == 0:
        print("D_loss",D_loss.item())
        print("G_loss",G_loss.item())
        
    with torch.no_grad():
        fake = generator.network(torch.tensor(fixed_noise,dtype=torch.float)).reshape(-1, 1, 28, 28)
        data = real_sample.reshape(-1, 1, 28, 28)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        img_grid_real = torchvision.utils.make_grid(data, normalize=True)

        writer_fake.add_image(
            "Mnist Fake Images", img_grid_fake, global_step=step
        )
        writer_real.add_image(
            "Mnist Real Images", img_grid_real, global_step=step
        )
        step += 1
        


