In [None]:
import os
import importlib
import numpy as np
import torch

# os.chdir('/your/path/to/Clip')
!pwd
!echo $CUDA_VISIBLE_DEVICES
!nvidia-smi

In [2]:
from dtd2.applications.fine_grained_classification.cub_dataset import CUBDataset
cub_data = CUBDataset(split='test', data_path='data/CUB_200_2011')

CUB dataset ready.


In [3]:
adj_to_att = dict()

for att_i, att in enumerate(cub_data.att_names):
    context, adj = att.split('::')
    skip = False
    for w in ['length', 'size', 'shape', 'bill', 'eye', 'nape', 'crown']:
        if w in context:
            skip = True
            break
    if skip:
        continue
    if adj in adj_to_att:
        adj_to_att[adj].append(att_i)
    else:
        adj_to_att[adj] = [att_i]

print(list(adj_to_att.keys()))


['blue', 'brown', 'iridescent', 'purple', 'rufous', 'grey', 'yellow', 'olive', 'green', 'pink', 'orange', 'black', 'white', 'red', 'buff', 'solid', 'spotted', 'striped', 'multi-colored', 'malar', 'crested', 'masked', 'unique_pattern', 'eyebrow', 'eyering', 'plain', 'eyeline', 'capped']


In [4]:
from dtd2.data_api.dataset_api import TextureDescriptionData
dtd2_data = TextureDescriptionData(phid_format=None)

phrases = [adj for adj in adj_to_att.keys() if adj in dtd2_data.phrases]
print(len(phrases), phrases)


TextureDescriptionData ready. 
655 phrases with frequency above 10.
Image count: train 3222, val 805, test 1342
17 ['blue', 'brown', 'iridescent', 'purple', 'grey', 'yellow', 'green', 'pink', 'orange', 'black', 'white', 'red', 'solid', 'spotted', 'striped', 'multi-colored', 'plain']


In [5]:
att_to_adj_mat = np.zeros((len(cub_data.att_names), len(phrases)))
for ph_i, ph in enumerate(phrases):
    for att_i in adj_to_att[ph]:
        att_to_adj_mat[att_i, ph_i] = 1

In [6]:
img_paths = [os.path.join('data/CUB_200_2011/images', d['img_name']) for d in cub_data.img_data_list]
img_html_paths = [os.path.join('http://maxwell.cs.umass.edu/chenyun/data/CUB_200_2011/images', d['img_name']) for d in cub_data.img_data_list]
gt_matrix = np.dot(cub_data.gt_att_labels, att_to_adj_mat)

In [7]:
img_paths = [img_paths[i] for i in cub_data.img_splits['test']]
img_html_paths = [img_html_paths[i] for i in cub_data.img_splits['test']]
gt_matrix = np.stack([gt_matrix[i] for i in cub_data.img_splits['test']])
print(len(cub_data.img_splits['test']), len(img_paths), gt_matrix.shape)


5794 5794 (5794, 17)


In [9]:
from utils.clip_encoder import ClipEncoder
clip_encoder = ClipEncoder()

text_l = ['An image of a %s bird' % p for p in phrases]
phrase_vecs = clip_encoder.encode_text_list(text_l)
print('phrases encoded')

img_vecs = clip_encoder.encode_imgs(img_paths)
print('imgs encoded')

clip_scores = img_vecs @ phrase_vecs.T
clip_scores = clip_scores.float().softmax(dim=-1).cpu()


ClipEncoder ready.
phrases encoded
5666 remaining imgs
5538 remaining imgs
5410 remaining imgs
5282 remaining imgs
5154 remaining imgs
5026 remaining imgs
4898 remaining imgs
4770 remaining imgs
4642 remaining imgs
4514 remaining imgs
4386 remaining imgs
4258 remaining imgs
4130 remaining imgs
4002 remaining imgs
3874 remaining imgs
3746 remaining imgs
3618 remaining imgs
3490 remaining imgs
3362 remaining imgs
3234 remaining imgs
3106 remaining imgs
2978 remaining imgs
2850 remaining imgs
2722 remaining imgs
2594 remaining imgs
2466 remaining imgs
2338 remaining imgs
2210 remaining imgs
2082 remaining imgs
1954 remaining imgs
1826 remaining imgs
1698 remaining imgs
1570 remaining imgs
1442 remaining imgs
1314 remaining imgs
1186 remaining imgs
1058 remaining imgs
930 remaining imgs
802 remaining imgs
674 remaining imgs
546 remaining imgs
418 remaining imgs
290 remaining imgs
162 remaining imgs
34 remaining imgs
imgs encoded


In [17]:
from utils.clip_plus import ClipPlusEncoder
clipp_encoder = ClipPlusEncoder(load_path='output/clipp/models/prompt_lr0.001/clipp_epoch100.pth')

text_l = ['An image of a %s bird' % p for p in phrases]
with torch.no_grad():
    phrase_vecs = clipp_encoder.encode_text_list(text_l)
    print('phrases encoded')

    img_vecs = clipp_encoder.encode_imgs(img_paths)
    print('imgs encoded')

clipp_scores = img_vecs @ phrase_vecs.T
clipp_scores = clipp_scores.float().softmax(dim=-1).cpu()

ClipEncoder ready.
CLIPP ready
phrases encoded
5666 remaining imgs
5538 remaining imgs
5410 remaining imgs
5282 remaining imgs
5154 remaining imgs
5026 remaining imgs
4898 remaining imgs
4770 remaining imgs
4642 remaining imgs
4514 remaining imgs
4386 remaining imgs
4258 remaining imgs
4130 remaining imgs
4002 remaining imgs
3874 remaining imgs
3746 remaining imgs
3618 remaining imgs
3490 remaining imgs
3362 remaining imgs
3234 remaining imgs
3106 remaining imgs
2978 remaining imgs
2850 remaining imgs
2722 remaining imgs
2594 remaining imgs
2466 remaining imgs
2338 remaining imgs
2210 remaining imgs
2082 remaining imgs
1954 remaining imgs
1826 remaining imgs
1698 remaining imgs
1570 remaining imgs
1442 remaining imgs
1314 remaining imgs
1186 remaining imgs
1058 remaining imgs
930 remaining imgs
802 remaining imgs
674 remaining imgs
546 remaining imgs
418 remaining imgs
290 remaining imgs
162 remaining imgs
34 remaining imgs
imgs encoded


In [19]:
from dtd2.models.layers.util import print_tensor_stats
print_tensor_stats(clip_scores, 'clip_scores')
print_tensor_stats(clipp_scores, 'clipp_scores')
print_tensor_stats(clip_scores - clipp_scores)

STAT clip_scores: shape torch.Size([5794, 17]) device cpu; mean 0.059; min 0.053; max 0.065; std 0.001
STAT clipp_scores: shape torch.Size([5794, 17]) device cpu; mean 0.059; min 0.001; max 0.873; std 0.055
STAT : shape torch.Size([5794, 17]) device cpu; mean 0.000; min -0.809; max 0.056; std 0.055


In [None]:
from utils.dtd2_triplet_encoder import TripletEncoder
dtd2_encoder = TripletEncoder()

phrase_vecs = dtd2_encoder.encode_text_list(phrases)
print('phrases encoded: ', phrase_vecs.shape)

img_vecs = dtd2_encoder.encode_imgs(img_paths)
print('imgs encoded: ', img_vecs.shape)

img_num = len(img_paths)
phrase_num = len(phrases)

neg_distances = torch.zeros((img_num, phrase_num))
with torch.no_grad():
    for img_i in range(img_num):
        for ph_i in range(phrase_num):
            v1 = img_vecs[img_i]
            v2 = phrase_vecs[ph_i]
            neg_distances[img_i, ph_i] = - dtd2_encoder.dist(v1, v2)
        if img_i % 100 == 0:
            print(img_i)

print_tensor_stats(neg_distances, 'pred_scores')
mdtd2_scores = neg_distances

In [20]:
np.sum(gt_matrix, axis=1)

array([ 1., 11., 17., ..., 16., 16., 32.])

In [21]:
import utils.retrieval_compare as rc
importlib.reload(rc)

from utils.retrieval_compare import compare_pred_to_html

# compare_pred_to_html(img_path_list=img_html_paths,
#                      phrase_list=phrases,
#                      phrase_weight=None,
#                      gt_matrix=gt_matrix,
#                      method_score_list=[['CLIP', clip_scores], ['DTD2', mdtd2_scores]],
#                      output_path='output/retrieve_CUB_CLIP',
#                      word_cloud=False)

compare_pred_to_html(img_path_list=img_html_paths,
                     phrase_list=phrases,
                     phrase_weight=None,
                     gt_matrix=gt_matrix,
                     method_score_list=[['CLIP', clip_scores], ['CLIPP', clipp_scores]],
                     output_path='output/retrieve_CUB_CLIPP',
                     word_cloud=False)

CLIP
image to phrase


  recall = np.nan_to_num(pred_num * 1.0 / gt_count)


mean_average_precision: 0.5406
mean_reciprocal_rank: 0.7593
precision_at_001: 0.6151
precision_at_005: 0.4321
precision_at_010: 0.3621
precision_at_020: -1.0000
precision_at_050: -1.0000
precision_at_100: -1.0000
query_average_precisions: [list, skipped]
r_precision: 0.3055
recall_at_001: 0.1292
recall_at_005: 0.4343
recall_at_010: 0.7110
recall_at_020: -1.0000
recall_at_050: -1.0000
recall_at_100: -1.0000
latex string
mean_average_precision & mean_reciprocal_rank & precision_at_005 & precision_at_020 & recall_at_005 & recall_at_020
54.06 & 75.93 & 43.21 & -100.00 & 43.43 & -100.00
phrase to image
mean_average_precision: 0.5014
mean_reciprocal_rank: 0.9171
precision_at_001: 0.8824
precision_at_005: 0.7294
precision_at_010: 0.7353
precision_at_020: 0.7176
precision_at_050: 0.7047
precision_at_100: 0.6812
query_average_precisions: [list, skipped]
r_precision: 0.3055
recall_at_001: 0.0012
recall_at_005: 0.0046
recall_at_010: 0.0090
recall_at_020: 0.0160
recall_at_050: 0.0375
recall_at_100

In [31]:
for n in cub_data.att_names:
    print(n)

has_bill_shape::curved_(up_or_down)
has_bill_shape::dagger
has_bill_shape::hooked
has_bill_shape::needle
has_bill_shape::hooked_seabird
has_bill_shape::spatulate
has_bill_shape::all-purpose
has_bill_shape::cone
has_bill_shape::specialized
has_wing_color::blue
has_wing_color::brown
has_wing_color::iridescent
has_wing_color::purple
has_wing_color::rufous
has_wing_color::grey
has_wing_color::yellow
has_wing_color::olive
has_wing_color::green
has_wing_color::pink
has_wing_color::orange
has_wing_color::black
has_wing_color::white
has_wing_color::red
has_wing_color::buff
has_upperparts_color::blue
has_upperparts_color::brown
has_upperparts_color::iridescent
has_upperparts_color::purple
has_upperparts_color::rufous
has_upperparts_color::grey
has_upperparts_color::yellow
has_upperparts_color::olive
has_upperparts_color::green
has_upperparts_color::pink
has_upperparts_color::orange
has_upperparts_color::black
has_upperparts_color::white
has_upperparts_color::red
has_upperparts_color::buff
has_u