In [11]:

import argparse

import time

from copy import deepcopy

from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms


try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
import torchvision.models as models

import clip
from clip.custom_clip import get_coop
from clip.cocoop import get_cocoop
from data.imagnet_prompts import imagenet_classes
from data.datautils import AugMixAugmenter, build_dataset
from utils.tools import Summary, AverageMeter, ProgressMeter, load_model_weight, set_random_seed, create_logger
from data.cls_to_names import *
from data.fewshot_datasets import fewshot_datasets
from data.imagenet_variants import thousand_k_to_200, imagenet_a_mask, imagenet_r_mask, imagenet_v_mask
from clip_retrieval.clip_client import ClipClient, Modality

from collections import defaultdict
%load_ext autoreload
%autoreload 2

client = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=10,
    deduplicate=False,
)
client_backup = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=200,
    deduplicate=False,
)

client_backup2 = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=1000,
    deduplicate=False,
)

## Class to names mapping
fewshot_datasets = ['DTD', 'Flower102', 'Food101', 'Cars', 'SUN397', 
                    'Aircraft', 'Pets', 'Caltech101', 'UCF101', 'eurosat']
test_sets = 'Caltech101'
if test_sets == 'Caltech101':
    cls2names = ['face', 'leopard', 'motorbike', 'accordion', 'airplane', 'anchor', 'ant', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', 'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi', 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang']
elif test_sets == 'DTD':
    cls2names = dtd_classes
elif test_sets =='Cars':
    cls2names = cars_classes
    # print(cls2names[:5])

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:

def select_confident_samples(logits, top):
    batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
    idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * top)]
    return logits[idx], idx

def avg_entropy(outputs):
    # epsilon = 1e-10
    assert len(outputs) > 0
    assert torch.any(torch.isnan(outputs)) == False
    logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) # logits = outputs.log_softmax(dim=1) [N, 1000]
    assert torch.any(torch.isnan(logits)) == False
    avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) # avg_logits = logits.mean(0) [1, 1000]
    # print(avg_logits)
    if torch.any(torch.isnan(avg_logits)):
        print("average logits ", outputs.log_softmax(dim=1).mean(0))
    assert torch.any(torch.isnan(avg_logits)) == False
    
    min_real = torch.finfo(avg_logits.dtype).min
    avg_logits = torch.clamp(avg_logits, min=min_real)
    assert torch.any(torch.isnan(avg_logits)) == False
    return -((avg_logits) * (torch.exp(avg_logits))).sum(dim=-1)

In [13]:

def accuracy(output, target, topk=(1,), caption=None, logger=None):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        if output.shape[0] == 1:#only image prediction
            logit_k, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))
        else: # evaluate captions
            bag = []
            # # length = max(5, output.shape[0]-1)
            # cap_pred = output[1:]
            # # cap_pred = torch.mean(cap_pred, 0,  keepdim=True)
            # _, pred = cap_pred.topk(maxk, 1, True, True) #5, 1 #candidate labels
            # pred = pred.reshape(maxk, 1)
            pred = torch.mean(output, 0, keepdim=True)
            _, pred = pred.topk(maxk, 1, True, True)
            pred = pred.reshape(maxk, 1)
            correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            if k == 1 and correct_k.item() == 0:

                pred = pred.squeeze().tolist()
                pred = [cls2names[lb] for lb in pred]

                if logger: logger.info("wrong prediction, target {} & predicted value {}".format(target, pred))
            elif k==1 and correct_k.item() == 1:
                # logger.info("wrong prediction , logit: ", output)
                pred = pred.squeeze().tolist()
                pred = [cls2names[lb] for lb in pred]

            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [14]:
#Parameters
tta_steps = 1
which_loss = "cosine"
gpu = 3
print_freq = 200

In [15]:
def return_caption(img_path, retrieve_K=1):
    try:
        query_res = client.query(image=img_path)[:retrieve_K]
        retrieved_txt= [D['caption'] for D in query_res]
        retrieved_url = [D['url'] for D in query_res]
        retrieved_score = [D['similarity'] for D in query_res]
    except:
        # print(client.query(image=img_path))
        query_res = client_backup.query(image=img_path)[:retrieve_K]
        retrieved_txt= [D['caption'] for D in query_res]
        retrieved_url = [D['url'] for D in query_res]
        retrieved_score = [D['similarity'] for D in query_res]
    if len(retrieved_txt) != retrieve_K:
        query_res = client_backup2.query(image=img_path)[:retrieve_K]
        retrieved_txt= [D['caption'] for D in query_res]
        retrieved_url = [D['url'] for D in query_res]
        retrieved_score = [D['similarity'] for D in query_res]
        # output = model.caption_ensemble(retrieved_txt)
        
    return retrieved_txt
                    

In [16]:
def test_time_adapt_eval_image(val_loader, model, model_state, optimizer, optim_state, scaler, save_result=None):
    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top2 = AverageMeter('Acc@2', ':6.2f', Summary.AVERAGE)
    top3 = AverageMeter('Acc@3', ':6.2f', Summary.AVERAGE)
    top4 = AverageMeter('Acc@4', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, top1, top2, top3, top4, top5],
        prefix='Test: ',
        logger = None)

    # reset model and switch to evaluate mode
    model.eval()
    with torch.no_grad():
        model.reset()
    end = time.time()
    cnt_empty = 0
    if save_result == None: save_result = defaultdict(list)
    
    for i, (image, target, imagepath) in tqdm(enumerate(val_loader)): 
        assert gpu is not None
        # print("Image Path ", imagepath[0])
        save_result['image_path'].append(imagepath[0])
        
        target = target.cuda(gpu, non_blocking=True)
        
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                #image
                output, output_cap = model.inference(image.cuda(gpu, non_blocking=True), caption= None)
                
        logit_k, pred = output.topk(1, 1, True, True) #1,1
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        correct = correct.reshape(-1).float().sum(0, keepdim=True).item() #1 or 0
        save_result['image_correct'].append(int(correct))
                
        acc1, acc2, acc3, acc4, acc5 = accuracy(output, target, topk=(1, 2, 3, 4, 5), caption=None, logger=None)
        # print(acc1, acc2, acc3, acc4, acc5)
        top1.update(acc1[0], image.size(0))
        top2.update(acc2[0], image.size(0))
        top3.update(acc3[0], image.size(0))
        top4.update(acc4[0], image.size(0))
        top5.update(acc5[0], image.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i+1) % print_freq == 0:
            progress.display(i)
        

    progress.display_summary()
    return [top1.avg, top2.avg, top3.avg, top4.avg, top5.avg], save_result

def test_time_adapt_eval_caption(val_loader, model, model_state, optimizer, optim_state, scaler, save_result=False):
    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top2 = AverageMeter('Acc@2', ':6.2f', Summary.AVERAGE)
    top3 = AverageMeter('Acc@3', ':6.2f', Summary.AVERAGE)
    top4 = AverageMeter('Acc@4', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, top1, top2, top3, top4, top5],
        prefix='Test: ',
        logger = None)

    # reset model and switch to evaluate mode
    model.eval()
    with torch.no_grad():
        model.reset()
    end = time.time()
    cnt_empty = 0
    assert save_result != None
    
    for i, (image, target, imagepath) in tqdm(enumerate(val_loader)): 
        assert gpu is not None
        # print("Image Path ", imagepath[0])
        # save_result['image_path'] = imagepath[0]
        
        target = target.cuda(gpu, non_blocking=True)
        
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                #caption
                retrieved_Caption = return_caption(imagepath[0])
                save_result['caption'].append(retrieved_Caption[0])
                output_caption = model.caption_ensemble(retrieved_Caption)

        logit_k, pred = output_caption.topk(1, 1, True, True) #1,1
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        correct = correct.reshape(-1).float().sum(0, keepdim=True).item() #1 or 0
        save_result['caption_correct'].append(int(correct))
         
        acc1, acc2, acc3, acc4, acc5 = accuracy(output_caption, target, topk=(1, 2, 3, 4, 5), caption=None, logger=None)
        # print(acc1, acc2, acc3, acc4, acc5)
        top1.update(acc1[0], image.size(0))
        top2.update(acc2[0], image.size(0))
        top3.update(acc3[0], image.size(0))
        top4.update(acc4[0], image.size(0))
        top5.update(acc5[0], image.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i+1) % print_freq == 0:
            progress.display(i)
    

    progress.display_summary()
    return [top1.avg, top2.avg, top3.avg, top4.avg, top5.avg], save_result

In [17]:
##parameters
fewshot_datasets = ['DTD', 'Flower102', 'Food101', 'Cars', 'SUN397', 
                    'Aircraft', 'Pets', 'Caltech101', 'UCF101', 'eurosat']
arch='ViT-B/16'
n_ctx=4
ctx_init="a_photo_of_a"
lr = 5e-3

In [18]:
# load model
if test_sets in fewshot_datasets:
    classnames = eval("{}_classes".format(test_sets.lower()))
model = get_coop(arch, test_sets, gpu, n_ctx, ctx_init)
model_state = None

cross_check = set()
for name, param in model.named_parameters():

    if "prompt_learner" not in name:
        param.requires_grad_(False)
    if param.requires_grad : cross_check.add(name)
print("tuing parameters ", cross_check)

print("=> Model created: visual backbone {}".format(arch))

assert gpu is not None
torch.cuda.set_device(gpu)
model = model.cuda(gpu)

trainable_param = model.prompt_learner.parameters()
optimizer = torch.optim.AdamW(trainable_param, lr)
optim_state = deepcopy(optimizer.state_dict())

# setup automatic mixed-precision (Amp) loss scaling
scaler = torch.cuda.amp.GradScaler(init_scale=1000)

print('=> Using native Torch AMP. Training in mixed precision.')

cudnn.benchmark = True

Initializing the contect with given words: [a_photo_of_a]
Initial context: "a photo of a"
Number of context words (tokens): 4
tuing parameters  {'prompt_learner.ctx'}
=> Model created: visual backbone ViT-B/16
=> Using native Torch AMP. Training in mixed precision.


In [19]:
resolution = 224
workers = 4
dataset_mode = 'test'
data = '/data/seongha'
import sys
from collections import defaultdict
    # norm stats from clip.load()
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                     std=[0.26862954, 0.26130258, 0.27577711])
    
    # iterating through eval datasets
datasets = test_sets.split("/")
results = {}

# with open('inferece_image_caption{}.txt'.format(test_sets), 'w') as f:
    # sys.stdout = f
Dict = defaultdict(list)
for set_id in datasets:

    data_transform = transforms.Compose([
        transforms.Resize(resolution, interpolation=BICUBIC),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
        normalize,
    ])
    batchsize = 1
    print("evaluating: {}".format(set_id))
    classnames = eval("{}_classes".format(set_id.lower()))
    model.reset_classnames(classnames, arch)

    val_dataset = build_dataset(set_id, data_transform, data, mode=dataset_mode)
    print("number of test samples: {}".format(len(val_dataset)))

    val_loader = torch.utils.data.DataLoader(
                val_dataset,
                batch_size=batchsize, shuffle=False,
                num_workers=workers, pin_memory=True)
        
    results['image'], tmp = test_time_adapt_eval_image(val_loader, model, model_state, optimizer, optim_state, scaler, Dict)
    Dict = tmp
    results['caption'], tmp = test_time_adapt_eval_caption(val_loader, model, model_state, optimizer, optim_state, scaler, Dict)
    Dict = tmp
    assert len(Dict['image_path']) == len(Dict['caption']) and len(Dict['image_correct']) == len(Dict['caption_correct']), [len(v) for k, v in Dict.items()]
    del val_dataset, val_loader

    # try:
    #     print("=> Acc. on testset [{}]: @1 {}/ @2 {}/ @3 {}/ @4 {}/ @5 {}".format(set_id, results[set_id][0], results[set_id][1], results['image'][2], results['image'][3], results['image'[4], results['image'][5]]))
    # except:
    #     print("=> Acc. on testset [{}]: {}".format(set_id, results[set_id]))
# sys.stdout = sys.__stdout__

evaluating: Caltech101
number of test samples: 2465


206it [00:06, 32.33it/s]

Test: [ 199/2465]	Time  0.031 ( 0.032)	Acc@1 100.00 ( 95.00)	Acc@2 100.00 ( 98.50)	Acc@3 100.00 ( 99.00)	Acc@4 100.00 ( 99.50)	Acc@5 100.00 (100.00)


406it [00:12, 32.34it/s]

Test: [ 399/2465]	Time  0.030 ( 0.032)	Acc@1 100.00 ( 97.50)	Acc@2 100.00 ( 99.25)	Acc@3 100.00 ( 99.50)	Acc@4 100.00 ( 99.75)	Acc@5 100.00 (100.00)


606it [00:18, 32.29it/s]

Test: [ 599/2465]	Time  0.030 ( 0.031)	Acc@1 100.00 ( 98.00)	Acc@2 100.00 ( 99.33)	Acc@3 100.00 ( 99.50)	Acc@4 100.00 ( 99.83)	Acc@5 100.00 (100.00)


806it [00:24, 31.63it/s]

Test: [ 799/2465]	Time  0.035 ( 0.031)	Acc@1 100.00 ( 96.25)	Acc@2 100.00 ( 98.50)	Acc@3 100.00 ( 98.75)	Acc@4 100.00 ( 99.25)	Acc@5 100.00 ( 99.75)


1006it [00:31, 31.82it/s]

Test: [ 999/2465]	Time  0.030 ( 0.031)	Acc@1   0.00 ( 95.70)	Acc@2 100.00 ( 98.60)	Acc@3 100.00 ( 98.90)	Acc@4 100.00 ( 99.40)	Acc@5 100.00 ( 99.80)


1206it [00:37, 31.73it/s]

Test: [1199/2465]	Time  0.032 ( 0.031)	Acc@1 100.00 ( 94.00)	Acc@2 100.00 ( 98.67)	Acc@3 100.00 ( 99.00)	Acc@4 100.00 ( 99.42)	Acc@5 100.00 ( 99.83)


1406it [00:43, 31.73it/s]

Test: [1399/2465]	Time  0.030 ( 0.031)	Acc@1 100.00 ( 93.00)	Acc@2 100.00 ( 98.71)	Acc@3 100.00 ( 99.14)	Acc@4 100.00 ( 99.50)	Acc@5 100.00 ( 99.86)


1606it [00:50, 31.67it/s]

Test: [1599/2465]	Time  0.030 ( 0.031)	Acc@1 100.00 ( 93.81)	Acc@2 100.00 ( 98.81)	Acc@3 100.00 ( 99.19)	Acc@4 100.00 ( 99.56)	Acc@5 100.00 ( 99.88)


1806it [00:56, 31.68it/s]

Test: [1799/2465]	Time  0.030 ( 0.031)	Acc@1 100.00 ( 92.00)	Acc@2 100.00 ( 98.44)	Acc@3 100.00 ( 99.11)	Acc@4 100.00 ( 99.44)	Acc@5 100.00 ( 99.78)


2006it [01:02, 31.69it/s]

Test: [1999/2465]	Time  0.031 ( 0.031)	Acc@1 100.00 ( 92.65)	Acc@2 100.00 ( 98.55)	Acc@3 100.00 ( 99.15)	Acc@4 100.00 ( 99.45)	Acc@5 100.00 ( 99.75)


2206it [01:09, 31.35it/s]

Test: [2199/2465]	Time  0.031 ( 0.031)	Acc@1 100.00 ( 92.50)	Acc@2 100.00 ( 98.45)	Acc@3 100.00 ( 99.18)	Acc@4 100.00 ( 99.45)	Acc@5 100.00 ( 99.73)


2406it [01:15, 31.69it/s]

Test: [2399/2465]	Time  0.030 ( 0.031)	Acc@1 100.00 ( 93.04)	Acc@2 100.00 ( 98.58)	Acc@3 100.00 ( 99.25)	Acc@4 100.00 ( 99.50)	Acc@5 100.00 ( 99.75)


2465it [01:17, 31.87it/s]

 *  Acc@1 92.941 Acc@2 98.499 Acc@3 99.270 Acc@4 99.513 Acc@5 99.757



200it [00:49,  4.01it/s]

Test: [ 199/2465]	Time  0.241 ( 0.246)	Acc@1 100.00 ( 37.50)	Acc@2 100.00 ( 50.50)	Acc@3 100.00 ( 73.50)	Acc@4 100.00 ( 75.50)	Acc@5 100.00 ( 76.50)


400it [01:36,  4.14it/s]

Test: [ 399/2465]	Time  0.226 ( 0.242)	Acc@1 100.00 ( 53.00)	Acc@2 100.00 ( 63.75)	Acc@3 100.00 ( 75.75)	Acc@4 100.00 ( 76.75)	Acc@5 100.00 ( 77.25)


600it [02:24,  4.29it/s]

Test: [ 599/2465]	Time  0.232 ( 0.241)	Acc@1 100.00 ( 53.50)	Acc@2 100.00 ( 62.67)	Acc@3 100.00 ( 71.83)	Acc@4 100.00 ( 74.50)	Acc@5 100.00 ( 75.00)


800it [03:11,  4.05it/s]

Test: [ 799/2465]	Time  0.217 ( 0.240)	Acc@1   0.00 ( 46.88)	Acc@2   0.00 ( 54.88)	Acc@3   0.00 ( 62.50)	Acc@4   0.00 ( 66.25)	Acc@5   0.00 ( 67.75)


1000it [03:59,  4.10it/s]

Test: [ 999/2465]	Time  0.261 ( 0.240)	Acc@1 100.00 ( 46.10)	Acc@2 100.00 ( 54.30)	Acc@3 100.00 ( 61.20)	Acc@4 100.00 ( 64.70)	Acc@5 100.00 ( 66.90)


1200it [04:47,  4.14it/s]

Test: [1199/2465]	Time  0.246 ( 0.240)	Acc@1 100.00 ( 41.67)	Acc@2 100.00 ( 50.00)	Acc@3 100.00 ( 56.75)	Acc@4 100.00 ( 59.92)	Acc@5 100.00 ( 62.33)


1400it [05:34,  4.22it/s]

Test: [1399/2465]	Time  0.236 ( 0.239)	Acc@1   0.00 ( 38.00)	Acc@2   0.00 ( 46.07)	Acc@3   0.00 ( 52.93)	Acc@4   0.00 ( 55.86)	Acc@5   0.00 ( 58.14)


1600it [06:22,  4.16it/s]

Test: [1599/2465]	Time  0.251 ( 0.239)	Acc@1   0.00 ( 37.25)	Acc@2   0.00 ( 44.56)	Acc@3   0.00 ( 51.25)	Acc@4   0.00 ( 54.12)	Acc@5   0.00 ( 56.19)


1800it [07:09,  4.23it/s]

Test: [1799/2465]	Time  0.211 ( 0.239)	Acc@1   0.00 ( 35.00)	Acc@2   0.00 ( 42.11)	Acc@3   0.00 ( 48.78)	Acc@4   0.00 ( 51.94)	Acc@5   0.00 ( 54.11)


2000it [07:57,  4.09it/s]

Test: [1999/2465]	Time  0.276 ( 0.239)	Acc@1 100.00 ( 33.10)	Acc@2 100.00 ( 40.05)	Acc@3 100.00 ( 46.25)	Acc@4 100.00 ( 49.15)	Acc@5 100.00 ( 51.20)


2200it [08:44,  4.49it/s]

Test: [2199/2465]	Time  0.208 ( 0.239)	Acc@1   0.00 ( 31.36)	Acc@2   0.00 ( 37.91)	Acc@3   0.00 ( 43.95)	Acc@4   0.00 ( 47.00)	Acc@5   0.00 ( 49.00)


2400it [09:32,  4.10it/s]

Test: [2399/2465]	Time  0.264 ( 0.239)	Acc@1   0.00 ( 32.88)	Acc@2   0.00 ( 39.08)	Acc@3 100.00 ( 44.83)	Acc@4 100.00 ( 47.67)	Acc@5 100.00 ( 49.62)


2465it [09:48,  4.19it/s]

 *  Acc@1 32.414 Acc@2 38.621 Acc@3 44.300 Acc@4 47.099 Acc@5 49.087





KeyError: 'Caltech101'

In [22]:
import pandas as pd

df = pd.DataFrame(Dict)
df = df.reset_index()
print(df.columns)
print(df.head(2))

img_corr_ind = df.loc[df['image_correct'] == 1, 'index'].to_list()
cap_corr_ind = df.loc[df['caption_correct'] == 1, 'index'].to_list()

Index(['index', 'image_path', 'image_correct', 'caption', 'caption_correct'], dtype='object')
   index                                         image_path  image_correct  \
0      0  /data/seongha/caltech-101/101_ObjectCategories...              1   
1      1  /data/seongha/caltech-101/101_ObjectCategories...              1   

                                             caption  caption_correct  
0                   Female hand on white background.                0  
1  Closeup portrait of adorable baby Royalty Free...                1  


In [26]:
print(len(img_corr_ind), len(cap_corr_ind))
print(len(set(img_corr_ind) - set(cap_corr_ind)))
print(len(set(cap_corr_ind) - set(img_corr_ind)))

union = set(img_corr_ind) | set(cap_corr_ind)
intersection = set(img_corr_ind) & set(cap_corr_ind)
img_diff = set(img_corr_ind) - set(cap_corr_ind)
cap_diff = set(cap_corr_ind) - set(img_corr_ind)

2291 801
1511
21


In [31]:
df.iloc[list(union)] 
print("Max accuracy: {:.4f}".format(len(union)/2465*100)) #image accuracy = 92.941

Max accuracy: 93.7931


In [None]:
for each, name in zip([union, intersection, img_diff, cap_diff],['union', 'intersection', 'img_diff', 'cap_diff']):
    df.iloc[list(each)].to_csv('inference_correct_by_{}.csv'.format(name))