Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Utils #33

Merged
merged 8 commits into from
Sep 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ tqdm
chainer
illustration2vec
scikit-image
scipy
scipy
git+https://github.com/kornia/kornia
3 changes: 2 additions & 1 deletion implementations/HoloGAN/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from torchvision.utils import save_image

from .model import Generator, Discriminator
from ..general import AnimeFaceDataset, to_loader, DiffAugment
from ..general import AnimeFaceDataset, to_loader
from ..gan_utils import DiffAugment

def gen_theta(
num_gen, minmax_angles=[0, 0, 220, 320, 0, 0],
Expand Down
3 changes: 2 additions & 1 deletion implementations/StyleGAN2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from .model import Generator, Discriminator

from ..general import AnimeFaceDataset, to_loader, DiffAugment
from ..general import AnimeFaceDataset, to_loader
from ..gan_utils import DiffAugment

def toggle_grad(model, state):
for param in model.parameters():
Expand Down
69 changes: 69 additions & 0 deletions implementations/gan_utils/DiffAugment_kornia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@

'''
kornia based Differetiable Augmentation

I took care so that it can be used by almost the same way as the official implementation,
but modified it so that I can set the parameters from other files.

Difference from official implementation
(Mostly because of using karnia)
- padding mode for translation
official : zero
here : same
- brightness and contrast range
official : brightness [-0.5, 0.5]
contrast [ 0.5, 1.5]
here : brightness [0.75, 1.25]
contrast [0.75, 1.25]
- denorm before augmentation
official : not needed
here : karnia.augmentation requires data values between [0, 1]
so if normalized, denorm -> augment -> norm
'''

import torch
import torch.nn as nn
import kornia.augmentation as aug
from kornia.constants import SamplePadding

class kDiffAugment:
def __init__(self,
brightness=(0.75, 1.25), contrast=(0.75, 1.25), saturation=(0., 2.), translate=(0.125, 0.125),
normalized=True, mean=0.5, std=0.5, device=None
):
if normalized:
if isinstance(mean, (tuple, list)) and isinstance(std, (tuple, list)):
if not device:
raise Exception('Please specify a torch.device() object when using mean and std for each channels')
mean = torch.Tensor(mean).to(device)
std = torch.Tensor(std).to(device)
self.normalize = aug.Normalize(mean, std)
self.denormalize = aug.Denormalize(mean, std)
else:
self.normalize, self.denormalize = None, None

color_jitter = aug.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, p=1.) # rand_brightness, rand_contrast, rand_saturation
affine = aug.RandomAffine(degrees=0, translate=translate, padding_mode=SamplePadding.BORDER, p=1.) # rand_translate
cutout = aug.RandomErasing(value=0.5, p=1.) # rand_cutout

self.augmentations = {
'color' : color_jitter,
'translation' : affine,
'cutout' : cutout
}

def __call__(self, x, policy):
if self.denormalize:
x = self.denormalize(x)
policy = self.__encode_policy(policy)
for p in policy:
aug_func = self.augmentations[p]
x = aug_func(x)
if self.normalize:
x = self.normalize(x)
return x

def __encode_policy(self, policy):
if isinstance(policy, (tuple, list)):
return policy
return policy.split(',')
5 changes: 5 additions & 0 deletions implementations/gan_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# import loss functions from gan_utils.losses
from .utils import *
from .ema import EMA
from .DiffAugment_pytorch import DiffAugment
from .DiffAugment_kornia import kDiffAugment
87 changes: 87 additions & 0 deletions implementations/gan_utils/ema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@

import copy

import torch
import torch.nn as nn

class EMA:
'''Exponential Moving Avergae'''
def __init__(self, init_model, decay=0.999):
'''
args
init_model : nn.Module
the model.
please input the "initialized" model
decay: float (default:0.999)
the decay used to update the model.
usually in [0.9, 0.999]
'''
self.decay = decay
self.G_ema = copy.deepcopy(init_model)
# freeze and eval mode
for param in self.G_ema.parameters():
param.requires_grad = False
self.G_ema.cpu().eval()

def update(self, model_running):
'''update G_ema

args
model_running: nn.Module
the running model.
must be the same model as EMA
'''

# running model to cpu
original_device = next(model_running.parameters()).device
model_running.cpu()

# update params
ema_param = dict(self.G_ema.named_parameters())
run_param = dict(G.named_parameters())

for key in ema_param.keys():
ema_param[key].data.mul_(self.decay).add_(run_param[key], alpha=(1-self.decay))

# running model to original device
model_running.to(original_device)

def __call__(self, *args, **kwargs):
'''return the model's output

args
inputs to the model.
args can be on gpu.
'''
# move inputs to cpu before input
cpu_args, cpu_kwargs = [], {}
for arg in args:
if isinstance(arg, torch.Tensor):
arg = arg.cpu()
cpu_args.append(arg)
for key, value in kwargs:
if isinstance(value, torch.Tensor):
value = value.cpu()
cpu_kwargs[key] = value

with torch.no_grad():
return self.G_ema(*cpu_args, **cpu_kwargs)

if __name__ == "__main__":
from utils import get_device
import torch.nn as nn
device = get_device()
G = nn.Sequential(
nn.Conv2d(64, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 3, 3, padding=1),
nn.Tanh()
)
G.to(device)

z = torch.randn(3, 64, 4, 4, device=device)
print(G(z).size())

G_ema = EMA(G)
G_ema.update(G)
print(G_ema(z).size())
File renamed without changes.
145 changes: 145 additions & 0 deletions implementations/gan_utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@

import os
import warnings

import torch
import torchvision as tv

# device specification
def get_device(): return torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# noise sampler
# normal distribution
def sample_nnoise(size, mean=0., std=1., device=None):
if device == None:
device = get_device()
return torch.empty(size, device=device).normal_(mean, std)

# uniform distribution
def sample_unoise(size, from_=0., to=1., device=None):
if device == None:
device = get_device()
return torch.empty(size, device=device).uniform_(from_, to)


class GANTrainingStatus:
'''GAN training status helper'''
def __init__(self):
self.losses = {'G' : [], 'D' : []}
self.batches_done = 0

def append_g_loss(self, g_loss):
'''append generator loss'''
if isinstance(g_loss, torch.Tensor):
g_loss = g_loss.item()
self.losses['G'].append(g_loss)
def append_d_loss(self, d_loss):
'''append discriminator loss'''
if isinstance(d_loss, torch.Tensor):
d_loss = d_loss.item()
self.losses['D'].append(d_loss)
self.batches_done += 1

def add_loss(self, key):
'''define additional loss'''
if not isinstance(key, str):
raise Exception('Input a String object as the key. Got type {}'.format(type(key)))
self.losses[key] = []
def append_additional_loss(self, **kwargs):
'''append additional loss'''
for key, value in kwargs.items():
try:
self.losses[key].append(value)
except KeyError as ke:
warnings.warn('You have tried to append a loss keyed as \'{}\' that is not defined. Please call add_loss() or check the spelling.'.format(key))

def append(self, g_loss, d_loss, **kwargs):
'''append loss at once'''
self.append_g_loss(g_loss)
self.append_d_loss(d_loss)
self.append_additional_loss(**kwargs)

def plot_loss(self, filename='./loss.png'):
'''plot loss'''
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

add_loss = [key for key in self.losses if key not in ['G', 'D']]
G_loss = self.losses['G']
D_loss = self.losses['D']

plt.figure(figsize=(12, 8))
for key in add_loss:
plt.plot(self.losses[key])
plt.plot(G_loss)
plt.plot(D_loss)
plt.title('Model Loss')
plt.xlabel('iter')
plt.ylabel('loss')
plt.legend([key for key in add_loss] + ['Generator', 'Discriminator'], loc='upper left')

plt.tight_layout()
plt.savefig(filename)
plt.close()

def save_image(self, folder, G, *input, filename=None, nrow=5, normalize=True, range=(-1, 1)):
'''simple save image
save_image func with
sampling images, and only args that I use frequently
'''
G.eval()
with torch.no_grad():
images = G(*input)
G.train()

if filename == None:
filename = '{}.png'.format(self.batches_done)

tv.utils.save_image(
images, os.path.join(folder, filename), nrow=nrow, normalize=normalize, range=range
)

def __str__(self):
'''print the latest losses when calling print() on the object'''
partial_msg = []
partial_msg += [
'{:6}'.format(self.batches_done),
'[D Loss : {:.5f}]'.format(self.losses['D'][-1]),
'[G Loss : {:.5f}]'.format(self.losses['G'][-1]),
]
# verbose additinal loss
add_loss = [key for key in self.losses if key not in ['D', 'G']]
if len(add_loss) > 0:
for key in add_loss:
if self.losses[key] == []: # skip when no entry
continue
partial_msg.append(
'[{} : {:.5f}]'.format(key, self.losses[key][-1])
)
return '\t'.join(partial_msg)



if __name__ == "__main__":
'''TEST'''
# device = get_device()
# print(sample_nnoise((3, 64), 0, 0.02).size())
# print(sample_nnoise((3, 64), 0, 0.02, torch.device('cpu')).size())
# print(sample_nnoise((3, 64), 0, 0.02, device).size())
# print(sample_unoise((3, 64), 0, 3).size())
# print(sample_unoise((3, 64), 0, 3, torch.device('cpu')).size())
# print(sample_unoise((3, 64), 0, 3, device).size())

status = GANTrainingStatus()

status.add_loss('real')
status.add_loss('fake')

import math

for i in range(100):
status.append(math.sin(i), math.cos(i), real=-i/10, fake=i/10)
print(status)

status.plot_loss()
2 changes: 0 additions & 2 deletions implementations/general/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from .danbooru import DanbooruDataset, GeneratePairImageDanbooruDataset
from .danbooru_portrait import DanbooruPortraitDataset

from .DiffAugment_pytorch import DiffAugment

from .fp16 import network_to_half

from torch.utils.data import DataLoader
Expand Down