In [1]:
import math, os, random, json, pickle, sys, pdb
import string, shutil, time, argparse
import numpy as np
import itertools

from sklearn.metrics import average_precision_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from tqdm import tqdm as tqdm
from PIL import Image

import torch.nn.functional as F
import torch, torchvision
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Function

from torchvision.utils import save_image
from torch.utils.data import DataLoader

from data_loader import ImSituVerbGender
from adv_model import VerbClassificationAdv
from logger import Logger

from tqdm.notebook import tqdm, trange

In [2]:

def save_checkpoint(args, state, is_best, filename):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(args.save_dir, 'model_best.pth.tar'))


def train(args, epoch, model, criterion, train_loader, optimizer, \
        train_logger, logging=True):

    # set the training model
    model.train()
    nProcessed = 0
    task_preds, adv_preds = [], []
    task_truth, adv_truth = [], []
    nTrain = len(train_loader.dataset) # number of images
    task_loss_logger = AverageMeter()
    adv_loss_logger = AverageMeter()


    t = tqdm(train_loader, desc = 'Train %d' % epoch)
    for batch_idx, (images, targets, genders, image_ids) in enumerate(t):
        #if batch_idx == 100: break # constrain epoch size

        # Set mini-batch dataset
        if args.batch_balanced:
            man_idx = genders[:, 0].nonzero().squeeze()
            if len(man_idx.size()) == 0: man_idx = man_idx.unsqueeze(0)
            woman_idx = genders[:, 1].nonzero().squeeze()
            if len(woman_idx.size()) == 0: woman_idx = woman_idx.unsqueeze(0)
            selected_num = min(len(man_idx), len(woman_idx))

            if selected_num < args.batch_size / 2:
                continue
            else:
                selected_num = args.batch_size / 2
                selected_idx = torch.cat((man_idx[:selected_num], woman_idx[:selected_num]), 0)

            images = torch.index_select(images, 0, selected_idx)
            targets = torch.index_select(targets, 0, selected_idx)
            genders = torch.index_select(genders, 0, selected_idx)

        images = images.cuda()
        targets = targets.cuda()
        genders = genders.cuda()

        # Forward, Backward and Optimizer
        task_pred, adv_pred = model(images)

        task_loss = criterion(task_pred, targets.max(1, keepdim=False)[1])
        adv_loss = F.cross_entropy(adv_pred, genders.max(1, keepdim=False)[1], reduction='mean')

        task_loss_logger.update(task_loss.item())
        adv_loss_logger.update(adv_loss.item())

        adv_pred = np.argmax(F.softmax(adv_pred, dim=1).cpu().detach().numpy(), axis=1)
        adv_preds += adv_pred.tolist()
        adv_truth += genders.cpu().max(1, keepdim=False)[1].numpy().tolist()

        task_pred = F.softmax(task_pred, dim=1)
        if batch_idx > 0 and len(task_preds) > 0:
            task_preds = torch.cat((task_preds, task_pred.detach().cpu()), 0)
            task_truth = torch.cat((task_truth, targets.cpu()), 0)
            total_genders = torch.cat((total_genders, genders.cpu()), 0)
        else:
            task_preds = task_pred.detach().cpu()
            task_truth = targets.cpu()
            total_genders = genders.cpu()

        loss = task_loss + adv_loss

        # backpropogation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    task_f1_score = f1_score(task_truth.max(1)[1].numpy(), task_preds.max(1)[1].numpy(), average = 'macro')

    man_idx = total_genders[:, 0].nonzero().squeeze()
    woman_idx = total_genders[:, 1].nonzero().squeeze()
    preds_man = torch.index_select(task_preds, 0, man_idx)
    preds_woman = torch.index_select(task_preds, 0, woman_idx)
    targets_man = torch.index_select(task_truth, 0, man_idx)
    targets_woman = torch.index_select(task_truth, 0, woman_idx)
    meanAP = average_precision_score(task_truth.numpy(), task_preds.numpy(), average='macro')
    meanAP_man = average_precision_score(targets_man.numpy(), preds_man.numpy(), average='macro')
    meanAP_woman = average_precision_score(targets_woman.numpy(), preds_woman.numpy(), average='macro')
    adv_acc = accuracy_score(adv_truth, adv_preds)

    if logging:
        train_logger.scalar_summary('task loss', task_loss_logger.avg, epoch)
        train_logger.scalar_summary('adv loss', adv_loss_logger.avg, epoch)
        train_logger.scalar_summary('task_f1_score', task_f1_score, epoch)
        train_logger.scalar_summary('meanAP', meanAP, epoch)
        train_logger.scalar_summary('meanAP_man', meanAP_man, epoch)
        train_logger.scalar_summary('meanAP_woman', meanAP_woman, epoch)
        train_logger.scalar_summary('adv acc', adv_acc, epoch)

    print('man size: {} woman size: {}'.format(len(man_idx), len(woman_idx)))
    print('Train epoch  : {}, meanAP: {:.2f}, meanAP_man: {:.2f}, meanAP_woman: {:.2f}, adv acc: {:.2f}, '.format( \
        epoch, meanAP*100, meanAP_man*100, meanAP_woman*100, adv_acc*100))

def test_balanced(args, epoch, model, criterion, val_loader, val_logger, print_every=10000, logging=True):

    # set eval
    model.eval()
    nProcessed = 0
    task_preds, adv_preds = [], []
    task_truth, adv_truth = [], []
    nTest = len(val_loader.dataset) # number of images
    task_loss_logger = AverageMeter()
    adv_loss_logger = AverageMeter()

    t = tqdm(val_loader, desc = 'Val balanced %d' % epoch)
    for batch_idx, (images, targets, genders, image_ids) in enumerate(t):
        #if batch_idx == 100: break # constrain epoch size

        # Set mini-batch dataset
        images = images.cuda()
        targets = targets.cuda()
        genders = genders.cuda()

        # Forward, Backward and Optimizer
        task_pred, adv_pred = model(images)

        task_loss = criterion(task_pred, targets.max(1, keepdim=False)[1])
        adv_loss = F.cross_entropy(adv_pred, genders.max(1, keepdim=False)[1], reduction='mean')

        task_loss_logger.update(task_loss.item())
        adv_loss_logger.update(adv_loss.item())

        adv_pred = np.argmax(F.softmax(adv_pred, dim=1).cpu().detach().numpy(), axis=1)
        adv_preds += adv_pred.tolist()
        adv_truth += genders.cpu().max(1, keepdim=False)[1].numpy().tolist()

        task_pred = F.softmax(task_pred, dim=1)
        if batch_idx > 0 and len(task_preds) > 0:
            task_preds = torch.cat((task_preds, task_pred.detach().cpu()), 0)
            task_truth = torch.cat((task_truth, targets.cpu()), 0)
            total_genders = torch.cat((total_genders, genders.cpu()), 0)
        else:
            task_preds = task_pred.detach().cpu()
            task_truth = targets.cpu()
            total_genders = genders.cpu()

        loss = task_loss + adv_loss

    task_f1_score = f1_score(task_truth.max(1)[1].numpy(), task_preds.max(1)[1].numpy(), average = 'macro')

    man_idx = total_genders[:, 0].nonzero().squeeze()
    woman_idx = total_genders[:, 1].nonzero().squeeze()
    preds_man = torch.index_select(task_preds, 0, man_idx)
    preds_woman = torch.index_select(task_preds, 0, woman_idx)
    targets_man = torch.index_select(task_truth, 0, man_idx)
    targets_woman = torch.index_select(task_truth, 0, woman_idx)
    meanAP = average_precision_score(task_truth.numpy(), task_preds.numpy(), average='macro')
    meanAP_man = average_precision_score(targets_man.numpy(), preds_man.numpy(), average='macro')
    meanAP_woman = average_precision_score(targets_woman.numpy(), preds_woman.numpy(), average='macro')
    adv_acc = accuracy_score(adv_truth, adv_preds)

    if logging:
        val_logger.scalar_summary('adv loss balanced', adv_loss_logger.avg, epoch)
        val_logger.scalar_summary('adv acc balanced', adv_acc, epoch)

    print('man size: {} woman size: {}'.format(len(man_idx), len(woman_idx)))
    print('Test epoch(f): {}, meanAP: {:.2f}, meanAP_man: {:.2f}, meanAP_woman: {:.2f}, adv acc: {:.2f}, '.format( \
        epoch, meanAP*100, meanAP_man*100, meanAP_woman*100, adv_acc*100))

    return task_f1_score

def test(args, epoch, model, criterion, val_loader, val_logger, print_every=10000, logging=True):
    model.eval()
    nProcessed = 0
    task_preds, adv_preds = [], []
    task_truth, adv_truth = [], []
    nTest = len(val_loader.dataset) # number of images
    task_loss_logger = AverageMeter()
    adv_loss_logger = AverageMeter()

    t = tqdm(val_loader, desc = 'Val %d' % epoch)
    for batch_idx, (images, targets, genders, image_ids) in enumerate(t):
        #if batch_idx == 100: break # constrain epoch size

        # Set mini-batch dataset
        images = images.cuda()
        targets = targets.cuda()
        genders = genders.cuda()

        # Forward, Backward and Optimizer
        task_pred, adv_pred = model(images)

        task_loss = criterion(task_pred, targets.max(1, keepdim=False)[1])
        adv_loss = F.cross_entropy(adv_pred, genders.max(1, keepdim=False)[1], reduction='mean')

        task_loss_logger.update(task_loss.item())
        adv_loss_logger.update(adv_loss.item())

        adv_pred = np.argmax(F.softmax(adv_pred, dim=1).cpu().detach().numpy(), axis=1)
        adv_preds += adv_pred.tolist()
        adv_truth += genders.cpu().max(1, keepdim=False)[1].numpy().tolist()

        task_pred = F.softmax(task_pred, dim=1)
        if batch_idx > 0 and len(task_preds) > 0:
            task_preds = torch.cat((task_preds, task_pred.detach().cpu()), 0)
            task_truth = torch.cat((task_truth, targets.cpu()), 0)
            total_genders = torch.cat((total_genders, genders.cpu()), 0)
        else:
            task_preds = task_pred.detach().cpu()
            task_truth = targets.cpu()
            total_genders = genders.cpu()

        loss = task_loss + adv_loss

    task_f1_score = f1_score(task_truth.max(1)[1].numpy(), task_preds.max(1)[1].numpy(), average = 'macro')

    man_idx = total_genders[:, 0].nonzero().squeeze()
    woman_idx = total_genders[:, 1].nonzero().squeeze()
    preds_man = torch.index_select(task_preds, 0, man_idx)
    preds_woman = torch.index_select(task_preds, 0, woman_idx)
    targets_man = torch.index_select(task_truth, 0, man_idx)
    targets_woman = torch.index_select(task_truth, 0, woman_idx)
    meanAP = average_precision_score(task_truth.numpy(), task_preds.numpy(), average='macro')
    meanAP_man = average_precision_score(targets_man.numpy(), preds_man.numpy(), average='macro')
    meanAP_woman = average_precision_score(targets_woman.numpy(), preds_woman.numpy(), average='macro')
    adv_acc = accuracy_score(adv_truth, adv_preds)

    if logging:
        val_logger.scalar_summary('task loss', task_loss_logger.avg, epoch)
        val_logger.scalar_summary('adv loss', adv_loss_logger.avg, epoch)
        val_logger.scalar_summary('task_f1_score', task_f1_score, epoch)
        val_logger.scalar_summary('meanAP', meanAP, epoch)
        val_logger.scalar_summary('meanAP_man', meanAP_man, epoch)
        val_logger.scalar_summary('meanAP_woman', meanAP_woman, epoch)
        val_logger.scalar_summary('adv acc', adv_acc, epoch)

    print('man size: {} woman size: {}'.format(len(man_idx), len(woman_idx)))
    print('Test epoch   : {}, meanAP: {:.2f}, meanAP_man: {:.2f}, meanAP_woman: {:.2f}, adv acc: {:.2f}, '.format( \
        epoch, meanAP*100, meanAP_man*100, meanAP_woman*100, adv_acc*100))

    return task_f1_score

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [3]:
class Arg:
    exp_id =  "exp_id"
    log_dir = "log_dir"
    ratio = "0"
    num_verb = 211
    annotation_dir = "../data"
    image_dir = "../data/of500_images_resized"
    balanced = False
    gender_balanced = False
    batch_balanced = False
    no_image = False
    adv_on = False
    layer = None
    adv_conv = False
    no_avgpool = False
    adv_capacity = 0
    adv_lambda = 1e-0
    dropout = 1e-1
    blackout = False
    blackout_box = False
    blackout_face = False
    blur = False
    grayscale = False
    edges = False
    resume = False
    learning_rate = 1e-4
    finetune = False
    num_epochs = 50
    batch_size = 64
    crop_size = 224
    image_size = 256
    start_epoch = 1
    seed = 1

In [4]:
args = Arg()

In [5]:
args.exp_id = "Test_ADV-Off_Layers-conv4"
args.adv_on = False
args.layer = "conv4"
args.no_avgpool = True
args.adv_capacity = 300
args.adv_lambda = 1
args.learning_rate = 0.00005
args.num_epochs = 100
args.batch_size = 128

In [6]:
# create model save directory
args.save_dir = os.path.join('./models', args.layer + '_' + str(args.adv_capacity) + '_' + \
        str(args.adv_lambda) + '_' + str(args.dropout) + '_' + args.exp_id)
if not os.path.exists(args.save_dir): os.makedirs(args.save_dir)

In [7]:
# create log save directory
args.log_dir = os.path.join('./logs', args.layer + '_' + str(args.adv_capacity) + '_' + \
        str(args.adv_lambda) + '_' + str(args.dropout) + '_' + args.exp_id)
train_log_dir = os.path.join(args.log_dir, 'train')
val_log_dir = os.path.join(args.log_dir, 'val')
if not os.path.exists(train_log_dir): os.makedirs(train_log_dir)
if not os.path.exists(val_log_dir): os.makedirs(val_log_dir)
train_logger = Logger(train_log_dir)
val_logger = Logger(val_log_dir)




In [8]:
#save all parameters for training
with open(os.path.join(args.log_dir, "arguments.txt"), "a") as f:
    f.write(str(args)+'\n')

In [9]:
normalize = transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])

In [10]:
# Image preprocessing
train_transform = transforms.Compose([
    transforms.Resize(args.image_size),
    transforms.RandomCrop(args.crop_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize])
val_transform = transforms.Compose([
    transforms.Resize(args.image_size),
    transforms.CenterCrop(args.crop_size),
    transforms.ToTensor(),
    normalize])

In [11]:
# Data samplers.
train_data = ImSituVerbGender(args,
                              annotation_dir = args.annotation_dir,
                              image_dir = args.image_dir,
                              split = 'train',
                              transform = train_transform)
val_data = ImSituVerbGender(args,
                            annotation_dir = args.annotation_dir,
                            image_dir = args.image_dir,
                            split = 'val',
                            transform = val_transform)
args.gender_balanced = True
val_data_gender_balanced = ImSituVerbGender(args, 
                                            annotation_dir = args.annotation_dir,
                                            image_dir = args.image_dir,
                                            split = 'val',
                                            transform = val_transform)
args.gender_balanced = False

ImSituVerbGender dataloader
loading train annotations..........
dataset size: 24301
man size : 14199 and woman size: 10102
ImSituVerbGender dataloader
loading val annotations..........
dataset size: 7730
man size : 4457 and woman size: 3273
ImSituVerbGender dataloader
loading val annotations..........
dataset size: 7730
man size : 3000 and woman size: 3000


In [12]:
# Data loaders / batch assemblers.
if args.batch_balanced:
    train_batch_size = int(2.5 * args.batch_size)
else:
    train_batch_size = int(args.batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size = train_batch_size,
        shuffle = True, num_workers = 0, pin_memory = True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size = args.batch_size,
        shuffle = False, num_workers = 0, pin_memory = True)
val_loader_gender_balanced = torch.utils.data.DataLoader(val_data_gender_balanced, \
    batch_size = args.batch_size, shuffle = False, num_workers = 0, pin_memory = True)

In [13]:
    # Build the models
model = VerbClassificationAdv(args, 
                              args.num_verb,
                              args.adv_capacity,
                              args.dropout,
                              args.adv_lambda).cuda()

Build a VerbClassification Model[conv4]
Load weights from Resnet18/50 done


In [14]:
# build loss
verb_weights = torch.FloatTensor(train_data.getVerbWeights())
criterion = nn.CrossEntropyLoss(weight=verb_weights, reduction='mean').cuda()

In [15]:
# build optimizer
def trainable_params():
    for param in model.parameters():
        if param.requires_grad:
            yield param
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('num_trainable_params:', num_trainable_params)
optimizer = torch.optim.Adam(trainable_params(), args.learning_rate, weight_decay = 1e-5)

num_trainable_params: 60826841


In [16]:
best_performance = 0
if args.resume:
    if os.path.isfile(os.path.join(args.save_dir, 'checkpoint.pth.tar')):
        print("=> loading checkpoint '{}'".format(args.save_dir))
        checkpoint = torch.load(os.path.join(args.save_dir, 'checkpoint.pth.tar'))
        args.start_epoch = checkpoint['epoch']
        best_performance = checkpoint['best_performance']
        # load partial weights
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.save_dir))

In [17]:
print('before training, evaluate the model')
test_balanced(args, 0, model, criterion, val_loader_gender_balanced,
    val_logger, logging=False)
test(args, 0, model, criterion, val_loader, val_logger, logging=False)

before training, evaluate the model


HBox(children=(FloatProgress(value=0.0, description='Val balanced 0', max=47.0, style=ProgressStyle(descriptio…




  recall = tps / tps[-1]


man size: 3000 woman size: 3000
Test epoch(f): 0, meanAP: 0.73, meanAP_man: nan, meanAP_woman: nan, adv acc: 50.00, 


  recall = tps / tps[-1]


HBox(children=(FloatProgress(value=0.0, description='Val 0', max=61.0, style=ProgressStyle(description_width='…


man size: 4457 woman size: 3273
Test epoch   : 0, meanAP: 0.69, meanAP_man: 0.73, meanAP_woman: nan, adv acc: 57.66, 


  recall = tps / tps[-1]


0.001921062242070278

In [None]:
for epoch in range(args.start_epoch, args.num_epochs + 1):
    train(args, epoch, model, criterion, train_loader, optimizer, train_logger, logging=True)
    test_balanced(args, epoch, model, criterion, val_loader_gender_balanced,
        val_logger, logging=True)
    current_performance = test(args, epoch, model, criterion, val_loader, val_logger, logging = True)
    is_best = current_performance > best_performance
    best_performance = max(current_performance, best_performance)
    model_state = {
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_performance': best_performance}
    save_checkpoint(args, model_state, is_best, os.path.join(args.save_dir, 'checkpoint.pth.tar'))
    # at the end of every run, save the model
    if epoch ==  args.num_epochs:
        torch.save(model_state, os.path.join(args.save_dir, 'checkpoint_%s.pth.tar' % str(args.num_epochs)))