In [11]:

import argparse

import time

from copy import deepcopy

from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import pandas as pd
import os

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
import torchvision.models as models

import clip
from clip.custom_clip import get_coop
from clip.cocoop import get_cocoop
from data.imagnet_prompts import imagenet_classes
from data.datautils import AugMixAugmenter, build_dataset
from utils.tools import Summary, AverageMeter, ProgressMeter, load_model_weight, set_random_seed, create_logger
from data.cls_to_names import *
from data.fewshot_datasets import fewshot_datasets
from data.imagenet_variants import thousand_k_to_200, imagenet_a_mask, imagenet_r_mask, imagenet_v_mask
from clip_retrieval.clip_client import ClipClient, Modality

from collections import defaultdict
%load_ext autoreload
%autoreload 2

client = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=20,
    deduplicate=False,
)
client_backup = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=200,
    deduplicate=False,
)

client_backup2 = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=500,
    deduplicate=False,
)

## Class to names mapping
fewshot_datasets = ['DTD', 'Flower102', 'Food101', 'Cars', 'SUN397', 
                    'Aircraft', 'Pets', 'Caltech101', 'UCF101', 'eurosat']
test_sets = 'DTD/Flower102/Food101/Cars/SUN397/Aircraft/Pets/Caltech101/UCF101/eurosat'#Parameters
tta_steps = 1
which_loss = "cosine"
gpu = 7
print_freq = 100
arch='ViT-B/16'
n_ctx=4
ctx_init="a_photo_of_a"
lr = 5e-3
global retrieve_K
retrieve_K= 1

tau_dict = {'DTD': {'i_tau': 0.0, 'c_tau': 0.08},
 'Flower102': {'i_tau': 0.46, 'c_tau': 0.0},
 'Food101': {'i_tau': 0.0, 'c_tau': 0.99},
 'Cars': {'i_tau': 0.0, 'c_tau': 0.99},
 'SUN397': {'i_tau': 0.0, 'c_tau': 0.72},
 'Aircraft': {'i_tau': 0.0, 'c_tau': 0.37},
 'Pets': {'i_tau': 1.47, 'c_tau': 0.0},
 'Caltech101': {'i_tau': 0.21, 'c_tau': 0.0},
 'UCF101': {'i_tau': 0.0, 'c_tau': 0.68},
 'eurosat': {'i_tau': 0.0, 'c_tau': 0.54}}

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
import torch
import torch.nn.functional as F

def avg_entropy(outputs):
    logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) # logits = outputs.log_softmax(dim=1) [N, 1000]
    # logits = outputs.log_softmax(dim=-1) #[N, 1000]
    avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) # avg_logits = logits.mean(0) [1, 1000]
    # avg_logits = logits.mean(0) #[1, 1000]
    # print(avg_logits
    
    min_real = torch.finfo(avg_logits.dtype).min
    avg_logits = torch.clamp(avg_logits, min=min_real)
    return -((avg_logits) * (torch.exp(avg_logits))).sum(dim=-1)

In [13]:

kl_loss = torch.nn.KLDivLoss(reduction="batchmean")

def select_confident_samples(logits, top):
    batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
    idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * top)]
    return logits[idx], idx

def Cosine(img_logit, cap_logit, model):
    # (retrieveK, C)
    cosine = torch.nn.CosineEmbeddingLoss(reduce=True)
    y_hat = F.softmax(torch.sigmoid(model.alpha)* img_logit + (1-torch.sigmoid(model.alpha)) * cap_logit, dim=-1)
    # print(y_hat.shape, img_logit.shape)
    return cosine(y_hat, y_hat, torch.ones(img_logit.shape[0]).cuda(gpu))
    
def JSdiv(logit, logit2, model):
    # assert logit.shape == logit2.shape, (logit.shape, logit2.shape)
    class_num = logit.size()[-1]
    
    kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
    prob1 = F.softmax(logit, dim=-1)
    prob2 = F.softmax(logit2, dim=-1)
    total_m = 0.5 *prob1 + 0.5*prob2
    loss = 0.0
    loss += torch.sigmoid(model.alpha)* kl_loss(F.log_softmax(logit, dim=-1), total_m)
    loss += (1-torch.sigmoid(model.alpha)) * kl_loss(F.log_softmax(logit2, dim=-1), total_m)
    return loss
def JeffreyDiv(logit, logit2, model):
    class_num = logit.size()[-1]
    uni_ent = avg_entropy(torch.ones(class_num))
    uni_ent.requires_grad_(True)
    # print(uni_ent.requires_grad)
    a = torch.sigmoid(model.c_tau)
    alpha = (1 - ((avg_entropy(logit2) * a)/uni_ent))
    input = F.log_softmax(logit, dim=-1)
    input.requires_grad_(True)
    target = F.softmax(logit2, dim=-1)
    #(1-entropy(p)/entropy(Unif(C)))
    input_ = F.log_softmax(logit2, dim=-1)
    target_ = F.softmax(logit, dim=-1)
    b = torch.sigmoid(model.i_tau)
    beta = (1 - ((avg_entropy(logit) * b)/uni_ent))
    # total = alpha+beta #weight sum 1
    one = alpha * kl_loss(input, target.detach()) #image를 caption에 맞춤
    two = beta * kl_loss(input_, target_.detach())#caption을 Image에 맞춤
    # assert alpha.requires_grad and beta.requires_grad and input.requires_grad and input_.requires_grad and logit.requires_grad and alpha.requires_grad
    # print(alpha.requires_grad, beta.requires_grad, input.requires_grad, input_.requires_grad, logit.requires_grad, alpha.requires_grad)
    return one + two

def KLreliable(logit, logit2):
    assert logit.shape == logit2.shape, (logit.shape, logit2.shape)
    kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
    if average_entropy(logit) < average_entropy(logit2):
        # first is more reliable
        # print("image is more reliable")
        input = F.log_softmax(logit2, dim=-1)
        target = F.softmax(logit, dim=-1)
        return kl_loss(input, target)
    else:
        # print("caption is more reliable")
        input = F.log_softmax(logit, dim=-1)
        target = F.softmax(logit2, dim=-1)
        return kl_loss(input, target)

In [14]:

def accuracy(output, target, topk=(1,), caption=None, logger=None):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        pred = output.argmax(dim=-1)
        assert pred.shape == target.shape, (pred.shape, target.shape)
        correct = pred.eq(target).sum()
        return correct


In [15]:

def return_caption(img_path, retrieve_K=1):
    try:
        query_res = client.query(image=img_path)
        assert len(query_res) >= retrieve_K
        query_res = query_res[:retrieve_K]
        retrieved_txt= [D['caption'] for D in query_res]
        retrieved_url = [D['url'] for D in query_res]
        retrieved_score = [D['similarity'] for D in query_res]
        return retrieved_txt, retrieved_score
    except:
        query_res = client_backup2.query(image=img_path)
        if len(query_res) >= retrieve_K:
            query_res = query_res[:retrieve_K]
            retrieved_txt= [D['caption'] for D in query_res]
            retrieved_url = [D['url'] for D in query_res]
            retrieved_score = [D['similarity'] for D in query_res]
            return retrieved_txt, retrieved_score
        else:
            return None, None    

In [16]:
test1= torch.tensor([2.0, 2, 2, 2, 3, 4])
def confidence(x):
    uniform_ent = avg_entropy(torch.ones(x.shape[-1]))
    logit_ent = avg_entropy(x)
    logit_ent1 = avg_entropy(x/x.norm(dim=-1))
    print(logit_ent, logit_ent1, uniform_ent)
    return logit_ent

print(confidence(test1))

tensor(0.2986) tensor(0.2986) tensor(0.2986)
tensor(0.2986)


In [17]:

def test_time_tuning(model, inputs, optimizer, scaler, imagepath = None):
    # Entropy + Triplet loss function * 0.5
    # Triplet loss function, anchor = retrieved vocab, positive = top5, negative = bottom5
    for j in range(tta_steps):
        with torch.cuda.amp.autocast():
            output_img, text_features = model(inputs) # bs, n_cls, (1, 1000), logit/ n_cls, 512
            # logit_scale = model.logit_scale.exp()
            # ent = avg_entropy(output_img)
            #caption
            retrieved_Caption, retrieved_score = return_caption(imagepath, retrieve_K=retrieve_K)
            
            if retrieved_Caption==None:
                return None
            # print(retrieved_Caption)
            output_caption = model.caption_ensemble(retrieved_Caption)
            
            # print(output_img.norm(dim=-1), output_caption.norm(dim=-1))
            # weighted = []
            # for score, logit in zip(retrieved_score, output_caption):
            #     logit = torch.nn.functional.softmax(logit, dim=-1)
            #     weighted.append(score/sum(retrieved_score) * logit)

            # tmp = torch.sum(torch.cat(weighted[:retrieve_K]).reshape(retrieve_K, -1), axis=0).reshape(1, -1)
            # loss = []
            # loss_val = JeffreyDiv(output_img/output_img.norm(dim=-1, keepdim=True), output_caption/output_caption.norm(dim=-1, keepdim=True), model)
            loss_val = JeffreyDiv(output_img, output_caption, model)
            # print(loss_val)
            # loss_val += Cosine(output_img, output_caption, model)

            optimizer.zero_grad() 
            scaler.scale(loss_val).backward()
            scaler.step(optimizer)
            scaler.update()

    # print("empty caption {}".format(cnt_empty))
    return retrieved_Caption, output_caption

In [18]:


def test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, save_result=False, set_id=''):
    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top2 = AverageMeter('Acc@2', ':6.2f', Summary.AVERAGE)
    top3 = AverageMeter('Acc@3', ':6.2f', Summary.AVERAGE)
    top4 = AverageMeter('Acc@4', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, top1, top2, top3, top4, top5],
        prefix='Test: ',
        logger = None)

    # reset model and switch to evaluate mode
    model.eval()
    with torch.no_grad():
        model.reset()
    end = time.time()
    cnt_empty = 0
    assert save_result != None
    ## to measure accuracy
    total_images = 0
    correct_images = 0
    
    ## image average entropy mean
    mean_stat = torch.tensor([], dtype=torch.float32)
    # cmean_stat = dict()
    
    ## caption average entropy mean
    cap_mean_stat = torch.tensor([], dtype=torch.float32)
    # cap_cmean_stat = 0
    
    # count accuracy of when using caption!
    cnt_cap = 0
    cnt_cap_correct = 0
    for i, (images, target, imagepath) in tqdm(enumerate(val_loader)): 
        # print(len(image))
        assert gpu is not None
        # print("Image Path ", imagepath[0])
        if isinstance(images, list):
            for k in range(len(images)):
                images[k] = images[k].cuda(gpu, non_blocking=True)
            image = images[0]
            images = torch.cat(images, dim=0)
        else:
            image = images.cuda(gpu, non_blocking=True)
        target = target.cuda(gpu, non_blocking=True)
        ### One time training
        
        with torch.no_grad():
            model.reset()
        optimizer.load_state_dict(optim_state)
        
        assert not torch.isinf(model.i_tau).any(), model.i_tau
        assert not torch.isnan(model.c_tau).any(), model.c_tau
        # TTA
        retrieved_Caption, caption_logit = test_time_tuning(model, image, optimizer, scaler, imagepath[0])
        # if retrieved_Caption==None:
            # cnt_empty +=1
        # print(model.gamma_cap, model.gamma_img)
        assert model.i_tau.is_leaf == True
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                #image
                output_img = model.inference(image.cuda(gpu, non_blocking=True))
                #caption
                # retrieved_Caption, retrieved_score = return_caption(imagepath[0], retrieve_K=retrieve_K)
                # if retrieved_Caption == None:
                #     cnt_empty +=1
                    # continue
                # caption_logit = model.caption_ensemble(retrieved_Caption)
        #TODO: scaled normalization for each modality
        conf_img = avg_entropy(output_img)
        conf_cap = avg_entropy(caption_logit)
        # print(conf_img.item(), conf_cap.item())
        if conf_img > torch.mean(mean_stat) * (1+model.i_tau) and conf_cap < torch.mean(cap_mean_stat) * (1-model.c_tau):
            # 이미자 못 맞추는걸 캡션으로 해보자
            # print(caption_logit.shape)
            correct_ = accuracy(caption_logit, target, topk=(1, 2, 3, 4, 5), caption=None, logger=None).item()
            cnt_cap += 1
            if correct_ : cnt_cap_correct +=1
        else:
            correct_ = accuracy(output_img , target, topk=(1, 2, 3, 4, 5), caption=None, logger=None).item()
        #update mean
        mean_stat = torch.cat([mean_stat, torch.tensor([conf_img],dtype=torch.float32)])
        cap_mean_stat = torch.cat([cap_mean_stat, torch.tensor([conf_cap], dtype=torch.float32)])
        total_images += 1
        correct_images += correct_


        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if (i+1) % print_freq == 0:
            # progress.display(i)
            
            print("accuracy  ", correct_images/total_images)
            print("image confidence score mean stat", torch.mean(mean_stat).item())
            print("caption confidence score mean stat", torch.mean(cap_mean_stat).item())
            print("count caption correct {} out of {}".format(cnt_cap_correct, cnt_cap))
            print("c tau {} i tau {}".format(model.c_tau, model.i_tau))
            save_result['accuracy'].append(correct_images/total_images)
            save_result['image_conf_mean'].append(torch.mean(mean_stat).item())
            save_result['caption_conf_mean'].append(torch.mean(cap_mean_stat).item())
            save_result['cap_cnt'].append(cnt_cap)
            save_result['cap_corr'].append(cnt_cap_correct)
            # save_result['c_tau'].append(c_tau)
            # save_result['i_tau'].append(i_tau)
            df = pd.DataFrame(save_result)
            df = df.reset_index()
            path = './notebook/JSdiv/{}'.format(arch.replace('/', ''))
            os.makedirs(path, exist_ok=True)
            df.to_csv(os.path.join(path, 'JSdiv_{}.csv'.format(set_id)))
    
    print("empty caption count = {}".format(cnt_empty))
    # progress.display_summary()
    print("Accuracy: {}".format(correct_images/total_images) )
    return save_result

In [19]:
# load model
if test_sets in fewshot_datasets:
    classnames = eval("{}_classes".format(test_sets.lower()))
model = get_coop(arch, test_sets, gpu, n_ctx, ctx_init)

model_state = None

dtype  torch.float32
Initializing the contect with given words: [a_photo_of_a]
Initial context: "a photo of a"
Number of context words (tokens): 4


In [20]:


cudnn.benchmark = True

resolution = 224
workers = 4
dataset_mode = 'test'
data = '/data/seongha'
import sys
from collections import defaultdict
    # norm stats from clip.load()
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                     std=[0.26862954, 0.26130258, 0.27577711])
    
    # iterating through eval datasets
datasets = test_sets.split("/")
results = {}

# with open('inferece_image_caption{}.txt'.format(test_sets), 'w') as f:
    # sys.stdout = f
for set_id in datasets:

    i_tau = tau_dict[set_id]['i_tau']
    c_tau = tau_dict[set_id]['c_tau']

    model.set_tau(i_tau=i_tau, c_tau = c_tau)
    assert gpu is not None
    torch.cuda.set_device(gpu)
    model = model.cuda(gpu)

    trainable_param = [
        # {'params' : model.prompt_learner.parameters()},
                        {'params' : model.c_tau, 'lr': 1e-3},
                        {'params' : model.i_tau, 'lr': 1e-3},
                        ]
                    
    # trainable_param = model.parameters()
    optimizer = torch.optim.AdamW(trainable_param, lr)
    optim_state = deepcopy(optimizer.state_dict())

    cross_check = set()
    for name, param in model.named_parameters():
        if name not in "i_tau" and name not in 'c_tau': param.requires_grad = False
        if param.requires_grad : cross_check.add(name)
    print("tuning parameters ", cross_check)
    print(optim_state)
    # setup automatic mixed-precision (Amp) loss scaling
    scaler = torch.cuda.amp.GradScaler()
    # scaler = torch.cuda.amp.GradScaler()

    print('=> Using native Torch AMP. Training in mixed precision.')
    for retrieve_K in [1]:
        print("retrieve K: {}".format(retrieve_K))
        Dict = defaultdict(list)
        data_transform = transforms.Compose([
            transforms.Resize(resolution, interpolation=BICUBIC),
            transforms.CenterCrop(resolution),
            transforms.ToTensor(),
            normalize,
        ])
        # base_transform = transforms.Compose([
        #     transforms.Resize(224, interpolation=BICUBIC),
        #     transforms.CenterCrop(224)])
        # preprocess = transforms.Compose([
        #     transforms.ToTensor(),
        #     normalize])
        # data_transform = AugMixAugmenter(base_transform, preprocess, n_views=retrieve_K-1, 
        #                                     augmix=len(set_id)>1)
        batchsize = 1
        print("evaluating: {}".format(set_id))
        classnames = eval("{}_classes".format(set_id.lower()))
        model.reset_classnames(classnames, arch)
        with torch.no_grad():
            model.reset()

            # model.alpha = torch.nn.Parameter(torch.tensor(0.0))
            # model.i_tau = torch.nn.Parameter(torch.tensor(tau_dict[set_id]['i_tau']))
            # model.c_tau = torch.nn.Parameter(torch.tensor(tau_dict[set_id]['c_tau']))
        optimizer.load_state_dict(optim_state)
        val_dataset = build_dataset(set_id, data_transform, data, mode=dataset_mode)
        total_length = len(val_dataset)

        assert next(model.parameters()).is_cuda and  model.i_tau.is_cuda, model.i_tau.is_cuda
        
        cross_check = set()
        for name, param in model.named_parameters():
            if param.requires_grad : cross_check.add(name)
        print("tuing parameters ", cross_check)
        print("number of test samples: {}".format(len(val_dataset)))

        val_loader = torch.utils.data.DataLoader(
                    val_dataset,
                    batch_size=batchsize, shuffle=True,
                    num_workers=workers, pin_memory=True)
            
        results = test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, Dict, set_id)
            # assert len(Dict['image_path']) == len(Dict['caption']) and len(Dict['image_correct']) == len(Dict['caption_correct']), [len(v) for k, v in Dict.items()]
        df = pd.DataFrame(results)
        df = df.reset_index()

        path = './notebook/JSdiv'
        os.makedirs(path, exist_ok=True)
        df.to_csv(os.path.join(path, 'JSdiv_{}.csv'.format(set_id)))
del val_dataset, val_loader

        

tuning parameters  {'i_tau', 'c_tau'}
{'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.01, 'amsgrad': False, 'foreach': None, 'maximize': False, 'capturable': False, 'params': [0]}, {'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.01, 'amsgrad': False, 'foreach': None, 'maximize': False, 'capturable': False, 'params': [1]}]}
=> Using native Torch AMP. Training in mixed precision.
retrieve K: 1
evaluating: DTD
tuing parameters  {'i_tau', 'c_tau'}
number of test samples: 1692


100it [00:39,  2.64it/s]

accuracy   0.36
image confidence score mean stat 1.7302119731903076
caption confidence score mean stat 1.2301819324493408
count caption correct 34 out of 97
c tau Parameter containing:
tensor(-2.4413, device='cuda:7', requires_grad=True) i tau Parameter containing:
tensor(-16.1169, device='cuda:7', requires_grad=True)


200it [01:17,  2.70it/s]

accuracy   0.325
image confidence score mean stat 1.7156692743301392
caption confidence score mean stat 1.1788921356201172
count caption correct 63 out of 197
c tau Parameter containing:
tensor(-2.4413, device='cuda:7', requires_grad=True) i tau Parameter containing:
tensor(-16.1171, device='cuda:7', requires_grad=True)


300it [01:56,  2.55it/s]

accuracy   0.29
image confidence score mean stat 1.6457427740097046
caption confidence score mean stat 1.2010986804962158
count caption correct 85 out of 297
c tau Parameter containing:
tensor(-2.4413, device='cuda:7', requires_grad=True) i tau Parameter containing:
tensor(-16.1169, device='cuda:7', requires_grad=True)


400it [02:35,  2.47it/s]

accuracy   0.2675
image confidence score mean stat 1.6830614805221558
caption confidence score mean stat 1.2281428575515747
count caption correct 105 out of 397
c tau Parameter containing:
tensor(-2.4413, device='cuda:7', requires_grad=True) i tau Parameter containing:
tensor(-16.1169, device='cuda:7', requires_grad=True)


500it [03:14,  2.47it/s]

accuracy   0.27
image confidence score mean stat 1.696144938468933
caption confidence score mean stat 1.2303662300109863
count caption correct 133 out of 497
c tau Parameter containing:
tensor(-2.4413, device='cuda:7', requires_grad=True) i tau Parameter containing:
tensor(-16.1169, device='cuda:7', requires_grad=True)


600it [03:53,  2.69it/s]

accuracy   0.26166666666666666
image confidence score mean stat 1.6995246410369873
caption confidence score mean stat 1.2460707426071167
count caption correct 155 out of 597
c tau Parameter containing:
tensor(-2.4413, device='cuda:7', requires_grad=True) i tau Parameter containing:
tensor(-16.1169, device='cuda:7', requires_grad=True)


660it [04:16,  2.71it/s]

662it [04:17,  2.56it/s]