In [6]:
import torch
from transformers import CLIPModel, CLIPProcessor
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from torchvision import transforms
import os 
from PIL import Image
from torch.utils.data import Dataset

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

class StyledImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        # List all image files starting with 'styled_image_'
        self.image_filenames = [f for f in os.listdir(root_dir) if f.startswith('styled_image_')]

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_filenames[idx])
        image = Image.open(img_name).convert('RGB')
        label = int(self.image_filenames[idx].split('_')[2].split('.')[0])

        if self.transform:
            image = self.transform(image)

        return image, label

# Define the directory containing your dataset
data_dir = './CIFAR_Styled'

# Mapping folder names to class indices (CIFAR-10 class names)
class_map = {
    'airplane': 0,
    'automobile': 1,
    'bird': 2,
    'cat': 3,
    'deer': 4,
    'dog': 5,
    'frog': 6,
    'horse': 7,
    'ship': 8,
    'truck': 9
}

# Image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to fit the model input size
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

# Initialize dataset and dataloader
dataset = StyledImageDataset(root_dir=data_dir, transform=transform)
test_loader = DataLoader(dataset, batch_size=32, shuffle=False)

# Prepare text features for the classes
class_names = list(class_map.keys())
text_inputs = processor(text=class_names, return_tensors="pt", padding=True).to(device)

# Zero-shot classification function
def zero_shot_classification(loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(loader):
            images = images.to(device)

            # Get image features
            image_features = model.get_image_features(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)  # Normalize image features

            # Get text features
            text_features = model.get_text_features(**text_inputs)
            text_features /= text_features.norm(dim=-1, keepdim=True)  # Normalize text features

            # Compute similarity scores
            logits_per_image = image_features @ text_features.T  # Dot product
            probs = logits_per_image.softmax(dim=-1)  # Convert to probabilities

            # Predicted class indices
            predictions = probs.argmax(dim=-1)

            # Update correct predictions count
            correct += (predictions.cpu() == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    return accuracy

# Calculate accuracy
accuracy = zero_shot_classification(test_loader)
print(f'Zero-shot classification accuracy: {accuracy * 100:.2f}%')


100%|██████████| 4/4 [00:02<00:00,  1.99it/s]

Zero-shot classification accuracy: 52.00%



