In [1]:
import os
import torch
import requests
import numpy as np
import transformers
from PIL import Image
import torch.nn as nn
import skimage.io as io
from tqdm.auto import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [2]:
import configs

configs.set_seed(42)
device = configs.set_device(2)

There are 8 GPU(s) available.
We will use the GPU: NVIDIA A100-SXM4-80GB


In [5]:
model = transformers.CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = transformers.CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

2024-02-01 14:13:19.227323: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-01 14:13:19.227408: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-01 14:13:19.227429: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-01 14:13:19.234420: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## data utils

In [18]:
import sys
sys.path.append('..')
from configs import *
import transformers

def collate_fn(batch):
    return {
        'image': [x['image'] for x in batch],
        'labels': [x['labels'] for x in batch]
    }

def remove_prefixes(strings):
    prefixes = ['a', 'an', 'the']
    result = []

    for string in strings:
        words = string.split()
        if words[0].lower() in prefixes:
            result.append(' '.join(words[1:]))
        else:
            result.append(string)

    return result

def preprocess_loader(loader, concepts: list):
    preprocessed_batches = []
    processor = transformers.CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    for batch in tqdm(loader):
        preprocessed_batch = preprocess_batch(batch, processor, concepts)
        preprocessed_batches.append(preprocessed_batch)
    return preprocessed_batches

def preprocess_batch(batch, processor, concepts: list):
    return processor(text=concepts, images=batch['image'], return_tensors="pt", padding=True), batch['labels']

def prepared_dataloaders(hf_link: str, concepts: list, test_size: int=0.2, prep_loaders='all'):
    from datasets import load_dataset
    from datasets import DatasetDict
    dataset = load_dataset(hf_link)
    dataset = dataset["train"].train_test_split(test_size=0.2)
    val_test = dataset["test"].train_test_split(test_size=0.5)
    dataset = DatasetDict({
        "train": dataset["train"],
        "validation": val_test["train"],
        "test": val_test["test"],
    })

    train_loader = torch.utils.data.DataLoader(dataset["train"], batch_size=32, shuffle=True, collate_fn=collate_fn, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(dataset["validation"], batch_size=32, shuffle=False, collate_fn=collate_fn, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(dataset["test"], batch_size=32, shuffle=False, collate_fn=collate_fn, pin_memory=True)

    if prep_loaders == 'all':
        train_loader_preprocessed = preprocess_loader(train_loader, concepts)
        val_loader_preprocessed = preprocess_loader(val_loader, concepts)
        test_loader_preprocessed = preprocess_loader(test_loader, concepts)
    elif prep_loaders == 'train':
        train_loader_preprocessed = preprocess_loader(train_loader, concepts)
        return train_loader_preprocessed
    elif prep_loaders == 'val':
        val_loader_preprocessed = preprocess_loader(val_loader, concepts)
        return val_loader_preprocessed
    elif prep_loaders == 'test':
        test_loader_preprocessed = preprocess_loader(test_loader, concepts)
        return test_loader_preprocessed

    #return train_loader_preprocessed, val_loader_preprocessed, test_loader_preprocessed

In [21]:
with open("conceptnet_cifar10_filtered_new.txt", "r") as f:
    concepts = f.read().lower().split("\n")
    concepts = remove_prefixes(concepts)

In [26]:
val_loader_preprocessed = prepared_dataloaders("Andron00e/CIFAR10-custom", 
                                                concepts=concepts,
                                                prep_loaders="val")

100%|█████████████████████████████████████████| 188/188 [00:34<00:00,  5.47it/s]


In [None]:
from transformers import datasets

In [37]:
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

In [45]:
def get_class_features(model, classes:list, model_name: str="openai/clip-vit-base-patch32", device=device):
    tokenizer = transformers.AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    inputs = tokenizer(classes, padding=True, return_tensors="pt").to(device)
    features = model.get_text_features(inputs['input_ids'])
    return features

In [46]:
class_features = get_class_features(model, classes, "openai/clip-vit-large-patch14", device)

In [100]:
from datasets import load_metric

def CMS(loader, model, class_features, device=device): #hf version
    accuracy_metric = load_metric("accuracy")
    with torch.no_grad(), torch.cuda.amp.autocast():
        class_features = class_features.to(device)
        clip_accuracy, cms_accuracy = [], []
        for step, batch in enumerate(loader):
            inputs, labels = batch
            inputs, targets = inputs.to(device), torch.LongTensor(labels).to(device)
            
            image_features = model.get_image_features(inputs['pixel_values'])
            concept_features = model.get_text_features(inputs['input_ids'])
            
            V_matrix = image_features @ concept_features.T
            T_matrix = class_features @ concept_features.T
            
            image_class_similarity = image_features @ class_features.T
            
            clip_preds = image_class_similarity.argmax(dim=1)
            clip_acc = accuracy_metric.compute(references=targets.cpu(), predictions=clip_preds.cpu())
            
            similarity = (V_matrix / V_matrix.norm(dim=1, keepdim=True)) @ (T_matrix / T_matrix.norm(dim=1, keepdim=True)).T
            cms_preds = similarity.argmax(dim=1)
            cms_acc = accuracy_metric.compute(references=targets.cpu(), predictions=cms_preds.cpu())
            
            clip_accuracy.append(clip_acc['accuracy'])
            cms_accuracy.append(cms_acc['accuracy'])
            
    return np.mean(clip_accuracy), np.mean(cms_accuracy)

In [101]:
clip_acc, cms_acc = CMS(val_loader_preprocessed, model, class_features, device)
print("CLIP accuracy: ", clip_acc, "\n")
print("CMS accuracy: ", cms_acc, "\n")

CLIP accuracy:  0.6032247340425532 

CMS accuracy:  0.29188829787234044 

