In [None]:
# %cd your workspace

In [None]:
import clip_new
import torch
from torch import nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import Owlv2ForObjectDetection, Owlv2Processor
from PIL import Image
import json
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import os

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset_name = 'cifar10'
zeroshot_dataset_name = 'cifar10'
prompt_type = 'deep'
clip_model = 'ViT-L/14' # [ViT-B/32, ViT-L/14]
prompt_len = 3
batch_size = 32
epochs_num = 20
learning_rate = 0.004

root = # '/path/to/'

In [4]:
class ECORDataset(Dataset):
    def __init__(self, preprocess, json_path, cat2id, rat2id, max_length):
        super().__init__()

        self.preprocess = preprocess
        self.datas = pd.read_json(json_path)    
        self.cat2id = cat2id
        self.rat2id = rat2id
        self.max_length = max_length

    def __getitem__(self, index):
        image = self.preprocess(Image.open(self.datas.iloc[index]['image']))
        category = self.cat2id[self.datas.iloc[index]['category']]

        rationales = torch.tensor(pd.Series(self.datas.iloc[index]['unqiue_rationales']).apply(lambda rat: self.rat2id[rat]).values)    
        rationales = torch.cat([rationales, -torch.ones(self.max_length-len(rationales))]).to(int)
        return image, category, rationales

    def __len__(self):
        return len(self.datas)

In [6]:
class CosineLR:
    def __init__(self, optimizer, base_lr, warmup_steps, total_steps):
        self.optimizer = optimizer
        self.base_lr = base_lr
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.initialize()

    def initialize(self):
        self._counter = 0
        self.step()

    def step(self):
        if self._counter < self.warmup_steps:
            lr = self.base_lr * (self._counter+1)/self.warmup_steps
        
        else:
            e = self._counter - self.warmup_steps
            es = self.total_steps - self.warmup_steps
            lr = 0.5 * (1 + np.cos(np.pi * e / es)) * self.base_lr

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        self._counter += 1
        

class VisionPromptTuning(nn.Module):
    def __init__(self, prompt_type, prompt_len, width, layers):
        super().__init__()
        assert prompt_type in {'shallow', 'deep'}
        self.prompt_type = prompt_type
        self.prompt_len = prompt_len
        self.width = width
        self.layers = layers

        scale = width ** (-0.5)
        if prompt_type == 'shallow':
            self.weights = nn.Parameter(scale * torch.randn(prompt_len, width))
        else:
            self.weights = nn.Parameter(scale * torch.randn(layers,prompt_len, width))
        
        self.pos_embeddings = nn.Parameter(scale * torch.randn(prompt_len, width))

    def forward(self):
        return self.weights, self.pos_embeddings

In [None]:
# load models
model, preprocess = clip_new.load(clip_model, device)
model = model.to(torch.float32)
model = model.eval()

tokenPrompts = VisionPromptTuning(prompt_type, prompt_len, model.vision_width, model.vision_layers).to(device)

'\nowl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")\nowl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)\n\nfor param in model.parameters():\n    param.requires_grad = False\n\n'

In [None]:
# load the learnt prompt if neccesacry
tokenPrompts = VisionPromptTuning(prompt_type, prompt_len, model.vision_width, model.vision_layers).to(device)
save_file = torch.load(f'{root}/ECOR_new/DROR_checkpoints/{dataset_name}.pth.tar')
tokenPrompts.load_state_dict(save_file['tokenPrompts'])
print("Epoch ", save_file['epoch'], ": ", "Loss: ", save_file['loss'])

Epoch  20 :  Loss:  5.398271606863961


In [None]:
def preprocess_cat_rat(json_path):
        
        datas = pd.read_json(json_path)
        rationales = []

        for _, data in datas.iterrows():
            rationales.extend(data['unqiue_rationales'])
        
        unique_rats = pd.Series(rationales).apply(lambda rat: rat.lower().strip()).unique()
        rat2id = dict(zip(unique_rats, range(len(unique_rats))))
        
        unique_cats = datas['category'].apply(lambda cat: cat.lower().strip()).unique()
        cat2id = dict(zip(unique_cats, range(len(unique_cats))))
        
        maxLength = datas['unqiue_rationales'].apply(lambda rats: len(rats)).max()        
        
        return cat2id, rat2id, maxLength

train_cat2id, train_rat2id, train_maxLength = preprocess_cat_rat(f'{root}/Datasets/{dataset_name}_dr/stats_train.json')
test_cat2id, test_rat2id, test_maxLength = preprocess_cat_rat(f'{root}/Datasets/{dataset_name}_dr/stats_test.json')
zeroshot_cat2id, zeroshot_rat2id, zeroshot_maxLength = preprocess_cat_rat(f'{root}/Datasets/{zeroshot_dataset_name}_dr/stats_test.json')

"""
cat_names = np.array(list(cat2id.keys()))
rat_names = np.array(list(rat2id.keys()))

cat_names_zeroshot = np.array(list(cat2id.keys()))
rat_names_zeroshot = np.array(list(rat2id.keys()))
"""

'\ncat_names = np.array(list(cat2id.keys()))\nrat_names = np.array(list(rat2id.keys()))\n\ncat_names_zeroshot = np.array(list(cat2id.keys()))\nrat_names_zeroshot = np.array(list(rat2id.keys()))\n'

In [None]:

trainset = ECORDataset(preprocess, f'{root}/Datasets/{dataset_name}_dr/stats_train.json', train_cat2id, train_rat2id, train_maxLength)
trainLoader = DataLoader(trainset, batch_size, shuffle=True)

testset = ECORDataset(preprocess, f'{root}/Datasets/{dataset_name}_dr/stats_test.json', test_cat2id, test_rat2id, test_maxLength)
testLoader = DataLoader(testset, batch_size)


zeroshot_set = ECORDataset(preprocess, f'{root}/Datasets/{zeroshot_dataset_name}_dr/stats_test.json', zeroshot_cat2id, zeroshot_rat2id, zeroshot_maxLength)
zeroshot_loader = DataLoader(zeroshot_set, batch_size)


In [None]:
def compute_cat_rat_embeddings(model, cat2id, rat2id):
    cats_embed = []
    rats_embed = []

    with torch.no_grad():
        block_size = 50
        block_num = np.ceil(len(cat2id)/block_size).astype(int)
        
        for i in range(block_num):
            if i != (block_num-1):
                cats_tokens_temp = clip_new.tokenize(pd.Series(list(cat2id.keys())[i*block_size:(i+1)*block_size]).apply(lambda cat: 'A photo of a ' + cat)).to(device)
            else:
                cats_tokens_temp = clip_new.tokenize(pd.Series(list(cat2id.keys())[i*block_size:]).apply(lambda cat: 'A photo of a ' + cat)).to(device)

            cats_embeds_temp = model.encode_text(cats_tokens_temp)
            cats_embeds_temp /= cats_embeds_temp.norm(dim=1, keepdim=True)

            cats_embed.append(cats_embeds_temp)
        
        block_size = 50
        block_num = np.ceil(len(rat2id)/block_size).astype(int)

        for i in range(block_num):
            if i != (block_num-1):
                rats_tokens_pos_temp = clip_new.tokenize(pd.Series(list(rat2id.keys())[i*block_size:(i+1)*block_size]).apply(lambda rat: 'There is ' + rat)).to(device)
            else:
                rats_tokens_pos_temp = clip_new.tokenize(pd.Series(list(rat2id.keys())[i*block_size:]).apply(lambda rat: 'There is ' + rat)).to(device)

            rats_embeds_pos_temp = model.encode_text(rats_tokens_pos_temp)
            rats_embeds_pos_temp /= rats_embeds_pos_temp.norm(dim=1, keepdim=True)

            rats_embed.append(rats_embeds_pos_temp)

        cats_embed = torch.cat(cats_embed, dim=0)
        rats_embed = torch.cat(rats_embed, dim=0)
    
    return cats_embed, rats_embed



train_cats_embed, train_rats_embed = compute_cat_rat_embeddings(model, train_cat2id, train_rat2id)
test_cats_embed, test_rats_embed = compute_cat_rat_embeddings(model, test_cat2id, test_rat2id)
cats_embed_zeroshot, rats_embed_zeroshot = compute_cat_rat_embeddings(model, zeroshot_cat2id, zeroshot_rat2id)

## Conditionality Evaluation

In [None]:
def compute_cat_rat_embeddings_DROR(model, cat2id, rat2id):
    with torch.no_grad():
        cats_rats_embed = []
        cats_rats_prompts = []
        rats_cats_prompts = []
        rats_cats_embed = []

        for rat in cat2id.keys():
            for cat in rat2id.keys():  
                cats_rats_prompt_temp = f"This is a photo of a {cat} because there is {rat}"
                cats_rats_prompts.append(cats_rats_prompt_temp)

        for cat in rat2id.keys():
            for rat in cat2id.keys():  
                rats_cats_prompt_temp = f"This is a photo of a {rat} because there is {cat}"
                rats_cats_prompts.append(rats_cats_prompt_temp)
        

        block_size = 50
        block_num = np.ceil(len(cats_rats_prompts)/block_size).astype(int)
        for i in range(block_num):
            if i != (block_num-1):
                cats_rats_tokens_temp = clip.tokenize(cats_rats_prompts[i*block_size:(i+1)*block_size]).to(device)
            else:
                cats_rats_tokens_temp = clip.tokenize(cats_rats_prompts[i*block_size:]).to(device)
            
            cats_rats_embed_temp = model.encode_text(cats_rats_tokens_temp)
            cats_rats_embed_temp /= cats_rats_embed_temp.norm(dim=1, keepdim=True)
            cats_rats_embed.append(cats_rats_embed_temp)
        

        block_size = 50
        block_num = np.ceil(len(rats_cats_prompts)/block_size).astype(int)
        for i in range(block_num):
            if i != (block_num-1):
                rats_cats_tokens_temp = clip.tokenize(rats_cats_prompts[i*block_size:(i+1)*block_size]).to(device)
            else:
                rats_cats_tokens_temp = clip.tokenize(rats_cats_prompts[i*block_size:]).to(device)
            
            rats_cats_embed_temp = model.encode_text(rats_cats_tokens_temp)
            rats_cats_embed_temp /= rats_cats_embed_temp.norm(dim=1, keepdim=True)
            rats_cats_embed.append(rats_cats_embed_temp)


        cats_rats_embed = torch.cat(cats_rats_embed, dim=0)
        rats_cats_embed = torch.cat(rats_cats_embed, dim=0)
    
        return cats_rats_embed, rats_cats_embed


train_cats_rats_embed, train_rats_cats_embed = compute_cat_rat_embeddings_DROR(model, train_cat2id, train_rat2id)
test_cats_rats_embed, test_rats_cats_embed = compute_cat_rat_embeddings_DROR(model, test_cat2id, test_rat2id)
zeroshot_cats_rats_embed, zeroshot_cats_rats_prompts = compute_cat_rat_embeddings_DROR(model, zeroshot_cat2id, zeroshot_rat2id)

In [None]:
with torch.no_grad():
    ratios = []
    for (imgs, cats, rats) in tqdm(trainLoader):
        images = imgs.to(device)
        cats = cats.to(device)
        rats = rats.to(device)

        images_embed = model.encode_image(images)
        images_embed /= images_embed.norm(dim=1, keepdim=True) # B x d
        
        cat_projections = []
        rat_projections = []
        for img_embed, cat, rat in zip(images_embed, cats, rats):
            re = train_rats_embed[rat[0]][None, :] # 1 x d
            hyper_plane = torch.cat([re.t(), img_embed[:, None]], dim=1) # d x 2
            res = torch.linalg.lstsq(hyper_plane, train_cats_embed.t())
            cat_projection = hyper_plane @ res.solution # d x |C|
            groundtruth_dir = hyper_plane.sum(dim=1, keepdim=True)
            groundtruth_dir /= groundtruth_dir.norm(dim=0, keepdim=True) # d x 1          
            cat_projections.append(torch.squeeze(groundtruth_dir.t() @ cat_projection)) # |C|

            ce = train_cats_embed[cat][None, :] # 1 x d
            hyper_plane = torch.cat([ce.t(), img_embed[:, None]], dim=1) # d x 2
            ces = torch.linalg.lstsq(hyper_plane, train_rats_embed.t())
            rat_projection = hyper_plane @ ces.solution
            groundtruth_dir = hyper_plane.sum(dim=1, keepdim=True)
            groundtruth_dir /= groundtruth_dir.norm(dim=0, keepdim=True) # d x 1          
            rat_projections.append(torch.squeeze(groundtruth_dir.t() @ rat_projection)) # |R|

        logits_cat = torch.stack(cat_projections, dim=0) # B x |C|
        logits_rat = torch.stack(rat_projections, dim=0) # B x |R|
        probs_cat_rat = logits_cat.softmax(dim=-1) # B x |C|
        probs_rat_cat = logits_rat.softmax(dim=-1) # B x |R|
        

        logits_cat = images_embed @ train_cats_embed.t() # B x |C|
        logits_rat = images_embed @ train_rats_embed.t() # B x |R|
        probs_cat = logits_cat.softmax(dim=-1)
        probs_rat = logits_rat.softmax(dim=-1)
        
        idxs = torch.arange(probs_cat.shape[0]).to(device)
        ratio = (probs_cat[idxs, cats] * probs_rat_cat[idxs, rats[:,0]]) / (probs_rat[idxs, rats[:, 0]] * probs_cat_rat[idxs, cats])
        ratios.append(ratio)

    ratios = torch.cat(ratios, dim=0)
    print(ratios.mean().item())      

100%|██████████| 12/12 [00:06<00:00,  1.75it/s]

1.0021806955337524





In [None]:
with torch.no_grad():
    ratios = []
    for (imgs, cats, rats) in tqdm(trainLoader):
        images = imgs.to(device)
        cats = cats.to(device)
        rats = rats.to(device)

        images_embed = model.encode_image(images)
        images_embed /= images_embed.norm(dim=1, keepdim=True) # B x d
        
        cat_projections = []
        rat_projections = []
        for img_embed, cat, rat in zip(images_embed, cats, rats):
            re = train_rats_embed[rat[0]][None, :] # 1 x d
            res = torch.linalg.lstsq(re.t(), train_cats_embed.t())
            cat_projection = re.t() @ res.solution # d x |C|
            cat_projections.append(cat_projection)

            ce = train_cats_embed[cat][None, :] # 1 x d
            ces = torch.linalg.lstsq(ce.t(), train_rats_embed.t())
            rat_projection = ce.t() @ ces.solution
            rat_projections.append(rat_projection)

        cat_projections = torch.stack(cat_projections, dim=0) # B x d x |C|
        rat_projections = torch.stack(rat_projections, dim=0) # B x d x |R|
        logits_cat = model.logit_scale.exp() * torch.squeeze(images_embed[:, None, :] @ cat_projections) # B x |C|
        logits_rat = model.logit_scale.exp() * torch.squeeze(images_embed[:, None, :] @ rat_projections) # B x |R|
        probs_cat_rat = logits_cat.softmax(dim=-1) # B x |C|
        probs_rat_cat = logits_rat.softmax(dim=-1) # B x |R|

        logits_cat = model.logit_scale.exp() * images_embed @ train_cats_embed.t() # B x |C|
        logits_rat = model.logit_scale.exp() * images_embed @ train_rats_embed.t() # B x |R|
        probs_cat = logits_cat.softmax(dim=-1)
        probs_rat = logits_rat.softmax(dim=-1)
        
        idxs = torch.arange(probs_cat.shape[0]).to(device)
        ratio = (probs_cat[idxs, cats] * probs_rat_cat[idxs, rats[:,0]]) / (probs_rat[idxs, rats[:, 0]] * probs_cat_rat[idxs, cats])
        ratios.append(ratio)

    ratios = torch.cat(ratios, dim=0)
    print(ratios.mean().item())      

  0%|          | 0/12 [00:00<?, ?it/s]

100%|██████████| 12/12 [00:07<00:00,  1.68it/s]

113.16597747802734





In [91]:
with torch.no_grad():
    ratios = []
    for (imgs, cats, rats) in tqdm(trainLoader):
        images = imgs.to(device)
        cats = cats.to(device)
        rats = rats.to(device)

        images_embed = model.encode_image(images)
        images_embed /= images_embed.norm(dim=1, keepdim=True) # B x d
        
        logist_cat_rat = images_embed @ train_cats_rats_embed.t() # B x |C|*|R|
        logist_rat_cat = images_embed @ train_rats_cats_embed.t() # B x |R|*|C|
        probs_cat_rat = logist_cat_rat.softmax(dim=-1)
        probs_rat_cat = logist_rat_cat.softmax(dim=-1)

        logits_cat = images_embed @ train_cats_embed.t() # B x |C|
        logits_rat = images_embed @ train_rats_embed.t() # B x |R|
        probs_cat = logits_cat.softmax(dim=-1)
        probs_rat = logits_rat.softmax(dim=-1)
        
        idxs = torch.arange(probs_cat.shape[0]).to(device)
        ratio = (probs_cat[idxs, cats] * probs_rat_cat[idxs, cats*len(train_rat2id)+rats[:,0]]) / (probs_rat[idxs, rats[:, 0]] * probs_cat_rat[idxs, rats[:,0]*len(train_cat2id)+cats])
        ratios.append(ratio)

    ratios = torch.cat(ratios, dim=0)
    print(ratios.mean().item())      

  0%|          | 0/343 [00:00<?, ?it/s]

100%|██████████| 343/343 [03:43<00:00,  1.53it/s]

4.392852783203125





## DROR Evaluation on Multi-Rationale

In [None]:
def DROR_compute_cat_rat_embeddings(model, cat2id, rat2id):
    with torch.no_grad():
        cats_rats_embed = []
        cats_rats_prompts = []

        for rat in rat2id.keys():
            for cat in cat2id.keys():  
                cats_rats_prompt_temp = f"This is a photo of a {cat} because there is {rat}"
                cats_rats_prompts.append(cats_rats_prompt_temp)

        block_size = 50
        block_num = np.ceil(len(cats_rats_prompts)/block_size).astype(int)
        for i in tqdm(range(block_num), total=block_num):
            if i != (block_num-1):
                cats_rats_tokens_temp = clip_new.tokenize(cats_rats_prompts[i*block_size:(i+1)*block_size]).to(device)
            else:
                cats_rats_tokens_temp = clip_new.tokenize(cats_rats_prompts[i*block_size:]).to(device)
            
            cats_rats_embed_temp = model.encode_text(cats_rats_tokens_temp)
            cats_rats_embed_temp /= cats_rats_embed_temp.norm(dim=1, keepdim=True)
            cats_rats_embed.append(cats_rats_embed_temp)
        
        cats_rats_embed = torch.cat(cats_rats_embed, dim=0)
    
        return cats_rats_embed


DROR_train_cats_rats_embed = DROR_compute_cat_rat_embeddings(model, train_cat2id, train_rat2id)
DROR_test_cats_rats_embed = DROR_compute_cat_rat_embeddings(model, test_cat2id, test_rat2id)
DROR_zeroshot_cats_rats_embed = DROR_compute_cat_rat_embeddings(model, zeroshot_cat2id, zeroshot_rat2id)

100%|██████████| 11/11 [00:00<00:00, 18.47it/s]


In [14]:
@torch.no_grad()
def compute_RR(pred_rats:list, target_rats:list, pred_cats:list, target_cat:int):
    # Compute Category Accuracy
    acc_cat = 0.
    cats_vote = pd.Series(pred_cats).value_counts()
    max_vote_num = cats_vote.iloc[0]
    j=0
    while j<len(cats_vote) and cats_vote.iloc[j] == max_vote_num:
        if cats_vote.index[j] == target_cat:
            acc_cat = 1.
            break
        j += 1
    
    # Compute Rationales Accuracy
    acc_rat = len(set(pred_rats) & set(target_rats)) / len(target_rats)

    return acc_cat, acc_rat
    
@torch.no_grad()
def evaluation(model, testLoader, cats_rats_embed, cat2id):
    RRs = []
    RWs = []
    WRs = []
    WWs = []

    for (images, cats, rats) in tqdm(testLoader, total=len(testLoader)):
        images = images.to(device)
        cats = cats.to(device)
        rats = rats.to(device)

        prompt, pos_embedding = tokenPrompts()
        images_embed = model.encode_image(images, prompt, pos_embedding)
        images_embed = images_embed / images_embed.norm(dim=1, keepdim=True) # B x d
        
        logits_cats_rats = images_embed @ cats_rats_embed.t() # B x |R||C|
        
        for (logits_cat_rat, cat, rat) in zip(logits_cats_rats, cats, rats):
            target_rat = rat[rat!=-1]
            preds_cat_rat = torch.topk(logits_cat_rat, k=len(target_rat), dim=0)[1]
            preds_rat = preds_cat_rat // len(cat2id)
            preds_cat = preds_cat_rat % len(cat2id)
            
            acc_cat, acc_rat = compute_RR(preds_rat.cpu().tolist(), target_rat.cpu().tolist(), preds_cat.cpu().tolist(), cat.item())
            RRs.append(acc_cat * acc_rat)
            RWs.append(acc_cat * (1-acc_rat))
            WRs.append((1-acc_cat) * acc_rat)
            WWs.append((1-acc_cat) * (1-acc_rat))

    RR = np.mean(RRs)
    RW = np.mean(RWs)
    WR = np.mean(WRs)
    WW = np.mean(WWs)
    print(f"RR: {100*RR}, RW: {100*RW}, WR: {100*WR}, WW: {100*WW}")
    return RR, RW, WR, WW


In [15]:
RR, RW, WR, WW = evaluation(model, zeroshot_loader, DROR_zeroshot_cats_rats_embed, zeroshot_cat2id)

100%|██████████| 11/11 [00:06<00:00,  1.63it/s]

RR: 28.325324013397406, RW: 66.78170962574632, WR: 1.6972477064220184, WW: 3.1957186544342515





In [66]:
RR, RW, WR, WW = evaluation(model, testLoader, DROR_test_cats_rats_embed, test_cat2id)

100%|██████████| 80/80 [03:05<00:00,  2.32s/it]

RR: 11.957717381824523, RW: 44.13529674690388, WR: 5.178516919588348, WW: 38.72846895168324





## ECOR Evaluation on Multi-Rationale

In [None]:
def ECOR_compute_cat_rat_embeddings(model, cat2id, rat2id):
    with torch.no_grad():
        cats_rats_embed = []
        cats_rats_prompts = []
        rats_prompts = []
        rats_embed = []

        for rat in rat2id.keys():
            rats_prompt_temp = f"There is {rat}"
            rats_prompts.append(rats_prompt_temp)
        
        for rat in rat2id.keys():
            for cat in cat2id.keys():  
                cats_rats_prompts_temp = f"This is a photo of a {cat} because there is {rat}"
                cats_rats_prompts.append(cats_rats_prompts_temp)

        block_size = 50
        block_num = np.ceil(len(rats_prompts)/block_size).astype(int)
        for i in range(block_num):
            if i != (block_num-1):
                rats_tokens_temp = clip_new.tokenize(rats_prompts[i*block_size:(i+1)*block_size]).to(device)
            else:
                rats_tokens_temp = clip_new.tokenize(rats_prompts[i*block_size:]).to(device)
            
            rats_embed_temp = model.encode_text(rats_tokens_temp)
            rats_embed_temp /= rats_embed_temp.norm(dim=1, keepdim=True)
            rats_embed.append(rats_embed_temp)
        
        rats_embed = torch.cat(rats_embed, dim=0)

        block_size = 50
        block_num = np.ceil(len(cats_rats_prompts)/block_size).astype(int)
        for i in tqdm(range(block_num), total=block_num):
            if i != (block_num-1):
                cats_rats_tokens_temp = clip_new.tokenize(cats_rats_prompts[i*block_size:(i+1)*block_size]).to(device)
            else:
                cats_rats_tokens_temp = clip_new.tokenize(cats_rats_prompts[i*block_size:]).to(device)
            
            cats_rats_embed_temp = model.encode_text(cats_rats_tokens_temp)
            cats_rats_embed_temp /= cats_rats_embed_temp.norm(dim=1, keepdim=True)
            cats_rats_embed.append(cats_rats_embed_temp)
        
        cats_rats_embed = torch.cat(cats_rats_embed, dim=0)
    
        return rats_embed, cats_rats_embed


ECOR_train_rats_embed, ECOR_train_cats_rats_embed = ECOR_compute_cat_rat_embeddings(model, train_cat2id, train_rat2id)
ECOR_test_rats_embed, ECOR_test_cats_rats_embed = ECOR_compute_cat_rat_embeddings(model, test_cat2id, test_rat2id)
ECOR_zeroshot_rats_embed, ECOR_zeroshot_cats_rats_embed = ECOR_compute_cat_rat_embeddings(model, zeroshot_cat2id, zeroshot_rat2id)

100%|██████████| 805/805 [00:42<00:00, 19.03it/s]


In [92]:
@torch.no_grad()
def compute_RR(pred_rats:list, target_rats:list, pred_cats:list, target_cat:int):
    # Compute Category Accuracy
    acc_cat = 0.
    cats_vote = pd.Series(pred_cats).value_counts()
    max_vote_num = cats_vote.iloc[0]
    j=0
    while j<len(cats_vote) and cats_vote.iloc[j] == max_vote_num:
        if cats_vote.index[j] == target_cat:
            acc_cat = 1.
            break
        j += 1
    
    # Compute Rationales Accuracy
    acc_rat = len(set(pred_rats) & set(target_rats)) / len(target_rats)

    return acc_cat, acc_rat


@torch.no_grad()
def evaluation(model, testLoader, rats_embed, cats_rats_embed, cat2id):
    RRs = []
    RWs = []
    WRs = []
    WWs = []

    for (images, cats, rats) in tqdm(testLoader, total=len(testLoader)):
        images = images.to(device)
        cats = cats.to(device)
        rats = rats.to(device)

        prompt, pos_embedding = tokenPrompts()
        images_embed = model.encode_image(images, prompt, pos_embedding)
        images_embed = images_embed / images_embed.norm(dim=1, keepdim=True) # B x d
        
        logits_rats = images_embed @ rats_embed.t() # B x |R|
        probs_rats = logits_rats.softmax(dim=-1) # B x |R|
        logits_cats_rats = images_embed @ cats_rats_embed.t() # B x |R||C|
        probs_cats = logits_cats_rats.softmax(dim=-1) # B x |R||C|
        probs_cats_rats = probs_cats * torch.repeat_interleave(probs_rats, len(cat2id), dim=-1) # B x |R||C|

        for (probs_cat_rat, cat, rat) in zip(probs_cats_rats, cats, rats):
            target_rat = rat[rat!=-1]
            preds_cat_rat = torch.topk(probs_cat_rat, k=len(target_rat), dim=0)[1]
            preds_rat = preds_cat_rat // len(cat2id)
            preds_cat = preds_cat_rat % len(cat2id)
            
            acc_cat, acc_rat = compute_RR(preds_rat.cpu().tolist(), target_rat.cpu().tolist(), preds_cat.cpu().tolist(), cat.item())
            RRs.append(acc_cat * acc_rat)
            RWs.append(acc_cat * (1-acc_rat))
            WRs.append((1-acc_cat) * acc_rat)
            WWs.append((1-acc_cat) * (1-acc_rat))

    RR = np.mean(RRs)
    RW = np.mean(RWs)
    WR = np.mean(WRs)
    WW = np.mean(WWs)
    print(f"RR: {100*RR}, RW: {100*RW}, WR: {100*WR}, WW: {100*WW}")
    return RR, RW, WR, WW

In [93]:
RR, RW, WR, WW = evaluation(model, zeroshot_loader, ECOR_zeroshot_rats_embed, ECOR_zeroshot_cats_rats_embed, zeroshot_cat2id)

  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:50<00:00,  2.53s/it]

RR: 19.28219129822089, RW: 63.619576067171636, WR: 2.9912121034192545, WW: 14.107020531188224





In [15]:
RR, RW, WR, WW = evaluation(model, testLoader, ECOR_test_rats_embed, ECOR_test_cats_rats_embed, test_cat2id)

100%|██████████| 80/80 [03:26<00:00,  2.58s/it]

RR: 12.77706526749894, RW: 45.00354697739902, WR: 5.415436321048565, WW: 36.80395143405348





## Train and Evaluation

In [None]:
@torch.no_grad()
def compute_RR(pred_rats:list, target_rats:list, pred_cats:list, target_cat:int):
    # Compute Category Accuracy
    acc_cat = 0.
    cats_vote = pd.Series(pred_cats).value_counts()
    max_vote_num = cats_vote.iloc[0]
    j=0
    while j<len(cats_vote) and cats_vote.iloc[j] == max_vote_num:
        if cats_vote.index[j] == target_cat:
            acc_cat = 1.
            break
        j += 1
    
    # Compute Rationales Accuracy
    acc_rat = len(set(pred_rats) & set(target_rats)) / len(target_rats)

    return acc_cat, acc_rat
    
@torch.no_grad()
def evaluation(model, testLoader, cats_embed, rats_embed, tokenPrompts=None):
    RRs = []
    RWs = []
    WRs = []
    WWs = []

    for (images, cats, rats) in tqdm(testLoader, total=len(testLoader)):
        images = images.to(device)
        cats = cats.to(device)
        rats = rats.to(device)

        prompt, pos_embedding = None, None
        if tokenPrompts:
            prompt, pos_embedding = tokenPrompts()

        images_embed = model.encode_image(images, prompt, pos_embedding)    
        images_embed = images_embed / images_embed.norm(dim=1, keepdim=True) # B x d
        
        logits_rat = images_embed @ rats_embed.t() # B x |R|
        probs_rat = logits_rat.softmax(dim=-1)
        
        for i, (img_embed, cat, rat) in enumerate(zip(images_embed, cats, rats)):
            hyper_planes = torch.stack([rats_embed, img_embed.expand(rats_embed.shape)], dim=-1) # |R| x d x 2
            res = torch.linalg.lstsq(hyper_planes, cats_embed.t()[None,:,:])
            cat_projection = hyper_planes @ res.solution # |R| x d x |C|
            groundtruth_dir = hyper_planes.sum(dim=-1) # |R| x d
            groundtruth_dir /= groundtruth_dir.norm(dim=-1, keepdim=True) # |R| x d
            
            proj_cats_logit = torch.squeeze(groundtruth_dir[:, None, :] @ cat_projection, dim=1) # |R| x |C|
            probs_proj_cat = proj_cats_logit.softmax(dim=-1) # |R| x |C|
            probs_cat_rat = probs_rat[i,:][:, None] * probs_proj_cat # |R| x |C|
        
            target_rat = rat[rat!=-1]
            preds_cat_rat = torch.topk(probs_cat_rat.reshape((-1,)), k=len(target_rat), dim=0)[1]
            preds_rat = preds_cat_rat // probs_cat_rat.shape[1]
            preds_cat = preds_cat_rat % probs_cat_rat.shape[1]
            
            acc_cat, acc_rat = compute_RR(preds_rat.cpu().tolist(), target_rat.cpu().tolist(), preds_cat.cpu().tolist(), cat.item())
            RRs.append(acc_cat * acc_rat)
            RWs.append(acc_cat * (1-acc_rat))
            WRs.append((1-acc_cat) * acc_rat)
            WWs.append((1-acc_cat) * (1-acc_rat))

    RR = np.mean(RRs)
    RW = np.mean(RWs)
    WR = np.mean(WRs)
    WW = np.mean(WWs)
    print(f"RR: {RR}, RW: {RW}, WR: {WR}, WW: {WW}")
    return RR, RW, WR, WW


@torch.no_grad()
def evaluation_MR(model, testLoader, cats_embed, rats_embed, tokenPrompts=None, k_beam=5):
    RRs = []
    RWs = []
    WRs = []
    WWs = []
    rat_idxs = torch.arange(rats_embed.shape[0]).to(device)

    for (images, cats, rats) in tqdm(testLoader, total=len(testLoader)):
        print("yes..............")
        images = images.to(device)
        cats = cats.to(device)
        rats = rats.to(device)

        prompt, pos_embedding = None, None
        if tokenPrompts:
            prompt, pos_embedding = tokenPrompts()

        images_embed = model.encode_image(images, prompt, pos_embedding)    
        images_embed = images_embed / images_embed.norm(dim=1, keepdim=True) # B x d
        
        logits_rat = images_embed @ rats_embed.t() # B x |R|
        probs_rat = logits_rat.softmax(dim=-1) # B x |R|
        
        for i, (img_embed, cat, rat) in enumerate(zip(images_embed, cats, rats)):
            target_rat = rat[rat!=-1]   
            choosed_rats_mask = torch.zeros(k_beam, rats_embed.shape[0], dtype=torch.bool).to(device)
            for k in range(len(target_rat)):
                choosed_rats_embed = rats_embed[None, :, :].expand((k_beam,-1,-1))[choosed_rats_mask].reshape((k_beam, k, rats_embed.shape[-1])) # k_beam x k x d
                remained_rats_embed = rats_embed[None, :, :].expand((k_beam,-1,-1))[~choosed_rats_mask].reshape((k_beam, rats_embed.shape[0]-k, -1)) # k_beam x (|R|-k) x d
                hyper_planes = torch.stack([remained_rats_embed, img_embed[None, None, :].expand(remained_rats_embed.shape)], dim=-1) # k_beam x (|R|-k) x d x 2
                hyper_planes = torch.cat([hyper_planes, choosed_rats_embed.permute(0,2,1)[:, None, :, :].expand((-1, hyper_planes.shape[1], -1, -1))], dim=-1) # k_beam x |R|-k x d x (k+2)
                res = torch.linalg.lstsq(hyper_planes, cats_embed.t()[None, None,:,:])
                cat_projection = hyper_planes @ res.solution # k_beam x (|R|-k) x d x |C|
                groundtruth_dir = hyper_planes.sum(dim=-1) # k_beam x (|R|-k) x d
                groundtruth_dir /= groundtruth_dir.norm(dim=-1, keepdim=True) # k_beam x (|R|-k) x d
                
                proj_cats_logit = torch.squeeze(groundtruth_dir[:, :, None, :] @ cat_projection, dim=-2) # k_beam x (|R|-k) x |C|
                probs_proj_cat = proj_cats_logit.softmax(dim=-1) # k_beam x (|R|-k) x |C|
                probs_cat_rat = probs_rat[i][None, :].expand((k_beam, -1))[~choosed_rats_mask].reshape((k_beam, -1, 1)) * probs_proj_cat # k_beam x (|R|-k) x |C|
                preds_cat_rat = torch.topk(probs_cat_rat.reshape((-1,),), k=k_beam, dim=0)[1]              
                pred_rat = rat_idxs[None, :].expand((k_beam, -1))[~choosed_rats_mask][preds_cat_rat // probs_cat_rat.shape[-1]]
                choosed_rats_mask = choosed_rats_mask[(preds_cat_rat // probs_cat_rat.shape[-1])//probs_cat_rat.shape[1]]
                choosed_rats_mask[torch.arange(k_beam), pred_rat] = True
           
           
            choosed_rats_embed = rats_embed[None, :, :].expand((k_beam, -1, -1))[choosed_rats_mask].reshape((k_beam, len(target_rat), -1)) # k-Beam x m x d
            hyper_plane = torch.cat([choosed_rats_embed.permute(0,2,1), img_embed[None, :, None].expand(k_beam, -1, -1)], dim=-1) # k_beam x d x (m+1)
            res = torch.linalg.lstsq(hyper_plane, cats_embed.t()[None, :, :])
            cat_projection = hyper_plane @ res.solution # k_beam x d x |C|
            groundtruth_dir = hyper_plane.sum(dim=-1) # k_beam x d
            groundtruth_dir /= groundtruth_dir.norm(dim=-1, keepdim=True) # k_beam x d
            
            proj_cats_logit = torch.squeeze(groundtruth_dir[:, None, :] @ cat_projection, dim=-2) # k_beam x |C|
            probs_proj_cat = proj_cats_logit.softmax(dim=-1) # k_beam x |C|
            preds_cat = torch.topk(probs_proj_cat.reshape(-1), k=len(target_rat), dim=0)[1]
            preds_rat = rat_idxs[choosed_rats_mask[(preds_cat//probs_proj_cat.shape[-1])[0]]]
            preds_cat = preds_cat % probs_proj_cat.shape[-1]
            
            acc_cat, acc_rat = compute_RR(preds_rat.cpu().tolist(), target_rat.cpu().tolist(), preds_cat.cpu().tolist(), cat.item())
            RRs.append(acc_cat * acc_rat)
            RWs.append(acc_cat * (1-acc_rat))
            WRs.append((1-acc_cat) * acc_rat)
            WWs.append((1-acc_cat) * (1-acc_rat))
            

    RR = np.mean(RRs)
    RW = np.mean(RWs)
    WR = np.mean(WRs)
    WW = np.mean(WWs)
    print(f"RR: {RR}, RW: {RW}, WR: {WR}, WW: {WW}")
    return RR, RW, WR, WW

In [11]:
RR, RW, WR, WW = evaluation_MR(model, testLoader, test_cats_embed, test_rats_embed, k_beam=1)

  0%|          | 0/589 [00:00<?, ?it/s]

yes..............


  0%|          | 1/589 [01:21<13:21:13, 81.76s/it]

yes..............


  0%|          | 2/589 [02:22<11:17:53, 69.29s/it]

yes..............


  1%|          | 3/589 [03:12<9:49:58, 60.41s/it] 

yes..............


  1%|          | 4/589 [04:19<10:16:13, 63.20s/it]

yes..............


  1%|          | 5/589 [05:11<9:35:14, 59.10s/it] 

yes..............


  1%|          | 6/589 [05:57<8:51:25, 54.69s/it]

yes..............


  1%|          | 7/589 [06:51<8:48:40, 54.50s/it]

yes..............


  1%|▏         | 8/589 [07:35<8:15:50, 51.21s/it]

yes..............


  2%|▏         | 9/589 [08:18<7:50:26, 48.67s/it]

yes..............


  2%|▏         | 10/589 [09:11<8:02:04, 49.96s/it]

yes..............


  2%|▏         | 11/589 [10:21<8:59:10, 55.97s/it]

yes..............


  2%|▏         | 12/589 [11:15<8:52:32, 55.38s/it]

yes..............


  2%|▏         | 13/589 [12:14<9:03:31, 56.62s/it]

yes..............


  2%|▏         | 14/589 [12:58<8:26:02, 52.81s/it]

yes..............


  3%|▎         | 15/589 [13:48<8:17:08, 51.97s/it]

yes..............


  3%|▎         | 16/589 [14:47<8:33:59, 53.82s/it]

yes..............


  3%|▎         | 17/589 [16:11<10:00:41, 63.01s/it]

yes..............


  3%|▎         | 17/589 [16:21<9:10:20, 57.73s/it] 


KeyboardInterrupt: 

In [None]:
total_steps = epochs_num * len(trainLoader)
warmup_steps = 0.2 * total_steps
optimizer = torch.optim.AdamW(tokenPrompts.parameters(), learning_rate)
scheduler = CosineLR(optimizer, learning_rate, warmup_steps, total_steps)

criterion_cat = torch.nn.CrossEntropyLoss()
criterion_rat = torch.nn.BCEWithLogitsLoss()

prev_RR_zeroshot = 0.

for epoch in range(epochs_num):
    
    cumLoss_cat= 0.0
    cumLoss_rat= 0.0
    count = 0

    model.eval()
    tokenPrompts.train()

    for (images, cats, rats) in tqdm(trainLoader, total=len(trainLoader)):
        images = images.to(device)
        cats = cats.to(device)
        rats = rats.to(device)
        
        prompt, pos_embedding = tokenPrompts()
        images_embed = model.encode_image(images, prompt, pos_embedding)
        images_embed = images_embed / images_embed.norm(dim=1, keepdim=True) # B x d
        
        # Compute Loss rat
        logits_rat =  model.logit_scale.exp() * torch.squeeze(images_embed @ train_rats_embed.t()) # B x |R|
        probs_rat = logits_rat.softmax(dim=-1) # B x |R|

        logits_cat = []
        target_rats = torch.zeros_like(logits_rat).to(device)

        for i, (img_embed, cat, rat) in enumerate(zip(images_embed, cats, rats)):
            target_rat = rat[rat!=-1]
            target_rat_embed = train_rats_embed[target_rat] # m x d
            hyper_plane = torch.cat([target_rat_embed, img_embed[None, :]], dim=0) # (m+1) x d
            res = torch.linalg.lstsq(hyper_plane.t(), train_cats_embed.t())
            cat_projection = hyper_plane.t() @ res.solution # d x |C|
            groundtruth_dir = hyper_plane.sum(dim=0) # d
            groundtruth_dir = groundtruth_dir / groundtruth_dir.norm(dim=-1, keepdim=True) # d
            
            logits_cat.append(model.logit_scale.exp() * groundtruth_dir[None, :] @ cat_projection) # 1 x |C|
            target_rats[i, target_rat] = 1./len(target_rat)

        logits_cat = torch.cat(logits_cat, dim=0) # B x |C|

        loss_rat = torch.sum(-target_rats * torch.log(probs_rat+1e-12), dim=-1).mean()
        #loss_rat = criterion_rat(logits_rat, target_rats)
        loss_cat = criterion_cat(logits_cat, cats)
        loss_total = loss_rat + loss_cat
        optimizer.zero_grad()
        loss_total.backward()
        optimizer.step()
        scheduler.step()

        cumLoss_cat += loss_cat.item() * len(cats)
        cumLoss_rat += loss_rat.item() * len(cats)
        count += len(cats)
        
    
    cumLoss_cat /= count
    cumLoss_rat /= count
    print(f'Epoch {epoch+1}: total loss = {cumLoss_cat + cumLoss_rat}, cat loss= {cumLoss_cat}, rat loss = {cumLoss_rat}')

    save_file = {'epoch': epoch+1,
                 'loss': cumLoss_cat+cumLoss_rat,
                 'tokenPrompts': tokenPrompts.state_dict()}

    
    torch.save(save_file, f'{root}/ECOR_new/ECOR_new_checkpoints/{dataset_name}.pth.tar')

100%|██████████| 46/46 [00:54<00:00,  1.18s/it]


Epoch 1: total loss = 3.9587723180847463, cat loss= 0.17338008843207756, rat loss = 3.7853922296526687


100%|██████████| 46/46 [00:54<00:00,  1.19s/it]


Epoch 2: total loss = 3.3979001796674266, cat loss= 0.21006648302243128, rat loss = 3.1878336966449954


100%|██████████| 46/46 [00:55<00:00,  1.21s/it]


Epoch 3: total loss = 3.1523674646724804, cat loss= 0.2204651695110656, rat loss = 2.931902295161415


100%|██████████| 46/46 [00:55<00:00,  1.21s/it]


Epoch 4: total loss = 3.022326851513541, cat loss= 0.2036431256789884, rat loss = 2.8186837258345525


100%|██████████| 46/46 [00:55<00:00,  1.20s/it]


Epoch 5: total loss = 2.9372972411980767, cat loss= 0.18820132972623632, rat loss = 2.7490959114718403


100%|██████████| 46/46 [00:55<00:00,  1.20s/it]


Epoch 6: total loss = 2.858408327316183, cat loss= 0.17891029127323446, rat loss = 2.6794980360429483


100%|██████████| 46/46 [00:55<00:00,  1.20s/it]


Epoch 7: total loss = 2.7967814322081836, cat loss= 0.16828056634049526, rat loss = 2.6285008658676885


100%|██████████| 46/46 [00:55<00:00,  1.21s/it]


Epoch 8: total loss = 2.753096593201243, cat loss= 0.1617149345442797, rat loss = 2.5913816586569633


100%|██████████| 46/46 [00:55<00:00,  1.21s/it]


Epoch 9: total loss = 2.7255153322739223, cat loss= 0.15603383907721086, rat loss = 2.5694814931967116


100%|██████████| 46/46 [00:55<00:00,  1.21s/it]


Epoch 10: total loss = 2.7010565672590197, cat loss= 0.1559157376055236, rat loss = 2.545140829653496


100%|██████████| 46/46 [00:55<00:00,  1.21s/it]


Epoch 11: total loss = 2.6789399683269393, cat loss= 0.152598964001463, rat loss = 2.5263410043254764


100%|██████████| 46/46 [00:55<00:00,  1.20s/it]


Epoch 12: total loss = 2.6536190257939722, cat loss= 0.1500606792843688, rat loss = 2.5035583465096036


100%|██████████| 46/46 [00:55<00:00,  1.20s/it]


Epoch 13: total loss = 2.6462323541339505, cat loss= 0.14767584804430048, rat loss = 2.49855650608965


100%|██████████| 46/46 [00:55<00:00,  1.20s/it]


Epoch 14: total loss = 2.6239817103842484, cat loss= 0.1487511885384636, rat loss = 2.4752305218457846


100%|██████████| 46/46 [00:55<00:00,  1.21s/it]


Epoch 15: total loss = 2.599507568123265, cat loss= 0.14520309525406708, rat loss = 2.4543044728691976


100%|██████████| 46/46 [00:55<00:00,  1.21s/it]


Epoch 16: total loss = 2.58743215656297, cat loss= 0.1459647879767055, rat loss = 2.4414673685862645


100%|██████████| 46/46 [00:55<00:00,  1.21s/it]


Epoch 17: total loss = 2.5788938277201368, cat loss= 0.14456028498472184, rat loss = 2.434333542735415


100%|██████████| 46/46 [00:55<00:00,  1.20s/it]


Epoch 18: total loss = 2.5658058038710476, cat loss= 0.14513832703232765, rat loss = 2.42066747683872


100%|██████████| 46/46 [00:55<00:00,  1.20s/it]


Epoch 19: total loss = 2.5591547728244364, cat loss= 0.14441786020128558, rat loss = 2.414736912623151


100%|██████████| 46/46 [00:55<00:00,  1.21s/it]

Epoch 20: total loss = 2.5560864763718256, cat loss= 0.1441927102509046, rat loss = 2.411893766120921





In [18]:
RR, RW, WR, WW = evaluation_MR(model, testLoader, test_cats_embed, test_rats_embed, tokenPrompts, k_beam=1)

100%|██████████| 11/11 [00:09<00:00,  1.19it/s]

RR: 0.5891655744866754, RW: 0.3955439056356488, WR: 0.0006116207951070337, WW: 0.014678899082568806





In [19]:
RR, RW, WR, WW = evaluation(model, testLoader, test_cats_embed, test_rats_embed, tokenPrompts)

100%|██████████| 11/11 [00:05<00:00,  2.01it/s]

RR: 0.5481869812145042, RW: 0.4212319790301442, WR: 0.005657492354740061, WW: 0.02492354740061162



