In [1]:
import time

import torch
import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np

import matplotlib.pyplot as plt
import random
from torchsummary import summary

import os

from model import Generator_32, Discriminator, weights_init, compute_acc
from resnet20 import ResNetCIFAR
from train_util import test

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
batch_size = 100
input_size = 110
num_classes = 10
image_size = 32
EPOCH = 100
noise_sd = 1
LR = 0.0002

In [4]:
resnet = ResNetCIFAR(num_layers=20)
resnet = resnet.to(device)
resnet.load_state_dict(torch.load("./model/resnet20.pt"))

<All keys matched successfully>

In [5]:
test(resnet)

Files already downloaded and verified
Test Loss=0.3238, Test accuracy=0.9116


In [6]:
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [7]:
model = Generator_32().to(device)
model.load_state_dict(torch.load(os.path.join('./model/', "netG_epoch_999.pth")))

<All keys matched successfully>

In [8]:
def get_sample_images(genmod, num_classes, noise_sd, n_images):
    random.seed(123)
    torch.manual_seed(123)
    noise = torch.FloatTensor(n_images, 110, 1, 1).to(device)
    label = np.random.randint(0, num_classes, n_images)
    noise_ = np.random.normal(0, noise_sd, (n_images, 110))
    class_onehot = np.zeros((n_images, num_classes))
    class_onehot[np.arange(n_images), label] = 1
    noise_[np.arange(n_images), :num_classes] = class_onehot[np.arange(n_images)]
    noise_ = (torch.from_numpy(noise_))
    noise.data.copy_(noise_.view(n_images, 110, 1, 1))
    fake = genmod(noise)
    return label, fake

In [9]:
labels, fake_imgs = get_sample_images(model, 10, 1, 50000)

In [11]:
np.save('./inception/labels.npy', labels)

In [13]:
np.save('./inception/fake_img_arr.npy', fake_imgs.detach().cpu().numpy())

In [9]:
def calculate_inception_scores(fake_imgs, resmod, n_splits, eps=1e-15):
    scores, scores_class = [], []
    for i in range(n_splits):
        cur_split = fake_imgs[i*n_splits: (i+1)*n_splits]
        predict = resmod(cur_split)
        p_yx = F.softmax(predict, dim=1)
        p_y = p_yx.mean(0, keepdim=True)
        KL = p_yx * (torch.log(p_yx + eps) - torch.log(p_y + eps))
        avg_KL = KL.sum(1).mean()
        inception_score = torch.exp(avg_KL)
        scores.append(inception_score.detach().cpu().numpy())

    return np.array(scores)

In [11]:
scores = calculate_inception_scores(fake_imgs=fake_imgs, resmod=resnet, n_splits=10, eps=1e-15)

In [12]:
print("mean score: %.4f, sd score:: %.4f" % (np.mean(scores), np.std(scores)))

mean score: 3.2761, sd score:: 0.7480


In [93]:
np.unique(labels, return_counts=True)

(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 array([4978, 4869, 5046, 4968, 5023, 4992, 5047, 5143, 4956, 4978]))

In [31]:
scores, scores_class = [], []
eps=1e-15
p_yx = np.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
p_y = p_yx.mean(axis=0,keepdims=True)
KL = p_yx * (np.log(p_yx + eps) - np.log(p_y + eps))
avg_KL = np.mean(KL.sum(axis=1))
inception_score = np.exp(avg_KL)
inception_score

2.9999999999999942

## IS on the raw training set

In [10]:
transform_train = transforms.Compose([
        transforms.Resize(32),
#         transforms.RandomCrop(image_size, padding=4),
#         transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

    ])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=16)



Files already downloaded and verified


In [11]:
outputs = []
for i, data in enumerate(trainloader, 0):
    image, label = data
    outputs.append(image)
cifar_traindata = torch.cat(outputs).cuda()
print(cifar_traindata.shape)

torch.Size([50000, 3, 32, 32])


In [15]:
raw_scores = calculate_inception_scores(cifar_traindata, resnet, 10, 1e-15)
print(np.mean(raw_scores), np.std(raw_scores))

4.784825 0.6176199
