In [35]:
import os
import torch
import logging.handlers

from dataloader import get_loader
from MCEVAE import MCEVAE

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=5, help='random seed')
parser.add_argument('--gpu', type=int, default=0, help='number of gpu')
parser.add_argument('--batch_size', type=int, default=64, help='number of gpu')

parser.add_argument('--ur_dim', type=int, default=5, help='dimension of ur')
parser.add_argument('--ud_dim', type=int, default=5, help='dimension of ud')
parser.add_argument('--beta1', type=float, default=1, help='beta1')
parser.add_argument('--beta2', type=float, default=40, help='beta2')
parser.add_argument('--beta3', type=float, default=1, help='beta3')
parser.add_argument('--beta4', type=float, default=3.2, help='beta4')
parser.add_argument('--beta5', type=float, default=1, help='beta5')

parser.add_argument('--int', type=str, default='S', help='intervention variable) M: mustache; S: smiling')

parser.add_argument('--max_epochs', type=int, default=500, help='max epochs')
parser.add_argument('--save_per_epoch', type=int, default=10, help='save per epoch')
parser.add_argument('--early_stop', type=int, default=30, help='early stop')
parser.add_argument('--lr', type=float, default=3e-4, help='learning rate')

args = parser.parse_args([])

import numpy as np
import random

np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

print('-----------------------------------')
print(args.device)
print('-----------------------------------')

src_path = os.path.dirname(os.path.realpath('__file__'))
src_path = os.path.abspath(os.path.join(src_path, os.pardir))
src_path = os.path.abspath(os.path.join(src_path, os.pardir))
data_df = os.path.join(src_path, 'data', 'celebA', 'images')
attr_path = os.path.join(src_path, 'data', 'celebA','list_attr_celeba.txt')

if args.int == 'M':
    whole = ['Young', 'Male', 'Eyeglasses', 'Bald', 'Mustache', 'Smiling', \
             'Wearing_Lipstick', 'Mouth_Slightly_Open', 'Narrow_Eyes']
    sens = ['Mustache']
    des = []
elif args.int == 'S':
    whole = ['Young', 'Male', 'Eyeglasses', 'Bald', 'Mustache', 'Mustache', 'Wearing_Lipstick']
    sens = ['Smiling']
    des = ['Mouth_Slightly_Open', 'Narrow_Eyes']

# Dimension
test_loader = get_loader(data_df, attr_path, whole, sens, des, mode='test')
args.fixed_batch = next(iter(test_loader))
img, sens, rest_att, des_att = args.fixed_batch
sens_dim = sens.shape[1]
rest_dim = rest_att.shape[1]
des_dim = des_att.shape[1]

model = MCEVAE(args, sens_dim=sens_dim, rest_dim=rest_dim, des_dim=des_dim, u_dim=10).to(args.device)

src_path = os.path.dirname(os.path.realpath('__file__'))
src_path = os.path.abspath(os.path.join(src_path, os.pardir))
result_path = os.path.join(src_path, 'result_S')

hyp = 'u_dim_10_1.0_40.0_1.0_0.1_5.0'
args.save_path = os.path.join(result_path, hyp)
if not os.path.exists(args.save_path):
    os.mkdir(args.save_path)
# args.save_path = os.path.join(args.save_path, str(args.int)+str(args.seed))
# if not os.path.exists(args.save_path):
#     os.mkdir(args.save_path)

-----------------------------------
cuda
-----------------------------------
Finished preprocessing the CelebA dataset...
test 19962


In [36]:
import matplotlib.colors as colors
class MidpointNormalize(colors.Normalize):
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        colors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        # Note that I'm ignoring clipping and other edge cases here.
        result, is_scalar = self.process_value(value)
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.array(np.interp(value, x, y), mask=result.mask, copy=False)

In [37]:
import os
import torch

def test(args, test_loader):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_path = os.path.join(args.save_path, 'model.pth')
    test_model = torch.load(model_path)
    test_model.to(device)
    test_model.eval()

    fixed_batch = next(iter(test_loader))
    fixed_data, fixed_sens, fixed_rest, fixed_des = fixed_batch
    fixed_data, fixed_sens, fixed_rest, fixed_des = \
    fixed_data.to(device), fixed_sens.to(device), fixed_rest.to(device), fixed_des.to(device)
    draw_cov(args, test_model, test_loader, 0, test=True)
    print('true')

In [103]:
import matplotlib
matplotlib.use('Agg')
import matplotlib.colors as colors

import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
import os
import torchvision.utils as vutils
import torch

def draw_cov(args, model, loader, epoch, test=True):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    cnt = 0
    for cur_data, cur_sens, cur_rest, cur_des in loader:
        cur_data, cur_sens = cur_data.to(device), cur_sens.to(device)
        cur_rest, cur_des = cur_rest.to(device), cur_des.to(device)


        _, _, u_mu, u_logvar, u = model.forward(cur_data, cur_sens, cur_rest, cur_des)
        
        if cnt == 0:
            u_whole = u
        else:
            u_whole = torch.cat([u_whole, u], 0)
        cnt += 1

        if cnt * cur_data.shape[0] > 1000:
            u_whole = u_whole[:1000, :]
            break

    mu = torch.mean(u_whole, 0)
    u_minus_mu = u_whole - mu
    cov_prev = torch.zeros(u.shape[0], u.shape[1], u.shape[1])

    for i in range(u.shape[0]):
        cov_i = torch.matmul(u_minus_mu[i].unsqueeze_(-1), u_minus_mu[i].unsqueeze_(0))
        cov_prev[i, :, :] = cov_i
        
    cmap = matplotlib.cm.RdBu_r  # set the colormap to soemthing diverging

    cov = torch.mean(cov_prev, 0).detach().numpy()
    
    for i in range(10):
        cov[i,i] = 0
    
    check1 = {}
    for i in range(10):
        check1[np.sum(cov[i, 0:9])] = i
    
    sorted1 = list(check1.keys())
    sorted1.sort()
    sorted1 = sorted1[::-1]
    print(sorted1)
    idx1 = [check1[idx] for idx in sorted1]
    
    cnt = 0
    for cur_data, cur_sens, cur_rest, cur_des in loader:
        if cnt == 0:
            sens = cur_sens
        else:
            sens = torch.cat([sens, cur_sens])
        cur_data, cur_sens = cur_data.to(device), cur_sens.to(device)
        cur_rest, cur_des = cur_rest.to(device), cur_des.to(device)


        _, _, u_mu, u_logvar, u = model.forward(cur_data, cur_sens, cur_rest, cur_des)

        u = u[:, idx1]
        if cnt == 0:
            u_whole = u
        else:
            u_whole = torch.cat([u_whole, u], 0)
        cnt += 1

        if cnt * cur_data.shape[0] > 1000:
            u_whole = u_whole[:1000, :]
            sens = sens[:1000]
            break
            
    print(sens.shape)
    sens = sens.squeeze(1)
    print(sens.shape)
    a0 = (sens==0)
    a0 = a0.sum().item()
    print(a0)
    a1 = (sens==1)
    a1 = a1.sum().item()
    print(a1)
    mu = torch.mean(u_whole, 0)
    u_minus_mu = u_whole - mu
    cov_prev = torch.zeros(u.shape[0], u.shape[1], u.shape[1])
    cov_prev0 = torch.zeros(a0, u.shape[1], u.shape[1])
    cov_prev1 = torch.zeros(a1, u.shape[1], u.shape[1])

    for i in range(u.shape[0]):
        cov_i = torch.matmul(u_minus_mu[i].unsqueeze_(-1), u_minus_mu[i].unsqueeze_(0))
        cov_prev[i, :, :] = cov_i
        
        if sens[i] == 0:
            cov_prev0[i, :, :] = cov_i
        else:
            cov_prev1[i, :, :] = cov_i
        
    cmap = matplotlib.cm.RdBu_r  # set the colormap to soemthing diverging
    
    cov = torch.mean(cov_prev, 0).detach().numpy()
    cov0 = torch.mean(cov_prev0, 0).detach().numpy()
    cov1 = torch.mean(cov_prev1, 0).detach().numpy()
    
    for i in range(10):
        cov[i,i] = 0
        cov0[i,i] = 0
        cov1[i,i] = 0

    png_cov0 = 'test_cov0_rearrange.png' if test == True else 'valid_cov_epoch_{:d}.png'.format(epoch)
    png_cov1 = 'test_cov1_rearrange.png' if test == True else 'valid_cov_epoch_{:d}.png'.format(epoch)
    png_cov = 'test_cov_rearrange.png' if test == True else 'valid_cov_epoch_{:d}.png'.format(epoch)

    # fig = plt.figure()
    # im = plt.imshow(cov, cmap='RdBu')
    # plt.colorbar(im, shrink=0.75)
    # plt.savefig(os.path.join(args.save_path, png_cov))

    elev_min = np.min(cov)
    elev_max = np.max(cov)
    mid_val = 0
    
    elev_min0 = np.min(cov0)
    elev_max0 = np.max(cov0)
    
    elev_min1 = np.min(cov1)
    elev_max1 = np.max(cov1)

    plt.figure()
    plt.imshow(cov, cmap=cmap, clim=(elev_min, elev_max),norm=MidpointNormalize(midpoint=mid_val, vmin=elev_min, vmax=elev_max))
    plt.colorbar()
    plt.show()
    plt.savefig(os.path.join(args.save_path, png_cov))
    
    plt.figure()
    plt.imshow(cov0, cmap=cmap, clim=(elev_min0, elev_max0),norm=MidpointNormalize(midpoint=mid_val, vmin=elev_min, vmax=elev_max))
    plt.colorbar()
    plt.show()
    plt.savefig(os.path.join(args.save_path, png_cov0))
    
    plt.figure()
    plt.imshow(cov1, cmap=cmap, clim=(elev_min1, elev_max1),norm=MidpointNormalize(midpoint=mid_val, vmin=elev_min, vmax=elev_max))
    plt.colorbar()
    plt.show()
    plt.savefig(os.path.join(args.save_path, png_cov1))

In [104]:
test(args, test_loader)

[32.650993, 22.619572, 20.259317, 15.888016, 14.351412, 3.8354201, -0.8162542, -2.3802693, -11.860994, -24.991594]
torch.Size([1000, 1])
torch.Size([1000])
488
512
true
