# CelebA

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
from os import listdir
from os.path import isfile, join
from pathlib import Path
from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import yaml
from sklearn.metrics import roc_auc_score
from torchvision import models, transforms

from celeb_race import CelebRace, unambiguous
from post_hoc_celeba import load_celeba, get_resnet_model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

descriptions = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive',
                'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
                'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair',
                'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses',
                'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
                'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes',
                'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose',
                'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling',
                'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat',
                'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie',
                'Young', 'White', 'Black', 'Asian', 'Index']

In [None]:
def image_from_index(index, folder='~/post_hoc_debiasing/data/celeba/img_align_celeba/', show=False):
    file = str(index).zfill(6)+'.jpg'
    img = Image.open(join(os.path.expanduser(folder), file))
    if show:
        plt.imshow(img)
        plt.show()
    return img

def imshow_from_tensor(img):
    # plot from a tensor. Only works for non-transformed data
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    
def imshow_group(data, n):
    plt.figure(figsize=(20,10))
    columns = n
        
    for i in range(n):
        plt.subplot(1, columns, i + 1)
        img = data[i]
        #img = img.astype(int)
        plt.axis('off')
        plt.imshow(img)

In [None]:
def show_bad_predictions(checkpoint='by_checkpoint.pt', 
                         prediction_attr='Young', 
                         protected_attr='Black',
                         threshold=0,
                         n=8):
    """
    Display images of people from the protected class who were misclassified
    """
    
    # load data
    _, _, _, trainloader, valloader, testloader = load_celeba(trainsize=1000, 
                                                              testsize=10000, 
                                                              num_workers=0, 
                                                              batch_size=4,
                                                              transform_type='tensor')
    # load model and predict
    net = get_resnet_model()
    net.load_state_dict(torch.load(checkpoint)['model_state_dict'])
    
    # find bad predictions
    prediction_index = descriptions.index(prediction_attr)
    protected_index = descriptions.index(protected_attr)
    ind = descriptions.index('Index')
    
    bad_imgs = []
    for (inputs, labels) in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = net(inputs)[:, 0]

        for i in range(len(inputs)):
            label = labels[i][prediction_index].item()
            protected = labels[i][protected_index].item()
            output = outputs[i].item()
            
            # save images from the protected class with a positive label but negative prediction
            if protected and label and output < threshold:
                print('prediction', output)
                index = labels[i][ind].item()
                bad_imgs.append(image_from_index(index))

        if len(bad_imgs) >= n:
            break
    
    imshow_group(bad_imgs[:8], 8)

In [None]:
show_bad_predictions(checkpoint='by_checkpoint.pt', 
                     prediction_attr='Young', 
                     protected_attr='Black',
                     threshold=1)

# Data exploration

In [None]:
def load(n=100, folder='~/post_hoc_debiasing/data/celeba/img_align_celeba/'):
    # convert the folder of images into a numpy array
    
    data = []
    num = 0
    for i in range(1,n+1):
        file = str(i).zfill(6)+'.jpg'
        img = Image.open(join(os.path.expanduser(folder), file))
        img = np.array(img)
        data.append(img)

    data = np.array(data)
    return data



def load_race(filepath='~/post_hoc_debiasing/celebrace/'):
    races = []
    #for i,file in enumerate(['black_100k.npy', 'asian_100k.npy', 'white_100k.npy']):
    for i,file in enumerate(['black_full.npy', 'asian_full.npy', 'white_full.npy']):
        races.append(np.load(os.path.expanduser(os.path.join(filepath, file))))
    return races

def load_attrs(file='~/post_hoc_debiasing/data/celeba/list_attr_celeba.txt', max_n=-1):
    # parse the features
    f = open(os.path.expanduser(file), "r")
    attrs = []
    descriptions = []
    num_attrs = 0
    n = 0
    for index,line in enumerate(f):
    
        #the first row is the header
        if index == 0:
            n = line
        elif index == 1:
            descriptions = [*line.split()]
            num_attrs = len(line.split())
        elif index == max_n:
            break
        else:
            attr = [int(num) for i, num in enumerate(line.split()) if i>0]
            attrs.append(attr)
        
    attrs = np.array(attrs)
    print(attrs.shape)
    return attrs, descriptions

In [None]:
# load all the data
data = load(n=20000) # 202599
print(data.shape)
attrs, descriptions = load_attrs()
races = load_race()


In [None]:
# check the attributes are correct
print(descriptions)
for i in range(3):
    plt.imshow(data[i])
    plt.show()
    for attr in ['Male', 'Attractive', 'Smiling', 'Pale_Skin']:
        print(attr, attrs[i][descriptions.index(attr)])
    print('black', races[0][i])

In [None]:
# check features
print(descriptions)
attr = 'Goatee'
inds = [i for i in range(1000) if attrs[i][descriptions.index(attr)]==1]
imshow_group([data[i] for i in inds[8:16]], 8)

In [None]:
# check races
for race in range(1):
    inds = [i for i in range(20000) if races[race][i]>.6]
    print(len(inds))
    k = 0
    print(inds[8*k:8*(k+1)])
    imshow_group([data[i] for i in inds[8*k:8*(k+1)]], 8)


In [None]:
# check races
for race in [0,2]:
    print('Attractive')
    inds = [i for i in range(20000) if races[race][i]>.8 and attrs[i][descriptions.index('Attractive')]==1]
    imshow_group([data[i] for i in inds[0:8]], 8)
    plt.show()
    print('Unattractive')
    inds = [i for i in range(10000) if races[race][i]>.8 and attrs[i][descriptions.index('Attractive')]==-1]
    imshow_group([data[i] for i in inds[0:8]], 8)
    plt.show()


In [None]:
inds = [i for i in range(1000) if (attrs[i][descriptions.index('Male')]==1 and attrs[i][descriptions.index('Attractive')]==1)]
imshow_group([data[i] for i in inds[0:8]], 8)
inds = [i for i in range(1000) if (attrs[i][descriptions.index('Male')]==-1) and (attrs[i][descriptions.index('Attractive')]==1)]
imshow_group([data[i] for i in inds[0:8]], 8)
inds = [i for i in range(1000) if (attrs[i][descriptions.index('Male')]==1 and attrs[i][descriptions.index('Attractive')]==-1)]
imshow_group([data[i] for i in inds[0:8]], 8)
inds = [i for i in range(1000) if (attrs[i][descriptions.index('Male')]==-1) and (attrs[i][descriptions.index('Attractive')]==-1)]
imshow_group([data[i] for i in inds[0:8]], 8)

In [None]:
inds = [i for i in range(1000) if (attrs[i][descriptions.index('Attractive')]==1)]
imshow_group([data[i] for i in inds[0:8]], 8)
imshow_group([data[i] for i in inds[8:16]], 8)
imshow_group([data[i] for i in inds[16:24]], 8)
inds = [i for i in range(1000) if (attrs[i][descriptions.index('Attractive')]==-1)]
imshow_group([data[i] for i in inds[0:8]], 8)
imshow_group([data[i] for i in inds[8:16]], 8)
imshow_group([data[i] for i in inds[16:24]], 8)


In [None]:
# get statistics for races
counts = [0,0,0]
for i in range(len(races[0])):
    for r in range(3):
        counts[r] += (races[r][i] > .501)

counts = [c / len(races[0]) for c in counts]
print(counts)