In [None]:
import os
#virtually move to parent directory
os.chdir("..")

In [6]:
import json
with open("multimodal/vocab.json", 'r') as f:
    print(list(json.load(f).keys())) 

['<pad>', '<unk>', '<sos>', '<eos>', '.', ',', 'you', '?', 'the', 'yeah', "'s", 'a', 'to', 'it', 'and', 'that', 'we', 'i', 'there', 'is', 'do', 'want', 'are', '"', 'go', 'can', '!', 'on', 'okay', 'in', 'your', 'have', "n't", 'this', 'put', 'oh', 'one', 'here', 'all', 'of', 'what', 'now', 'see', 'some', 'get', 'let', 'right', 'with', 'ok', "'re", '...', 'me', 'going', 'know', 'look', 'like', 'at', 'no', 'where', 'he', 'not', 'alright', 'they', 'for', 'out', 'up', 'so', 'think', 'just', 'if', 'but', 'ball', 'sam', 'little', 'more', '-', 'kitty', 'try', 'good', 's', 'very', 'then', 'done', 'be', 'or', 'my', 'them', 'baby', 'did', "'m", "'ll", 'too', 'again', 'back', 'will', 'down', 'how', 'off', 'na', 'play', 'gon', 'should', 'come', 'read', 'big', 'book', 'was', 'bit', 'got', 'wanna', 'turn', 'need', 'other', 'an', 'about', 'does', 'over', 'those', 'huh', 'well', 'two', 'yea', 'make', 'train', 'time', '..', 'another', 'bye', 'water', 'bear', 'its', 'these', 'take', 'ready', 'which', 'lot

In [84]:
from utils import set_seed
set_seed(42)

In [10]:
import data_utils
import importlib
importlib.reload(data_utils)
device = 'cuda'
model_name = 'clip'
model, transform = data_utils.get_model(model_name, device)
data = data_utils.AnimalDataset("/home/Dataset/xueyi/Animals_with_Attributes2", transform, 'classes.txt',baby_vocab=True, use_attr=True, continuous=True, top_n=5)

Loading CLIP...
Successfully loaded CLIP-ViT-L/14


In [14]:
print(data.class_descriptions)

{2: 'big, furry, strong, brown, cave', 3: 'big, ocean, black, water, fish', 6: 'furry, tail, white, ground, slow', 7: 'fast, big, tail, brown, strong', 9: 'big, water, ocean, blue, strong', 10: 'furry, lean, small, ground, tail', 12: 'small, brown, ground, furry, black', 13: 'stripes, strong, big, ground, fast', 16: 'brown, big, strong, ground, slow', 17: 'tree, tail, hands, furry, fast', 18: 'big, water, ocean, strong, smart', 19: 'big, gray, strong, ground, slow', 22: 'smart, fast, tail, furry, ground', 23: 'white, ground, furry, slow, black', 27: 'tail, furry, small, gray, tree', 29: 'furry, ground, white, small, fast', 31: 'big, lean, ground, yellow, tail', 38: 'stripes, black, white, ground, fast', 42: 'ground, smelly, big, tail, slow', 43: 'strong, big, furry, fast, tail', 44: 'small, tail, ground, white, fast', 45: 'white, big, furry, strong, fish', 49: 'big, tail, brown, ground, white'}


In [88]:
import torch
from tqdm import tqdm
import clip

class ZeroShotClassifier:
    def __init__(self, model_name, model, device):
        self.model_name = model_name
        self.model = model
        self.device = device
        self.model.eval()

    def get_txt_feature(self, clean_cls_name, cls_desc=None, prefix="", use_attr=False):
        if use_attr:
            if cls_desc is None:
                raise ValueError("cls_desc must be provided when use_attr is True.")            
            combined_texts = [f"{prefix}{name}, {desc}" for name, desc in zip(clean_cls_name, cls_desc.values())]
            if "cvcl" in self.model_name:
                print(combined_texts)  # 打印以验证正确性
                text_tokens = [self.model.tokenize(text) for text in combined_texts]
            elif "clip" in self.model_name:
                print(combined_texts)  # 打印以验证正确性
                text_inputs = clip.tokenize(combined_texts).to(self.device)
        else:
            texts = [f"{prefix}{c}" for c in clean_cls_name]
            if "cvcl" in self.model_name:
                text_tokens = [self.model.tokenize(c) for c in texts]
            elif "clip" in self.model_name:
                text_inputs = clip.tokenize(texts).to(self.device)

        if "cvcl" in self.model_name:
            text_inputs = torch.cat([txt[0] for txt in text_tokens]).to(self.device)
            text_lens = torch.cat([txt[1] for txt in text_tokens]).to(self.device)
            text_features = self.model.encode_text(text_inputs, text_lens)
        elif "clip" in self.model_name:
            text_features = self.model.encode_text(text_inputs)

        normalized_text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        return normalized_text_features
        
    def get_img_feature(self, dataloader):
        img_features_list = []
        all_labels_list = []
        for data in tqdm(dataloader, desc="Encoding Images"):
            img, label = data[:2] # can return more than 2 items
            img = img.to(self.device)
            img_features = self.model.encode_image(img)
            img_features_list.append(img_features)
            all_labels_list.append(label.to(self.device)) 
        img_features_tensor = torch.cat(img_features_list)
        norm_img_features = img_features_tensor / img_features_tensor.norm(dim=-1, keepdim=True)
        all_labels_tensor = torch.cat(all_labels_list)
        return norm_img_features, all_labels_tensor
    
    def compute_similarity(self, img_features, text_features):
        similarity = (100.0 * img_features @ text_features.T).softmax(dim=-1)
        return similarity

    def predict_labels(self, similarity, index_map):
        preds = similarity.argmax(dim=-1)  # 这里获取的是相对于文本特征索引的最大值索引
        # 使用从数据集中获取的映射来转换预测的索引
        mapped_preds = [index_map[pred.item()] for pred in preds]
        return torch.tensor(mapped_preds, device=preds.device)
    
    def predict(self, dataloader, prefix=None, use_attr=False):
        with torch.no_grad():
            text_features = self.get_txt_feature(dataloader.dataset.clean_cls_names, dataloader.dataset.class_descriptions, prefix, use_attr)
            img_features, all_labels = self.get_img_feature(dataloader)
            similarity = self.compute_similarity(img_features, text_features)
            index_map = dataloader.dataset.index_map
            all_preds = self.predict_labels(similarity, index_map)  # Pass the full_class_indices here
            similarities = similarity.max(dim=1)[0].cpu().numpy()  # max value of each row
        return similarities, all_preds, all_labels

In [89]:
classifier = ZeroShotClassifier(model_name, model, device)
print(data.classes)
cls_desc = data.class_descriptions
print(cls_desc)

    class_index class_name
0             7      horse
1            12       mole
2            13      tiger
3            16      moose
4            19   elephant
5            22        fox
6            23      sheep
7            27   squirrel
8            29     rabbit
9            31    giraffe
10           38      zebra
11           42        pig
12           43       lion
13           44      mouse
14           49        cow
{7: 'fast, big, tail, brown, strong, ground, lean, black, white, furry, smart', 12: 'small, brown, ground, furry, black, gray, slow, lean, fast, tail, smart', 13: 'stripes, strong, big, ground, fast, orange, tail, furry, lean, bush, black', 16: 'brown, big, strong, ground, slow, furry, tail, smelly, fast, smart, gray', 19: 'big, gray, strong, ground, slow, tail, smelly, bush, smart, brown, white', 22: 'smart, fast, tail, furry, ground, small, brown, lean, orange, red, fish', 23: 'white, ground, furry, slow, black, smelly, gray, small, bush, big, tail', 27: 'tail

In [90]:
dataloader = torch.utils.data.DataLoader(data, batch_size=512, shuffle=False, num_workers=4)
similarities, predictions, labels = classifier.predict(dataloader, prefix="", use_attr=True)
print(predictions)
print(labels)

['horse, fast, big, tail, brown, strong, ground, lean, black, white, furry, smart', 'mole, small, brown, ground, furry, black, gray, slow, lean, fast, tail, smart', 'tiger, stripes, strong, big, ground, fast, orange, tail, furry, lean, bush, black', 'moose, brown, big, strong, ground, slow, furry, tail, smelly, fast, smart, gray', 'elephant, big, gray, strong, ground, slow, tail, smelly, bush, smart, brown, white', 'fox, smart, fast, tail, furry, ground, small, brown, lean, orange, red, fish', 'sheep, white, ground, furry, slow, black, smelly, gray, small, bush, big, tail', 'squirrel, tail, furry, small, gray, tree, brown, fast, ground, lean, smart, hands', 'rabbit, furry, ground, white, small, fast, tail, brown, gray, black, bush, smelly', 'giraffe, big, lean, ground, yellow, tail, bush, brown, fast, strong, slow, orange', 'zebra, stripes, black, white, ground, fast, tail, big, bush, lean, furry, strong', 'pig, ground, smelly, big, tail, slow, brown, strong, gray, white, smart, black'

Encoding Images: 100%|██████████| 29/29 [01:35<00:00,  3.31s/it]


tensor([16,  7, 16,  ..., 49, 49, 49], device='cuda:0')
tensor([ 7,  7,  7,  ..., 49, 49, 49], device='cuda:0')


In [91]:
print(data.index_map)
print(data.classes)
print(data.clean_cls_names)

{0: 7, 1: 12, 2: 13, 3: 16, 4: 19, 5: 22, 6: 23, 7: 27, 8: 29, 9: 31, 10: 38, 11: 42, 12: 43, 13: 44, 14: 49}
    class_index class_name
0             7      horse
1            12       mole
2            13      tiger
3            16      moose
4            19   elephant
5            22        fox
6            23      sheep
7            27   squirrel
8            29     rabbit
9            31    giraffe
10           38      zebra
11           42        pig
12           43       lion
13           44      mouse
14           49        cow
['horse', 'mole', 'tiger', 'moose', 'elephant', 'fox', 'sheep', 'squirrel', 'rabbit', 'giraffe', 'zebra', 'pig', 'lion', 'mouse', 'cow']


In [92]:
from utils import calculate_accuracy
acc, _ = calculate_accuracy(predictions, labels)
print(f"Accuracy: {acc * 100:.2f}%")

Accuracy: 96.04%


Top N?

In [93]:
model, _  = get_model('cvcl_res', device)
description = "horse,fast,big,tail,brown,strong,ground,lean,smart,furry,white,black"
tokens, token_lengths = model.tokenize(description)
print(f"Tokens: {tokens}")
print(f"Token lengths: {token_lengths}")

Loading CVCL...


Lightning automatically upgraded your loaded checkpoint from v1.5.8 to v2.2.1. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../.cache/huggingface/hub/models--wkvong--cvcl_s_dino_resnext50_embedding/snapshots/fe96aa69683bad69e5dd5195fc874a3edb8cb691/cvcl_s_dino_resnext50_embedding.ckpt`


Successfully loaded CVCL-resnext50
Tokens: tensor([[   2,  381,    5,  531,    5,  104,    5,  720,    5,  290,    5, 1552,
            5,  795,    5, 1259,    5, 1915,    5,  998,    5,  416,    5,  450,
            3]])
Token lengths: tensor([25])


/home/xke001/miniconda3/envs/cvcl/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'vision_encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['vision_encoder'])`.
/home/xke001/miniconda3/envs/cvcl/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'text_encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['text_encoder'])`.


In [94]:
print(data.attribute_file)

    attribute_index attribute_name
0                 1          black
1                 2          white
2                 3           blue
3                 4          brown
4                 5           gray
5                 6         orange
6                 7            red
7                 8         yellow
10               11        stripes
11               12          furry
14               15            big
15               16          small
17               18           lean
19               20          hands
25               26           tail
33               34         smelly
39               40           fast
40               41           slow
41               42         strong
51               52           fish
67               68           bush
73               74          ocean
74               75         ground
75               76          water
76               77           tree
77               78           cave
80               81          smart


In [95]:
class_index = 7  # 1-based index, 用于选取类似“horse”这样的类

data = AnimalDataset("/home/Dataset/xueyi/Animals_with_Attributes2", transform, 'classes.txt', True, True, True)
attribute_vector = data.attribute_matrix[class_index - 1]

# 获取属性值的索引，这次只取最高的8个或10个值
top_indices = np.argsort(attribute_vector)[-11:]  # 可以修改这里的数值为-8或-10
top_values = attribute_vector[top_indices]
# 确保属性名称正确地对应于这些索引
top_attribute_names = [data.attribute_file['attribute_name'].iloc[i] for i in top_indices]

# 打印结果，确保以降序打印属性
print("Top Attributes and their values:")
for name, value in zip(top_attribute_names, top_values):
    print(f"{name}: {value}")

Top Attributes and their values:
smart: 37.28
furry: 40.58
white: 42.91
black: 44.9
lean: 47.96
ground: 56.52
strong: 69.13
brown: 69.41
tail: 70.42
big: 71.5
fast: 81.68


more attri more noise? top11=>Accuracy: 6.84% top7=>8.64% 
no attri: 8+
no baby: 4+
