Evaluating Color Bias Of Clip-ViT-Base 

What is Color Bias

$$
\text{Color Bias} = \frac{\text{Color Accuracy}}{\text{Total Accuracy}}
$$

Color Bias is overlliance of the model to rely on color instead of other features such as shape or texture to make predictions. We will use two datasets, the MNIST AND colored MNIST dataset to do that today.

Make All Necessary Imports

In [6]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader

Load in the CLIP Model and Processor. The Model that we will be using is "openai/clip-vit-base-patch32". We will also have processor which processes the image and text data required by CLIP-ViT-Base. We also move the model to GPU if available. 

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

Load In the MNIST Dataset and apply any transformations needed

In [8]:
mnist_test = datasets.MNIST(root='./data', train=False, download=True,
                             transform=transforms.Compose([
                                 transforms.Resize((224, 224)),  
                                 transforms.Grayscale(num_output_channels=3),  
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                      std=[0.229, 0.224, 0.225]),
                             ]))

Define the dataloader which splits the dataset into batches and the classnames for the MNIST Dataset. 

In [9]:
mnist_dataloader = DataLoader(mnist_test, batch_size=128, shuffle=False)

mnist_classes = [str(i) for i in range(10)]

Function For zero shot image classification - Calculating Total Accuracy

In [10]:
def zero_shot_classification(model, processor, dataloader, device):
    text_inputs = processor(text=mnist_classes, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        text_features = model.get_text_features(**text_inputs)

    correct_predictions = 0
    total_images = 0

    model.eval()
    
    for images, labels in tqdm(dataloader):
        
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            image_features = model.get_image_features(images)

        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)  

        predicted_class = similarity.argmax(dim=1)

        correct_predictions += (predicted_class == labels).sum().item()
        total_images += labels.size(0)

    accuracy = correct_predictions / total_images
    print(f"Total accuracy: {accuracy * 100:.2f}%")
    return accuracy


Function Call to Calculate and report Total Accuracy 

In [11]:
Total_Accuracy = zero_shot_classification(model, processor, mnist_dataloader, device)

100%|██████████| 79/79 [17:27<00:00, 13.26s/it]

Total accuracy: 32.66%





Now we Move On To get the Color Accuracy. The First step is to get and load in the Colorized MNIST Dataset. The following code helps to load in the dataset.

In [12]:
class ColorizedMNISTDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        for idx, category in enumerate(mnist_classes):
            category_dir = os.path.join(root_dir, category)
            
            for img_file in os.listdir(category_dir):
                if img_file.endswith(('.png', '.jpg', '.jpeg')):  
                    img_path = os.path.join(category_dir, img_file)
                    self.image_paths.append(img_path)
                    self.labels.append(idx)  

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)
        
        return image, label

Define the necessary transformations that need to be applied onto the dataset such as resizing and converting to tensor.

In [13]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),  
])

Load in the Colorized MNIST Dataset and break it down into batches by sending it dataloader.

In [14]:
root_dir = './Colorized_MNIST'
colorized_mnist_dataset = ColorizedMNISTDataset(root_dir=root_dir, transform=transform)
colorized_mnist_dataloader = DataLoader(colorized_mnist_dataset, batch_size=64, shuffle=True)


Now Calculate Zero-Shot Accuracy on Colorized MNIST Dataset the function is defined below 

In [15]:
def zero_shot_classification(model, processor, dataloader, device):
    text_inputs = processor(text=mnist_classes, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        text_features = model.get_text_features(**text_inputs)

    correct_predictions = 0
    total_images = 0

    model.eval()
    
    for images, labels in tqdm(dataloader):
        
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            image_features = model.get_image_features(images)

        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) 

        predicted_class = similarity.argmax(dim=1)

        correct_predictions += (predicted_class == labels).sum().item()
        total_images += labels.size(0)

    accuracy = correct_predictions / total_images
    print(f"Zero-shot accuracy on Colorized MNIST: {accuracy:.3f}%")
    return accuracy


Now to Get Total Accuracy we can do the following 

In [16]:
Color_accuracy = zero_shot_classification(model, processor, colorized_mnist_dataloader, device)
print(f"Total Accuracy is : {(Color_accuracy)*100:.2f}%")

100%|██████████| 157/157 [16:55<00:00,  6.47s/it]

Zero-shot accuracy on Colorized MNIST: 0.259%
Total Accuracy is : 25.85%





So As we remeber color Bias is :

$$
\text{Color Bias} = \frac{\text{Color Accuracy}}{\text{Total Accuracy}}
$$

In [17]:
Color_Bias = Color_accuracy/Total_Accuracy

print(f"Color Bias is : {(Color_Bias)*100:.2f}%")

Color Bias is : 79.15%
