In [1]:
import torch
import pandas as pd
import os.path as osp
import numpy as np
from ast import literal_eval
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import pdb
import clip 
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
    

In [2]:
## Prepare the  dataset (merge it with the emotion-histograms.)
datasetname = 'COCO' #ArtEmis, Flickr30K,  VizWiz, COCO

if datasetname == 'ArtEmis':
    datafile = f'../Dataset/{datasetname}/{datasetname}_IdC/{datasetname}_IdCII_3ErrType.csv'
    img_dir = f"../Dataset/{datasetname}/{datasetname}_IdC/Images/rawImages"
    df = pd.read_csv(datafile)
    df = df[df.split=='test']
else:
    datafile = f'../Dataset/{datasetname}/{datasetname}_IdCII_3ErrType.csv'
    img_dir = f"../Dataset/{datasetname}/Images/rawImages"
    df = pd.read_csv(datafile)
df.reset_index(drop=True,inplace=True)
print('Number of caption sets in the test set:', len(df))
df['captSet_CLIP_tokens'] = df['captSet_CLIP_tokens'].apply(literal_eval)
df.img_files = [osp.join(img_dir,imgfile) for imgfile in df.img_files]

Number of caption sets in the test set: 1699


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
CLIP_name = 'RN50x16'
CLIPmodel,CLIPtransform,CLIPsettings = clip.load(CLIP_name,device, jit=False)
embed_dim,image_resolution, vision_layers, vision_width, vision_patch_size,context_length_CLIP, vocab_size_CLIP, transformer_width, transformer_heads, transformer_layers = CLIPsettings

In [4]:
def preprocess_dataset(df,img_dim):
    img_transform = Compose([ 
                        Resize(image_resolution, interpolation=BICUBIC),
                        CenterCrop(image_resolution),
                        lambda image: image.convert("RGB"),
                        ToTensor(),
                        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
                    ])
    dataset = Dataset(df.img_files, df.captSet_CLIP_tokens,img_transform=img_transform)
    return dataset
class Dataset(Dataset):
    def __init__(self, image_files,captSets,img_transform=None):
        super().__init__()
        self.image_files = image_files
        self.captSets = captSets
        self.img_transform = img_transform

    def __getitem__(self, index):
        captSet = np.array(self.captSets[index]).astype(dtype=np.long)
        if self.image_files is not None:
            img = Image.open(self.image_files[index])

            if img.mode is not 'RGB':
                img = img.convert('RGB')

            if self.img_transform is not None:
                img = self.img_transform(img)
        else:
            img = []
        item = {'image': img, 'captSet': captSet, 'index': index}
        return item

    def __len__(self):
        return len(self.captSets)


In [5]:
# Extracting caption scores
dataset = preprocess_dataset(df,image_resolution)
targets = []
similarities = []
no_imgs =len(dataset)
print(no_imgs)

for i in range(no_imgs):
    data = dataset[i]   
    image_inputs = torch.unsqueeze(data['image'], 0)
    text_inputs  = torch.LongTensor(data['captSet']) 
    # Calculate features
    with torch.no_grad():
        image_features = CLIPmodel.encode_image(image_inputs.to(device))
        text_features = CLIPmodel.encode_text(text_inputs.to(torch.long).to(device))

    image_features /= image_features.norm(dim=-1, keepdim=True) #torch.Size([1, 512])
    text_features /= text_features.norm(dim=-1, keepdim=True) #torch.Size([101, 512])
    similarity = image_features @ text_features.T
    similarities.append(similarity.squeeze(0).tolist()) #torch.Size([1, 101])

1699


In [6]:
import numpy as np
no_errType = 3
cnt_corr_all = 0
cnt_incorr_all = 0
print("Dataset:",datasetname,", Number of caption sets:",len(similarities))
for errType in range(1,no_errType+1):
    cnt_corr = 0
    cnt_incorr = 0
    for sim in similarities:
        if sim[0] > sim[errType]:
            cnt_corr +=1
            cnt_corr_all +=1
        else:
            cnt_incorr +=1
            cnt_incorr_all +=1
    print(f"Accuracy at errType={errType}:{cnt_corr}/{cnt_corr+cnt_incorr}=",cnt_corr/(cnt_corr+cnt_incorr))

print(f"Accuracy for all types:{cnt_corr_all}/{cnt_corr_all+cnt_incorr_all}=",cnt_corr_all/(cnt_corr_all+cnt_incorr_all))

Dataset: COCO , Number of caption sets: 1699
Accuracy at errType=1:1641/1699= 0.9658622719246616
Accuracy at errType=2:1510/1699= 0.8887580929958799
Accuracy at errType=3:1334/1699= 0.785167745732784
Accuracy for all types:4485/5097= 0.8799293702177752


In [7]:
Dataset: COCO , Number of caption sets: 1699
Accuracy at errType=1:1641/1699= 0.9658622719246616
Accuracy at errType=2:1510/1699= 0.8887580929958799
Accuracy at errType=3:1334/1699= 0.785167745732784
Accuracy for all types:4485/5097= 0.8799293702177752

Dataset: VizWiz , Number of caption sets: 1160
Accuracy at errType=1:1005/1160= 0.8663793103448276
Accuracy at errType=2:1005/1160= 0.8663793103448276
Accuracy at errType=3:817/1160= 0.7043103448275863
Accuracy for all types:2827/3480= 0.8123563218390805

Dataset: Flickr30K , Number of caption sets: 595
Accuracy at errType=1:560/595= 0.9411764705882353
Accuracy at errType=2:522/595= 0.8773109243697479
Accuracy at errType=3:420/595= 0.7058823529411765
Accuracy for all types:1502/1785= 0.8414565826330532

Dataset: ArtEmis , Number of caption sets: 15884
Accuracy at errType=1:11739/15884= 0.7390455804583228
Accuracy at errType=2:12090/15884= 0.7611432888441199
Accuracy at errType=3:9195/15884= 0.5788844119869051
Accuracy for all types:33024/47652= 0.6930244270964493

SyntaxError: invalid syntax (<ipython-input-7-5920281a0d18>, line 1)