<a href="https://colab.research.google.com/github/FCUAIC/Basic-ML/blob/main/mnist_cgan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MNIST CGAN

作者: 梁定殷

版權歸FCUAI所有

Input shape: (1, 28, 28)

## 超參數(Hyperparameters)

In [None]:
# 學習率(Learning Rate), LR越大模型越有自信, LR越小模型越沒自信
LR = 0.0002
# 每次學習要看過多少的Batch後才更新模型
BATCH_SIZE = 32
# 學習次數
EPOCHS = 500

C = 0.03

用離線的MNIST = False # Pytorch的MNIST暫時不能用

## 載入需要用的Package

In [None]:
import numpy as np
import random
import torch
import torch.nn as nn
from torch.optim import Adam

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid

if 用離線的MNIST:
    from keras.datasets import mnist
else:
    from torchvision.datasets import MNIST


from sklearn.model_selection import  train_test_split
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from PIL import Image

from tqdm.notebook import trange, tqdm

In [None]:
# 固定亂數
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)

<torch._C.Generator at 0x7f20196c8550>

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

## 模型

In [None]:
class MNISTGenerator(nn.Module):
    def __init__(self):
        super(MNISTGenerator, self).__init__()
        self.label_embedding = nn.Embedding(10, 32)
        self.net = nn.Sequential(
            nn.ConvTranspose2d(in_channels=100 + 32, out_channels=512, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=128, out_channels=1, kernel_size=2, stride=2, padding=2, bias=False),
            nn.Tanh(),
        )

    def forward(self, noise, label):
        embedded_label = self.label_embedding(label)
        embedded_label = torch.reshape(embedded_label, (label.size(0), 32, 1, 1))
        return self.net(torch.cat([noise, embedded_label], 1))
        
G = MNISTGenerator().cuda()
G.apply(weights_init)
print(G)

RuntimeError: ignored

In [None]:
class MNISTDiscriminator(nn.Module):
    def __init__(self):
        super(MNISTDiscriminator, self).__init__()
        self.label_embedding = nn.Sequential(
            nn.Embedding(10, 32),
            nn.Linear(32, 1 * 28 * 28)
        )
        self.cnn = nn.Sequential(
            nn.Conv2d(2, 16, 4, 2, 1),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(16, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, True),

            nn.Flatten(),
        )

        self.classifier = nn.Sequential(
            nn.Linear(1568, 512),
            nn.Dropout(),
            nn.LeakyReLU(0.2, True),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, label):
        label_embed = self.label_embedding(label)
        label_reshaped = torch.reshape(label_embed, (img.size(0), 1, 28, 28))
        inp = torch.cat([img, label_reshaped], 1)
        cnn_oup = self.cnn(inp)
        class_oup = self.classifier(cnn_oup)
        return class_oup

D = MNISTDiscriminator().cuda()
D.apply(weights_init)
print(D)

## 資料集(Dataset)

In [None]:
class MNISTDataset(Dataset):
    """
        這是我們定義的Dataset
    """
    def __init__(self, train=False, transformer=None):
        # 從Keras載入MNIST
        (train_feature, train_label), (test_feature, test_label) = mnist.load_data()
        if train:
            # 我們只要訓練的
            self.data = np.array([list(d) for d in zip(train_feature, train_label)], dtype=object)
            self.length = len(train_feature)
        else:
            # 我們只要測試的
            self.data = np.array([list(d) for d in zip(test_feature, test_label)], dtype=object)
            self.length = len(test_feature)
        
        self.transformer = transformer

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # 對data做轉換
        if self.transformer:
            img, label = self.data[idx]
            return self.transformer(img), torch.tensor(label, dtype=torch.long)
        return self.data[idx]

## 載入資料集(Dataset)

In [None]:
preprocessor = transforms.Compose([
    transforms.ToTensor() #轉成Tensor的時候會做歸一化(Normalize)
    ])


if 用離線的MNIST:
    print('使用離線的Dataset.')
    mnist_train = MNISTDataset(train=True, transformer=preprocessor)
    mnist_test = MNISTDataset(train=False, transformer=preprocessor)
else:
    print('使用Pytorch的Dataset.')
    mnist_train = MNIST(root='mnist', download=True, transform=preprocessor, train=True)
    mnist_test = MNIST(root='mnist', transform=preprocessor, train=False)

print(f'訓練資料一共有{len(mnist_train)}筆資料\n測試資料一共有{len(mnist_test)}筆資料')


mnist_train = DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=True) # 我們想要打散訓練資料的順序
mnist_test = DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=True)

## 宣告損失函數&優化器

In [None]:
g_optim = Adam(G.parameters(), lr=LR)
d_optim = Adam(D.parameters(), lr=LR)

## 訓練

In [None]:
# 開始訓練
for epoch in trange(1, EPOCHS+1, desc='Epoch', unit='次'):
    total_d_loss = 0
    total_g_loss = 0
    # 給模型看很多的圖
    for idx, (batch_x, batch_y) in tqdm(enumerate(mnist_test), desc='訓練進度', unit='batch'):
        # 把圖跟答案放到GPU
        x = batch_x.cuda()
        y = batch_y.cuda()
        # 生成噪聲
        noise = torch.randn(x.size(0), 100, 1, 1).cuda()

        d_optim.zero_grad()

        # 真圖
        d_pred_real = D(x, y).view(-1)
        # 假圖
        fake_data = G(noise, y)
        d_pred_fake = D(fake_data, y).view(-1)
        # 使用WGAN的loss算法來計算loss
        d_loss = -torch.mean(d_pred_real) + torch.mean(d_pred_fake)
        d_loss.backward()
        d_optim.step()
        total_d_loss += d_loss.item()

        # 限制網路權重在-C到C之間
        for p in D.parameters():
            p.data.clamp_(-C, C)

        # 每5次更新一次生成網路
        if idx % 5 == 0:
            g_optim.zero_grad()
            d_pred = D(G(noise, y), y).view(-1)
            g_loss = -torch.mean(d_pred)
            g_loss.backward()
            g_optim.step()
            total_g_loss += g_loss.item()

    print(f'EPOCH {epoch} | d_loss: {total_d_loss} | g_loss: {total_g_loss}')
    with torch.no_grad():
        for i in range(0, 10):
            gen = G(torch.randn(1, 100, 1, 1).cuda(), torch.tensor([i], device='cuda')).detach().cpu()
            img = torch.reshape(gen, (28, 28)).numpy()
            img = Image.fromarray(np.uint8(img * 255), 'L')
            display(img)