In [3]:
import numpy as np
import pandas as pd

import sys
sys.path.append("/home/jinsu/workstation/project/debiasing-multi-modal")

dict_embeddings=np.load("/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/waterbirds/resnet50/image_embedding.npy", allow_pickle=True).item()

In [9]:
import torch
import torch.nn as nn
from __future__ import print_function

import sys
import argparse
import time
import math

import torch
import torch.backends.cudnn as cudnn

from util import AverageMeter
from util import adjust_learning_rate, warmup_learning_rate, accuracy
from util import set_optimizer
# from networks.resnet_big import SupConResNet, LinearClassifier

try:
    import apex
    from apex import amp, optimizers
except ImportError:
    pass

from resnet import resnet50
from data.embeddings import Embeddings, load_embeddings
model_dict = {'resnet50': [resnet50, 2048]}

class LinearClassifier(nn.Module):
    def __init__(self, name='resnet50', num_classes=2):
        super(LinearClassifier, self).__init__()
        _, feat_dim = model_dict[name]
        self.fc = nn.Linear(feat_dim, num_classes)

    def forward(self, features):
        return self.fc(features)

def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50,
                        help='save frequency')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='batch_size')
    parser.add_argument('--num_workers', type=int, default=16,
                        help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.1,
                        help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90',
                        help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.2,
                        help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=0,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')

    # model dataset
    parser.add_argument('--model', type=str, default='resnet50')
    parser.add_argument('--dataset', type=str, default='waterbirds',
                        choices=['celebA', 'waterbirds'], help='dataset')

    # other setting
    parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
    parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')

    parser.add_argument('--embedding_dir', type=str,
                        help='extracted embedding')
    parser.add_argument('--target', type=str, default="class", choices=["class", "group", "spurious"])
    parser.add_argument('--data_dir', type=str,
                        help='metadata.csv')

    opt = parser.parse_args(args=[])

    # set the path according to the environment

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))

    opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\
        format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
               opt.batch_size)

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)

    # warm-up for large-batch training,
    if opt.warm:
        opt.model_name = '{}_warm'.format(opt.model_name)
        opt.warmup_from = 0.01
        opt.warm_epochs = 10
        if opt.cosine:
            eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
            opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
        else:
            opt.warmup_to = opt.learning_rate
            
    if opt.dataset == 'celebA':
        opt.n_cls = 2
    elif opt.dataset == 'waterbirds':
        opt.n_cls = 2
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))

    return opt


def set_model(opt):
    # model = SupConResNet(name=opt.model)
    criterion = torch.nn.CrossEntropyLoss()

    classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)

    # ckpt = torch.load(opt.ckpt, map_location='cpu')
    # state_dict = ckpt['model']

    if torch.cuda.is_available():
        # if torch.cuda.device_count() > 1:
        #     model.encoder = torch.nn.DataParallel(model.encoder)
        # else:
            # new_state_dict = {}
            # for k, v in state_dict.items():
            #     k = k.replace("module.", "")
            #     new_state_dict[k] = v
            # state_dict = new_state_dict
        
        # model = model.cuda()
        classifier = classifier.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

        # model.load_state_dict(state_dict)

    return classifier, criterion # model, 


def train(train_loader, classifier, criterion, optimizer, epoch, opt): # model,
    """one epoch training"""
    # model.eval()
    classifier.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()

    end = time.time()
    for idx, data in enumerate(train_loader):
        if opt.dataset == 'waterbirds':
            embeddings, all_labels, _ = data
            labels = all_labels[opt.target] # (y, y_group, y_spurious)
        else:
            embeddings, all_labels = data
            labels = all_labels[opt.target] # (y, y_group, y_spurious)
        
        data_time.update(time.time() - end)

        embeddings = embeddings.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        # warm-up learning rate
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # compute loss
        # with torch.no_grad():
        #     features = model.encoder(embeddings)
        output = classifier(features.detach())
        loss = criterion(output, labels)

        # update metric
        losses.update(loss.item(), bsz)
        acc1 = accuracy(output, labels, bsz)
        acc.update(acc1, bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Acc@1 {acc.val:.3f} ({acc.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, acc=acc))
            sys.stdout.flush()

    return losses.avg, acc.avg


def validate(val_loader, model, classifier, criterion, opt):
    """validation"""
    
    # model.eval()
    classifier.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()

    with torch.no_grad():
        end = time.time()
        for idx, data in enumerate(val_loader):
            if opt.dataset == 'waterbirds':
                embeddings, all_labels, _ = data
                labels = all_labels[opt.target] # (y, y_group, y_spurious)
            else:
                embeddings, all_labels = data
                labels = all_labels[opt.target] # (y, y_group, y_spurious)
            
            embeddings = embeddings.float().cuda()
            labels = labels.cuda()
            bsz = labels.shape[0]

            # forward
            output = classifier(model.encoder(embeddings))
            loss = criterion(output, labels)

            # update metric
            losses.update(loss.item(), bsz)
            acc1 = accuracy(output, labels, bsz)
            acc.update(acc1[0], bsz)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if idx % opt.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {acc.val:.3f} ({acc.avg:.3f})'.format(
                       idx, len(val_loader), batch_time=batch_time,
                       loss=losses, acc=acc))

    print(' * Acc@1 {acc.avg:.3f}'.format(acc=acc))
    return losses.avg, acc.avg


def main(opt):
    best_acc = 0
    # opt = parse_option()

    # build data loader
    train_loader, val_loader, test_loader = load_embeddings(opt.data_dir, opt.embedding_dir, opt.model, opt.batch_size, opt.batch_size)
    
    # build model and criterion
    classifier, criterion = set_model(opt) # model, 

    # build optimizer
    optimizer = set_optimizer(opt, classifier)

    # training routine
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        loss, acc = train(val_loader, classifier, criterion,
                          optimizer, epoch, opt)
        time2 = time.time()
        print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
            epoch, time2 - time1, acc))

        # eval for one epoch
        loss, val_acc = validate(val_loader, classifier, criterion, opt)
        if val_acc > best_acc:
            best_acc = val_acc
            
        loss, val_acc = validate(test_loader, classifier, criterion, opt)
        if val_acc > best_acc:
            best_acc = val_acc

    print('best accuracy: {:.2f}'.format(best_acc))

In [10]:
import argparse

parser = argparse.ArgumentParser('argument for training')

parser.add_argument('--print_freq', type=int, default=10,
                    help='print frequency')
parser.add_argument('--save_freq', type=int, default=50,
                    help='save frequency')
parser.add_argument('--batch_size', type=int, default=512,
                    help='batch_size')
parser.add_argument('--num_workers', type=int, default=16,
                    help='num of workers to use')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of training epochs')

# optimization
parser.add_argument('--learning_rate', type=float, default=5,
                    help='learning rate')
parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90',
                    help='where to decay lr, can be a list')
parser.add_argument('--lr_decay_rate', type=float, default=0.2,
                    help='decay rate for learning rate')
parser.add_argument('--weight_decay', type=float, default=0,
                    help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9,
                    help='momentum')

# model dataset
parser.add_argument('--model', type=str, default='resnet50')
parser.add_argument('--dataset', type=str, default='waterbirds',
                    choices=['celebA', 'waterbirds'], help='dataset')

# other setting
parser.add_argument('--cosine', action='store_true',
                    help='using cosine annealing')
parser.add_argument('--warm', action='store_true',
                    help='warm-up for large batch training')

parser.add_argument('--embedding_dir', type=str, 
                    help='extracted embedding')
parser.add_argument('--target', type=str, default="class", choices=["class", "group", "spurious"])
parser.add_argument('--data_dir', type=str,
                    help='metadata.csv')

opt = parser.parse_args(args=[])   



iterations = opt.lr_decay_epochs.split(',')
opt.lr_decay_epochs = list([])
for it in iterations:
    opt.lr_decay_epochs.append(int(it))

opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\
    format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
            opt.batch_size)

if opt.cosine:
    opt.model_name = '{}_cosine'.format(opt.model_name)

# warm-up for large-batch training,
if opt.warm:
    opt.model_name = '{}_warm'.format(opt.model_name)
    opt.warmup_from = 0.01
    opt.warm_epochs = 10
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
        opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
    else:
        opt.warmup_to = opt.learning_rate
        
if opt.dataset == 'celebA':
    opt.n_cls = 2
elif opt.dataset == 'waterbirds':
    opt.n_cls = 2
else:
    raise ValueError('dataset not supported: {}'.format(opt.dataset))

In [11]:
opt.embedding_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings/waterbirds/"
opt.data_dir = "/home/jinsu/workstation/project/debiasing-multi-modal/data/waterbirds/"

In [12]:
if __name__ == '__main__':    
    
    main(opt)
    

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/jinsu/anaconda3/envs/cuda_test/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/jinsu/anaconda3/envs/cuda_test/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/jinsu/anaconda3/envs/cuda_test/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings.py", line 61, in __getitem__
    z = self.embeddings_dict[img_filename.split('/')[-1]]
KeyError: 'Black_Footed_Albatross_0009_34.jpg'


In [13]:
for key in dict_embeddings.keys():
    if '0009_34' in key:

In [None]:
def train(train_loader, classifier, criterion, optimizer, epoch, opt)