In [3]:
import torch
from pathlib import Path
import torchvision
from torch import nn
from torchvision import transforms
from dataclasses import dataclass
from torch.utils.tensorboard import SummaryWriter  

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
device

'cuda'

In [5]:
@dataclass
class ModelArgs:
    latent_vector_size = 100
    device = 'cpu'
    batch_size = 64
    lr = 5e-5
    num_classes = 10
    img_size = 64
    no_of_lables = 10
    no_of_channels = 1
    c = 1e-2
    nCritic = 5


In [6]:
ModelArgs.device = device

In [7]:
#Transforms for images
transforms = torchvision.transforms.Compose([
    transforms.Resize(size=(ModelArgs.img_size,ModelArgs.img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))

])

In [8]:
#Loading MNIST Dataset
import torchvision
from torch.utils.data import DataLoader
import os

data_path = Path('/data/')

# train_dir = data_path / "train"
# test_dir = data_path / "test"

# Load the training set
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=ModelArgs.batch_size, shuffle=True)

# Load the test set
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=ModelArgs.batch_size, shuffle=False)

In [9]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)  #mean = 0, std = 0.02


In [10]:
class Generator(nn.Module):
    def __init__(
        self,
        latent_vector_size = 100,
        no_of_channels = 1,
        kernel_size = (4,4),
        stride: int = 2,
        number_of_feature_maps: int = 64,
        padding: int = 1,

        img_size: int = 1

    ):

        super().__init__()

        self.dense = nn.Linear(in_features=latent_vector_size, out_features=img_size * img_size, device=ModelArgs.device)
        self.combined_hidden_layer_dimensions = latent_vector_size + ModelArgs.no_of_lables
        self.embedding = nn.Embedding(num_embeddings=ModelArgs.num_classes, embedding_dim=latent_vector_size, device=ModelArgs.device)

        self.img_size = img_size
        self.main = nn.Sequential(



            nn.ConvTranspose2d(ModelArgs.latent_vector_size + ModelArgs.latent_vector_size, number_of_feature_maps * 16 , kernel_size=kernel_size, stride=stride, padding=0, bias=False),
            nn.BatchNorm2d(number_of_feature_maps * 16),
            nn.ReLU(),

            #shape = (...,1024, 4, 4)
            nn.ConvTranspose2d(number_of_feature_maps * 16, number_of_feature_maps * 8 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(number_of_feature_maps * 8),
            nn.ReLU(),

            #shape = (..., 512, 8, 8)
            nn.ConvTranspose2d(number_of_feature_maps * 8, number_of_feature_maps * 4 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(number_of_feature_maps * 4),
            nn.ReLU(),
            

             #shape = (..., 256, 16, 16)
            nn.ConvTranspose2d(number_of_feature_maps * 4, number_of_feature_maps * 2 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(number_of_feature_maps * 2),
            nn.ReLU(),
            

             #shape = (..., 128, 32, 32)
            nn.ConvTranspose2d(number_of_feature_maps * 2, no_of_channels , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.Tanh()
            #shape = (..., 3, 64, 64)
        )

    def forward(self, x, y):
        labels = self.embedding(y)
        # print(x.shape)
        # print(y.shape)
        labels = labels.unsqueeze(2).unsqueeze(3).view(x.shape[0], ModelArgs.latent_vector_size, 1,1)

        combined = torch.cat([x, labels], dim=1)
        # print(combined.shape)
        out = self.main(combined)
        return out

In [166]:
#Intializing the Generator instance
generator = Generator().to(ModelArgs.device)

#Applying the weights transformation
generator.apply(weights_init)

#Printing the structure
print(generator)

Generator(
  (dense): Linear(in_features=100, out_features=1, bias=True)
  (embedding): Embedding(10, 100)
  (main): Sequential(
    (0): ConvTranspose2d(200, 1024, kernel_size=(4, 4), stride=(2, 2), bias=False)
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): ReLU()
    (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): ReLU()
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2

In [51]:
torch.randint(0, 10, (128, 1), dtype=torch.long, device=ModelArgs.device).shape

torch.Size([128, 1])

In [76]:
from torchinfo import summary

random_data = torch.randn(ModelArgs.batch_size, ModelArgs.latent_vector_size, 1, 1, device=ModelArgs.device)
# labels =
labels = torch.randint(0, 10, (64,), dtype=torch.long, device=ModelArgs.device)
random_data = random_data.to(ModelArgs.device)
summary(model=generator,
        
        # input_size=(128, 100, 10, 1, 1),
        input_data=(random_data, labels),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

torch.Size([64, 100, 1, 1])
torch.Size([64])


Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
Generator (Generator)                    [64, 100, 1, 1]      [64, 1, 64, 64]      101                  True
├─Embedding (embedding)                  [64]                 [64, 100]            1,000                True
├─Sequential (main)                      [64, 200, 1, 1]      [64, 1, 64, 64]      --                   True
│    └─ConvTranspose2d (0)               [64, 200, 1, 1]      [64, 1024, 4, 4]     3,276,800            True
│    └─BatchNorm2d (1)                   [64, 1024, 4, 4]     [64, 1024, 4, 4]     2,048                True
│    └─ReLU (2)                          [64, 1024, 4, 4]     [64, 1024, 4, 4]     --                   --
│    └─ConvTranspose2d (3)               [64, 1024, 4, 4]     [64, 512, 8, 8]      8,388,608            True
│    └─BatchNorm2d (4)                   [64, 512, 8, 8]      [64, 512, 8, 8]      1,024                True
│    └─ReLU (5) 

In [11]:
class Critic(nn.Module):
    def __init__(
        self,
        no_of_channels = 1,
        kernel_size = (4,4),
        stride: int = 2,
        number_of_feature_maps: int = 64,
        padding: int = 1,
        lr_slope=0.2,

    ):

        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=ModelArgs.num_classes, embedding_dim=ModelArgs.img_size * ModelArgs.img_size, device=ModelArgs.device)


        self.main = nn.Sequential(
            nn.Conv2d(no_of_channels + 1, number_of_feature_maps * 2 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(number_of_feature_maps * 2),
            nn.LeakyReLU(negative_slope=lr_slope),

                #shape = (...,1024, 32, 32)
            nn.Conv2d(number_of_feature_maps * 2, number_of_feature_maps * 4 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(number_of_feature_maps * 4),
            nn.LeakyReLU(negative_slope=lr_slope),

                #shape = (..., 512, 16, 16)
            nn.Conv2d(number_of_feature_maps * 4, number_of_feature_maps * 8 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(number_of_feature_maps * 8),
            nn.LeakyReLU(negative_slope=lr_slope),

                #shape = (..., 256, 8, 8)
            nn.Conv2d(number_of_feature_maps * 8, number_of_feature_maps * 16 , kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(number_of_feature_maps * 16),
            nn.LeakyReLU(negative_slope=lr_slope),
            #  shape = (..., 128, 4, 4)

            nn.Conv2d(number_of_feature_maps * 16, 1 , kernel_size=kernel_size, stride=4, padding=padding, bias=False),

         )

    def forward(self, x, y):
        
        y = self.embedding(y)
        B,E = y.shape

        combined = torch.concat([x, y.unsqueeze(2).unsqueeze(3).view(x.shape[0], ModelArgs.no_of_channels, ModelArgs.img_size, ModelArgs.img_size)], dim=1)

        x = self.main(combined)

        return x

In [52]:
#Intializing the Discriminator instance
critic = Critic().to(ModelArgs.device)
#Apply the wieght intilization function layer by layer
critic = critic.apply(weights_init)
#Printing the structure
print(critic)

Critic(
  (embedding): Embedding(10, 4096)
  (main): Sequential(
    (0): Conv2d(2, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
    (3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2)
    (9): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): LeakyReLU(negative_slope=0.2)
    (12): Conv2d(1024, 1, kernel_size=(4, 4), stride=(4, 4), 

In [54]:
from torchinfo import summary

images = torch.randn(64, 1, 64, 64)
labels = torch.randint(0, 10, (64,), dtype=torch.long)

summary(model=critic,
        # input_size=(100, 1, 64, 64, 10),
        input_data=(images.to(ModelArgs.device), labels.to(ModelArgs.device)),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
Critic (Critic)                          [64, 1, 64, 64]      [64, 1, 1, 1]        --                   True
├─Embedding (embedding)                  [64]                 [64, 4096]           40,960               True
├─Sequential (main)                      [64, 2, 64, 64]      [64, 1, 1, 1]        --                   True
│    └─Conv2d (0)                        [64, 2, 64, 64]      [64, 128, 32, 32]    4,096                True
│    └─BatchNorm2d (1)                   [64, 128, 32, 32]    [64, 128, 32, 32]    256                  True
│    └─LeakyReLU (2)                     [64, 128, 32, 32]    [64, 128, 32, 32]    --                   --
│    └─Conv2d (3)                        [64, 128, 32, 32]    [64, 256, 16, 16]    524,288              True
│    └─BatchNorm2d (4)                   [64, 256, 16, 16]    [64, 256, 16, 16]    512                  True
│    └─LeakyReLU

In [25]:


generator = Generator().to(ModelArgs.device).apply(weights_init)
critic = Critic().to(ModelArgs.device).apply(weights_init)



epochs = 10000 #30


optimizerC = torch.optim.RMSprop(params=critic.parameters(), lr=ModelArgs.lr) #For discriminator
optimizerG = torch.optim.RMSprop(params=generator.parameters(), lr=ModelArgs.lr) #For generator



real_label = 1
fake_label = 0


loss_g = []
loss_d = []
img_list = []
    

# # Fixed noise for generating the images
fixed_noise = torch.randn((ModelArgs.batch_size, ModelArgs.latent_vector_size, 1, 1), dtype=torch.float32, device=ModelArgs.device)

In [23]:
import shutil
save_images = Path('output_images/MNIST')


In [26]:
#Training loop
from tqdm import tqdm


generator.train()
critic.train()
iters = 0

writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")

for epoch in tqdm(range(epochs)):
    

    for _ in range(ModelArgs.nCritic):
        
        # count = ModelArgs.batch_size
        # for X, y in trainloader:
        loss_tot = []
        loss_real= torch.tensor(0.0, device=ModelArgs.device)
        loss_fake= torch.tensor(0.0, device=ModelArgs.device)
        data_iter = iter(trainloader)

        X, y = next(data_iter)
        
        X = X.to(ModelArgs.device)
        y = y.to(ModelArgs.device)
        

        #Train the discriminator

        ############################
        # (1) Update D network: maximize: log(1 - D(G(z)))
        ###########################

        current_batch_size = X.shape[0]  #Getting the current batch size

        critic_real = critic(X,y).view(-1)
        
        noise = torch.randn((current_batch_size, ModelArgs.latent_vector_size, 1, 1), device=ModelArgs.device)
        
        
        
        fake_label = torch.randint(0, ModelArgs.no_of_lables, (current_batch_size, ), device=ModelArgs.device)
        
        noise_generated_by_generator = generator(noise, fake_label)

        critic_fake = critic(noise_generated_by_generator, fake_label).view(-1)
        

        loss_d = -(torch.mean(critic_real) - torch.mean(critic_fake))
        loss_d_output = loss_d

        optimizerC.zero_grad()

        # 4. Loss backward
        loss_d.backward(retain_graph=True)
            
        optimizerC.step()
        # loss_d.append(loss_tot.item())

        # Weight Clipping the parameters of the critic between -c to +c
        for p in critic.parameters():
            p.clamp(-(ModelArgs.c), ModelArgs.c)

    ############################
    # (2) Update G network: minimize log(D(G(z)))
    ###########################


    data_iter = iter(trainloader)
    X, y = next(data_iter)
        
    y = y.to(ModelArgs.device)
        
    
    #1. Forward pass

    noise = torch.randn((current_batch_size, ModelArgs.latent_vector_size, 1, 1), device=ModelArgs.device)
    noise_generated_by_generator = generator(noise, y)
    critic_gen = critic(noise_generated_by_generator, y).view(-1)
    
    loss_g = -(torch.mean(critic_gen))
    loss_g_output = loss_g
    # loss_G  += temp
        
        
    # 3. Optimizer zero grad
    optimizerG.zero_grad()

    # 4. Loss backward
    loss_g.backward(retain_graph=True)
        
    optimizerG.step()
    
    # loss_g.append(loss.item())

    if epoch % 100 == 0:
        print("Epoch: ", epoch, "Generator loss: ", loss_g_output.item(), "Discriminator loss: ", loss_d_output.item())


    with torch.no_grad():
        if epoch % 100 == 0:
            

            print('saving the output')
           
            fake = generator(fixed_noise, y)


            img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
            img_grid_real = torchvision.utils.make_grid(X, normalize=True)
                
            writer_fake.add_image(
                        "Mnist Fake Images", img_grid_fake, global_step=epoch
                    )
            writer_real.add_image(
                        "Mnist Real Images", img_grid_real, global_step=epoch
                    )


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 1/10000 [00:00<2:27:37,  1.13it/s]

Epoch:  0 Generator loss:  67.1589584350586 Discriminator loss:  -110.96104431152344
saving the output


  1%|          | 101/10000 [00:52<1:19:51,  2.07it/s]

Epoch:  100 Generator loss:  204.52890014648438 Discriminator loss:  -457.029296875
saving the output


  2%|▏         | 201/10000 [01:47<1:37:34,  1.67it/s]

Epoch:  200 Generator loss:  416.69805908203125 Discriminator loss:  -739.1362915039062
saving the output


  3%|▎         | 301/10000 [02:39<1:36:46,  1.67it/s]

Epoch:  300 Generator loss:  597.2041015625 Discriminator loss:  -1078.3162841796875
saving the output


  4%|▍         | 401/10000 [03:31<1:24:25,  1.90it/s]

Epoch:  400 Generator loss:  728.41796875 Discriminator loss:  -1369.512451171875
saving the output


  5%|▌         | 501/10000 [04:23<1:29:32,  1.77it/s]

Epoch:  500 Generator loss:  1078.882568359375 Discriminator loss:  -2073.3583984375
saving the output


  6%|▌         | 601/10000 [05:15<1:28:26,  1.77it/s]

Epoch:  600 Generator loss:  1270.7742919921875 Discriminator loss:  -2061.97021484375
saving the output


  7%|▋         | 701/10000 [06:06<1:17:21,  2.00it/s]

Epoch:  700 Generator loss:  1588.5242919921875 Discriminator loss:  -3156.283203125
saving the output


  8%|▊         | 801/10000 [06:58<1:28:44,  1.73it/s]

Epoch:  800 Generator loss:  1786.45947265625 Discriminator loss:  -3382.288330078125
saving the output


  9%|▉         | 901/10000 [07:49<1:24:40,  1.79it/s]

Epoch:  900 Generator loss:  2138.578125 Discriminator loss:  -3863.49609375
saving the output


 10%|█         | 1001/10000 [08:42<1:30:12,  1.66it/s]

Epoch:  1000 Generator loss:  2602.70458984375 Discriminator loss:  -4997.04052734375
saving the output


 11%|█         | 1101/10000 [09:33<1:28:45,  1.67it/s]

Epoch:  1100 Generator loss:  1598.454345703125 Discriminator loss:  -5043.9384765625
saving the output


 12%|█▏        | 1201/10000 [10:24<1:24:14,  1.74it/s]

Epoch:  1200 Generator loss:  2841.954833984375 Discriminator loss:  -6090.7978515625
saving the output


 13%|█▎        | 1301/10000 [11:16<1:25:43,  1.69it/s]

Epoch:  1300 Generator loss:  3862.17724609375 Discriminator loss:  -7859.525390625
saving the output


 14%|█▍        | 1401/10000 [12:07<1:22:03,  1.75it/s]

Epoch:  1400 Generator loss:  4215.3173828125 Discriminator loss:  -8505.833984375
saving the output


 15%|█▌        | 1501/10000 [12:59<1:21:41,  1.73it/s]

Epoch:  1500 Generator loss:  4685.98388671875 Discriminator loss:  -9552.9931640625
saving the output


 16%|█▌        | 1601/10000 [13:51<1:16:10,  1.84it/s]

Epoch:  1600 Generator loss:  5138.703125 Discriminator loss:  -10479.14453125
saving the output


 17%|█▋        | 1701/10000 [14:42<1:18:19,  1.77it/s]

Epoch:  1700 Generator loss:  5558.93310546875 Discriminator loss:  -11350.0859375
saving the output


 18%|█▊        | 1801/10000 [15:35<1:19:09,  1.73it/s]

Epoch:  1800 Generator loss:  6047.86474609375 Discriminator loss:  -12342.248046875
saving the output


 19%|█▉        | 1901/10000 [16:29<1:15:00,  1.80it/s]

Epoch:  1900 Generator loss:  -5707.587890625 Discriminator loss:  -6661.7919921875
saving the output


 20%|██        | 2001/10000 [17:19<1:12:23,  1.84it/s]

Epoch:  2000 Generator loss:  7000.421875 Discriminator loss:  -14293.2548828125
saving the output


 21%|██        | 2101/10000 [18:09<1:06:44,  1.97it/s]

Epoch:  2100 Generator loss:  7503.51953125 Discriminator loss:  -15322.8974609375
saving the output


 22%|██▏       | 2201/10000 [19:01<1:12:09,  1.80it/s]

Epoch:  2200 Generator loss:  8010.5458984375 Discriminator loss:  -16369.478515625
saving the output


 23%|██▎       | 2301/10000 [19:52<1:13:43,  1.74it/s]

Epoch:  2300 Generator loss:  8537.9541015625 Discriminator loss:  -17443.369140625
saving the output


 24%|██▍       | 2401/10000 [20:44<39:13,  3.23it/s]  

Epoch:  2400 Generator loss:  9039.408203125 Discriminator loss:  -18484.79296875
saving the output


 25%|██▌       | 2501/10000 [21:35<1:11:27,  1.75it/s]

Epoch:  2500 Generator loss:  9629.451171875 Discriminator loss:  -19679.86328125
saving the output


 26%|██▌       | 2601/10000 [22:26<1:13:11,  1.68it/s]

Epoch:  2600 Generator loss:  10189.755859375 Discriminator loss:  -20827.61328125
saving the output


 27%|██▋       | 2701/10000 [23:19<1:12:02,  1.69it/s]

Epoch:  2700 Generator loss:  10778.8115234375 Discriminator loss:  -22033.80078125
saving the output


 28%|██▊       | 2801/10000 [24:10<1:12:41,  1.65it/s]

Epoch:  2800 Generator loss:  11367.0224609375 Discriminator loss:  -23238.369140625
saving the output


 29%|██▉       | 2901/10000 [25:02<1:07:19,  1.76it/s]

Epoch:  2900 Generator loss:  11982.8837890625 Discriminator loss:  -24500.53125
saving the output


 30%|███       | 3001/10000 [25:52<1:06:36,  1.75it/s]

Epoch:  3000 Generator loss:  12598.3681640625 Discriminator loss:  -25767.787109375
saving the output


 31%|███       | 3101/10000 [26:43<1:04:37,  1.78it/s]

Epoch:  3100 Generator loss:  13218.451171875 Discriminator loss:  -27042.89453125
saving the output


 32%|███▏      | 3201/10000 [27:32<37:31,  3.02it/s]  

Epoch:  3200 Generator loss:  13882.2451171875 Discriminator loss:  -28398.568359375
saving the output


 33%|███▎      | 3301/10000 [28:25<1:03:18,  1.76it/s]

Epoch:  3300 Generator loss:  14552.1630859375 Discriminator loss:  -29772.41796875
saving the output


 34%|███▍      | 3401/10000 [29:17<1:04:10,  1.71it/s]

Epoch:  3400 Generator loss:  15227.623046875 Discriminator loss:  -31148.751953125
saving the output


 35%|███▌      | 3500/10000 [30:09<55:52,  1.94it/s]  

Epoch:  3500 Generator loss:  15902.123046875 Discriminator loss:  -32545.44921875
saving the output


 36%|███▌      | 3601/10000 [30:59<1:00:19,  1.77it/s]

Epoch:  3600 Generator loss:  16620.8125 Discriminator loss:  -34015.43359375
saving the output


 37%|███▋      | 3701/10000 [31:50<57:50,  1.82it/s]  

Epoch:  3700 Generator loss:  17287.13671875 Discriminator loss:  -35379.5703125
saving the output


 38%|███▊      | 3801/10000 [32:41<57:36,  1.79it/s]

Epoch:  3800 Generator loss:  18048.95703125 Discriminator loss:  -36952.953125
saving the output


 39%|███▉      | 3901/10000 [33:31<57:07,  1.78it/s]

Epoch:  3900 Generator loss:  18813.65234375 Discriminator loss:  -38513.5078125
saving the output


 40%|████      | 4001/10000 [34:22<55:29,  1.80it/s]

Epoch:  4000 Generator loss:  19574.291015625 Discriminator loss:  -40075.62890625
saving the output


 41%|████      | 4101/10000 [35:13<53:18,  1.84it/s]

Epoch:  4100 Generator loss:  20356.451171875 Discriminator loss:  -41688.96875
saving the output


 42%|████▏     | 4201/10000 [36:04<55:29,  1.74it/s]

Epoch:  4200 Generator loss:  21139.09375 Discriminator loss:  -43290.4765625
saving the output


 43%|████▎     | 4300/10000 [36:54<49:06,  1.93it/s]

Epoch:  4300 Generator loss:  21491.25 Discriminator loss:  -44003.625
saving the output


 44%|████▍     | 4401/10000 [37:45<52:13,  1.79it/s]

Epoch:  4400 Generator loss:  22757.015625 Discriminator loss:  -46612.18359375
saving the output


 45%|████▌     | 4501/10000 [38:35<52:16,  1.75it/s]

Epoch:  4500 Generator loss:  23599.984375 Discriminator loss:  -48339.3046875
saving the output


 46%|████▌     | 4601/10000 [39:27<50:56,  1.77it/s]

Epoch:  4600 Generator loss:  24441.6328125 Discriminator loss:  -50070.6484375
saving the output


 47%|████▋     | 4701/10000 [40:17<50:32,  1.75it/s]

Epoch:  4700 Generator loss:  25303.22265625 Discriminator loss:  -51840.6953125
saving the output


 48%|████▊     | 4801/10000 [41:07<42:54,  2.02it/s]

Epoch:  4800 Generator loss:  26186.359375 Discriminator loss:  -53655.609375
saving the output


 49%|████▉     | 4901/10000 [41:59<48:04,  1.77it/s]

Epoch:  4900 Generator loss:  27079.03515625 Discriminator loss:  -55488.6015625
saving the output


 50%|█████     | 5001/10000 [42:49<46:06,  1.81it/s]

Epoch:  5000 Generator loss:  27975.033203125 Discriminator loss:  -57323.02734375
saving the output


 51%|█████     | 5101/10000 [43:40<45:37,  1.79it/s]

Epoch:  5100 Generator loss:  28824.767578125 Discriminator loss:  -59096.9765625
saving the output


 52%|█████▏    | 5201/10000 [44:30<37:41,  2.12it/s]

Epoch:  5200 Generator loss:  29834.671875 Discriminator loss:  -61145.171875
saving the output


 53%|█████▎    | 5301/10000 [45:22<44:16,  1.77it/s]

Epoch:  5300 Generator loss:  30639.1875 Discriminator loss:  -62815.71875
saving the output


 54%|█████▍    | 5401/10000 [46:13<41:44,  1.84it/s]

Epoch:  5400 Generator loss:  31702.916015625 Discriminator loss:  -64985.90625
saving the output


 55%|█████▌    | 5501/10000 [47:01<43:26,  1.73it/s]

Epoch:  5500 Generator loss:  32681.916015625 Discriminator loss:  -66997.0234375
saving the output


 56%|█████▌    | 5601/10000 [47:53<40:01,  1.83it/s]

Epoch:  5600 Generator loss:  33151.375 Discriminator loss:  -67748.71875
saving the output


 57%|█████▋    | 5701/10000 [48:45<37:10,  1.93it/s]

Epoch:  5700 Generator loss:  34667.9765625 Discriminator loss:  -71073.8046875
saving the output


 58%|█████▊    | 5801/10000 [49:35<38:55,  1.80it/s]

Epoch:  5800 Generator loss:  35712.328125 Discriminator loss:  -73214.46875
saving the output


 59%|█████▉    | 5901/10000 [50:25<36:20,  1.88it/s]

Epoch:  5900 Generator loss:  36743.8125 Discriminator loss:  -75336.328125
saving the output


 60%|██████    | 6001/10000 [51:16<36:46,  1.81it/s]

Epoch:  6000 Generator loss:  37794.08203125 Discriminator loss:  -77495.5234375
saving the output


 61%|██████    | 6101/10000 [52:07<38:13,  1.70it/s]

Epoch:  6100 Generator loss:  38779.33984375 Discriminator loss:  -79549.8125
saving the output


 62%|██████▏   | 6201/10000 [52:58<35:29,  1.78it/s]

Epoch:  6200 Generator loss:  39905.953125 Discriminator loss:  -81844.0
saving the output


 63%|██████▎   | 6301/10000 [53:48<34:07,  1.81it/s]

Epoch:  6300 Generator loss:  41021.140625 Discriminator loss:  -84130.3515625
saving the output


 64%|██████▍   | 6401/10000 [54:38<32:40,  1.84it/s]

Epoch:  6400 Generator loss:  42101.421875 Discriminator loss:  -86357.2734375
saving the output


 65%|██████▌   | 6501/10000 [55:29<32:23,  1.80it/s]

Epoch:  6500 Generator loss:  43191.95703125 Discriminator loss:  -88595.03125
saving the output


 66%|██████▌   | 6601/10000 [56:19<31:17,  1.81it/s]

Epoch:  6600 Generator loss:  44357.421875 Discriminator loss:  -90983.421875
saving the output


 67%|██████▋   | 6701/10000 [57:10<30:31,  1.80it/s]

Epoch:  6700 Generator loss:  45070.3671875 Discriminator loss:  -92392.5
saving the output


 68%|██████▊   | 6801/10000 [58:00<29:25,  1.81it/s]

Epoch:  6800 Generator loss:  46594.84375 Discriminator loss:  -95627.453125
saving the output


 69%|██████▉   | 6901/10000 [58:50<26:35,  1.94it/s]

Epoch:  6900 Generator loss:  47814.609375 Discriminator loss:  -98094.1953125
saving the output


 70%|███████   | 7001/10000 [59:41<28:10,  1.77it/s]

Epoch:  7000 Generator loss:  48989.390625 Discriminator loss:  -100522.3125
saving the output


 71%|███████   | 7101/10000 [1:00:31<26:46,  1.80it/s]

Epoch:  7100 Generator loss:  49934.7890625 Discriminator loss:  -102638.25
saving the output


 72%|███████▏  | 7201/10000 [1:01:22<25:30,  1.83it/s]

Epoch:  7200 Generator loss:  51398.70703125 Discriminator loss:  -105459.2265625
saving the output


 73%|███████▎  | 7301/10000 [1:02:13<24:55,  1.80it/s]

Epoch:  7300 Generator loss:  52641.48046875 Discriminator loss:  -108016.421875
saving the output


 74%|███████▍  | 7401/10000 [1:03:03<22:59,  1.88it/s]

Epoch:  7400 Generator loss:  53880.4921875 Discriminator loss:  -110566.46875
saving the output


 75%|███████▌  | 7501/10000 [1:03:54<23:30,  1.77it/s]

Epoch:  7500 Generator loss:  55066.84375 Discriminator loss:  -113025.3046875
saving the output


 76%|███████▌  | 7601/10000 [1:04:45<23:06,  1.73it/s]

Epoch:  7600 Generator loss:  56400.3046875 Discriminator loss:  -115742.421875
saving the output


 77%|███████▋  | 7701/10000 [1:05:36<11:54,  3.22it/s]

Epoch:  7700 Generator loss:  57673.69921875 Discriminator loss:  -118361.40625
saving the output


 78%|███████▊  | 7801/10000 [1:06:27<20:39,  1.77it/s]

Epoch:  7800 Generator loss:  58991.17578125 Discriminator loss:  -121071.234375
saving the output


 79%|███████▉  | 7901/10000 [1:07:17<19:06,  1.83it/s]

Epoch:  7900 Generator loss:  60305.03515625 Discriminator loss:  -123776.484375
saving the output


 80%|████████  | 8001/10000 [1:08:09<19:05,  1.74it/s]

Epoch:  8000 Generator loss:  61640.29296875 Discriminator loss:  -126518.59375
saving the output


 81%|████████  | 8101/10000 [1:08:59<18:05,  1.75it/s]

Epoch:  8100 Generator loss:  62981.32421875 Discriminator loss:  -129279.609375
saving the output


 82%|████████▏ | 8201/10000 [1:09:50<16:29,  1.82it/s]

Epoch:  8200 Generator loss:  64325.0546875 Discriminator loss:  -132045.734375
saving the output


 83%|████████▎ | 8301/10000 [1:10:41<16:20,  1.73it/s]

Epoch:  8300 Generator loss:  65245.69921875 Discriminator loss:  -132941.328125
saving the output


 84%|████████▍ | 8401/10000 [1:11:32<14:56,  1.78it/s]

Epoch:  8400 Generator loss:  67064.375 Discriminator loss:  -137686.046875
saving the output


 85%|████████▌ | 8501/10000 [1:12:22<14:07,  1.77it/s]

Epoch:  8500 Generator loss:  68362.515625 Discriminator loss:  -140413.75
saving the output


 86%|████████▌ | 8601/10000 [1:13:12<08:44,  2.67it/s]

Epoch:  8600 Generator loss:  69765.2578125 Discriminator loss:  -143307.28125
saving the output


 87%|████████▋ | 8701/10000 [1:14:03<12:28,  1.73it/s]

Epoch:  8700 Generator loss:  71225.046875 Discriminator loss:  -146252.875
saving the output


 88%|████████▊ | 8801/10000 [1:14:54<11:40,  1.71it/s]

Epoch:  8800 Generator loss:  72758.875 Discriminator loss:  -149379.125
saving the output


 89%|████████▉ | 8901/10000 [1:15:44<10:12,  1.80it/s]

Epoch:  8900 Generator loss:  74144.9375 Discriminator loss:  -152291.921875
saving the output


 90%|█████████ | 9001/10000 [1:16:34<10:05,  1.65it/s]

Epoch:  9000 Generator loss:  75522.3125 Discriminator loss:  -155115.84375
saving the output


 91%|█████████ | 9101/10000 [1:17:25<07:53,  1.90it/s]

Epoch:  9100 Generator loss:  76614.421875 Discriminator loss:  -157381.5
saving the output


 92%|█████████▏| 9201/10000 [1:18:16<07:36,  1.75it/s]

Epoch:  9200 Generator loss:  78577.4609375 Discriminator loss:  -161387.09375
saving the output


 93%|█████████▎| 9301/10000 [1:19:07<06:33,  1.78it/s]

Epoch:  9300 Generator loss:  80139.3359375 Discriminator loss:  -164577.3125
saving the output


 94%|█████████▍| 9401/10000 [1:19:58<05:32,  1.80it/s]

Epoch:  9400 Generator loss:  81653.078125 Discriminator loss:  -167679.5625
saving the output


 95%|█████████▌| 9501/10000 [1:20:48<04:53,  1.70it/s]

Epoch:  9500 Generator loss:  83221.0546875 Discriminator loss:  -170894.03125
saving the output


 96%|█████████▌| 9601/10000 [1:21:38<03:39,  1.82it/s]

Epoch:  9600 Generator loss:  84771.6640625 Discriminator loss:  -174096.171875
saving the output


 97%|█████████▋| 9701/10000 [1:22:30<02:45,  1.80it/s]

Epoch:  9700 Generator loss:  86342.6328125 Discriminator loss:  -177325.359375
saving the output


 98%|█████████▊| 9801/10000 [1:23:19<01:53,  1.76it/s]

Epoch:  9800 Generator loss:  87919.46875 Discriminator loss:  -180583.71875
saving the output


 99%|█████████▉| 9900/10000 [1:24:10<00:51,  1.93it/s]

Epoch:  9900 Generator loss:  89503.7421875 Discriminator loss:  -183836.8125
saving the output


100%|██████████| 10000/10000 [1:25:00<00:00,  1.96it/s]


In [12]:
#For MNIST
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import numpy as np

plt.figure(figsize=(10,5))
plt.title("Generator and Critic Loss During Training")
plt.plot(loss_g,label="G")
plt.plot(loss_d,label="D")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

