In [1]:
## load necessary modules
import argparse
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from utils.tools import *
from utils.losses import *
from models.cifar10 import *

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from datetime import datetime
now = datetime.now()
timestamp = now.strftime("%m_%d_%H%M")
print(timestamp)

02_16_1126


In [2]:
# Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # check if gpu is available

## load datasets
# train_gen, dev_gen, test_gen = load(batch_size, batch_size)
# data = inf_train_gen_mnist(train_gen)
transform    = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),  # Normalize the image
                         (0.5, 0.5, 0.5))
])
train_gen    = dsets.CIFAR10(root="./datasets",train=True, transform=transform, download=True)
test_gen     = dsets.CIFAR10(root="./datasets",train=False, transform=transform, download=True)
# train_loader = DataLoader(train_gen, batch_size=batch_size, shuffle=True)
# test_loader  = DataLoader(test_gen, batch_size=batch_size, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [14]:
## hyper-parameters
n_rep = 1
epochs1 = 70
epochs2 = 30
std = 0.1
lr_GI = 1e-4
lr_D = 5e-4
weight_decay = 0.01
batch_size = 250
z_dim = 5
lambda_mmd = 1.0
lambda_gp = 0.1
lambda_power = 0.2
eta = 3.0
present_label = list(range(10))
missing_label = []
all_label     = present_label + missing_label
classes       = train_gen.classes

In [15]:
# ************************
# *** DPI-RG Algorithm ***
# ************************

cover_accs = []
avg_counts = []

# for rep in range(n_rep):
T_trains = []
for lab in present_label:
    ## initialize models
    netI = I_CIFAR10(nz=z_dim)
    netG = G_CIFAR10(nz=z_dim)
    netD = D_CIFAR10(nz=z_dim)
    netI = netI.to(device)
    netG = netG.to(device)
    netD = netD.to(device)
    netI = nn.DataParallel(netI)
    netG = nn.DataParallel(netG)
    netD = nn.DataParallel(netD)

    ## set up optimizers
    optim_I = optim.Adam(netI.parameters(), lr=lr_GI, betas=(0.5, 0.999))
    optim_G = optim.Adam(netG.parameters(), lr=lr_GI, betas=(0.5, 0.999))
    optim_D = optim.Adam(netD.parameters(), lr=lr_D, betas=(0.5, 0.999), 
                         weight_decay=weight_decay)
    ## filter data for each label and train them respectively
    if torch.is_tensor(train_gen.targets):
        idxs = torch.where(train_gen.targets == lab)[0] 
    else:
        idxs = torch.where(torch.Tensor(train_gen.targets) == lab)[0] 
    train_data = torch.utils.data.Subset(train_gen, idxs)
    train_loader  = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    ## train for the first time
    train_al(netI, netG, netD, optim_I, optim_G, optim_D,
             train_gen, train_loader, batch_size, 0, epochs1, 
             z_dim, device, lab, present_label, all_label, 
             lambda_gp, lambda_power, lambda_mmd = lambda_mmd,
             img_size = 32, nc = 3, eta = eta, 
             lr_decay = None, trace=True)

    ## find out fake_zs
    fake_zs = []
    with torch.no_grad(): 
        for i, batch in enumerate(train_loader):
            x, _ = batch
            fake_z = netI(x.to(device))
            fake_zs.append(fake_z)
    fake_zs = torch.cat(fake_zs)
    ## get the empirical distribution for each label
    T_train = torch.sqrt(torch.sum(fake_zs ** 2, dim=1) + 1)

    ## get powers to determine new sample sizes
    powers = []
    for cur_lab in present_label:    
        if cur_lab != lab:
            # fake_Cs for this class
            if torch.is_tensor(train_gen.targets):
                idxs3 = torch.where(train_gen.targets == cur_lab)[0] 
            else:
                idxs3 = torch.where(torch.Tensor(train_gen.targets) == cur_lab)[0] 
            train_data3 = torch.utils.data.Subset(train_gen, idxs3)
            train_loader3  = DataLoader(train_data3, batch_size=batch_size, shuffle=False)
            p_vals = torch.zeros(len(idxs3)) 
            fake_zs = torch.zeros(len(idxs3))
            em_len = len(T_train)

            for i, batch in enumerate(train_loader3):
                x, _ = batch
                fake_z = netI(x.to(device))
                T_batch = torch.sqrt(torch.sum(fake_z ** 2, dim=1) + 1)

                # compute p-value for each sample
                for j in range(len(fake_z)):
                    p1 = torch.sum(T_train > T_batch[j]) / em_len
                    p = p1
                    # calculate the p-value and put it in the corresponding list
                    p_vals[i * batch_size + j] = p.item()
            powers.append(np.sum(np.array(p_vals) <= 0.05) / len(idxs3))
            
    sample_sizes = max(powers) - powers + 0.05
    sample_sizes = (sample_sizes / sum(sample_sizes) * len(idxs3)).astype(int)
    # print(sample_sizes)
    ## train for the second time according to the calculated sample sizes
    train_al(netI, netG, netD, optim_I, optim_G, optim_D,
             train_gen, train_loader, batch_size, epochs1, epochs2, 
             z_dim, device, lab, present_label, all_label, 
             lambda_gp, lambda_power, lambda_mmd = lambda_mmd, sample_sizes = sample_sizes, 
             img_size = 32, nc = 3, eta = eta, 
             lr_decay = None, trace = True)
    
    ## find out fake_zs
    fake_zs = []
    with torch.no_grad(): 
        for i, batch in enumerate(train_loader):
            x, _ = batch
            fake_z = netI(x.to(device))
            fake_zs.append(fake_z)
    fake_zs = torch.cat(fake_zs)
    ## get the empirical distribution for each label
    T_train = torch.sqrt(torch.sum(fake_zs ** 2, dim=1) + 1)
    T_trains.append(T_train)

    ## save net and graphs for each label
    model_save_file = f'cifar10_param/{timestamp}_class{lab}.pt'
    torch.save(netI.state_dict(), model_save_file)
    del netI
    print('Class', lab)

GI: 1.144873
MMD: 0.022351
D: -0.002342
gp: 0.083201
power: 0.222840
GI: 0.861455
MMD: 0.032856
D: -0.001182
gp: 0.081699
power: 0.210729
GI: 0.732750
MMD: 0.023839
D: -0.018673
gp: 0.084783
power: 0.196006
GI: 0.656960
MMD: 0.029777
D: -0.023253
gp: 0.084481
power: 0.182653
GI: 0.629430
MMD: 0.020457
D: -0.011703
gp: 0.085584
power: 0.179821
GI: 0.625214
MMD: 0.020502
D: -0.042227
gp: 0.075454
power: 0.164620
GI: 0.550818
MMD: 0.025694
D: -0.021140
gp: 0.078590
power: 0.160562
GI: 0.621817
MMD: 0.025842
D: -0.061191
gp: 0.077208
power: 0.158691
GI: 0.532500
MMD: 0.030207
D: -0.032460
gp: 0.073779
power: 0.149494
GI: 0.542564
MMD: 0.023767
D: -0.066100
gp: 0.078054
power: 0.140824
GI: 0.535765
MMD: 0.025596
D: -0.102516
gp: 0.074854
power: 0.141254
GI: 0.457590
MMD: 0.027578
D: -0.075387
gp: 0.079415
power: 0.135380
GI: 0.391673
MMD: 0.020192
D: -0.041276
gp: 0.080159
power: 0.124399
GI: 0.432696
MMD: 0.025203
D: -0.148227
gp: 0.080202
power: 0.121018
GI: 0.463569
MMD: 0.023138
D: -0.1

In [16]:
# training for verification
all_p_vals  = []
all_fake_Ts = []

for lab in all_label:    
    if torch.is_tensor(train_gen.targets):
        idxs2 = torch.where(train_gen.targets == lab)[0] 
    else:
        idxs2 = torch.where(torch.Tensor(train_gen.targets) == lab)[0] 
    test_data = torch.utils.data.Subset(train_gen, idxs2)
    test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    # p_vals and fake_zs store p-values, fake_zs for the current iteration
    fake_Ts = torch.zeros(len(present_label), len(idxs2))
    p_vals = torch.zeros(len(present_label), len(idxs2)) 

    for pidx in range(len(present_label)):
        T_train = T_trains[pidx]
        em_len = len(T_train)
        netI = I_CIFAR10(nz=z_dim)
        netI = netI.to(device)
        netI = torch.nn.DataParallel(netI)
        model_save_file = f'cifar10_param/{timestamp}_class{present_label[pidx]}.pt'
        netI.load_state_dict(torch.load(model_save_file))
        
        for i, batch in enumerate(test_loader):
            images, y = batch
            x = images.view(-1, 3, 32 * 32).to(device)
            fake_z = netI(x)
            T_batch = torch.sqrt(torch.sum(torch.square(fake_z), 1) + 1) 
            ## compute p-value for each sample
            for j in range(len(fake_z)):
                p1 = torch.sum(T_train > T_batch[j]) / em_len
                p = p1
                # calculate the p-value and put it in the corresponding list
                fake_Ts[pidx, i * batch_size + j] = T_batch[j].item()
                p_vals[pidx, i * batch_size + j] = p.item()

    all_p_vals.append(np.array(p_vals))
    ## concatenate torch data
    all_fake_Ts.append(np.array(fake_Ts))
    # print('Finished Label {}'.format(lab))

IndexError: list index out of range

In [None]:
visualize_T(all_fake_Ts, present_label, all_label, missing_label, z_dim, classes)

In [None]:
visualize_p(all_p_vals, present_label, all_label, missing_label, z_dim, classes)

In [None]:
## test data set
all_p_vals  = []
all_fake_Ts = []

for lab in all_label:    
    if torch.is_tensor(test_gen.targets):
        idxs2 = torch.where(test_gen.targets == lab)[0] 
    else:
        idxs2 = torch.where(torch.Tensor(test_gen.targets) == lab)[0] 
    test_data = torch.utils.data.Subset(test_gen, idxs2)
    test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    # p_vals and fake_zs store p-values, fake_zs for the current iteration
    fake_Ts = torch.zeros(len(present_label), len(idxs2))
    p_vals = torch.zeros(len(present_label), len(idxs2)) 

    for pidx in range(len(present_label)):
        T_train = T_trains[pidx]
        em_len = len(T_train)
        netI = I_CIFAR10(nz=z_dim)
        netI = netI.to(device)
        netI = torch.nn.DataParallel(netI)
        model_save_file = f'cifar10_param/{timestamp}_class{present_label[pidx]}.pt'
        netI.load_state_dict(torch.load(model_save_file))
        
        for i, batch in enumerate(test_loader):
            images, y = batch
            x = images.view(-1, 3, 32 * 32).to(device)
            fake_z = netI(x)
            T_batch = torch.sqrt(torch.sum(torch.square(fake_z), 1) + 1) 
            ## compute p-value for each sample
            for j in range(len(fake_z)):
                p1 = torch.sum(T_train > T_batch[j]) / em_len
                p = p1
                # calculate the p-value and put it in the corresponding list
                fake_Ts[pidx, i * batch_size + j] = T_batch[j].item()
                p_vals[pidx, i * batch_size + j] = p.item()

    all_p_vals.append(np.array(p_vals))
    ## concatenate torch data
    all_fake_Ts.append(np.array(fake_Ts))
    # print('Finished Label {}'.format(lab))

In [None]:
# visualize_T(all_fake_Ts, present_label, all_label, missing_label, z_dim, classes)

In [None]:
visualize_p(all_p_vals, present_label, all_label, missing_label, z_dim, classes)

In [11]:
cover_accs = []
avg_counts = []

cover_acc = torch.zeros(len(all_label))
avg_count = torch.zeros(len(all_label))
for i, lab in enumerate(all_label):
    p_vals = all_p_vals[i]
    n = p_vals.shape[1]
    cover = 0.0
    counts = 0.0
    for j in range(n):
        pred = np.argmax(p_vals[:, j])
        p_set = np.where(p_vals[:, j] > 0.05)[0]
        counts += len(p_set)
        if lab in missing_label:
            if len(p_set) == 0:
                cover += 1
        else:
            if all_label[i] in p_set:
                cover += 1
    cover_acc[i] = cover / n
    avg_count[i] = counts / n
cover_accs.append(cover_acc)
avg_counts.append(avg_count)

In [12]:
print(cover_accs)
print(avg_counts)

[tensor([0.9730, 0.8690, 0.9530, 0.9390, 0.9370, 0.9580, 0.9460, 0.9550, 0.9140,
        0.9500])]
[tensor([3.8330, 4.0980, 4.3980, 3.7430, 3.6970, 3.8000, 4.5480, 3.5420, 3.7370,
        4.0930])]
