In [1]:

import argparse

import time

from copy import deepcopy

from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import pandas as pd
import os

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
import torchvision.models as models

import clip
from clip.custom_clip import get_coop
from clip.cocoop import get_cocoop
from data.imagnet_prompts import imagenet_classes
from data.datautils import AugMixAugmenter, build_dataset
from utils.tools import Summary, AverageMeter, ProgressMeter, load_model_weight, set_random_seed, create_logger
from data.cls_to_names import *
from data.fewshot_datasets import fewshot_datasets
from data.imagenet_variants import thousand_k_to_200, imagenet_a_mask, imagenet_r_mask, imagenet_v_mask
from clip_retrieval.clip_client import ClipClient, Modality

from collections import defaultdict
%load_ext autoreload
%autoreload 2

client = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=10,
    deduplicate=False,
)
client_backup = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=200,
    deduplicate=False,
)

client_backup2 = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=1000,
    deduplicate=False,
)

## Class to names mapping
fewshot_datasets = ['DTD', 'Flower102', 'Food101', 'Cars', 'SUN397', 
                    'Aircraft', 'Pets', 'Caltech101', 'UCF101', 'eurosat']
test_sets = 'Caltech101'#Parameters
tta_steps = 1
which_loss = "cosine"
gpu = 4
print_freq = 200
arch='ViT-B/16'
n_ctx=4
ctx_init="a_photo_of_a"
lr = 5e-3
global retrieve_K
retrieve_K= 32

  from .autonotebook import tqdm as notebook_tqdm


In [2]:


def select_confident_samples(logits, top):
    batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
    idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * top)]
    return logits[idx], idx

def avg_entropy(outputs):
    # epsilon = 1e-10
    assert len(outputs) > 0
    assert torch.any(torch.isnan(outputs)) == False
    # logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) # logits = outputs.log_softmax(dim=1) [N, 1000]
    logits = outputs.log_softmax(dim=-1) #[N, 1000]
    assert torch.any(torch.isnan(logits)) == False
    # avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) # avg_logits = logits.mean(0) [1, 1000]
    avg_logits = logits.mean(0) #[1, 1000]
    # print(avg_logits)
    if torch.any(torch.isnan(avg_logits)):
        print("average logits ", outputs.log_softmax(dim=1).mean(0))
    assert torch.any(torch.isnan(avg_logits)) == False
    
    min_real = torch.finfo(avg_logits.dtype).min
    avg_logits = torch.clamp(avg_logits, min=min_real)
    assert torch.any(torch.isnan(avg_logits)) == False
    return -((avg_logits) * (torch.exp(avg_logits))).sum(dim=-1)


# logit = torch.tensor([12.0, 3, 3, 3.0], requires_grad=True)
# logit2 = torch.tensor([3, 3, 3, 12.0], requires_grad=True)
# GT = torch.tensor([3, 3, 3, 9.0], requires_grad=True)
# for each in [logit, logit2]:
#     print("logit ", each)
#     print("target dist ", GT)
#     print(avg_entropy(each))
#     # print(avg_entropy(torch.nn.functional.softmax(each, dim=-1)))
    
#     input = F.log_softmax(each, dim=-1)
#     # input = F.log_softmax(torch.nn.functional.softmax(each, dim=-1), dim=-1)
#     target = F.softmax(GT, dim=-1)
#     one = 0.5 *kl_loss(input, target)
#     print(input)
#     print("kl loss: ", kl_loss(input, target))

#     input = F.log_softmax(GT, dim=-1)
#     # input = F.log_softmax(torch.nn.functional.softmax(GT, dim=-1), dim=-1)
#     target = F.softmax(each, dim=-1)
#     two = 0.5 *kl_loss(input, target)
#     print("kl loss: ", kl_loss(input, target))
#     print("JS loss : {}\n".format( one + two))
    
def JSdiv(logit, logit2):
    assert logit.shape == logit2.shape, (logit.shape, logit2.shape)
    kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
    input = F.log_softmax(logit, dim=-1)
    target = F.softmax(logit2, dim=-1)
    
    one = 0.5 *kl_loss(input, target)
    
    input = F.log_softmax(logit2, dim=-1)
    target = F.softmax(logit, dim=-1)
    two = 0.5 *kl_loss(input, target)
    return one+two

def KLreliable(logit, logit2):
    assert logit.shape == logit2.shape, (logit.shape, logit2.shape)
    kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
    if avg_entropy(logit) + 0.2 < avg_entropy(logit2) :
        # first is more reliable
        # print("image is more reliable")
        input = F.log_softmax(logit2, dim=-1)
        target = F.softmax(logit, dim=-1)
        return kl_loss(input, target)
    elif avg_entropy(logit2) + 0.2 < avg_entropy(logit) :
        # print("caption is more reliable")
        input = F.log_softmax(logit, dim=-1)
        target = F.softmax(logit2, dim=-1)
        return kl_loss(input, target)
    else:
        return None

In [3]:

def accuracy(output, target, topk=(1,), caption=None, logger=None):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        if output.shape[0] == 1:#only image prediction
            logit_k, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))
        else: # evaluate captions
            pred = torch.mean(output, 0, keepdim=True)
            _, pred = pred.topk(maxk, 1, True, True)
            pred = pred.reshape(maxk, 1)
            correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [4]:

def return_caption(img_path, retrieve_K=1):
    try:
        query_res = client.query(image=img_path)[:retrieve_K]
        retrieved_txt= [D['caption'] for D in query_res]
        retrieved_url = [D['url'] for D in query_res]
        retrieved_score = [D['similarity'] for D in query_res]
        assert len(query_res) == retrieve_K
    except:
        query_res = client_backup2.query(image=img_path)[:retrieve_K]
        retrieved_txt= [D['caption'] for D in query_res]
        retrieved_url = [D['url'] for D in query_res]
        retrieved_score = [D['similarity'] for D in query_res]
    if len(retrieved_txt) == retrieve_K:
        return retrieved_txt, retrieved_score
    else:
        return None, None
                    

In [5]:

def test_time_tuning(model, inputs, optimizer, scaler, imagepath = None):
    # Entropy + Triplet loss function * 0.5
    # Triplet loss function, anchor = retrieved vocab, positive = top5, negative = bottom5
    for j in range(tta_steps):
        with torch.cuda.amp.autocast():
            output_img, text_features = model(inputs) # bs, n_cls, (1, 1000), logit/ n_cls, 512
            logit_scale = model.logit_scale.exp()
            # ent = avg_entropy(output_img)
            #caption
            retrieved_Caption, retrieved_score = return_caption(imagepath, retrieve_K=retrieve_K)
            
            if retrieved_Caption==None:
                return None
            output_caption = model.caption_ensemble(retrieved_Caption)
            # weighted = []
            # for score, logit in zip(retrieved_score, output_caption):
            #     logit = torch.nn.functional.softmax(logit, dim=-1)
            #     weighted.append(score/sum(retrieved_score) * logit)

            # tmp = torch.sum(torch.cat(weighted[:retrieve_K]).reshape(retrieve_K, -1), axis=0).reshape(1, -1)
            loss = []
            for each in output_caption:
                # loss.append(JSdiv(output_img.squeeze(), each.squeeze()))
                L = KLreliable(output_img.squeeze(), each.squeeze())
                if L: loss.append(L)
            if len(loss) == 0: continue
            # loss /= retrieve_K #average JSdiv
            if retrieve_K >= 32:
                loss.sort()
                cut = int(len(loss) * 0.1)
                loss_val = torch.mean(torch.stack(loss[:cut]))
            else:
                loss_val = torch.mean(torch.stack(loss))
            optimizer.zero_grad() 
            scaler.scale(loss_val).backward()
            scaler.step(optimizer)
            scaler.update()

    # print("empty caption {}".format(cnt_empty))
    return retrieved_Caption

In [6]:
def test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, save_result=False):
    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top2 = AverageMeter('Acc@2', ':6.2f', Summary.AVERAGE)
    top3 = AverageMeter('Acc@3', ':6.2f', Summary.AVERAGE)
    top4 = AverageMeter('Acc@4', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, top1, top2, top3, top4, top5],
        prefix='Test: ',
        logger = None)

    # reset model and switch to evaluate mode
    model.eval()
    with torch.no_grad():
        model.reset()
    end = time.time()
    cnt_empty = 0
    assert save_result != None
    
    for i, (image, target, imagepath) in tqdm(enumerate(val_loader)): 
        assert gpu is not None
        # print("Image Path ", imagepath[0])
        
        target = target.cuda(gpu, non_blocking=True)
        ### One time training

        with torch.no_grad():
            model.reset()
        optimizer.load_state_dict(optim_state)
        # TTA
        retrieved_Caption = test_time_tuning(model, image.cuda(gpu, non_blocking=True), optimizer, scaler, imagepath[0])
        if retrieved_Caption==None:
            cnt_empty +=1
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                #image
                output_img = model.inference(image.cuda(gpu, non_blocking=True))
                save_result['image_path'] = imagepath[0]
                # save_result['caption'].append(retrieved_Caption[0])
                # output_caption = model.caption_ensemble(retrieved_Caption)

        # output_caption = output_caption / output_caption.norm(dim=-1, keepdim=True)
        ent_img = avg_entropy(output_img)
        # ent_caption = avg_entropy(output_caption)
        save_result['image_entropy'].append('{:.4f}'.format(ent_img))
        # save_result['caption_entropy'].append('{:.4f}'.format(ent_caption))
        
        #TODO pick smaller entropy
        # if ent_img < ent_caption:
        prob = torch.nn.functional.softmax(output_img, dim=-1)
        # else:
            # prob = torch.nn.functional.softmax(output_caption, dim=-1)
        logit_k, pred = prob.topk(2, 1, True, True) #1,2
        logit_k = logit_k.squeeze()
        pred = pred[:,0].t()
        correct = pred.eq(target)
        correct = correct.reshape(-1).float().sum(0, keepdim=True).item() #1 or 0
        save_result['correct'].append(int(correct))
        save_result['logit'].append('{:.4f}'.format(logit_k[0].item()))
        if correct == 1:
            # correct label - top2
            # print(logit_k)
            save_result['logit_gap'].append('{:.4f}'.format(logit_k[0].item() - logit_k[1].item()))
        else:
            # incorrect top1 - real label
            # print(target, output.squeeze()[target], logit_k )
            save_result['logit_gap'].append('{:.4f}'.format(logit_k[0].item() - prob.squeeze()[target].item()))
        
         
        acc1, acc2, acc3, acc4, acc5 = accuracy(prob, target, topk=(1, 2, 3, 4, 5), caption=None, logger=None)
        # print(acc1, acc2, acc3, acc4, acc5)
        top1.update(acc1[0], image.size(0))
        top2.update(acc2[0], image.size(0))
        top3.update(acc3[0], image.size(0))
        top4.update(acc4[0], image.size(0))
        top5.update(acc5[0], image.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i+1) % print_freq == 0:
            progress.display(i)
    
    print("empty caption count = {}".format(cnt_empty))
    progress.display_summary()
    return [top1.avg, top2.avg, top3.avg, top4.avg, top5.avg], save_result

In [7]:
# load model
if test_sets in fewshot_datasets:
    classnames = eval("{}_classes".format(test_sets.lower()))
model = get_coop(arch, test_sets, gpu, n_ctx, ctx_init)
model_state = None

cross_check = set()
for name, param in model.named_parameters():

    if "prompt_learner" not in name:
        param.requires_grad_(False)
    if param.requires_grad : cross_check.add(name)
print("tuing parameters ", cross_check)

print("=> Model created: visual backbone {}".format(arch))

assert gpu is not None
torch.cuda.set_device(gpu)
model = model.cuda(gpu)

trainable_param = model.prompt_learner.parameters()
optimizer = torch.optim.AdamW(trainable_param, lr)
optim_state = deepcopy(optimizer.state_dict())

# setup automatic mixed-precision (Amp) loss scaling
scaler = torch.cuda.amp.GradScaler(init_scale=1000)

print('=> Using native Torch AMP. Training in mixed precision.')

cudnn.benchmark = True

resolution = 224
workers = 4
dataset_mode = 'test'
data = '/data/seongha'
import sys
from collections import defaultdict
    # norm stats from clip.load()
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                     std=[0.26862954, 0.26130258, 0.27577711])
    
    # iterating through eval datasets
datasets = test_sets.split("/")
results = {}

# with open('inferece_image_caption{}.txt'.format(test_sets), 'w') as f:
    # sys.stdout = f
for set_id in datasets:
    print("retrieve K: {}".format(retrieve_K))
    Dict = defaultdict(list)
    data_transform = transforms.Compose([
        transforms.Resize(resolution, interpolation=BICUBIC),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
        normalize,
    ])
    batchsize = 1
    print("evaluating: {}".format(set_id))
    classnames = eval("{}_classes".format(set_id.lower()))
    model.reset_classnames(classnames, arch)

    val_dataset = build_dataset(set_id, data_transform, data, mode=dataset_mode)
    total_length = len(val_dataset)
    print("number of test samples: {}".format(len(val_dataset)))

    val_loader = torch.utils.data.DataLoader(
                val_dataset,
                batch_size=batchsize, shuffle=False,
                num_workers=workers, pin_memory=True)
        
    results['image'], tmp = test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, Dict)
    Dict = tmp
        # assert len(Dict['image_path']) == len(Dict['caption']) and len(Dict['image_correct']) == len(Dict['caption_correct']), [len(v) for k, v in Dict.items()]
del val_dataset, val_loader

        

Initializing the contect with given words: [a_photo_of_a]
Initial context: "a photo of a"
Number of context words (tokens): 4
tuing parameters  {'prompt_learner.ctx'}
=> Model created: visual backbone ViT-B/16
=> Using native Torch AMP. Training in mixed precision.
retrieve K: 32
evaluating: Caltech101
number of test samples: 2465


200it [14:02,  4.69s/it]

Test: [ 199/2465]	Time  6.726 ( 4.215)	Acc@1 100.00 ( 95.00)	Acc@2 100.00 ( 98.50)	Acc@3 100.00 ( 99.00)	Acc@4 100.00 ( 99.50)	Acc@5 100.00 (100.00)


379it [27:52,  5.31s/it]