From 2f8efad989d76596032ca96364344ae21a4f997c Mon Sep 17 00:00:00 2001 From: STomoya Date: Tue, 29 Sep 2020 12:02:19 +0900 Subject: [PATCH 1/8] [UPDATE] move diffaug to gan_utils --- implementations/{general => gan_utils}/DiffAugment_pytorch.py | 0 implementations/general/__init__.py | 2 -- 2 files changed, 2 deletions(-) rename implementations/{general => gan_utils}/DiffAugment_pytorch.py (100%) diff --git a/implementations/general/DiffAugment_pytorch.py b/implementations/gan_utils/DiffAugment_pytorch.py similarity index 100% rename from implementations/general/DiffAugment_pytorch.py rename to implementations/gan_utils/DiffAugment_pytorch.py diff --git a/implementations/general/__init__.py b/implementations/general/__init__.py index 62b8fca..6b83068 100644 --- a/implementations/general/__init__.py +++ b/implementations/general/__init__.py @@ -3,8 +3,6 @@ from .danbooru import DanbooruDataset, GeneratePairImageDanbooruDataset from .danbooru_portrait import DanbooruPortraitDataset -from .DiffAugment_pytorch import DiffAugment - from .fp16 import network_to_half from torch.utils.data import DataLoader From 2005b4e0c1964859ac0bdd130f6bf7fff33719e7 Mon Sep 17 00:00:00 2001 From: STomoya Date: Tue, 29 Sep 2020 12:05:33 +0900 Subject: [PATCH 2/8] [UPDATE] adjust diffaug aimport --- implementations/HoloGAN/utils.py | 3 ++- implementations/StyleGAN2/utils.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/implementations/HoloGAN/utils.py b/implementations/HoloGAN/utils.py index 838ad47..d8dda7f 100644 --- a/implementations/HoloGAN/utils.py +++ b/implementations/HoloGAN/utils.py @@ -7,7 +7,8 @@ from torchvision.utils import save_image from .model import Generator, Discriminator -from ..general import AnimeFaceDataset, to_loader, DiffAugment +from ..general import AnimeFaceDataset, to_loader +from ..gan_utils import DiffAugment def gen_theta( num_gen, minmax_angles=[0, 0, 220, 320, 0, 0], diff --git a/implementations/StyleGAN2/utils.py b/implementations/StyleGAN2/utils.py index 0148e76..2bff653 100644 --- a/implementations/StyleGAN2/utils.py +++ b/implementations/StyleGAN2/utils.py @@ -7,7 +7,8 @@ from .model import Generator, Discriminator -from ..general import AnimeFaceDataset, to_loader, DiffAugment +from ..general import AnimeFaceDataset, to_loader +from ..gan_utils import DiffAugment def toggle_grad(model, state): for param in model.parameters(): From 6df5185c8d35d023b1c3c19de9568b1e7d3fae65 Mon Sep 17 00:00:00 2001 From: STomoya Date: Tue, 29 Sep 2020 13:19:10 +0900 Subject: [PATCH 3/8] [UPDATE] move losses to gan_utils --- implementations/{general => gan_utils}/losses.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename implementations/{general => gan_utils}/losses.py (100%) diff --git a/implementations/general/losses.py b/implementations/gan_utils/losses.py similarity index 100% rename from implementations/general/losses.py rename to implementations/gan_utils/losses.py From c69d141bd451dee43e09a2305933b38581e42dce Mon Sep 17 00:00:00 2001 From: STomoya Date: Tue, 29 Sep 2020 17:20:19 +0900 Subject: [PATCH 4/8] [ADD] utils --- implementations/gan_utils/utils.py | 145 +++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 implementations/gan_utils/utils.py diff --git a/implementations/gan_utils/utils.py b/implementations/gan_utils/utils.py new file mode 100644 index 0000000..2cf802b --- /dev/null +++ b/implementations/gan_utils/utils.py @@ -0,0 +1,145 @@ + +import os +import warnings + +import torch +import torchvision as tv + +# device specification +def get_device(): return torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + +# noise sampler +# normal distribution +def sample_nnoise(size, mean=0., std=1., device=None): + if device == None: + device = get_device() + return torch.empty(size, device=device).normal_(mean, std) + +# uniform distribution +def sample_unoise(size, from_=0., to=1., device=None): + if device == None: + device = get_device() + return torch.empty(size, device=device).uniform_(from_, to) + + +class GANTrainingStatus: + '''GAN training status helper''' + def __init__(self): + self.losses = {'G' : [], 'D' : []} + self.batches_done = 0 + + def append_g_loss(self, g_loss): + '''append generator loss''' + if isinstance(g_loss, torch.Tensor): + g_loss = g_loss.item() + self.losses['G'].append(g_loss) + def append_d_loss(self, d_loss): + '''append discriminator loss''' + if isinstance(d_loss, torch.Tensor): + d_loss = d_loss.item() + self.losses['D'].append(d_loss) + self.batches_done += 1 + + def add_loss(self, key): + '''define additional loss''' + if not isinstance(key, str): + raise Exception('Input a String object as the key. Got type {}'.format(type(key))) + self.losses[key] = [] + def append_additional_loss(self, **kwargs): + '''append additional loss''' + for key, value in kwargs.items(): + try: + self.losses[key].append(value) + except KeyError as ke: + warnings.warn('You have tried to append a loss keyed as \'{}\' that is not defined. Please call add_loss() or check the spelling.'.format(key)) + + def append(self, g_loss, d_loss, **kwargs): + '''append loss at once''' + self.append_g_loss(g_loss) + self.append_d_loss(d_loss) + self.append_additional_loss(**kwargs) + + def plot_loss(self, filename='./loss.png'): + '''plot loss''' + import matplotlib + matplotlib.use('agg') + import matplotlib.pyplot as plt + + add_loss = [key for key in self.losses if key not in ['G', 'D']] + G_loss = self.losses['G'] + D_loss = self.losses['D'] + + plt.figure(figsize=(12, 8)) + for key in add_loss: + plt.plot(self.losses[key]) + plt.plot(G_loss) + plt.plot(D_loss) + plt.title('Model Loss') + plt.xlabel('iter') + plt.ylabel('loss') + plt.legend([key for key in add_loss] + ['Generator', 'Discriminator'], loc='upper left') + + plt.tight_layout() + plt.savefig(filename) + plt.close() + + def save_image(self, folder, G, *input, filename=None, nrow=5, normalize=True, range=(-1, 1)): + '''simple save image + save_image func with + sampling images, and only args that I use frequently + ''' + G.eval() + with torch.no_grad(): + images = G(*input) + G.train() + + if filename == None: + filename = '{}.png'.format(self.batches_done) + + tv.utils.save_image( + images, os.path.join(folder, filename), nrow=nrow, normalize=normalize, range=range + ) + + def __str__(self): + '''print the latest losses when calling print() on the object''' + partial_msg = [] + partial_msg += [ + '{:6}'.format(self.batches_done), + '[D Loss : {:.5f}]'.format(self.losses['D'][-1]), + '[G Loss : {:.5f}]'.format(self.losses['G'][-1]), + ] + # verbose additinal loss + add_loss = [key for key in self.losses if key not in ['D', 'G']] + if len(add_loss) > 0: + for key in add_loss: + if self.losses[key] == []: # skip when no entry + continue + partial_msg.append( + '[{} : {:.5f}]'.format(key, self.losses[key][-1]) + ) + return '\t'.join(partial_msg) + + + +if __name__ == "__main__": + '''TEST''' + # device = get_device() + # print(sample_nnoise((3, 64), 0, 0.02).size()) + # print(sample_nnoise((3, 64), 0, 0.02, torch.device('cpu')).size()) + # print(sample_nnoise((3, 64), 0, 0.02, device).size()) + # print(sample_unoise((3, 64), 0, 3).size()) + # print(sample_unoise((3, 64), 0, 3, torch.device('cpu')).size()) + # print(sample_unoise((3, 64), 0, 3, device).size()) + + status = GANTrainingStatus() + + status.add_loss('real') + status.add_loss('fake') + + import math + + for i in range(100): + status.append(math.sin(i), math.cos(i), real=-i/10, fake=i/10) + print(status) + + status.plot_loss() \ No newline at end of file From 223f6adc39f8b907656b09a188bd1e394eed6fff Mon Sep 17 00:00:00 2001 From: STomoya Date: Tue, 29 Sep 2020 17:20:33 +0900 Subject: [PATCH 5/8] [ADD] ema --- implementations/gan_utils/ema.py | 87 ++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 implementations/gan_utils/ema.py diff --git a/implementations/gan_utils/ema.py b/implementations/gan_utils/ema.py new file mode 100644 index 0000000..40da3f9 --- /dev/null +++ b/implementations/gan_utils/ema.py @@ -0,0 +1,87 @@ + +import copy + +import torch +import torch.nn as nn + +class EMA: + '''Exponential Moving Avergae''' + def __init__(self, init_model, decay=0.999): + ''' + args + init_model : nn.Module + the model. + please input the "initialized" model + decay: float (default:0.999) + the decay used to update the model. + usually in [0.9, 0.999] + ''' + self.decay = decay + self.G_ema = copy.deepcopy(init_model) + # freeze and eval mode + for param in self.G_ema.parameters(): + param.requires_grad = False + self.G_ema.cpu().eval() + + def update(self, model_running): + '''update G_ema + + args + model_running: nn.Module + the running model. + must be the same model as EMA + ''' + + # running model to cpu + original_device = next(model_running.parameters()).device + model_running.cpu() + + # update params + ema_param = dict(self.G_ema.named_parameters()) + run_param = dict(G.named_parameters()) + + for key in ema_param.keys(): + ema_param[key].data.mul_(self.decay).add_(run_param[key], alpha=(1-self.decay)) + + # running model to original device + model_running.to(original_device) + + def __call__(self, *args, **kwargs): + '''return the model's output + + args + inputs to the model. + args can be on gpu. + ''' + # move inputs to cpu before input + cpu_args, cpu_kwargs = [], {} + for arg in args: + if isinstance(arg, torch.Tensor): + arg = arg.cpu() + cpu_args.append(arg) + for key, value in kwargs: + if isinstance(value, torch.Tensor): + value = value.cpu() + cpu_kwargs[key] = value + + with torch.no_grad(): + return self.G_ema(*cpu_args, **cpu_kwargs) + +if __name__ == "__main__": + from utils import get_device + import torch.nn as nn + device = get_device() + G = nn.Sequential( + nn.Conv2d(64, 32, 3, padding=1), + nn.ReLU(), + nn.Conv2d(32, 3, 3, padding=1), + nn.Tanh() + ) + G.to(device) + + z = torch.randn(3, 64, 4, 4, device=device) + print(G(z).size()) + + G_ema = EMA(G) + G_ema.update(G) + print(G_ema(z).size()) \ No newline at end of file From a81138901a8a889dd52149d58f85de0c27eab67f Mon Sep 17 00:00:00 2001 From: STomoya Date: Tue, 29 Sep 2020 17:23:52 +0900 Subject: [PATCH 6/8] [UPDATE] add kornia --- docker/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/requirements.txt b/docker/requirements.txt index 84d367d..ee4bf45 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -7,4 +7,5 @@ tqdm chainer illustration2vec scikit-image -scipy \ No newline at end of file +scipy +git+https://github.com/kornia/kornia \ No newline at end of file From 59de4fe818ccb2c461fd0e10b563d14a80d4eee0 Mon Sep 17 00:00:00 2001 From: STomoya Date: Tue, 29 Sep 2020 17:24:25 +0900 Subject: [PATCH 7/8] [ADD] diffaug kornia ver. --- .../gan_utils/DiffAugment_kornia.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 implementations/gan_utils/DiffAugment_kornia.py diff --git a/implementations/gan_utils/DiffAugment_kornia.py b/implementations/gan_utils/DiffAugment_kornia.py new file mode 100644 index 0000000..9f79052 --- /dev/null +++ b/implementations/gan_utils/DiffAugment_kornia.py @@ -0,0 +1,69 @@ + +''' +kornia based Differetiable Augmentation + +I took care so that it can be used by almost the same way as the official implementation, +but modified it so that I can set the parameters from other files. + +Difference from official implementation +(Mostly because of using karnia) +- padding mode for translation + official : zero + here : same +- brightness and contrast range + official : brightness [-0.5, 0.5] + contrast [ 0.5, 1.5] + here : brightness [0.75, 1.25] + contrast [0.75, 1.25] +- denorm before augmentation + official : not needed + here : karnia.augmentation requires data values between [0, 1] + so if normalized, denorm -> augment -> norm +''' + +import torch +import torch.nn as nn +import kornia.augmentation as aug +from kornia.constants import SamplePadding + +class kDiffAugment: + def __init__(self, + brightness=(0.75, 1.25), contrast=(0.75, 1.25), saturation=(0., 2.), translate=(0.125, 0.125), + normalized=True, mean=0.5, std=0.5, device=None + ): + if normalized: + if isinstance(mean, (tuple, list)) and isinstance(std, (tuple, list)): + if not device: + raise Exception('Please specify a torch.device() object when using mean and std for each channels') + mean = torch.Tensor(mean).to(device) + std = torch.Tensor(std).to(device) + self.normalize = aug.Normalize(mean, std) + self.denormalize = aug.Denormalize(mean, std) + else: + self.normalize, self.denormalize = None, None + + color_jitter = aug.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, p=1.) # rand_brightness, rand_contrast, rand_saturation + affine = aug.RandomAffine(degrees=0, translate=translate, padding_mode=SamplePadding.BORDER, p=1.) # rand_translate + cutout = aug.RandomErasing(value=0.5, p=1.) # rand_cutout + + self.augmentations = { + 'color' : color_jitter, + 'translation' : affine, + 'cutout' : cutout + } + + def __call__(self, x, policy): + if self.denormalize: + x = self.denormalize(x) + policy = self.__encode_policy(policy) + for p in policy: + aug_func = self.augmentations[p] + x = aug_func(x) + if self.normalize: + x = self.normalize(x) + return x + + def __encode_policy(self, policy): + if isinstance(policy, (tuple, list)): + return policy + return policy.split(',') \ No newline at end of file From da47e45ca8a756b9edb0b009760b098962962c07 Mon Sep 17 00:00:00 2001 From: STomoya Date: Tue, 29 Sep 2020 17:26:36 +0900 Subject: [PATCH 8/8] [ADD] init file for gan_utils --- implementations/gan_utils/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 implementations/gan_utils/__init__.py diff --git a/implementations/gan_utils/__init__.py b/implementations/gan_utils/__init__.py new file mode 100644 index 0000000..1689b51 --- /dev/null +++ b/implementations/gan_utils/__init__.py @@ -0,0 +1,5 @@ +# import loss functions from gan_utils.losses +from .utils import * +from .ema import EMA +from .DiffAugment_pytorch import DiffAugment +from .DiffAugment_kornia import kDiffAugment \ No newline at end of file