In [None]:

import argparse

import time

from copy import deepcopy

from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms


try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
import torchvision.models as models

import clip
from clip.custom_clip import get_coop
from clip.cocoop import get_cocoop
from data.imagnet_prompts import imagenet_classes
from data.datautils import AugMixAugmenter, build_dataset
from utils.tools import Summary, AverageMeter, ProgressMeter, accuracy, load_model_weight, set_random_seed, create_logger
from data.cls_to_names import *
from data.fewshot_datasets import fewshot_datasets
from data.imagenet_variants import thousand_k_to_200, imagenet_a_mask, imagenet_r_mask, imagenet_v_mask
%load_ext autoreload
%autoreload 2

In [None]:

def select_confident_samples(logits, top):
    batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
    idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * top)]
    return logits[idx], idx

def avg_entropy(outputs):
    # epsilon = 1e-10
    assert len(outputs) > 0
    assert torch.any(torch.isnan(outputs)) == False
    logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) # logits = outputs.log_softmax(dim=1) [N, 1000]
    assert torch.any(torch.isnan(logits)) == False
    avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) # avg_logits = logits.mean(0) [1, 1000]
    # print(avg_logits)
    if torch.any(torch.isnan(avg_logits)):
        print("average logits ", outputs.log_softmax(dim=1).mean(0))
    assert torch.any(torch.isnan(avg_logits)) == False
    
    min_real = torch.finfo(avg_logits.dtype).min
    avg_logits = torch.clamp(avg_logits, min=min_real)
    assert torch.any(torch.isnan(avg_logits)) == False
    return -((avg_logits) * (torch.exp(avg_logits))).sum(dim=-1)

In [None]:
#Parameters
tta_steps = 1
which_loss = "both"
gpu = 2
print_freq = 10

In [None]:

def test_time_tuning(model, inputs, optimizer, scaler):
    # Entropy + Triplet loss function * 0.5
    # Triplet loss function, anchor = retrieved vocab, positive = top5, negative = bottom5
    for j in range(tta_steps):
        with torch.cuda.amp.autocast():
            output_img, text_features = model(inputs) # bs, n_cls, (1, 1000), logit/ n_cls, 512
            loss = 0
            if which_loss in ["both", "cosine"]:
                retrieve_K = 5 #TODO: 1,2,4,8
                retrieved_txt_prj, _ = model.nn_project(image=inputs, print_hits=False, topk=retrieve_K) # K, 512, retrieved text projection
                assert len(retrieved_txt_prj) == retrieve_K, len(retrieved_txt_prj)
                
                K = 5
                top, top_ind = output_img.topk(2*K, 1, True, True)
                # print(top_ind.shape)
                top5_ind = top_ind.squeeze()[:K]
                bot5_ind = top_ind.squeeze()[K:]
                assert len(top5_ind) == K and len(bot5_ind) == K, (top5_ind, bot5_ind)
                # bot5, bot5_ind = output_img.topk(K, 1, False, True)
                print("Top 5 ", top5_ind)
                # loss2 = avg_cos_dist(text_features[top5_ind], projections)
                x1 = torch.cat((text_features[top5_ind].squeeze(), text_features[bot5_ind].squeeze()), dim=0).cuda(gpu)
                label = torch.tensor([1 for _ in range(K)] + [ -1 for _ in range(K)]).cuda(gpu)
                criterion = torch.nn.CosineEmbeddingLoss()
                loss2=[]
                for each in retrieved_txt_prj:
                    loss2.append(criterion(x1, each.expand_as(x1), label))
                print("Cosine embedding loss ", (sum(loss2)/len(loss2)).detach().cpu())
                loss += sum(loss2)/len(loss2) * 0.5
            if which_loss in ["both", "entropy"]:
                loss += avg_entropy(retrieved_txt_prj @ text_features.t()) 
                print("Entropy loss: ", loss.detach().cpu())
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    return

In [None]:
def test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler):
    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top2 = AverageMeter('Acc@2', ':6.2f', Summary.AVERAGE)
    top3 = AverageMeter('Acc@3', ':6.2f', Summary.AVERAGE)
    top4 = AverageMeter('Acc@4', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, top1, top2, top3, top4, top5],
        prefix='Test: ',
        logger = None)

    # reset model and switch to evaluate mode
    model.eval()
    with torch.no_grad():
        model.reset()
    end = time.time()
    cnt_empty = 0
    
    for i, (image, target, imagepath) in tqdm(enumerate(val_loader)): 
        assert gpu is not None
        print("Image Path ", imagepath)
        target = target.cuda(gpu, non_blocking=True)
        
        ### One time training
        # reset the tunable prompt to its initial state

        with torch.no_grad():
            model.reset()
        optimizer.load_state_dict(optim_state)
                
        retrieved_caption = test_time_tuning(model, image.cuda(gpu, non_blocking=True), optimizer, scaler)

        
        with torch.no_grad():
            with torch.cuda.amp.autocast():

                output, output_cap = model.inference(image.cuda(gpu, non_blocking=True), caption= None)

        acc1, acc2, acc3, acc4, acc5 = accuracy(output, target, topk=(1, 2, 3, 4, 5), caption=None, logger=None)
        # print(acc1, acc2, acc3, acc4, acc5)
        top1.update(acc1[0], image.size(0))
        top2.update(acc2[0], image.size(0))
        top3.update(acc3[0], image.size(0))
        top4.update(acc4[0], image.size(0))
        top5.update(acc5[0], image.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i+1) % print_freq == 0:
            progress.display(i)
        

    progress.display_summary()
    return [top1.avg, top2.avg, top3.avg, top4.avg, top5.avg]

In [None]:
##parameters
fewshot_datasets = ['DTD', 'Flower102', 'Food101', 'Cars', 'SUN397', 
                    'Aircraft', 'Pets', 'Caltech101', 'UCF101', 'eurosat']
test_sets = 'Caltech101'
arch='ViT-B/16'
n_ctx=4
ctx_init="a_photo_of_a"
lr = 5e-3

In [None]:
# load model
if test_sets in fewshot_datasets:
    classnames = eval("{}_classes".format(test_sets.lower()))
model = get_coop(arch, test_sets, gpu, n_ctx, ctx_init)
model_state = None

cross_check = set()
for name, param in model.named_parameters():

    if "prompt_learner" not in name:
        param.requires_grad_(False)
    if param.requires_grad : cross_check.add(name)
print("tuing parameters ", cross_check)

print("=> Model created: visual backbone {}".format(arch))

assert gpu is not None
torch.cuda.set_device(gpu)
model = model.cuda(gpu)

trainable_param = model.prompt_learner.parameters()
optimizer = torch.optim.AdamW(trainable_param, lr)
optim_state = deepcopy(optimizer.state_dict())

# setup automatic mixed-precision (Amp) loss scaling
scaler = torch.cuda.amp.GradScaler(init_scale=1000)

print('=> Using native Torch AMP. Training in mixed precision.')

cudnn.benchmark = True

In [None]:
resolution = 224
workers = 4
dataset_mode = 'test'
data = '/data/seongha'
    # norm stats from clip.load()
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                     std=[0.26862954, 0.26130258, 0.27577711])
    
    # iterating through eval datasets
datasets = test_sets.split("/")
results = {}
for set_id in datasets:

    data_transform = transforms.Compose([
        transforms.Resize(resolution, interpolation=BICUBIC),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
        normalize,
    ])
    batchsize = 1
    print("evaluating: {}".format(set_id))
    classnames = eval("{}_classes".format(set_id.lower()))
    model.reset_classnames(classnames, arch)

    val_dataset = build_dataset(set_id, data_transform, data, mode=dataset_mode)
    print("number of test samples: {}".format(len(val_dataset)))

    val_loader = torch.utils.data.DataLoader(
                val_dataset,
                batch_size=batchsize, shuffle=False,
                num_workers=workers, pin_memory=True)
        
    results[set_id] = test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler)
    del val_dataset, val_loader

    try:
        print("=> Acc. on testset [{}]: @1 {}/ @2 {}/ @3 {}/ @4 {}/ @5 {}".format(set_id, results[set_id][0], results[set_id][1], results[set_id][2], results[set_id][3], results[set_id[4], results[set_id][5]]))
    except:
        print("=> Acc. on testset [{}]: {}".format(set_id, results[set_id]))