<a href="https://colab.research.google.com/github/SisekoC/My-Notebooks/blob/main/GAN_Pytorch_Fashion_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!mkdir diff-run

In [2]:
!mkdir diff-run/images

In [3]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import datetime

In [4]:
torch.manual_seed(1)

<torch._C.Generator at 0x7e5bc82bc2b0>

In [5]:
writer = SummaryWriter('diff-run/py-gan')

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128

In [7]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])])
train_dataset = datasets.FashionMNIST(root='./data/', train=True, transform=train_transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:08<00:00, 2956940.53it/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 200453.52it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:03<00:00, 1443285.82it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 6169222.00it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






In [8]:
image_shape = (1, 28, 28)
image_dim = int(np.prod(image_shape))
latent_dim = 100

In [9]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.model = nn.Sequential(nn.Linear(latent_dim, 128),
                                    nn.LeakyReLU(0.2, inplace=True),
                                    nn.Linear(128, 256),
                      nn.BatchNorm1d(256, 0.8),
                                    nn.LeakyReLU(0.2, inplace=True),
                                    nn.Linear(256, 512),
                      nn.BatchNorm1d(512, 0.8),
                                    nn.LeakyReLU(0.2, inplace=True),
                      nn.Linear(512, 1024),
                      nn.BatchNorm1d(1024, 0.8),
                                    nn.LeakyReLU(0.2, inplace=True),
                                    nn.Linear(1024, image_dim),
                                    nn.Tanh())

    def forward(self, noise_vector):
        image = self.model(noise_vector)
        image = image.view(image.size(0), *image_shape)
        return image




In [10]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(nn.Linear(image_dim, 512),
                                    nn.LeakyReLU(0.2, inplace=True),
                                    nn.Linear(512, 256),
                                    nn.LeakyReLU(0.2, inplace=True),
                                    nn.Linear(256, 1),
                                    nn.Sigmoid())

    def forward(self, image):
        image_flattened = image.view(image.size(0), -1)
        result = self.model(image_flattened)
        return result

In [11]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [12]:
torch.save(generator.state_dict(), 'generator.pth')

In [13]:
 for layer in generator.children():
     print(layer.type)

<bound method Module.type of Sequential(
  (0): Linear(in_features=100, out_features=128, bias=True)
  (1): LeakyReLU(negative_slope=0.2, inplace=True)
  (2): Linear(in_features=128, out_features=256, bias=True)
  (3): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (4): LeakyReLU(negative_slope=0.2, inplace=True)
  (5): Linear(in_features=256, out_features=512, bias=True)
  (6): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (7): LeakyReLU(negative_slope=0.2, inplace=True)
  (8): Linear(in_features=512, out_features=1024, bias=True)
  (9): BatchNorm1d(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (10): LeakyReLU(negative_slope=0.2, inplace=True)
  (11): Linear(in_features=1024, out_features=784, bias=True)
  (12): Tanh()
)>


In [14]:
summary(generator, (100,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 128]          12,928
         LeakyReLU-2                  [-1, 128]               0
            Linear-3                  [-1, 256]          33,024
       BatchNorm1d-4                  [-1, 256]             512
         LeakyReLU-5                  [-1, 256]               0
            Linear-6                  [-1, 512]         131,584
       BatchNorm1d-7                  [-1, 512]           1,024
         LeakyReLU-8                  [-1, 512]               0
            Linear-9                 [-1, 1024]         525,312
      BatchNorm1d-10                 [-1, 1024]           2,048
        LeakyReLU-11                 [-1, 1024]               0
           Linear-12                  [-1, 784]         803,600
             Tanh-13                  [-1, 784]               0
Total params: 1,510,032
Trainable param

In [15]:
summary(discriminator, (1,28,28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 512]         401,920
         LeakyReLU-2                  [-1, 512]               0
            Linear-3                  [-1, 256]         131,328
         LeakyReLU-4                  [-1, 256]               0
            Linear-5                    [-1, 1]             257
           Sigmoid-6                    [-1, 1]               0
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 2.04
Estimated Total Size (MB): 2.05
----------------------------------------------------------------


In [16]:
adversarial_loss = nn.BCELoss()

In [17]:
learning_rate = 0.0002
G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

In [18]:
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [19]:
num_epochs = 500
D_loss_plot, G_loss_plot = [], []
for epoch in range(1, num_epochs+1):

    D_loss_list, G_loss_list = [], []

    for index, (real_images, _) in enumerate(train_loader):
        D_optimizer.zero_grad()
        real_images = real_images.to(device)
        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        D_real_loss = adversarial_loss(discriminator(real_images), real_target)
        # print(discriminator(real_images))

        noise_vector = Variable(torch.randn(real_images.size(0), latent_dim).to(device))
        #noise_vector = Variable(Tensor(np.random.normal(0, 1, \
        #                                                (real_images.size(0),\
        #                                                 latent_dim))))
        noise_vector = noise_vector.to(device)
        generated_image = generator(noise_vector)

        D_fake_loss = adversarial_loss(discriminator(generated_image),\
                                     fake_target)

        D_total_loss = D_real_loss + D_fake_loss
        D_loss_list.append(D_total_loss)
        D_total_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        generated_image = generator(noise_vector)
        G_loss = adversarial_loss(discriminator(generated_image), real_target)
        G_loss_list.append(G_loss)

        G_loss.backward()
        G_optimizer.step()
        d = generated_image.data

        writer.add_scalar('Discriminator Loss',
                            D_total_loss,
                            epoch * len(train_loader) + index)

        writer.add_scalar('Generator Loss',
                            G_loss,
                            epoch * len(train_loader) + index)


    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
            (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),\
             torch.mean(torch.FloatTensor(G_loss_list))))

    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))
    save_image(generated_image.data[:90], 'diff-run/images/sample_%d'%epoch + '.png', nrow=10, normalize=True)

Epoch: [1/500]: D_loss: 1.038, G_loss: 1.310
Epoch: [2/500]: D_loss: 0.868, G_loss: 2.099
Epoch: [3/500]: D_loss: 1.023, G_loss: 1.740
Epoch: [4/500]: D_loss: 1.064, G_loss: 1.570
Epoch: [5/500]: D_loss: 1.095, G_loss: 1.438
Epoch: [6/500]: D_loss: 1.118, G_loss: 1.392
Epoch: [7/500]: D_loss: 1.116, G_loss: 1.362
Epoch: [8/500]: D_loss: 1.089, G_loss: 1.477
Epoch: [9/500]: D_loss: 1.076, G_loss: 1.469
Epoch: [10/500]: D_loss: 1.096, G_loss: 1.383
Epoch: [11/500]: D_loss: 1.107, G_loss: 1.400
Epoch: [12/500]: D_loss: 1.119, G_loss: 1.358
Epoch: [13/500]: D_loss: 1.161, G_loss: 1.269
Epoch: [14/500]: D_loss: 1.147, G_loss: 1.250
Epoch: [15/500]: D_loss: 1.159, G_loss: 1.239
Epoch: [16/500]: D_loss: 1.158, G_loss: 1.243
Epoch: [17/500]: D_loss: 1.156, G_loss: 1.197
Epoch: [18/500]: D_loss: 1.193, G_loss: 1.161
Epoch: [19/500]: D_loss: 1.201, G_loss: 1.129
Epoch: [20/500]: D_loss: 1.201, G_loss: 1.108
Epoch: [21/500]: D_loss: 1.214, G_loss: 1.092
Epoch: [22/500]: D_loss: 1.202, G_loss: 1.1