In [1]:
import functools

#import imlib as im
import numpy as np
#import pylib as py
import tensorboardX
import torch
#import torchlib
from torch import nn
from tqdm.auto import trange, tqdm
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms


## Data processing

In [2]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [3]:
import requests
url = 'https://raw.githubusercontent.com/anastasia-yaschenko/Generative-models/main/celeba%20(1).py'
open('celeba.py', 'wb').write(requests.get(url).content);
url = 'https://raw.githubusercontent.com/vpozdnyakov/DeepGenerativeModels/spring-2022/data/celeba/list_attr_celeba.txt'
open('list_attr_celeba.txt', 'wb').write(requests.get(url).content);

from celeba import CelebADataset

In [4]:
crop_size = 108

offset_height = (218 - crop_size) // 2
offset_width = (178 - crop_size) // 2
crop = lambda x: x[:, offset_height:offset_height + crop_size, offset_width:offset_width + crop_size]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(crop),
    transforms.ToPILImage(),
    transforms.Resize(size=(64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [5]:
dataset = CelebADataset(attr_file_path='list_attr_celeba.txt', transform=transform, crop=False)

data_loader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)

shape = (64, 64, 3)

## WGAN architecture

In [6]:
class Identity(torch.nn.Module):

    def __init__(self, *args, **keyword_args):
        super().__init__()

    def forward(self, x):
        return x

def _get_norm_layer_2d(norm):
    if norm == 'none':
        return Identity
    elif norm == 'batch_norm':
        return nn.BatchNorm2d
    elif norm == 'instance_norm':
        return functools.partial(nn.InstanceNorm2d, affine=True)
    elif norm == 'layer_norm':
        return lambda num_features: nn.GroupNorm(1, num_features)
    else:
        raise NotImplementedError


In [7]:
class ConvGenerator(nn.Module):

    def __init__(self,
                 input_dim=128,
                 output_channels=3,
                 dim=64,
                 n_upsamplings=4,
                 norm='batch_norm'):
        super().__init__()

        Norm = _get_norm_layer_2d(norm)

        def dconv_norm_relu(in_dim, out_dim, kernel_size=4, stride=2, padding=1):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride=stride, padding=padding, bias=False or Norm == Identity),
                Norm(out_dim),
                nn.ReLU()
            )

        layers = []

        # 1: 1x1 -> 4x4
        d = min(dim * 2 ** (n_upsamplings - 1), dim * 8)
        layers.append(dconv_norm_relu(input_dim, d, kernel_size=4, stride=1, padding=0))

        # 2: upsamplings, 4x4 -> 8x8 -> 16x16 -> ...
        for i in range(n_upsamplings - 1):
            d_last = d
            d = min(dim * 2 ** (n_upsamplings - 2 - i), dim * 8)
            layers.append(dconv_norm_relu(d_last, d, kernel_size=4, stride=2, padding=1))

        layers.append(nn.ConvTranspose2d(d, output_channels, kernel_size=4, stride=2, padding=1))
        layers.append(nn.Tanh())

        self.net = nn.Sequential(*layers)

    def forward(self, z):
        x = self.net(z)
        return x

In [8]:
class ConvDiscriminator(nn.Module):

    def __init__(self,
                 input_channels=3,
                 dim=64,
                 n_downsamplings=4,
                 norm='batch_norm'):
        super().__init__()

        Norm = _get_norm_layer_2d(norm)

        def conv_norm_lrelu(in_dim, out_dim, kernel_size=4, stride=2, padding=1):
            return nn.Sequential(
                nn.Conv2d(in_dim, out_dim, kernel_size, stride=stride, padding=padding, bias=False or Norm == Identity),
                Norm(out_dim),
                nn.LeakyReLU(0.2)
            )

        layers = []

        # 1: downsamplings, ... -> 16x16 -> 8x8 -> 4x4
        d = dim
        layers.append(nn.Conv2d(input_channels, d, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.2))

        for i in range(n_downsamplings - 1):
            d_last = d
            d = min(dim * 2 ** (i + 1), dim * 8)
            layers.append(conv_norm_lrelu(d_last, d, kernel_size=4, stride=2, padding=1))

        # 2: logit
        layers.append(nn.Conv2d(d, 1, kernel_size=4, stride=1, padding=0))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        y = self.net(x)
        return y


## Training

In [9]:
def train_G():
    G.train()
    D.train()

    z = torch.randn(64, 128, 1, 1).to(device)
    x_fake = G(z)

    x_fake_d_logit = D(x_fake)
    G_loss = -torch.mean(x_fake_d_logit)

    G.zero_grad()
    G_loss.backward()
    G_optimizer.step()

    return {'g_loss': G_loss}


In [10]:
def _sample_line(real, fake):
    shape = [real.size(0)] + [1] * (real.dim() - 1)
    alpha = torch.rand(shape, device=real.device)
    sample = real + alpha * (fake - real)
    return sample


def _sample_DRAGAN(real, fake):  # fake is useless
    beta = torch.rand_like(real)
    fake = real + 0.5 * real.std() * beta
    sample = _sample_line(real, fake)
    return sample
    
def _norm(x):
    norm = x.view(x.size(0), -1).norm(p=2, dim=1)
    return norm


def _one_mean_gp(grad):
    norm = _norm(grad)
    gp = ((norm - 1)**2).mean()
    return gp


def _zero_mean_gp(grad):
    norm = _norm(grad)
    gp = (norm**2).mean()
    return gp


def _lipschitz_penalty(grad):
    norm = _norm(grad)
    gp = (torch.max(torch.zeros_like(norm), norm - 1)**2).mean()
    return gp

def gradient_penalty(f, real, fake, gp_mode, sample_mode):
    sample_fns = {
        'line': _sample_line,
        'real': lambda real, fake: real,
        'fake': lambda real, fake: fake,
        'dragan': _sample_DRAGAN,
    }

    gp_fns = {
        '1-gp': _one_mean_gp,
        '0-gp': _zero_mean_gp,
        'lp': _lipschitz_penalty,
    }

    if gp_mode == 'none':
        gp = torch.tensor(0, dtype=real.dtype, device=real.device)
    else:
        x = sample_fns[sample_mode](real, fake).detach()
        x.requires_grad = True
        pred = f(x)
        grad = torch.autograd.grad(pred, x, grad_outputs=torch.ones_like(pred), create_graph=True)[0]
        gp = gp_fns[gp_mode](grad)

    return gp

In [11]:
def train_D(x_real):
    G.train()
    D.train()

    z = torch.randn(64, 128, 1, 1).to(device)
    x_fake = G(z).detach()

    x_real_d_logit = D(x_real)
    x_fake_d_logit = D(x_fake)

    x_real_d_loss, x_fake_d_loss = d_loss_fn(x_real_d_logit, x_fake_d_logit)
    gp = gradient_penalty(functools.partial(D), x_real, x_fake, gp_mode='1-gp', sample_mode='line')

    D_loss =  -torch.mean(x_real_d_logit) + torch.mean(x_fake_d_logit) + gp * 10.0

    D.zero_grad()
    D_loss.backward()
    D_optimizer.step()

    return {'d_loss': x_real_d_loss + x_fake_d_loss, 'gp': gp}


@torch.no_grad()
def sample(z):
    G.eval()
    return G(z)

In [12]:
def get_gan_losses_fn():
    bce = torch.nn.BCEWithLogitsLoss()

    def d_loss_fn(r_logit, f_logit):
        r_loss = bce(r_logit, torch.ones_like(r_logit))
        f_loss = bce(f_logit, torch.zeros_like(f_logit))
        return r_loss, f_loss

    def g_loss_fn(f_logit):
        f_loss = bce(f_logit, torch.ones_like(f_logit))
        return f_loss

    return d_loss_fn, g_loss_fn


def get_hinge_v1_losses_fn():
    def d_loss_fn(r_logit, f_logit):
        r_loss = torch.max(1 - r_logit, torch.zeros_like(r_logit)).mean()
        f_loss = torch.max(1 + f_logit, torch.zeros_like(f_logit)).mean()
        return r_loss, f_loss

    def g_loss_fn(f_logit):
        f_loss = torch.max(1 - f_logit, torch.zeros_like(f_logit)).mean()
        return f_loss

    return d_loss_fn, g_loss_fn


def get_hinge_v2_losses_fn():
    def d_loss_fn(r_logit, f_logit):
        r_loss = torch.max(1 - r_logit, torch.zeros_like(r_logit)).mean()
        f_loss = torch.max(1 + f_logit, torch.zeros_like(f_logit)).mean()
        return r_loss, f_loss

    def g_loss_fn(f_logit):
        f_loss = -f_logit.mean()
        return f_loss

    return d_loss_fn, g_loss_fn


def get_lsgan_losses_fn():
    mse = torch.nn.MSELoss()

    def d_loss_fn(r_logit, f_logit):
        r_loss = mse(r_logit, torch.ones_like(r_logit))
        f_loss = mse(f_logit, torch.zeros_like(f_logit))
        return r_loss, f_loss

    def g_loss_fn(f_logit):
        f_loss = mse(f_logit, torch.ones_like(f_logit))
        return f_loss

    return d_loss_fn, g_loss_fn


def get_wgan_losses_fn():
    def d_loss_fn(r_logit, f_logit):
        r_loss = -r_logit.mean()
        f_loss = f_logit.mean()
        return r_loss, f_loss

    def g_loss_fn(f_logit):
        f_loss = -f_logit.mean()
        return f_loss

    return d_loss_fn, g_loss_fn


def get_adversarial_losses_fn(mode):
    if mode == 'gan':
        return get_gan_losses_fn()
    elif mode == 'hinge_v1':
        return get_hinge_v1_losses_fn()
    elif mode == 'hinge_v2':
        return get_hinge_v2_losses_fn()
    elif mode == 'lsgan':
        return get_lsgan_losses_fn()
    elif mode == 'wgan':
        return get_wgan_losses_fn()

In [13]:
# networks
n_G_upsamplings = n_D_downsamplings = 4
lr = 0.0002

G = ConvGenerator(128, shape[-1], n_upsamplings=n_G_upsamplings).to(device)
D = ConvDiscriminator(shape[-1], n_downsamplings=n_D_downsamplings, norm='layer_norm').to(device)
print(G)
print(D)

# adversarial_loss_functions
d_loss_fn, g_loss_fn = get_adversarial_losses_fn('wgan')

# optimizer
G_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

ConvGenerator(
  (net): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(128, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): ConvTranspose2d(64, 3, ke

In [14]:
import skimage.io as iio

def _check(images, dtypes, min_value=-np.inf, max_value=np.inf):
    # check type
    assert isinstance(images, np.ndarray), '`images` should be np.ndarray!'

    # check dtype
    dtypes = dtypes if isinstance(dtypes, (list, tuple)) else [dtypes]
    assert images.dtype in dtypes, 'dtype of `images` shoud be one of %s!' % dtypes

    # check nan and inf
    assert np.all(np.isfinite(images)), '`images` contains NaN or Inf!'

    # check value
    if min_value not in [None, -np.inf]:
        l = '[' + str(min_value)
    else:
        l = '(-inf'
        min_value = -np.inf
    if max_value not in [None, np.inf]:
        r = str(max_value) + ']'
    else:
        r = 'inf)'
        max_value = np.inf
    assert np.min(images) >= min_value and np.max(images) <= max_value, \
        '`images` should be in the range of %s!' % (l + ',' + r)


def to_range(images, min_value=0.0, max_value=1.0, dtype=None):
    """Transform images from [-1.0, 1.0] to [min_value, max_value] of dtype."""
    _check(images, [np.float32, np.float64], -1.0, 1.0)
    dtype = dtype if dtype else images.dtype
    return ((images + 1.) / 2. * (max_value - min_value) + min_value).astype(dtype)

def im2uint(images):
    """Transform images from [-1.0, 1.0] to uint8."""
    return to_range(images, 0, 255, np.uint8)

def imwrite(image, path, quality=95, **plugin_args):
    """Save a [-1.0, 1.0] image."""
    iio.imsave(path, im2uint(image), quality=quality, **plugin_args)

def immerge(images, n_rows=None, n_cols=None, padding=0, pad_value=0):
    """Merge images to an image with (n_rows * h) * (n_cols * w).

    Parameters
    ----------
    images : numpy.array or object which can be converted to numpy.array
        Images in shape of N * H * W(* C=1 or 3).

    """
    images = np.array(images)
    n = images.shape[0]
    if n_rows:
        n_rows = max(min(n_rows, n), 1)
        n_cols = int(n - 0.5) // n_rows + 1
    elif n_cols:
        n_cols = max(min(n_cols, n), 1)
        n_rows = int(n - 0.5) // n_cols + 1
    else:
        n_rows = int(n ** 0.5)
        n_cols = int(n - 0.5) // n_rows + 1

    h, w = images.shape[1], images.shape[2]
    shape = (h * n_rows + padding * (n_rows - 1),
             w * n_cols + padding * (n_cols - 1))
    if images.ndim == 4:
        shape += (images.shape[3],)
    img = np.full(shape, pad_value, dtype=images.dtype)

    for idx, image in enumerate(images):
        i = idx % n_cols
        j = idx // n_cols
        img[j * (h + padding):j * (h + padding) + h,
            i * (w + padding):i * (w + padding) + w, ...] = image

    return img

In [15]:
import os
#import shutil

def save_checkpoint(obj, save_path, is_best=False, max_keep=None):
    # save checkpoint
    torch.save(obj, save_path)

    # deal with max_keep
    save_dir = os.path.dirname(save_path)
    list_path = os.path.join(save_dir, 'checkpoint')

    save_path = os.path.basename(save_path)
    if os.path.exists(list_path):
        with open(list_path) as f:
            ckpt_list = f.readlines()
            ckpt_list = [save_path + '\n'] + ckpt_list
    else:
        ckpt_list = [save_path + '\n']

    if max_keep is not None:
        for ckpt in ckpt_list[max_keep:]:
            ckpt = os.path.join(save_dir, ckpt[:-1])
            if os.path.exists(ckpt):
                os.remove(ckpt)
        ckpt_list[max_keep:] = []

    with open(list_path, 'w') as f:
        f.writelines(ckpt_list)
    
def load_checkpoint(ckpt_dir_or_file, map_location=None, load_best=False):
    if os.path.isdir(ckpt_dir_or_file):
        if load_best:
            ckpt_path = os.path.join(ckpt_dir_or_file, 'best_model.ckpt')
        else:
            with open(os.path.join(ckpt_dir_or_file, 'checkpoint')) as f:
                ckpt_path = os.path.join(ckpt_dir_or_file, f.readline()[:-1])
    else:
        ckpt_path = ckpt_dir_or_file
    ckpt = torch.load(ckpt_path, map_location=map_location)
    return ckpt   

    # copy best
    #if is_best:
     #   shutil.copyfile(save_path, os.path.join(save_dir, 'best_model.ckpt'))

In [17]:
output_dir = 'output/wgan' #py.join('output', 'wgan')
#py.mkdir(output_dir)

# save settings
#py.args_to_yaml(py.join(output_dir, 'settings.yml'), args)

ckpt_dir1 = 'checkpoints/all_checkpoints' #py.join(output_dir, 'all_checkpoints')
#py.mkdir(ckpt_dir1)
# load checkpoint if exists
ckpt_dir = 'checkpoints/checkpoints'#py.join(output_dir, 'checkpoints')
#py.mkdir(ckpt_dir)


try:
    ckpt = load_checkpoint(ckpt_dir)
    ep, it_d, it_g = ckpt['ep'], ckpt['it_d'], ckpt['it_g']
    print(ep)
    D.load_state_dict(ckpt['D'])
    G.load_state_dict(ckpt['G'])
    D_optimizer.load_state_dict(ckpt['D_optimizer'])
    G_optimizer.load_state_dict(ckpt['G_optimizer'])
except:
    ep, it_d, it_g = 0, 0, 0

# sample
sample_dir = 'output/samples_training'#py.join(output_dir, 'samples_training')
#py.mkdir(sample_dir)

# main loop
#writer = tensorboardX.SummaryWriter(py.join(output_dir, 'summaries'))
z = torch.randn(100, 128, 1, 1).to(device)  # a fixed noise for sampling

for ep_ in range(75, 81):
    print('Epoch {}'.format(ep_))
    if ep_ < ep:
        continue
    ep += 1

    # train for an epoch
    for x_real, _ in tqdm(data_loader):
        x_real = x_real.to(device)
        

        D_loss_dict = train_D(x_real)
        it_d += 1
        #for k, v in D_loss_dict.items():
            #writer.add_scalar('D/%s' % k, v.data.cpu().numpy(), global_step=it_d)

        if it_d % 5 == 0:
            G_loss_dict = train_G()
            it_g += 1
            #for k, v in G_loss_dict.items():
                #writer.add_scalar('G/%s' % k, v.data.cpu().numpy(), global_step=it_g)

        # sample
        if it_g % 100 == 0:
            x_fake = sample(z)
            x_fake = np.transpose(x_fake.data.cpu().numpy(), (0, 2, 3, 1))
            img = immerge(x_fake, n_rows=10).squeeze()
            imwrite(img, os.path.join(sample_dir, 'iter-%09d.jpg' % it_g))

    if ep % 5 == 0:
        # save checkpoint
        save_checkpoint({'ep': ep, 'it_d': it_d, 'it_g': it_g,
                                'D': D.state_dict(),
                                'G': G.state_dict(),
                                'D_optimizer': D_optimizer.state_dict(),
                                'G_optimizer': G_optimizer.state_dict()},
                               os.path.join(ckpt_dir1, 'Epoch_(%d).ckpt' % ep)) #max_keep=1)
    
    # save checkpoint
    save_checkpoint({'ep': ep, 'it_d': it_d, 'it_g': it_g,
                            'D': D.state_dict(),
                            'G': G.state_dict(),
                            'D_optimizer': D_optimizer.state_dict(),
                            'G_optimizer': G_optimizer.state_dict()},
                            os.path.join(ckpt_dir, 'Epoch_(%d).ckpt' % ep), max_keep=1)

75
Epoch 75


100%|██████████| 3165/3165 [10:38<00:00,  4.96it/s]


Epoch 76


100%|██████████| 3165/3165 [10:11<00:00,  5.17it/s]


Epoch 77


100%|██████████| 3165/3165 [10:12<00:00,  5.17it/s]


Epoch 78


100%|██████████| 3165/3165 [10:21<00:00,  5.09it/s]


Epoch 79


100%|██████████| 3165/3165 [10:18<00:00,  5.12it/s]


Epoch 80


100%|██████████| 3165/3165 [10:22<00:00,  5.08it/s]
