In [1]:
import torch
import pandas as pd
import os.path as osp
import numpy as np
from ast import literal_eval
from model.AdapterCLIPMetric_3MLP import adapterCLIP as adapterCLIP_model
from model.func_train_v2 import load_state_dicts
import clip
from artemis.in_out.neural_net_oriented import read_saved_args
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
from PIL import Image
from model.argument import set_seed

In [2]:
FreezeCase = 1
runRawImg = True
CLIP_name = 'RN50x16' #RN50, RN101, or RN50x4
no_ErrorTypes = 3
output_dir = f"output/adapterCLIP_3MLP/{CLIP_name}_F{FreezeCase}"


In [3]:
model_dir = osp.join(output_dir, 'checkpoints')
saved_model_file = osp.join(model_dir,  'best_model.pt')
arg_file = osp.join(output_dir, 'config.json.txt')

In [4]:
origCLIP,CLIPtransform,CLIPsettings = clip.load(CLIP_name,"cpu")
embed_dim,image_resolution, vision_layers, vision_width, vision_patch_size,context_length_CLIP, vocab_size_CLIP, transformer_width, transformer_heads, transformer_layers = CLIPsettings

if runRawImg:
    img_transform = Compose([ 
                        Resize(image_resolution, interpolation=BICUBIC),
                        CenterCrop(image_resolution),
                        lambda image: image.convert("RGB"),
                        ToTensor(),
                        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
                    ])
else:
    raise ValueError("Do not support runRawImg != True!")
    #img_transform = Compose([ToTensor()])

adapterCLIP = adapterCLIP_model(embed_dim,image_resolution,vision_layers,vision_width,
            vision_patch_size,context_length_CLIP,
            vocab_size_CLIP,transformer_width,transformer_heads,transformer_layers)

In [5]:
#Copy weights of original CLIP 
def copy_params(model_init,model):
    state_dict_init = model_init.state_dict()
    model.load_state_dict(state_dict_init,strict=False)
    return True

copy_params(origCLIP,adapterCLIP)

True

In [6]:
args = read_saved_args(arg_file)
loaded_epoch = load_state_dicts(saved_model_file, map_location='cpu', model=adapterCLIP)
print("load model at epoch=",loaded_epoch)

load model at epoch= 3


In [7]:
## Prepare the  dataset (merge it with the emotion-histograms.)
datasetname = 'ArtEmis' #ArtEmis, Flickr30K,  VizWiz, COCO
assert datasetname == 'ArtEmis' ## Only work with ArtEmis
if datasetname == 'ArtEmis':
    datafile = f'../Dataset/{datasetname}/{datasetname}_IdC/{datasetname}_IdCII_3ErrType.csv'
    img_dir = '../Dataset/ArtEmis/OriginalArtEmis/wikiart/'
    df = pd.read_csv(datafile)
    df = df[df.split=='test']
else:
    datafile = f'../Dataset/{datasetname}/{datasetname}_IdCII_3ErrType.csv'
    img_dir = f"../Dataset/{datasetname}/Images/rawImages"
    df = pd.read_csv(datafile)
df.reset_index(drop=True,inplace=True)
print('Number of caption sets in the test set:', len(df))
df.img_files = [osp.join(img_dir,imgfile) for imgfile in df.img_files]
df['captSet_CLIP_tokens'] = df['captSet_CLIP_tokens'].apply(literal_eval)

Number of caption sets in the test set: 15884


In [8]:
class IdCIIDataset(Dataset):
    def __init__(self, image_files,captsets,img_transform=None):
        super().__init__()
        self.image_files = image_files
        self.captsets = captsets
        self.img_transform = img_transform
        print(img_transform)

    def __getitem__(self, index):
        captset = np.array(self.captsets[index]).astype(dtype=np.long)
        if self.image_files is not None:
            img = Image.open(self.image_files[index])

            if img.mode is not 'RGB':
                img = img.convert('RGB')

            if self.img_transform is not None:
                img = self.img_transform(img)
        else:
            img = []
        item = {'image_file': self.image_files[index],'image': img, 'captset': captset, 'index': index}
        return item

    def __len__(self):
        return len(self.image_files)

def preprocess_Dataset(df, img_transform):
    set_seed(args.random_seed)
    datasets = dict()
    for split, g in df.groupby('split'):
        g.reset_index(inplace=True, drop=True) 
        img_files = None
        img_files = g.img_files
        img_files.name = 'image_files'

        dataset = IdCIIDataset(img_files, g.captSet_CLIP_tokens,img_transform=img_transform)

        datasets[split] = dataset

    dataloaders = dict()
    for split in datasets:
        if split=='train' or split == 'val':
            b_size = args.batch_size  
        else:
            b_size = 1
        dataloaders[split] = torch.utils.data.DataLoader(dataset=datasets[split],
                                                         batch_size=b_size,
                                                         shuffle=split=='train')
    return dataloaders, datasets
dataloaders, datasets = preprocess_Dataset(df,img_transform)
dataset = datasets['test']

Compose(
    Resize(size=384, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(384, 384))
    <function <lambda> at 0x7f3e545fc598>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)


In [9]:
scores = []
ita_scores = []
g_scores = []
no_imgs =len(dataset)
print(no_imgs)

15884


In [None]:
device='cuda'
adapterCLIP.to(device)
adapterCLIP.eval()
for i in range(no_imgs):
    print(i)
    data = dataset[i]   
    image_inputs = torch.unsqueeze(data['image'], 0)
    text_inputs  = torch.LongTensor(data['captset']) 
    # Calculate features
    ita1_score_per_image,_,ita2_score_per_image,_,g_score,score = adapterCLIP(image_inputs.to(device),text_inputs.to(device))
    ita_score_per_image = (ita1_score_per_image+ita2_score_per_image)/2
    #score = (ita_score_per_image+g_score)/2
    scores.append(score.squeeze(0).tolist()) #torch.Size([1, 101])
    ita_scores.append(ita_score_per_image.squeeze(0).tolist()) #torch.Size([1, 101])
    g_scores.append(g_score.squeeze(0).tolist()) #torch.Size([1, 101])

In [11]:
ita1_score_per_image - ita2_score_per_image

tensor([[-0.1713,  0.0059, -0.0076, -0.1482]], device='cuda:0',
       grad_fn=<SubBackward0>)

In [12]:
with open(osp.join(output_dir,'ita.npy'), 'wb') as f:
    np.save(f, np.array(ita_scores))
with open(osp.join(output_dir,'g.npy'), 'wb') as f:
    np.save(f, np.array(g_scores))
#with open('test.npy', 'rb') as f:
#    a = np.load(f)
print(osp.join(output_dir,'g.npy'))  

output/adapterCLIP_3MLP/RN50x16_F1/g.npy


In [13]:
import numpy as np
no_errType = 3
import numpy as np
no_errType = 3
for beta in [0]:
    print("beta=0")
    cnt_corr_all = 0
    cnt_incorr_all = 0
    for errType in range(1,no_errType+1):
        cnt_corr = 0
        cnt_incorr = 0
        for ita,g in zip(ita_scores,g_scores):
            score = beta*np.array(ita)-np.array(g)
            if score[0] > score[errType]:
                cnt_corr +=1
                cnt_corr_all +=1
            else:
                cnt_incorr +=1
                cnt_incorr_all +=1
        print(f"Accuracy at errType={errType}:{cnt_corr}/{cnt_corr+cnt_incorr}=",cnt_corr/(cnt_corr+cnt_incorr))

    print(f"Accuracy for all types:{cnt_corr_all}/{cnt_corr_all+cnt_incorr_all}=",cnt_corr_all/(cnt_corr_all+cnt_incorr_all))

for alpha in range(0,11):
    alpha_ = alpha/10
    print("alpha=",alpha_)
    cnt_corr_all = 0
    cnt_incorr_all = 0
    for errType in range(1,no_errType+1):
        cnt_corr = 0
        cnt_incorr = 0
        for ita,g in zip(ita_scores,g_scores):
            score = np.array(ita)-alpha_*np.array(g)
            if score[0] > score[errType]:
                cnt_corr +=1
                cnt_corr_all +=1
            else:
                cnt_incorr +=1
                cnt_incorr_all +=1
        print(f"Accuracy at errType={errType}:{cnt_corr}/{cnt_corr+cnt_incorr}=",cnt_corr/(cnt_corr+cnt_incorr))

    print(f"Accuracy for all types:{cnt_corr_all}/{cnt_corr_all+cnt_incorr_all}=",cnt_corr_all/(cnt_corr_all+cnt_incorr_all))


beta=0
Accuracy at errType=1:11582/15884= 0.7291614202971544
Accuracy at errType=2:9111/15884= 0.5735960715185092
Accuracy at errType=3:14568/15884= 0.917149332661798
Accuracy for all types:35261/47652= 0.7399689414924872
alpha= 0.0
Accuracy at errType=1:14435/15884= 0.9087761269201713
Accuracy at errType=2:12283/15884= 0.7732938806346008
Accuracy at errType=3:12520/15884= 0.7882145555275749
Accuracy for all types:39238/47652= 0.8234281876941156
alpha= 0.1
Accuracy at errType=1:14372/15884= 0.9048098715688744
Accuracy at errType=2:12078/15884= 0.760387811634349
Accuracy at errType=3:13438/15884= 0.846008562075044
Accuracy for all types:39888/47652= 0.8370687484260891
alpha= 0.2
Accuracy at errType=1:14285/15884= 0.8993326617980357
Accuracy at errType=2:11876/15884= 0.7476706119365399
Accuracy at errType=3:13774/15884= 0.8671619239486276
Accuracy for all types:39935/47652= 0.838055065894401
alpha= 0.3
Accuracy at errType=1:14182/15884= 0.8928481490808361
Accuracy at errType=2:11627/1588

In [None]:
# alpha here is γ in the paper.         
#The selected value is 0.2
#alpha= 0.2
#Accuracy at errType=1:14285/15884= 0.8993326617980357
#Accuracy at errType=2:11876/15884= 0.7476706119365399
#Accuracy at errType=3:13774/15884= 0.8671619239486276
#Accuracy for all types:39935/47652= 0.838055065894401