# **Task 1: Diffusion Model Zero Shot Classification**

In [1]:
%%capture
!pip install diffusers

In [2]:
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import numpy as np
from torchvision import transforms
from torchvision.models import resnet50
from torch.nn import CosineSimilarity
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [5]:
def zero_shot_classification(image_path, categories):
    model_id = "CompVis/stable-diffusion-v1-4"
    print(f"Loading {model_id}...")
    model = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    model = model.to(device)
    
    print("Prprocessing image...")
    input_image = Image.open(image_path).convert("RGB")
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image).unsqueeze(0).to("cuda")

    print("Loading Resnet 50...")
    feature_extractor = resnet50(pretrained=True).to("cuda")
    feature_extractor.eval()

    with torch.no_grad():
        input_features = feature_extractor(input_tensor)

    print("Generating Images and Computing Similarities")
    similarities = []
    for category in categories:
        prompt = f"A photo of a {category}"
        generated_image = model(prompt).images[0]
        generated_tensor = preprocess(generated_image).unsqueeze(0).to("cuda")
        with torch.no_grad():
            generated_features = feature_extractor(generated_tensor)
        similarity = CosineSimilarity(dim=1)(input_features, generated_features).item()
        similarities.append(similarity)
    predicted_category = categories[np.argmax(similarities)]
    return predicted_category

In [7]:
image_path = "/kaggle/input/horse-image/image-10.png"
categories = ["cat", "dog", "bird", "car", "house", "horse", "zebra"]
prediction = zero_shot_classification(image_path, categories)
print(f"The predicted category is: {prediction}")

Loading CompVis/stable-diffusion-v1-4...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Prprocessing image...
Loading Resnet 50...
Generating Images and Computing Similarities


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

The predicted category is: horse
