In [3]:
import os
import torch
import logging.handlers

from dataloader import get_loader
from DCEVAE import DCEVAE

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=5, help='random seed')
parser.add_argument('--gpu', type=int, default=0, help='number of gpu')
parser.add_argument('--batch_size', type=int, default=64, help='number of gpu')

parser.add_argument('--ur_dim', type=int, default=5, help='dimension of ur')
parser.add_argument('--ud_dim', type=int, default=5, help='dimension of ud')
parser.add_argument('--beta1', type=float, default=1, help='beta1')
parser.add_argument('--beta2', type=float, default=40, help='beta2')
parser.add_argument('--beta3', type=float, default=1, help='beta3')
parser.add_argument('--beta4', type=float, default=3.2, help='beta4')
parser.add_argument('--beta5', type=float, default=1, help='beta5')

parser.add_argument('--int', type=str, default='S', help='intervention variable) M: mustache; S: smiling')

parser.add_argument('--max_epochs', type=int, default=500, help='max epochs')
parser.add_argument('--save_per_epoch', type=int, default=10, help='save per epoch')
parser.add_argument('--early_stop', type=int, default=30, help='early stop')
parser.add_argument('--lr', type=float, default=3e-4, help='learning rate')

args = parser.parse_args([])

import numpy as np
import random

np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

print('-----------------------------------')
print(args.device)
print('-----------------------------------')

src_path = os.path.dirname(os.path.realpath('__file__'))
src_path = os.path.abspath(os.path.join(src_path, os.pardir))
src_path = os.path.abspath(os.path.join(src_path, os.pardir))
data_df = os.path.join(src_path, 'data', 'celebA', 'images')
attr_path = os.path.join(src_path, 'data', 'celebA','list_attr_celeba.txt')

if args.int == 'M':
    whole = ['Young', 'Male', 'Eyeglasses', 'Bald', 'Mustache', 'Smiling', \
             'Wearing_Lipstick', 'Mouth_Slightly_Open', 'Narrow_Eyes']
    sens = ['Mustache']
    des = []
elif args.int == 'S':
    whole = ['Young', 'Male', 'Eyeglasses', 'Bald', 'Mustache', 'Mustache', 'Wearing_Lipstick']
    sens = ['Smiling']
    des = ['Mouth_Slightly_Open', 'Narrow_Eyes']

# Dimension
test_loader = get_loader(data_df, attr_path, whole, sens, des, mode='test')
args.fixed_batch = next(iter(test_loader))
img, sens, rest_att, des_att, _, _ = args.fixed_batch
sens_dim = sens.shape[1]
rest_dim = rest_att.shape[1]
des_dim = des_att.shape[1]

model = DCEVAE(args, sens_dim=sens_dim, rest_dim=rest_dim, des_dim=des_dim, ur_dim=args.ur_dim, \
               ud_dim=args.ud_dim).to(args.device)

src_path = os.path.dirname(os.path.realpath('__file__'))
src_path = os.path.abspath(os.path.join(src_path, os.pardir))
result_path = os.path.join(src_path, 'result')

if not os.path.exists(result_path):
    os.mkdir(result_path)

args.save_path = os.path.join(result_path, "ud_ur_dim_{:d}_{:d}_{:.2f}_{:.2f}_{:.2f}_{:.2f}_{:.2f}"\
                              .format(args.ud_dim, args.ur_dim, args.beta1, args.beta2, args.beta3, args.beta4, args.beta5))
if not os.path.exists(args.save_path):
    os.mkdir(args.save_path)
args.save_path = os.path.join(args.save_path, str(args.int)+str(args.seed))
if not os.path.exists(args.save_path):
    os.mkdir(args.save_path)

-----------------------------------
cuda
-----------------------------------
Finished preprocessing the CelebA dataset...
test 19962


In [4]:
class MidpointNormalize(colors.Normalize):
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        colors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        # Note that I'm ignoring clipping and other edge cases here.
        result, is_scalar = self.process_value(value)
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.array(np.interp(value, x, y), mask=result.mask, copy=False)

In [5]:
import os
import torch

def test(args, test_loader):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_path = os.path.join(args.save_path, 'model.pth')
    test_model = torch.load(model_path)
    test_model.to(device)
    test_model.eval()

    fixed_batch = next(iter(test_loader))
    fixed_data, fixed_sens, fixed_all, _, _, _ = fixed_batch
    fixed_data, fixed_sens, fixed_all = \
    fixed_data.to(device), fixed_sens.to(device), fixed_all.to(device)
    draw_cov(args, test_model, test_loader, 0, test=True)
    print('true')

In [25]:
import matplotlib
matplotlib.use('Agg')
import matplotlib.colors as colors

import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
import os
import torchvision.utils as vutils
import torch

def draw_cov(args, model, loader, epoch, test=True):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    cnt = 0
    for cur_data, cur_sens, cur_rest, cur_des, cur_data2, cur_sens2 in loader:
        cur_data, cur_sens = cur_data.to(device), cur_sens.to(device)

        _, _, u_mu, u_logvar, u = model.forward(cur_data, cur_sens)
        
        #u = u[:,[1, 3, 4, 2,0,7, 8, 9, 6, 5]]
        if args.int == 'S':
            u = u[:, [3, 2, 4, 0, 1, 6, 9, 8, 7, 5]]
        if cnt == 0:
            u_whole = u
        else:
            u_whole = torch.cat([u_whole, u], 0)
        cnt += 1

        if cnt * cur_data.shape[0] > 1000:
            u_whole = u_whole[:1000, :]
            break

    mu = torch.mean(u_whole, 0)
    u_minus_mu = u_whole - mu
    cov_prev = torch.zeros(u.shape[0], u.shape[1], u.shape[1])

    for i in range(u.shape[0]):
        cov_i = torch.matmul(u_minus_mu[i].unsqueeze_(-1), u_minus_mu[i].unsqueeze_(0))
        cov_prev[i, :, :] = cov_i
        
    cmap = matplotlib.cm.RdBu_r  # set the colormap to soemthing diverging

    cov = torch.mean(cov_prev, 0).detach().numpy()
    
    check1 = {}
    for i in range(5):
        check1[np.sum(cov[i, 0:5])] = i
                
    check2 = {}
    for i in range(5):
        check2[np.sum(cov[i+5, 6:10])] = i+5
    
    sorted1 = list(check1.keys())
    sorted1.sort
    idx1 = [check1[idx] for idx in sorted1]
    
    sorted2 = list(check2.keys())
    sorted2.sort
    idx2 = [check2[idx] for idx in sorted2]
    
    cnt = 0
    for cur_data, cur_sens, cur_rest, cur_des, cur_data2, cur_sens2 in loader:
        cur_data, cur_sens = cur_data.to(device), cur_sens.to(device)

        _, _, u_mu, u_logvar, u = model.forward(cur_data, cur_sens)

        u = u[:, idx1 + idx2]
        if cnt == 0:
            u_whole = u
        else:
            u_whole = torch.cat([u_whole, u], 0)
        cnt += 1

        if cnt * cur_data.shape[0] > 1000:
            u_whole = u_whole[:1000, :]
            break

    mu = torch.mean(u_whole, 0)
    u_minus_mu = u_whole - mu
    cov_prev = torch.zeros(u.shape[0], u.shape[1], u.shape[1])

    for i in range(u.shape[0]):
        cov_i = torch.matmul(u_minus_mu[i].unsqueeze_(-1), u_minus_mu[i].unsqueeze_(0))
        cov_prev[i, :, :] = cov_i
        
    cmap = matplotlib.cm.RdBu_r  # set the colormap to soemthing diverging

    cov = torch.mean(cov_prev, 0).detach().numpy()
    
    png_cov = 'test_cov_rearrange.png' if test == True else 'valid_cov_epoch_{:d}.png'.format(epoch)

    # fig = plt.figure()
    # im = plt.imshow(cov, cmap='RdBu')
    # plt.colorbar(im, shrink=0.75)
    # plt.savefig(os.path.join(args.save_path, png_cov))

    elev_min = np.min(cov)
    elev_max = np.max(cov)
    mid_val = 0

    plt.figure()
    plt.imshow(cov, cmap=cmap, clim=(elev_min, elev_max),norm=MidpointNormalize(midpoint=mid_val, vmin=elev_min, vmax=elev_max))
    plt.colorbar()
    plt.show()
    plt.savefig(os.path.join(args.save_path, png_cov))

In [26]:
test(args, test_loader)

true
