## Packages and hyper-parameters 

In [1]:
import itertools

import torch
import torch.nn as nn
import torchvision
import torchvision.utils as vutils
import torchvision.transforms as transforms
import torchvision.models as models
from torch import optim
from torch.utils.data import DataLoader

from models import GeneratorA2B
from models import GeneratorB2A
from models import DiscriminatorA

from utils import train_al
from utils import visualize_fake_C, visualize_p
from utils import weights_init_normal
from utils import weights_init
from utils import LambdaLR

# import different loss functions for GAN B
from geomloss import SamplesLoss

import os
os.makedirs('mnist_models', exist_ok=True)
import numpy as np

# hyper-parameters
# device        = torch.device("cuda" if cuda else "cpu")
transform     = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))])
trainset_A    = torchvision.datasets.MNIST(root="./datasets",train=True, transform=transform, download=True)
testset_A     = torchvision.datasets.MNIST(root="./datasets",train=False, transform=transform, download=True)
missing_label = []
present_label = list(range(10))
all_label     = present_label + missing_label
# classes       = trainset_A.classes
# idxs          = torch.where(torch.Tensor([x in present_label for x in trainset_A.targets]))[0] 
# idxs_         = torch.where(torch.Tensor([x in all_label for x in testset_A.targets]))[0]
# train_data    = torch.utils.data.Subset(trainset_A, idxs)
# test_data     = torch.utils.data.Subset(testset_A, idxs_)

# train_loader  = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# test_loader   = DataLoader(test_data, batch_size=batch_size, shuffle=False)

## TSNE dimension reduction

In [2]:
from sklearn.neighbors import KernelDensity
from sklearn.manifold import TSNE
# kdes = []
# for i in range(len(present_label)):
#     if torch.is_tensor(trainset_A.targets):
#         idxs = torch.where(trainset_A.targets == i)[0] 
#     else:
#         idxs = torch.where(torch.Tensor(trainset_A.targets) == i)[0] 
#     x_subset = trainset_A.data[idxs].view(-1, 28*28)
#     train_data = torch.utils.data.Subset(trainset_A, idxs)
#     x_subset = train_data.dataset.data.view(-1,28*28)
x_train    = trainset_A.data.view(-1, 28*28)
x_test     = testset_A.data.view(-1, 28*28)
x_all      = torch.cat((x_train, x_test), 0)
tsne_all   = TSNE(n_components=2, random_state=42, n_iter=500).fit_transform(x_all)
tsne_train = tsne_all[:len(x_train),:]
tsne_test = tsne_all[len(x_train):,:]

## Conditional density estimation

In [11]:
kdes = []
for i in range(10):
    if torch.is_tensor(trainset_A.targets):
        idxs = torch.where(trainset_A.targets == i)[0] 
    else:
        idxs = torch.where(torch.Tensor(trainset_A.targets) == i)[0] 
    kde = KernelDensity(kernel='gaussian', bandwidth=1)
    kde.fit(tsne_train[idxs])
    kdes.append(kde)

## dens_class and store p-values for each class 

In [12]:
dens_classes = np.zeros((len(present_label), len(testset_A)))
for lab in all_label:   
    kde = kdes[lab]
    log_p = kde.score_samples(tsne_test)
    p = np.exp(log_p)
    # calculate the p-value and put it in the corresponding list
    dens_classes[lab, :] = p
    ## logger
#     print('Finished Label {}'.format(lab))

In [13]:
dens_classes[:,0]

array([1.96091345e-14, 2.42076140e-05, 8.68497574e-05, 1.89709612e-05,
       8.06078040e-06, 9.75218717e-08, 5.42156515e-39, 1.11438278e-02,
       4.17691623e-07, 2.61479891e-05])

## Coverage and set size error

In [14]:
cover_accs = []
avg_sizes = []

In [15]:
cover = torch.zeros(len(all_label))
size = torch.zeros(len(all_label))
count = torch.zeros(len(all_label))
for i in range(len(testset_A)):
    dens = dens_classes[:,i]
    lab = testset_A.targets[i].item()
    ## sort the p value list and get the corresponding indicies
    sorted = -np.sort(-dens)
    indicies = np.argsort(-dens)
    if sorted[0] == 0:
        p_set = np.array([])
    else:
        ## find the minimum index when the coverage first exceeds 1-alpha
        idx = np.argmax(np.cumsum(sorted) / np.sum(sorted) > 0.95)
        p_set = indicies[:idx + 1]
        size[lab] += len(p_set)
    if lab in missing_label:
        if len(p_set) == 0:
            cover[lab] += 1
    else:
        if lab in p_set:
            cover[lab] += 1
    count[lab] += 1

cover_acc = torch.div(cover, count)
avg_size = torch.div(size, count)
cover_accs.append(cover_acc)
avg_sizes.append(avg_size)

In [16]:
cover = torch.zeros(len(all_label))
size = torch.zeros(len(all_label))
count = torch.zeros(len(all_label))
for i in range(len(testset_A)):
    dens = dens_classes[:,i]
    lab = testset_A.targets[i].item()
    ## sort the p value list and get the corresponding indicies
    sorted = -np.sort(-dens)
    sorted_norm = sorted / np.sum(sorted)
    indicies = np.argsort(-dens)
    if sorted[0] == 0:
        p_set = np.array([])
    else:
        ## find the minimum index when the coverage first exceeds 1-alpha
        sorted_norm_cumsum = np.cumsum(sorted_norm)
        idx = np.argmax(sorted_norm_cumsum > 0.95)
        if idx != 0:
            gamma = (0.95 - sorted_norm_cumsum[idx - 1]) / sorted_norm[idx]
        else: 
            gamma = 0.95 / sorted_norm[idx]
        if np.random.rand(1) < gamma:
            p_set = indicies[:idx + 1]
        else: 
            p_set = indicies[:idx]
        size[lab] += len(p_set)
    if lab in missing_label:
        if len(p_set) == 0:
            cover[lab] += 1
    else:
        if lab in p_set:
            cover[lab] += 1
    count[lab] += 1

cover_acc = torch.div(cover, count)
avg_size = torch.div(size, count)
cover_accs.append(cover_acc)
avg_sizes.append(avg_size)

In [17]:
cover_accs

[tensor([0.9939, 0.9974, 0.9787, 0.9802, 0.9817, 0.9843, 0.9885, 0.9767, 0.9610,
         0.9752]),
 tensor([0.9561, 0.9868, 0.9506, 0.9604, 0.9440, 0.9709, 0.9541, 0.9621, 0.9374,
         0.9673])]

In [18]:
avg_sizes

[tensor([1.1173, 1.4141, 1.4060, 1.3980, 1.1965, 1.5538, 1.1075, 1.6060, 1.3563,
         1.7939]),
 tensor([1.0337, 1.2326, 1.2306, 1.2168, 1.1008, 1.3442, 1.0344, 1.3940, 1.2166,
         1.5441])]