# 1. Visualize Quality Enhancement by Showing Some Fake Images
(1). Baseline; (2) DRS; (3) MH-GAN; (4) DRE-F-SP+RS; (5) DRE-F-SP+MH; (6) DRE-F-SP+SIR;

In [2]:
import os
wd = '/home/xin/OneDrive/Working_directory/DDRE_Sampling_GANs/CIFAR10'
os.chdir(wd)
import torch
import torchvision
import torchvision.transforms as transforms
import random
import numpy as np
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.nn import functional as F
from torchvision.utils import save_image
import gc
from sklearn.linear_model import LogisticRegression
from itertools import groupby
from tqdm import tqdm

from SimpleProgressBar import SimpleProgressBar
from models import *
from Train_DCGAN import SampDCGAN
from Train_WGAN import SampWGAN
from Train_MMDGAN import SampMMDGAN
from utils import *

###########################################
#             Overall Settings
###########################################
GAN = "DCGAN" #DCGAN, WGANGP, MMDGAN
DRE = "DRE-F-SP" #Candidate: None, disc, disc_MHcal, DRE-F-SP, DRE-F-uLSIF
Sampler = "SIR" #Candidate: None, RS, MH, SIR
Seed = 101

###########################################
#              GAN Settings
###########################################
dim_GAN = 128
if GAN == "DCGAN":
    ckpt_GAN_name = './Output/saved_models/ckpt_DCGAN_epoch_500_SEED_2019'
elif GAN == "WGANGP":
    ckpt_GAN_name = './Output/saved_models/ckpt_WGANGP_epoch_2000_SEED_2019'
elif GAN == "MMDGAN":
    ckpt_GAN_name = './Output/saved_models/ckpt_MMDGAN_epoch_4000_SEED_2019'
    
###########################################
#              DRE Settings
###########################################
ckpt_PreCNN_name = './Output/saved_models/ckpt_PreCNNForDRE_ResNet34_epoch_200_SEED_2019_Transformation_True'
if GAN == "DCGAN":
    if DRE == "DRE-F-SP":
        ckpt_DR_name = './Output/saved_models/ckpt_DRE_F_SP_MLP5_ReLU_epoch_200_SEED_2019_Lambda_0.0_DCGAN_epoch_500'
    elif DRE == "DRE-F-uLSIF":
        ckpt_DR_name = './Output/saved_models/ckpt_DRE_F_uLSIF_MLP5_ReLU_epoch_200_SEED_2019_Lambda_0.0_DCGAN_epoch_500'
elif GAN == "WGANGP":
    ckpt_DR_name = './Output/saved_models/ckpt_DRE_F_SP_MLP5_ReLU_epoch_200_SEED_2019_Lambda_0.005_WGANGP_epoch_2000'
elif GAN == "MMDGAN":
    ckpt_DR_name = './Output/saved_models/ckpt_DRE_F_SP_MLP5_ReLU_epoch_200_SEED_2019_Lambda_0.006_MMDGAN_epoch_4000'


###########################################
#            Sampling Settings
###########################################
ckpt_PreNetFIDIS_name = './Output/saved_models/ckpt_PreCNNForEvalGANs_InceptionV3_epoch_200_SEED_2019_Transformation_True'
NFAKE = 5000
# Sel_Class = [0,1,2,3,4,5,6,7,8,9]
# Sel_Class = [1,7]
NFAKE_per_class = 50
NPOOL_SIR = 20000 #Pool size for SIR
samp_batch_size = 1000
pred_batch_size = 100
MH_K = 640
MH_mute = True #do not print sampling progress
DR_comp_batch_size = 50
assert samp_batch_size>DR_comp_batch_size
assert NFAKE > NFAKE_per_class
flag_real_imgs = False
nrow=10

###########################################
#              Other Settings
###########################################
N_CLASS = 10
NC = 3 #number of channels
IMG_SIZE = 32
NGPU=torch.cuda.device_count()
device = torch.device("cuda")

random.seed(Seed)
torch.manual_seed(Seed)
torch.backends.cudnn.deterministic = True
np.random.seed(Seed)

# data loader
means = (0.5, 0.5, 0.5)
stds = (0.5, 0.5, 0.5)
transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(means, stds),
    ])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(means, stds),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)


############################################
#          fn for computing DR
############################################
# load pre-trained GAN
checkpoint = torch.load(ckpt_GAN_name)
if GAN == "DCGAN":
    netG = cnn_generator(NGPU, dim_GAN).to(device)
    netD = cnn_discriminator(True, NGPU).to(device)
    def fn_sampleGAN(nfake, batch_size):
        return SampDCGAN(netG, GAN_Latent_Length = dim_GAN, NFAKE = nfake, batch_size = batch_size)
elif GAN == "WGANGP":
    netG = cnn_generator(NGPU, dim_GAN).to(device)
    netD = cnn_discriminator(True, NGPU).to(device)
    def fn_sampleGAN(nfake, batch_size):
        return SampWGAN(netG, GAN_Latent_Length = dim_GAN, NFAKE = nfake, batch_size = batch_size)
elif GAN == "MMDGAN":
    G_decoder = MMDGAN_Decoder(IMG_SIZE, NC, k=dim_GAN, ngf=64, ngpu=NGPU)
    netG = MMDGAN_G(G_decoder).to(device)
    def fn_sampleGAN(nfake, batch_size=samp_batch_size):
        return SampMMDGAN(netG, GAN_Latent_Length = dim_GAN, NFAKE = nfake, batch_size = batch_size)
netG.load_state_dict(checkpoint['netG_state_dict'])
if GAN != "MMDGAN":
    netD.load_state_dict(checkpoint['netD_state_dict'])

# compute density ratios
#--------------------------------------
#use GAN property
if DRE == "disc": 
    def comp_density_ratio(imgs):
            #imgs: an numpy array
            n_imgs = imgs.shape[0]
            batch_size_tmp = DR_comp_batch_size
            dataset_tmp = IMGs_dataset(imgs)
            dataloader_tmp = torch.utils.data.DataLoader(dataset_tmp, batch_size=batch_size_tmp, shuffle=False, num_workers=0)
            data_iter = iter(dataloader_tmp)
            density_ratios = np.zeros((n_imgs+batch_size_tmp, 1))

            netD.eval()
            with torch.no_grad():
                tmp = 0
                while tmp < n_imgs:
                    batch_imgs = data_iter.next()
                    batch_imgs = batch_imgs.type(torch.float).to(device)
                    disc_probs = netD(batch_imgs).cpu().detach().numpy()
                    disc_probs = np.clip(disc_probs.astype(np.float), 1e-14, 1 - 1e-14)
                    density_ratios[tmp:(tmp+batch_size_tmp)] = np.divide(disc_probs, 1-disc_probs)
                    tmp += batch_size_tmp
                #end while
            return density_ratios[0:n_imgs]
#--------------------------------------
#use GAN property and calibration
elif DRE == "disc_MHcal": 
    n_test = testset.data.shape[0]
    batch_size_tmp = DR_comp_batch_size
    cal_labels_fake = np.zeros((n_test,1))
    cal_labels_real = np.ones((n_test,1))
    cal_imgs_fake = fn_sampleGAN(nfake=n_test, batch_size=batch_size_tmp)
    cal_imgs_real = np.transpose(testset.data, (0, 3, 1, 2))
    #standarize real images
    cal_imgs_real = cal_imgs_real/255.0
    for i in range(NC):
        cal_imgs_real[:,i,:,:] = (cal_imgs_real[:,i,:,:]-means[i])/stds[i]
    dataset_fake = IMGs_dataset(cal_imgs_fake)
    dataloader_fake = torch.utils.data.DataLoader(dataset_fake, batch_size=batch_size_tmp, shuffle=False, num_workers=0)
    dataset_real = IMGs_dataset(cal_imgs_real)
    dataloader_real = torch.utils.data.DataLoader(dataset_real, batch_size=batch_size_tmp, shuffle=False, num_workers=0)
    del cal_imgs_fake, cal_imgs_real; gc.collect()

    # get the output of disc before the final sigmoid layer; the \tilde{D} in Eq.(4) in "Discriminator Rejection Sampling"
    # def comp_disc_scores(imgs_dataloader, netD):
    def comp_disc_scores(imgs_dataloader):
        # imgs_dataloader: the data loader for images
        n_imgs = len(imgs_dataloader.dataset)
        data_iter = iter(imgs_dataloader)
        batch_size_tmp = imgs_dataloader.batch_size
        disc_scores = np.zeros((n_imgs+batch_size_tmp, 1))
        netD.eval()
        with torch.no_grad():
            tmp = 0
            while tmp < n_imgs:
                batch_imgs = data_iter.next()
                batch_imgs = batch_imgs.type(torch.float).to(device)
                disc_probs = netD(batch_imgs).cpu().detach().numpy()
                disc_probs = np.clip(disc_probs.astype(np.float), 1e-14, 1 - 1e-14)
                disc_scores[tmp:(tmp+batch_size_tmp)] = np.log(np.divide(disc_probs, 1-disc_probs))
                tmp += batch_size_tmp
            #end while
        return disc_scores[0:n_imgs]

    # compute disc score of a given img which is in tensor format
    # def fn_disc_score(img, netD):
    def fn_disc_score(img):
        #img must be a tensor: 1*NC*IMG_SIZE*IMG_SIZE
        netD.eval()
        with torch.no_grad():
            img = img.type(torch.float).to(device)
            disc_prob = netD(img).cpu().detach().numpy()
            disc_prob = np.clip(disc_prob.astype(np.float), 1e-14, 1-1e-14)
            return np.log(disc_prob/(1-disc_prob))

    cal_disc_scores_fake = comp_disc_scores(dataloader_fake) #discriminator scores for fake images
    cal_disc_scores_real = comp_disc_scores(dataloader_real) #discriminator scores for real images

    # Train a logistic regression model
    X_train = np.concatenate((cal_disc_scores_fake, cal_disc_scores_real),axis=0).reshape(-1,1)
    y_train = np.concatenate((cal_labels_fake, cal_labels_real), axis=0).reshape(-1)
    #del cal_disc_scores_fake, cal_disc_scores_real; gc.collect()
    cal_logReg = LogisticRegression(solver="liblinear").fit(X_train, y_train)

    # function for computing a bunch of images
    # def comp_density_ratio(imgs, netD):
    def comp_density_ratio(imgs):
       #imgs: an numpy array
       dataset_tmp = IMGs_dataset(imgs)
       dataloader_tmp = torch.utils.data.DataLoader(dataset_tmp, batch_size=batch_size_tmp, shuffle=False, num_workers=0)
       disc_scores = comp_disc_scores(dataloader_tmp)
       disc_probs = (cal_logReg.predict_proba(disc_scores))[:,1] #second column corresponds to the real class
       disc_probs = np.clip(disc_probs.astype(np.float), 1e-14, 1 - 1e-14)
       density_ratios = np.divide(disc_probs, 1-disc_probs)
       return density_ratios.reshape(-1,1)
#--------------------------------------
#our DRE method
elif DRE in ["DRE-F-SP","DRE-F-uLSIF"]: 
    # load pre-trained ResNet34
    PreNetDRE = ResNet34(isometric_map = True, num_classes=N_CLASS, ngpu = NGPU).to(device)
    checkpoint = torch.load(ckpt_PreCNN_name)
    PreNetDRE.load_state_dict(checkpoint['net_state_dict'])
    
    # load pre-trained netDR
    netDR = DR_MLP("MLP5", ngpu=NGPU, final_ActFn="ReLU").to(device)
    checkpoint_netDR = torch.load(ckpt_DR_name)
    netDR.load_state_dict(checkpoint_netDR['net_state_dict'])
    
    def comp_density_ratio(imgs):
        #imgs: an numpy array
        n_imgs = imgs.shape[0]
        batch_size_tmp = DR_comp_batch_size
        dataset_tmp = IMGs_dataset(imgs)
        dataloader_tmp = torch.utils.data.DataLoader(dataset_tmp, batch_size=batch_size_tmp, shuffle=False, num_workers=0)
        data_iter = iter(dataloader_tmp)
        density_ratios = np.zeros((n_imgs+batch_size_tmp, 1))

        netDR.eval()
        PreNetDRE.eval()
        # print("\n Begin computing density ratio for images >>")
        with torch.no_grad():
            tmp = 0
            while tmp < n_imgs:
                batch_imgs = data_iter.next()
                batch_imgs = batch_imgs.type(torch.float).to(device)
                _, batch_features = PreNetDRE(batch_imgs)
                batch_weights = netDR(batch_features)
                density_ratios[tmp:(tmp+batch_size_tmp)] = batch_weights.cpu().detach().numpy()
                tmp += batch_size_tmp
            #end while
        # print("\n End computing density ratio.")
        return density_ratios[0:n_imgs]
    
############################################
#          fn for samplers
############################################
#---------------------------------------------
# Rejection Sampling: "Discriminator Rejection Sampling"; based on https://github.com/shinseung428/DRS_Tensorflow/blob/master/config.py
if Sampler == "RS":
    def fn_enhanceSampler(nfake, batch_size=samp_batch_size):
        ## Burn-in Stage
        n_burnin = 50000
        burnin_imgs = fn_sampleGAN(n_burnin, batch_size=samp_batch_size)
        burnin_densityraios = comp_density_ratio(burnin_imgs)
        M_bar = np.max(burnin_densityraios)
        del burnin_imgs, burnin_densityraios; gc.collect()
        torch.cuda.empty_cache()
        ## Rejection sampling
        enhanced_imgs = np.zeros((1, NC, IMG_SIZE, IMG_SIZE)) #initilize
        pb = SimpleProgressBar()
        num_imgs = 0
        while num_imgs < nfake:
            pb.update(float(num_imgs)*100/nfake)
            batch_imgs = fn_sampleGAN(batch_size, batch_size)
            batch_ratios = comp_density_ratio(batch_imgs)
            M_bar = np.max([M_bar, np.max(batch_ratios)])
            #threshold
            if DRE in ["disc", "disc_MHcal"]:
                epsilon_tmp = 1e-8;
                D_tilde_M = np.log(M_bar)
                batch_F = np.log(batch_ratios) - D_tilde_M - np.log(1-np.exp(np.log(batch_ratios)-D_tilde_M-epsilon_tmp))
                gamma_tmp = np.percentile(batch_F, 80) #80 percentile of each batch; follow DRS's setting
                batch_F_hat = batch_F - gamma_tmp
                batch_p = 1/(1+np.exp(-batch_F_hat))
            else:
                batch_p = batch_ratios/M_bar
            batch_psi = np.random.uniform(size=batch_size).reshape(-1,1)
            indx_accept = np.where((batch_psi<=batch_p)==True)[0]
            if len(indx_accept)>0:
                enhanced_imgs = np.concatenate((enhanced_imgs, batch_imgs[indx_accept]))
            num_imgs=len(enhanced_imgs)-1
            del batch_imgs, batch_ratios; gc.collect()
            torch.cuda.empty_cache()
        return enhanced_imgs[1:(nfake+1)] #remove the first all zero array

#---------------------------------------------
# MCMC, Metropolis-Hastings algorithm: MH-GAN
elif Sampler == "MH":
    trainloader_MH = torch.utils.data.DataLoader(trainset, batch_size=samp_batch_size, shuffle=True, num_workers=0)
    def fn_enhanceSampler(nfake, batch_size=samp_batch_size):
        enhanced_imgs = np.zeros((1, NC, IMG_SIZE, IMG_SIZE)) #initilize
        pb = SimpleProgressBar()
        num_imgs = 0
        while num_imgs < nfake:
            data_iter = iter(trainloader_MH)
            batch_imgs_new, _ = data_iter.next()
            batch_imgs_new = batch_imgs_new.cpu().detach().numpy()
            batch_update_flags = np.zeros(batch_size) #if an img in a batch is updated during MH, replace corresponding entry with 1
            for k in tqdm(range(MH_K)):
                if not MH_mute:
                    print((k, num_imgs))
                batch_imgs_old = fn_sampleGAN(batch_size, batch_size)
                batch_U = np.random.uniform(size=batch_size).reshape(-1,1)
                batch_ratios_old = comp_density_ratio(batch_imgs_old)
                batch_ratios_new = comp_density_ratio(batch_imgs_new)
                batch_p = batch_ratios_old/(batch_ratios_new+1e-14)
                batch_p[batch_p>1]=1
                indx_accept = np.where((batch_U<=batch_p)==True)[0]
                if len(indx_accept)>0:
                    batch_imgs_new[indx_accept] = batch_imgs_old[indx_accept]
                    batch_update_flags[indx_accept] = 1 #if an img in a batch is updated during MH, replace corresponding entry with 1
            indx_updated = np.where(batch_update_flags==1)[0]
            enhanced_imgs = np.concatenate((enhanced_imgs, batch_imgs_new[indx_updated]))
            num_imgs=len(enhanced_imgs)-1
            print("MH already got %d/%d images" % (num_imgs,nfake))
            del batch_imgs_new, batch_imgs_old; gc.collect()
            torch.cuda.empty_cache()
        return enhanced_imgs[1:(nfake+1)] #remove the first all zero array

#---------------------------------------------
# Sampling-Importance Resampling
elif Sampler == "SIR":
   def fn_enhanceSampler(nfake, batch_size=samp_batch_size):
       enhanced_imgs = fn_sampleGAN(NPOOL_SIR, batch_size)
       enhanced_ratios = comp_density_ratio(enhanced_imgs)
       weights = enhanced_ratios / np.sum(enhanced_ratios) #normlaize to [0,1]
       resampl_indx = np.random.choice(a = np.arange(len(weights)), size = nfake, replace = True, p = weights.reshape(weights.shape[0]))
       enhanced_imgs = enhanced_imgs[resampl_indx]
       return enhanced_imgs


############################################
#          Draw fake images
############################################
# load pre-trained InceptionV3 (pretrained on CIFAR-10)
PreNetFIDIS = Inception3(num_classes=10, aux_logits=True, transform_input=False)
Filename_PreCNNForEvalGANs = './Output/saved_models/ckpt_PreCNNForEvalGANs_InceptionV3_epoch_200_SEED_2019_Transformation_True'
checkpoint_PreNet = torch.load(Filename_PreCNNForEvalGANs)
PreNetFIDIS = nn.DataParallel(PreNetFIDIS).to(device)
PreNetFIDIS.load_state_dict(checkpoint_PreNet['net_state_dict'])

# generate fake samples
if DRE == "None" and Sampler == "None":
    print("Directly sample from GAN >>>")
    fake_imgs = fn_sampleGAN(NFAKE, samp_batch_size)
else:
    print("Enhanced Sampling >>>")
    fake_imgs=fn_enhanceSampler(NFAKE, batch_size=samp_batch_size)
torch.cuda.empty_cache()
    
# classify them into 10 classes
fake_labels = PredictLabel(fake_imgs, PreNetFIDIS, N_CLASS = N_CLASS, BATCH_SIZE = pred_batch_size, resize = (299, 299))
fake_labels = fake_labels.astype(np.int)
num_each_class = [len(list(group)) for key, group in groupby(np.sort(fake_labels))]


#print("select NFAKE_per_class samples from each class >>>")
# fake_imgs_even = np.zeros((NFAKE_per_class*N_CLASS, NC, IMG_SIZE, IMG_SIZE))
# pb = SimpleProgressBar()
# tmp = 0
# for i in range(N_CLASS):
#     N_tmp = num_each_class[i] #number of fake images per class
#     assert N_tmp > NFAKE_per_class
#     idx_i_tmp = np.where(fake_labels==i)[0]
#     np.random.shuffle(idx_i_tmp)
#     idx_i_tmp = idx_i_tmp[0:NFAKE_per_class]
#     fake_imgs_even[tmp:(tmp+NFAKE_per_class)] = fake_imgs[idx_i_tmp]
#     tmp += NFAKE_per_class
#     pb.update(float(tmp)*100/(NFAKE_per_class*N_CLASS))

SyntaxError: EOL while scanning string literal (<ipython-input-2-ebb9a336bd64>, line 2)

In [8]:
# select NFAKE_per_class samples from each class
NFAKE_per_class = 10
Sel_Class = [0,1,2,3,4,5,6,7,8,9]
# Sel_Class = [1,7]
print("select NFAKE_per_class samples from each class >>>")
fake_imgs_even = np.zeros((NFAKE_per_class*len(Sel_Class), NC, IMG_SIZE, IMG_SIZE))
pb = SimpleProgressBar()
tmp = 0
for i in range(len(Sel_Class)):
    N_tmp = num_each_class[Sel_Class[i]] #number of fake images per class
    assert N_tmp > NFAKE_per_class
    idx_i_tmp = np.where(fake_labels==Sel_Class[i])[0]
    np.random.shuffle(idx_i_tmp)
    idx_i_tmp = idx_i_tmp[0:NFAKE_per_class]
    fake_imgs_even[tmp:(tmp+NFAKE_per_class)] = fake_imgs[idx_i_tmp]
    tmp += NFAKE_per_class
    pb.update(float(tmp)*100/(NFAKE_per_class*len(Sel_Class))) 

#save selected fake images
image_filename = './Output/saved_images/Visualize_Quality_Improvement_GAN_'\
                +GAN+'_DRE_'+DRE+'_Sampler_'+Sampler+'.pdf'
save_image(torch.from_numpy(fake_imgs_even).data, image_filename, nrow=nrow, normalize=True)


select NFAKE_per_class samples from each class >>>
100% [##################################################]


In [None]:
#find closest real imgs; using mean square error
#if the closest real images are very similar to their counterpart, we may have overfitting problems here.
print("Find closest real images for fake images >>>")
IMGSr_train = np.transpose(trainset.data, (0, 3, 1, 2))
IMGSr_train = IMGSr_train/255.0
for i in range(3):
    IMGSr_train[:,i,:,:] = (IMGSr_train[:,i,:,:] - 0.5) / 0.5
closest_real_imgs_for_fake = np.zeros((fake_imgs_even.shape[0], NC, IMG_SIZE, IMG_SIZE))    
mse_fake_to_real = np.zeros((fake_imgs_even.shape[0], IMGSr_train.shape[0]))
for i in tqdm(range(fake_imgs_even.shape[0])):
    for j in range(IMGSr_train.shape[0]):
        mse_fake_to_real[i,j] = np.mean((fake_imgs_even[i]-IMGSr_train[j])**2)
    indx_min = np.argmin(mse_fake_to_real[i,:])
    closest_real_imgs_for_fake[i] = IMGSr_train[indx_min]

#save closest real images
image_filename = './Output/saved_images/Visualize_Quality_Improvement_GAN_'\
                +GAN+'_DRE_'+DRE+'_Sampler_'+Sampler+'_ClosestRealImgs.pdf'
save_image(torch.from_numpy(closest_real_imgs_for_fake).data, image_filename, nrow=nrow, normalize=True)

In [None]:
# NFAKE_per_class = 10
# Sel_Class = [0,1,2,3,4,5,6,7,8,9]
# # Sel_Class = [1,7]

#if flag_real_imgs: #output real images
real_imgs_test = np.transpose(testset.data, (0, 3, 1, 2))
#rescale to [0,1]
real_imgs_test = real_imgs_test/255.0
#rescale to [-1,1]
for i in range(3):
   real_imgs_test[:,i,:,:] = (real_imgs_test[:,i,:,:] - 0.5) / 0.5
real_labels_test = np.array(testset.targets)

# select NFAKE_per_class real samples from each class
print("select NFAKE_per_class real samples from each class >>>")
real_imgs_even = np.zeros((NFAKE_per_class*len(Sel_Class), NC, IMG_SIZE, IMG_SIZE))
pb = SimpleProgressBar()
tmp = 0
for i in range(len(Sel_Class)):
    idx_i_tmp = np.where(real_labels_test==Sel_Class[i])[0]
    np.random.shuffle(idx_i_tmp)
    idx_i_tmp = idx_i_tmp[0:NFAKE_per_class]
    real_imgs_even[tmp:(tmp+NFAKE_per_class)] = real_imgs_test[idx_i_tmp]
    tmp += NFAKE_per_class
    pb.update(float(tmp)*100/(NFAKE_per_class*len(Sel_Class)))

#save images
image_filename = './Output/saved_images/Visualize_Quality_Improvement_GAN_'\
                +GAN+'_real_data.pdf'
save_image(torch.from_numpy(real_imgs_even).data, image_filename, nrow=nrow, normalize=True)

# 2. Compare convergence of uLSIF and SP loss with DCGAN samples
Fix DR model (MLP5 and ReLU) and Sampler (RS) <br/>
Plot training losses under lambda's

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import os
%matplotlib inline

wd = "/home/xin/Working directory/DDRE_Sampling_GANs/CIFAR10/Output/Training_loss_fig/Compare_convergence_diff_loss"
os.chdir(wd)
# name of npy files; under the optimal setting

file_uLSIF_1 = wd+"/DRE_F_uLSIF_MLP5_LAMBDA0.0_epochDRE500_DCGAN_epochGAN500_TrainLoss.npy"
file_uLSIF_2 = wd+"/DRE_F_uLSIF_MLP5_LAMBDA0.1_epochDRE500_DCGAN_epochGAN500_TrainLoss.npy"
file_uLSIF_3 = wd+"/DRE_F_uLSIF_MLP5_LAMBDA1.0_epochDRE500_DCGAN_epochGAN500_TrainLoss.npy"
file_uLSIF_4 = wd+"/DRE_F_uLSIF_MLP5_LAMBDA5.0_epochDRE500_DCGAN_epochGAN500_TrainLoss.npy"
file_uLSIF_5 = wd+"/DRE_F_uLSIF_MLP5_LAMBDA10.0_epochDRE500_DCGAN_epochGAN500_TrainLoss.npy"

file_SP_1 = wd+"/DRE_F_SP_MLP5_LAMBDA0.0_epochDRE500_DCGAN_epochGAN500_TrainLoss.npy"
file_SP_2 = wd+"/DRE_F_SP_MLP5_LAMBDA0.1_epochDRE500_DCGAN_epochGAN500_TrainLoss.npy"
file_SP_3 = wd+"/DRE_F_SP_MLP5_LAMBDA1.0_epochDRE500_DCGAN_epochGAN500_TrainLoss.npy"
file_SP_4 = wd+"/DRE_F_SP_MLP5_LAMBDA5.0_epochDRE500_DCGAN_epochGAN500_TrainLoss.npy"
file_SP_5 = wd+"/DRE_F_SP_MLP5_LAMBDA10.0_epochDRE500_DCGAN_epochGAN500_TrainLoss.npy"

# load training loss
loss_uLSIF_1 = np.load(file_uLSIF_1)
loss_uLSIF_2 = np.load(file_uLSIF_2)
loss_uLSIF_3 = np.load(file_uLSIF_3)
loss_uLSIF_4 = np.load(file_uLSIF_4)
loss_uLSIF_5 = np.load(file_uLSIF_5)
loss_SP_1 = np.load(file_SP_1)
loss_SP_2 = np.load(file_SP_2)
loss_SP_3 = np.load(file_SP_3)
loss_SP_4 = np.load(file_SP_4)
loss_SP_5 = np.load(file_SP_5)


# plot training curves
filename_uLISF = wd+"/Convergence_uLISF.pdf"
filename_SP = wd+"/Convergence_SP.pdf"
num_epochs = 500
x = np.arange(start = 1, stop = num_epochs+1)

f1 = plt.figure()
#f1.set_size_inches(10, 8)
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.plot(x, loss_uLSIF_1, c='r', label = r"$\lambda=0$")
plt.plot(x, loss_uLSIF_2, c='b', label = r"$\lambda=0.1$")
plt.plot(x, loss_uLSIF_3, c='y', label = r"$\lambda=1$")
plt.plot(x, loss_uLSIF_4, c='g', label = r"$\lambda=5$")
plt.plot(x, loss_uLSIF_5, c='m', label = r"$\lambda=10$")
plt.xlabel("epoch")
plt.ylabel("training loss")
plt.legend()
plt.show()
f1.savefig(filename_uLISF, bbox_inches='tight')

f2 = plt.figure()
#f2.set_size_inches(10, 8)
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.plot(x, loss_SP_1, c='r', label = r"$\lambda=0$")
plt.plot(x, loss_SP_2, c='b', label = r"$\lambda=0.1$")
plt.plot(x, loss_SP_3, c='y', label = r"$\lambda=1$")
plt.plot(x, loss_SP_4, c='g', label = r"$\lambda=5$")
plt.plot(x, loss_SP_5, c='m', label = r"$\lambda=10$")
plt.xlabel("epoch")
plt.ylabel("training loss")
plt.legend()
plt.show()
f2.savefig(filename_SP, bbox_inches='tight')


# 3. Visualize the IS or FID under different lambda’s for DRE-F-SP