In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.utils as vutils
import numpy as np

In [2]:
batch_size = 49
n_epoch = 50
lr_g = 0.001
lr_d = 0.0002
beta1 = 0.5
nz = 100
nch = 1
nch_g = 28
nch_d = 28
n_label = 49
result_dir = "./gan_results"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: {}".format(device))
try:
    os.makedirs(result_dir)
except OSError:
    pass

device: cuda


In [3]:
def loadData(img_path, label_path):
    img_np = np.load(img_path)["arr_0"].reshape(-1, 1, 28, 28)
    img_np = img_np/255
    label_np = np.load(label_path)["arr_0"]
    return torch.utils.data.TensorDataset(torch.from_numpy(img_np),
                                            torch.from_numpy(label_np))

train_data = loadData("kmnist/k49-train-imgs.npz", "kmnist/k49-train-labels.npz")
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
test_data = loadData("kmnist/k49-test-imgs.npz", "kmnist/k49-test-labels.npz")
test_loader = torch.utils.data.DataLoader(test_data)

In [4]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [5]:
class Generator(nn.Module):
    def __init__(self, nz=100, nch_g=28, nch=1):
        super(Generator, self).__init__()
        self.l1 =  nn.Sequential(
            nn.ConvTranspose2d(nz, nch_g * 4, 3, 1), 
            nn.BatchNorm2d(nch_g * 4),    
            nn.ReLU()
        )  # (100, 1, 1) -> (112, 3, 3)
        self.l2 = nn.Sequential(
            nn.ConvTranspose2d(nch_g * 4, nch_g * 2, 5, 2, 1),
            nn.BatchNorm2d(nch_g * 2),
            nn.ReLU()
        )  # (112, 3, 3) -> (56, 7, 7)
        self.l3 = nn.Sequential(
            nn.ConvTranspose2d(nch_g * 2, nch_g, 4, 2, 1),
            nn.BatchNorm2d(nch_g),
            nn.ReLU()
        )  # (56, 7, 7) -> (28, 14, 14)
        self.l4 = nn.Sequential(
            nn.ConvTranspose2d(nch_g, nch, 4, 2, 1),
            nn.Tanh()
        )   # (28, 14, 14) -> (1, 28, 28)
    
    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        return x
        

net_g = Generator(nz=(nz + n_label), nch_g=nch_g, nch=nch).to(device)
net_g.apply(weights_init)
print(net_g)

Generator(
  (l1): Sequential(
    (0): ConvTranspose2d(149, 112, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (l2): Sequential(
    (0): ConvTranspose2d(112, 56, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (l3): Sequential(
    (0): ConvTranspose2d(56, 28, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (l4): Sequential(
    (0): ConvTranspose2d(28, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): Tanh()
  )
)


In [6]:
class Discriminator(nn.Module):
    def __init__(self, nch=1, nch_d=28):
        super(Discriminator, self).__init__()
        self.l1 = nn.Sequential(
            nn.Conv2d(nch, nch_d, 4, 2, 1),
            nn.LeakyReLU(negative_slope=0.2)
            # (1, 28, 28) -> (28, 14, 14)
        )
        self.l2 = nn.Sequential(
            nn.Conv2d(nch_d, nch_d * 2, 4, 2, 1),
            nn.BatchNorm2d(nch_d * 2),
            nn.LeakyReLU(negative_slope=0.2)
            # (28, 14, 14) -> (56, 7, 7)
        )
        self.l3 = nn.Sequential(
            nn.Conv2d(nch_d * 2, nch_d * 4, 4, 2, 1),
            nn.BatchNorm2d(nch_d * 4),
            nn.LeakyReLU(negative_slope=0.2)
            # (56, 7, 7) -> (112, 3, 3)
        )
        self.l4 = nn.Conv2d(nch_d * 4, 1, 3, 1)
            # (112, 4, 4) -> (1, 1, 1)
        
    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        return x
    
net_d = Discriminator(nch=(nch + n_label), nch_d=nch_d).to(device)
net_d.apply(weights_init)
print(net_d)

Discriminator(
  (l1): Sequential(
    (0): Conv2d(50, 28, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
  )
  (l2): Sequential(
    (0): Conv2d(28, 56, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (l3): Sequential(
    (0): Conv2d(56, 112, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (l4): Conv2d(112, 1, kernel_size=(3, 3), stride=(1, 1))
)


In [7]:
def onehot_encode(label, device, n_label=49):
    eye = torch.eye(n_label, device=device)
    return eye[label].view(-1, n_label, 1, 1)

def concat_image_label(image, label, device, n_label=49):
    b, c, h, w = image.shape
    oh_label = onehot_encode(label, device)
    oh_label = oh_label.expand(b, n_label, h, w)
    return torch.cat((image, oh_label), dim=1)

def concat_noise_label(noise, label, device):
    oh_label = onehot_encode(label, device)
    return torch.cat((noise, oh_label), dim=1)

In [8]:
criterion = nn.MSELoss()
optimizer_g = torch.optim.Adam(net_g.parameters(), lr=lr_g, betas=(beta1, 0.999), weight_decay=1e-5)
optimizer_d = torch.optim.Adam(net_d.parameters(), lr=lr_d, betas=(beta1, 0.999), weight_decay=1e-5)

fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)
fixed_label = [i for i in range(n_label)] * (batch_size // n_label)
fixed_label = torch.tensor(fixed_label, dtype=torch.long, device=device)
fixed_noise_label = concat_noise_label(fixed_noise, fixed_label, device)

In [9]:
def train(epoch):
    for step, (data, target) in enumerate(train_loader, 0):
        real_image = data.to("cuda", dtype=torch.float)        
        real_label = target.to("cuda", dtype=torch.long)
        real_image_label = concat_image_label(real_image, real_label, device)
        
        sample_size = real_image.size(0)
        noise = torch.randn(sample_size, nz, 1, 1, device=device)
        fake_label = torch.randint(n_label, (sample_size,), dtype=torch.long, device=device)
        fake_noise_label = concat_noise_label(noise, fake_label, device)
        
        real_target = torch.full((sample_size,), 1, device=device)
        fake_target = torch.full((sample_size,), 0, device=device)
        
        net_d.zero_grad()
        output = net_d(real_image_label)
        loss_d_real = criterion(output, real_target)
        d_x = output.mean().item()

        fake_image = net_g(fake_noise_label)
        fake_image_label = concat_image_label(fake_image, fake_label, device=device)
        output = net_d(fake_image_label.detach())
        loss_d_fake = criterion(output, fake_target)
        d_g_z1 = output.mean().item()
        
        loss_d = loss_d_real + loss_d_fake
        loss_d.backward()
        optimizer_d.step()
        
        net_g.zero_grad()
        output = net_d(fake_image_label)
        loss_g = criterion(output, real_target)
        loss_g.backward()
        d_g_z2 = output.mean().item()
        optimizer_g.step()
        if(step == 0):
            print("Epoch:{}, Loss_G:{:.4}, Loss_D:{:.4}".format(epoch+1, loss_g.item(), loss_d.item()))
    fake_image = net_g(fixed_noise_label)
    vutils.save_image(fake_image.detach(), '{}/fake_samples_epoch_{:03d}.png'.format(result_dir, epoch + 1), normalize=True, nrow=7)
    if (epoch + 1) % 10 == 0:  
            torch.save(net_g.state_dict(), '{}/net_g_epoch_{}.pth'.format(result_dir, epoch + 1))
            torch.save(net_d.state_dict(), '{}/net_d_epoch_{}.pth'.format(result_dir, epoch + 1))

In [10]:
for epoch in range(50):
    train(epoch)

Epoch:1, Loss_G:0.7664, Loss_D:1.367
Epoch:2, Loss_G:1.167, Loss_D:0.02186
Epoch:3, Loss_G:1.034, Loss_D:0.01668
Epoch:4, Loss_G:0.7952, Loss_D:0.02059
Epoch:5, Loss_G:0.917, Loss_D:0.003356
Epoch:6, Loss_G:1.005, Loss_D:0.002974
Epoch:7, Loss_G:0.9656, Loss_D:0.002365
Epoch:8, Loss_G:0.9799, Loss_D:0.005243
Epoch:9, Loss_G:1.025, Loss_D:0.001608
Epoch:10, Loss_G:0.5724, Loss_D:0.1224
Epoch:11, Loss_G:1.038, Loss_D:0.0124
Epoch:12, Loss_G:0.8795, Loss_D:0.01976
Epoch:13, Loss_G:1.117, Loss_D:0.03371
Epoch:14, Loss_G:1.001, Loss_D:0.009833
Epoch:15, Loss_G:0.9277, Loss_D:0.01022
Epoch:16, Loss_G:1.04, Loss_D:0.007725
Epoch:17, Loss_G:0.9564, Loss_D:0.004904
Epoch:18, Loss_G:1.017, Loss_D:0.006856
Epoch:19, Loss_G:0.9428, Loss_D:0.02481
Epoch:20, Loss_G:0.9732, Loss_D:0.02724
Epoch:21, Loss_G:1.007, Loss_D:0.0124
Epoch:22, Loss_G:1.035, Loss_D:0.03262
Epoch:23, Loss_G:1.053, Loss_D:0.1369
Epoch:24, Loss_G:0.9568, Loss_D:0.01224
Epoch:25, Loss_G:1.138, Loss_D:0.007819
Epoch:26, Loss_G:0.9