In [15]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision
from torch.utils.data import DataLoader
import time

In [39]:
batch_size = 64
image_size = 64

transform = torchvision.transforms.Compose([torchvision.transforms.Resize(image_size), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(0.5, 0.5)])

path = r"/Users/desidero/Desktop/Kodlar/GANs/mnist_torch/"
dataset = torchvision.datasets.MNIST(root=path, train=False, download=True, transform=transform)
data = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [40]:
for image, label in data:
    print(image.shape)
    print(label.shape)
    break

torch.Size([64, 1, 64, 64])
torch.Size([64])


In [41]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_

In [42]:
class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
            )

    def forward(self, x):
        return self.main(x)
    

In [43]:
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False), # discriminating a number between 0 and 1, so out_channel=1
            nn.Sigmoid() # because sigmoid returns a value between 0 and 1
            )

    def forward(self, x):
        output = self.main(x)
        return output.view(-1)
        

In [44]:
device = torch.device('mps')

In [45]:
gen = Generator().to(device)
disc = Discriminator().to(device)

gen.apply(weights_init) # .apply method applies a function on class
disc.apply(weights_init) # .apply method applies a function on class

criterion = nn.BCELoss()
gen_optim = torch.optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
disc_optim = torch.optim.Adam(disc.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [48]:
epochs = 100
liste = [i+1 for i in range(epochs)]
j = 0
for epoch in range(epochs):
    print("Epoch: {}".format(epoch))
    for real, label in data: # i will start from 0
       
        # Step 1: Updating the weights of the discriminator 
        disc.zero_grad()
        #print(data.shape)
        # Training the discriminator with a real image of the dataset
        #real, label = data
        real = real.to(device)
        #label = label.to(device)
        input_real = Variable(real)
        target = Variable(torch.ones(input_real.size()[0])).to(device) # target should be 1 because this images are real images
        output = disc.forward(input_real).to(device)
        disc_loss_real = criterion(output, target).to(device)
        
        # Training the discriminator with a fake image generated by the generator
        noise = Variable(torch.randn(input_real.size()[0], 100, 1, 1)).to(device) # giving fake dimensions with 1,1 after 100
        fake = gen.forward(noise).to(device)
        target = Variable(torch.zeros(input_real.size()[0])).to(device) # target should be 1 because this images are fake images
        output = disc.forward(fake.detach()) # .detach to save some cmemory and make the training faster
        disc_loss_fake = criterion(output, target).to(device)

        # Backpropogating the total error
        disc_loss = disc_loss_real + disc_loss_fake
        disc_loss.backward()
        disc_optim.step()

        # Step 2: Updating the weights of the generator
        gen.zero_grad()
        target = Variable(torch.ones(input_real.size()[0])).to(device)
        output = disc.forward(fake).to(device)
        gen_loss = criterion(output, target).to(device)
        gen_loss.backward()
        gen_optim.step()        

        # Step 3: Printing the losses, then saving the real and generated images of the minibatch every 100 steps
        print("Loss_D: %.4f, Loss_G: %.4f" % (disc_loss, gen_loss))
    torchvision.utils.save_image(real, '%s/real_samples_{}.png'.format(liste[j]) %  r"/Users/desidero/Desktop/Kodlar/GANs/mnist_torch/", normalize=True)
    fake = gen.forward(noise)
    torchvision.utils.save_image(fake.data, '%s/fake_samples_{}.png'.format(liste[j]) % r"/Users/desidero/Desktop/Kodlar/GANs/mnist_torch/", normalize=True)
    j += 1

Epoch: 0
Loss_D: 0.0486, Loss_G: 6.6606
Loss_D: 0.2645, Loss_G: 6.1163
Loss_D: 0.4816, Loss_G: 7.6838
Loss_D: 0.2502, Loss_G: 7.1517
Loss_D: 0.1670, Loss_G: 6.7131
Loss_D: 0.1766, Loss_G: 8.0779
Loss_D: 0.0464, Loss_G: 7.7084
Loss_D: 0.0961, Loss_G: 7.1896
Loss_D: 0.0954, Loss_G: 8.1389
Loss_D: 0.0792, Loss_G: 7.8709
Loss_D: 0.0888, Loss_G: 8.9995
Loss_D: 0.0930, Loss_G: 7.5884
Loss_D: 0.2705, Loss_G: 13.8748
Loss_D: 0.2078, Loss_G: 10.5491
Loss_D: 0.0362, Loss_G: 6.6309
Loss_D: 0.9136, Loss_G: 21.1447
Loss_D: 2.0457, Loss_G: 15.1089
Loss_D: 0.0151, Loss_G: 4.4071
Loss_D: 5.6032, Loss_G: 18.2590
Loss_D: 0.1946, Loss_G: 20.6413
Loss_D: 0.2542, Loss_G: 14.5471
Loss_D: 0.2302, Loss_G: 5.1781
Loss_D: 2.4797, Loss_G: 21.0998
Loss_D: 0.7420, Loss_G: 22.7621
Loss_D: 0.0913, Loss_G: 18.1244
Loss_D: 0.0568, Loss_G: 9.3338
Loss_D: 0.1631, Loss_G: 8.5966
Loss_D: 0.1463, Loss_G: 11.1993
Loss_D: 0.0348, Loss_G: 9.2814
Loss_D: 0.1774, Loss_G: 10.4494
Loss_D: 0.1362, Loss_G: 8.5581
Loss_D: 0.1519, Lo