# 概要

このノートブックでは,
pytorchを用いてGANの実装を行う.

In [1]:
from keras.datasets import mnist

In [2]:
(train_X, train_Y), (test_X, test_Y) = mnist.load_data()
train_X, test_X = train_X.reshape(train_X.shape[0], -1), test_X.reshape(test_X.shape[0], -1)
train_Y, test_Y = train_Y.reshape(-1), test_Y.reshape(-1)
train_X, test_X = train_X/255, test_X/255

## Generative Adversarial Network (GAN)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [5]:
class Model(nn.Module):
    def __init__(self, *structure):
        self.depth = len(structure)
        self.units = structure
        
        for i in range(self.depth-1):
            layer = nn.Linear(self.units[i], self.units[i+1])
            nn.init.kaiming_normal_(layer.weight, mode="fan_in", nonlinearity="relu")
            setattr(self, f'fc{i+1}', layer)
            
    def forward(self, x):
        out = x
        for i in range(self.depth-2):
            out = F.relu(getattr(self, f'fc{i+1}')(out))
        return torch.sigmoid(getattr(self, f'fc{self.depth-1}')(out))

In [6]:
class Solver(nn.Module):
    def __init__(self, max_epoch = 100, batch_size = 128, g_structure=[100, 300, 500, 784], d_structure=[784, 1024, 1024, 1]):
        self.max_epoch = max_epoch
        self.batch_size = batch_size
        
        if torch.cuda.is_available():
            self.device = torch.device("cuda:0")
        else:
            self.device = torch.device("cpu")
        
        self.G = Model(*g_structure).to(self.device)
        self.D = Model(*d_structure).to(self.device)
        self.g_structure = g_structure
        self.d_structure = d_structure
        self.g_optimizer = torch.optim.Adam(self.G.parameters())
        self.d_optimizer = torch.optim.Adam(self.D.parameters())
        
        self.criterion = nn.BCELoss()
        
    def fit(self, train_X):
        self.g_loss = []
        self.d_loss = []
        for epoch in range(self.max_epoch):
            g_loss, d_loss = self._epoch_procedure(train_X)
            self.g_loss.append(g_loss)
            self.d_loss.append(d_loss)
            
    def _epoch_procedure(self, train_X):
        loader = torch.utils.data.DataLoader(torch.from_numpy(train_X).to(self.device), batch_size=self.batch_size, shuffle=True)
        
        running_d_loss = 0.0
        running_g_loss = 0.0
        for idx, (inputs) in enumerate(loader):
            inputs = Variable(inputs).float()
            real_label = torch.ones(inputs.size(0)).long()
            fake_label = torch.zeros(inputs.size(0)).long()
            z = np.random.normal(0.0, 1.0, (inputs.size(0), self.g_structure[0]))
            z = torch.from_numpy(z).to(self.device)
            z = Variable(z).float()
            