In [30]:
import numpy as np
import pandas as pd
import os
import sys
import json
sys.path.append("/home/jinsu/workstation/project/debiasing-multi-modal")


df_zs_embeddings = pd.read_json("/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/waterbirds/RN50/embedding_prediction.json")
df_meta = pd.read_csv("/home/jinsu/workstation/project/debiasing-multi-modal/data/waterbirds/waterbird_complete95_forest2water2/metadata.csv")

In [31]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


In [32]:
import torch
import torch.nn as nn
from __future__ import print_function

import sys
import argparse
import time
import tqdm
import math

import torch
import torch.backends.cudnn as cudnn

from util import AverageMeter
from util import adjust_learning_rate, warmup_learning_rate, accuracy, accuracy_zs
from util import set_optimizer
# from networks.resnet_big import SupConResNet, LinearClassifier

try:
    import apex
    from apex import amp, optimizers
except ImportError:
    pass

from resnet import resnet50
from data.waterbirds_embeddings import WaterbirdsEmbeddings, load_waterbirds_embeddings
from data.celeba_embeddings import CelebaEmbeddings, load_celeba_embeddings
model_dict = {'resnet50': [resnet50, 1024]}
new_order_for_print = [
    'weighted_mean_acc',
    'worst_acc',
    'acc_0_0',
    'acc_0_1',
    'acc_1_0',
    'acc_1_1',
    'mean_acc'
]
from functools import partial

class LinearClassifier(nn.Module):
    def __init__(self, input_dim, num_classes=2):
        super(LinearClassifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, features):
        return self.fc(features)



class CustomCLIP(nn.Module):
    def __init__(self, adapter, text_embedding_dir, temperature=0.01):
        super().__init__()
        self.text_embedding_dir = text_embedding_dir
        self.adapter = adapter
        self.temperature = temperature # CA default : 0.01, B2T default : 0.02 (?) NOTE
        
        with open(self.text_embedding_dir, 'r') as f:
            self.text_embeddings = json.load(f)
        text_features = []
        for class_template, class_embedding in self.text_embeddings.items():
            text_features.append(torch.tensor(class_embedding))
        self.text_features = torch.stack(text_features, dim=1).cuda() # (B, 2, 1024)
        
    
    def forward(self, features):
        image_features =  self.adapter(features) # Un-normalized (B, 1024)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        text_features = self.text_features # 이미 Normalized(?) NOTE (B, 2, 1024)
        
        logits = image_features @ text_features / self.temperature # (B, 1024) X (B, 2, 1024) = # (B, 2)
        return logits
        
        
class Adapter(nn.Module):
    """
    - Residual connetion : 제외 (original Adapter - 0.2*images + 0.8*adapter)
    - Hidden dimension : 128 고정 (original Adatper - input_dim // 4)
    """
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
    def forward(self, features):
        return self.layers(features)


def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50,
                        help='save frequency')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='batch_size')
    parser.add_argument('--num_workers', type=int, default=16,
                        help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.1,
                        help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90',
                        help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.2,
                        help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=0,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')

    # model dataset
    parser.add_argument('--model', type=str, default='resnet50')
    parser.add_argument('--dataset', type=str, default='waterbirds',
                        choices=['celeba', 'waterbirds'], help='dataset')

    # other setting
    parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
    parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')

    parser.add_argument('--image_embedding_dir', type=str,
                        help='extracted image embedding')
    parser.add_argument('--text_embedding_dir', type=str,
                        help='extracted text embedding')
    parser.add_argument('--train_target', type=str, default="class", choices=["class", "group", "spurious"]) # label for prediction.
    parser.add_argument('--data_dir', type=str,
                        help='metadata.csv')
    parser.add_argument('--tl_method', type=str, default= "linear_probing", choices=["linear_probing", "adapter", "contrastive_adapter"]
                        ,help='transfer learning method')

    opt = parser.parse_args(args=[])

    # set the path according to the environment

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))

    opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\
        format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
               opt.batch_size)

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)

    # warm-up for large-batch training,
    if opt.warm:
        opt.model_name = '{}_warm'.format(opt.model_name)
        opt.warmup_from = 0.01
        opt.warm_epochs = 10
        if opt.cosine:
            eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
            opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
        else:
            opt.warmup_to = opt.learning_rate
            
    if opt.dataset == 'celeba':
        opt.n_cls = 2
    elif opt.dataset == 'waterbirds':
        opt.n_cls = 2
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))

    return opt


def set_model(opt):
    # model = SupConResNet(name=opt.model)
    criterion = torch.nn.CrossEntropyLoss()
        
    _ , input_dim = model_dict[opt.model] 
    
    if opt.tl_method =='linear_probing':
        print("Off-the-shelf prediction module : [Linear Classifier]")
        classifier = LinearClassifier(input_dim = input_dim, num_classes = opt.n_cls)
    elif opt.tl_method =='adapter':
        print("Off-the-shelf prediction module : [Adapter + temperatured-image-text-normalized-prediction]")
        adapter = Adapter(input_dim = input_dim, hidden_dim = 128) # Fixed by heuristics
        classifier = CustomCLIP(adapter, opt.text_embedding_dir, temperature=0.01)
    

    # ckpt = torch.load(opt.ckpt, map_location='cpu')
    # state_dict = ckpt['model']

    if torch.cuda.is_available():
        # if torch.cuda.device_count() > 1:
        #     model.encoder = torch.nn.DataParallel(model.encoder)
        # else:
            # new_state_dict = {}
            # for k, v in state_dict.items():
            #     k = k.replace("module.", "")
            #     new_state_dict[k] = v
            # state_dict = new_state_dict
        
        # model = model.cuda()
        classifier = classifier.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

        # model.load_state_dict(state_dict)

    return classifier, criterion # model, 

def update_dict(acc_groups, y, g, logits):
    preds = torch.argmax(logits, axis=1)
    correct_batch = (preds == y)
    g = g.cpu()
    for g_val in np.unique(g):
        mask = g == g_val
        n = mask.sum().item()
        corr = correct_batch[mask].sum().item()
        acc_groups[g_val].update(corr / n, n) # AverageMeter Updater. 

def get_results(acc_groups, get_yp_func, ): # Input 중 acc_groups : AverageMeter()를 담고있는 dict. get_yp_func : 미리 partial을 이용해 n_groups를 저장해놓음. 
    groups = acc_groups.keys() # 0, 1, 2, 3
    results = {
            f"acc_{get_yp_func(g)[0]}_{get_yp_func(g)[1]}": acc_groups[g].avg
            for g in groups
    }
    all_correct = sum([acc_groups[g].sum for g in groups])
    all_total = sum([acc_groups[g].count for g in groups])
    results.update({"mean_acc" : all_correct / all_total})
    results.update({"worst_acc" : min(results.values())})
    
    return results

def get_y_p(g, n_places):
    y = g // n_places
    p = g % n_places
    return y, p


def train(train_loader, classifier, criterion, optimizer, epoch, get_yp_func, target, label='Train'): # model,
    """one epoch training"""
    # model.eval()
    classifier.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()
    acc_groups = {g_idx : AverageMeter() for g_idx in range(train_loader.dataset.n_groups)}

    end = time.time()
    for idx, data in enumerate(train_loader):
        if opt.dataset == 'waterbirds':
            embeddings, all_labels, img_filenames = data
            labels = all_labels[target] # (y, group, spurious)
            groups = all_labels['group']
        else:
            embeddings, all_labels, img_filenames = data
            labels = all_labels[target] # (y, group, ypurious)
            groups = all_labels['group']
        
        data_time.update(time.time() - end)

        embeddings = embeddings.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        # warm-up learning rate
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # compute loss
        # with torch.no_grad():
        #     features = model.encoder(embeddings)
        output = classifier(embeddings.detach())
        
        loss = criterion(output, labels)

        # update metric
        losses.update(loss.item(), bsz)
        acc1 = accuracy(output, labels, bsz)
        acc.update(acc1, bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        # Update acc dict
        update_dict(acc_groups, labels, groups, output)
        
        # print info
        # if (idx + 1) % opt.print_freq == 0:
        #     print(f'{label}: [{0}][{1}/{2}]\t'
        #           'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
        #           'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
        #           'loss {loss.val:.3f} ({loss.avg:.3f})\t'
        #           'Acc@1 {acc.val:.3f} ({acc.avg:.3f})'.format(
        #            epoch, idx + 1, len(train_loader), batch_time=batch_time,
        #            data_time=data_time, loss=losses, acc=acc))
        #     sys.stdout.flush()
            
    group_acc = get_results(acc_groups, get_yp_func) # NOTE declared in [def main]
    group_acc = {key: group_acc[key] for key in new_order_for_print[1:]}
    print(f"{label}:", str(group_acc))
    
    return losses.avg, acc.avg, group_acc


def validate(val_loader, classifier, criterion, get_yp_func, train_group_ratio, target, label='Test', watch=True):
    """validation"""
    
    # model.eval()
    classifier.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()
    acc_groups = {g_idx : AverageMeter() for g_idx in range(val_loader.dataset.n_groups)}

    with torch.no_grad():
        end = time.time()
        for idx, data in enumerate(val_loader):
            if opt.dataset == 'waterbirds':
                embeddings, all_labels, _ = data
                labels = all_labels[target] # (y, group, spurious)
                groups = all_labels['group']
            elif opt.dataset == 'celeba':
                embeddings, all_labels = data
                labels = all_labels[target] # (y, group, spurious)
                groups = all_labels['group']
            
            embeddings = embeddings.float().cuda()
            labels = labels.cuda()
            bsz = labels.shape[0]

            # forward
            output = classifier(embeddings)
            loss = criterion(output, labels)

            # update metric
            losses.update(loss.item(), bsz)
            acc1 = accuracy(output, labels, bsz)
            acc.update(acc1, bsz)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            
            # Update acc dict
            update_dict(acc_groups, labels, groups, output)
        
            # if (idx+1) % opt.print_freq == 0:
            #     if watch:
            #         print(f'{label}: [{0}/{1}]\t'
            #             'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            #             'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
            #             'Acc@1 {acc.val:.3f} ({acc.avg:.3f})'.format(
            #             idx, len(val_loader), batch_time=batch_time,
            #             loss=losses, acc=acc))
                    
                    
    group_acc = get_results(acc_groups, get_yp_func)
    
    #NOTE add Weighted mean acc.
    groups = range(val_loader.dataset.n_groups) # 0, 1, 2, 3
    group_acc_indiv =  [group_acc[f"acc_{get_yp_func(g)[0]}_{get_yp_func(g)[1]}"] for g in groups]
    weighted_mean_acc = (np.array(group_acc_indiv) * np.array(train_group_ratio)).sum() # Weighted Sum \
    
    group_acc["weighted_mean_acc"] = weighted_mean_acc
    group_acc = {key: group_acc[key] for key in new_order_for_print}
    
    if watch:
        print(f"{label}:", str(group_acc))
        # print(' * Acc@1 {acc.avg:.3f}'.format(acc=acc))
    return losses.avg, acc.avg, group_acc


def main(opt):
    best_acc = 0
    best_epoch = 0
    # opt = parse_option()
    
    print(f"Transfer Learning using [{opt.tl_method}]")
    if opt.dataset == 'waterbirds':
        # build dataset example.
        trainset = WaterbirdsEmbeddings(opt.data_dir, 'train', opt.image_embedding_dir, None, None)
        # build data loader
        
        if opt.train_target == "class":
            print(f"Train Target : {opt.train_target} (Land bird(0) / Water bird(1))")
        elif opt.train_target == "spurious":
            print(f"Train Target : {opt.train_target} (Land background(0) / Water background(1)")
            
        print("Load Data Loader (train, validation, test)")
        train_loader, val_loader, test_loader = load_waterbirds_embeddings(opt.data_dir, opt.image_embedding_dir, opt.batch_size, opt.batch_size)
        
    elif opt.dataset == 'celeba':
        # build dataset example.
        trainset = CelebaEmbeddings(opt.data_dir, 'train', opt.image_embedding_dir, None)
        
        if opt.train_target == "class":
            print(f"Target : {opt.train_target} (non-blond hair(0) / blond hair(1)")
        elif opt.train_target == "spurious":
            print(f"Target : {opt.train_target} (female(0) / male(1))")
            
        # build data loader
        print("Load Data Loader (train, validation, test)")
        train_loader, val_loader, test_loader = load_celeba_embeddings(opt.data_dir, opt.image_embedding_dir, opt.batch_size, opt.batch_size)
    

    get_yp_func = partial(get_y_p, n_places=trainset.n_places)
    train_group_ratio = trainset.group_ratio
    
    # build model and criterion
    print(f"Set Classifier : {opt.tl_method}")
    classifier, criterion = set_model(opt) # model, 

    # build optimizer
    print("Set Optimizer")
    optimizer = set_optimizer(opt, classifier)

    # training routine
    train_losses = []
    train_accs = []
    train_group_accs = []
    val_losses = []
    val_accs = []
    val_group_accs = []
    val_losses_non_target = []
    val_accs_non_target = []
    val_group_accs_non_target = []
    
    test_losses_y = [] # NOTE: Don't peek ! 
    test_accs_y = [] # NOTE: Don't peek ! 
    test_group_accs_y = [] # NOTE: Don't peek ! 
    test_losses_spurious = [] # NOTE: Don't peek ! 
    test_accs_spurious = [] # NOTE: Don't peek ! 
    test_group_accs_spurious = [] # NOTE: Don't peek ! 
    
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)
        print(f'--- Epoch {epoch} ---')
        # train for one epoch
        loss, acc, group_acc = train(train_loader, classifier, criterion,
                          optimizer, epoch, get_yp_func, target=opt.train_target, label=f'Train({opt.train_target})')
        
        train_losses.append(loss)
        train_accs.append(acc)
        train_group_accs.append(group_acc)
        # eval for one epoch
        val_loss, val_acc, val_group_acc = validate(val_loader, classifier, criterion, get_yp_func, train_group_ratio, target=opt.train_target, label=f'Val({opt.train_target})')
        
        if opt.tl_method in ['linear_probing', 'adapter']:
            if val_group_acc['weighted_mean_acc'] > best_acc:
                best_acc = val_group_acc['weighted_mean_acc']
                best_epoch = epoch
        elif opt.tl_method in ['contrastive_adapter']:
            if val_group_acc['worst_acc'] > best_acc:
                best_acc = val_group_acc['worst_acc']
                best_epoch = epoch
            
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        val_group_accs.append(val_group_acc)
        
        
        non_target = [label for label in ['class', 'spurious'] if (label!=opt.train_target)][0]
        val_loss_non_target, val_acc_non_target, val_group_acc_non_target = validate(val_loader, classifier, criterion, get_yp_func, train_group_ratio, target=non_target, label=f'Val({non_target})')

        val_losses_non_target.append(val_loss_non_target)
        val_accs_non_target.append(val_acc_non_target)
        val_group_accs_non_target.append(val_group_acc_non_target)
        
        
        test_loss_y, test_acc_y, test_group_acc_y = validate(test_loader, classifier, criterion, get_yp_func, train_group_ratio, target='class', label='Test(class)', watch=True)
        test_losses_y.append(test_loss_y)
        test_accs_y.append(test_acc_y)
        test_group_accs_y.append(test_group_acc_y)
        
        test_loss_spurious, test_acc_spurious, test_group_acc_spurious= validate(test_loader, classifier, criterion, get_yp_func, train_group_ratio, target='spurious', label='Test(spurious)', watch=True)
        test_losses_spurious.append(test_loss_spurious)
        test_accs_spurious.append(test_acc_spurious)
        test_group_accs_spurious.append(test_group_acc_spurious)
    
    
    print('===============')
    print('best epoch : {}'.format(best_epoch))
    
    val_group_acc = val_group_accs[best_epoch-1]
    test_group_acc_y = test_group_accs_y[best_epoch-1]
    test_group_acc_spurious = test_group_accs_spurious[best_epoch-1]
    print(f'best validation accuracy on {opt.train_target}: {val_group_acc}')
    
    print(f'best test accuracy (class): {test_group_acc_y}')
    print(f'best test accuracy (spurious): {test_group_acc_spurious}')
    
    
    # return val_group_acc, test_group_acc_y, test_group_acc_spurious # 최종 결과만
    return val_group_acc, test_group_acc_y, test_group_acc_spurious, train_group_accs, val_group_accs, val_group_accs_non_target # 전체 결과까지.

In [33]:
import argparse

parser = argparse.ArgumentParser('argument for training')

parser.add_argument('--print_freq', type=int, default=20,
                    help='print frequency')
parser.add_argument('--save_freq', type=int, default=50,
                    help='save frequency')
parser.add_argument('--batch_size', type=int, default=128,
                    help='batch_size')
parser.add_argument('--num_workers', type=int, default=16,
                    help='num of workers to use')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of training epochs')

# optimization
parser.add_argument('--learning_rate', type=float, default=5,
                    help='learning rate')
parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90',
                    help='where to decay lr, can be a list')
parser.add_argument('--lr_decay_rate', type=float, default=0.2,
                    help='decay rate for learning rate')
parser.add_argument('--weight_decay', type=float, default=0,
                    help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9,
                    help='momentum')

# model dataset
parser.add_argument('--model', type=str, default='resnet50')
parser.add_argument('--dataset', type=str, default='waterbirds',
                    choices=['celeba', 'waterbirds'], help='dataset')

# other setting
parser.add_argument('--cosine', action='store_true',
                    help='using cosine annealing')
parser.add_argument('--warm', action='store_true',
                    help='warm-up for large batch training')

parser.add_argument('--image_embedding_dir', type=str, 
                    help='extracted image embedding')
parser.add_argument('--text_embedding_dir', type=str, 
                    help='extracted text embedding')
parser.add_argument('--train_target', type=str, default="class", choices=["class", "spurious", "group"]) # Label for training.
parser.add_argument('--data_dir', type=str,
                    help='folder, in which [metadata.csv] exists')
parser.add_argument('--tl_method', type=str, default="linear_probing", choices=["linear_probing", "adapter", "contrastive_adapter"]
                        ,help='transfer learning method')

opt = parser.parse_args(args=[])   

iterations = opt.lr_decay_epochs.split(',')
opt.lr_decay_epochs = list([])
for it in iterations:
    opt.lr_decay_epochs.append(int(it))

opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\
    format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
            opt.batch_size)

if opt.cosine:
    opt.model_name = '{}_cosine'.format(opt.model_name)

# warm-up for large-batch training,
if opt.warm:
    opt.model_name = '{}_warm'.format(opt.model_name)
    opt.warmup_from = 0.01
    opt.warm_epochs = 10
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
        opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
    else:
        opt.warmup_to = opt.learning_rate
        
if opt.dataset == 'celeba':
    opt.n_cls = 2
elif opt.dataset == 'waterbirds':
    opt.n_cls = 2
else:
    raise ValueError('dataset not supported: {}'.format(opt.dataset))

# Linear Probing  & Adapter
- (Epoch default : 100)
- Test 성능 모니터링하면 안 됨.

## Waterbirds

In [11]:
final_acc_dict = {} # for mean+-std (final) accuracy (over 3 times run) 

full_acc_dict = {} # for full accuracy (on only single run) (for simple reasoning)

opt.epochs = 2
opt.dataset = 'waterbirds'

opt.text_embedding_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings_unnormalized/waterbirds/text_embedding.json"
opt.image_embedding_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings_unnormalized/waterbirds/RN50/embedding_prediction.json"
opt.data_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/waterbirds/waterbird_complete95_forest2water2"

# opt.tl_method = "linear_probing"
opt.tl_method = "adapter"
if __name__ == '__main__':    
    for tl_method in ["linear_probing", "adapter"]:
        opt.tl_method = tl_method
        final_acc_dict[tl_method] = {}
        full_acc_dict[tl_method] = {}
        for train_target in ['class', 'spurious']:
            opt.train_target = train_target
            for iter in range(1, 3):
                print(f"================= {opt.dataset}_ft_on_{opt.train_target} (iter. {iter})=================")
                
                v_acc_t, t_acc_y, t_acc_s, tr_accs_t, v_accs_t, v_accs_non_t = main(opt)
                final_acc_dict[tl_method][f"{opt.dataset}_ft_on_{opt.train_target}({iter})"] = {f'val_{opt.train_target}': v_acc_t, 'test_class': t_acc_y, 'test_spurious': t_acc_s}
                
                if iter==1:
                    full_acc_dict[tl_method][f"{opt.dataset}_ft_on_{opt.train_target}"] = {f"train_{opt.train_target}" : tr_accs_t, f"val_{opt.train_target}" : v_accs_t, f"val_non_target" : v_accs_non_t}                


Transfer Learning using [linear_probing]
Train Target : class (Land bird(0) / Water bird(1))
Load Data Loader (train, validation, test)
Set Classifier : linear_probing
Off-the-shelf prediction module : [Linear Classifier]
Set Optimizer
--- Epoch 1 ---
Train(class): [0][1/2]	BT 0.002 (0.040)	DT 0.001 (0.039)	loss 0.173 (0.360)	Acc@1 0.930 (0.838)
Train(class): {'worst_acc': 0.30357142857142855, 'acc_0_0': 0.9488279016580904, 'acc_0_1': 0.5543478260869565, 'acc_1_0': 0.30357142857142855, 'acc_1_1': 0.804162724692526, 'mean_acc': 0.8942648592283629}
Val(class): {'weighted_mean_acc': 0.9640794620449225, 'worst_acc': 0.37593984962406013, 'acc_0_0': 0.9957173447537473, 'acc_0_1': 0.5944206008583691, 'acc_1_0': 0.37593984962406013, 'acc_1_1': 0.9548872180451128, 'mean_acc': 0.7664720600500416}
Val(spurious): {'weighted_mean_acc': 0.959730755973583, 'worst_acc': 0.4055793991416309, 'acc_0_0': 0.9957173447537473, 'acc_0_1': 0.4055793991416309, 'acc_1_0': 0.6240601503759399, 'acc_1_1': 0.9548872

### Mean+-std (averaged over 3 times run)

In [28]:
df_for_final_report = {} # For reporting final performance.
pd.set_option('display.max_columns', 10)
pd.set_option('display.width', 1000)

In [29]:
# val_target
for tl_method in ["linear_probing", "adapter"]:
    if tl_method == "linear_probing":
        report_label = 'Lin. Prov.'
        val_on = 'avg'
    elif tl_method =='adapter':
        report_label = 'Adapter'
        val_on = 'avg'
    elif tl_method =='contrastive_adapter':
        report_label = 'Contra. Adapter'
        val_on = 'worst'
        
        
    for train_target in ["class", "spurious"]:
        df_for_final_report[f"{report_label} on {train_target}"]  = {}
    
        for eval in [f"val_{train_target}", "test_class", "test_spurious"]:
            if "val" in eval:
                print(f"============== Best [{val_on}]-validation group accuracy on [{train_target}]-trained [{tl_method}] ==============\n")
            elif "test" in eval:
                print(f"===== Corresponding [{eval}] group accuracy on [{train_target}]-trained [{tl_method}] =====\n")
                
            multiple_run = []
            for iter in range(1,3): 
                # val_target 
                single_run = final_acc_dict[tl_method][f"{opt.dataset}_ft_on_{train_target}({iter})"][eval]
                multiple_run.append(single_run)
                    
            df_multiple = pd.DataFrame(multiple_run, index=range(1, 3))
            df_multiple.index.name = 'iter'
            
            df_multiple.loc["mean"] = df_multiple.mean()
            df_multiple.loc["std"] = df_multiple.std()
            print(df_multiple)
            
            # 3-times-averaged performance(Avg. & Worst acc.)
            if "test" in eval:
                df_for_final_report[f"{report_label} on {train_target}"][f"Avg. acc (on {eval.split('_')[-1]})"] = df_multiple.loc["mean"]["weighted_mean_acc"]
                df_for_final_report[f"{report_label} on {train_target}"][f"Worst. acc (on {eval.split('_')[-1]})"] = df_multiple.loc["mean"]["worst_acc"]

            for k, v in df_for_final_report.items():
                print(k, v)
            
        

# for k, v in final_acc_dict[tl_method].items():
#     print(k, v)
#     print(v.keys())
    


      weighted_mean_acc  worst_acc   acc_0_0   acc_0_1   acc_1_0   acc_1_1  mean_acc
iter                                                                                
1              0.965644   0.225564  0.997859  0.856223  0.225564  0.917293  0.848207
2              0.965496   0.586466  0.989293  0.603004  0.586466  0.969925  0.792327
mean           0.965570   0.406015  0.993576  0.729614  0.406015  0.943609  0.820267
std            0.000074   0.180451  0.004283  0.126609  0.180451  0.026316  0.027940
Lin. Prov. on class {}
===== Corresponding [test_class] group accuracy on [class]-trained [linear_probing] =====

      weighted_mean_acc  worst_acc   acc_0_0   acc_0_1   acc_1_0   acc_1_1  mean_acc
iter                                                                                
1              0.956283   0.257009  0.999113  0.864302  0.257009  0.867601  0.849845
2              0.960183   0.549446  0.990244  0.549446  0.591900  0.951713  0.770280
mean           0.958233   0.403228 

### Full acc plot for understanding learning mechanism (일단 생략)

In [30]:
# Linear probing on semantic "class" (foreground in Waterbirds)


In [75]:
# Linear probing on "spurious attributes" (backgorund in Waterbirds)

## CelebA

In [None]:
# df_zs_embeddings = pd.read_json("/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/celeba/RN50/embedding_prediction.json")        
full_dict_celeba = {}
opt.dataset = 'celeba'
opt.image_embedding_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/celeba/RN50/embedding_prediction.json"
opt.data_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/celeba"

if __name__ == '__main__':    
    for target in ['class', 'spurious']:
        opt.train_target = target
        for iter in range(1, 4):
            print(f"======= {opt.dataset}_ft_on_{opt.train_target} (No. {iter} )=======")
            t_acc_y, t_acc_s = main(opt)
            full_dict_celeba[f"{opt.dataset}_ft_on_{opt.train_target}({iter})"] = {'class': t_acc_y, 'spurious': t_acc_s}
        
    

# Zero-Shot Evaluation

In [19]:
def update_dict_zs(acc_groups, y, g, preds):
    # preds = torch.argmax(logits, axis=1)
    correct_batch = (preds == y)
    g = g.cpu()
    for g_val in np.unique(g):
        mask = g == g_val
        n = mask.sum().item()
        corr = correct_batch[mask].sum().item()
        acc_groups[g_val].update(corr / n, n) # AverageMeter Updater. 

def validate_zs(val_loader, get_yp_func, train_group_ratio, target, label='Test', watch=True):
    """validation"""

    batch_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()
    acc_groups = {g_idx : AverageMeter() for g_idx in range(val_loader.dataset.n_groups)}
        
    with torch.no_grad():
        end = time.time()
        for idx, data in enumerate(val_loader):
            if opt.dataset == 'waterbirds':
                _, all_labels, _ = data
                labels = all_labels[target] # (y, y_group, y_spurious)
                groups = all_labels['group']
                preds = all_labels['ebd_y_pred']
            else:
                _,  all_labels, _ = data
                labels = all_labels[target] # (y, y_group, y_spurious)
                groups = all_labels['group']
                preds = all_labels['ebd_y_pred']
              
              
            preds = preds.float().cuda()  
            labels = labels.cuda()
            bsz = labels.shape[0]
            
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            
            # update metric    
            acc1 = accuracy_zs(preds, labels, bsz)
            acc.update(acc1, bsz)
            
            # Update acc dict
            update_dict_zs(acc_groups, labels, groups, preds)
            
            # if (idx+1) % opt.print_freq == 0:
            #     print(f'{label}: [{0}/{1}]\t'
            #         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            #         'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
            #         'Acc@1 {acc.val:.3f} ({acc.avg:.3f})'.format(
            #         idx, len(val_loader), batch_time=batch_time,
            #         loss=losses, acc=acc))
                    
                    
    group_acc = get_results(acc_groups, get_yp_func)
    
    #NOTE add Weighted mean acc.
    groups = range(val_loader.dataset.n_groups) # 0, 1, 2, 3
    group_acc_indiv =  [group_acc[f"acc_{get_yp_func(g)[0]}_{get_yp_func(g)[1]}"] for g in groups]
    weighted_mean_acc = (np.array(group_acc_indiv) * np.array(train_group_ratio)).sum() # Weighted Sum \
    
    group_acc["weighted_mean_acc"] = weighted_mean_acc
    group_acc = {key: group_acc[key] for key in new_order_for_print}
    
    if watch:
        print(f"{label}:", str(group_acc))
        print(' * Acc@1 {acc.avg:.3f}'.format(acc=acc))
        
    return  acc.avg, group_acc


def main_zs(opt):
    best_acc = 0
    best_epoch = 0
    # opt = parse_option()
    
    # build dataset example.
    if opt.dataset == 'waterbirds':
        # build dataset example.
        trainset = WaterbirdsEmbeddings(opt.data_dir, 'train', opt.image_embedding_dir, None, None)
        # build data loader
        print("Load Data Loader (Waterbirds) (test)")
        train_loader, val_loader, test_loader = load_waterbirds_embeddings(opt.data_dir, opt.image_embedding_dir, opt.batch_size, opt.batch_size)
    elif opt.dataset == 'celeba':
        # build dataset example.
        trainset = CelebaEmbeddings(opt.data_dir, 'train', opt.image_embedding_dir, None)
        # build data loader
        print("Load Data Loader (CelebA) (test)")
        train_loader, val_loader, test_loader = load_celeba_embeddings(opt.data_dir, opt.image_embedding_dir, opt.batch_size, opt.batch_size)
    
    get_yp_func = partial(get_y_p, n_places=trainset.n_places)
    train_group_ratio = trainset.group_ratio

    # eval for one epoch
    # val_acc, val_group_acc = validate_zs(val_loader, get_yp_func, train_group_ratio, target='y', label='Val(y)')

    test_acc_y, test_group_acc_y = validate_zs(test_loader, get_yp_func, train_group_ratio, target='class', label='Test(class)', watch=True)
    test_acc_spurious, test_group_acc_spurious = validate_zs(test_loader, get_yp_func, train_group_ratio, target='spurious', label='Test(spurious)', watch=True)
    print('===============================Final Results===============================')
    # print('Zero-shot validation accuracy: {} '.format(val_group_acc))
    
    print('Zero-shot test accuracy (class): {}'.format(test_group_acc_y))
    print('Zero-shot test accuracy (spurious): {}'.format(test_group_acc_spurious))
    
    return test_group_acc_y, test_group_acc_spurious

## Waterbirds

In [16]:
full_dict_wb_zs = {}

opt.dataset = 'waterbirds'
# opt.train_target = 'class'

opt.image_embedding_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/waterbirds/RN50/embedding_prediction.json"
opt.data_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/waterbirds/waterbird_complete95_forest2water2"
if __name__ == '__main__':    
    opt.train_target = target
    t_acc_y, t_acc_s = main_zs(opt)
    full_dict_wb_zs[f"{opt_dataset}_zs"] = {'class': t_acc_y, 'spurious': t_acc_s}
        

Load Data Loader (Waterbirds) (validation / test)
Val(y): {'weighted_mean_acc': 0.9450627926530569, 'worst_acc': 0.3233082706766917, 'acc_0_0': 0.9914346895074947, 'acc_0_1': 0.7145922746781116, 'acc_1_0': 0.3233082706766917, 'acc_1_1': 0.8646616541353384, 'mean_acc': 0.7956630525437864}
 * Acc@1 0.796
Test(y): {'weighted_mean_acc': 0.9289761348579215, 'worst_acc': 0.3909657320872274, 'acc_0_0': 0.9804878048780488, 'acc_0_1': 0.7254988913525499, 'acc_1_0': 0.3909657320872274, 'acc_1_1': 0.822429906542056, 'mean_acc': 0.7984121505005177}
 * Acc@1 0.798
Test(spurious): {'weighted_mean_acc': 0.9142166444777964, 'worst_acc': 0.2745011086474501, 'acc_0_0': 0.9804878048780488, 'acc_0_1': 0.2745011086474501, 'acc_1_0': 0.6090342679127726, 'acc_1_1': 0.822429906542056, 'mean_acc': 0.6470486710390059}
 * Acc@1 0.647
Zero-shot (worst-)validation accuracy: {'weighted_mean_acc': 0.9450627926530569, 'worst_acc': 0.3233082706766917, 'acc_0_0': 0.9914346895074947, 'acc_0_1': 0.7145922746781116, 'acc_

### Double Check

In [13]:
df_zs_embeddings = pd.read_json("/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/waterbirds/RN50/embedding_prediction.json")
from copy import deepcopy
w = deepcopy(df_zs_embeddings)
w= w.T
print("Worst acc: ", len(w[(w['split']=='2') & (w['y']=='1') & (w['place']=='0') & (w['y_pred']=='1')]) / len(w[(w['split']=='2') & (w['y']=='1') & (w['place']=='0')]))

Worst acc:  0.3909657320872274


## CelebA

In [9]:
opt.dataset = 'celeba'
opt.train_target = 'class'
opt.image_embedding_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/celeba/RN50/embedding_prediction.json"
opt.data_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/celeba"
if __name__ == '__main__':    
    main_zs(opt)
    

Load Data Loader (train, validation, test)
Val(y): [0/1]	Time 0.009 (0.277)	Loss 0.0000 (0.0000)	Acc@1 0.137 (0.146)
Val(y): [0/1]	Time 0.003 (0.144)	Loss 0.0000 (0.0000)	Acc@1 0.137 (0.148)
Val(y): [0/1]	Time 0.009 (0.099)	Loss 0.0000 (0.0000)	Acc@1 0.150 (0.147)
Val(y): [0/1]	Time 0.003 (0.078)	Loss 0.0000 (0.0000)	Acc@1 0.119 (0.148)
Val(y): [0/1]	Time 0.002 (0.064)	Loss 0.0000 (0.0000)	Acc@1 0.143 (0.147)
Val(y): [0/1]	Time 0.003 (0.056)	Loss 0.0000 (0.0000)	Acc@1 0.156 (0.147)
Val(y): [0/1]	Time 0.002 (0.049)	Loss 0.0000 (0.0000)	Acc@1 0.152 (0.146)
Val(y): [0/1]	Time 0.003 (0.044)	Loss 0.0000 (0.0000)	Acc@1 0.141 (0.147)
Val(y): [0/1]	Time 0.002 (0.041)	Loss 0.0000 (0.0000)	Acc@1 0.150 (0.146)
Val(y): [0/1]	Time 0.002 (0.038)	Loss 0.0000 (0.0000)	Acc@1 0.139 (0.146)
Val(y): [0/1]	Time 0.002 (0.035)	Loss 0.0000 (0.0000)	Acc@1 0.145 (0.147)
Val(y): [0/1]	Time 0.003 (0.033)	Loss 0.0000 (0.0000)	Acc@1 0.131 (0.147)
Val(y): [0/1]	Time 0.002 (0.032)	Loss 0.0000 (0.0000)	Acc@1 0.168 (0.

### Double Check

In [17]:
from copy import deepcopy
df_zs_embeddings = pd.read_json("/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/celeba/RN50/embedding_prediction.json")
w = deepcopy(df_zs_embeddings)
w= w.T
print("Worst acc:", len(w[(w['split']=='2') & (w['blond']=='1') & (w['male']=='1') & (w['y_pred']=='1')]) / len(w[(w['split']=='2') & (w['blond']=='1') & (w['male']=='1')]))

Worst acc: 0.23333333333333334
