In [None]:
import os
import sys
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image as im
import numpy as np
import subprocess
import matplotlib.pyplot as plt
import sklearn
from scipy import stats
import pickle
%matplotlib inline

In [None]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

### Test

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), AddGaussianNoise(0, 0.1)])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
# functions to show an image

def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show imagesz
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(images.shape[0])))

### Define main functions

In [None]:
import gs

In [None]:
def get_rlts(X):
    N = 2500
    gamma = (1/128)/(N/5000)
    rlts = gs.rlts(X, gamma=gamma, n=N, n_threads = 40)
    
    return rlts

In [None]:
idx = 0
labels_idx = [None] * 50000

for image, label in trainloader:
    labels_idx[idx] = label.item()
    idx += 1

In [None]:
def get_idx_with_label(cl):
    idx_with_label = list(filter(lambda x : x[1] == cl, list(enumerate(labels_idx))))
    idx_with_label = [x[0] for x in idx_with_label]
    
    return idx_with_label

In [None]:
def get_statified(num):
    
    idx = []
    
    for i in range(10):
        idx.extend(get_idx_with_label(i)[0:num])
        
    return idx

In [None]:
get_statified(1)

In [None]:
def cifar10_filtered(allowed_labes = set(range(10)), transforms_list = [], train = True):
    
    transform = transforms.Compose(
        [transforms.ToTensor()] + transforms_list)
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=train,
                                        download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=False, num_workers=1)
    for image, label in trainloader:
        if label.item() in allowed_labes:
            yield image, label

In [None]:
all_idx = list(range(50000))

np.random.seed(7)
np.random.shuffle(all_idx)
ten_thousands_idx = set(all_idx[0:10000])

In [None]:
def write_dir(adir, allowed_labels, selected_idx = None, transforms_list = [], train = True, max_cnt = None):
    os.system('rm -rf %s' % adir)
    os.mkdir(adir)
    
    idx = 0
    X = []
    
    for image, label in cifar10_filtered(allowed_labels, transforms_list, train):
        if selected_idx:
            if not (idx in selected_idx):
                idx += 1
                continue
                
        if max_cnt:
            if idx == max_cnt:
                break
                
        npimg = image.numpy().mean(axis = 0)
        png_data = im.fromarray((255*np.transpose(npimg, (1, 2, 0))).astype('uint8'))
        path = '%s/%d.png' % (adir, idx)
        png_data.save(path)
        X.append(npimg.flatten())
        idx += 1
        
    print('num points', len(X))
    return np.array(X)

In [None]:
def run_exp(func_exp):
    
    rlts = []
    res = []
    i = 0
    
    for X_base, X in func_exp():
        
        archive.append((X_base, X))
        
        if not rlts:
            rlts.append(get_rlts(X_base))

        rlts.append(get_rlts(X))
        cmd = 'pytorch-fid tmp1 tmp2 --device cuda:1'
        res_str = subprocess.run(cmd.split(' '), capture_output=True, text=True).stdout
        
        res.append((i, res_str))
        i+= 1
        
    return res, rlts

In [None]:
def get_q95(rlts_base, rlts):
    mean_base = np.mean(rlts_base, axis = 0)

    gs_base = []
    
    for i in range(1000):
        rlts2 = sklearn.utils.resample(rlts_base)
        
        mrlt1 = mean_base
        mrlt2 = np.mean(rlts2, axis=0)
        gs_base.append(np.sum((mrlt1 - mrlt2) ** 2))
        
    
    idx = int(len(gs_base)*0.95)
    q95 = sorted(list(gs_base))[idx]

    return 1e3 * q95

In [None]:
def print_stat(rlts):
    
    print('q95', get_q95(rlts[0], None))
    print()
    
    for i in range(1, len(rlts)):
        print(1e3 * gs.geom_score(rlts[0], rlts[i]))

In [None]:
archive = []

### Mode drop

In [None]:
def mode_drop_exp():
    all_labels = set(range(10))
    X_base = write_dir('tmp1', all_labels, train = False)

    for i in range(5):
        X = write_dir('tmp2', all_labels, max_cnt = 10000)
        all_labels.remove(i)
        
        yield X_base, X

In [None]:
res_drop, rlts_drop = run_exp(mode_drop_exp)

In [None]:
res_drop

In [None]:
print_stat(rlts_drop)

### Mode drop by class

In [None]:
def mode_drop_exp2():
    all_labels = set(range(10))
    X_base = write_dir('tmp1', all_labels, train = False)

    for i in range(10):
        X = write_dir('tmp2', all_labels.difference(set([i])), max_cnt = 10000)
        yield X_base, X

In [None]:
res_drop2, rlts_drop2 = run_exp(mode_drop_exp2)

In [None]:
res_drop2

In [None]:
print_stat(rlts_drop2)

### Mode invention

In [None]:
def mode_invention_exp():
    X_base = write_dir('tmp1', set(range(5)), train = False)
    new_labels = set(range(5))
    
    for i in range(5, 10):
        X = write_dir('tmp2', new_labels, max_cnt = 5000)
        new_labels.add(i)
        
        yield X_base, X

In [None]:
res_invention, rlts_invention = run_exp(mode_invention_exp)

In [None]:
res_invention

In [None]:
print_stat(rlts_invention)

### Intra-mode collapse

In [None]:
def intra_mode_collapse_exp():
    all_labels = set(range(10))
    X_base = write_dir('tmp1', all_labels, train = False)

    for c in [1, 10, 100, 1000]:
        X = write_dir('tmp2', all_labels, get_statified(c))
        yield X_base, X

In [None]:
res_intra, rlts_intra = run_exp(intra_mode_collapse_exp)

In [None]:
res_intra

In [None]:
print_stat(rlts_intra)

### Random Erase

In [None]:
def random_erase_exp():
    all_labels = set(range(10))
    X_base = write_dir('tmp1', all_labels, train = False)

    for ascale in [0.0, 0.01, 0.05, 0.25]:
        random_erase = [transforms.RandomErasing(scale = (ascale, ascale))]
        X = write_dir('tmp2', all_labels, transforms_list = random_erase, max_cnt = 10000)
        
        yield X_base, X

In [None]:
res_erase, rlts_erase = run_exp(random_erase_exp)

In [None]:
res_erase

In [None]:
print_stat(rlts_erase)

### Add Gaussian Noise

In [None]:
def gaussian_noise_exp():
    all_labels = set(range(10))
    X_base = write_dir('tmp1', all_labels, train = False)

    for sigma in [0.0, 0.01, 0.02, 0.04, 0.08]:
        random_erase = [AddGaussianNoise(0, sigma)]
        X = write_dir('tmp2', all_labels, transforms_list = random_erase, max_cnt = 10000)
        
        yield X_base, X

In [None]:
res_gauss, rlts_gauss = run_exp(gaussian_noise_exp)

In [None]:
res_gauss

In [None]:
print_stat(rlts_gauss)

In [None]:
#pickle.dump(archive, open('archive_v2.pickle', 'wb'))