In [1]:

import argparse

import time

from copy import deepcopy

from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import pandas as pd
import os

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
import torchvision.models as models

import clip
from clip.custom_clip import get_coop
from clip.cocoop import get_cocoop
from data.imagnet_prompts import imagenet_classes
from data.datautils import AugMixAugmenter, build_dataset
from utils.tools import Summary, AverageMeter, ProgressMeter, load_model_weight, set_random_seed, create_logger
from data.cls_to_names import *
from data.fewshot_datasets import fewshot_datasets
from data.imagenet_variants import thousand_k_to_200, imagenet_a_mask, imagenet_r_mask, imagenet_v_mask
from clip_retrieval.clip_client import ClipClient, Modality

from collections import defaultdict
%load_ext autoreload
%autoreload 2

client = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=20,
    deduplicate=False,
)
client_backup = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=200,
    deduplicate=False,
)

client_backup2 = ClipClient(
    url="http://127.0.0.1:1234/knn-service",
    indice_name='laion_400m',
    modality=Modality.IMAGE,
    num_images=500,
    deduplicate=False,
)

## Class to names mapping
fewshot_datasets = ['DTD', 'Flower102', 'Food101', 'Cars', 'SUN397', 
                    'Aircraft', 'Pets', 'Caltech101', 'UCF101', 'eurosat']
test_sets = 'DTD/Food101/Cars/SUN397/Aircraft/Pets/Caltech101/UCF101/eurosat'#Parameters
tta_steps = 1
which_loss = "cosine"
gpu = 7
print_freq = 100
arch='ViT-B/16'
n_ctx=4
ctx_init="a_photo_of_a"
lr = 5e-3
global retrieve_K
retrieve_K= 1

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import torch.nn.functional as F

def average_entropy(input_tensor):
    # Compute the softmax over the second dimension (N classes)
    probs = F.softmax(input_tensor, dim=-1)
    
    # Compute the entropy
    entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
    
    # Compute and return the average entropy
    return torch.mean(entropy)

In [3]:


def select_confident_samples(logits, top):
    batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
    idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * top)]
    return logits[idx], idx

def Cosine(img_logit, cap_logit1, cap_logit2, model):
    return 1- F.cosine_similarity(F.softmax(model.alpha * img_logit + (1-model.alpha) * cap_logit1, dim=-1) , F.softmax(model.alpha * img_logit + (1-model.alpha) * cap_logit2, dim=-1) )
    
def JSdiv(logit, logit2, model):
    # assert logit.shape == logit2.shape, (logit.shape, logit2.shape)
    class_num = logit.size()[-1]
    
    kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
    prob1 = F.softmax(logit, dim=-1)
    prob2 = F.softmax(logit2, dim=-1)
    total_m = torch.sigmoid(model.alpha) * prob1 + (1-torch.sigmoid(model.alpha)) *prob2
    loss = 0.0
    loss += torch.sigmoid(model.alpha)* kl_loss(F.log_softmax(logit, dim=-1), total_m)
    loss += (1-torch.sigmoid(model.alpha)) * kl_loss(F.log_softmax(logit2, dim=-1), total_m)
    return loss
def JeffreyDiv(logit, logit2):

    kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
    class_num = logit.size()[-1]
    uni_ent = average_entropy(torch.ones(class_num))
    alpha = (1 - (average_entropy(logit2)/uni_ent))
    input = F.log_softmax(logit, dim=-1)
    target = F.softmax(logit2, dim=-1)
    #(1-entropy(p)/entropy(Unif(C)))
    input_ = F.log_softmax(logit2, dim=-1)
    target_ = F.softmax(logit, dim=-1)
    beta = (1 - (average_entropy(logit)/uni_ent))
    # total = alpha+beta #weight sum 1
    one = alpha * kl_loss(input, target.detach()) #image를 caption에 맞춤
    two = beta * kl_loss(input_, target_.detach())#caption을 Image에 맞춤

    # print(alpha, beta)
    # print(one, two)
    # print(alpha/cap_max, beta/img_max)
    return one + two

def KLreliable(logit, logit2):
    assert logit.shape == logit2.shape, (logit.shape, logit2.shape)
    kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
    if average_entropy(logit) < average_entropy(logit2) + 0.2:
        # first is more reliable
        # print("image is more reliable")
        input = F.log_softmax(logit2, dim=-1)
        target = F.softmax(logit, dim=-1)
        return kl_loss(input, target)
    else:
        # print("caption is more reliable")
        input = F.log_softmax(logit, dim=-1)
        target = F.softmax(logit2, dim=-1)
        return kl_loss(input, target)

In [4]:

def accuracy(output, target, topk=(1,), caption=None, logger=None):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        
        maxk = max(topk)
        batch_size = target.size(0)
        pred = output.argmax(dim=1)
        assert pred.shape == target.shape, (pred.shape, target.shape)
        correct = pred.eq(target).sum()
        # if output.shape[0] == 1:#only image prediction
        #     logit_k, pred = output.topk(maxk, 1, True, True)
        #     pred = pred.t()
        #     correct = pred.eq(target.view(1, -1).expand_as(pred))
        # else: # evaluate captions
        #     pred = torch.mean(output, 0, keepdim=True)
        #     _, pred = pred.topk(maxk, 1, True, True)
        #     pred = pred.reshape(maxk, 1)
        #     correct = pred.eq(target.view(1, -1).expand_as(pred))
        return correct
        # res = []
        # for k in topk:
        #     correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        #     res.append(correct_k.mul_(100.0 / batch_size))
        # return res

In [5]:

def return_caption(img_path, retrieve_K=1):
    try:
        query_res = client.query(image=img_path)
        assert len(query_res) >= retrieve_K
        query_res = query_res[:retrieve_K]
        retrieved_txt= [D['caption'] for D in query_res]
        retrieved_url = [D['url'] for D in query_res]
        retrieved_score = [D['similarity'] for D in query_res]
        return retrieved_txt, retrieved_score
    except:
        query_res = client_backup2.query(image=img_path)
        if len(query_res) >= retrieve_K:
            query_res = query_res[:retrieve_K]
            retrieved_txt= [D['caption'] for D in query_res]
            retrieved_url = [D['url'] for D in query_res]
            retrieved_score = [D['similarity'] for D in query_res]
            return retrieved_txt, retrieved_score
        else:
            return None, None


                    

In [6]:

def test_time_tuning(model, inputs, optimizer, scaler, imagepath = None):
    # Entropy + Triplet loss function * 0.5
    # Triplet loss function, anchor = retrieved vocab, positive = top5, negative = bottom5
    for j in range(tta_steps):
        with torch.cuda.amp.autocast():
            output_img, text_features = model(inputs) # bs, n_cls, (1, 1000), logit/ n_cls, 512
            # logit_scale = model.logit_scale.exp()
            # ent = avg_entropy(output_img)
            #caption
            retrieved_Caption, retrieved_score = return_caption(imagepath, retrieve_K=retrieve_K)
            
            if retrieved_Caption==None:
                return None
            # print(retrieved_Caption)
            output_caption = model.caption_ensemble(retrieved_Caption)
            # weighted = []
            # for score, logit in zip(retrieved_score, output_caption):
            #     logit = torch.nn.functional.softmax(logit, dim=-1)
            #     weighted.append(score/sum(retrieved_score) * logit)

            # tmp = torch.sum(torch.cat(weighted[:retrieve_K]).reshape(retrieve_K, -1), axis=0).reshape(1, -1)
            loss = []
            # loss_val = JeffreyDiv(output_img.squeeze(), output_caption, model)
            loss_val = JSdiv(output_img, output_caption, model)
            # # print(loss_val)
            # for each in output_caption:
            # #     loss.append(JeffreyDiv(output_img.squeeze(), each.squeeze()))
            #     loss.append(JSdiv(output_img.squeeze(), each.squeeze(), model))
            # loss_val = torch.mean(torch.stack(loss))
            #     # loss /= retrieve_K #average JSdiv
            # loss_val += Cosine(output_img, output_caption[0], output_caption[1], model).squeeze()
            # print("loss ", loss)
            # if retrieve_K >= 64:
            #     loss.sort()
            #     cut = int(len(loss) * 0.1)
            #     loss_val = torch.mean(torch.stack(loss[:cut]))
            # else:
            #     loss_val = torch.mean(torch.stack(loss))
            optimizer.zero_grad() 
            scaler.scale(loss_val).backward()
            scaler.step(optimizer)
            scaler.update()

    # print("empty caption {}".format(cnt_empty))
    return retrieved_Caption, output_caption

In [7]:
img_set = []
cap_set = []
def test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, save_result=False):
    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top2 = AverageMeter('Acc@2', ':6.2f', Summary.AVERAGE)
    top3 = AverageMeter('Acc@3', ':6.2f', Summary.AVERAGE)
    top4 = AverageMeter('Acc@4', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, top1, top2, top3, top4, top5],
        prefix='Test: ',
        logger = None)

    # reset model and switch to evaluate mode
    # model.eval()
    with torch.no_grad():
        model.reset()
    end = time.time()
    cnt_empty = 0
    assert save_result != None
    total_images = 0
    correct_images = 0
    for i, (images, target, imagepath) in tqdm(enumerate(val_loader)): 
        # print(len(image))
        assert gpu is not None
        # print("Image Path ", imagepath[0])
        if isinstance(images, list):
            for k in range(len(images)):
                images[k] = images[k].cuda(gpu, non_blocking=True)
            image = images[0]
        target = target.cuda(gpu, non_blocking=True)
        ### One time training
        images = torch.cat(images, dim=0)
        with torch.no_grad():
            model.reset()
        # optimizer.load_state_dict(optim_state)
        # TTA
        retrieved_Caption, caption_logit = test_time_tuning(model, image, optimizer, scaler, imagepath[0])
        # if retrieved_Caption==None:
            # cnt_empty +=1
        # print(model.gamma_cap, model.gamma_img)
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                #image
                output_img = model.inference(image.cuda(gpu, non_blocking=True))
                # save_result['image_path'] = imagepath[0]
                # save_result['caption'].append(retrieved_Caption[0])
                # output_caption = model.caption_ensemble(retrieved_Caption)
        if len(caption_logit)>1:
            caption_logit = torch.mean(caption_logit, dim=0)
        # print(caption_logit.shape)
        uniform = average_entropy(torch.ones(output_img.shape[-1]))
        # alpha = (1-average_entropy(output_img)/(uniform))
        # beta = (1-average_entropy(caption_logit)/(uniform))
        # img_set.append(alpha.item())
        # cap_set.append(beta.item())
        # print(alpha, beta)
        # print("--- normalize ---")
        # print((alpha  - np.mean(img_set)) / np.std(img_set),(beta  - np.mean(cap_set)) / np.std(cap_set))
        # adjusted = (alpha  - np.mean(img_set)) / np.std(img_set) *output_img + (beta  - np.mean(cap_set)) / np.std(cap_set)*caption_logit 
        adjusted = torch.nn.functional.softmax( torch.sigmoid(model.alpha) * output_img + (1-torch.sigmoid(model.alpha)) * caption_logit, dim=-1)
        # adjusted = torch.sigmoid(model.alpha) * torch.nn.functional.normalize(input, p=2.0, dim=1, eps=1e-12, out=None)
        # class_num = caption_logit.size()[-1]
        # uni_ent = average_entropy(torch.ones(class_num))
        # alpha = (1 - (average_entropy(caption_logit)/uni_ent)) * (1-torch.sigmoid(model.alpha))
        # beta = (1 - (average_entropy(output_img)/uni_ent)) * torch.sigmoid(model.alpha)
        total_images += 1
        correct_images += accuracy(adjusted, target, topk=(1, 2, 3, 4, 5), caption=None, logger=None)
        # correct_images += accuracy(output_img, target, topk=(1, 2, 3, 4, 5), caption=None, logger=None)
        # acc1, acc2, acc3, acc4, acc5 = accuracy(prob, target, topk=(1, 2, 3, 4, 5), caption=None, logger=None)
        # # print(acc1, acc2, acc3, acc4, acc5)
        # top1.update(acc1[0], image.size(0))
        # top2.update(acc2[0], image.size(0))
        # top3.update(acc3[0], image.size(0))
        # top4.update(acc4[0], image.size(0))
        # top5.update(acc5[0], image.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i+1) % print_freq == 0:
            # progress.display(i)
            print("model alpha ", model.alpha)
            print("accuracy  ", correct_images/total_images)
            
    
    print("empty caption count = {}".format(cnt_empty))
    # progress.display_summary()
    print("Accuracy: {}".format(correct_images/total_images) )
    return [top1.avg, top2.avg, top3.avg, top4.avg, top5.avg], save_result

In [8]:
# load model
if test_sets in fewshot_datasets:
    classnames = eval("{}_classes".format(test_sets.lower()))
model = get_coop(arch, test_sets, gpu, n_ctx, ctx_init)
model_state = None
print(model.logit_scale)
print(model.logit_scale.exp())
tmp = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
print(tmp, tmp.exp())

Initializing the contect with given words: [a_photo_of_a]
Initial context: "a photo of a"
Number of context words (tokens): 4
tensor(4.6052, device='cuda:7')
tensor(100., device='cuda:7')
Parameter containing:
tensor(2.6593, requires_grad=True) tensor(14.2857, grad_fn=<ExpBackward0>)


In [9]:
# model.logit_scale
# logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
# logit_scale =torch.nn.Parameter(torch.ones([]))
# print(logit_scale)

# model.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
# print(model.logit_scale)

cross_check = set()
for name, param in model.named_parameters():

    if "alpha" not in name:
        param.requires_grad_(False)
    if param.requires_grad : cross_check.add(name)
print("tuing parameters ", cross_check)

print("=> Model created: visual backbone {}".format(arch))

assert gpu is not None
torch.cuda.set_device(gpu)
model = model.cuda(gpu)

trainable_param = [{'params' : model.prompt_learner.parameters()},
                    {'params' : model.alpha, 'lr': 1e-3}]
                   
# trainable_param = model.parameters()
optimizer = torch.optim.AdamW(trainable_param, lr)
optim_state = deepcopy(optimizer.state_dict())

cross_check = set()
for name, param in model.named_parameters():
    if param.requires_grad : cross_check.add(name)
print("tuing parameters ", cross_check)
print(optim_state)

tuing parameters  {'alpha'}
=> Model created: visual backbone ViT-B/16
tuing parameters  {'alpha'}
{'state': {}, 'param_groups': [{'lr': 0.005, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.01, 'amsgrad': False, 'foreach': None, 'maximize': False, 'capturable': False, 'params': [0]}, {'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.01, 'amsgrad': False, 'foreach': None, 'maximize': False, 'capturable': False, 'params': [1]}]}


In [10]:


# setup automatic mixed-precision (Amp) loss scaling
# scaler = torch.cuda.amp.GradScaler(init_scale=1000)
scaler = torch.cuda.amp.GradScaler()

print('=> Using native Torch AMP. Training in mixed precision.')

cudnn.benchmark = True

resolution = 224
workers = 4
dataset_mode = 'test'
data = '/data/seongha'
import sys
from collections import defaultdict
    # norm stats from clip.load()
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                     std=[0.26862954, 0.26130258, 0.27577711])
    
    # iterating through eval datasets
datasets = test_sets.split("/")
results = {}

# with open('inferece_image_caption{}.txt'.format(test_sets), 'w') as f:
    # sys.stdout = f
for set_id in datasets:
    for retrieve_K in [32]:
        print("retrieve K: {}".format(retrieve_K))
        Dict = defaultdict(list)
        # data_transform = transforms.Compose([
        #     transforms.Resize(resolution, interpolation=BICUBIC),
        #     transforms.CenterCrop(resolution),
        #     transforms.ToTensor(),
        #     normalize,
        # ])
        base_transform = transforms.Compose([
            transforms.Resize(224, interpolation=BICUBIC),
            transforms.CenterCrop(224)])
        preprocess = transforms.Compose([
            transforms.ToTensor(),
            normalize])
        data_transform = AugMixAugmenter(base_transform, preprocess, n_views=retrieve_K-1, 
                                            augmix=len(set_id)>1)
        batchsize = 1
        print("evaluating: {}".format(set_id))
        classnames = eval("{}_classes".format(set_id.lower()))
        model.reset_classnames(classnames, arch)
        with torch.no_grad():
            model.reset()
        optimizer.load_state_dict(optim_state)
        val_dataset = build_dataset(set_id, data_transform, data, mode=dataset_mode)
        total_length = len(val_dataset)
        print("number of test samples: {}".format(len(val_dataset)))

        val_loader = torch.utils.data.DataLoader(
                    val_dataset,
                    batch_size=batchsize, shuffle=True,
                    num_workers=workers, pin_memory=True)
            
        results['image'], tmp = test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, Dict)
        Dict = tmp
            # assert len(Dict['image_path']) == len(Dict['caption']) and len(Dict['image_correct']) == len(Dict['caption_correct']), [len(v) for k, v in Dict.items()]
        df = pd.DataFrame(Dict)
        df = df.reset_index()

        path = './notebook/JSdiv'
        os.makedirs(path, exist_ok=True)
        df.to_csv(os.path.join(path, 'JSdiv_{}.csv'.format(set_id)))
        
        # with open(os.path.join(path,'JSdiv_{}.txt'.format(set_id)), 'w') as f:
            
        #     correct = df.loc[df['correct'] == 1]
        #     wrong = df.loc[df['correct'] == 0]
        # print("Entropy mean {}\n".format(str(correct['image_entropy'].astype(float).mean())))
        # print("Entropy mean {}\n".format(str(wrong['image_entropy'].astype(float).mean())))
del val_dataset, val_loader

        

=> Using native Torch AMP. Training in mixed precision.
retrieve K: 32
evaluating: DTD


number of test samples: 1692


100it [02:47,  1.42s/it]

model alpha  Parameter containing:
tensor(0.0995, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4600, device='cuda:7')


200it [05:14,  1.39s/it]

model alpha  Parameter containing:
tensor(0.2100, device='cuda:7', requires_grad=True)
accuracy   tensor(0.5000, device='cuda:7')


300it [07:52,  1.80s/it]

model alpha  Parameter containing:
tensor(0.3090, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4933, device='cuda:7')


400it [10:32,  1.41s/it]

model alpha  Parameter containing:
tensor(0.4222, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4750, device='cuda:7')


500it [13:05,  2.01s/it]

model alpha  Parameter containing:
tensor(0.5448, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4740, device='cuda:7')


600it [15:37,  1.50s/it]

model alpha  Parameter containing:
tensor(0.6582, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4800, device='cuda:7')


700it [18:07,  1.45s/it]

model alpha  Parameter containing:
tensor(0.7758, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4786, device='cuda:7')


800it [20:39,  1.40s/it]

model alpha  Parameter containing:
tensor(0.8859, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4812, device='cuda:7')


900it [23:07,  1.34s/it]

model alpha  Parameter containing:
tensor(0.9840, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4756, device='cuda:7')


1000it [25:34,  1.59s/it]

model alpha  Parameter containing:
tensor(1.0844, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4800, device='cuda:7')


1100it [28:12,  1.54s/it]

model alpha  Parameter containing:
tensor(1.1760, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4782, device='cuda:7')


1200it [30:46,  1.45s/it]

model alpha  Parameter containing:
tensor(1.2622, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4733, device='cuda:7')


1300it [33:16,  1.61s/it]

model alpha  Parameter containing:
tensor(1.3612, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4738, device='cuda:7')


1400it [35:47,  1.42s/it]

model alpha  Parameter containing:
tensor(1.4380, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4664, device='cuda:7')


1500it [38:33,  1.82s/it]

model alpha  Parameter containing:
tensor(1.5216, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4680, device='cuda:7')


1600it [41:08,  1.46s/it]

model alpha  Parameter containing:
tensor(1.5991, device='cuda:7', requires_grad=True)
accuracy   tensor(0.4650, device='cuda:7')


1692it [43:27,  1.54s/it]


empty caption count = 0
Accuracy: 0.4651300311088562
retrieve K: 32
evaluating: Food101
number of test samples: 30300


100it [02:34,  1.46s/it]

model alpha  Parameter containing:
tensor(1.7382, device='cuda:7', requires_grad=True)
accuracy   tensor(0.8000, device='cuda:7')


130it [03:19,  1.48s/it]