In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.utils as vutils
import numpy as np

In [2]:
batch_size = 64
n_epoch = 50
lr = 0.001
beta1 = 0.5
nz = 100
nch = 1
nch_g = 28
nch_d = 28
result_dir = "./gan_results"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: {}".format(device))
try:
    os.makedirs(result_dir)
except OSError:
    pass

device: cuda


In [3]:
def loadData(img_path, label_path):
    img_np = np.load(img_path)["arr_0"].reshape(-1, 1, 28, 28)
    img_np = img_np/255
    label_np = np.load(label_path)["arr_0"]
    return torch.utils.data.TensorDataset(torch.from_numpy(img_np),
                                            torch.from_numpy(label_np))

train_data = loadData("kmnist/k49-train-imgs.npz", "kmnist/k49-train-labels.npz")
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
test_data = loadData("kmnist/k49-test-imgs.npz", "kmnist/k49-test-labels.npz")
test_loader = torch.utils.data.DataLoader(test_data)

In [4]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [5]:
class Generator(nn.Module):
    def __init__(self, nz=100, nch_g=28, nch=1):
        super(Generator, self).__init__()
        self.l1 =  nn.Sequential(
            nn.ConvTranspose2d(nz, nch_g * 4, 3, 1), 
            nn.BatchNorm2d(nch_g * 4),    
            nn.ReLU()
        )  # (100, 1, 1) -> (112, 3, 3)
        self.l2 = nn.Sequential(
            nn.ConvTranspose2d(nch_g * 4, nch_g * 2, 5, 2, 1),
            nn.BatchNorm2d(nch_g * 2),
            nn.ReLU()
        )  # (112, 3, 3) -> (56, 7, 7)
        self.l3 = nn.Sequential(
            nn.ConvTranspose2d(nch_g * 2, nch_g, 4, 2, 1),
            nn.BatchNorm2d(nch_g),
            nn.ReLU()
        )  # (56, 7, 7) -> (28, 14, 14)
        self.l4 = nn.Sequential(
            nn.ConvTranspose2d(nch_g, nch, 4, 2, 1),
            nn.Tanh()
        )   # (28, 14, 14) -> (1, 28, 28)
    
    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        return x
        

net_g = Generator(nz=nz, nch_g=nch_g, nch=nch_g).to(device)
net_g.apply(weights_init)
print(net_g)

Generator(
  (l1): Sequential(
    (0): ConvTranspose2d(100, 112, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (l2): Sequential(
    (0): ConvTranspose2d(112, 56, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (l3): Sequential(
    (0): ConvTranspose2d(56, 28, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (l4): Sequential(
    (0): ConvTranspose2d(28, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): Tanh()
  )
)


In [6]:
class Discriminator(nn.Module):
    def __init__(self, nch=1, nch_d=28):
        super(Discriminator, self).__init__()
        self.l1 = nn.Sequential(
            nn.Conv2d(nch, nch_d, 4, 2, 1),
            nn.LeakyReLU(negative_slope=0.2)
            # (1, 28, 28) -> (28, 14, 14)
        )
        self.l2 = nn.Sequential(
            nn.Conv2d(nch_d, nch_d * 2, 4, 2, 1),
            nn.BatchNorm2d(nch_d * 2),
            nn.LeakyReLU(negative_slope=0.2)
            # (28, 14, 14) -> (56, 7, 7)
        )
        self.l3 = nn.Sequential(
            nn.Conv2d(nch_d * 2, nch_d * 4, 4, 2, 1),
            nn.BatchNorm2d(nch_d * 4),
            nn.LeakyReLU(negative_slope=0.2)
            # (56, 7, 7) -> (112, 3, 3)
        )
        self.l4 = nn.Conv2d(nch_d * 4, 1, 3, 1)
            # (112, 4, 4) -> (1, 1, 1)
        
    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        return x
    
net_d = Discriminator(nch=nch, nch_d=nch_d).to(device)
net_d.apply(weights_init)
print(net_d)

Discriminator(
  (l1): Sequential(
    (0): Conv2d(1, 28, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
  )
  (l2): Sequential(
    (0): Conv2d(28, 56, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (l3): Sequential(
    (0): Conv2d(56, 112, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (l4): Conv2d(112, 1, kernel_size=(3, 3), stride=(1, 1))
)


In [7]:
criterion = nn.MSELoss()
optimizer_g = torch.optim.Adam(net_g.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5)
optimizer_d = torch.optim.Adam(net_d.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5)

fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)

In [8]:
def train(epoch):
    for step, (data, target) in enumerate(train_loader, 0):
        real_image = data.to(device, dtype=torch.float)
        sample_size = real_image.size(0)
        noise = torch.randn(sample_size, nz, 1, 1, device=device)
        
        real_target = torch.full((sample_size,), 1, device=device)
        fake_target = torch.full((sample_size,), 0, device=device)
        
        net_d.zero_grad()
        output = net_d(real_image)
        loss_d_real = criterion(output, real_target)
        d_x = output.mean().item()

        fake_image = net_g(noise)
        output = net_d(fake_image.detach())
        loss_d_fake = criterion(output, fake_target)
        d_g_z1 = output.mean().item()
        
        loss_d = loss_d_real + loss_d_fake
        loss_d.backward()
        optimizer_d.step()
        
        net_g.zero_grad()
        output = net_d(fake_image)
        loss_g = criterion(output, real_target)
        loss_g.backward()
        d_g_z2 = output.mean().item()
        optimizer_g.step()
        if(step == 0):
            print("Epoch:{}, Loss_G:{:.4}, Loss_D:{:.4}".format(epoch+1, loss_g.item(), loss_d.item()))
    fake_image = net_g(fixed_noise)
    vutils.save_image(fake_image.detach(), '{}/fake_samples_epoch_{:03d}.png'.format(result_dir, epoch + 1), normalize=True, nrow=8)

In [9]:
for epoch in range(3):
    train(epoch)

Epoch:1, Loss_G:5.732, Loss_D:0.8799
Epoch:2, Loss_G:0.8599, Loss_D:0.1057
Epoch:3, Loss_G:0.8004, Loss_D:0.06943
