In [None]:
import torch
torch.manual_seed(41)
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Configurations

In [None]:
##check if cuda is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

## if cuda is available, get GPU-name for double-check
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))

#learning rate
lr = 0.0001
#beta1 and beta2 for adam optimizer
#pytorch default coeff. / recommended values
beta1 = 0.9
beta2 = 0.999
#batchsize, noise dimens. and epochs
batchsize = 128
noise_dim = 64
epochs = 50

# Loading Dataset

In [None]:
#download Fashion-MNIST Dataset
#performing transformation on the image data
from torchvision import datasets, transforms as T

#transform dataset from (h,c,w) into (c,h,w)
transform = T.Compose([
    T.ToTensor()
])
#loading training and test datasets
training_set = datasets.FashionMNIST(root='Fashion_MNIST/', train=True, download=True, transform=transform )
test_set = datasets.FashionMNIST(root = 'Fashion_MNIST/', train=False, download=True,transform=transform)

#num of training and testdata
print("Total number of trainingset:", len(training_set))
print("Total number of testset:" ,len(test_set))

In [None]:
#check the image and label of it
image, label = training_set[7000]
print("Label of the showed image is:",label)
###squeeze method is used to remove single-dimensional entries from the shape of an array
plt.imshow(image.squeeze(), cmap='gray')


# Load Dataset into Batches

In [None]:
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

In [None]:
trainloader = DataLoader(training_set, batch_size=batchsize, shuffle=True)
##check the number of total batches
#it should be 469; 60000 images / 128 images per epoch = 469
#it depends on the batchsize number, in this tutorial batchsize = 128
print("Batches in trainloader:", len(trainloader))

dataiter = iter(trainloader)
images, label = next(dataiter)
#shape of the images; (batchsize, channel, hight, width)
print(images.shape)

#function to show 16 images
#input: images, number of images that will be displayed

def show_images(images, number_images=16):
    
    # if device is gpu, we have to move tensor to cpu:
    #img_cpu = images.detach().cpu()
    #img_mesh = make_grif(img_cpu[:number_images], nrow = 4)
    
    #if device is cpu:
    img_mesh = make_grid(images[:number_images], nrow=4)
    plt.imshow(img_mesh.permute(1, 2, 0).squeeze())
    plt.show()


In [None]:
show_images(images, number_images=16)

# Discriminator Network

In [None]:
# basic building block for neural networks
from torch import nn
#summarize the network
from torchsummary import summary
#Relu function
from torch.nn.modules.activation import LeakyReLU
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.flatten import Flatten

In [None]:
def discriminator_network(in_channels, out_channels, kernel_size, stride):
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size, stride),
      nn.BatchNorm2d(out_channels),
      nn.LeakyReLU(0.2) )

In [None]:
class Discriminator(nn.Module):

  def __init__(self):
    super(Discriminator, self).__init__()
    # in channels, out channels, kernelsize and stride
    #more infos: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
    self.block_1 = discriminator_network(1, 16, (3,3), 2)
    self.block_2 = discriminator_network(16, 32, (5,5), 2)
    self.block_3 = discriminator_network(32, 64, (5,5), 2)
    self.flatten = nn.Flatten()
    self.linear = nn.Linear(in_features = 64, out_features=1)

  def forward(self, images):

    x1 = self.block_1(images)
    x2 = self.block_2(x1)
    x3 = self.block_3(x2)

    x4 = self.flatten(x3)
    x5 = self.linear(x4)

    return x5

In [None]:
Disc = Discriminator()
Disc.to(device)

summary(Disc, input_size=(1,28,28))

# Generator Network

In [None]:
def generator_network(in_channels, out_channels, kernel_size, stride, final_block = False):
  if final_block == True:
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride ),
        nn.Tanh()
    )
  return nn.Sequential(
      nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride),
      nn.BatchNorm2d(out_channels),
      nn.ReLU() )

In [None]:
class Generator(nn.Module):
  def __init__(self, noise_dim):
    super(Generator, self).__init__()

    self.noise_dim = noise_dim
    self.block_1 = generator_network(noise_dim, 256, (3,3), 2)
    self.block_2 = generator_network(256,128, (4,4), 1)
    self.block_3 = generator_network(128, 64, (3,3), 2)
    self.block_4 = generator_network(64,1, (4,4), 2, final_block=True)

  def forward(self, r_noise_vec):
    ##shape of r_noise is: (batch_size, noise)-> (batch_size, noise_dim, 1, 1)
    x = r_noise_vec.view(-1, self.noise_dim,1,1)
    x1 = self.block_1(x)
    x2 = self.block_2(x1)
    x3 = self.block_3(x2)
    x4 = self.block_4(x3)

    return x4

In [None]:
Gen = Generator(noise_dim)
Gen.to(device)

summary(Gen, input_size =(1, noise_dim))

In [None]:
# Replace Random initialized weights to Normal weights

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
        nn.init.constant_(m.bias, 0)

In [None]:
Disc = Disc.apply(weights_init)
Gen = Gen.apply(weights_init)

# Loss Function and Load Optimizer

In [None]:
#two loses: real loss and fake loss
def real_loss(disc_pred):
  criterion = nn.BCEWithLogitsLoss()
  ground_truth = torch.ones_like(disc_pred)
  loss = criterion(disc_pred, ground_truth)
  return loss

def fake_loss(disc_pred):
  criterion = nn.BCEWithLogitsLoss()
  ground_truth = torch.zeros_like(disc_pred)
  loss = criterion(disc_pred, ground_truth)
  return loss

In [None]:
Disc_opt = torch.optim.Adam(Disc.parameters(), lr = lr, betas=(beta1, beta2))
Gen_opt = torch.optim.Adam(Gen.parameters(), lr = lr, betas=(beta1, beta2))

In [None]:
for i in range(epochs):

  total_disc_loss = 0.0
  total_gen_loss = 0.0

  for real_img, _ in tqdm(trainloader):

    real_img = real_img.to(device)
    noise = torch.randn(batchsize, noise_dim, device = device)

    #find loss and update weights for Discriminator
    Disc_opt.zero_grad()

    fake_img = Gen(noise)
    D_pred = Disc(fake_img)
    D_fake_loss = fake_loss(D_pred)

    D_pred = Disc(real_img)
    D_real_loss = real_loss(D_pred)

    D_loss = (D_fake_loss + D_real_loss)/2
    total_disc_loss += D_loss.item()

    D_loss.backward()
    Disc_opt.step()

    #find loss and update weights for Gen
    Gen_opt.zero_grad()
    noise = torch.randn(batchsize, noise_dim, device= device)

    fake_img = Gen(noise)
    D_pred = Disc(fake_img)
    G_loss = real_loss(D_pred)
    total_gen_loss += G_loss.item()
    G_loss.backward()
    Gen_opt.step()

  avg_disc_loss = total_disc_loss / len(trainloader)
  avg_gen_loss = total_gen_loss / len(trainloader)

  print("Epoch: {} | Disc_loss: {} | Gen_loss: {}".format(i+1, avg_disc_loss, avg_gen_loss))

  show_images(fake_img)