In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import copy
from PIL import Image
import os
import time
import scipy
import gc

import math
import torch.nn.functional as F
import torchvision
import torch.nn as nn
from torchvision import models

from mpl_toolkits.axes_grid1 import ImageGrid
from skimage import exposure
from skimage.exposure import match_histograms
from datetime import datetime
from torchvision import models

# Function Definitions

In [None]:
def composition(im_path, IM_SIZE, cropped=True):
    cx = 89
    cy = 121
    
    # center crop to 128 x 128, then resize to relevant dimension
    with Image.open(im_path) as im1:
        if cropped:
            im1 = im1.crop((cx-64, cy-64, cx+64, cy+64))
        return im1.resize((IM_SIZE, IM_SIZE))

In [None]:
# computing FID code, adapted from https://github.com/mseitzer/pytorch-fid
class InceptionV3(nn.Module):
    """Pretrained InceptionV3 network returning feature maps"""

    # Index of default block of inception to return,
    # corresponds to output of final average pooling
    DEFAULT_BLOCK_INDEX = 3

    # Maps feature dimensionality to their output blocks indices
    BLOCK_INDEX_BY_DIM = {
        64: 0,   # First max pooling features
        192: 1,  # Second max pooling featurs
        768: 2,  # Pre-aux classifier features
        2048: 3  # Final average pooling features
    }

    def __init__(self,
                 output_blocks=[DEFAULT_BLOCK_INDEX],
                 resize_input=True,
                 normalize_input=True,
                 requires_grad=False):
        
        super(InceptionV3, self).__init__()

        self.resize_input = resize_input
        self.normalize_input = normalize_input
        self.output_blocks = sorted(output_blocks)
        self.last_needed_block = max(output_blocks)

        assert self.last_needed_block <= 3, \
            'Last possible output block index is 3'

        self.blocks = nn.ModuleList()

        
        inception = models.inception_v3(pretrained=True)

        # Block 0: input to maxpool1
        block0 = [
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        ]
        self.blocks.append(nn.Sequential(*block0))

        # Block 1: maxpool1 to maxpool2
        if self.last_needed_block >= 1:
            block1 = [
                inception.Conv2d_3b_1x1,
                inception.Conv2d_4a_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2)
            ]
            self.blocks.append(nn.Sequential(*block1))

        # Block 2: maxpool2 to aux classifier
        if self.last_needed_block >= 2:
            block2 = [
                inception.Mixed_5b,
                inception.Mixed_5c,
                inception.Mixed_5d,
                inception.Mixed_6a,
                inception.Mixed_6b,
                inception.Mixed_6c,
                inception.Mixed_6d,
                inception.Mixed_6e,
            ]
            self.blocks.append(nn.Sequential(*block2))

        # Block 3: aux classifier to final avgpool
        if self.last_needed_block >= 3:
            block3 = [
                inception.Mixed_7a,
                inception.Mixed_7b,
                inception.Mixed_7c,
                nn.AdaptiveAvgPool2d(output_size=(1, 1))
            ]
            self.blocks.append(nn.Sequential(*block3))

        for param in self.parameters():
            param.requires_grad = requires_grad

    def forward(self, inp):
        """Get Inception feature maps
        Parameters
        ----------
        inp : torch.autograd.Variable
            Input tensor of shape Bx3xHxW. Values are expected to be in
            range (0, 1)
        Returns
        -------
        List of torch.autograd.Variable, corresponding to the selected output
        block, sorted ascending by index
        """
        outp = []
        x = inp

        if self.resize_input:
            x = F.interpolate(x,
                              size=(299, 299),
                              mode='bilinear',
                              align_corners=False)

        if self.normalize_input:
            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)

        for idx, block in enumerate(self.blocks):
            x = block(x)
            if idx in self.output_blocks:
                outp.append(x)

            if idx == self.last_needed_block:
                break

        return outp
    
def calculate_activation_statistics(images,model,batch_size=100, dims=2048,
                    cuda=False):
    model.eval()
    act=np.empty((len(images), dims))
    nBatches = len(images)//batch_size
    
    for i in range(nBatches):
        batch = images[i*batch_size:(i+1)*batch_size]
        if cuda:
            batch=batch.cuda()

        pred = model(batch)[0]

            # If model output is not scalar, apply global spatial average pooling.
            # This happens if you choose a dimensionality not equal 2048.
        if pred.size(2) != 1 or pred.size(3) != 1:
            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))

        act[i*batch_size:(i+1)*batch_size]= pred.cpu().data.numpy().reshape(pred.size(0), -1)
    
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    
    covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)


def calculate_fretchet(images_real,images_fake, batch_size=10):
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
    model = InceptionV3([block_idx])
    model=model.cuda()
    
    images_real = torch.from_numpy(images_real).float().permute(0, 3, 1, 2)
    images_fake = torch.from_numpy(images_fake).float().permute(0, 3, 1, 2)
    
    mu_1,std_1=calculate_activation_statistics(images_real,model,batch_size=batch_size,cuda=True)
    mu_2 ,std_2=calculate_activation_statistics(images_fake,model,batch_size=batch_size,cuda=True)
    
    """get fretchet distance"""
    fid_value = calculate_frechet_distance(mu_1, std_1, mu_2, std_2)
    print('FID score', fid_value)
    return fid_value

In [None]:
# for clearing gpu memory between runs, if necessary
def pretty_size(size):
    """Pretty prints a torch.Size object"""
    assert(isinstance(size, torch.Size))
    return " × ".join(map(str, size))

def dump_tensors(gpu_only=True):
    """Prints a list of the Tensors being tracked by the garbage collector."""
    total_size = 0
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                if not gpu_only or obj.is_cuda:
                    print("%s:%s%s %s" % (type(obj).__name__, 
                                          " GPU" if obj.is_cuda else "",
                                          " pinned" if obj.is_pinned else "",
                                          pretty_size(obj.size())))
                    total_size += obj.numel()
            elif hasattr(obj, "data") and torch.is_tensor(obj.data):
                if not gpu_only or obj.is_cuda:
                    print("%s → %s:%s%s%s%s %s" % (type(obj).__name__, 
                                                   type(obj.data).__name__, 
                                                   " GPU" if obj.is_cuda else "",
                                                   " pinned" if obj.data.is_pinned else "",
                                                   " grad" if obj.requires_grad else "", 
                                                   " volatile" if obj.volatile else "",
                                                   pretty_size(obj.data.size())))                    
                    
                    total_size += obj.data.numel()
                    del obj
            del obj
        except Exception as e:
            pass        
    print("Total size:", total_size)


In [None]:
# layerwise training of procogan 
def layerwise(truncation_indices, root,
              attribute_path=None, attribute_index=20, 
              init_resolution=4, final_resolution=128, 
              n_validation=1000, training_samples=50000, 
              training_gen_samples=50000, d_latent=48, 
              upsample=True, plot_histogram=False, 
              bias=True, match_hist=True, 
              cropped=False, use_attr=False, 
              not_attr=False, fit_upsampled=False,
              avg=False, save_img=True, skip_connection=False, device='cpu'):
    
    """Implementation of layerwise training of ProCoGAN. Requires truncation_indices, which determines how much
    to truncate at each stage (i.e. how many singular values of the true data to keep), and root, which represents
    the root directory of the CelebA dataset. If one desires, one can filter images by a particular attribute
    (e.g. facial hair) if attribute_path, attribute_index, and use_attr are specified. May use GPU or CPU. 
    """
    
    assert d_latent >= 3/truncation_indices[0]*(init_resolution)**2, "latent dimension must be greater than truncation dimension"
    start_all = time.time()
    
    if fit_upsampled:
        upsample=False
    
    print('loading data')
    if use_attr and attribute_path is not None:
        attr_data = np.loadtxt(attribute_path, skiprows=2, dtype='str')
        fnames = attr_data[:, 0]
        if not_attr:
             attrs = (attr_data[:, attribute_index+1].astype(np.int32) - 1).astype('bool')
        else:
            attrs = (attr_data[:, attribute_index+1].astype(np.int32) + 1).astype('bool')
        relevant_images = set(fnames[attrs])

        images_path =np.array([ os.path.join(root, item)  for item in os.listdir(root) if item in relevant_images])
        print(len(images_path), 'images with desired attribute')
    else:
        images_path =np.array([ os.path.join(root, item)  for item in os.listdir(root)])

    weights = []
    num_blocks = int(np.log2(final_resolution//init_resolution)) + 1
    
    curr_images_path = images_path[:training_samples]
    lazy_arrays = [composition(fn, final_resolution, cropped=cropped) for fn in curr_images_path]
    
    curr_res = init_resolution
    
    for i in range(num_blocks):
        with torch.no_grad():
            print('res', curr_res, 'reshaping data')

            if fit_upsampled:
                curr_arrs = [np.array(im1.resize((curr_res, curr_res)).resize((final_resolution, final_resolution), resample=0), 
                                      dtype='uint8') for im1 in lazy_arrays]
            else:
                curr_arrs = [np.array(im1.resize((curr_res, curr_res)), dtype='uint8') for im1 in lazy_arrays]
            image_data = np.stack(curr_arrs)
            celebA= image_data/255
            
            celeb_a_means = np.mean(celebA, axis=0)
            celeb_a_centered = celebA - celeb_a_means

            celeb_a_centered = celeb_a_centered.reshape((celeb_a_centered.shape[0], -1))

            d = celeb_a_centered.shape[1]

            truncation_idx = truncation_indices[i] # keep fraction of singular values for this experiment at each block

            print('computing SVD of inputs')
            Z = np.random.randn(training_gen_samples, d_latent)
            Z_hat = Z
            tmp_im_size = init_resolution

            if bias:
                Z_hat = np.concatenate((Z_hat, np.ones((len(Z_hat), 1))), 1)

            # forward pass through the network
            for j, w in enumerate(weights):
                if device=='cuda':
                    Z_hat = torch.from_numpy(Z_hat).float().cuda()
                    w = torch.from_numpy(w).float().cuda()

                if skip_connection and j > 0:
                    Z_hat_prev = Z_hat
                    Z_hat =Z_hat @ w 
                    Z_hat += Z_hat_prev[:, :Z_hat.shape[1]]
                else: 
                    Z_hat = Z_hat @ w

                if device=='cuda':
                    Z_hat = Z_hat.detach().cpu().numpy()
                    w = w.detach().cpu().numpy()

                # upsample
                if upsample:
                    Z_hat = Z_hat.reshape((Z_hat.shape[0], tmp_im_size, tmp_im_size, 3))
                    Z_hat = Z_hat.repeat(2, axis=1).repeat(2, axis=2)
                    Z_hat = Z_hat.reshape((Z_hat.shape[0], -1))
                    tmp_im_size *= 2

                if bias:
                    Z_hat = np.concatenate((Z_hat, np.ones((len(Z_hat), 1))), 1)

            start = time.time()
            if device=='cuda':
                
                if curr_res < 64:
                    _, D, q = torch.svd(torch.from_numpy(Z_hat).float().cuda())
                    qt = q.t()
                else:
                    _, D, q = torch.svd_lowrank(torch.from_numpy(Z_hat).float().cuda(), q=truncation_idx)
                    qt = q.t()
                del _
            else:
                _, D, qt = np.linalg.svd(Z_hat)
            print('svd took', time.time()- start)

            if skip_connection and i > 0:
                celeb_a_centered -= Z_hat[:, :d]
            
            print('computing svd')
            start = time.time()
            if device=='cuda':
                if curr_res < 64:
                    _, s, v = torch.svd(torch.from_numpy(celeb_a_centered).float().cuda())
                    vt = v.t()
                else:
                    _, s, v = torch.svd_lowrank(torch.from_numpy(celeb_a_centered).float().cuda(), q=truncation_idx+1)
                    vt = v.t()
            else:
                _, s, vt = np.linalg.svd(celeb_a_centered)
            print('svd took', time.time()- start)
            
            beta = s[truncation_idx]**2
            print('beta', beta)

            if device=='cuda':
                W_g = qt[:truncation_idx, :].t() @ torch.diag(1/D[:truncation_idx]) @ torch.diag(torch.sqrt(s[:truncation_idx]**2-beta)) @ vt[:truncation_idx, :]
                W_g = W_g.detach().cpu().numpy()
            else:
                W_g = qt[:truncation_idx, :].T @np.diag(1/D[:truncation_idx]) @ np.diag(np.sqrt(s[:truncation_idx]**2-beta)) @ vt[:truncation_idx, :]

            weights.append(W_g)

            print('computing validation instances')
            out = np.random.randn(n_validation, d_latent)
            if bias:
                out = np.concatenate((out, np.ones((len(out), 1))), 1)

            tmp_im_size = init_resolution

            # forward pass through the network
            for j, w in enumerate(weights):
                if device=='cuda':
                    out = torch.from_numpy(out).float().cuda()
                    w = torch.from_numpy(w).float().cuda()

                if skip_connection and j > 0:
                    out_prev = out
                    out = out @ w 
                    out = out + out_prev[:, :out.shape[1]]
                else: 
                    out = out @ w

                if device=='cuda':
                    out = out.detach().cpu().numpy()
                    w = w.detach().cpu().numpy()

                if upsample and  j < len(weights)-1:
                    # upsample
                    out = out.reshape((out.shape[0], tmp_im_size, tmp_im_size, 3))
                    out = out.repeat(2, axis=1).repeat(2, axis=2)
                    out= out.reshape((out.shape[0], -1))
                    tmp_im_size *= 2

                if bias and j < len(weights) - 1:
                    out = np.concatenate((out, np.ones((len(out), 1))), 1)

            if fit_upsampled:
                validation_generated_data = out.reshape((n_validation, final_resolution, final_resolution,3))
            else:
                validation_generated_data = out.reshape((n_validation, curr_res, curr_res,3))
            
            validation_generated_data += celeb_a_means

            display_img = validation_generated_data

            if avg:
                display_img = (display_img[:n_validation//2] + display_img[n_validation//2:])/2

            if match_hist:
                display_img = match_histograms(display_img, celebA, multichannel=True)


            print('real data')
            fig = plt.figure(figsize=(10., 10.))
            grid = ImageGrid(fig, 111,  # similar to subplot(111)
                             nrows_ncols=(3, 3),
                             axes_pad=0.1,  # pad between axes in inch.
                             )

            for ax, im in zip(grid, celebA[:9]):
                # Iterating over the grid returns the Axes.
                ax.imshow(im)

            plt.show()

            print('generated data')
            fig = plt.figure(figsize=(10., 10.))
            grid = ImageGrid(fig, 111,  # similar to subplot(111)
                             nrows_ncols=(3, 3),
                             axes_pad=0.1,  # pad between axes in inch.
                             )

            for ax, im in zip(grid, display_img[:9]):
                # Iterating over the grid returns the Axes.
                ax.imshow(im)


            if plot_histogram:
                print('histogram of generated data samples')
                fig = plt.figure()
                fig, ((ax0, ax1, ax2), (ax3, ax4, ax5), (ax6, ax7, ax8)) = plt.subplots(nrows=3, ncols=3,
                                figsize=(10., 10.)# pad between axes in inch.
                                 )

                ax0.hist(validation_generated_data[0].flatten())
                ax1.hist(validation_generated_data[1].flatten())
                ax2.hist(validation_generated_data[2].flatten())
                ax3.hist(validation_generated_data[3].flatten())
                ax4.hist(validation_generated_data[4].flatten())
                ax5.hist(validation_generated_data[5].flatten())
                ax6.hist(validation_generated_data[6].flatten())
                ax7.hist(validation_generated_data[7].flatten())
                ax8.hist(validation_generated_data[8].flatten())
            
            if save_img:
                print('saving images')
                dirname = '../generated_images' + str(datetime.now())
                os.mkdir(dirname)
                for i, im in enumerate(display_img[:50]):
                    plt.imsave(os.path.join(dirname, str(i)+'.png'), im)
            
            plt.show()
            curr_res *= 2
        
    print('total running time', time.time()-start_all)

    print('calculating FID score')
    fid = calculate_fretchet(celebA,display_img, batch_size=20)
             
    return validation_generated_data, weights, fid

# Example Run

In [None]:
# Set CelebA directory here
root="/mnt/raid3/sahiner/img_align_celeba"

In [None]:
# Runs 3 times with the lower beta values reported in Figure 3 of the main paper, and saves 50 
# representative images from each resolution. Clears GPU memory between each run. 

runs_low_beta = 3

for i in range(runs_low_beta):

    if i==0:
        saveimg = False
    else:
        saveimg=False
    res= layerwise(truncation_indices=[20, 30, 40, 100, 175],
                   root=root, init_resolution=4, 
                   final_resolution=64, 
                   training_samples=50000, 
                   training_gen_samples=50000, d_latent=48, 
                   upsample=True, plot_histogram=False, 
                   bias=True, match_hist=True, 
                   cropped=False, not_attr=True, use_attr=False, 
                   fit_upsampled=False, avg=False, save_img=saveimg, 
                   skip_connection=False, device='cuda')
    
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    gc.collect()
    dump_tensors()