In [None]:
import os
import importlib
import numpy as np
import torch
from PIL import Image
from wordcloud import WordCloud

# os.chdir('/your/path/to/Clip')
!pwd
!echo $CUDA_VISIBLE_DEVICES
!nvidia-smi

In [2]:
from dtd2.data_api.dataset_api import TextureDescriptionData
from dtd2.models.layers.util import print_tensor_stats
from dtd2.data_api.eval_retrieve import retrieve_eval as dtd2_eval

split = 'test'
dataset = TextureDescriptionData(phid_format=None)
print('dataset ready.')
img_num = len(dataset.img_splits[split])
phrase_num = len(dataset.phrases)

TextureDescriptionData ready. 
655 phrases with frequency above 10.
Image count: train 3222, val 805, test 1342
dataset ready.


In [3]:
desc_num = 0
for img_name in dataset.img_splits['train']:
    desc_num += len(dataset.img_data_dict[img_name]['descriptions'])
print(desc_num)

14797


In [7]:
img_paths = [os.path.join('data/DTD2/images', img_name) for img_name in dataset.img_splits[split]]


# Specialized model

In [4]:
from utils.dtd2_triplet_encoder import TripletEncoder
dtd2_encoder = TripletEncoder()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
img_vecs = dtd2_encoder.encode_imgs(img_paths)
print(img_vecs.shape)

0 imgs encoded
100 imgs encoded
200 imgs encoded
300 imgs encoded
400 imgs encoded
500 imgs encoded
600 imgs encoded
700 imgs encoded
800 imgs encoded
900 imgs encoded
1000 imgs encoded
1100 imgs encoded
1200 imgs encoded
1300 imgs encoded
1342 imgs encoded
torch.Size([1342, 256])


In [9]:
phrase_vecs = dtd2_encoder.encode_text_list(dataset.phrases)
print(phrase_vecs.shape)

527 remaining texts
399 remaining texts
271 remaining texts
143 remaining texts
15 remaining texts
torch.Size([655, 256])


In [10]:
print(img_vecs.shape)
print(phrase_vecs.shape)

neg_distances = torch.zeros((img_num, phrase_num))
with torch.no_grad():
    for img_i in range(img_num):
        for ph_i in range(phrase_num):
            v1 = img_vecs[img_i]
            v2 = phrase_vecs[ph_i]
            neg_distances[img_i, ph_i] = - dtd2_encoder.dist(v1, v2)
        if img_i % 100 == 0:
            print(img_i)

print_tensor_stats(neg_distances, 'pred_scores')
mdtd2_scores = neg_distances

torch.Size([1342, 256])
torch.Size([655, 256])
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
STAT pred_scores: shape torch.Size([1342, 655]) device cpu; mean -12.123; min -24.844; max -4.927; std 1.761


In [11]:
p2i_result = dtd2_eval(mode='phrase2img', match_scores=mdtd2_scores, dataset=dataset,
                           split=split, visualize_path='output/dtd2_model_result')

i2p_result = dtd2_eval(mode='img2phrase', match_scores=mdtd2_scores, dataset=dataset,
                           split=split, visualize_path='output/dtd2_model_result')


  recall = np.nan_to_num(pred_num * 1.0 / gt_count)


## retrieve_eval phrase2img on test ##
mean_average_precision: 0.1350
mean_reciprocal_rank: 0.3112
precision_at_005: 0.1652
precision_at_010: 0.1565
precision_at_020: 0.1457
precision_at_050: 0.1162
precision_at_100: 0.0885
query_average_precisions: [skipped]
r_precision: 0.0164
recall_at_005: 0.0524
recall_at_010: 0.0978
recall_at_020: 0.1732
recall_at_050: 0.3362
recall_at_100: 0.4738
## retrieve_eval img2phrase on test ##
mean_average_precision: 0.3177
mean_reciprocal_rank: 0.7412
precision_at_005: 0.4170
precision_at_010: 0.3256
precision_at_020: 0.2360
precision_at_050: 0.1375
precision_at_100: 0.0847
query_average_precisions: [skipped]
r_precision: 0.0164
recall_at_005: 0.2017
recall_at_010: 0.3135
recall_at_020: 0.4504
recall_at_050: 0.6488
recall_at_100: 0.7934


# CLIP

In [12]:
from utils.clip_encoder import ClipEncoder
clip_encoder = ClipEncoder()

ClipEncoder ready.


In [13]:
img_vecs = clip_encoder.encode_imgs(img_paths)
print('images encoded')

1214 remaining imgs
1086 remaining imgs
958 remaining imgs
830 remaining imgs
702 remaining imgs
574 remaining imgs
446 remaining imgs
318 remaining imgs
190 remaining imgs
62 remaining imgs
images encoded


In [14]:
template = 'An image of %s texture'
phrases = [template % p for p in dataset.phrases]
phrase_vecs = clip_encoder.encode_text_list(phrases)
print('phrases encoded')
print(phrase_vecs.shape)

527 remaining texts
399 remaining texts
271 remaining texts
143 remaining texts
15 remaining texts
phrases encoded
torch.Size([655, 512])


In [15]:
clip_scores = img_vecs @ phrase_vecs.T
clip_scores = clip_scores.float().softmax(dim=-1).cpu()

print_tensor_stats(clip_scores, 'clip_scores')

STAT clip_scores: shape torch.Size([1342, 655]) device cpu; mean 0.002; min 0.001; max 0.002; std 0.000


In [16]:
clip_p2i = dtd2_eval(mode='phrase2img', match_scores=clip_scores, dataset=dataset,
                     split=split, visualize_path='output/dtd2_p2i_result')

clip_i2p = dtd2_eval(mode='img2phrase', match_scores=clip_scores, dataset=dataset,
                     split=split, visualize_path='output/dtd2_p2i_result')

## retrieve_eval phrase2img on test ##
mean_average_precision: 0.1274
mean_reciprocal_rank: 0.3215
precision_at_005: 0.1695
precision_at_010: 0.1536
precision_at_020: 0.1322
precision_at_050: 0.1005
precision_at_100: 0.0762
query_average_precisions: [skipped]
r_precision: 0.0164
recall_at_005: 0.0616
recall_at_010: 0.1077
recall_at_020: 0.1731
recall_at_050: 0.2982
recall_at_100: 0.4225
## retrieve_eval img2phrase on test ##
mean_average_precision: 0.1220
mean_reciprocal_rank: 0.4006
precision_at_005: 0.1768
precision_at_010: 0.1473
precision_at_020: 0.1142
precision_at_050: 0.0736
precision_at_100: 0.0519
query_average_precisions: [skipped]
r_precision: 0.0164
recall_at_005: 0.0842
recall_at_010: 0.1397
recall_at_020: 0.2154
recall_at_050: 0.3483
recall_at_100: 0.4889


# CLIPP 

In [5]:
from utils.clip_plus import ClipPlusEncoder

for epoch in range(130, 70, -10):
# for epoch in [100]:
    print('\nepoch%d' % epoch)
    clipp = ClipPlusEncoder(load_path='output/clipp/models/prompt_lr0.001/clipp_epoch%d.pth' % epoch)
    clipp.eval()

    img_paths = [os.path.join('data/DTD2/images', img_name) for img_name in dataset.img_splits[split]]
    img_vecs = clipp.encode_imgs(img_paths)
    print('images encoded')

    template = 'An image of %s texture'
    phrases = [template % p for p in dataset.phrases]
    phrase_vecs = clipp.encode_text_list(phrases)
    print('phrases encoded')
    print(phrase_vecs.shape)

    clipp_scores = img_vecs @ phrase_vecs.T
    clipp_scores = clipp_scores.float().softmax(dim=-1).cpu().detach()

    print_tensor_stats(clipp_scores, 'clipp_scores')

    clipp_p2i = dtd2_eval(mode='phrase2img', match_scores=clipp_scores, dataset=dataset,
                         split=split, visualize_path='output/clipp/dtd2_p2i_result')

    clipp_i2p = dtd2_eval(mode='img2phrase', match_scores=clipp_scores, dataset=dataset,
                     split=split, visualize_path='output/clipp/dtd2_p2i_result')


epoch130
ClipEncoder ready.
CLIPP ready
1214 remaining imgs
1086 remaining imgs
958 remaining imgs
830 remaining imgs
702 remaining imgs
574 remaining imgs
446 remaining imgs
318 remaining imgs
190 remaining imgs
62 remaining imgs
images encoded
527 remaining texts
399 remaining texts
271 remaining texts
143 remaining texts
15 remaining texts
527 remaining texts
399 remaining texts
271 remaining texts
143 remaining texts
15 remaining texts
phrases encoded
torch.Size([655, 512])
STAT clipp_scores: shape torch.Size([1342, 655]) device cpu; mean 0.002; min 0.000; max 0.307; std 0.004


  recall = np.nan_to_num(pred_num * 1.0 / gt_count)


## retrieve_eval phrase2img on test ##
mean_average_precision: 0.1372
mean_reciprocal_rank: 0.3248
precision_at_005: 0.1795
precision_at_010: 0.1685
precision_at_020: 0.1490
precision_at_050: 0.1122
precision_at_100: 0.0833
query_average_precisions: [skipped]
r_precision: 0.0164
recall_at_005: 0.0617
recall_at_010: 0.1122
recall_at_020: 0.1849
recall_at_050: 0.3198
recall_at_100: 0.4398
## retrieve_eval img2phrase on test ##
mean_average_precision: 0.1189
mean_reciprocal_rank: 0.3589
precision_at_005: 0.1610
precision_at_010: 0.1362
precision_at_020: 0.1121
precision_at_050: 0.0776
precision_at_100: 0.0549
query_average_precisions: [skipped]
r_precision: 0.0164
recall_at_005: 0.0770
recall_at_010: 0.1297
recall_at_020: 0.2114
recall_at_050: 0.3671
recall_at_100: 0.5150

epoch120
ClipEncoder ready.
CLIPP ready
1214 remaining imgs
1086 remaining imgs
958 remaining imgs
830 remaining imgs
702 remaining imgs
574 remaining imgs
446 remaining imgs
318 remaining imgs
190 remaining imgs
62 rem

# Comparison and visualization

In [17]:
import utils.retrieval_compare as rc
importlib.reload(rc)
from utils.retrieval_compare import compare_pred_to_html

img_html_paths = ['https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/%s' % img_name
                  for img_name in dataset.img_splits[split]]
gt_matrix = dataset.get_img_phrase_match_matrices(split)

compare_pred_to_html(img_path_list=img_html_paths,
                     phrase_list=dataset.phrases,
                     phrase_weight=dataset.phrase_freq,
                     gt_matrix=gt_matrix,
                     method_score_list=[['CLIP', clip_scores], ['DTD2_contrastive', mdtd2_scores]],
                     output_path='output/retrieve_dtd2')

CLIP
image to phrase


  recall = np.nan_to_num(pred_num * 1.0 / gt_count)


mean_average_precision: 0.1220
mean_reciprocal_rank: 0.4006
precision_at_001: 0.2280
precision_at_005: 0.1768
precision_at_010: 0.1473
precision_at_020: 0.1142
precision_at_050: 0.0736
precision_at_100: 0.0519
query_average_precisions: [list, skipped]
r_precision: 0.0164
recall_at_001: 0.0214
recall_at_005: 0.0842
recall_at_010: 0.1397
recall_at_020: 0.2154
recall_at_050: 0.3483
recall_at_100: 0.4889
latex string
mean_average_precision & mean_reciprocal_rank & precision_at_005 & precision_at_020 & recall_at_005 & recall_at_020
12.20 & 40.06 & 17.68 & 11.42 & 8.42 & 21.54
phrase to image
mean_average_precision: 0.1274
mean_reciprocal_rank: 0.3215
precision_at_001: 0.1954
precision_at_005: 0.1695
precision_at_010: 0.1536
precision_at_020: 0.1322
precision_at_050: 0.1005
precision_at_100: 0.0762
query_average_precisions: [list, skipped]
r_precision: 0.0164
recall_at_001: 0.0137
recall_at_005: 0.0616
recall_at_010: 0.1077
recall_at_020: 0.1731
recall_at_050: 0.2982
recall_at_100: 0.4225
la