In [1]:


import argparse

import time

from copy import deepcopy

from PIL import Image
import numpy as np
from tqdm import tqdm
import pandas as pd
import os

import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms


try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
import torchvision.models as models

import clip
from clip.custom_clip import get_coop
from clip.cocoop import get_cocoop
from data.imagnet_prompts import imagenet_classes
from data.datautils import AugMixAugmenter, build_dataset
from utils.tools import Summary, AverageMeter, ProgressMeter, load_model_weight, set_random_seed, create_logger
from data.cls_to_names import *
from data.fewshot_datasets import fewshot_datasets
from data.imagenet_variants import thousand_k_to_200, imagenet_a_mask, imagenet_r_mask, imagenet_v_mask
from clip_retrieval.clip_client import ClipClient, Modality
%load_ext autoreload
%autoreload 2

client = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=1000,
    deduplicate=False,
)
client_backup = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=200,
    deduplicate=False,
)

client_backup2 = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=1000,
    deduplicate=False,
)

## Class to names mapping
fewshot_datasets = ['DTD', 'Flower102', 'Food101', 'Cars', 'SUN397', 
                    'Aircraft', 'Pets', 'Caltech101', 'UCF101', 'eurosat']
test_sets = 'Caltech101'


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

def select_confident_samples(logits, top):
    batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
    idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * top)]
    return logits[idx], idx

def avg_entropy(outputs):
    # epsilon = 1e-10
    assert len(outputs) > 0
    assert torch.any(torch.isnan(outputs)) == False
    # logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) # logits = outputs.log_softmax(dim=1) [N, 1000]
    logits = outputs.log_softmax(dim=-1) #[N, 1000]
    assert torch.any(torch.isnan(logits)) == False
    # avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) # avg_logits = logits.mean(0) [1, 1000]
    avg_logits = logits.mean(0) #[1, 1000]
    # print(avg_logits)
    if torch.any(torch.isnan(avg_logits)):
        print("average logits ", outputs.log_softmax(dim=1).mean(0))
    assert torch.any(torch.isnan(avg_logits)) == False
    
    min_real = torch.finfo(avg_logits.dtype).min
    avg_logits = torch.clamp(avg_logits, min=min_real)
    assert torch.any(torch.isnan(avg_logits)) == False
    return -((avg_logits) * (torch.exp(avg_logits))).sum(dim=-1)

In [3]:

def accuracy(output, target, topk=(1,), caption=None, logger=None, args=None):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        if output.shape[0] == 1:#only one prediction
            logit_k, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))
        else: # evaluate captions
            
            pred = torch.mean(output, 0, keepdim=True)
            logit_k, pred = pred.topk(maxk, 1, True, True)
            pred = pred.reshape(maxk, 1)
            
            #majority voting
            # bag = []
            # for i, each in enumerate(output):
            #     if i==0: # image
            #         # continue
            #         each = each.unsqueeze(0) #1, 1000
            #         _, image_pred = each.topk(maxk, 1, True, True) #1, 5
            #         image_pred = image_pred.t() #5, 1
            #     else: #caption
            #         each = each.unsqueeze(0) #1, 1000
            #         _, pred = each.topk(maxk, 1, True, True)
            #         pred = pred.t() #5, 1
            #         for elem in pred.tolist(): bag.append(elem[0])
            # # # _, pred = output.topk(maxk, 1, True, True)
            # # # pred = pred.t()
            # c = Counter(bag)

            # pred = c.most_common(maxk)

            # print("caption prediction" , pred)
            # print("label ", target.view(1, -1).expand_as(pred))
            correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            # if k == 1 and correct_k.item() == 0:
            #     # print("-------wrong prediction-------")
            #     # logger.info("wrong prediction , logit: ", output)
            #     pred = pred.squeeze().tolist()
            #     pred = [cls2names[lb] for lb in pred]
            #     # print("target: [{}]".format( cls2names[target.squeeze().item()]))
            #     # print("predicted category & logit: {}".format(list(zip(pred, logit_k.squeeze().tolist()))))
            #     # print("logit ", logit_k)
            #     if logger: logger.info("wrong prediction, target {} & predicted value {}".format(target, pred))
            #     # print("-------------------------------")
            # elif k==1 and correct_k.item() == 1:
            #     # print("-------correct prediction-------")
            #     # logger.info("wrong prediction , logit: ", output)
            #     pred = pred.squeeze().tolist()
            #     pred = [cls2names[lb] for lb in pred]
            #     # print("target: [{}]".format( cls2names[target.squeeze().item()]))
            #     # print("predicted category & logit: {}".format(list(zip(pred, logit_k.squeeze().tolist()))))
                # print("-------------------------------")
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [4]:
#Parameters
tta_steps = 1
which_loss = "entropy"
gpu = 4
print_freq = 1000
retrieve_K = 4

In [5]:
def test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, save_result=False):
    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top2 = AverageMeter('Acc@2', ':6.2f', Summary.AVERAGE)
    top3 = AverageMeter('Acc@3', ':6.2f', Summary.AVERAGE)
    top4 = AverageMeter('Acc@4', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, top1, top2, top3, top4, top5],
        prefix='Test: ',
        logger = None)

    # reset model and switch to evaluate mode
    model.eval()
    with torch.no_grad():
        model.reset()
    end = time.time()
    assert save_result != None
    cnt_empty = 0
    
    for i, (image, target, imagepath) in tqdm(enumerate(val_loader)): 
        assert gpu is not None
        # print(" ")
        # print("Image Path ", imagepath)
        target = target.cuda(gpu, non_blocking=True)
        
        ### One time training
        # reset the tunable prompt to its initial state

        with torch.no_grad():
            model.reset()
        optimizer.load_state_dict(optim_state)
                
        # retrieved_caption = test_time_tuning(model, image.cuda(gpu, non_blocking=True), optimizer, scaler)
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                img_path = imagepath[0]
                
                try:
                    query_res = client.query(image=img_path)[:retrieve_K]
                    retrieved_txt= [D['caption'] for D in query_res]
                    retrieved_url = [D['url'] for D in query_res]
                    retrieved_score = [D['similarity'] for D in query_res]
                    assert len(query_res) == retrieve_K
                except:
                    # print(client.query(image=img_path))
                    query_res = client_backup2.query(image=img_path)[:retrieve_K]
                    retrieved_txt= [D['caption'] for D in query_res]
                    retrieved_url = [D['url'] for D in query_res]
                    retrieved_score = [D['similarity'] for D in query_res]
                if len(retrieved_txt) == retrieve_K:
                    output = model.caption_ensemble(retrieved_txt)
                    # output_merged = model.caption_ensemble(retrieved_txt, retrieved_score)
                else:
                    cnt_empty +=1
                    continue

        ent = avg_entropy(output)
        save_result['caption_entropy'].append('{:.4f}'.format(ent))
        # check for each caption
        # for i, c_i in enumerate(output):
        #     print("Caption: ", retrieved_txt[i])
        #     print("paired Image: ", retrieved_url[i])
        #     c_i = c_i.reshape(1, -1)
        #     ent = avg_entropy(c_i)
        #     print("caption entropy ", ent)
        #     acc1, acc2, acc3, acc4, acc5 = accuracy(c_i, target, topk=(1, 2, 3, 4, 5), caption=None, logger=None)
        # merged
        save_result['image_path'].append(imagepath[0])
        weighted =[]
        assert len(retrieved_score) == len(output), (len(retrieved_score), len(output))
        # print()
        for score, logit in zip(retrieved_score, output):
            logit = torch.nn.functional.softmax(logit, dim=-1)
            weighted.append(score/sum(retrieved_score) * logit)
            # print(score, logit.shape)
        # weighted = [score * logit for score, logit in zip(retrieved_score, output)]/sum(retrieved_score)
        # print(weighted)
        weighted = torch.sum(torch.stack(weighted), axis=0).reshape(1, -1)
        # assert weighted.shape == (1, len(cls2names)), weighted.shape
        # print(weighted.shape)
        logit_k, pred = weighted.topk(2, 1, True, True)
        logit_k = logit_k.squeeze()
        pred = pred[:,0].t()
        correct = pred.eq(target)
        correct = correct.reshape(-1).float().sum(0, keepdim=True).item()
        save_result['caption_correct'].append(int(correct))
        save_result['caption_logit'].append('{:4f}'.format(logit_k[0].item()))
        if correct == 1:
            save_result['caption_gap'].append('{:.4f}'.format(logit_k[0].item() - logit_k[1].item()))
        else:
            save_result['caption_gap'].append('{:.4f}'.format(logit_k[0].item() - weighted.squeeze()[target].item()))
        
        acc1, acc2, acc3, acc4, acc5 = accuracy(weighted, target, topk=(1, 2, 3, 4, 5), caption=None, logger=None)
        # print(acc1, acc2, acc3, acc4, acc5)
        top1.update(acc1[0], image.size(0))
        top2.update(acc2[0], image.size(0))
        top3.update(acc3[0], image.size(0))
        top4.update(acc4[0], image.size(0))
        top5.update(acc5[0], image.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i+1) % print_freq == 0:
            progress.display(i)
        

    progress.display_summary()
    return [top1.avg, top2.avg, top3.avg, top4.avg, top5.avg]

In [6]:

##parameters
arch='ViT-B/16'
n_ctx=4
ctx_init="a_photo_of_a"
lr = 5e-3

In [7]:
# load model
if test_sets in fewshot_datasets:
    classnames = eval("{}_classes".format(test_sets.lower()))
model = get_coop(arch, test_sets, gpu, n_ctx, ctx_init)
model_state = None

cross_check = set()
for name, param in model.named_parameters():

    if "prompt_learner" not in name:
        param.requires_grad_(False)
    if param.requires_grad : cross_check.add(name)
print("tuing parameters ", cross_check)

print("=> Model created: visual backbone {}".format(arch))

assert gpu is not None
torch.cuda.set_device(gpu)
model = model.cuda(gpu)

trainable_param = model.prompt_learner.parameters()
optimizer = torch.optim.AdamW(trainable_param, lr)
optim_state = deepcopy(optimizer.state_dict())

# setup automatic mixed-precision (Amp) loss scaling
scaler = torch.cuda.amp.GradScaler(init_scale=1000)

print('=> Using native Torch AMP. Training in mixed precision.')

cudnn.benchmark = True

Initializing the contect with given words: [a_photo_of_a]
Initial context: "a photo of a"
Number of context words (tokens): 4
tuing parameters  {'prompt_learner.ctx'}
=> Model created: visual backbone ViT-B/16
=> Using native Torch AMP. Training in mixed precision.


In [8]:

resolution = 224
workers = 4
dataset_mode = 'test'
data = '/data/seongha'

from collections import defaultdict
    # norm stats from clip.load()
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                     std=[0.26862954, 0.26130258, 0.27577711])
    
    # iterating through eval datasets
datasets = test_sets.split("/")
results = {}

for set_id in datasets:
    Dict = defaultdict(list)
    data_transform = transforms.Compose([
        transforms.Resize(resolution, interpolation=BICUBIC),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
        normalize,
    ])
    
    batchsize = 1
    print("evaluating: {}".format(set_id))
    classnames = eval("{}_classes".format(set_id.lower()))
    model.reset_classnames(classnames, arch)

    val_dataset = build_dataset(set_id, data_transform, data, mode=dataset_mode)
    total_length = len(val_dataset)
    print("number of test samples: {}".format(len(val_dataset)))

    val_loader = torch.utils.data.DataLoader(
                val_dataset,
                batch_size=batchsize, shuffle=False,
                num_workers=workers, pin_memory=True)
        
    results[set_id] = test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, Dict)
    del val_dataset, val_loader

    # try:
    #     print("=> Acc. on testset [{}]: @1 {}/ @2 {}/ @3 {}/ @4 {}/ @5 {}".format(set_id, results[set_id][0], results[set_id][1], results[set_id][2], results[set_id][3], results[set_id[4], results[set_id][5]]))
    # except:
    #     print("=> Acc. on testset [{}]: {}".format(set_id, results[set_id]))

    df = pd.DataFrame(Dict)
    df = df.reset_index()

    path = './notebook/caption_ensemble'
    os.makedirs(path, exist_ok=True)
    df.to_csv(os.path.join(path, 'caption_ensemble_{}.csv'.format(set_id)))
    
    with open(os.path.join(path,'caption_ensemble_{}.txt'.format(set_id)), 'w') as f:
        
        cap_corr_ind = df.loc[df['caption_correct'] == 1, 'index'].to_list()
        #image accuracy, caption ensemble accuracy
        f.write("1. Caption Accuracy {:.4f}".format( len(cap_corr_ind)/total_length))
        #entropy, logit gap
        f.write("Entropy & Logit Gap\n")
        cap_correct = df.loc[df['caption_correct'] == 1]
        f.write("correct\n")
        f.write(" {}\n".format(str(cap_correct.shape)))
        f.write("top1 - top2 mean, std\n")
        f.write("{:.4f} {:.4f}\n".format(cap_correct['caption_gap'].astype(float).mean(), cap_correct['caption_gap'].astype(float).std() ))
        f.write("Entropy mean {}\n".format(str(cap_correct['caption_entropy'].astype(float).mean())))
        f.write("")
        cap_wrong = df.loc[df['caption_correct'] == 0]
        f.write("wrong\n")
        f.write(" {}\n".format(str(cap_wrong.shape)))
        f.write("pred(top1) - target mean, std\n")
        f.write("{:.4f} {:.4f}\n".format(cap_wrong['caption_gap'].astype(float).mean(), cap_wrong['caption_gap'].astype(float).std() ))
        f.write("Entropy mean {}\n".format(str(cap_wrong['caption_entropy'].astype(float).mean())))
        


evaluating: Caltech101


number of test samples: 2465


1000it [09:46,  1.38it/s]

Test: [ 999/2465]	Time  0.309 ( 0.587)	Acc@1   0.00 ( 90.90)	Acc@2 100.00 ( 94.20)	Acc@3 100.00 ( 95.30)	Acc@4 100.00 ( 95.70)	Acc@5 100.00 ( 96.20)


2000it [19:36,  2.25it/s]

Test: [1999/2465]	Time  0.237 ( 0.589)	Acc@1 100.00 ( 81.85)	Acc@2 100.00 ( 89.25)	Acc@3 100.00 ( 91.80)	Acc@4 100.00 ( 92.50)	Acc@5 100.00 ( 93.35)


2465it [24:11,  1.70it/s]

 *  Acc@1 81.744 Acc@2 88.803 Acc@3 91.034 Acc@4 91.684 Acc@5 92.414



