<a href="https://colab.research.google.com/github/HanzhouLiu/python_fromzero/blob/main/first_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [25]:
import torchvision.transforms as transforms
class Discriminator(nn.Module):
  def __init__(self, in_features):
    super().__init__() # super.() lets you access methods in a parent class
    self.disc = nn.Sequential(
        nn.Linear(in_features, 128),
        nn.LeakyReLU(0.1),
        nn.Sigmoid(),
    )

  def forward(self, x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim, 256),
        nn.LeakyReLU(0.1),
        nn.Linear(256, img_dim),
        nn.Tanh(),
    )

  def forward(self, x):
    return self.gen(x)

device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28*28*1
batch_size = 32
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

for epoch in range(num_epochs):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.view(-1, 784).to(device)
    batch_size = real.shape[0]

    # train discriminator: max_log(D(real)) + log(1-D(G(z)))
    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)
    disc_real = disc(real).view(-1)
    lossD_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake).view(-1)
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    lossD = (lossD_real + lossD_fake) / 2
    disc.zero_grad()
    lossD.backward(retain_graph=True)
    opt_disc.step()

    ### Train Generator min log(1 - D(G(Z))) <---> max log(D(G(z)))
    output = disc(fake).view(-1)
    lossG = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_disc.step()

    if batch_idx == 0:
      print(
          f"Epoch [{epoch}/{num_epochs}] \ "
          f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}"
      )

      with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
        data = real.reshape(-1, 1, 28, 28)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        img_grid_real = torchvision.utils.make_grid(data, normalize=True)

        writer_fake.add_image(
            "Mnist Fake Images", img_grid_fake, global_step=step
        )

        writer_fake.add_image(
            "Mnist Fake Images", img_grid_fake, global_step=step
        )


Epoch [0/50] \ Loss D: 0.6555, Loss G: 0.6723
Epoch [1/50] \ Loss D: 0.3406, Loss G: 0.7054
Epoch [2/50] \ Loss D: 0.3423, Loss G: 0.7018
Epoch [3/50] \ Loss D: 0.3424, Loss G: 0.7016
Epoch [4/50] \ Loss D: 0.3431, Loss G: 0.7005
Epoch [5/50] \ Loss D: 0.3431, Loss G: 0.7003
Epoch [6/50] \ Loss D: 0.3432, Loss G: 0.7000
Epoch [7/50] \ Loss D: 0.3436, Loss G: 0.6991
Epoch [8/50] \ Loss D: 0.3436, Loss G: 0.6992
Epoch [9/50] \ Loss D: 0.3438, Loss G: 0.6989
Epoch [10/50] \ Loss D: 0.3437, Loss G: 0.6990
Epoch [11/50] \ Loss D: 0.3439, Loss G: 0.6986
Epoch [12/50] \ Loss D: 0.3440, Loss G: 0.6985
Epoch [13/50] \ Loss D: 0.3441, Loss G: 0.6982
Epoch [14/50] \ Loss D: 0.3442, Loss G: 0.6981
Epoch [15/50] \ Loss D: 0.3443, Loss G: 0.6978
Epoch [16/50] \ Loss D: 0.3446, Loss G: 0.6973
Epoch [17/50] \ Loss D: 0.3444, Loss G: 0.6976
Epoch [18/50] \ Loss D: 0.3447, Loss G: 0.6970
Epoch [19/50] \ Loss D: 0.3446, Loss G: 0.6973
Epoch [20/50] \ Loss D: 0.3447, Loss G: 0.6969
Epoch [21/50] \ Loss D:

In [None]:
print(3e-4)

0.0003


In [None]:
class CheckingAccount: # an example of class
	def __init__(self, name, balance):
		self.name = name
		self.balance = balance

	def checkBalance(self):
		print("Balance: ", self.balance)

In [None]:
CheckingAccount('lily', 200).checkBalance() # print the class's attribute

Balance:  200


In [None]:
class ChildAccount(CheckingAccount):
	def __init__(self, name, balance, is_child=True):
		self.is_child = is_child
		super().__init__(name, balance)

	def checkBalance(self):
		print("You are a child account holder!")
		super().checkBalance()

In [None]:
ChildAccount('bob', 100).checkBalance()

You are a child account holder!
Balance:  100


In [None]:
from typing 