-
Notifications
You must be signed in to change notification settings - Fork 491
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f732f39
commit 46049c0
Showing
9 changed files
with
2,650 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# NVIDIA CORPORATION and its licensors retain all intellectual property | ||
# and proprietary rights in and to this software, related documentation | ||
# and any modifications thereto. Any use, reproduction, disclosure or | ||
# distribution of this software and related documentation without an express | ||
# license agreement from NVIDIA CORPORATION is strictly prohibited. | ||
|
||
# empty |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# NVIDIA CORPORATION and its licensors retain all intellectual property | ||
# and proprietary rights in and to this software, related documentation | ||
# and any modifications thereto. Any use, reproduction, disclosure or | ||
# distribution of this software and related documentation without an express | ||
# license agreement from NVIDIA CORPORATION is strictly prohibited. | ||
|
||
"""Streaming images and labels from datasets created with dataset_tool.py.""" | ||
|
||
import os | ||
import numpy as np | ||
import zipfile | ||
import PIL.Image | ||
import json | ||
import torch | ||
import dnnlib | ||
|
||
try: | ||
import pyspng | ||
except ImportError: | ||
pyspng = None | ||
|
||
#---------------------------------------------------------------------------- | ||
|
||
class Dataset(torch.utils.data.Dataset): | ||
def __init__(self, | ||
name, # Name of the dataset. | ||
raw_shape, # Shape of the raw image data (NCHW). | ||
max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. | ||
use_labels = False, # Enable conditioning labels? False = label dimension is zero. | ||
xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. | ||
random_seed = 0, # Random seed to use when applying max_size. | ||
): | ||
self._name = name | ||
self._raw_shape = list(raw_shape) | ||
self._use_labels = use_labels | ||
self._raw_labels = None | ||
self._label_shape = None | ||
|
||
# Apply max_size. | ||
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) | ||
if (max_size is not None) and (self._raw_idx.size > max_size): | ||
np.random.RandomState(random_seed).shuffle(self._raw_idx) | ||
self._raw_idx = np.sort(self._raw_idx[:max_size]) | ||
|
||
# Apply xflip. | ||
self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) | ||
if xflip: | ||
self._raw_idx = np.tile(self._raw_idx, 2) | ||
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) | ||
|
||
def _get_raw_labels(self): | ||
if self._raw_labels is None: | ||
self._raw_labels = self._load_raw_labels() if self._use_labels else None | ||
if self._raw_labels is None: | ||
self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) | ||
assert isinstance(self._raw_labels, np.ndarray) | ||
assert self._raw_labels.shape[0] == self._raw_shape[0] | ||
assert self._raw_labels.dtype in [np.float32, np.int64] | ||
if self._raw_labels.dtype == np.int64: | ||
assert self._raw_labels.ndim == 1 | ||
assert np.all(self._raw_labels >= 0) | ||
return self._raw_labels | ||
|
||
def close(self): # to be overridden by subclass | ||
pass | ||
|
||
def _load_raw_image(self, raw_idx): # to be overridden by subclass | ||
raise NotImplementedError | ||
|
||
def _load_raw_labels(self): # to be overridden by subclass | ||
raise NotImplementedError | ||
|
||
def __getstate__(self): | ||
return dict(self.__dict__, _raw_labels=None) | ||
|
||
def __del__(self): | ||
try: | ||
self.close() | ||
except: | ||
pass | ||
|
||
def __len__(self): | ||
return self._raw_idx.size | ||
|
||
def __getitem__(self, idx): | ||
image = self._load_raw_image(self._raw_idx[idx]) | ||
assert isinstance(image, np.ndarray) | ||
assert list(image.shape) == self.image_shape | ||
assert image.dtype == np.uint8 | ||
if self._xflip[idx]: | ||
assert image.ndim == 3 # CHW | ||
image = image[:, :, ::-1] | ||
return image.copy(), self.get_label(idx) | ||
|
||
def get_label(self, idx): | ||
label = self._get_raw_labels()[self._raw_idx[idx]] | ||
if label.dtype == np.int64: | ||
onehot = np.zeros(self.label_shape, dtype=np.float32) | ||
onehot[label] = 1 | ||
label = onehot | ||
return label.copy() | ||
|
||
def get_details(self, idx): | ||
d = dnnlib.EasyDict() | ||
d.raw_idx = int(self._raw_idx[idx]) | ||
d.xflip = (int(self._xflip[idx]) != 0) | ||
d.raw_label = self._get_raw_labels()[d.raw_idx].copy() | ||
return d | ||
|
||
@property | ||
def name(self): | ||
return self._name | ||
|
||
@property | ||
def image_shape(self): | ||
return list(self._raw_shape[1:]) | ||
|
||
@property | ||
def num_channels(self): | ||
assert len(self.image_shape) == 3 # CHW | ||
return self.image_shape[0] | ||
|
||
@property | ||
def resolution(self): | ||
assert len(self.image_shape) == 3 # CHW | ||
assert self.image_shape[1] == self.image_shape[2] | ||
return self.image_shape[1] | ||
|
||
@property | ||
def label_shape(self): | ||
if self._label_shape is None: | ||
raw_labels = self._get_raw_labels() | ||
if raw_labels.dtype == np.int64: | ||
self._label_shape = [int(np.max(raw_labels)) + 1] | ||
else: | ||
self._label_shape = raw_labels.shape[1:] | ||
return list(self._label_shape) | ||
|
||
@property | ||
def label_dim(self): | ||
assert len(self.label_shape) == 1 | ||
return self.label_shape[0] | ||
|
||
@property | ||
def has_labels(self): | ||
return any(x != 0 for x in self.label_shape) | ||
|
||
@property | ||
def has_onehot_labels(self): | ||
return self._get_raw_labels().dtype == np.int64 | ||
|
||
#---------------------------------------------------------------------------- | ||
|
||
class ImageFolderDataset(Dataset): | ||
def __init__(self, | ||
path, # Path to directory or zip. | ||
resolution = None, # Ensure specific resolution, None = highest available. | ||
**super_kwargs, # Additional arguments for the Dataset base class. | ||
): | ||
self._path = path | ||
self._zipfile = None | ||
|
||
if os.path.isdir(self._path): | ||
self._type = 'dir' | ||
self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} | ||
elif self._file_ext(self._path) == '.zip': | ||
self._type = 'zip' | ||
self._all_fnames = set(self._get_zipfile().namelist()) | ||
else: | ||
raise IOError('Path must point to a directory or zip') | ||
|
||
PIL.Image.init() | ||
self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) | ||
if len(self._image_fnames) == 0: | ||
raise IOError('No image files found in the specified path') | ||
|
||
name = os.path.splitext(os.path.basename(self._path))[0] | ||
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) | ||
if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): | ||
raise IOError('Image files do not match the specified resolution') | ||
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) | ||
|
||
@staticmethod | ||
def _file_ext(fname): | ||
return os.path.splitext(fname)[1].lower() | ||
|
||
def _get_zipfile(self): | ||
assert self._type == 'zip' | ||
if self._zipfile is None: | ||
self._zipfile = zipfile.ZipFile(self._path) | ||
return self._zipfile | ||
|
||
def _open_file(self, fname): | ||
if self._type == 'dir': | ||
return open(os.path.join(self._path, fname), 'rb') | ||
if self._type == 'zip': | ||
return self._get_zipfile().open(fname, 'r') | ||
return None | ||
|
||
def close(self): | ||
try: | ||
if self._zipfile is not None: | ||
self._zipfile.close() | ||
finally: | ||
self._zipfile = None | ||
|
||
def __getstate__(self): | ||
return dict(super().__getstate__(), _zipfile=None) | ||
|
||
def _load_raw_image(self, raw_idx): | ||
fname = self._image_fnames[raw_idx] | ||
with self._open_file(fname) as f: | ||
if pyspng is not None and self._file_ext(fname) == '.png': | ||
image = pyspng.load(f.read()) | ||
else: | ||
image = np.array(PIL.Image.open(f)) | ||
if image.ndim == 2: | ||
image = image[:, :, np.newaxis] # HW => HWC | ||
image = image.transpose(2, 0, 1) # HWC => CHW | ||
return image | ||
|
||
def _load_raw_labels(self): | ||
fname = 'dataset.json' | ||
if fname not in self._all_fnames: | ||
return None | ||
with self._open_file(fname) as f: | ||
labels = json.load(f)['labels'] | ||
if labels is None: | ||
return None | ||
labels = dict(labels) | ||
labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] | ||
labels = np.array(labels) | ||
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) | ||
return labels | ||
|
||
#---------------------------------------------------------------------------- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# NVIDIA CORPORATION and its licensors retain all intellectual property | ||
# and proprietary rights in and to this software, related documentation | ||
# and any modifications thereto. Any use, reproduction, disclosure or | ||
# distribution of this software and related documentation without an express | ||
# license agreement from NVIDIA CORPORATION is strictly prohibited. | ||
|
||
"""Loss functions.""" | ||
|
||
import numpy as np | ||
import torch | ||
from torch_utils import training_stats | ||
from torch_utils.ops import conv2d_gradfix | ||
from torch_utils.ops import upfirdn2d | ||
|
||
#---------------------------------------------------------------------------- | ||
|
||
class Loss: | ||
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass | ||
raise NotImplementedError() | ||
|
||
#---------------------------------------------------------------------------- | ||
|
||
class StyleGAN2Loss(Loss): | ||
def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0): | ||
super().__init__() | ||
self.device = device | ||
self.G = G | ||
self.D = D | ||
self.augment_pipe = augment_pipe | ||
self.r1_gamma = r1_gamma | ||
self.style_mixing_prob = style_mixing_prob | ||
self.pl_weight = pl_weight | ||
self.pl_batch_shrink = pl_batch_shrink | ||
self.pl_decay = pl_decay | ||
self.pl_no_weight_grad = pl_no_weight_grad | ||
self.pl_mean = torch.zeros([], device=device) | ||
self.blur_init_sigma = blur_init_sigma | ||
self.blur_fade_kimg = blur_fade_kimg | ||
|
||
def run_G(self, z, c, update_emas=False): | ||
ws = self.G.mapping(z, c, update_emas=update_emas) | ||
if self.style_mixing_prob > 0: | ||
with torch.autograd.profiler.record_function('style_mixing'): | ||
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) | ||
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) | ||
ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] | ||
img = self.G.synthesis(ws, update_emas=update_emas) | ||
return img, ws | ||
|
||
def run_D(self, img, c, blur_sigma=0, update_emas=False): | ||
blur_size = np.floor(blur_sigma * 3) | ||
if blur_size > 0: | ||
with torch.autograd.profiler.record_function('blur'): | ||
f = torch.arange(-blur_size, blur_size + 1, device=img.device).div(blur_sigma).square().neg().exp2() | ||
img = upfirdn2d.filter2d(img, f / f.sum()) | ||
if self.augment_pipe is not None: | ||
img = self.augment_pipe(img) | ||
logits = self.D(img, c, update_emas=update_emas) | ||
return logits | ||
|
||
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): | ||
assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] | ||
if self.pl_weight == 0: | ||
phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase) | ||
if self.r1_gamma == 0: | ||
phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase) | ||
blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 | ||
|
||
# Gmain: Maximize logits for generated images. | ||
if phase in ['Gmain', 'Gboth']: | ||
with torch.autograd.profiler.record_function('Gmain_forward'): | ||
gen_img, _gen_ws = self.run_G(gen_z, gen_c) | ||
gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma) | ||
training_stats.report('Loss/scores/fake', gen_logits) | ||
training_stats.report('Loss/signs/fake', gen_logits.sign()) | ||
loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) | ||
training_stats.report('Loss/G/loss', loss_Gmain) | ||
with torch.autograd.profiler.record_function('Gmain_backward'): | ||
loss_Gmain.mean().mul(gain).backward() | ||
|
||
# Gpl: Apply path length regularization. | ||
if phase in ['Greg', 'Gboth']: | ||
with torch.autograd.profiler.record_function('Gpl_forward'): | ||
batch_size = gen_z.shape[0] // self.pl_batch_shrink | ||
gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size]) | ||
pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) | ||
with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad): | ||
pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] | ||
pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() | ||
pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) | ||
self.pl_mean.copy_(pl_mean.detach()) | ||
pl_penalty = (pl_lengths - pl_mean).square() | ||
training_stats.report('Loss/pl_penalty', pl_penalty) | ||
loss_Gpl = pl_penalty * self.pl_weight | ||
training_stats.report('Loss/G/reg', loss_Gpl) | ||
with torch.autograd.profiler.record_function('Gpl_backward'): | ||
loss_Gpl.mean().mul(gain).backward() | ||
|
||
# Dmain: Minimize logits for generated images. | ||
loss_Dgen = 0 | ||
if phase in ['Dmain', 'Dboth']: | ||
with torch.autograd.profiler.record_function('Dgen_forward'): | ||
gen_img, _gen_ws = self.run_G(gen_z, gen_c, update_emas=True) | ||
gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True) | ||
training_stats.report('Loss/scores/fake', gen_logits) | ||
training_stats.report('Loss/signs/fake', gen_logits.sign()) | ||
loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) | ||
with torch.autograd.profiler.record_function('Dgen_backward'): | ||
loss_Dgen.mean().mul(gain).backward() | ||
|
||
# Dmain: Maximize logits for real images. | ||
# Dr1: Apply R1 regularization. | ||
if phase in ['Dmain', 'Dreg', 'Dboth']: | ||
name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1' | ||
with torch.autograd.profiler.record_function(name + '_forward'): | ||
real_img_tmp = real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth']) | ||
real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma) | ||
training_stats.report('Loss/scores/real', real_logits) | ||
training_stats.report('Loss/signs/real', real_logits.sign()) | ||
|
||
loss_Dreal = 0 | ||
if phase in ['Dmain', 'Dboth']: | ||
loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) | ||
training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) | ||
|
||
loss_Dr1 = 0 | ||
if phase in ['Dreg', 'Dboth']: | ||
with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): | ||
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] | ||
r1_penalty = r1_grads.square().sum([1,2,3]) | ||
loss_Dr1 = r1_penalty * (self.r1_gamma / 2) | ||
training_stats.report('Loss/r1_penalty', r1_penalty) | ||
training_stats.report('Loss/D/reg', loss_Dr1) | ||
|
||
with torch.autograd.profiler.record_function(name + '_backward'): | ||
(loss_Dreal + loss_Dr1).mean().mul(gain).backward() | ||
|
||
#---------------------------------------------------------------------------- |
Oops, something went wrong.