In [31]:
def get_metrics(image_features, text_features, labels, logit_scale):
    # 计算相似度，支持多个样本的情况（比如一个图片有多个caption）
    # img2txt计算的时候要用到，因为一张图片可能对应多个文本。
    # txt2img计算的时候不需要（一般一个text只有一个对应图片）
    metrics = {}
    logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
    logits_per_text = logits_per_image.t().detach().cpu()

    logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}

    label2idx = {}  # 计算label到idx的映射。
    repeat_id = []
    for i, label in enumerate(labels):
        if label not in label2idx:
            label2idx[label] = [i]
        else:
            # 表示该index的标签出现过，记录这个index，后续算txt2img分数的时候，这些index的权值要降低。
            label2idx[label].append(i)
            repeat_id.append(i)
    # print(label2idx)    # 标注了每个label的idx

    # print('repeat_id:', repeat_id)
    ground_truth = [label2idx[label] for label in labels]
    # print(ground_truth)

    for name, logit in logits.items():
        # print(name, logit.shape)
        if name == 'text_to_image':
            logit[:, repeat_id] -= 1e8   # 这部分的分数要降低。（重复出现的图片，直接忽略）
        r1_stat, r5_stat, r10_stat = [], [], []
        ranking = torch.argsort(logit, descending=True) # index of the largest element to the smallest
        # print(name, ranking[:, :10])
        for i, each_query in enumerate(ranking[:, :10]):
            for j, q in enumerate(each_query):
                if q in ground_truth[i]:
                    if j == 0:
                        r1_stat.append(1)
                        r5_stat.append(1)
                        r10_stat.append(1)
                        break
                    if j < 5:
                        r5_stat.append(1)
                        r10_stat.append(1)
                        break
                    if j < 10:
                        r10_stat.append(1)
                        break
        print(f'{name} r1:{sum(r1_stat)/len(logit)}, r5:{sum(r5_stat)/len(logit)}, r10:{sum(r10_stat)/len(logit)}')


# COCO-CN

In [33]:
from transformers import BertTokenizer
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
class COCO_CN(Dataset):
    def __init__(self, img_root_path='/home/chenweifeng/dataset/coco', \
                test_img_path='/home/chenweifeng/dataset/coco/coco-cn-version1805v1.1/coco-cn_test.txt', \
                annot_path = '/home/chenweifeng/dataset/coco/coco-cn-version1805v1.1/imageid.human-written-caption.txt', \
                transform=None):
        self.images = []
        self.captions = []
        self.labels = []
        self.root = img_root_path
        
        test_path = dict()
        with open(test_img_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line not in test_path:
                    test_path[line] = 1
        # print(test_path)

        with open(annot_path, 'r') as f:
            for line in f:
                line = line.strip().split('\t')
                key, caption = line[0].split('#')[0], line[1]
                # NOTE 只保留test set的
                if key not in test_path:
                    continue
                # if line[0].split('#')[-1] != '0':
                #     # print(key, line[0].split('#')[-1])
                #     continue # 只保留一句
                img_path = key + '.jpg'

                if 'train' in img_path:
                    self.images.append(os.path.join('train2014' ,img_path) )
                else:
                    self.images.append(os.path.join('val2014' ,img_path) )
                self.captions.append(caption)
                self.labels.append(key)
        self.transforms = transform
        self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

        # NOTE large 模型
        self.context_length = 77

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = str(self.images[idx])
        image = self.transforms(Image.open( os.path.join(self.root, img_path ))) 
        text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length, padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0]
        label = self.labels[idx]
        return image, text, label

from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
    CenterCrop
def _convert_to_rgb(image):
    return image.convert('RGB')

def image_transform(
        image_size: int,
        is_train: bool,
        mean=(0.48145466, 0.4578275, 0.40821073),
        std=(0.26862954, 0.26130258, 0.27577711)
):
    normalize = Normalize(mean=mean, std=std)
    if is_train:
        return Compose([
            RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
            _convert_to_rgb,
            ToTensor(),
            normalize,
        ])
    else:
        return Compose([
            Resize(image_size, interpolation=InterpolationMode.BICUBIC),
            CenterCrop(image_size),
            _convert_to_rgb,
            ToTensor(),
            normalize,
        ])

val_transform = image_transform(224, False)
dataset = COCO_CN(transform = val_transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)

In [23]:
len(dataset)

1000

In [36]:
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from transformers import CLIPModel
import torch
# NOTE load model

text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese").cuda().eval()
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").cuda().eval() 

# text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese").cuda().eval()
# clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").cuda().eval() 


all_img_features, all_text_features, all_labels = [], [], []
with torch.no_grad():
    for i, data in enumerate(dataloader):
        images, captions, labels = data
        images = images.cuda()
        captions = captions.cuda()
        all_labels.extend(labels)
        # print(images.shape, captions.shape, labels)

        image_features = clip_model.get_image_features(images)
        text_features = text_encoder(captions).logits
        # 归一化
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)
        all_img_features.append(image_features)
        all_text_features.append(text_features)
        # if i == 10:
        #     break
    img_features = torch.cat(all_img_features)
    text_features = torch.cat(all_text_features)
    print(img_features.shape, text_features.shape, len(all_labels))

torch.Size([1053, 512]) torch.Size([1053, 512]) 1053


In [37]:
get_metrics(img_features, text_features, all_labels, 100)  

image_to_text r1:0.5289648622981956, r5:0.8110161443494777, r10:0.8983855650522318
text_to_image r1:0.4624881291547958, r5:0.7806267806267806, r10:0.8888888888888888


# flickr30k-CNA

In [26]:
from transformers import BertTokenizer
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
class flickr30k_CNA(Dataset):
    def __init__(self, img_root_path='/home/chenweifeng/dataset/mm_data/Flickr30k-CNA/flickr30k/images', \
                text_annot_path='/home/chenweifeng/dataset/mm_data/Flickr30k-CNA/test/flickr30k_cn_test.txt', \
                transform=None):
        self.images = []
        self.captions = []
        self.labels = []
        self.root = img_root_path
        with open(text_annot_path, 'r') as f:
            for line in f:
                line = line.strip().split('\t')
                key, caption = line[0].split('#')[0], line[1]
                img_path = key + '.jpg'
                self.images.append(img_path)
                self.captions.append(caption)
                self.labels.append(key)
        self.transforms = transform
        self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

        # NOTE large 模型
        self.context_length = 77

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = str(self.images[idx])
        image = self.transforms(Image.open( os.path.join(self.root, img_path ))) 
        text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length, padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0]
        label = self.labels[idx]
        return image, text, label

from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
    CenterCrop
def _convert_to_rgb(image):
    return image.convert('RGB')

def image_transform(
        image_size: int,
        is_train: bool,
        mean=(0.48145466, 0.4578275, 0.40821073),
        std=(0.26862954, 0.26130258, 0.27577711)
):
    normalize = Normalize(mean=mean, std=std)
    if is_train:
        return Compose([
            RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
            _convert_to_rgb,
            ToTensor(),
            normalize,
        ])
    else:
        return Compose([
            Resize(image_size, interpolation=InterpolationMode.BICUBIC),
            CenterCrop(image_size),
            _convert_to_rgb,
            ToTensor(),
            normalize,
        ])

val_transform = image_transform(224, False)
img_root = '/home/chenweifeng/dataset/mm_data/Flickr30k-CNA/flickr30k/images'
text_annot_path = '/home/chenweifeng/dataset/mm_data/Flickr30k-CNA/test/flickr30k_cn_test.txt'
dataset = flickr30k_CNA(img_root, text_annot_path, val_transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)

In [27]:
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from transformers import CLIPModel
import torch
# text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese").cuda().eval()
# clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").cuda().eval() 

text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese").cuda().eval()
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").cuda().eval() 

In [28]:
all_img_features, all_text_features, all_labels = [], [], []
with torch.no_grad():
    for i, data in enumerate(dataloader):
        images, captions, labels = data
        images = images.cuda()
        captions = captions.cuda()
        all_labels.extend(labels)
        # print(images.shape, captions.shape, labels)

        image_features = clip_model.get_image_features(images)
        text_features = text_encoder(captions).logits
        # 归一化
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)
        all_img_features.append(image_features)
        all_text_features.append(text_features)
        # if i == 10:
        #     break
    img_features = torch.cat(all_img_features)
    text_features = torch.cat(all_text_features)
    print(img_features.shape, text_features.shape, len(all_labels))

torch.Size([5000, 768]) torch.Size([5000, 768]) 5000


In [32]:
get_metrics(img_features, text_features, all_labels, 100)  # 图片取前1000张，因为后面的是重复的（每张图片对应5个caption）Flickr

image_to_text r1:0.659, r5:0.903, r10:0.957
text_to_image r1:0.5108, r5:0.782, r10:0.8594
