In [5]:
import numpy as np
import pandas as pd

import sys
sys.path.append("/home/jinsu/workstation/project/debiasing-multi-modal")


df_zs_embeddings = pd.read_json("/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/waterbirds/RN50/embedding_prediction.json")
df_meta = pd.read_csv("/home/jinsu/workstation/project/debiasing-multi-modal/data/waterbirds/waterbird_complete95_forest2water2/metadata.csv")

In [6]:
import torch
import torch.nn as nn
from __future__ import print_function

import sys
import argparse
import time
import tqdm
import math

import torch
import torch.backends.cudnn as cudnn

from util import AverageMeter
from util import adjust_learning_rate, warmup_learning_rate, accuracy, accuracy_zs
from util import set_optimizer
# from networks.resnet_big import SupConResNet, LinearClassifier

try:
    import apex
    from apex import amp, optimizers
except ImportError:
    pass

from resnet import resnet50
from data.waterbirds_embeddings import WaterbirdsEmbeddings, load_waterbirds_embeddings
from data.celeba_embeddings import CelebaEmbeddings, load_celeba_embeddings
model_dict = {'resnet50': [resnet50, 1024]}
new_order_for_print = [
    'weighted_mean_acc',
    'worst_acc',
    'acc_0_0',
    'acc_0_1',
    'acc_1_0',
    'acc_1_1',
    'mean_acc'
]
from functools import partial

class LinearClassifier(nn.Module):
    def __init__(self, name='resnet50', num_classes=2):
        super(LinearClassifier, self).__init__()
        _, feat_dim = model_dict[name]
        self.fc = nn.Linear(feat_dim, num_classes)

    def forward(self, features):
        return self.fc(features)

def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50,
                        help='save frequency')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='batch_size')
    parser.add_argument('--num_workers', type=int, default=16,
                        help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.1,
                        help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90',
                        help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.2,
                        help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=0,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')

    # model dataset
    parser.add_argument('--model', type=str, default='resnet50')
    parser.add_argument('--dataset', type=str, default='waterbirds',
                        choices=['celeba', 'waterbirds'], help='dataset')

    # other setting
    parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
    parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')

    parser.add_argument('--embedding_dir', type=str,
                        help='extracted embedding')
    parser.add_argument('--target', type=str, default="class", choices=["class", "group", "spurious"])
    parser.add_argument('--data_dir', type=str,
                        help='metadata.csv')

    opt = parser.parse_args(args=[])

    # set the path according to the environment

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))

    opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\
        format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
               opt.batch_size)

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)

    # warm-up for large-batch training,
    if opt.warm:
        opt.model_name = '{}_warm'.format(opt.model_name)
        opt.warmup_from = 0.01
        opt.warm_epochs = 10
        if opt.cosine:
            eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
            opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
        else:
            opt.warmup_to = opt.learning_rate
            
    if opt.dataset == 'celeba':
        opt.n_cls = 2
    elif opt.dataset == 'waterbirds':
        opt.n_cls = 2
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))

    return opt


def set_model(opt):
    # model = SupConResNet(name=opt.model)
    criterion = torch.nn.CrossEntropyLoss()

    classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)

    # ckpt = torch.load(opt.ckpt, map_location='cpu')
    # state_dict = ckpt['model']

    if torch.cuda.is_available():
        # if torch.cuda.device_count() > 1:
        #     model.encoder = torch.nn.DataParallel(model.encoder)
        # else:
            # new_state_dict = {}
            # for k, v in state_dict.items():
            #     k = k.replace("module.", "")
            #     new_state_dict[k] = v
            # state_dict = new_state_dict
        
        # model = model.cuda()
        classifier = classifier.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

        # model.load_state_dict(state_dict)

    return classifier, criterion # model, 

def update_dict(acc_groups, y, g, logits):
    preds = torch.argmax(logits, axis=1)
    correct_batch = (preds == y)
    g = g.cpu()
    for g_val in np.unique(g):
        mask = g == g_val
        n = mask.sum().item()
        corr = correct_batch[mask].sum().item()
        acc_groups[g_val].update(corr / n, n) # AverageMeter Updater. 

def get_results(acc_groups, get_yp_func, ): # Input 중 acc_groups : AverageMeter()를 담고있는 dict. get_yp_func : 미리 partial을 이용해 n_groups를 저장해놓음. 
    groups = acc_groups.keys() # 0, 1, 2, 3
    results = {
            f"acc_{get_yp_func(g)[0]}_{get_yp_func(g)[1]}": acc_groups[g].avg
            for g in groups
    }
    all_correct = sum([acc_groups[g].sum for g in groups])
    all_total = sum([acc_groups[g].count for g in groups])
    results.update({"mean_acc" : all_correct / all_total})
    results.update({"worst_acc" : min(results.values())})
    
    return results

def get_y_p(g, n_places):
    y = g // n_places
    p = g % n_places
    return y, p


def train(train_loader, classifier, criterion, optimizer, epoch, get_yp_func, target, label='Train'): # model,
    """one epoch training"""
    # model.eval()
    classifier.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()
    acc_groups = {g_idx : AverageMeter() for g_idx in range(train_loader.dataset.n_groups)}

    end = time.time()
    for idx, data in enumerate(train_loader):
        if opt.dataset == 'waterbirds':
            embeddings, all_labels, _ = data
            labels = all_labels[target] # (y, group, spurious)
            groups = all_labels['group']
        else:
            embeddings, all_labels = data
            labels = all_labels[target] # (y, group, ypurious)
            groups = all_labels['group']
        
        data_time.update(time.time() - end)

        embeddings = embeddings.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        # warm-up learning rate
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # compute loss
        # with torch.no_grad():
        #     features = model.encoder(embeddings)
        output = classifier(embeddings.detach())
        loss = criterion(output, labels)

        # update metric
        losses.update(loss.item(), bsz)
        acc1 = accuracy(output, labels, bsz)
        acc.update(acc1, bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        # Update acc dict
        update_dict(acc_groups, labels, groups, output)
        
        # print info
        if (idx + 1) % opt.print_freq == 0:
            print(f'{label}: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Acc@1 {acc.val:.3f} ({acc.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, acc=acc))
            sys.stdout.flush()
            
    group_acc = get_results(acc_groups, get_yp_func) # NOTE declared in [def main]
    group_acc = {key: group_acc[key] for key in new_order_for_print[1:]}
    print(f"{label}:", str(group_acc))
    
    return losses.avg, acc.avg, group_acc


def validate(val_loader, classifier, criterion, get_yp_func, train_group_ratio, target, label='Test', watch=True):
    """validation"""
    
    # model.eval()
    classifier.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()
    acc_groups = {g_idx : AverageMeter() for g_idx in range(val_loader.dataset.n_groups)}

    with torch.no_grad():
        end = time.time()
        for idx, data in enumerate(val_loader):
            if opt.dataset == 'waterbirds':
                embeddings, all_labels, _ = data
                labels = all_labels[target] # (y, group, spurious)
                groups = all_labels['group']
            elif opt.dataset == 'celeba':
                embeddings, all_labels = data
                labels = all_labels[target] # (y, group, spurious)
                groups = all_labels['group']
            
            embeddings = embeddings.float().cuda()
            labels = labels.cuda()
            bsz = labels.shape[0]

            # forward
            output = classifier(embeddings)
            loss = criterion(output, labels)

            # update metric
            losses.update(loss.item(), bsz)
            acc1 = accuracy(output, labels, bsz)
            acc.update(acc1, bsz)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            
            # Update acc dict
            update_dict(acc_groups, labels, groups, output)
        
            if (idx+1) % opt.print_freq == 0:
                if watch:
                    print(f'{label}: [{0}/{1}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Acc@1 {acc.val:.3f} ({acc.avg:.3f})'.format(
                        idx, len(val_loader), batch_time=batch_time,
                        loss=losses, acc=acc))
                    
                    
    group_acc = get_results(acc_groups, get_yp_func)
    
    #NOTE add Weighted mean acc.
    groups = range(val_loader.dataset.n_groups) # 0, 1, 2, 3
    group_acc_indiv =  [group_acc[f"acc_{get_yp_func(g)[0]}_{get_yp_func(g)[1]}"] for g in groups]
    weighted_mean_acc = (np.array(group_acc_indiv) * np.array(train_group_ratio)).sum() # Weighted Sum \
    
    group_acc["weighted_mean_acc"] = weighted_mean_acc
    group_acc = {key: group_acc[key] for key in new_order_for_print}
    
    if watch:
        print(f"{label}:", str(group_acc))
        # print(' * Acc@1 {acc.avg:.3f}'.format(acc=acc))
    return losses.avg, acc.avg, group_acc


def main(opt):
    best_acc = 0
    best_epoch = 0
    # opt = parse_option()
    
    if opt.dataset == 'waterbirds':
        # build dataset example.
        trainset = WaterbirdsEmbeddings(opt.data_dir, 'train', opt.embedding_dir, None, None)
        # build data loader
        print("Load Data Loader (train, validation, test)")
        train_loader, val_loader, test_loader = load_waterbirds_embeddings(opt.data_dir, opt.embedding_dir, opt.batch_size, opt.batch_size)
    elif opt.dataset == 'celeba':
        # build dataset example.
        trainset = CelebaEmbeddings(opt.data_dir, 'train', opt.embedding_dir, None)
        # build data loader
        print("Load Data Loader (train, validation, test)")
        train_loader, val_loader, test_loader = load_celeba_embeddings(opt.data_dir, opt.embedding_dir, opt.batch_size, opt.batch_size)
        
    get_yp_func = partial(get_y_p, n_places=trainset.n_places)
    train_group_ratio = trainset.group_ratio
    
    # build model and criterion
    print("Set Linear Classifier")
    classifier, criterion = set_model(opt) # model, 

    # build optimizer
    print("Set Optimizer")
    optimizer = set_optimizer(opt, classifier)

    # training routine
    train_losses = []
    train_accs = []
    train_group_accs = []
    val_losses = []
    val_accs = []
    val_group_accs = []
    
    test_losses_y = [] # NOTE: Don't peek ! 
    test_accs_y = [] # NOTE: Don't peek ! 
    test_group_accs_y = [] # NOTE: Don't peek ! 
    test_losses_spurious = [] # NOTE: Don't peek ! 
    test_accs_spurious = [] # NOTE: Don't peek ! 
    test_group_accs_spurious = [] # NOTE: Don't peek ! 
    
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)


        print(f'--- Epoch {epoch} ---')
        # train for one epoch
        loss, acc, group_acc = train(train_loader, classifier, criterion,
                          optimizer, epoch, get_yp_func, target='y', label='Train(y)')
        
        train_losses.append(loss)
        train_accs.append(acc)
        train_group_accs.append(group_acc)
        # eval for one epoch
        val_loss, val_acc, val_group_acc = validate(val_loader, classifier, criterion, get_yp_func, train_group_ratio, target='y', label='Val(y)')
        if val_group_acc['worst_acc'] > best_acc:
            best_acc = val_group_acc['worst_acc']
            best_epoch = epoch
        
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        val_group_accs.append(val_group_acc)
            
        test_loss_y, test_acc_y, test_group_acc_y = validate(test_loader, classifier, criterion, get_yp_func, train_group_ratio, target='y', label='Test(y)', watch=True)
        test_losses_y.append(test_loss_y)
        test_accs_y.append(test_acc_y)
        test_group_accs_y.append(test_group_acc_y)
        
        test_loss_spurious, test_acc_spurious, test_group_acc_spurious= validate(test_loader, classifier, criterion, get_yp_func, train_group_ratio, target='spurious', label='Test(spurious)', watch=True)
        test_losses_spurious.append(test_loss_spurious)
        test_accs_spurious.append(test_acc_spurious)
        test_group_accs_spurious.append(test_group_acc_spurious)
    
    
    print('==================================================================')
    print('best epoch : {}'.format(best_epoch))
    print('best (worst-)validation accuracy: {} '.format(val_group_accs[best_epoch-1]))
    
    print('best test accuracy (class): {}'.format(test_group_accs_y[best_epoch-1]))
    print('best test accuracy (spurious): {}'.format(test_group_accs_spurious[best_epoch-1]))

In [7]:
import argparse

parser = argparse.ArgumentParser('argument for training')

parser.add_argument('--print_freq', type=int, default=20,
                    help='print frequency')
parser.add_argument('--save_freq', type=int, default=50,
                    help='save frequency')
parser.add_argument('--batch_size', type=int, default=512,
                    help='batch_size')
parser.add_argument('--num_workers', type=int, default=16,
                    help='num of workers to use')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of training epochs')

# optimization
parser.add_argument('--learning_rate', type=float, default=5,
                    help='learning rate')
parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90',
                    help='where to decay lr, can be a list')
parser.add_argument('--lr_decay_rate', type=float, default=0.2,
                    help='decay rate for learning rate')
parser.add_argument('--weight_decay', type=float, default=0,
                    help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9,
                    help='momentum')

# model dataset
parser.add_argument('--model', type=str, default='resnet50')
parser.add_argument('--dataset', type=str, default='waterbirds',
                    choices=['celeba', 'waterbirds'], help='dataset')

# other setting
parser.add_argument('--cosine', action='store_true',
                    help='using cosine annealing')
parser.add_argument('--warm', action='store_true',
                    help='warm-up for large batch training')

parser.add_argument('--embedding_dir', type=str, 
                    help='extracted embedding')
parser.add_argument('--target', type=str, default="y", choices=["y", "group", "spurious"]) # Label for linear proving
parser.add_argument('--data_dir', type=str,
                    help='folder, in which [metadata.csv] exists')

opt = parser.parse_args(args=[])   

iterations = opt.lr_decay_epochs.split(',')
opt.lr_decay_epochs = list([])
for it in iterations:
    opt.lr_decay_epochs.append(int(it))

opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\
    format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
            opt.batch_size)

if opt.cosine:
    opt.model_name = '{}_cosine'.format(opt.model_name)

# warm-up for large-batch training,
if opt.warm:
    opt.model_name = '{}_warm'.format(opt.model_name)
    opt.warmup_from = 0.01
    opt.warm_epochs = 10
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
        opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
    else:
        opt.warmup_to = opt.learning_rate
        
if opt.dataset == 'celeba':
    opt.n_cls = 2
elif opt.dataset == 'waterbirds':
    opt.n_cls = 2
else:
    raise ValueError('dataset not supported: {}'.format(opt.dataset))

# Linear Proving  
- (Epoch default : 100)
- Test 성능 모니터링하면 안 됨.

## Waterbirds

In [10]:
opt.dataset = 'waterbird'
opt.target = 'class'
opt.embedding_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/waterbirds/RN50/embedding_prediction.json"
opt.data_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/waterbirds/waterbird_complete95_forest2water2"
if __name__ == '__main__':    
    main(opt)
    

Load Data Loader (train, validation, test)
Set Linear Classifier
Set Optimizer
--- Epoch 1 ---
Train(y): {'worst_acc': 0.30357142857142855, 'acc_0_0': 0.8776443682104059, 'acc_0_1': 0.5054347826086957, 'acc_1_0': 0.30357142857142855, 'acc_1_1': 0.5534531693472091, 'mean_acc': 0.7851929092805006}
Val(y): {'weighted_mean_acc': 0.9556073321666032, 'worst_acc': 0.18796992481203006, 'acc_0_0': 0.9935760171306209, 'acc_0_1': 0.5579399141630901, 'acc_1_0': 0.18796992481203006, 'acc_1_1': 0.9398496240601504, 'mean_acc': 0.7289407839866555}
 * Acc@1 0.729
Test(y): {'weighted_mean_acc': 0.9490767917756933, 'worst_acc': 0.19003115264797507, 'acc_0_0': 0.9973392461197339, 'acc_0_1': 0.5516629711751663, 'acc_1_0': 0.19003115264797507, 'acc_1_1': 0.8987538940809969, 'mean_acc': 0.7235070762858129}
 * Acc@1 0.724
Test(spurious): {'weighted_mean_acc': 0.9523519817025319, 'worst_acc': 0.4483370288248337, 'acc_0_0': 0.9973392461197339, 'acc_0_1': 0.4483370288248337, 'acc_1_0': 0.8099688473520249, 'acc_1

## CelebA

In [4]:
# df_zs_embeddings = pd.read_json("/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/celeba/RN50/embedding_prediction.json")
opt.dataset = 'celeba'
opt.target = 'class'
opt.embedding_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/celeba/RN50/embedding_prediction.json"
opt.data_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/celeba"
if __name__ == '__main__':    
    main(opt)
    

# Zero-Shot Evaluation

In [8]:
def update_dict_zs(acc_groups, y, g, preds):
    # preds = torch.argmax(logits, axis=1)
    correct_batch = (preds == y)
    g = g.cpu()
    for g_val in np.unique(g):
        mask = g == g_val
        n = mask.sum().item()
        corr = correct_batch[mask].sum().item()
        acc_groups[g_val].update(corr / n, n) # AverageMeter Updater. 

def validate_zs(val_loader, get_yp_func, train_group_ratio, target, label='Test', watch=True):
    """validation"""

    batch_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()
    acc_groups = {g_idx : AverageMeter() for g_idx in range(val_loader.dataset.n_groups)}
        
    with torch.no_grad():
        end = time.time()
        for idx, data in enumerate(val_loader):
            if opt.dataset == 'waterbirds':
                _, all_labels, _ = data
                labels = all_labels[target] # (y, y_group, y_spurious)
                groups = all_labels['group']
                preds = all_labels['ebd_y_pred']
            else:
                _,  all_labels, _ = data
                labels = all_labels[target] # (y, y_group, y_spurious)
                groups = all_labels['group']
                preds = all_labels['ebd_y_pred']
              
              
            preds = preds.float().cuda()  
            labels = labels.cuda()
            bsz = labels.shape[0]
            
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            
            # update metric    
            acc1 = accuracy_zs(preds, labels, bsz)
            acc.update(acc1, bsz)
            
            # Update acc dict
            update_dict_zs(acc_groups, labels, groups, preds)
            
            if (idx+1) % opt.print_freq == 0:
                print(f'{label}: [{0}/{1}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Acc@1 {acc.val:.3f} ({acc.avg:.3f})'.format(
                    idx, len(val_loader), batch_time=batch_time,
                    loss=losses, acc=acc))
                    
                    
    group_acc = get_results(acc_groups, get_yp_func)
    
    #NOTE add Weighted mean acc.
    groups = range(val_loader.dataset.n_groups) # 0, 1, 2, 3
    group_acc_indiv =  [group_acc[f"acc_{get_yp_func(g)[0]}_{get_yp_func(g)[1]}"] for g in groups]
    weighted_mean_acc = (np.array(group_acc_indiv) * np.array(train_group_ratio)).sum() # Weighted Sum \
    
    group_acc["weighted_mean_acc"] = weighted_mean_acc
    group_acc = {key: group_acc[key] for key in new_order_for_print}
    
    if watch:
        print(f"{label}:", str(group_acc))
        print(' * Acc@1 {acc.avg:.3f}'.format(acc=acc))
        
    return  acc.avg, group_acc


def main_zs(opt):
    best_acc = 0
    best_epoch = 0
    # opt = parse_option()
    
    # build dataset example.
    if opt.dataset == 'waterbirds':
        # build dataset example.
        trainset = WaterbirdsEmbeddings(opt.data_dir, 'train', opt.embedding_dir, None, None)
        # build data loader
        print("Load Data Loader (train, validation, test)")
        train_loader, val_loader, test_loader = load_waterbirds_embeddings(opt.data_dir, opt.embedding_dir, opt.batch_size, opt.batch_size)
    elif opt.dataset == 'celeba':
        # build dataset example.
        trainset = CelebaEmbeddings(opt.data_dir, 'train', opt.embedding_dir, None)
        # build data loader
        print("Load Data Loader (train, validation, test)")
        train_loader, val_loader, test_loader = load_celeba_embeddings(opt.data_dir, opt.embedding_dir, opt.batch_size, opt.batch_size)
    
    get_yp_func = partial(get_y_p, n_places=trainset.n_places)
    train_group_ratio = trainset.group_ratio

    # eval for one epoch
    val_acc, val_group_acc = validate_zs(val_loader, get_yp_func, train_group_ratio, target='y', label='Val(y)')
    test_acc_y, test_group_acc_y = validate_zs(test_loader, get_yp_func, train_group_ratio, target='y', label='Test(y)', watch=True)
    test_acc_spurious, test_group_acc_spurious = validate_zs(test_loader, get_yp_func, train_group_ratio, target='spurious', label='Test(spurious)', watch=True)
    
    print('===============================Final Results===============================')
    print('Zero-shot (worst-)validation accuracy: {} '.format(val_group_acc))
    
    print('Zero-shot test accuracy (class): {}'.format(test_group_acc_y))
    print('Zero-shot test accuracy (spurious): {}'.format(test_group_acc_spurious))

## Watearbirds

In [20]:
# 
opt.dataset = 'waterbird'
opt.target = 'class'
opt.embedding_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/waterbirds/RN50/embedding_prediction.json"
opt.data_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/waterbirds/waterbird_complete95_forest2water2"
if __name__ == '__main__':    
    main_zs(opt)
    

Load Data Loader (train, validation, test)
Val(y): {'weighted_mean_acc': 0.9450627926530569, 'worst_acc': 0.3233082706766917, 'acc_0_0': 0.9914346895074947, 'acc_0_1': 0.7145922746781116, 'acc_1_0': 0.3233082706766917, 'acc_1_1': 0.8646616541353384, 'mean_acc': 0.7956630525437864}
 * Acc@1 0.796
Test(y): {'weighted_mean_acc': 0.9289761348579215, 'worst_acc': 0.3909657320872274, 'acc_0_0': 0.9804878048780488, 'acc_0_1': 0.7254988913525499, 'acc_1_0': 0.3909657320872274, 'acc_1_1': 0.822429906542056, 'mean_acc': 0.7984121505005177}
 * Acc@1 0.798
Zero-shot (worst-)validation accuracy: {'weighted_mean_acc': 0.9450627926530569, 'worst_acc': 0.3233082706766917, 'acc_0_0': 0.9914346895074947, 'acc_0_1': 0.7145922746781116, 'acc_1_0': 0.3233082706766917, 'acc_1_1': 0.8646616541353384, 'mean_acc': 0.7956630525437864} 
Zero-shot test accuracy (class): {'weighted_mean_acc': 0.9289761348579215, 'worst_acc': 0.3909657320872274, 'acc_0_0': 0.9804878048780488, 'acc_0_1': 0.7254988913525499, 'acc_1_0

### Double Check

In [6]:
df_zs_embeddings = pd.read_json("/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/waterbirds/RN50/embedding_prediction.json")
from copy import deepcopy
w = deepcopy(df_zs_embeddings)
w= w.T
print("Worst acc: ", len(w[(w['split']=='2') & (w['y']=='1') & (w['place']=='0') & (w['y_pred']=='1')]) / len(w[(w['split']=='2') & (w['y']=='1') & (w['place']=='0')]))

0.3909657320872274


## CelebA

In [10]:
df_zs_embeddings

Unnamed: 0,000001.jpg,000002.jpg,000003.jpg,000004.jpg,000005.jpg,000006.jpg,000007.jpg,000008.jpg,000009.jpg,000010.jpg,...,202590.jpg,202591.jpg,202592.jpg,202593.jpg,202594.jpg,202595.jpg,202596.jpg,202597.jpg,202598.jpg,202599.jpg
blond,0,0,0,0,0,0,0,0,0,0,...,0,0,1,0,0,1,1,0,0,1
male,0,0,1,0,0,0,1,1,0,0,...,1,0,0,0,0,0,1,1,0,0
group,0,0,1,0,0,0,1,1,0,0,...,1,0,2,0,0,2,3,1,0,2
split,0,0,0,0,0,0,0,0,0,0,...,2,2,2,2,2,2,2,2,2,2
image_embedding,"[-0.009254455566406, 0.006446838378906, -0.000...","[-0.014373779296875002, -1.072883605957031e-06...","[-0.020751953125, -0.01275634765625, -0.002752...","[0.008903503417968, 0.005790710449218001, -0.0...","[-0.020309448242187, -0.0140380859375, -0.0012...","[-0.016098022460937, 0.002260208129882, 0.0353...","[0.013648986816406, 0.010597229003906, 0.02252...","[-0.011924743652343, 0.020156860351562, 0.0206...","[0.007244110107421001, 0.008682250976562, -0.0...","[-0.000575065612792, 0.03582763671875, 0.02078...",...,"[0.005859375, 0.02935791015625, -0.01418304443...","[0.03826904296875, -0.0186767578125, 0.0081253...","[-0.012893676757812, 0.022796630859375003, 0.0...","[0.04962158203125, -0.0041542053222650005, -0....","[-0.036895751953125, -0.00086498260498, 0.0211...","[-0.0157470703125, 0.013694763183593, 0.056549...","[0.004299163818359, 0.0045204162597650005, 0.0...","[0.006679534912109, 0.01129150390625, 0.012878...","[0.037109375, 0.007495880126953001, -0.0129013...","[0.008468627929687, -0.000998497009277, 0.0297..."
y_pred,1,1,1,1,0,1,1,1,1,0,...,1,1,1,1,1,0,0,1,1,0


In [9]:
opt.dataset = 'celeba'
opt.target = 'class'
opt.embedding_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/celeba/RN50/embedding_prediction.json"
opt.data_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/celeba"
if __name__ == '__main__':    
    main_zs(opt)
    

Load Data Loader (train, validation, test)
Val(y): [0/1]	Time 0.009 (0.277)	Loss 0.0000 (0.0000)	Acc@1 0.137 (0.146)
Val(y): [0/1]	Time 0.003 (0.144)	Loss 0.0000 (0.0000)	Acc@1 0.137 (0.148)
Val(y): [0/1]	Time 0.009 (0.099)	Loss 0.0000 (0.0000)	Acc@1 0.150 (0.147)
Val(y): [0/1]	Time 0.003 (0.078)	Loss 0.0000 (0.0000)	Acc@1 0.119 (0.148)
Val(y): [0/1]	Time 0.002 (0.064)	Loss 0.0000 (0.0000)	Acc@1 0.143 (0.147)
Val(y): [0/1]	Time 0.003 (0.056)	Loss 0.0000 (0.0000)	Acc@1 0.156 (0.147)
Val(y): [0/1]	Time 0.002 (0.049)	Loss 0.0000 (0.0000)	Acc@1 0.152 (0.146)
Val(y): [0/1]	Time 0.003 (0.044)	Loss 0.0000 (0.0000)	Acc@1 0.141 (0.147)
Val(y): [0/1]	Time 0.002 (0.041)	Loss 0.0000 (0.0000)	Acc@1 0.150 (0.146)
Val(y): [0/1]	Time 0.002 (0.038)	Loss 0.0000 (0.0000)	Acc@1 0.139 (0.146)
Val(y): [0/1]	Time 0.002 (0.035)	Loss 0.0000 (0.0000)	Acc@1 0.145 (0.147)
Val(y): [0/1]	Time 0.003 (0.033)	Loss 0.0000 (0.0000)	Acc@1 0.131 (0.147)
Val(y): [0/1]	Time 0.002 (0.032)	Loss 0.0000 (0.0000)	Acc@1 0.168 (0.

### Double Check

In [7]:
from copy import deepcopy
df_zs_embeddings = pd.read_json("/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/celeba/RN50/embedding_prediction.json")
w = deepcopy(df_zs_embeddings)
w= w.T
print("Worst acc:", len(w[(w['split']=='2') & (w['blond']=='1') & (w['male']=='1') & (w['y_pred']=='1')]) / len(w[(w['split']=='2') & (w['blond']=='1') & (w['male']=='1')]))

0.23333333333333334
