In [0]:
#utils.py
import os
import yaml
import torch
import random
import inspect
import numpy as np
from typing import Dict
from pickle import load, dump
from functools import partial
from fractions import Fraction
import torch.nn.functional as F
from argparse import ArgumentParser

EPSILON = 1e-8
half_tensor = None


def generate_samples(generator, gen_input):
    return generator(gen_input)[0]['x'].data.cpu().numpy()


def save_pkl(file_name, obj):
    with open(file_name, 'wb') as f:
        dump(obj, f, protocol=4)


def load_pkl(file_name):
    if not os.path.exists(file_name):
        return None
    with open(file_name, 'rb') as f:
        return load(f)


def get_half(num_latents, latent_size):
    global half_tensor
    if half_tensor is None or half_tensor.size() != (num_latents, latent_size):
        half_tensor = torch.ones(num_latents, latent_size) * 0.5
    return half_tensor


def random_latents(num_latents, latent_size, z_distribution='normal'):
    if z_distribution == 'normal':
        return torch.randn(num_latents, latent_size)
    elif z_distribution == 'censored':
        return F.relu(torch.randn(num_latents, latent_size))
    elif z_distribution == 'bernoulli':
        return torch.bernoulli(get_half(num_latents, latent_size))
    else:
        raise ValueError()


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)


def create_result_subdir(results_dir, experiment_name, dir_pattern='{new_num:03}-{exp_name}'):
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    fnames = os.listdir(results_dir)
    max_num = max(map(int, filter(lambda x: all(y.isdigit() for y in x), (x.split('-')[0] for x in fnames))), default=0)
    path = os.path.join(results_dir, dir_pattern.format(new_num=max_num + 1, exp_name=experiment_name))
    os.makedirs(path, exist_ok=False)
    return path


def num_params(net):
    model_parameters = trainable_params(net)
    return sum([np.prod(p.size()) for p in model_parameters])


def generic_arg_parse(x, hinttype=None):
    if hinttype is int or hinttype is float or hinttype is str:
        return hinttype(x)
    try:
        for _ in range(2):
            x = x.strip('\'').strip("\"")
        __special_tmp = eval(x, {}, {})
    except:  # the string contained some name - probably path, treat as string
        __special_tmp = x  # treat as string
    return __special_tmp


def create_params(classes, excludes=None, overrides=None):
    params = {}
    if not excludes:
        excludes = {}
    if not overrides:
        overrides = {}
    for cls in classes:
        nm = cls.__name__
        params[nm] = {
            k: (v.default if nm not in overrides or k not in overrides[nm] else overrides[nm][k])
            for k, v in dict(inspect.signature(cls.__init__).parameters).items()
            if v.default != inspect._empty and (nm not in excludes or k not in excludes[nm])
        }
    return params


def get_structured_params(params):
    new_params = {}
    for p in params:
        if '.' in p:
            [cls, attr] = p.split('.', 1)
            if cls not in new_params:
                new_params[cls] = {}
            new_params[cls][attr] = params[p]
        else:
            new_params[p] = params[p]
    return new_params


def cudize(thing):
    if thing is None:
        return None
    has_cuda = torch.cuda.is_available()
    if not has_cuda:
        return thing
    if isinstance(thing, (list, tuple)):
        return [item.cuda() for item in thing]
    if isinstance(thing, dict):
        return {k: v.cuda() for k, v in thing.items()}
    return thing.cuda()


def trainable_params(model):
    return filter(lambda p: p.requires_grad, model.parameters())


def pixel_norm(h):
    mean = torch.mean(h * h, dim=1, keepdim=True)
    dom = torch.rsqrt(mean + EPSILON)
    return h * dom


def simple_argparser(default_params):
    parser = ArgumentParser()
    for k in default_params:
        parser.add_argument('--{}'.format(k), type=partial(generic_arg_parse, hinttype=type(default_params[k])))
    parser.set_defaults(**default_params)
    return get_structured_params(vars(parser.parse_args()))


def enable_benchmark():
    torch.backends.cudnn.benchmark = True  # for fast training(if network input size is almost constant)


def load_model(model_path, return_all=False):
    state = torch.load(model_path, map_location='cpu')
    if not return_all:
        return state['model']
    return state['model'], state['optimizer'], state['cur_nimg']


def parse_config(default_params, need_arg_classes, exclude_adam=True, read_cli=True):
    parser = ArgumentParser()
    if exclude_adam:
        excludes = {'Adam': {'lr', 'amsgrad'}}
        default_overrides = {'Adam': {'betas': (0.0, 0.99)}}
        auto_args = create_params(need_arg_classes, excludes, default_overrides)
    else:
        auto_args = create_params(need_arg_classes)
    for k in default_params:
        parser.add_argument('--{}'.format(k), type=partial(generic_arg_parse, hinttype=type(default_params[k])))
    for cls in auto_args:
        group = parser.add_argument_group(cls, 'Arguments for initialization of class {}'.format(cls))
        for k in auto_args[cls]:
            name = '{}.{}'.format(cls, k)
            group.add_argument('--{}'.format(name), type=generic_arg_parse)
            default_params[name] = auto_args[cls][k]
    parser.set_defaults(**default_params)
    if read_cli:
        params = vars(parser.parse_args())
    else:
        params = default_params
    if params['config_file']:
        print('loading config_file')
        with open(params['config_file']) as f:
            params = _update_params(params, yaml.load(f))
    params = get_structured_params(params)
    random.seed(params['random_seed'])
    np.random.seed(params['random_seed'])
    torch.manual_seed(params['random_seed'])
    if torch.cuda.is_available():
        torch.cuda.set_device(params['cuda_device'])
        torch.cuda.manual_seed_all(params['random_seed'])
        enable_benchmark()
    return params


def _update_params(params: Dict, given_conf: Dict):
    for k, v in given_conf.items():
        if isinstance(v, dict):
            for kk, vv in v.items():
                params['{}.{}'.format(k, kk)] = vv
        else:
            params[k] = v
    return params


def upsample_signal(signal, upsample_factor):
    return F.interpolate(signal, scale_factor=upsample_factor, mode='linear', align_corners=False)


def downsample_signal(signal, downsample_factor):
    return F.avg_pool1d(signal, downsample_factor, downsample_factor, 0, False, True)


def expand3d(signal):
    orig_dim = signal.dim()
    if orig_dim == 2:
        return signal[None]
    if orig_dim == 1:
        return signal[None, None]
    return signal


def resample_signal(signal, signal_freq, desired_freq, pytorch=False):
    if isinstance(signal, np.ndarray):
        new_signal = torch.from_numpy(signal)
    else:
        new_signal = signal
    orig_dim = new_signal.dim()
    if orig_dim == 2:
        new_signal = new_signal[None]
    if orig_dim == 1:
        new_signal = new_signal[None, None]
    if pytorch and (orig_dim < 3 or signal.shape[2] == 1):
        ratio = desired_freq / signal_freq
        assert ratio == int(ratio)
        return new_signal.repeat(1, 1, int(ratio))
    if isinstance(desired_freq, float):
        if desired_freq == int(desired_freq) and signal_freq == int(signal_freq):
            desired_freq = int(desired_freq)
            signal_freq = int(signal_freq)
        else:
            desired_freq = desired_freq / signal_freq
            signal_freq = None
    ratio = Fraction(desired_freq, signal_freq)
    if ratio.numerator != 1:
        new_signal = upsample_signal(new_signal, ratio.numerator)
    if ratio.denominator != 1:
        new_signal = downsample_signal(new_signal, ratio.denominator)
    if not pytorch:
        if orig_dim == 2:
            return new_signal[0].numpy()
        return new_signal[0, 0].numpy()
    return new_signal


In [0]:
#dataset.py
import os
import glob
from random import shuffle

import torch
import numpy as np
from scipy.io import loadmat
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate

DATASET_VERSION = 6

'''
!pip install gdown
!gdown --id 1rVey-8ZN1mAYcCISie5kGuZ_2WNWZ8Jb
!gdown --id 1QMKt1c6v-Pa6sGXDia4x_OwNkSNSocCW
!gdown --id 1vMVSmtdUzIt7MfDnNzB5CO8u8Xl19k5s
!gdown --id 1GpBX7pkmUHu6JP63WH9FcvHp9Wxz7ZUi
'''


# bio_sampling_freq: 1 -> 4 -> 8 -> 16 -> 24 -> 32 -> 40 -> 60 -> 100
class EEGDataset(Dataset):
    # for 60(sampling), starting from 1 hz(sampling) [32 samples at the beginning]
    progression_scale_up = [4, 2, 2, 3, 4, 5, 3]
    progression_scale_down = [1, 1, 1, 2, 3, 4, 2]

    # for 60(sampling), starting from 0.25 hz(sampling) [8 samples at the beginning]
    # progression_scale_up   = [2, 2] + progression_scale_up
    # progression_scale_down = [1, 1] + progression_scale_down

    # for 100(sampling), starting from 1 hz(sampling) [32 samples at the beginning]
    # progression_scale_up   = progression_scale_up + [5]
    # progression_scale_down = progression_scale_down + [3]

    # for 100(sampling), starting from 0.25 hz(sampling) [8 samples at the beginning]
    # progression_scale_up   = [2, 2] + progression_scale_up + [5]
    # progression_scale_down = [1, 1] + progression_scale_down + [3]

    picked_channels = None

    # picked_channels = [3, 5, 9, 15, 16]
    # picked_channels = [3, 5, 9, 12, 13, 14, 15, 16]
    # picked_channels = [3, 5, 8, 9, 10, 13, 15, 16]

    def __init__(self, train_files, norms, given_data, validation_ratio: float = 0.1, dir_path: str = './',
                 data_sampling_freq: float = 80, start_sampling_freq: float = 1, end_sampling_freq: float = 60,
                 start_seq_len: int = 32, stride: float = 0.5, num_channels: int = 5, number_of_files: int = 100000,
                 per_user_normalization: bool = True, per_channel_normalization: bool = False):
        super().__init__()
        self.model_depth = 0
        self.alpha = 1.0
        self.dir_path = dir_path
        self.end_sampling_freq = end_sampling_freq
        seq_len = start_seq_len * end_sampling_freq / start_sampling_freq
        assert seq_len == int(seq_len), 'seq_len must be an int'
        seq_len = int(seq_len)
        self.seq_len = seq_len
        self.initial_kernel_size = start_seq_len
        self.stride = int(seq_len * stride)
        self.per_user_normalization = per_user_normalization
        self.per_channel_normalization = per_channel_normalization
        self.max_dataset_depth = len(self.progression_scale_up)
        self.norms = norms
        self.num_channels = num_channels if self.picked_channels is None else len(self.picked_channels)
        if given_data is not None:
            self.sizes = given_data[0]['sizes']
            self.files = given_data[0]['files']
            self.norms = given_data[0]['norms']
            self.data_pointers = given_data[0]['pointers']
            self.datas = [given_data[1]['arr_{}'.format(i)] for i in range(len(given_data[1].keys()))]
            return
        all_files = glob.glob(os.path.join(dir_path, '*_1.txt'))[:number_of_files]
        is_matlab = len(all_files) == 0
        if is_matlab:
            all_files = glob.glob(os.path.join(dir_path, '*.mat'))[:number_of_files]
        files = len(all_files)
        files = [i for i in range(files)]
        if train_files is None:
            shuffle(files)
            files = files[:int(len(all_files) * (1.0 - validation_ratio))]
        else:
            files = list(set(files) - set(train_files))
        self.files = files
        sizes = []
        num_points = []
        self.datas = []
        for i in tqdm(files):
            is_ok = True
            if is_matlab:
                try:
                    tmp = loadmat(all_files[i])['eeg_signal']
                    tmp = resample_signal(tmp, data_sampling_freq, end_sampling_freq)
                    size = int(np.ceil((tmp.shape[1] - seq_len + 1) / self.stride))
                except:
                    size = 0
                if size <= 0:
                    is_ok = False
                else:
                    sizes.append(size)
                    num_points.append((sizes[-1] - 1) * self.stride + seq_len)
                    if self.picked_channels is None:
                        self.datas.append(tmp[:num_channels, :num_points[-1]])
                    else:
                        self.datas.append(tmp[self.picked_channels, :num_points[-1]])
            else:
                for_range = range(num_channels) if self.picked_channels is None else self.picked_channels
                for kk, j in enumerate(for_range):
                    with open('{}_{}.txt'.format(all_files[i][:-6], j + 1)) as f:
                        tmp = list(map(float, f.read().split()))
                        tmp = np.array(tmp, dtype=np.float32)
                        tmp = resample_signal(tmp, data_sampling_freq, end_sampling_freq)
                        if kk == 0:
                            size = int(np.ceil((len(tmp) - seq_len + 1) / self.stride))
                            if size <= 0:
                                is_ok = False
                                break
                            sizes.append(size)
                            num_points.append((sizes[-1] - 1) * self.stride + seq_len)
                            self.datas.append(np.zeros((num_channels, num_points[-1]), dtype=np.float32))
                        tmp = tmp[:num_points[-1]]
                        self.datas[-1][j, :] = tmp
            if is_ok and per_user_normalization:
                self.datas[-1], is_ok = self.normalize(self.datas[-1], self.per_channel_normalization)
                if not is_ok:
                    del sizes[-1]
                    del num_points[-1]
                    del self.datas[-1]
        self.sizes = sizes
        self.data_pointers = [(i, j) for i, s in enumerate(self.sizes) for j in range(s)]
        if not per_user_normalization:
            self.normalize_all()

    @classmethod
    def from_config(cls, validation_ratio: float, dir_path: str, number_of_files: int,
                    data_sampling_freq: float, start_sampling_freq: float, end_sampling_freq: float,
                    start_seq_len: int, stride: float, num_channels: int,
                    per_user_normalization: bool, per_channel_normalization: bool):
        assert end_sampling_freq <= data_sampling_freq
        mode = per_user_normalization * 2 + per_channel_normalization * 1
        train_files = None
        train_norms = None
        datasets = [None, None]
        for index, split in enumerate(('train', 'val')):
            target_location = os.path.join(dir_path, '{}%_{}c_{}m_{}s_{}v_{}ss_{}es_{}l_{}n_{}.npz'
                                           .format(validation_ratio, num_channels, mode, stride,
                                                   DATASET_VERSION, start_sampling_freq, end_sampling_freq,
                                                   start_seq_len, number_of_files, split))
            given_data = None
            if os.path.exists(target_location):
                print('loading {} dataset from file'.format(split))
                if split == 'val' and validation_ratio == 0.0:
                    print('creating {} dataset from scratch'.format(split))
                else:
                    given_data = (load_pkl(target_location + '.pkl'), np.load(target_location))
            else:
                print('creating {} dataset from scratch'.format(split))
            dataset = cls(train_files, train_norms, given_data, validation_ratio, dir_path, data_sampling_freq,
                          start_sampling_freq, end_sampling_freq, start_seq_len, stride, num_channels,
                          number_of_files, per_user_normalization, per_channel_normalization)
            if train_files is None:
                train_files = dataset.files
                train_norms = dataset.norms
            if given_data is None:
                np.savez_compressed(target_location, *dataset.datas)
                save_pkl(target_location + '.pkl', {'sizes': dataset.sizes, 'pointers': dataset.data_pointers,
                                                    'norms': dataset.norms, 'files': dataset.files})
            datasets[index] = dataset
        return datasets[0], datasets[1]

    def normalize_all(self):
        num_files = len(self.datas)
        if self.norms is None:
            all_max = np.max(
                np.array([data.max(axis=1 if self.per_channel_normalization else None) for data in self.datas]), axis=0)
            all_min = np.min(
                np.array([data.min(axis=1 if self.per_channel_normalization else None) for data in self.datas]), axis=0)
            self.norms = (all_max, all_min)
        else:
            all_max, all_min = self.norms
        is_ok = True
        for i in range(num_files):
            self.datas[i], is_ok = self.normalize(self.datas[i], self.per_channel_normalization, all_max, all_min)
        if not is_ok:
            raise ValueError('data is constant!')

    @staticmethod
    def normalize(arr, per_channel, arr_max=None, arr_min=None):
        if arr_max is None:
            arr_max = arr.max(axis=1 if per_channel else None)
        if arr_min is None:
            arr_min = arr.min(axis=1 if per_channel else None)
        is_ok = arr_max != arr_min
        if per_channel:
            is_ok = is_ok.all()
        return ((arr - arr_min) / ((arr_max - arr_min) if is_ok else 1.0)) * 2.0 - 1.0, is_ok

    @property
    def shape(self):
        return len(self), self.num_channels, self.seq_len

    def __len__(self):
        return len(self.data_pointers)

    def load_file(self, item):
        i, k = self.data_pointers[item]
        res = self.datas[i][:, k * self.stride:k * self.stride + self.seq_len]
        return res

    def resample_data(self, data, index, forward=True, alpha_fade=False):
        up_scale = self.progression_scale_up[index - (1 if alpha_fade else 0)]
        down_scale = self.progression_scale_down[index - (1 if alpha_fade else 0)]
        if forward:
            return resample_signal(data, down_scale, up_scale, True)
        return resample_signal(data, up_scale, down_scale, True)

    def __getitem__(self, item):
        with torch.no_grad():
            datapoint = torch.from_numpy(self.load_file(item).astype(np.float32)).unsqueeze(0)
            target_depth = self.model_depth
            if self.max_dataset_depth != target_depth:
                datapoint = self.create_datapoint_from_depth(datapoint, target_depth)
        return {'x': self.alpha_fade(datapoint).squeeze(0)}

    def create_datapoint_from_depth(self, datapoint, target_depth):
        depth_diff = (self.max_dataset_depth - target_depth)
        for index in reversed(list(range(len(self.progression_scale_up)))[-depth_diff:]):
            datapoint = self.resample_data(datapoint, index, False)
        return datapoint

    def alpha_fade(self, datapoint):
        if self.alpha == 1:
            return datapoint
        t = self.resample_data(datapoint, self.model_depth, False, alpha_fade=True)
        t = self.resample_data(t, self.model_depth, True, alpha_fade=True)
        return datapoint + (t - datapoint) * (1 - self.alpha)


def get_collate_real(max_sampling_freq, max_len):
    def collate_real(batch):
        return cudize(default_collate(batch))

    return collate_real


def get_collate_fake(latent_size, z_distribution, collate_real):
    def collate_fake(batch):
        batch = collate_real(batch)  # extract condition(features)
        batch['z'] = random_latents(batch['x'].size(0), latent_size, z_distribution)
        del batch['x']
        return batch

    return collate_fake


In [0]:
#layers.py
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import calculate_gain
from torch.nn.utils import spectral_norm


class PixelNorm(nn.Module):
    def forward(self, x):
        return pixel_norm(x)


class ScaledTanh(nn.Tanh):
    def __init__(self, scale=0.5):
        super().__init__()
        self.scale = scale

    def forward(self, x):
        return super().forward(x * self.scale)


class GDropLayer(nn.Module):
    def __init__(self, strength=0.2, axes=(0, 1)):
        super().__init__()
        self.strength = strength
        self.axes = [axes] if isinstance(axes, int) else list(axes)

    def forward(self, x, deterministic=False):
        if deterministic or not self.strength:
            return x
        rnd_shape = [s if axis in self.axes else 1 for axis, s in enumerate(x.size())]
        rnd = (1 + self.strength) ** np.random.normal(size=rnd_shape)
        rnd = torch.from_numpy(rnd).type(x.data.type()).to(x)
        return x * rnd


class SelfAttention(nn.Module):
    def __init__(self, channels_in, spectral, init='xavier_uniform'):
        super().__init__()
        d_key = max(channels_in // 8, 2)
        conv_conf = dict(kernel_size=1, equalized=False, spectral=spectral,
                         init=init, bias=False, act_alpha=-1)
        self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True)
        self.pooling = nn.MaxPool1d(4)
        self.key_conv = GeneralConv(channels_in, d_key, **conv_conf)
        self.query_conv = GeneralConv(channels_in, d_key, **conv_conf)
        self.value_conv = GeneralConv(channels_in, channels_in // 2, **conv_conf)
        self.final_conv = GeneralConv(channels_in // 2, channels_in, **conv_conf)
        self.softmax = nn.Softmax(dim=-1)
        self.scale = 1.0

    def forward(self, x):  # BCT
        query = self.query_conv(x)  # BC/8T
        key = self.pooling(self.key_conv(x))  # BC/8T[/4]
        value = self.pooling(self.value_conv(x))  # BC/2T[/4]
        out = F.softmax(torch.bmm(key.permute(0, 2, 1), query) / self.scale, dim=1)  # Bnormed(T[/4])T
        attention_map = out
        out = torch.bmm(value, out)  # BC/2T
        out = self.final_conv(out)  # BCT
        return self.gamma * out + x, attention_map


class MinibatchStddev(nn.Module):
    def __init__(self, group_size=4, temporal_groups_per_window=1, kernel_size=32):
        super().__init__()
        self.group_size = group_size if group_size != 0 else 1e6
        self.kernel_size = kernel_size
        self.stride_size = self.kernel_size // temporal_groups_per_window

    def forward(self, x):  # B, C, T
        s = x.size()
        group_size = min(s[0], self.group_size)
        all_y = []
        for i in range(s[2] // self.stride_size):
            y = x[..., i * self.stride_size:(i + 1) * self.stride_size]
            y = y.view(group_size, -1, s[1], self.stride_size)  # G,B//G,C,T
            y = y - y.mean(dim=0, keepdim=True)  # G,B//G,C,T
            y = torch.sqrt((y ** 2).mean(dim=0))  # B//G,C,T
            y = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)  # B//G,1,1
            y = y.repeat((group_size, 1, self.stride_size))  # B,1,T
            all_y.append(y)
        return torch.cat([x, torch.cat(all_y, dim=2)], dim=1)


class ConditionalBatchNorm(nn.Module):
    def __init__(self, num_features, num_classes, latent_size, spectral):
        super().__init__()
        no_cond = latent_size == 0 and num_classes == 0
        self.num_features = num_features
        self.normalizer = nn.BatchNorm1d(num_features, affine=no_cond)
        if no_cond:
            self.mode = 'BN'  # batch norm
        elif latent_size == 0:
            self.embed = GeneralConv(num_classes, num_features * 2, kernel_size=1, equalized=False,
                                     act_alpha=-1, spectral=spectral, bias=False)
            self.mode = 'CBN'  # conditional batch norm
        else:  # both 'SM'(self modulation) and 'CSM'(conditional self modulation)
            # NOTE maybe reduce it to a single layer linear network?
            self.embed = nn.Sequential(
                GeneralConv(latent_size + num_classes, num_features * 2, kernel_size=1,
                            equalized=False, act_alpha=0.0, spectral=spectral, bias=True),
                GeneralConv(num_features * 2, num_features * 2, kernel_size=1,
                            equalized=False, act_alpha=-1, spectral=spectral, bias=False))
            if num_classes == 0:
                self.mode = 'SM'  # self modulation
            else:
                self.mode = 'CSM'  # conditional self modulation(biggan)

    def forward(self, x, y, z):  # y = B*num_classes*Ty ; x = B*num_features*Tx ; z = B*latent_size
        out = self.normalizer(x)
        if self.mode == 'BN':
            return out
        if y is not None and y.ndimension() == 2:
            y = y.unsqueeze(2)
        if self.mode == 'CBN':
            cond = y
        else:
            if self.mode == 'CSM':
                z = expand3d(z)
                cond = torch.cat([resample_signal(z, z.size(2), y.size(2), pytorch=True), y], dim=1)
            else:
                cond = expand3d(z)
        embed = self.embed(cond)  # B, num_features*2, Ty
        embed = resample_signal(embed, embed.shape[2], out.shape[2], pytorch=True)
        gamma, beta = embed.chunk(2, dim=1)
        return out + gamma * out + beta  # trick to make sure gamma is 1.0 at the beginning of the training


class EqualizedSeparableConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 padding, spectral, equalized, init, act_alpha, bias, groups, stride):
        super().__init__()
        self.net = nn.Sequential(
            EqualizedConv1d(in_channels, in_channels, kernel_size, padding,
                            spectral, equalized, init, -1, True, groups=in_channels, stride=stride),
            EqualizedConv1d(in_channels, out_channels, 1, 0, spectral, equalized, init, act_alpha, bias, groups, 1))

    def forward(self, x):
        return self.net(x)


class EqualizedConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding,
                 spectral, equalized, init, act_alpha, bias, groups, stride):
        super().__init__()
        self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, stride=stride,
                              kernel_size=kernel_size, padding=padding, bias=bias, groups=groups)
        if bias:
            self.conv.bias.data.zero_()
        act_alpha = act_alpha if act_alpha > 0 else 1
        if init == 'kaiming_normal':
            torch.nn.init.kaiming_normal_(self.conv.weight, a=act_alpha)
        elif init == 'xavier_uniform':
            torch.nn.init.xavier_uniform_(self.conv.weight, gain=calculate_gain('leaky_relu', param=act_alpha))
        elif init == 'orthogonal':
            torch.nn.init.orthogonal_(self.conv.weight, gain=calculate_gain('leaky_relu', param=act_alpha))
        if not equalized:
            self.scale = 1.0
        else:
            self.scale = ((torch.mean(self.conv.weight.data ** 2)) ** 0.5).item()
        self.conv.weight.data.copy_(self.conv.weight.data / self.scale)
        if spectral:
            self.conv = spectral_norm(self.conv)

    def forward(self, x):
        return self.conv(x * self.scale)


class GeneralConv(nn.Module):
    def __init__(self, in_channels, out_channels, z_to_bn_size=0, kernel_size=3, equalized=True,
                 pad=None, act_alpha=0.2, do=0, num_classes=0, act_norm=None, spectral=False,
                 init='kaiming_normal', bias=True, separable=False, stride=1):
        super().__init__()
        pad = (kernel_size - 1) // 2 if pad is None else pad
        if separable:
            conv_class = EqualizedSeparableConv1d
        else:
            conv_class = EqualizedConv1d
        conv = conv_class(in_channels, out_channels, kernel_size, padding=pad, spectral=spectral,
                          equalized=equalized, init=init, act_alpha=act_alpha,
                          bias=bias if act_norm != 'batch' else False, groups=1, stride=stride)
        norm = None
        if act_norm == 'batch':
            norm = ConditionalBatchNorm(out_channels, num_classes, z_to_bn_size, spectral)
        self.conv = conv
        self.norm = norm
        self.net = []
        if act_alpha >= 0:
            if act_alpha == 0:
                self.net.append(nn.ReLU())  # DO NOT use inplace, gradient penalty will break
            else:
                self.net.append(nn.LeakyReLU(act_alpha))  # DO NOT use inplace, gradient penalty will break
        if act_norm == 'pixel':
            self.net.append(PixelNorm())
        if do != 0:
            self.net.append(GDropLayer(strength=do))
        self.net = nn.Sequential(*self.net)

    def forward(self, x, y=None, z=None, conv_noise=None):
        c = self.conv(x)
        if conv_noise is not None:
            c = c * conv_noise
        if self.norm:
            c = self.norm(c, y, z)
        return self.net(c)


class PassChannelResidual(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        if x.size(1) >= y.size(1):
            x[:, :y.size(1)] = x[:, :y.size(1)] + y
            return x
        return y[:, :x.size(1)] + x


class ConcatResidual(nn.Module):
    def __init__(self, ch_in, ch_out, equalized, spectral, init):
        super().__init__()
        assert ch_out >= ch_in
        if ch_out > ch_in:
            self.net = GeneralConv(ch_in, ch_out - ch_in, kernel_size=1, equalized=equalized, act_alpha=-1,
                                   spectral=spectral, init=init)
        else:
            self.net = None

    def forward(self, h, x):
        if self.net:
            return h + torch.cat([x, self.net(x)], dim=1)
        return h + x


In [0]:
#losses.py
import torch
import random
import torch.nn.functional as F
from torch.autograd import Variable, grad

one = None
zero = None
mixing_factors = None


def get_mixing_factor(batch_size):
    global mixing_factors
    if mixing_factors is None or batch_size != mixing_factors.size(0):
        mixing_factors = cudize(torch.FloatTensor(batch_size, 1, 1))
    mixing_factors.uniform_()
    return mixing_factors


def get_one(batch_size):
    global one
    if one is None or batch_size != one.size(0):
        one = cudize(torch.ones(batch_size))
    return one


def get_zero(batch_size):
    global zero
    if zero is None or batch_size != zero.size(0):
        zero = cudize(torch.zeros(batch_size))
    return zero


def calc_grad(x_hat, pred_hat):
    return grad(outputs=pred_hat, inputs=x_hat, grad_outputs=get_one(pred_hat.size(0)),
                create_graph=True, retain_graph=True, only_inputs=True)[0]


def generator_loss(dis: torch.nn.Module, gen: torch.nn.Module, real, z, loss_type: str,
                   random_multiply: bool, feature_matching_lambda: float = 0.0):
    gen.zero_grad()
    g_, _ = gen(z)
    d_fake, fake_features, _ = dis(g_)
    real_features = None
    scale = random.random() if random_multiply else 1.0
    if loss_type == 'hinge' or loss_type.startswith('wgan'):
        g_loss = -d_fake.mean()
    else:
        with torch.no_grad():
            d_real, real_features, _ = dis(real)
        if loss_type == 'rsgan':
            g_loss = F.binary_cross_entropy_with_logits(d_fake - d_real, get_one(d_fake.size(0)))
        elif loss_type == 'rasgan':
            batch_size = d_fake.size(0)
            g_loss = (F.binary_cross_entropy_with_logits(d_fake - d_real.mean(),
                                                         get_one(batch_size)) + F.binary_cross_entropy_with_logits(
                d_real - d_fake.mean(), get_zero(batch_size))) / 2.0
        elif loss_type == 'rahinge':
            g_loss = (torch.mean(F.relu(1.0 + (d_real - torch.mean(d_fake)))) + torch.mean(
                F.relu(1.0 - (d_fake - torch.mean(d_real))))) / 2
        else:
            raise ValueError('Invalid loss type')
    if feature_matching_lambda != 0.0:
        if real_features is None:
            with torch.no_grad():
                _, real_features, _ = dis(real)
        diff = real_features.mean(dim=0) - fake_features.mean(dim=0)
        g_loss = g_loss + (diff * diff).mean()
    return g_loss * scale


def discriminator_loss(dis: torch.nn.Module, gen: torch.nn.Module, real, z, loss_type: str,
                       iwass_drift_epsilon: float, grad_lambda: float, iwass_target: float):
    dis.zero_grad()
    d_real, _, _ = dis(real)
    with torch.no_grad():
        g_, _ = gen(z)
    d_fake, _, _ = dis(g_)
    batch_size = d_real.size(0)
    gp_gain = 1.0 if grad_lambda != 0 else 0
    if loss_type == 'hinge':
        d_loss = F.relu(1.0 - d_real).mean() + F.relu(1.0 + d_fake).mean()
    elif loss_type == 'rsgan':
        d_loss = F.binary_cross_entropy_with_logits(d_real - d_fake, get_one(batch_size))
    elif loss_type == 'rasgan':
        d_loss = (F.binary_cross_entropy_with_logits(d_real - d_fake.mean(),
                                                     get_one(batch_size)) + F.binary_cross_entropy_with_logits(
            d_fake - d_real.mean(), get_zero(batch_size))) / 2.0
    elif loss_type == 'rahinge':
        d_loss = (torch.mean(F.relu(1.0 - (d_real - d_fake.mean()))) + torch.mean(
            F.relu(1.0 + (d_fake - d_real.mean())))) / 2
    elif loss_type.startswith('wgan'):  # wgan and wgan_theirs
        d_fake_mean = d_fake.mean()
        d_real_mean = d_real.mean()
        if loss_type == 'wgan_theirs':
            d_loss = d_fake_mean - d_real_mean + (d_fake_mean + d_real_mean) ** 2 * iwass_drift_epsilon
            gp_gain = F.relu(d_real_mean - d_fake_mean)
        elif loss_type == 'wgan_gp':
            d_loss = d_fake_mean - d_real_mean + (d_real ** 2).mean() * iwass_drift_epsilon
            gp_gain = 1
        else:
            raise ValueError('Invalid loss type')
    else:
        raise ValueError('Invalid loss type')
    if gp_gain != 0 and grad_lambda != 0:
        alpha = get_mixing_factor(real['x'].size(0))
        print(type(alpha), type(real), )
        x_hat = {'x': Variable(alpha * real['x'].data + (1.0 - alpha) * g_['x'].data, requires_grad=True)}
        beta = alpha.squeeze(dim=2)
        for k in real.keys():
            if k.startswith('global_') or k.startswith('temporal_'):
                x_hat[k] = Variable(beta * real[k].data + (1.0 - beta) * g_[k].data, requires_grad=True)
        pred_hat, _, _ = dis(x_hat)
        g = calc_grad(x_hat['x'], pred_hat).view(batch_size, -1)
        gp = g.norm(p=2, dim=1) - iwass_target
        if loss_type == 'wgan_theirs':
            gp = F.relu(gp)
        gp_loss = gp_gain * (gp ** 2).mean() * grad_lambda / (iwass_target ** 2)
        d_loss = d_loss + gp_loss
    return d_loss


In [0]:
#network.py
import torch
from torch import nn
from torch.nn.utils import spectral_norm


class GBlock(nn.Module):
    def __init__(self, ch_in, ch_out, ch_rgb, k_size=3, initial_kernel_size=None, is_residual=False,
                 no_tanh=False, deep=False, per_channel_noise=False, to_rgb_mode='pggan', **layer_settings):
        super().__init__()
        is_first = initial_kernel_size is not None
        first_k_size = initial_kernel_size if is_first else k_size
        hidden_size = (ch_in // 4) if deep else ch_out
        self.c1 = GeneralConv(ch_in, hidden_size, kernel_size=first_k_size,
                              pad=initial_kernel_size - 1 if is_first else None, **layer_settings)
        self.c2 = GeneralConv(hidden_size, hidden_size, kernel_size=k_size, **layer_settings)
        if per_channel_noise:
            self.c1_noise_weight = nn.Parameter(torch.zeros(1, hidden_size, 1))
            self.c2_noise_weight = nn.Parameter(torch.zeros(1, hidden_size, 1))
        else:
            self.c1_noise_weight, self.c2_noise_weight = None, None
        if deep:
            self.c3 = GeneralConv(hidden_size, hidden_size, kernel_size=k_size, **layer_settings)
            self.c4 = GeneralConv(hidden_size, ch_out, kernel_size=k_size, **layer_settings)
            if per_channel_noise:
                self.c3_noise_weight = nn.Parameter(torch.zeros(1, hidden_size, 1))
                self.c4_noise_weight = nn.Parameter(torch.zeros(1, ch_out, 1))
            else:
                self.c3_noise_weight, self.c4_noise_weight = None, None
        reduced_layer_settings = dict(equalized=layer_settings['equalized'], spectral=layer_settings['equalized'],
                                      init=layer_settings['equalized'])
        if to_rgb_mode == 'pggan':
            to_rgb = GeneralConv(ch_out, ch_rgb, kernel_size=1, act_alpha=-1, **reduced_layer_settings)
        elif to_rgb_mode in {'sngan', 'sagan'}:
            to_rgb = GeneralConv(ch_out, ch_rgb if to_rgb_mode == 'sngan' else ch_out,
                                 kernel_size=3, act_alpha=0.2, **reduced_layer_settings)
            if to_rgb_mode == 'sagan':
                to_rgb = nn.Sequential(
                    GeneralConv(ch_out, ch_rgb, kernel_size=1, act_alpha=-1, **reduced_layer_settings), to_rgb)
        elif to_rgb_mode == 'biggan':
            to_rgb = nn.Sequential(nn.BatchNorm1d(ch_out), nn.ReLU(),
                                   GeneralConv(ch_out, ch_rgb, kernel_size=3, act_alpha=-1, **reduced_layer_settings))
        else:
            raise ValueError()
        if no_tanh:
            self.toRGB = to_rgb
        else:
            self.toRGB = nn.Sequential(to_rgb, ScaledTanh())
        if deep:
            self.residual = PassChannelResidual()
        else:
            if not is_first and is_residual:
                self.residual = nn.Sequential() if ch_in == ch_out else \
                    GeneralConv(ch_in, ch_out, 1, act_alpha=-1, **reduced_layer_settings)
            else:
                self.residual = None
        self.deep = deep

    @staticmethod
    def get_per_channel_noise(noise_weight):
        return None if noise_weight is None else torch.randn(*noise_weight.size()) * noise_weight

    def forward(self, x, y=None, z=None, last=False):
        h = self.c1(x, y=y, z=z, conv_noise=self.get_per_channel_noise(self.c1_noise_weight))
        h = self.c2(h, y=y, z=z, conv_noise=self.get_per_channel_noise(self.c2_noise_weight))
        if self.deep:
            h = self.c3(h, y=y, z=z, conv_noise=self.get_per_channel_noise(self.c3_noise_weight))
            h = self.c4(h, y=y, z=z, conv_noise=self.get_per_channel_noise(self.c4_noise_weight))
            h = self.residual(h, x)
        elif self.residual is not None:
            h = h + self.residual(x)
        if last:
            return self.toRGB(h)
        return h


class Generator(nn.Module):
    def __init__(self, initial_kernel_size, num_rgb_channels, fmap_base, fmap_max, fmap_min, kernel_size,
                 self_attention_layers, progression_scale_up, progression_scale_down, residual, separable,
                 equalized, init, act_alpha, num_classes, deep, z_distribution, spectral=False,
                 latent_size=256, no_tanh=False, per_channel_noise=False, to_rgb_mode='pggan', z_to_bn=False,
                 split_z=False, dropout=0.2, act_norm='pixel', conv_only=False, shared_embedding_size=32,
                 normalize_latents=True, rgb_generation_mode='pggan'):
        """
        :param initial_kernel_size: int, this should be always correct regardless of conv_only
        :param num_rgb_channels: int
        :param fmap_base: int
        :param fmap_max: int
        :param fmap_min: int
        :param kernel_size: int
        :param self_attention_layers: list[int]
        :param progression_scale_up: list[int]
        :param progression_scale_down: list[int]
        :param residual: bool
        :param separable: bool
        :param equalized: bool
        :param spectral: bool
        :param init: 'kaiming_normal' or 'xavier_uniform' or 'orthogonal'
        :param act_alpha: float, 0 is relu, -1 is linear and 0.2 is recommended
        :param z_distribution: 'normal' or 'bernoulli' or 'censored'
        :param latent_size: int
        :param no_tanh: bool
        :param deep: bool, in case it's true it will turn off split_z
        :param per_channel_noise: bool
        :param to_rgb_mode: 'pggan' or 'sagan' or 'sngan' or 'biggan'
        :param z_to_bn: bool, whether to concatenate z with y(if available) to feed to cbn or not
        :param split_z: bool
        :param dropout: float
        :param num_classes: int, input y.shape == (batch_size, num_classes, T_y)
        :param act_norm: 'batch' or 'pixel' or None
        :param conv_only: bool
        :param shared_embedding_size: int, in case it's none zero, y will be transformed to (batch_size, shared_embedding_size, T_y)
        :param normalize_latents: bool
        :param rgb_generation_mode: 'residual'sum([rgbs]) or 'mean'mean([rgbs]) or 'pggan'(last_rgb)
        """
        super().__init__()
        R = len(progression_scale_up)
        assert len(progression_scale_up) == len(progression_scale_down)
        self.progression_scale_up = progression_scale_up
        self.progression_scale_down = progression_scale_down
        self.depth = 0
        self.alpha = 1.0
        if deep:
            split_z = False
        self.split_z = split_z
        self.z_distribution = z_distribution
        self.conv_only = conv_only
        self.initial_kernel_size = initial_kernel_size
        self.normalize_latents = normalize_latents
        self.z_to_bn = z_to_bn

        def nf(stage):
            return min(max(int(fmap_base / (2.0 ** stage)), fmap_min), fmap_max)

        if latent_size is None:
            latent_size = nf(0)
        self.input_latent_size = latent_size
        if num_classes != 0:
            if shared_embedding_size > 0:
                self.y_encoder = GeneralConv(num_classes, shared_embedding_size, kernel_size=1,
                                             equalized=False, act_alpha=act_alpha, spectral=False, bias=False)
                num_classes = shared_embedding_size
            else:
                self.y_encoder = nn.Sequential()
        else:
            self.y_encoder = None
        if split_z:
            latent_size //= R + 2  # we also give part of the z to the first layer
        self.latent_size = latent_size
        block_settings = dict(ch_rgb=num_rgb_channels, k_size=kernel_size, is_residual=residual, deep=deep,
                              no_tanh=no_tanh, per_channel_noise=per_channel_noise, to_rgb_mode=to_rgb_mode)
        layer_settings = dict(z_to_bn_size=latent_size if z_to_bn else 0, equalized=equalized, spectral=spectral,
                              init=init, act_alpha=act_alpha, do=dropout, num_classes=num_classes, act_norm=act_norm,
                              bias=True, separable=separable)
        self.block0 = GBlock(latent_size, nf(1), **block_settings, **layer_settings,
                             initial_kernel_size=None if conv_only else initial_kernel_size)
        dummy = []  # to make SA layers registered
        self.self_attention = dict()
        for layer in self_attention_layers:
            dummy.append(SelfAttention(nf(layer + 1), spectral, init))
            self.self_attention[layer] = dummy[-1]
        if len(dummy) != 0:
            self.dummy = nn.ModuleList(dummy)
        self.blocks = nn.ModuleList(
            [GBlock(nf(i + 1), nf(i + 2), **block_settings, **layer_settings) for i in range(R)])
        self.max_depth = len(self.blocks)
        self.deep = deep
        self.rgb_generation_mode = rgb_generation_mode

    def _split_z(self, l, z):
        if not self.z_to_bn:
            return None
        if self.split_z:
            return z[:, (2 + l) * self.latent_size:(3 + l) * self.latent_size]
        return z

    def _do_layer(self, l, h, y, z):
        if l in self.self_attention:
            h, attention_map = self.self_attention[l](h)
        else:
            attention_map = None
        h = resample_signal(h, self.progression_scale_down[l], self.progression_scale_up[l], True)
        return self.blocks[l](h, y, self._split_z(l, z), last=False), attention_map

    def _combine_rgbs(self, last_rgb, saved_rgbs):
        if self.rgb_generation_mode == 'residual':
            return_value = saved_rgbs[0]
            for rgb in saved_rgbs[1:]:
                return_value = resample_signal(return_value, return_value.size(2), rgb.size(2)) + rgb
            if self.alpha == 1.0:
                return return_value
            return return_value - (1.0 - self.alpha) * saved_rgbs[-1]
        elif self.rgb_generation_mode == 'mean':
            return_value = saved_rgbs[0]
            for rgb in saved_rgbs[1:]:
                return_value = resample_signal(return_value, return_value.size(2), rgb.size(2)) + rgb
            return_value = return_value / len(saved_rgbs)
            if self.alpha == 1.0:
                return return_value
            return (return_value * len(saved_rgbs) - saved_rgbs[-1]) / (len(saved_rgbs) - 1) * (
                    1.0 - self.alpha) + return_value * self.alpha
        return last_rgb

    def _wrap_output(self, last_rgb, all_rgbs, y):
        return {'x': self._combine_rgbs(last_rgb, all_rgbs), 'y': y}

    def forward(self, z):
        if isinstance(z, dict):
            z, y = z['z'], z.get('y', None)
        elif isinstance(z, tuple):
            z, y = z
        else:
            y = None
        if y is not None:
            if y.ndimension() == 2:
                y = y.unsqueeze(2)
            if self.y_encoder is not None:
                y = self.y_encoder(y)
            else:
                y = None
        if self.normalize_latents:
            z = pixel_norm(z)
        if z.ndimension() == 2:
            z = z.unsqueeze(2)
        if self.split_z and not self.deep:
            h = z[:, :self.latent_size, :]
        else:
            h = z
        save_rgb = self.rgb_generation_mode != 'pggan'
        saved_rgbs = []
        if self.depth == 0:
            h = self.block0(h, y, self._split_z(-1, z), last=True)
            if save_rgb:
                saved_rgbs.append(h)
            return self._wrap_output(h, saved_rgbs, y), {}
        h = self.block0(h, y, self._split_z(-1, z))
        if save_rgb:
            saved_rgbs.append(self.block0.toRGB(h))
        all_attention_maps = {}
        for i in range(self.depth - 1):
            h, attention_map = self._do_layer(i, h, y, z)
            if save_rgb:
                saved_rgbs.append(self.blocks[i].toRGB(h))
            if attention_map is not None:
                all_attention_maps[i] = attention_map
        h = resample_signal(h, self.progression_scale_down[self.depth - 1], self.progression_scale_up[self.depth - 1],
                            True)
        ult = self.blocks[self.depth - 1](h, y, self._split_z(self.depth - 1, z), True)
        if save_rgb:
            saved_rgbs.append(ult)
        if self.alpha == 1.0:
            return self._wrap_output(ult, saved_rgbs, y), all_attention_maps
        preult_rgb = self.blocks[self.depth - 2].toRGB(h) if self.depth > 1 else self.block0.toRGB(h)
        return self._wrap_output(preult_rgb * (1.0 - self.alpha) + ult * self.alpha, saved_rgbs, y), all_attention_maps


class DBlock(nn.Module):
    def __init__(self, ch_in, ch_out, ch_rgb, k_size=3, initial_kernel_size=None, is_residual=False,
                 deep=False, group_size=4, temporal_groups_per_window=1, conv_disc=False, **layer_settings):
        super().__init__()
        is_last = initial_kernel_size is not None
        self.net = []
        if is_last:
            self.net.append(MinibatchStddev(group_size, temporal_groups_per_window, initial_kernel_size))
        hidden_size = (ch_out // 4) if deep else ch_in
        self.net.append(
            GeneralConv(ch_in + (1 if is_last else 0), hidden_size, kernel_size=k_size, **layer_settings))
        if deep:
            self.net.append(
                GeneralConv(hidden_size, hidden_size, kernel_size=k_size, **layer_settings))
            self.net.append(
                GeneralConv(hidden_size, hidden_size, kernel_size=k_size, **layer_settings))
        is_linear_last = is_last and not conv_disc
        self.net.append(GeneralConv(hidden_size, ch_out, kernel_size=initial_kernel_size if is_linear_last else k_size,
                                    pad=0 if is_linear_last else None, **layer_settings))
        self.net = nn.Sequential(*self.net)
        reduced_layer_settings = dict(equalized=layer_settings['equalized'], spectral=layer_settings['equalized'],
                                      init=layer_settings['equalized'])
        self.fromRGB = GeneralConv(ch_rgb, ch_in, kernel_size=1, act_alpha=layer_settings['act_alpha'],
                                   **reduced_layer_settings)
        if deep:
            self.residual = ConcatResidual(ch_in, ch_out, **reduced_layer_settings)
        else:
            if is_residual and (not is_last or conv_disc):
                self.residual = nn.Sequential() if ch_in == ch_out else GeneralConv(ch_in, ch_out, kernel_size=1,
                                                                                    act_alpha=-1,
                                                                                    **reduced_layer_settings)
            else:
                self.residual = None
        self.deep = deep

    def forward(self, x, first=False):
        if first:
            x = self.fromRGB(x)
        h = self.net(x)
        if self.deep:
            return self.residual(h, x)
        if self.residual:
            h = h + self.residual(x)
        return h


class Discriminator(nn.Module):
    def __init__(self, initial_kernel_size, num_rgb_channels, fmap_base, fmap_max, fmap_min, kernel_size,
                 self_attention_layers, progression_scale_up, progression_scale_down, residual, separable,
                 equalized, init, act_alpha, num_classes, deep, spectral=False, dropout=0.2, act_norm=None,
                 group_size=4, temporal_groups_per_window=1, conv_only=False, input_to_all_layers=False):
        """
        NOTE we only support global conidtioning(not temporal) for now
        :param initial_kernel_size:
        :param num_rgb_channels:
        :param fmap_base:
        :param fmap_max:
        :param fmap_min:
        :param kernel_size:
        :param self_attention_layers:
        :param progression_scale_up:
        :param progression_scale_down:
        :param residual:
        :param separable:
        :param equalized:
        :param spectral:
        :param init:
        :param act_alpha:
        :param num_classes:
        :param deep:
        :param dropout:
        :param act_norm:
        :param group_size:
        :param temporal_groups_per_window:
        :param conv_only:
        :param input_to_all_layers:
        """
        super().__init__()
        R = len(progression_scale_up)
        assert len(progression_scale_up) == len(progression_scale_down)
        self.progression_scale_up = progression_scale_up
        self.progression_scale_down = progression_scale_down
        self.depth = 0
        self.alpha = 1.0
        self.input_to_all_layers = input_to_all_layers

        def nf(stage):
            return min(max(int(fmap_base / (2.0 ** stage)), fmap_min), fmap_max)

        layer_settings = dict(equalized=equalized, spectral=spectral, init=init, act_alpha=act_alpha,
                              do=dropout, num_classes=0, act_norm=act_norm, bias=True, separable=separable)
        block_settings = dict(ch_rgb=num_rgb_channels, k_size=kernel_size, is_residual=residual, conv_disc=conv_only,
                              group_size=group_size, temporal_groups_per_window=temporal_groups_per_window, deep=deep)

        last_block = DBlock(nf(1), nf(0), initial_kernel_size=initial_kernel_size, **block_settings, **layer_settings)
        dummy = []  # to make SA layers registered
        self.self_attention = dict()
        for layer in self_attention_layers:
            dummy.append(SelfAttention(nf(layer + 1), spectral, init))
            self.self_attention[layer] = dummy[-1]
        if len(dummy):
            self.dummy = nn.ModuleList(dummy)
        self.blocks = nn.ModuleList(
            [DBlock(nf(i + 2), nf(i + 1), **block_settings, **layer_settings) for i in range(R - 1, -1, -1)] + [
                last_block])

        if num_classes != 0:
            self.class_emb = nn.Linear(num_classes, nf(0), False)
            if spectral:
                self.class_emb = spectral_norm(self.class_emb)
        else:
            self.class_emb = None
        self.linear = GeneralConv(nf(0), 1, kernel_size=1, equalized=equalized, act_alpha=-1,
                                  spectral=spectral, init=init)
        self.max_depth = len(self.blocks) - 1

    def forward(self, x):
        if isinstance(x, dict):
            x, y = x['x'], x.get('y', None)
        elif isinstance(x, tuple):
            x, y = x
        else:
            y = None
        h = self.blocks[-(self.depth + 1)](x, True)
        if self.depth > 0:
            h = resample_signal(h, self.progression_scale_up[self.depth - 1],
                                self.progression_scale_down[self.depth - 1], True)
            if self.alpha < 1.0 or self.input_to_all_layers:
                x_lowres = resample_signal(x, self.progression_scale_up[self.depth - 1],
                                           self.progression_scale_down[self.depth - 1], True)
                preult_rgb = self.blocks[-self.depth].fromRGB(x_lowres)
                if self.input_to_all_layers:
                    h = (h * self.alpha + preult_rgb) / (1.0 + self.alpha)
                else:
                    h = h * self.alpha + (1.0 - self.alpha) * preult_rgb
        all_attention_maps = {}
        for i in range(self.depth, 0, -1):
            h = self.blocks[-i](h)
            if i > 1:
                h = resample_signal(h, self.progression_scale_up[i - 2], self.progression_scale_down[i - 2], True)
                if self.input_to_all_layers:
                    x_lowres = resample_signal(x_lowres, self.progression_scale_up[i - 2],
                                               self.progression_scale_down[i - 2], True)
                    h = (h + self.blocks[-i + 1].fromRGB(x_lowres)) / 2.0
            if (i - 2) in self.self_attention:
                h, attention_map = self.self_attention[i - 2](h)
                if attention_map is not None:
                    all_attention_maps[i] = attention_map
        o = self.linear(h).mean(dim=2).squeeze()
        if y is not None:
            emb = self.class_emb(y)
            cond_loss = (emb * h.squeeze()).sum(dim=1)
        else:
            cond_loss = 0.0
        return o + cond_loss, h, all_attention_maps


In [0]:
#torch_utils.py
from collections import defaultdict


class Plugin(object):

    def __init__(self, interval=None):
        if interval is None:
            interval = []
        self.trigger_interval = interval

    def register(self, trainer):
        raise NotImplementedError


class Monitor(Plugin):

    def __init__(self, running_average=True, epoch_average=True, smoothing=0.7,
                 precision=None, number_format=None, unit=''):
        if precision is None:
            precision = 4
        if number_format is None:
            number_format = '.{}f'.format(precision)
        number_format = ':' + number_format
        super(Monitor, self).__init__([(1, 'iteration'), (1, 'epoch')])

        self.smoothing = smoothing
        self.with_running_average = running_average
        self.with_epoch_average = epoch_average

        self.log_format = number_format
        self.log_unit = unit
        self.log_epoch_fields = None
        self.log_iter_fields = ['{last' + number_format + '}' + unit]
        if self.with_running_average:
            self.log_iter_fields += [' ({running_avg' + number_format + '}' + unit + ')']
        if self.with_epoch_average:
            self.log_epoch_fields = ['{epoch_mean' + number_format + '}' + unit]

    def register(self, trainer):
        self.trainer = trainer
        stats = self.trainer.stats.setdefault(self.stat_name, {})
        stats['log_format'] = self.log_format
        stats['log_unit'] = self.log_unit
        stats['log_iter_fields'] = self.log_iter_fields
        if self.with_epoch_average:
            stats['log_epoch_fields'] = self.log_epoch_fields
        if self.with_epoch_average:
            stats['epoch_stats'] = (0, 0)

    def iteration(self, *args):
        stats = self.trainer.stats.setdefault(self.stat_name, {})
        stats['last'] = self._get_value(*args)

        if self.with_epoch_average:
            stats['epoch_stats'] = tuple(sum(t) for t in
                                         zip(stats['epoch_stats'], (stats['last'], 1)))

        if self.with_running_average:
            previous_avg = stats.get('running_avg', 0)
            stats['running_avg'] = previous_avg * self.smoothing + \
                                   stats['last'] * (1 - self.smoothing)

    def epoch(self, idx):
        stats = self.trainer.stats.setdefault(self.stat_name, {})
        if self.with_epoch_average:
            epoch_stats = stats['epoch_stats']
            stats['epoch_mean'] = epoch_stats[0] / epoch_stats[1]
            stats['epoch_stats'] = (0, 0)


class LossMonitor(Monitor):
    stat_name = 'loss'

    def _get_value(self, iteration, input, target, output, loss):
        return loss.item()


class Logger(Plugin):
    alignment = 4
    separator = '#' * 80

    def __init__(self, fields, interval=None):
        if interval is None:
            interval = [(1, 'iteration'), (1, 'epoch')]
        super(Logger, self).__init__(interval)
        self.field_widths = defaultdict(lambda: defaultdict(int))
        self.fields = list(map(lambda f: f.split('.'), fields))

    def _join_results(self, results):
        joined_out = map(lambda i: (i[0], ' '.join(i[1])), results)
        joined_fields = map(lambda i: '{}: {}'.format(i[0], i[1]), joined_out)
        return '\t'.join(joined_fields)

    def log(self, msg):
        print(msg)

    def register(self, trainer):
        self.trainer = trainer

    def gather_stats(self):
        result = {}
        return result

    def _align_output(self, field_idx, output):
        for output_idx, o in enumerate(output):
            if len(o) < self.field_widths[field_idx][output_idx]:
                num_spaces = self.field_widths[field_idx][output_idx] - len(o)
                output[output_idx] += ' ' * num_spaces
            else:
                self.field_widths[field_idx][output_idx] = len(o)

    def _gather_outputs(self, field, log_fields, stat_parent, stat, require_dict=False):
        output = []
        name = ''
        if isinstance(stat, dict):
            log_fields = stat.get(log_fields, [])
            name = stat.get('log_name', '.'.join(field))
            for f in log_fields:
                output.append(f.format(**stat))
        elif not require_dict:
            name = '.'.join(field)
            number_format = stat_parent.get('log_format', '')
            unit = stat_parent.get('log_unit', '')
            fmt = '{' + number_format + '}' + unit
            output.append(fmt.format(stat))
        return name, output

    def _log_all(self, log_fields, prefix=None, suffix=None, require_dict=False):
        results = []
        for field_idx, field in enumerate(self.fields):
            parent, stat = None, self.trainer.stats
            for f in field:
                parent, stat = stat, stat[f]
            name, output = self._gather_outputs(field, log_fields,
                                                parent, stat, require_dict)
            if not output:
                continue
            self._align_output(field_idx, output)
            results.append((name, output))
        if not results:
            return
        output = self._join_results(results)
        if prefix is not None:
            self.log(prefix)
        self.log(output)
        if suffix is not None:
            self.log(suffix)

    def iteration(self, *args):
        self._log_all('log_iter_fields')

    def epoch(self, epoch_idx):
        self._log_all('log_epoch_fields',
                      prefix=self.separator + '\nEpoch summary:',
                      suffix=self.separator,
                      require_dict=True)


In [0]:
#trainer.py
import heapq


class Trainer(object):
    def __init__(self, discriminator: Discriminator, generator: Generator, d_loss, g_loss, dataset,
                 random_latents_generator, resume_nimg, optimizer_g, optimizer_d, d_training_repeats: int = 5,
                 tick_kimg_default: float = 5.0):
        assert d_training_repeats >= 1
        self.d_training_repeats = d_training_repeats
        self.discriminator = discriminator
        self.generator = generator
        self.d_loss = d_loss
        self.g_loss = g_loss
        self.dataset = dataset
        self.cur_nimg = resume_nimg
        self.random_latents_generator = random_latents_generator
        self.tick_start_nimg = self.cur_nimg
        self.tick_duration_nimg = int(tick_kimg_default * 1000)
        self.iterations = 0
        self.cur_tick = 0
        self.time = 0
        self.lr_scheduler_g = None
        self.lr_scheduler_d = None
        self.optimizer_g = optimizer_g
        self.optimizer_d = optimizer_d
        self.stats = {
            'kimg_stat': {'val': self.cur_nimg / 1000., 'log_epoch_fields': ['{val:8.3f}'], 'log_name': 'kimg'},
            'tick_stat': {'val': self.cur_tick, 'log_epoch_fields': ['{val:5}'], 'log_name': 'tick'}
        }
        self.plugin_queues = {
            'iteration': [],
            'epoch': [],  # this is tick
            'end': []
        }

    def register_plugin(self, plugin):
        plugin.register(self)
        intervals = plugin.trigger_interval
        if not isinstance(intervals, list):
            intervals = [intervals]
        for (duration, unit) in intervals:
            queue = self.plugin_queues[unit]
            queue.append((duration, len(queue), plugin))

    def call_plugins(self, queue_name, time, *args):
        args = (time,) + args
        queue = self.plugin_queues[queue_name]
        if len(queue) == 0:
            return
        while queue[0][0] <= time:
            plugin = queue[0][2]
            getattr(plugin, queue_name)(*args)
            for trigger in plugin.trigger_interval:
                if trigger[1] == queue_name:
                    interval = trigger[0]
            new_item = (time + interval, queue[0][1], plugin)
            heapq.heappushpop(queue, new_item)

    def run(self, total_kimg=1):
        for q in self.plugin_queues.values():
            heapq.heapify(q)
        total_nimg = int(total_kimg * 1000)
        try:
            while self.cur_nimg < total_nimg:
                self.train()
                if self.cur_nimg >= self.tick_start_nimg + self.tick_duration_nimg or self.cur_nimg >= total_nimg:
                    self.cur_tick += 1
                    self.tick_start_nimg = self.cur_nimg
                    self.stats['kimg_stat']['val'] = self.cur_nimg / 1000.
                    self.stats['tick_stat']['val'] = self.cur_tick
                    self.call_plugins('epoch', self.cur_tick)
        except KeyboardInterrupt:
            return
        self.call_plugins('end', 1)

    def train(self):
        if self.lr_scheduler_g is not None:
            self.lr_scheduler_g.step(self.cur_nimg / self.d_training_repeats)
        fake_latents_in = cudize(next(self.random_latents_generator))
        for i in range(self.d_training_repeats):
            if self.lr_scheduler_d is not None:
                self.lr_scheduler_d.step(self.cur_nimg)
            real_images_expr = cudize(next(self.dataiter))
            self.cur_nimg += real_images_expr['x'].size(0)
            d_loss = self.d_loss(self.discriminator, self.generator, real_images_expr, fake_latents_in)
            d_loss.backward()
            self.optimizer_d.step()
            fake_latents_in = cudize(next(self.random_latents_generator))
        g_loss = self.g_loss(self.discriminator, self.generator, real_images_expr, fake_latents_in)
        g_loss.backward()
        self.optimizer_g.step()
        self.iterations += 1
        self.call_plugins('iteration', self.iterations, *(g_loss, d_loss))


In [0]:
import gc
import os
import time
from copy import deepcopy
from datetime import timedelta
from glob import glob

import matplotlib
import numpy as np
import pandas as pd
import torch
from scipy import misc
from sklearn.utils.extmath import randomized_svd

matplotlib.use('Agg')
import matplotlib.pyplot as plt


class DepthManager(Plugin):
    minibatch_override = {0: 256, 1: 256, 2: 128, 3: 128, 4: 48, 5: 32,
                          6: 32, 7: 32, 8: 16, 9: 16, 10: 8, 11: 8}

    tick_kimg_override = {4: 4, 5: 4, 6: 4, 7: 3, 8: 3, 9: 2, 10: 2, 11: 1}
    training_kimg_override = {1: 200, 2: 200, 3: 200, 4: 200}
    transition_kimg_override = {1: 200, 2: 200, 3: 200, 4: 200}

    def __init__(self,  # everything starts from 0 or 1
                 create_dataloader_fun, create_rlg, max_depth,
                 tick_kimg_default, get_optimizer, default_lr,
                 reset_optimizer: bool = True, disable_progression=False,
                 minibatch_default=256, depth_offset=0,  # starts form 0
                 lod_training_kimg=400, lod_transition_kimg=400):
        super().__init__([(1, 'iteration')])
        self.reset_optimizer = reset_optimizer
        self.minibatch_default = minibatch_default
        self.tick_kimg_default = tick_kimg_default
        self.create_dataloader_fun = create_dataloader_fun
        self.create_rlg = create_rlg
        self.trainer = None
        self.depth = -1
        self.alpha = -1
        self.get_optimizer = get_optimizer
        self.disable_progression = disable_progression
        self.depth_offset = depth_offset
        self.max_depth = max_depth
        self.default_lr = default_lr
        self.alpha_map = self.pre_compute_alpha_map(self.depth_offset, max_depth, lod_training_kimg,
                                                    self.training_kimg_override, lod_transition_kimg,
                                                    self.transition_kimg_override)

    def register(self, trainer):
        self.trainer = trainer
        self.trainer.stats['minibatch_size'] = self.minibatch_default
        self.trainer.stats['alpha'] = {'log_name': 'alpha', 'log_epoch_fields': ['{val:.2f}'], 'val': self.alpha}
        self.iteration(is_resuming=self.trainer.optimizer_d is not None)

    @staticmethod
    def pre_compute_alpha_map(start_depth, max_depth, lod_training_kimg, lod_training_kimg_overrides,
                              lod_transition_kimg, lod_transition_kimg_overrides):
        points = []
        pointer = 0
        for i in range(start_depth, max_depth):
            pointer += int(lod_training_kimg_overrides.get(i + 1, lod_training_kimg) * 1000)
            points.append(pointer)
            pointer += int(lod_transition_kimg_overrides.get(i + 1, lod_transition_kimg) * 1000)
            points.append(pointer)
        return points

    def calc_progress(self, cur_nimg=None):
        if cur_nimg is None:
            cur_nimg = self.trainer.cur_nimg
        depth = self.depth_offset
        alpha = 1.0
        for i, point in enumerate(self.alpha_map):
            if cur_nimg == point:
                break
            if cur_nimg > point and i % 2 == 0:
                depth += 1
            if cur_nimg < point and i % 2 == 1:
                alpha = (cur_nimg - self.alpha_map[i - 1]) / (point - self.alpha_map[i - 1])
                break
            if cur_nimg < point:
                break
        depth = min(self.max_depth, depth)
        if self.disable_progression:
            depth = self.max_depth
            alpha = 1.0
        return depth, alpha

    def iteration(self, is_resuming=False, *args):
        depth, alpha = self.calc_progress()
        dataset = self.trainer.dataset
        if depth != self.depth:
            self.trainer.discriminator.depth = self.trainer.generator.depth = dataset.model_depth = depth
            self.depth = depth
            minibatch_size = self.minibatch_override.get(depth - self.depth_offset, self.minibatch_default)
            if self.reset_optimizer and not is_resuming:
                self.trainer.optimizer_g, self.trainer.optimizer_d, self.trainer.lr_scheduler_g, self.trainer.lr_scheduler_d = self.get_optimizer(
                    self.minibatch_default * self.default_lr / minibatch_size)
            self.data_loader = self.create_dataloader_fun(minibatch_size)
            self.trainer.dataiter = iter(self.data_loader)
            self.trainer.random_latents_generator = self.create_rlg(minibatch_size)
            tick_duration_kimg = self.tick_kimg_override.get(depth - self.depth_offset, self.tick_kimg_default)
            self.trainer.tick_duration_nimg = int(tick_duration_kimg * 1000)
            self.trainer.stats['minibatch_size'] = minibatch_size
        if alpha != self.alpha:
            self.trainer.discriminator.alpha = self.trainer.generator.alpha = dataset.alpha = alpha
            self.alpha = alpha
        self.trainer.stats['depth'] = depth
        self.trainer.stats['alpha']['val'] = alpha


class EfficientLossMonitor(LossMonitor):
    def __init__(self, loss_no, stat_name, monitor_threshold: float = 10.0,
                 monitor_warmup: int = 50, monitor_patience: int = 5):
        super().__init__()
        self.loss_no = loss_no
        self.stat_name = stat_name
        self.threshold = monitor_threshold
        self.warmup = monitor_warmup
        self.patience = monitor_patience
        self.counter = 0

    def _get_value(self, iteration, *args):
        val = args[self.loss_no].item()
        if val != val:
            raise ValueError('loss value is NaN :((')
        return val

    def epoch(self, idx):
        super().epoch(idx)
        if idx > self.warmup:
            loss_value = self.trainer.stats[self.stat_name]['epoch_mean']
            if abs(loss_value) > self.threshold:
                self.counter += 1
                if self.counter > self.patience:
                    raise ValueError('loss value exceeded the threshold')
            else:
                self.counter = 0


class AbsoluteTimeMonitor(Plugin):
    def __init__(self):
        super().__init__([(1, 'epoch')])
        self.start_time = time.time()
        self.epoch_start = self.start_time
        self.start_nimg = None
        self.epoch_time = 0

    def register(self, trainer):
        self.trainer = trainer
        self.start_nimg = trainer.cur_nimg
        self.trainer.stats['sec'] = {'log_format': ':.1f'}

    def epoch(self, epoch_index):
        cur_time = time.time()
        tick_time = cur_time - self.epoch_start
        self.epoch_start = cur_time
        kimg_time = tick_time / (self.trainer.cur_nimg - self.start_nimg) * 1000
        self.start_nimg = self.trainer.cur_nimg
        self.trainer.stats['time'] = timedelta(seconds=time.time() - self.start_time)
        self.trainer.stats['sec']['tick'] = tick_time
        self.trainer.stats['sec']['kimg'] = kimg_time


class SaverPlugin(Plugin):
    last_pattern = 'network-snapshot-{}-{}.dat'

    def __init__(self, checkpoints_path, keep_old_checkpoints: bool = True, network_snapshot_ticks: int = 50):
        super().__init__([(network_snapshot_ticks, 'epoch')])
        self.checkpoints_path = checkpoints_path
        self.keep_old_checkpoints = keep_old_checkpoints

    def register(self, trainer: Trainer):
        self.trainer = trainer

    def epoch(self, epoch_index):
        if not self.keep_old_checkpoints:
            self._clear(self.last_pattern.format('*', '*'))
        dest = os.path.join(self.checkpoints_path,
                            self.last_pattern.format('{}', '{:06}'.format(self.trainer.cur_nimg // 1000)))
        for model, optimizer, name in [(self.trainer.generator, self.trainer.optimizer_g, 'generator'),
                                       (self.trainer.discriminator, self.trainer.optimizer_d, 'discriminator')]:
            torch.save({'cur_nimg': self.trainer.cur_nimg, 'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()}, dest.format(name))

    def end(self, *args):
        self.epoch(*args)

    def _clear(self, pattern):
        pattern = os.path.join(self.checkpoints_path, pattern)
        for file_name in glob(pattern):
            os.remove(file_name)


class EvalDiscriminator(Plugin):
    def __init__(self, create_dataloader_fun, output_snapshot_ticks):
        super().__init__([(1, 'epoch')])
        self.create_dataloader_fun = create_dataloader_fun
        self.output_snapshot_ticks = output_snapshot_ticks

    def register(self, trainer):
        self.trainer = trainer
        self.trainer.stats['memorization'] = {
            'log_name': 'memorization',
            'log_epoch_fields': ['{val:.2f}', '{epoch:.2f}'],
            'val': float('nan'), 'epoch': 0,
        }

    def epoch(self, epoch_index):
        if epoch_index % self.output_snapshot_ticks != 0:
            return
        values = []
        with torch.no_grad():
            i = 0
            for data in self.create_dataloader_fun(min(self.trainer.stats['minibatch_size'], 1024), False,
                                                   self.trainer.dataset.model_depth, self.trainer.dataset.alpha):
                d_real, _, _ = self.trainer.discriminator(cudize(data))
                values.append(d_real.mean().item())
                i += 1
        values = np.array(values).mean()
        self.trainer.stats['memorization']['val'] = values
        self.trainer.stats['memorization']['epoch'] = epoch_index


# TODO
class OutputGenerator(Plugin):

    def __init__(self, sample_fn, checkpoints_dir: str, seq_len: int, max_freq: float,
                 samples_count: int = 8, output_snapshot_ticks: int = 25, old_weight: float = 0.59):
        super().__init__([(1, 'epoch')])
        self.old_weight = old_weight
        self.sample_fn = sample_fn
        self.samples_count = samples_count
        self.checkpoints_dir = checkpoints_dir
        self.seq_len = seq_len
        self.max_freq = max_freq
        self.my_g_clone = None
        self.output_snapshot_ticks = output_snapshot_ticks

    @staticmethod
    def flatten_params(model):
        return deepcopy(list(p.data for p in model.parameters()))

    @staticmethod
    def load_params(flattened, model):
        for p, avg_p in zip(model.parameters(), flattened):
            p.data.copy_(avg_p)

    def register(self, trainer):
        self.trainer = trainer
        self.my_g_clone = self.flatten_params(self.trainer.generator)

    @staticmethod
    def running_mean(x, n=8):
        return pd.Series(x).rolling(window=n).mean().values

    @staticmethod
    def get_images(frequency, epoch, generated, my_range=range):
        num_channels = generated.shape[1]
        seq_len = generated.shape[2]
        t = np.linspace(0, seq_len / frequency, seq_len)
        f = np.fft.rfftfreq(seq_len, d=1. / frequency)
        images = []
        for index in my_range(len(generated)):
            fig, (axs) = plt.subplots(num_channels, 4)
            if num_channels == 1:
                axs = axs.reshape(1, -1)
            fig.set_figheight(40)
            fig.set_figwidth(40)
            for ch in range(num_channels):
                data = generated[index, ch, :]
                axs[ch][0].plot(t, data, color=(0.8, 0, 0, 0.5), label='time domain')
                axs[ch][1].plot(f, np.abs(np.fft.rfft(data)), color=(0.8, 0, 0, 0.5), label='freq domain')
                axs[ch][2].plot(f, OutputGenerator.running_mean(np.abs(np.fft.rfft(data))),
                                color=(0.8, 0, 0, 0.5), label='freq domain(smooth)')
                axs[ch][3].semilogy(f, np.abs(np.fft.rfft(data)), color=(0.8, 0, 0, 0.5), label='freq domain(log)')
                axs[ch][0].set_ylim([-1.1, 1.1])
                axs[ch][0].legend()
                axs[ch][1].legend()
                axs[ch][2].legend()
                axs[ch][3].legend()
            fig.suptitle('epoch: {}, sample: {}'.format(epoch, index))
            fig.canvas.draw()
            image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
            image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            images.append(image)
            plt.close(fig)
        return images

    def epoch(self, epoch_index):
        for p, avg_p in zip(self.trainer.generator.parameters(), self.my_g_clone):
            avg_p.mul_(self.old_weight).add_((1.0 - self.old_weight) * p.data)
        if epoch_index % self.output_snapshot_ticks == 0:
            z = next(self.sample_fn(self.samples_count))
            gen_input = cudize(z)
            original_param = self.flatten_params(self.trainer.generator)
            self.load_params(self.my_g_clone, self.trainer.generator)
            dest = os.path.join(self.checkpoints_dir, SaverPlugin.last_pattern.format('smooth_generator',
                                                                                      '{:06}'.format(
                                                                                          self.trainer.cur_nimg // 1000)))
            torch.save({'cur_nimg': self.trainer.cur_nimg, 'model': self.trainer.generator.state_dict()}, dest)
            out = generate_samples(self.trainer.generator, gen_input)
            self.load_params(original_param, self.trainer.generator)
            frequency = self.max_freq * out.shape[2] / self.seq_len
            images = self.get_images(frequency, epoch_index, out)
            for i, image in enumerate(images):
                misc.imsave(os.path.join(self.checkpoints_dir, '{}_{}.png'.format(epoch_index, i)), image)


class TeeLogger(Logger):

    def __init__(self, log_file, exp_name, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.log_file = open(log_file, 'a', 1)
        self.exp_name = exp_name

    def log(self, msg):
        print(self.exp_name, msg, flush=True)
        self.log_file.write(msg + '\n')

    def epoch(self, epoch_idx):
        self._log_all('log_epoch_fields')


class WatchSingularValues(Plugin):
    def __init__(self, network, one_divided_two: float = 10.0, output_snapshot_ticks: int = 20):
        super().__init__([(1, 'epoch')])
        self.network = network
        self.one_divided_two = one_divided_two
        self.output_snapshot_ticks = output_snapshot_ticks

    def register(self, trainer):
        self.trainer = trainer

    def epoch(self, epoch_index):
        if epoch_index % self.output_snapshot_ticks != 0:
            return
        for module in self.network.modules:
            if isinstance(module, torch.nn.Conv1d):
                weight = module.weight.data.cpu().numpy()
                _, s, _ = randomized_svd(weight.reshape(weight.shape[0], -1), n_components=2)
                if abs(s[0] / s[1]) > self.one_divided_two:
                    raise ValueError(module)


class SlicedWDistance(Plugin):
    def __init__(self, progression_scale: int, output_snapshot_ticks: int, patches_per_item: int = 16,
                 patch_size: int = 49, max_items: int = 1024, number_of_projections: int = 512,
                 dir_repeats: int = 4, dirs_per_repeat: int = 128):
        super().__init__([(1, 'epoch')])
        self.output_snapshot_ticks = output_snapshot_ticks
        self.progression_scale = progression_scale
        self.patches_per_item = patches_per_item
        self.patch_size = patch_size
        self.max_items = max_items
        self.number_of_projections = number_of_projections
        self.dir_repeats = dir_repeats
        self.dirs_per_repeat = dirs_per_repeat

    def register(self, trainer):
        self.trainer = trainer
        self.trainer.stats['swd'] = {
            'log_name': 'swd',
            'log_epoch_fields': ['{val:.2f}', '{epoch:.2f}'],
            'val': float('nan'), 'epoch': 0,
        }

    def sliced_wasserstein(self, A, B):
        results = []
        for repeat in range(self.dir_repeats):
            dirs = torch.randn(A.shape[1], self.dirs_per_repeat)  # (descriptor_component, direction)
            dirs /= torch.sqrt(
                (dirs * dirs).sum(dim=0, keepdim=True) + EPSILON)  # normalize descriptor components for each direction
            projA = torch.matmul(A, dirs)  # (neighborhood, direction)
            projB = torch.matmul(B, dirs)
            projA = torch.sort(projA, dim=0)[0]  # sort neighborhood projections for each direction
            projB = torch.sort(projB, dim=0)[0]
            dists = (projA - projB).abs()  # pointwise wasserstein distances
            results.append(dists.mean())  # average over neighborhoods and directions
        return torch.mean(torch.stack(results)).item()  # average over repeats

    def epoch(self, epoch_index):
        if epoch_index % self.output_snapshot_ticks != 0:
            return
        gc.collect()
        all_fakes = []
        all_reals = []
        with torch.no_grad():
            remaining_items = self.max_items
            while remaining_items > 0:
                z = next(self.trainer.random_latents_generator)
                fake_latents_in = cudize(z)
                all_fakes.append(self.trainer.generator(fake_latents_in)[0]['x'].data.cpu())
                if all_fakes[-1].size(2) < self.patch_size:
                    break
                remaining_items -= all_fakes[-1].size(0)
            all_fakes = torch.cat(all_fakes, dim=0)
            remaining_items = self.max_items
            while remaining_items > 0:
                all_reals.append(next(self.trainer.dataiter)['x'])
                if all_reals[-1].size(2) < self.patch_size:
                    break
                remaining_items -= all_reals[-1].size(0)
            all_reals = torch.cat(all_reals, dim=0)
        swd = self.get_descriptors(all_fakes, all_reals)
        if len(swd) > 0:
            swd.append(np.array(swd).mean())
        self.trainer.stats['swd']['val'] = swd
        self.trainer.stats['swd']['epoch'] = epoch_index

    def get_descriptors(self, batch1, batch2):
        b, c, t_max = batch1.shape
        t = t_max
        num_levels = 0
        while t >= self.patch_size:
            num_levels += 1
            t //= self.progression_scale
        swd = []
        for level in range(num_levels):
            both_descriptors = [None, None]
            batchs = [batch1, batch2]
            for i in range(2):
                descriptors = []
                max_index = batchs[i].shape[2] - self.patch_size
                for j in range(b):
                    for k in range(self.patches_per_item):
                        rand_index = np.random.randint(0, max_index)
                        descriptors.append(batchs[i][j, :, rand_index:rand_index + self.patch_size])
                descriptors = torch.stack(descriptors, dim=0)  # N, c, patch_size
                descriptors = descriptors.reshape((-1, c))
                descriptors -= torch.mean(descriptors, dim=0, keepdim=True)
                descriptors /= torch.std(descriptors, dim=0, keepdim=True) + EPSILON
                both_descriptors[i] = descriptors
                batchs[i] = batchs[i][:, :, ::self.progression_scale]
            swd.append(self.sliced_wasserstein(both_descriptors[0], both_descriptors[1]))
        return swd


In [0]:
import os
import signal
import subprocess
import time
from functools import partial

import numpy as np
import torch
import yaml
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.sampler import SubsetRandomSampler

default_params = dict(
    result_dir='results',
    exp_name='',
    lr=0.001,  # generator's learning rate
    total_kimg=6000,
    resume_network='',  # 001-test/network-snapshot-{}-000025.dat
    num_data_workers=0,
    random_seed=1373,
    grad_lambda=10.0,  # must set it to zero to disable gp loss (even for non wgan based losses)
    iwass_drift_epsilon=0.001,
    iwass_target=1.0,
    feature_matching_lambda=0.0,
    loss_type='wgan_gp',  # wgan_gp, hinge, wgan_theirs, rsgan, rasgan, rahinge
    cuda_device=0,
    ttur=False,
    config_file=None,
    fmap_base=1024,
    fmap_max=256,
    fmap_min=64,
    equalized=True,
    kernel_size=3,
    self_attention_layers=[],  # starts from 0 or null (for G it means putting it after ith layer)
    random_multiply=False,
    lr_rampup_kimg=0.0,  # set to 0 to disable (used to be 40)
    z_distribution='normal',  # or 'bernoulli' or 'censored'
    init='kaiming_normal',  # or xavier_uniform or orthogonal
    act_alpha=0.2,
    residual=False,
    calc_swd=False,
    separable=False,
    num_classes=0,
    deep=False
)


class InfiniteRandomSampler(SubsetRandomSampler):
    def __iter__(self):
        while True:
            it = super().__iter__()
            for x in it:
                yield x


def load_models(resume_network, result_dir, logger):
    logger.log('Resuming {}'.format(resume_network))
    dest = os.path.join(result_dir, resume_network)
    generator, g_optimizer, g_cur_img = load_model(dest.format('generator'), True)
    discriminator, d_optimizer, d_cur_img = load_model(dest.format('discriminator'), True)
    assert g_cur_img == d_cur_img
    return generator, g_optimizer, discriminator, d_optimizer, g_cur_img


def thread_exit(_signal, frame):
    exit(0)


def worker_init(x):
    signal.signal(signal.SIGINT, thread_exit)


def main(params):
    dataset_params = params['EEGDataset']
    dataset, val_dataset = EEGDataset.from_config(**dataset_params)
    if params['config_file'] and params['exp_name'] == '':
        params['exp_name'] = params['config_file'].split('/')[-1].split('.')[0]
    result_dir = create_result_subdir(params['result_dir'], params['exp_name'])

    losses = ['G_loss', 'D_loss']
    stats_to_log = ['tick_stat', 'kimg_stat']
    stats_to_log.extend(['depth', 'alpha', 'minibatch_size'])
    stats_to_log.extend(['time', 'sec.tick', 'sec.kimg'] + losses)
    if dataset_params['validation_ratio'] > 0:
        stats_to_log.extend(['memorization.val', 'memorization.epoch'])
    if params['calc_swd']:
        stats_to_log.extend(['swd.val', 'swd.epoch'])

    logger = TeeLogger(os.path.join(result_dir, 'log.txt'), params['exp_name'], stats_to_log, [(1, 'epoch')])
    shared_model_params = dict(initial_kernel_size=dataset.initial_kernel_size, num_rgb_channels=dataset.num_channels,
                               fmap_base=params['fmap_base'], fmap_max=params['fmap_max'], fmap_min=params['fmap_min'],
                               kernel_size=params['kernel_size'], self_attention_layers=params['self_attention_layers'],
                               progression_scale_up=dataset.progression_scale_up,
                               progression_scale_down=dataset.progression_scale_down, residual=params['residual'],
                               separable=params['separable'], equalized=params['equalized'], init=params['init'],
                               act_alpha=params['act_alpha'], num_classes=params['num_classes'], deep=params['deep'])
    for n in ('Generator', 'Discriminator'):
        p = params[n]
        if p['spectral']:
            if p['act_norm'] == 'pixel':
                logger.log('Warning, setting pixel normalization with spectral norm in {} is not a good idea'.format(n))
            if params['equalized']:
                logger.log('Warning, setting equalized weights with spectral norm in {} is not a good idea'.format(n))
    if params['DepthManager']['disable_progression'] and not params['residual']:
        logger.log('Warning, you have set the residual to false and disabled the progression')
    if params['Discriminator']['act_norm'] is not None:
        logger.log('Warning, you are using an activation normalization in discriminator')
    generator = Generator(**shared_model_params, z_distribution=params['z_distribution'], **params['Generator'])
    discriminator = Discriminator(**shared_model_params, **params['Discriminator'])

    def rampup(cur_nimg):
        if cur_nimg < params['lr_rampup_kimg'] * 1000:
            p = max(0.0, 1 - cur_nimg / (params['lr_rampup_kimg'] * 1000))
            return np.exp(-p * p * 5.0)
        else:
            return 1.0

    if params['ttur']:
        params['Adam']['betas'] = (0, 0.9)

    def get_optimizers(g_lr):
        d_lr = g_lr
        if params['ttur']:
            d_lr *= 4.0
        opt_g = Adam(trainable_params(generator), g_lr, **params['Adam'])
        opt_d = Adam(trainable_params(discriminator), d_lr, **params['Adam'])
        if params['lr_rampup_kimg'] > 0:
            lr_scheduler_g = LambdaLR(opt_g, rampup, -1)
            lr_scheduler_d = LambdaLR(opt_d, rampup, -1)
            return opt_g, opt_d, lr_scheduler_g, lr_scheduler_d
        return opt_g, opt_d, None, None

    if params['resume_network'] != '':
        logger.log('resuming networks')
        generator_state, opt_g_state, discriminator_state, opt_d_state, train_cur_img = load_models(
            params['resume_network'], params['result_dir'], logger)
        generator.load_state_dict(generator_state)
        discriminator.load_state_dict(discriminator_state)
        opt_g, opt_d, _, _ = get_optimizers(params['lr'])
        opt_g.load_state_dict(opt_g_state)
        opt_d.load_state_dict(opt_d_state)
    else:
        opt_g = None
        opt_d = None
        train_cur_img = 0
    latent_size = generator.input_latent_size
    generator.train()
    discriminator.train()
    generator = cudize(generator)
    discriminator = cudize(discriminator)
    if opt_g is not None:
        for opt in [opt_g, opt_d]:
            for state in opt.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = cudize(v)
    d_loss_fun = partial(discriminator_loss, loss_type=params['loss_type'], iwass_target=params['iwass_target'],
                         iwass_drift_epsilon=params['iwass_drift_epsilon'], grad_lambda=params['grad_lambda'])
    g_loss_fun = partial(generator_loss, random_multiply=params['random_multiply'], loss_type=params['loss_type'],
                         feature_matching_lambda=params['feature_matching_lambda'])
    max_depth = generator.max_depth

    logger.log('exp name: {}'.format(params['exp_name']))
    try:
        logger.log('commit hash: {}'.format(subprocess.check_output(['git', 'describe', '--always']).strip()))
    except:
        logger.log('current time: {}'.format(time.time()))
    logger.log('training dataset shape: {}'.format(dataset.shape))
    if dataset_params['validation_ratio'] > 0:
        logger.log('val dataset shape: {}'.format(val_dataset.shape))
    logger.log('Total number of parameters in Generator: {}'.format(num_params(generator)))
    logger.log('Total number of parameters in Discriminator: {}'.format(num_params(discriminator)))

    mb_def = params['DepthManager']['minibatch_default']

    collate_real = get_collate_real(dataset.end_sampling_freq, dataset.seq_len)
    collate_fake = get_collate_fake(latent_size, params['z_distribution'], collate_real)

    def get_dataloader(minibatch_size, is_training=True, depth=0, alpha=1, is_real=True):
        ds = dataset if is_training else val_dataset
        shared_dataloader_params = {'dataset': ds, 'batch_size': minibatch_size, 'drop_last': True,
                                    'worker_init_fn': worker_init, 'num_workers': params['num_data_workers'],
                                    'collate_fn': collate_real if is_real else collate_fake}
        if not is_training:
            ds.model_depth = depth
            ds.alpha = alpha
            # NOTE you must drop last in order to be compatible with D.stats layer
            return DataLoader(**shared_dataloader_params, shuffle=True)
        return DataLoader(**shared_dataloader_params, sampler=InfiniteRandomSampler(list(range(len(ds)))))

    # NOTE you can not put the if inside your function (a function should either return or yield)
    def get_random_latents(minibatch_size, is_training=True, depth=0, alpha=1):
        while True:
            yield {'z': cudize(random_latents(minibatch_size, latent_size, params['z_distribution']))}

    trainer = Trainer(discriminator, generator, d_loss_fun, g_loss_fun, dataset, get_random_latents(mb_def),
                      train_cur_img, opt_g, opt_d, **params['Trainer'])
    dm = DepthManager(get_dataloader, get_random_latents, max_depth, params['Trainer']['tick_kimg_default'],
                      get_optimizers, params['lr'], **params['DepthManager'])
    trainer.register_plugin(dm)
    for i, loss_name in enumerate(losses):
        trainer.register_plugin(EfficientLossMonitor(i, loss_name, **params['EfficientLossMonitor']))
    trainer.register_plugin(SaverPlugin(result_dir, **params['SaverPlugin']))
    trainer.register_plugin(
        OutputGenerator(lambda x: get_random_latents(x), result_dir, dataset.seq_len,
                        dataset.end_sampling_freq, **params['OutputGenerator']))
    if dataset_params['validation_ratio'] > 0:
        trainer.register_plugin(EvalDiscriminator(get_dataloader, params['SaverPlugin']['network_snapshot_ticks']))
    if params['calc_swd']:
        trainer.register_plugin(
            SlicedWDistance(dataset.progression_scale, params['SaverPlugin']['network_snapshot_ticks'],
                            **params['SlicedWDistance']))
    trainer.register_plugin(AbsoluteTimeMonitor())
    if params['Generator']['spectral']:
        trainer.register_plugin(WatchSingularValues(generator, **params['WatchSingularValues']))
    if params['Discriminator']['spectral']:
        trainer.register_plugin(WatchSingularValues(discriminator, **params['WatchSingularValues']))
    trainer.register_plugin(logger)
    params['EEGDataset']['progression_scale_up'] = dataset.progression_scale_up
    params['EEGDataset']['progression_scale_down'] = dataset.progression_scale_down
    params['EEGDataset']['picked_channels'] = dataset.picked_channels
    params['DepthManager']['minibatch_override'] = dm.minibatch_override
    params['DepthManager']['tick_kimg_override'] = dm.tick_kimg_override
    params['DepthManager']['training_kimg_override'] = dm.training_kimg_override
    params['DepthManager']['transition_kimg_override'] = dm.transition_kimg_override
    yaml.dump(params, open(os.path.join(result_dir, 'conf.yml'), 'w'))
    trainer.run(params['total_kimg'])
    del trainer


if __name__ == "__main__":
    need_arg_classes = [Trainer, Generator, Discriminator, Adam, OutputGenerator, DepthManager, SaverPlugin,
                        SlicedWDistance, EfficientLossMonitor, EvalDiscriminator, EEGDataset, WatchSingularValues]
    main(parse_config(default_params, need_arg_classes, True, False))
    print('training finished!')


loading train dataset from file
loading val dataset from file
 exp name: 
 current time: 1554705972.6557548
 training dataset shape: (9501, 5, 1920)
 val dataset shape: (1155, 5, 1920)
 Total number of parameters in Generator: 2976360
 Total number of parameters in Discriminator: 2978497
 tick:     1	kimg:    5.120	depth: 0	alpha: 1.00	minibatch_size: 256	time: 0:00:12.304734	sec.tick: 12.3	sec.kimg: 2.4	G_loss: 0.0427	D_loss: 8.7106	memorization.val: nan	memorization.epoch: 0
 tick:     2	kimg:   10.240	depth: 0	alpha: 1.00	minibatch_size: 256	time: 0:00:18.866216	sec.tick: 6.6 	sec.kimg: 1.3	G_loss: 0.1024	D_loss: 6.8374	memorization.val: nan	memorization.epoch: 0
 tick:     3	kimg:   15.360	depth: 0	alpha: 1.00	minibatch_size: 256	time: 0:00:25.456018	sec.tick: 6.6 	sec.kimg: 1.3	G_loss: 0.1803	D_loss: 4.3948	memorization.val: nan	memorization.epoch: 0
 tick:     4	kimg:   20.480	depth: 0	alpha: 1.00	minibatch_size: 256	time: 0:00:32.020246	sec.tick: 6.6 	sec.kimg: 1.3	G_loss: 0.328