<a href="https://colab.research.google.com/github/GuiXu40/deeplearning0/blob/main/Basic_code/Basic_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### 1. 导包

In [2]:
import torch
import torch.nn as nn
import torchvision
import numpy as np

### 2. 网络结构

In [11]:
image_size = [1, 28, 28]
latent_dim = 64
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
        nn.Linear(latent_dim, 128, bias=True),
        nn.ReLU(),
        nn.Linear(128, 256, bias=True),
        nn.ReLU(),
        nn.Linear(256, 512, bias=True),
        nn.ReLU(),
        nn.Linear(512, 1024, bias=True),
        nn.ReLU(),
        nn.Linear(1024, np.prod(image_size, dtype=np.int32), bias=True),
        nn.Sigmoid() ,
    )
  def forward(self, x):
    output = self.model(x)
    return output.reshape(x.shape[0], *image_size) # 输出图片

class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.model = nn.Sequential(
        nn.Linear(np.prod(image_size, dtype=np.int32), 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, 32),
        nn.ReLU(),
        nn.Linear(32, 1),
        nn.Sigmoid(),
    )

  def forward(self, image):
    prob = self.model(image.reshape(image.shape[0], -1))
    return prob


### 3. 加载数据

In [7]:
batch_size = 32

dataset = torchvision.datasets.MNIST("./data", train=True, download=True,
                                     transform=torchvision.transforms.Compose([
                                         torchvision.transforms.Resize(28),
                                         torchvision.transforms.ToTensor(),
                                     ])
                                    )
print(len(dataset)) # 60000
print(dataset[0][0].shape) #[1, 28, 28]

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

60000
torch.Size([1, 28, 28])


In [12]:
# 优化器 / 损失函数

generator = Generator()
discriminator = Discriminator()

g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)

loss_fn = nn.BCELoss()

In [None]:
use_gpu = torch.cuda.is_available()
# 训练
epochs = 20
labels_one = torch.ones(batch_size, 1)
labels_zero = torch.zeros(batch_size, 1)
if use_gpu:
    print("use gpu for training")
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    loss_fn = loss_fn.cuda()
    labels_one = labels_one.to("cuda")
    labels_zero = labels_zero.to("cuda")


for epoch in range(epochs):
  for i, mini_batch in enumerate(dataloader):
    gt_images, _ = mini_batch

    z = torch.rand(batch_size, latent_dim)
    if use_gpu:
      gt_images = gt_images.to("cuda")
      z = z.to("cuda")
    pred_images = generator(z)
    g_optimizer.zero_grad()

    #recons_loss = torch.abs(pred_images-gt_images).mean()
    #g_loss = recons_loss*0.05 + loss_fn(discriminator(pred_images), labels_one)
    g_loss = loss_fn(discriminator(pred_images), labels_one)
    g_loss.backward()
    g_optimizer.step()

    d_optimizer.zero_grad()
    real_loss = loss_fn(discriminator(gt_images), labels_one)
    fake_loss = loss_fn(discriminator(pred_images.detach()), labels_zero)
    d_loss = (real_loss + fake_loss)

    # 观察real_loss与fake_loss，同时下降同时达到最小值，并且差不多大，说明D已经稳定了

    d_loss.backward()
    d_optimizer.step()

    if i % 50 == 0:
        # print(f"step:{len(dataloader)*epoch+i}, recons_loss:{recons_loss.item()}, g_loss:{g_loss.item()}, d_loss:{d_loss.item()}, real_loss:{real_loss.item()}, fake_loss:{fake_loss.item()}")
        print(f"step:{len(dataloader)*epoch+i}, g_loss:{g_loss.item()}, d_loss:{d_loss.item()}, real_loss:{real_loss.item()}, fake_loss:{fake_loss.item()}")

    if i % 400 == 0:
        image = pred_images[:16].data
        torchvision.utils.save_image(image, f"image_{len(dataloader)*epoch+i}.png", nrow=4)


In [20]:
!rm -f sample_data

rm: cannot remove 'sample_data': Is a directory
