In [1]:
## load necessary modules
import argparse
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from utils.tools import *
from utils.losses import *
from models.cifar10 import *

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from datetime import datetime
now = datetime.now()
timestamp = now.strftime("%m_%d_%H%M")
print(timestamp)

02_16_2042


In [2]:
# Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # check if gpu is available

## load datasets
# train_gen, dev_gen, test_gen = load(batch_size, batch_size)
# data = inf_train_gen_mnist(train_gen)
transform    = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),  # Normalize the image
                         (0.5, 0.5, 0.5))
])
train_gen    = dsets.CIFAR10(root="./datasets",train=True, transform=transform, download=True)
test_gen     = dsets.CIFAR10(root="./datasets",train=False, transform=transform, download=True)
# train_loader = DataLoader(train_gen, batch_size=batch_size, shuffle=True)
# test_loader  = DataLoader(test_gen, batch_size=batch_size, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
## hyper-parameters
n_rep = 1
epochs1 = 50
epochs2 = 50
std = 0.1
lr_GI = 2e-5
lr_D = 1e-4
weight_decay = 0.01
batch_size = 250
z_dim = 5
lambda_mmd = 1.0
lambda_gp = 0.1
lambda_power = 0.6
eta = 0.0
present_label = list(range(10))
missing_label = []
all_label     = present_label + missing_label
classes       = train_gen.classes

In [4]:
# ************************
# *** DPI-RG Algorithm ***
# ************************

cover_accs = []
avg_counts = []

# for rep in range(n_rep):
T_trains = []
for lab in present_label:
    ## initialize models
    netI = I_CIFAR10_2(nz=z_dim)
    netG = G_CIFAR10(nz=z_dim, ngf=64)
    netD = D_CIFAR10(nz=z_dim, ndf=64)
    netI = netI.to(device)
    netG = netG.to(device)
    netD = netD.to(device)
    netI = nn.DataParallel(netI)
    netG = nn.DataParallel(netG)
    netD = nn.DataParallel(netD)

    ## set up optimizers
    optim_I = optim.Adam(netI.parameters(), lr=lr_GI, betas=(0.5, 0.999))
    optim_G = optim.Adam(netG.parameters(), lr=lr_GI, betas=(0.5, 0.999))
    optim_D = optim.Adam(netD.parameters(), lr=lr_D, betas=(0.5, 0.999), 
                         weight_decay=weight_decay)
    ## filter data for each label and train them respectively
    if torch.is_tensor(train_gen.targets):
        idxs = torch.where(train_gen.targets == lab)[0] 
    else:
        idxs = torch.where(torch.Tensor(train_gen.targets) == lab)[0] 
    train_data = torch.utils.data.Subset(train_gen, idxs)
    train_loader  = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    ## train for the first time
    train_al(netI, netG, netD, optim_I, optim_G, optim_D,
             train_gen, train_loader, batch_size, 0, epochs1, 
             z_dim, device, lab, present_label, all_label, 
             lambda_gp, lambda_power, lambda_mmd = lambda_mmd,
             img_size = 32, nc = 3, eta = eta, 
             lr_decay = None, trace=True)

    ## find out fake_zs
    fake_zs = []
    with torch.no_grad(): 
        for i, batch in enumerate(train_loader):
            x, _ = batch
            fake_z = netI(x.to(device))
            fake_zs.append(fake_z)
    fake_zs = torch.cat(fake_zs)
    ## get the empirical distribution for each label
    T_train = torch.sqrt(torch.sum(fake_zs ** 2, dim=1) + 1)

    ## get powers to determine new sample sizes
    powers = []
    for cur_lab in present_label:    
        if cur_lab != lab:
            # fake_Cs for this class
            if torch.is_tensor(train_gen.targets):
                idxs3 = torch.where(train_gen.targets == cur_lab)[0] 
            else:
                idxs3 = torch.where(torch.Tensor(train_gen.targets) == cur_lab)[0] 
            train_data3 = torch.utils.data.Subset(train_gen, idxs3)
            train_loader3  = DataLoader(train_data3, batch_size=batch_size, shuffle=False)
            p_vals = torch.zeros(len(idxs3)) 
            fake_zs = torch.zeros(len(idxs3))
            em_len = len(T_train)

            for i, batch in enumerate(train_loader3):
                x, _ = batch
                fake_z = netI(x.to(device))
                T_batch = torch.sqrt(torch.sum(fake_z ** 2, dim=1) + 1)

                # compute p-value for each sample
                for j in range(len(fake_z)):
                    p1 = torch.sum(T_train > T_batch[j]) / em_len
                    p = p1
                    # calculate the p-value and put it in the corresponding list
                    p_vals[i * batch_size + j] = p.item()
            powers.append(np.sum(np.array(p_vals) <= 0.05) / len(idxs3))
            
    sample_sizes = max(powers) - powers + 0.05
    sample_sizes = (sample_sizes / sum(sample_sizes) * len(idxs3)).astype(int)
    # print(sample_sizes)
    ## train for the second time according to the calculated sample sizes
    train_al(netI, netG, netD, optim_I, optim_G, optim_D,
             train_gen, train_loader, batch_size, epochs1, epochs2, 
             z_dim, device, lab, present_label, all_label, 
             lambda_gp, lambda_power, lambda_mmd = lambda_mmd, sample_sizes = sample_sizes, 
             img_size = 32, nc = 3, eta = eta, 
             lr_decay = 10, trace = True)
    
    ## find out fake_zs
    fake_zs = []
    with torch.no_grad(): 
        for i, batch in enumerate(train_loader):
            x, _ = batch
            fake_z = netI(x.to(device))
            fake_zs.append(fake_z)
    fake_zs = torch.cat(fake_zs)
    ## get the empirical distribution for each label
    T_train = torch.sqrt(torch.sum(fake_zs ** 2, dim=1) + 1)
    T_trains.append(T_train)

    ## save net and graphs for each label
    model_save_file = f'cifar10_param/{timestamp}_class{lab}.pt'
    torch.save(netI.state_dict(), model_save_file)
    del netI
    print('-'*100)
    print('Class', lab)
    print('-'*100)

GI: 2.414025
MMD: 0.061850
D: 0.001388
gp: 0.081754
power: 0.043468
GI: 1.895977
MMD: 0.041415
D: -0.016375
gp: 0.078484
power: 0.033978
GI: 1.533893
MMD: 0.021671
D: -0.000313
gp: 0.078327
power: 0.032312
GI: 1.404056
MMD: 0.021221
D: -0.007261
gp: 0.082081
power: 0.027762
GI: 1.283061
MMD: 0.019388
D: -0.001798
gp: 0.079392
power: 0.024166
GI: 1.238168
MMD: 0.020364
D: -0.006988
gp: 0.076218
power: 0.023194
GI: 1.091419
MMD: 0.027706
D: -0.006854
gp: 0.079309
power: 0.023258
GI: 1.041238
MMD: 0.026457
D: -0.022245
gp: 0.080165
power: 0.022166
GI: 1.000254
MMD: 0.025999
D: -0.011686
gp: 0.076808
power: 0.022294
GI: 0.998105
MMD: 0.023003
D: -0.014069
gp: 0.079786
power: 0.022010
GI: 0.977537
MMD: 0.024668
D: -0.022323
gp: 0.076468
power: 0.021798
GI: 0.889499
MMD: 0.028113
D: -0.009361
gp: 0.077654
power: 0.021460
GI: 0.820475
MMD: 0.023508
D: -0.016347
gp: 0.080867
power: 0.020716
GI: 0.847514
MMD: 0.022703
D: -0.017392
gp: 0.071891
power: 0.020728
GI: 0.796209
MMD: 0.018622
D: -0.01

KeyboardInterrupt: 

In [None]:
# training for verification
all_p_vals  = []
all_fake_Ts = []

for lab in all_label:    
    if torch.is_tensor(train_gen.targets):
        idxs2 = torch.where(train_gen.targets == lab)[0] 
    else:
        idxs2 = torch.where(torch.Tensor(train_gen.targets) == lab)[0] 
    test_data = torch.utils.data.Subset(train_gen, idxs2)
    test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    # p_vals and fake_zs store p-values, fake_zs for the current iteration
    fake_Ts = torch.zeros(len(present_label), len(idxs2))
    p_vals = torch.zeros(len(present_label), len(idxs2)) 

    for pidx in range(len(present_label)):
        T_train = T_trains[pidx]
        em_len = len(T_train)
        netI = I_CIFAR10_2(nz=z_dim)
        netI = netI.to(device)
        netI = torch.nn.DataParallel(netI)
        model_save_file = f'cifar10_param/{timestamp}_class{present_label[pidx]}.pt'
        netI.load_state_dict(torch.load(model_save_file))
        
        for i, batch in enumerate(test_loader):
            images, y = batch
            x = images.view(-1, 3, 32 * 32).to(device)
            fake_z = netI(x)
            T_batch = torch.sqrt(torch.sum(torch.square(fake_z), 1) + 1) 
            ## compute p-value for each sample
            for j in range(len(fake_z)):
                p1 = torch.sum(T_train > T_batch[j]) / em_len
                p2 = torch.sum(T_train < T_batch[j]) / em_len
                p = 2*min(p1, p2)
                # calculate the p-value and put it in the corresponding list
                fake_Ts[pidx, i * batch_size + j] = T_batch[j].item()
                p_vals[pidx, i * batch_size + j] = p.item()

    all_p_vals.append(np.array(p_vals))
    ## concatenate torch data
    all_fake_Ts.append(np.array(fake_Ts))
    # print('Finished Label {}'.format(lab))

In [None]:
visualize_T(all_fake_Ts, present_label, all_label, missing_label, z_dim, classes)

In [None]:
visualize_p(all_p_vals, present_label, all_label, missing_label, z_dim, classes)

In [None]:
## test data set
all_p_vals  = []
all_fake_Ts = []

for lab in all_label:    
    if torch.is_tensor(test_gen.targets):
        idxs2 = torch.where(test_gen.targets == lab)[0] 
    else:
        idxs2 = torch.where(torch.Tensor(test_gen.targets) == lab)[0] 
    test_data = torch.utils.data.Subset(test_gen, idxs2)
    test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    # p_vals and fake_zs store p-values, fake_zs for the current iteration
    fake_Ts = torch.zeros(len(present_label), len(idxs2))
    p_vals = torch.zeros(len(present_label), len(idxs2)) 

    for pidx in range(len(present_label)):
        T_train = T_trains[pidx]
        em_len = len(T_train)
        netI = I_CIFAR10_2(nz=z_dim)
        netI = netI.to(device)
        netI = torch.nn.DataParallel(netI)
        model_save_file = f'cifar10_param/{timestamp}_class{present_label[pidx]}.pt'
        netI.load_state_dict(torch.load(model_save_file))
        
        for i, batch in enumerate(test_loader):
            images, y = batch
            x = images.view(-1, 3, 32 * 32).to(device)
            fake_z = netI(x)
            T_batch = torch.sqrt(torch.sum(torch.square(fake_z), 1) + 1) 
            ## compute p-value for each sample
            for j in range(len(fake_z)):
                p1 = torch.sum(T_train > T_batch[j]) / em_len
                p2 = torch.sum(T_train < T_batch[j]) / em_len
                p = 2*min(p1, p2)
                # calculate the p-value and put it in the corresponding list
                fake_Ts[pidx, i * batch_size + j] = T_batch[j].item()
                p_vals[pidx, i * batch_size + j] = p.item()

    all_p_vals.append(np.array(p_vals))
    ## concatenate torch data
    all_fake_Ts.append(np.array(fake_Ts))
    # print('Finished Label {}'.format(lab))

In [None]:
visualize_T(all_fake_Ts, present_label, all_label, missing_label, z_dim, classes)

In [None]:
visualize_p(all_p_vals, present_label, all_label, missing_label, z_dim, classes)

In [None]:
cover_accs = []
avg_counts = []

cover_acc = torch.zeros(len(all_label))
avg_count = torch.zeros(len(all_label))
for i, lab in enumerate(all_label):
    p_vals = all_p_vals[i]
    n = p_vals.shape[1]
    cover = 0.0
    counts = 0.0
    for j in range(n):
        pred = np.argmax(p_vals[:, j])
        p_set = np.where(p_vals[:, j] > 0.05)[0]
        counts += len(p_set)
        if lab in missing_label:
            if len(p_set) == 0:
                cover += 1
        else:
            if all_label[i] in p_set:
                cover += 1
    cover_acc[i] = cover / n
    avg_count[i] = counts / n
cover_accs.append(cover_acc)
avg_counts.append(avg_count)

In [None]:
print(cover_accs)
print(avg_counts)