Skip to content

Commit

Permalink
Improve docs, update scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
ajbrock committed Mar 12, 2019
1 parent ec17149 commit b588d62
Show file tree
Hide file tree
Showing 40 changed files with 459 additions and 401 deletions.
File renamed without changes.
18 changes: 12 additions & 6 deletions calculate_inception_moments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# This script iterates over the dataset and calculates the moments of the
# activations of the Inception net (needed for FID), and also returns
# the Inception Score of the training data.
''' Calculate Inception Moments
This script iterates over the dataset and calculates the moments of the
activations of the Inception net (needed for FID), and also returns
the Inception Score of the training data.
Note that if you don't shuffle the data, the IS of true data will be under-
estimated as it is label-ordered. By default, the data is not shuffled
so as to reduce non-determinism. '''
import numpy as np
import torch
import torch.nn as nn
Expand All @@ -15,11 +20,11 @@ def prepare_parser():
usage = 'Calculate and store inception metrics.'
parser = ArgumentParser(description=usage)
parser.add_argument(
'--dataset', type=str, default='I128',
'--dataset', type=str, default='I128_hdf5',
help='Which Dataset to train on, out of I128, I256, C10, C100...'
'Append _hdf5 to use the hdf5 version of the dataset. (default: %(default)s)')
parser.add_argument(
'--dataset_root', type=str, default='/home/s1580274/scratch/data/',
'--dataset_root', type=str, default='data',
help='Default location where data is stored (default: %(default)s)')
parser.add_argument(
'--batch_size', type=int, default=64,
Expand Down Expand Up @@ -67,7 +72,8 @@ def run(config):
print('Calculating inception metrics...')
IS_mean, IS_std = inception_utils.calculate_inception_score(logits)
print('Training data from dataset %s has IS of %5.5f +/- %5.5f' % (config['dataset'], IS_mean, IS_std))
# Prepare mu and sigma, save to disk
# Prepare mu and sigma, save to disk. Remove "hdf5" by default
# (the FID code also knows to strip "hdf5")
print('Calculating means and covariances...')
mu, sigma = np.mean(pool, axis=0), np.cov(pool, rowvar=False)
print('Saving calculated means and covariances to disk...')
Expand Down
36 changes: 12 additions & 24 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
''' Datasets
This file contains definitions for our CIFAR, ImageFolder, and HDF5 datasets
'''
import os
import os.path
import sys
Expand All @@ -10,8 +13,7 @@
from torchvision.datasets.utils import download_url, check_integrity
import torch.utils.data as data
from torch.utils.data import DataLoader

# Stuff for full imagenet

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']


Expand Down Expand Up @@ -106,9 +108,12 @@ def __init__(self, root, transform=None, target_transform=None,
loader=default_loader, load_in_mem=False,
index_filename='imagenet_imgs.npz', **kwargs):
classes, class_to_idx = find_classes(root)
# Load pre-computed image directory walk
if os.path.exists(index_filename):
print('Loading pre-saved Index file %s...' % index_filename)
imgs = np.load(index_filename)['imgs']
# If first time, walk the folder directory and save the
# results to a pre-computed file.
else:
print('Generating Index file %s...' % index_filename)
imgs = make_dataset(root, class_to_idx)
Expand Down Expand Up @@ -171,19 +176,9 @@ def __repr__(self):
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str


# class ImageNetA(ImageFolder):
# def __init__(self, root, transform=None, target_transform=None,
# loader=default_loader, load_in_mem=False,
# train=True,download=False, validate_seed=0,
# val_split=0):
# super(ImageNetA, self).__init__(root, transform, target_transform,
# default_loader, load_in_mem, train, download, validate_seed,
# val_split):


# Imagenet at 256 with '/home/s1580274/scratch/ILSVRC256.hdf5'

''' ILSVRC_HDF5: A dataset to support I/O from an HDF5 to avoid
having to load individual images all the time. '''
import h5py as h5
import torch
class ILSVRC_HDF5(data.Dataset):
Expand Down Expand Up @@ -250,7 +245,7 @@ class CIFAR10(dset.CIFAR10):
def __init__(self, root, train=True,
transform=None, target_transform=None,
download=True, validate_seed=0,
val_split=0, load_in_mem=True):
val_split=0, load_in_mem=True, **kwargs):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
Expand All @@ -264,9 +259,7 @@ def __init__(self, root, train=True,
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

# now load the picked numpy arrays


# now load the picked numpy arrays
self.data = []
self.labels= []
for fentry in self.train_list:
Expand Down Expand Up @@ -294,15 +287,10 @@ def __init__(self, root, train=True,

# randomly grab 500 elements of each class
np.random.seed(validate_seed)

self.val_indices = []



for l_i in label_indices:
self.val_indices += list(l_i[np.random.choice(len(l_i), int(len(self.data) * val_split) // (max(self.labels) + 1) ,replace=False)])



if self.train=='validate':
self.data = self.data[self.val_indices]
self.labels = list(np.asarray(self.labels)[self.val_indices])
Expand Down
23 changes: 12 additions & 11 deletions inception_tf13.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
## Tensorflow inception score code
# Derived from https://github.com/openai/improved-gan
# Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
# THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in BATCH MODE
''' Tensorflow inception score code
Derived from https://github.com/openai/improved-gan
Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in PARALLEL BATCH MODE
To use this code, run sample.py on your model with --sample_npz, and then
pass the experiment name in the --experiment_name
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand All @@ -17,20 +21,18 @@
from six.moves import urllib
import tensorflow as tf

MODEL_DIR = '/home/s1580274/scratch'
MODEL_DIR = ''
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
softmax = None

# Run on eddie with /home/s1580274/group/myconda/tensorflow/bin/python
#fname='BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs128_nDa4_nGa4_Glr1.0e-04_Dlr4.0e-04_Gnlrelu_Dnlrelu_Ginitxavier_Dinitxavier_Gattn64_Dattn64_Gshared_ema_SAGAN_bs128x4_ema'
def prepare_parser():
usage = 'Parser for TF1.3- Inception Score scripts.'
parser = ArgumentParser(description=usage)
parser.add_argument(
'--experiment_name', type=str, default='',
help='Which experiment''s samples.npz file to pull and evaluate')
parser.add_argument(
'--experiment_root', type=str, default='/home/s1580274/scratch/samples/',
'--experiment_root', type=str, default='samples',
help='Default location where samples are stored (default: %(default)s)')
parser.add_argument(
'--batch_size', type=int, default=500,
Expand Down Expand Up @@ -110,13 +112,12 @@ def _progress(count, block_size, total_size):
logits = tf.matmul(tf.squeeze(pool3), w)
softmax = tf.nn.softmax(logits)

# if softmax is None:
# if softmax is None: # No need to functionalize like this.
_init_inception()

fname = '%s/%s/samples.npz' % (config['experiment_root'], config['experiment_name'])
print('loading %s ...'%fname)
ims = np.load(fname)['x']# + '_samples.npz')['x']
# ims =
ims = np.load(fname)['x']
import time
t0 = time.time()
inc = get_inception_score(list(ims.swapaxes(1,2).swapaxes(2,3)), splits=10)
Expand Down
25 changes: 21 additions & 4 deletions inception_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
''' Inception utilities
This file contains methods for calculating IS and FID, using either
the original numpy code or an accelerated fully-pytorch version that
uses a fast newton-schulz approximation for the matrix sqrt. There are also
methods for acquiring a desired number of samples from the Generator,
and parallelizing the inbuilt PyTorch inception network.
NOTE that Inception Scores and FIDs calculated using these methods will
*not* be directly comparable to values calculated using the original TF
IS/FID code. You *must* use the TF model if you wish to report and compare
numbers. This code tends to produce IS values that are 5-10% lower than
those obtained through TF.
'''
import numpy as np
from scipy import linalg # For FID
from scipy import linalg # For numpy FID
import time

import torch
Expand All @@ -8,6 +21,7 @@
from torch.nn import Parameter as P
from torchvision.models.inception import inception_v3


# Module that wraps the inception network to enable use with dataparallel and
# returning pool features and logits.
class WrapInception(nn.Module):
Expand Down Expand Up @@ -69,6 +83,7 @@ def forward(self, x):
# 1000 (num_classes)
return pool, logits


# A pytorch implementation of cov, from Modar M. Alfadly
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
def torch_cov(m, rowvar=False):
Expand Down Expand Up @@ -103,9 +118,9 @@ def torch_cov(m, rowvar=False):
mt = m.t() # if complex: mt = m.t().conj()
return fact * m.matmul(mt).squeeze()


# Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji
# https://github.com/msubhransu/matrix-sqrt

def sqrt_newton_schulz(A, numIters, dtype=None):
with torch.no_grad():
if dtype is None:
Expand All @@ -123,9 +138,9 @@ def sqrt_newton_schulz(A, numIters, dtype=None):
sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A)
return sA


# FID calculator from TTUR--consider replacing this with GPU-accelerated cov
# calculations using torch?

def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
Taken from https://github.com/bioinf-jku/TTUR
Expand Down Expand Up @@ -180,7 +195,7 @@ def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):

out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
return out
# # return () +


def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Pytorch implementation of the Frechet Distance.
Expand Down Expand Up @@ -214,6 +229,8 @@ def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
out = (diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2)
- 2 * torch.trace(covmean))
return out


# Calculate Inception Score mean + std given softmax'd logits and number of splits
def calculate_inception_score(pred, num_splits=10):
scores = []
Expand Down
18 changes: 12 additions & 6 deletions layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
''' Layers
This file contains various layers for the BigGAN models.
'''
import numpy as np
import torch
import torch.nn as nn
Expand All @@ -8,6 +11,7 @@

from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d


# Projection of x onto y
def proj(x, y):
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
Expand Down Expand Up @@ -45,11 +49,13 @@ def power_iteration(W, u_, update=True, eps=1e-12):
#svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
return svs, us, vs


# Convenience passthrough function
class identity(nn.Module):
def forward(self, input):
return input


# Spectral normalization base class
class SN(object):
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
Expand Down Expand Up @@ -105,6 +111,7 @@ def forward(self, x):
return F.conv2d(x, self.W_(), self.bias, self.stride,
self.padding, self.dilation, self.groups)


# Linear layer with spectral norm
class SNLinear(nn.Linear, SN):
def __init__(self, in_features, out_features, bias=True,
Expand All @@ -114,6 +121,7 @@ def __init__(self, in_features, out_features, bias=True,
def forward(self, x):
return F.linear(x, self.W_(), self.bias)


# Embedding layer with spectral norm
# We use num_embeddings as the dim instead of embedding_dim here
# for convenience sake
Expand Down Expand Up @@ -319,7 +327,8 @@ def extra_repr(self):
s = 'out: {output_size}, in: {input_size},'
s +=' cross_replica={cross_replica}'
return s.format(**self.__dict__)



# Normal, non-class-conditional BN
class bn(nn.Module):
def __init__(self, output_size, eps=1e-5, momentum=0.1,
Expand Down Expand Up @@ -355,15 +364,14 @@ def forward(self, x, y=None):
else:
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
self.bias, self.training, self.momentum, self.eps)



# Generator blocks
# Note that this class assumes the kernel size and padding (and any other
# settings) have been selected in the main generator module and passed in
# through the which_conv arg. Similar rules apply with which_bn (the input
# size [which is actually the number of channels of the conditional info] must
# be preselected)
""" Andy's note: I changed activation to NONE to enforce passing in an activation
"""
class GBlock(nn.Module):
def __init__(self, in_channels, out_channels,
which_conv=nn.Conv2d, which_bn=bn, activation=None,
Expand Down Expand Up @@ -401,8 +409,6 @@ def forward(self, x, y):


# Residual block for the discriminator
""" Andy's note: I changed activation to NONE to enforce passing in an activation
"""
class DBlock(nn.Module):
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
preactivation=False, activation=None, downsample=None,):
Expand Down
Loading

0 comments on commit b588d62

Please sign in to comment.