In [10]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="2"

## Load dataset

In [11]:
from datasets import load_dataset, load_from_disk
#food_dataset = load_dataset('food101')
food_dataset = load_from_disk('dataset/salads/hf_dataset/')

In [12]:
import json

#with open('mapping.json') as f_in:
with open('dataset/salads/mapping_salad.json') as f_in:
    name_to_index = json.load(f_in)

index_to_name = {v: k for k, v in name_to_index.items()}

num_classes = len(index_to_name)
print(num_classes)

184


In [13]:
from itertools import chain
import random

label_index = {}

for i, lbl in enumerate(food_dataset['train']['label']):
    if lbl not in label_index:
        label_index[lbl] = []
    
    # if len(label_index[lbl]) >= hash(lbl) % 10 + 1:
        # continue
    
    label_index[lbl].append(i)

fs_indices = list(chain(*label_index.values()))
fs_subset = food_dataset['train'].select(fs_indices)

print(len(fs_subset))

588


## Load Model

In [14]:
import spacy
from itertools import islice

nlp = spacy.load("en_core_web_sm")

def label_to_string(label, template='A photo of a {name}. A picture of food.'):
    label_name = index_to_name[label]
    label_name = label_name.replace('_', ' ')
    return template.format(name=label_name)

def article_text_shorten(text, num_sentences=2):
    sents = islice(nlp(text).sents, num_sentences)
    sents = [str(s) for s in sents]
    return ' '.join(sents)

In [15]:
from transformers import CLIPProcessor, CLIPModel

hidden_size = 768

# laion/CLIP-ViT-H-14-laion2B-s32B-b79K
# laion/CLIP-ViT-B-32-laion2B-s34B-b79K
# openai/clip-vit-large-patch14
def load_clip_model(device='cuda', model_name="openai/clip-vit-large-patch14"):
    model = CLIPModel.from_pretrained(model_name)
    model.eval()
    model.to(device)

    processor = CLIPProcessor.from_pretrained(model_name)

    return model, processor

model, processor = load_clip_model()

## Get CLIP embeddings

In [16]:
from tqdm import tqdm
import torch
from torch.nn import functional as F

ref_images = torch.zeros((len(fs_subset), hidden_size))
ref_text = torch.zeros((num_classes, hidden_size))
labels_all = []

for i, ex in enumerate(tqdm(fs_subset)):
    image, label = ex['image'], ex['label']
    labels_all.append(label)
    
    label_str = label_to_string(label)
    #label_str = article_text_shorten(ex['text'], num_sentences=1)

    model_input = processor(
        images=image, text=label_str, return_tensors="pt", padding=True, truncation=True
    ).to('cuda')

    model_output = model(**model_input)

    ref_images[i] = model_output.image_embeds.detach().cpu()
    ref_text[label] = model_output.text_embeds.detach().cpu()

labels_oh = F.one_hot(torch.tensor(labels_all), num_classes=num_classes).float()
labels_oh = labels_oh / torch.sum(labels_oh, dim=0, keepdim=True)
assert not torch.isnan(labels_oh).any()

 44%|████▎     | 256/588 [00:44<00:57,  5.76it/s]


KeyboardInterrupt: 

## Main TIP-Adapter Code

In [None]:
def infer_single(test_image, beta=1.0, alpha=1.0):
    test_image_embeds = model(
        **(processor(images=test_image, text='', return_tensors="pt", padding=True).to('cuda'))
    ).image_embeds.cpu()

    img_sim = torch.matmul(test_image_embeds, ref_images.T)
    img_sim = ((-1) * (beta - beta * img_sim)).exp()
    class_sim_img = torch.matmul(img_sim, labels_oh) # (1, num_classes)

    class_sim_text = torch.matmul(test_image_embeds, ref_text.T) # (1, num_classes)

    class_sim = alpha * class_sim_img + class_sim_text # (1, num_classes)

    return class_sim

## Evaluate

In [None]:
from sklearn.metrics import accuracy_score, classification_report

food_dataset['validation'] = food_dataset['validation'].shuffle()

def validate(alpha=1.0, beta=1.0, limit_test=500, verbose=False):
    labels = []
    preds = []
    top_5_all = []

    total_test = min(len(food_dataset['validation']), limit_test)

    for test_idx, test_sample in enumerate(tqdm(food_dataset['validation'], total=total_test)):
        if test_idx >= limit_test:
            break
        
        test_label = test_sample['label']
        test_image = test_sample['image']

        logits = infer_single(test_image, alpha=alpha, beta=beta)
        preds_top_5 = torch.sort(logits, descending=True)[1][0,:5]
        preds_top_1 = preds_top_5[0]

        labels.append(test_label)
        preds.append(preds_top_1)
        top_5_all.append(test_label in preds_top_5)

    if verbose:
        print(classification_report(labels, preds))

    return accuracy_score(labels, preds)

In [None]:
validate(alpha=0.1, beta=0.1, limit_test=5000, verbose=True)

100%|██████████| 65/65 [00:12<00:00,  5.34it/s]

              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           2       0.00      0.00      0.00         1
           4       0.00      0.00      0.00         1
           7       0.00      0.00      0.00         0
           8       1.00      1.00      1.00         1
          13       0.00      0.00      0.00         1
          16       0.00      0.00      0.00         1
          18       0.00      0.00      0.00         1
          19       0.00      0.00      0.00         1
          20       0.00      0.00      0.00         0
          21       0.00      0.00      0.00         1
          25       0.00      0.00      0.00         0
          26       0.00      0.00      0.00         0
          30       0.00      0.00      0.00         0
          32       0.00      0.00      0.00         1
          36       1.00      1.00      1.00         1
          40       0.00      0.00      0.00         0
          41       0.00    


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


0.3076923076923077

## Hyperparameter Search

In [None]:
import itertools

param_grid = {
    'alpha': [0.01, 0.1, 1, 2, 5],
    'beta': [0.01, 0.1, 1, 2, 5]
}

best_hyps = None
best_score = float('-inf')

for combination in itertools.product(*param_grid.values()):
    hyps = dict(zip(param_grid.keys(), combination))
    print(f'testing {hyps}')
    score = validate(**hyps)

    if score > best_score:
        print(f'new_score={score} > best_score={best_score}; best_config={hyps}')
        best_hyps = hyps
        best_score = score
    else:
        print(f'new =_score={score} <= best_score={best_score}; best_config={best_hyps}')

testing {'alpha': 0.01, 'beta': 0.01}


100%|██████████| 500/500 [00:24<00:00, 20.64it/s]


new_score=0.958 > best_score=-inf; best_config={'alpha': 0.01, 'beta': 0.01}
testing {'alpha': 0.01, 'beta': 0.1}


100%|██████████| 500/500 [00:24<00:00, 20.71it/s]


new =_score=0.958 <= best_score=0.958; best_config={'alpha': 0.01, 'beta': 0.01}
testing {'alpha': 0.01, 'beta': 1}


100%|██████████| 500/500 [00:23<00:00, 21.06it/s]


new =_score=0.958 <= best_score=0.958; best_config={'alpha': 0.01, 'beta': 0.01}
testing {'alpha': 0.01, 'beta': 2}


100%|██████████| 500/500 [00:23<00:00, 21.06it/s]


new_score=0.96 > best_score=0.958; best_config={'alpha': 0.01, 'beta': 2}
testing {'alpha': 0.01, 'beta': 5}


100%|██████████| 500/500 [00:24<00:00, 20.79it/s]


new =_score=0.96 <= best_score=0.96; best_config={'alpha': 0.01, 'beta': 2}
testing {'alpha': 0.1, 'beta': 0.01}


100%|██████████| 500/500 [00:24<00:00, 20.82it/s]


new =_score=0.958 <= best_score=0.96; best_config={'alpha': 0.01, 'beta': 2}
testing {'alpha': 0.1, 'beta': 0.1}


100%|██████████| 500/500 [00:24<00:00, 20.56it/s]


new =_score=0.958 <= best_score=0.96; best_config={'alpha': 0.01, 'beta': 2}
testing {'alpha': 0.1, 'beta': 1}


100%|██████████| 500/500 [00:24<00:00, 20.75it/s]


new =_score=0.96 <= best_score=0.96; best_config={'alpha': 0.01, 'beta': 2}
testing {'alpha': 0.1, 'beta': 2}


100%|██████████| 500/500 [00:24<00:00, 20.23it/s]


new_score=0.962 > best_score=0.96; best_config={'alpha': 0.1, 'beta': 2}
testing {'alpha': 0.1, 'beta': 5}


100%|██████████| 500/500 [00:24<00:00, 20.68it/s]


new =_score=0.962 <= best_score=0.962; best_config={'alpha': 0.1, 'beta': 2}
testing {'alpha': 1, 'beta': 0.01}


 73%|███████▎  | 365/500 [00:17<00:06, 20.53it/s]


KeyboardInterrupt: 