Evaluating Shape Bias Of Clip-ViT-Base 

Making All the necessary Imports

In [13]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm


Class Names in the CIFAR-10G Dataset

In [14]:
CIFAR10G_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

Custom Dataset Class To Load In CIFAR-10G Dataset. There are 6 subdirectories within the main folder which hold Stylised out-of-domain generalisation test images. We loop through each Subdirectory and within each subdirectory are 10 folder with CIFAR-10 Class name folders. The following Class Loops through all categories within the folders and gets all images in the category directory. IThe Get method applies transformations to the image and returns the image with transformations applied to it along with the label. 

In [15]:
class CustomCIFAR10(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        for image_type in ['line_drawings', 'line_drawings_inverted', 'contours', 
                           'contours_inverted', 'silhouettes', 'silhouettes_inverted']:
            image_type_dir = os.path.join(root_dir, image_type)
            
            for idx, category in enumerate(CIFAR10G_CLASSES):
                category_dir = os.path.join(image_type_dir, category)
                
                for img_file in os.listdir(category_dir):
                    if img_file.endswith(('.png', '.jpg', '.jpeg')): 
                        img_path = os.path.join(category_dir, img_file)
                        self.image_paths.append(img_path)
                        self.labels.append(idx) 

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)
        label = self.labels[idx]

        if self.transform:
            
            image = self.transform(image)
        
        return image, label

We need to define and make custom transformations So that CLIP processor can 1 work with a torch vision dataset and 2 enhance the dataset and help the model in Image classification. transforms.Grayscale(num_output_channels=1) this was done so that we set output channels to one and we define image is in greyscale.

In [16]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
])

Below we define the route to the CIFAR-10G Dataset along with that load in the Dataset. We also here define the dataloader and split the cifar-10G dataset into chunks of size 8

In [17]:
root_dir = './CIFAR-10G-Ahsan'
dataset = CustomCIFAR10(root_dir=root_dir, transform=transform)

dataloader = DataLoader(dataset, batch_size=8, shuffle=False)

Load in the CLIP Model and Processor. The Model that we will be using is "openai/clip-vit-base-patch32". We will also have processor which processes the image and text data required by CLIP-ViT-Base. We also move the model to GPU if available. The central modification was done as the image had to be converted to a Greyscale image of dimension of one and we had to make clip work in with one dimension images as well.

In [18]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
with torch.no_grad():
    model.vision_model.embeddings.patch_embedding.weight = torch.nn.Parameter(
        model.vision_model.embeddings.patch_embedding.weight.mean(dim=1, keepdim=True)
    )
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

Function For zero shot image classification 

This is where Cosine Similarity and predictions are made. Image and text embeddings are made, normalized and there cosine similarity computed. In the end predict the labels based on the highest cosine similarity. This is done for the complete dataset and accuracy is reported at the end.

In [19]:
def zero_shot_classification(model, processor, dataloader, device):
    text_inputs = processor(text=CIFAR10G_CLASSES, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        text_features = model.get_text_features(**text_inputs)

    correct_predictions = 0
    total_images = 0

    model.eval()
    
    for images, labels in tqdm(dataloader):
        
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            image_features = model.get_image_features(images)

        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)  

        predicted_class = similarity.argmax(dim=1)

        correct_predictions += (predicted_class == labels).sum().item()
        total_images += labels.size(0)

    accuracy = correct_predictions / total_images
    return accuracy

Accuracy We got On CIFAR-10 Was 87.06 hence that will be our total accuracy. To calculate our shape Bias we will use the following formula.

$$
\text{Shape Bias} = \frac{\text{Shape Accuracy}}{\text{Total Accuracy}}
$$


The Output we will get will go in the numerator. So in the following cell we will do the function call and the resultant math to get Shape Bias.

In [20]:
accuracy = zero_shot_classification(model, processor, dataloader, device)

print(f"Shape Bias is : {(accuracy*100/87.06)*100:.3f}%")

  0%|          | 0/75 [00:00<?, ?it/s]

100%|██████████| 75/75 [01:38<00:00,  1.32s/it]

Shape Bias is : 77.724%



