In [15]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import numpy as np
import math
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [16]:
pip install torchvision

Note: you may need to restart the kernel to use updated packages.


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets
from torchvision import models
from torchvision import transforms
from torch.utils.data import DataLoader

In [18]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
    

In [19]:
config = AttrDict()
config.data_path = 'data/'
config.save_path = 'save/'
config.dataset = 'MNIST'
config.n_epoch = 200
config.log_interval = 100
config.save_interval = 20
config.batch_size = 64
config.learning_rate = 0.0002
config.b1 = 0.5
config.b2 = 0.999
config.img_shape = (1, 28, 28)
config.latent_size = 100
"""
모델 입력 이미지에 수행할 normalization과 모델 생성 결과 이미지에 수행할 denormalization을 정의합니다.
"""
config.augmentation = transforms.Compose([
                        transforms.Resize((config.img_shape[1], config.img_shape[2])),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.5], std=[0.5]) 
                      ])
config.denormalize = lambda x: x*0.5+0.5
config.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [20]:
if not os.path.isdir(config.data_path):
    os.makedirs(config.data_path)
if not os.path.isdir(os.path.join(config.save_path, config.dataset)):
    os.makedirs(os.path.join(config.save_path, config.dataset))

In [21]:
config.device

device(type='cpu')

In [22]:
if config.dataset == 'MNIST':
    train_dataset = datasets.MNIST(config.data_path,
                                    train=True,
                                    download=True,
                                    transform=config.augmentation
                                  ) 
###
#elif config.dataset == 'CIFAR10': 
    #train_dataset = datasets.CIFAR10(config.data_path,
                                       #train=True,
                                       #download=True,
                                       #transform=config.augmentation
                                     #)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
print(train_dataset)

Dataset MNIST
    Number of datapoints: 60000
    Root location: data/
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.5], std=[0.5])
           )


In [23]:
class Discriminator(nn.Module):
    def __init__(self, config):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(config.img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img = img.reshape(img.shape[0], -1)
        validity = self.model(img)
        return validity

In [24]:
class Generator(nn.Module):
    def __init__(self, config):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(config.latent_size, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, int(np.prod(config.img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.reshape(img.shape[0], *config.img_shape)
        return img
    
    def block(self, input_size, output_size, batchnorm=True):
        layers = [nn.Linear(input_size, output_size)]
        if batchnorm:
            layers.append(nn.BatchNorm1d(output_size))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

In [25]:
criterion = nn.BCELoss()
generator = Generator(config).to(config.device)
discriminator = Discriminator(config).to(config.device)

optimizer_g = torch.optim.Adam(generator.parameters(), lr=config.learning_rate, betas=(config.b1, config.b2))
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=config.learning_rate, betas=(config.b1, config.b2))

In [26]:
generator.model

Sequential(
  (0): Linear(in_features=100, out_features=256, bias=True)
  (1): LeakyReLU(negative_slope=0.2, inplace=True)
  (2): Linear(in_features=256, out_features=512, bias=True)
  (3): LeakyReLU(negative_slope=0.2, inplace=True)
  (4): Linear(in_features=512, out_features=1024, bias=True)
  (5): LeakyReLU(negative_slope=0.2, inplace=True)
  (6): Linear(in_features=1024, out_features=784, bias=True)
  (7): Tanh()
)

In [27]:
discriminator.model

Sequential(
  (0): Linear(in_features=784, out_features=512, bias=True)
  (1): LeakyReLU(negative_slope=0.2, inplace=True)
  (2): Linear(in_features=512, out_features=256, bias=True)
  (3): LeakyReLU(negative_slope=0.2, inplace=True)
  (4): Linear(in_features=256, out_features=1, bias=True)
  (5): Sigmoid()
)

In [28]:
g_loss_list = []
d_loss_list = []
for epoch in tqdm(range(config.n_epoch)):
    for i, (real_img, _) in enumerate(train_loader):
        
        real_img = real_img.to(config.device)
        
        valid_label = torch.ones((real_img.shape[0], 1), device=config.device, dtype=torch.float32)
        fake_label = torch.zeros((real_img.shape[0], 1), device=config.device, dtype=torch.float32)

        
        z = torch.randn((real_img.shape[0], config.latent_size), device=config.device, dtype=torch.float32)
        gen_img = generator(z)

        real_loss = criterion(discriminator(real_img), valid_label)
        fake_loss = criterion(discriminator(gen_img.detach()), fake_label)
        d_loss = (real_loss + fake_loss) * 0.5
        
        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()
        
        z = torch.randn((real_img.shape[0], config.latent_size), device=config.device, dtype=torch.float32)
        gen_img = generator(z)
        
        g_loss = criterion(discriminator(gen_img), valid_label)
        
        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()

        


        if (i+1) % config.log_interval == 0:
            g_loss_list.append(g_loss.item())
            d_loss_list.append(d_loss.item())
            print('Epoch [{}/{}] Batch [{}/{}] Discriminator loss: {:.4f} Generator loss: {:.4f}'.format(
                epoch+1, config.n_epoch, i+1, len(train_loader), d_loss.item(), g_loss.item()))

    if (epoch+1) % config.save_interval == 0:
        save_path = os.path.join(config.save_path, config.dataset, 'epoch_[{}].png'.format(epoch+1))
        gen_img = config.denormalize(gen_img)
        torchvision.utils.save_image(gen_img.data[:25], save_path, nrow=5, normalize=True)

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch [1/200] Batch [100/938] Discriminator loss: 0.7079 Generator loss: 0.7254
Epoch [1/200] Batch [200/938] Discriminator loss: 0.6800 Generator loss: 0.7843
Epoch [1/200] Batch [300/938] Discriminator loss: 0.5256 Generator loss: 0.9310
Epoch [1/200] Batch [400/938] Discriminator loss: 0.5359 Generator loss: 0.9733
Epoch [1/200] Batch [500/938] Discriminator loss: 0.5488 Generator loss: 1.3646
Epoch [1/200] Batch [600/938] Discriminator loss: 0.4027 Generator loss: 1.4724
Epoch [1/200] Batch [700/938] Discriminator loss: 0.4820 Generator loss: 1.3167
Epoch [1/200] Batch [800/938] Discriminator loss: 0.4823 Generator loss: 2.1213
Epoch [1/200] Batch [900/938] Discriminator loss: 0.4688 Generator loss: 1.6766
Epoch [2/200] Batch [100/938] Discriminator loss: 0.4939 Generator loss: 3.0066
Epoch [2/200] Batch [200/938] Discriminator loss: 0.4346 Generator loss: 1.5072
Epoch [2/200] Batch [300/938] Discriminator loss: 0.3850 Generator loss: 1.8061
Epoch [2/200] Batch [400/938] Discrimina

plt.title('GAN training loss on {} data'.format(config.dataset))
plt.plot(g_loss_list, label='generator loss')
plt.plot(d_loss_list, label='discriminator loss')
plt.legend()
plt.show()

In [31]:
save_path = os.path.join(config.save_path, config.dataset)
for image_path in os.listdir(save_path):
    if image_path.endswith('.png'):
        plt.figure(figsize=(5,5))
        image = image.open(os.path.join(save_path,image_path))
        plt.title(image_path)
        plt.imshow(image)
        plt.show()

NameError: name 'image' is not defined

<Figure size 360x360 with 0 Axes>