<a href="https://colab.research.google.com/github/PsorTheDoctor/Sekcja-SI/blob/master/neural_networks/GAN/stargan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# StarGAN

## Import bibliotek

In [0]:
import glob
import random
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys
from PIL import Image

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.autograd as autograd

import torch.nn as nn
import torch.nn.functional as F
import torch

## Hiperparametry

In [0]:
os.makedirs('images', exist_ok=True)
os.makedirs('saved_models', exist_ok=True)

In [0]:
epoch = 0                          # epoka, od której zaczyna się trening
n_epochs = 200                     # ilość epok 
dataset_name = 'img_align_celeba_4000'  # nazwa zbioru danych 
batch_size = 16                    # rozmiar wsadów
lr = 0.0002                        # współczynnik uczenia
b1 = 0.5                           # momentum pierwszego rzędu gradientu
b2 = 0.999                         # momentum pierwszego rzędu gradientu
decay_epoch = 100                  # epoka, od której zaczyna się spadek współczynnika uczenia
n_cpu = 8                          # ilość wątków cpu użytych podczas generacji wsadu
img_height = 128                   # wysokość obrazu
img_width = 128                    # szerokość obrazu
channels = 3                       # ilość kanałów
sample_interval = 400              # przerwa między zapisem próbek generatora
checkpoint_interval = -1           # przerwa między punktami kontrolnymi modelu
residual_blocks = 6                # ilość bloków w generatorze

# wybrane atrybuty do zbioru danych CelebA
selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']

n_critic = 5                       # ilość iteracji trenowania dyskryminatora WGAN

c_dim = len(selected_attrs)
img_shape = (channels, img_height, img_width)

In [4]:
cuda = torch.cuda.is_available()
cuda

True

In [0]:
# Funkcje straty
criterion_cycle = torch.nn.L1Loss()

def criterion_cls(logit, target):
  return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)

# Wagi straty
lambda_cls = 1
lambda_rec = 10
lambda_gp = 10

## Funkcje
### Załadowanie danych

In [0]:
# num_of_files = 2000

class CelebADataset(Dataset):
  def __init__(self, root, transforms_=None, mode='train', attributes=None):
    self.transform = transforms.Compose(transforms_)

    self.selected_attrs = attributes 
    self.files = sorted(glob.glob('%s/*.jpg' % root))
    self.files = self.files[:-2000] if mode == 'train' else self.files[-2000:]
    self.label_path = glob.glob('%s/*.txt' % root)[0]
    self.annotations = self.get_annotations()

  def get_annotations(self):
    """Wyodrębnia adnotacje dla CelebA"""
    annotations = {}
    lines = [line.rstrip() for line in open(self.label_path, 'r')]
    self.label_names = lines[1].split()
    for _, line in enumerate(lines[2:]):
      filename, *values = line.split()
      labels = []
      for attr in self.selected_attrs:
        idx = self.label_names.index(attr)
        labels.append(1 * (values[idx] == '1'))
      annotations[filename] = labels
    return annotations

  def __getitem__(self, index):
    filepath = self.files[index % len(self.files)]
    filename = filepath.split('/')[-1]
    img = self.transform(Image.open(filepath))
    label = self.annotations[filename]
    label = torch.FloatTensor(np.array(label))

    return img, label

  def __len__(self):
    return len(self.files)

### Budowa modelu

In [0]:
def weights_init_normal(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

In [0]:
class ResidualBlock(nn.Module):
  def __init__(self, in_features):
    super(ResidualBlock, self).__init__()

    conv_block = [
      nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False),
      nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False),
      nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True),
    ]

    self.conv_block = nn.Sequential(*conv_block)

  def forward(self, x):
    return x + self.conv_block(x)

#### Generator

In [0]:
class GeneratorResNet(nn.Module):
  def __init__(self, img_shape=(3, 128, 128), res_blocks=9, c_dim=5):
    super(GeneratorResNet, self).__init__()
    channels, img_size, _ = img_shape

    # Wstępny blok konwolucyjny
    model = [
      nn.Conv2d(channels + c_dim, 64, 7, stride=1, padding=3, bias=False),
      nn.InstanceNorm2d(64, affine=True, track_running_stats=True),
      nn.ReLU(inplace=True)
    ]

    # Próbkowanie w dół
    curr_dim = 64
    for _ in range(2):
      model += [
        nn.Conv2d(curr_dim, curr_dim * 2, 4, stride=2, padding=1, bias=False),
        nn.InstanceNorm2d(curr_dim* 2, affine=True, track_running_stats=True),
        nn.ReLU(inplace=True),
      ]
      curr_dim *= 2

    # Pozostałe bloki
    for _ in range(res_blocks):
      model += [ResidualBlock(curr_dim)]

    # Próbkowanie w górę
    for _ in range(2):
      model += [
        nn.ConvTranspose2d(curr_dim, curr_dim // 2, 4, stride=2, padding=1, bias=False),
        nn.InstanceNorm2d(curr_dim // 2, affine=True, track_running_stats=True),
        nn.ReLU(inplace=True),
      ]
      curr_dim = curr_dim // 2
    
    # Warstwa wyjściowa
    model += [nn.Conv2d(curr_dim, channels, 7, stride=1, padding=3), nn.Tanh()]

    self.model = nn.Sequential(*model)

  def forward(self, x, c):
    c = c.view(c.size(0), c.size(1), 1, 1)
    c = c.repeat(1, 1, x.size(2), x.size(3))
    x = torch.cat((x, c), 1)
    return self.model(x)

#### Dyskryminator

In [0]:
class Discriminator(nn.Module):
  def __init__(self, img_shape=(3, 128, 128), c_dim=5, n_strided=6):
    super(Discriminator, self).__init__()
    channels, img_size, _ = img_shape

    def discriminator_block(in_filters, out_filters):
      """Zwraca próbkowane w dół warstwy każdego bloku dyskryminatora"""
      layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1), nn.LeakyReLU(0.01)]
      return layers

    layers = discriminator_block(channels, 64)
    curr_dim = 64
    for _ in range(n_strided - 1):
      layers.extend(discriminator_block(curr_dim, curr_dim * 2))
      curr_dim *= 2

    self.model = nn.Sequential(*layers)

    # Wyjście 1: PatchGAN
    self.out1 = nn.Conv2d(curr_dim, 1, 3, padding=1, bias=False)
    # Wyjście 2: Predykcja klasy
    kernel_size = img_size // 2 ** n_strided
    self.out2 = nn.Conv2d(curr_dim, c_dim, kernel_size, bias=False)

  def forward(self, img):
    feature_repr = self.model(img)
    out_adv = self.out1(feature_repr)
    out_cls = self.out2(feature_repr)
    return out_adv, out_cls.view(out_cls.size(0), -1)

## Inicjalizacja modelu

In [0]:
# Inicjalizacja generatora i dyskryminatora
generator = GeneratorResNet(img_shape=img_shape, res_blocks=residual_blocks, c_dim=c_dim)
discriminator = Discriminator(img_shape=img_shape, c_dim=c_dim)

In [0]:
if cuda:
  generator = generator.cuda()
  discriminator = discriminator.cuda()
  criterion_cycle.cuda()

if epoch != 0:
  # Załadowanie przetrenowanych modeli
  generator.load_state_dict(torch.load('saved_models/generator_%d.pth' % epoch))
  discriminator.load_state_dict(torch.load('saved_models/discriminator_%d.pth' % epoch))
else:
  generator.apply(weights_init_normal)
  discriminator.apply(weights_init_normal)

In [0]:
# Optymalizatory
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

## Załadowanie danych

In [14]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
# import zipfile
# zip_ref = zipfile.ZipFile('/content/drive/My Drive/img_align_celeba_4000.zip', 'r')
# zip_ref.extractall('/content/drive/My Drive')
# zip_ref.close()

### Zbiór treningowy

In [0]:
# Konfiguracja ładowania danych
train_transforms = [
  transforms.Resize(int(1.12 * img_height), Image.BICUBIC),
  transforms.RandomCrop(img_height),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

In [0]:
dataloader = DataLoader(
    CelebADataset(
        '/content/drive/My Drive/%s' % dataset_name, transforms_=train_transforms, mode='train', attributes=selected_attrs
    ),
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu,
)

### Zbiór walidacyjny

In [0]:
val_transforms = [
  transforms.Resize((img_height, img_width), Image.BICUBIC),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

In [0]:
val_dataloader = DataLoader(
    CelebADataset(
        '/content/drive/My Drive/%s' % dataset_name, transforms_=val_transforms, mode='val', attributes=selected_attrs
    ),
    batch_size=10,
    shuffle=True,
    num_workers=1,
)

## Funkcje

In [0]:
# Typ tensora
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [0]:
def compute_gradient_penalty(D, real_samples, fake_samples):
  # Losowa waga do interpolacji między prawdziwymi i fałszywymi próbkami
  alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
  # Wzięcie losowej interpolacji między prawdziwymi i fałszywymi próbkami
  interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
  d_interpolates, _ = D(interpolates)
  fake = Variable(Tensor(np.ones(d_interpolates.shape)), requires_grad=False)
  
  gradients = autograd.grad(
      outputs=d_interpolates,
      inputs=interpolates,
      grad_outputs=fake,
      create_graph=True,
      retain_graph=True,
      only_inputs=True,
  )[0]
  gradients = gradients.view(gradients.size(0), -1)
  gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
  return gradient_penalty

In [0]:
# Działa tylko dla selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']!

# (a, b)
# a      - indeks cechy w selected_attrs, do której się odwołujemy
# b = -1 - negacja 
# b = 0  - obojętność
# b = 1  - wystąpienie

label_changes = [
  ((0, 1), (1, 0), (2, 0)),  # Ustawienie czarnych włosów
  ((0, 0), (1, 1), (2, 0)),  # Ustawienie blond włosów
  ((0, 0), (1, 0), (2, 1)),  # Ustawienie brązowych włosów
  ((3, -1),),  # Zmiana płci
  ((4, -1),),  # Zmiana wieku
]

In [0]:
def sample_images(batches_done):
  """Zapis wygenerowanej próbki"""
  val_imgs, val_labels = next(iter(val_dataloader))
  val_imgs = Variable(val_imgs.type(Tensor))
  val_labels = Variable(val_labels.type(Tensor))
  img_samples = None
  for i in range(10):
    img, label = val_imgs[i], val_labels[i]
    
    imgs = img.repeat(c_dim, 1, 1, 1)
    labels = label.repeat(c_dim, 1)

    for sample_i, changes in enumerate(label_changes):
      for col, val in changes:
        labels[sample_i, col] = 1 - labels[sample_i, col] if val == -1 else val
    
    gen_imgs = generator(imgs, labels)
    gen_imgs = torch.cat([x for x in gen_imgs.data], -1)
    img_sample = torch.cat((img.data, gen_imgs), -1)

    img_samples = img_sample if img_samples is None else torch.cat((img_samples, img_sample), -2)

  save_image(img_samples.view(1, *img_samples.shape), 'images/%s.png' % batches_done, normalize=True)

## Trening modelu

In [23]:
saved_samples = []
start_time = time.time()
for epoch in range(epoch, n_epochs):
  for i, (imgs, labels) in enumerate(dataloader):

    # Wejścia modelu
    imgs = Variable(imgs.type(Tensor))
    labels = Variable(labels.type(Tensor))

    # Przykładowe etykiety wejść generatora
    sampled_c = Variable(Tensor(np.random.randint(0, 2, (imgs.size(0), c_dim))))
    # Generacja fałszywego wsadu obrazów
    fake_imgs = generator(imgs, sampled_c)

    # Trening dyskryminatora

    optimizer_D.zero_grad()

    # Prawdziwe obrazy
    real_validity, pred_cls = discriminator(imgs)
    # Fałszywe obrazy
    fake_validity, _ = discriminator(fake_imgs.detach())

    gradient_penalty = compute_gradient_penalty(discriminator, imgs.data, fake_imgs.data)
    # Strata antagonistyczna
    loss_D_adv = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
    # Strata klasyfikacji
    loss_D_cls = criterion_cls(pred_cls, labels)
    # Strata całkowita
    loss_D = loss_D_adv + lambda_cls * loss_D_cls

    loss_D.backward()
    optimizer_D.step()

    optimizer_G.zero_grad()

    if i % n_critic == 0:

      # Trening generatora

      # Tłumaczenie i rekonstrukcja obrazu
      gen_imgs = generator(imgs, sampled_c)
      recov_imgs = generator(gen_imgs, labels)
      # Dyskryminator rozważa przetłumaczony obraz
      fake_validity, pred_cls = discriminator(gen_imgs)
      # Strata antagonistyczna
      loss_G_adv = -torch.mean(fake_validity)
      # Strata klasyfikacji
      loss_G_cls = criterion_cls(pred_cls, sampled_c)
      # Strata rekonstrukcji
      loss_G_rec = criterion_cycle(recov_imgs, imgs)
      # Strata całkowita
      loss_G = loss_G_adv + lambda_cls * loss_G_cls + lambda_rec * loss_G_rec
      
      loss_G.backward()
      optimizer_G.step()

      # Logi postępu

      batches_done = epoch * len(dataloader) + i
      batches_left = n_epochs * len(dataloader) - batches_done
      time_left = datetime.timedelta(seconds=batches_left * (time.time() - start_time) / (batches_done + 1))

      # Wypisz log
      sys.stdout.write(
          '\r[Epoch %d/%d] [Batch %d/%d] [D adv: %f, aux: %f] [G loss: %f, adv: %f, aux: %f, cycle: %f] ETA: %s'
          % (
             epoch,
             n_epochs,
             i,
             len(dataloader),
             loss_D_adv.item(),
             loss_D_cls.item(),
             loss_G.item(),
             loss_G_adv.item(),
             loss_G_cls.item(),
             loss_G_rec.item(),
             time_left,
          )
      )

      if batches_done % sample_interval == 0:
        sample_images(batches_done)

  if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
    # Zapis punktów kontrolnych modelu
    torch.save(generator.state_dict(), '/content/saved_models/generator_%d.pth' % epoch)
    torch.save(discriminator.state_dict(), '/content/saved_models/discriminator_%d.pth' % epoch)



[Epoch 199/200] [Batch 125/126] [D adv: -3.337196, aux: 4.035940] [G loss: 9.611263, adv: -0.089546, aux: 8.015087, cycle: 0.168572] ETA: 0:00:00.533382

In [0]:
# torch.save(generator.state_dict(), '/content/saved_models/generator_%d.pth' % epoch)
# torch.save(discriminator.state_dict(), '/content/drive/My Drive/discriminator_%d.pth' % epoch)

## Pobranie przetrenowanych modeli i obrazów

In [0]:
from google.colab import files
files.download('/content/saved_models/discriminator_199.pth')

In [0]:
files.download('/content/saved_models/generator_199.pth')

In [34]:
!zip -r /content/images.zip /content/images/
files.download('/content/images.zip')

  adding: content/images/ (stored 0%)
  adding: content/images/19600.png (deflated 0%)
  adding: content/images/12000.png (deflated 0%)
  adding: content/images/20800.png (deflated 0%)
  adding: content/images/6400.png (deflated 0%)
  adding: content/images/3200.png (deflated 0%)
  adding: content/images/10800.png (deflated 0%)
  adding: content/images/7600.png (deflated 0%)
  adding: content/images/15200.png (deflated 0%)
  adding: content/images/22800.png (deflated 0%)
  adding: content/images/16400.png (deflated 0%)
  adding: content/images/0.png (deflated 0%)
  adding: content/images/24000.png (deflated 0%)
  adding: content/images/2000.png (deflated 0%)
