<a href="https://colab.research.google.com/github/MarcYu0303/GAN-and-Diffusion-Model/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import torch.nn as nn
import numpy as np
import os

In [None]:
image_size = torch.tensor([1, 28, 28]) # 图像大小的常量
latent_dim = 64
batch_size = 32

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    
    self.model = nn.Sequential(
        nn.Linear(latent_dim, 64),
        nn.ReLU(inplace=True),
        nn.Linear(64, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 256),
        nn.ReLU(inplace=True),
        nn.Linear(256, 512),
        nn.ReLU(inplace=True),
        nn.Linear(512, 1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024, torch.prod(image_size, dtype=torch.int32)), # torch.prod(image_size, dtype=torch.int32) 图片长乘宽
        nn.Tanh(), # 映射到（-1,1）
    )

  def forward(self, z): 
    # shape of z: [batchsize, latent_dim]; z is a Gaussian noise
    output = self.model(z)
    image = output.reshape(z.shape[0], *image_size) # z.shape[0] batch size 大小； *image_size list->tuple

    return image

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    self.model = nn.Sequential(
        nn.Linear(torch.prod(image_size, dtype=torch.int32), 1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024, 512),
        nn.ReLU(inplace=True),
        nn.Linear(512, 256),
        nn.ReLU(inplace=True),
        nn.Linear(256, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 1),
        nn.Sigmoid(),
    )

  def forward(self, image):
    # shape of iamge: [batchsize, 1, 28, 28];
     prob =  self.model(image.reshape(image.shape[0], -1))
     return prob

In [None]:
# Loading MNIST dataset
dataset = torchvision.datasets.MNIST("minist_data", train=True, download=True,
                    transform=torchvision.transforms.Compose(
                        [
                            torchvision.transforms.Resize(28),
                            torchvision.transforms.ToTensor(),
                            torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
                        ]))
print(dataset)

Dataset MNIST
    Number of datapoints: 60000
    Root location: minist_data
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=28, interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.5], std=[0.5])
           )


In [None]:
for i in range(len(dataset)):
  if i < 3:
    print(dataset[i][0].shape)
    print(dataset[i][1])
  else: break

torch.Size([1, 28, 28])
5


In [None]:
# Training

dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

generator = Generator()
discriminator = Discriminator()

loss_fn = nn.BCELoss() # 真或假,用交叉熵函数

g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0001)

num_epochs = 100

for epoch in range(num_epochs):
  for i, mini_batch in enumerate(dataloader):
    get_images, _ = mini_batch
    
    z = torch.randn(batch_size, latent_dim)
    pred_images = generator(z)

    '''
    对生成器的优化
    '''
    g_optimizer.zero_grad()
    target = torch.ones(batch_size, 1)
    g_loss = loss_fn(discriminator(pred_images), target)
    g_loss.backward()
    g_optimizer.step()

    '''
    对判别器的优化
    '''
    d_optimizer.zero_grad()
    d_loss = loss_fn(discriminator(get_images), torch.ones(batch_size, 1)) + loss_fn(discriminator(pred_images.detach()), torch.zeros(batch_size, 1))
    # .detach() 从计算图中分离， 不需要梯度信息
    d_loss.backward()
    d_optimizer.step()

    if i % 1000 == 0:
      for index, image in enumerate(pred_images):
        torchvision.utils.save_image(image, f"image_{index}.png")

KeyboardInterrupt: ignored