In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)

print(torch.FloatTensor)

<class 'torch.FloatTensor'>


생성자

In [2]:
# H out = (Hin -1)*stride - 2*padding + (kernel_size-1) + 1
# H out = └(Hin + 2* padding - dilation * (kernel_size-1)-1)/stride +1┘
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=3, hidden_dim=32):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.gen_block(z_dim, hidden_dim*8, padding=0),
            self.gen_block(hidden_dim*8, hidden_dim*4),
            self.gen_block(hidden_dim*4, hidden_dim*2),
            self.gen_block(hidden_dim*2, hidden_dim),
            self.gen_block(hidden_dim, im_chan, final_layer=True),
        )
        
    def gen_block(self, in_channel, out_channel, kernel_size=4, padding=1, stride=2, dilation=1, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, dilation=1),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride=2, dilation=1, padding=1),
                nn.Tanh(),
            )
        
    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.z_dim, 1, 1)
    
    def forward(self, noise):
        x = self.unsqueeze_noise(noise)
        return self.gen(x)

critic

In [3]:
class Critic(nn.Module):
    def __init__(self, im_chan=3, hidden_dim=16):
        super(Critic, self).__init__()
        self.crit = nn.Sequential(
            self.crit_block(im_chan, hidden_dim),
            self.crit_block(hidden_dim, hidden_dim*2),
            self.crit_block(hidden_dim*2, hidden_dim*4),
            self.crit_block(hidden_dim*4, hidden_dim*8),
            self.crit_block(hidden_dim*8, 1, final_layer=True),
        )
        
    def crit_block(self, in_, out, kernel_size=4, stride=2, padding=1, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(in_, out, kernel_size=kernel_size, stride=stride, padding=padding),
                nn.BatchNorm2d(out),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:
            return nn.Sequential(
                nn.Conv2d(in_, out, kernel_size=kernel_size, stride=stride, padding=padding)
            )
        
    def forward(self, image):
        crit_pred = self.crit(image)
        return crit_pred.view(len(crit_pred), -1)

One-hot vector 생성

In [4]:
import torch.nn.functional as F
def get_one_Hot_labels(labels, n_clases):
  return F.one_hot(labels, n_classes)

Latent vector와 one-hot vector의 concatenation

In [5]:
def combine_vectors(x, y):
  combined = torch.cat((x.float(), y.float()), 1)
  return combied

노이즈 생성

In [6]:
def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

초기화

In [7]:
n_epochs = 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cuda'

데이터 로딩

In [8]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [16]:
cd /content/drive/My Drive/celeba/

/content/drive/My Drive/celeba


In [20]:
url = 'img_align_celeba/'

transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
])

dataset = torchvision.datasets.ImageFolder(root = url, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,)),
])


dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
    )

optimizer

In [21]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
crit = Critic().to(device)
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
        
gen = gen.apply(weights_init)
crit = crit.apply(weights_init)

gradient

In [22]:
def get_gradient(crit, real, fake, epsilon):
  mixed_images = real * epsilon + fake * (1 - epsilon)
  mixed_scores = crit(mixed_images)

  gradient = torch.autograd.grad(
      inputs = mixed_images,
      outputs = mixed_scores,
      
      grad_outputs = torch.ones_like(mixed_scores),
      create_graph = True,
      retain_graph = True,
  ) [0]
  return gradient

In [23]:
def gradient_penalty(gradient):
  gradient = gradient.view(len(gradient), -1)
  gradient_norm = gradient.norm(2, dim=1)

  penalty = torch.mean((gradient_norm -1)**2)
  return penalty

loss 함수

In [24]:
def get_gen_loss(crit_fake_pred):
  gen_loss = -1 * torch.mean(crit_fake_pred)
  return gen_loss

def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
  crit_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gp
  return crit_loss

image_display

In [25]:
def show_tensor_images(image_tensor, num_images=25, size=(3,64,64)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1,2,0).squeeze())
    plt.show()

training

In [None]:
import matplotlib.pyplot as plt

cur_step = 0
generator_losses = []
critic_losses = []
for epoch in range(n_epochs):
  for real, _ in tqdm(dataloader):
    cur_batch_size = len(real)
    real = real.to(device)
    mean_iteration_critic_loss = 0
    for _ in range (crit_repeats):
      crit_opt.zero_grad()
      fake_noise = get_noise(cur_batch_size, z_dim, device=device)
      fake = gen(fake_noise)
      crit_fake_pred = crit(fake.detach())
      crit_real_pred = crit(real)

      epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
      gradient = get_gradient(crit, real, fake.detach(), epsilon)
      gp = gradient_penalty(gradient)
      crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)

      mean_iteration_critic_loss += crit_loss.item() / crit_repeats
      crit_loss.backward(retain_graph=True)
      crit_opt.step()
    critic_losses += [mean_iteration_critic_loss]

    gen_opt.zero_grad()
    fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
    fake_2 = gen(fake_noise_2)
    crit_fake_pred = crit(fake_2)

    gen_loss = get_gen_loss(crit_fake_pred)
    gen_loss.backward()

    gen_opt.step()

    generator_losses += [gen_loss.item()]

    if cur_step % display_step == 0 and cur_step > 0:
      gen_mean = sum(generator_losses[-display_step:]) / display_step
      crit_mean = sum(critic_losses[-display_step:]) / display_step
      print(f"step : {cur_step}, Generator loss: {gen_mean}, Critic loss: {crit_mean}")
      show_tensor_images(fake)
      show_tensor_images(real)
      step_bins = 20
      num_examples = (len(generator_losses) // step_bins) * step_bins
      plt.plot (
          range(num_examples // step_bins), 
          torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
          label="Generator loss"
      )
      plt.plot (
          range(num_examples // step_bins), 
          torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
          label="Critic loss"
      )
      plt.legend()
      plt.show()
    cur_step += 1
                 

HBox(children=(FloatProgress(value=0.0, max=1220.0), HTML(value='')))