In [1]:
from transformers import CLIPProcessor, CLIPModel
import requests
from PIL import Image

In [2]:
class CLIPClassifier:
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

    def __init__(self, imgURL: str, labels: list[str]):
        img = Image.open(requests.get(imgURL, stream=True).raw)
        self.labels = labels
        self.inputs = CLIPClassifier.processor(text=labels, images=img, return_tensors="pt", padding=True)
    
    def fit(self):
        return CLIPClassifier.model(**self.inputs)

    def predict(self, printInfo=True):
        outputs = self.fit()
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)
        pred = self.labels[probs.argmax().item()]
        if printInfo:
            print("=====================")
            print(''.join([f"label: {self.labels[i]}, prob: {probs[0, i]}\n" for i in range(len(self.labels))]))
            print(f"Prediction: {pred}")
            print("=====================")
        return pred



### A Photo Of A Cat

<img src="http://images.cocodataset.org/val2017/000000039769.jpg" alt="image" width="200px"/>

In [3]:
cat = CLIPClassifier("http://images.cocodataset.org/val2017/000000039769.jpg", ["cat", "dog"])
cat.predict()

label: cat, prob: 0.9932803511619568
label: dog, prob: 0.006719659082591534

Prediction: cat


'cat'

### A Photo Of A Dog

<img src="https://images.unsplash.com/photo-1587402092301-725e37c70fd8?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cHVwcHklMjBkb2d8ZW58MHx8MHx8&w=1000&q=80" alt="image" width="200px"/>

In [4]:
dog = CLIPClassifier("https://images.unsplash.com/photo-1587402092301-725e37c70fd8?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cHVwcHklMjBkb2d8ZW58MHx8MHx8&w=1000&q=80", ["cat", "dog"])
dog.predict()

label: cat, prob: 0.003522968152537942
label: dog, prob: 0.9964770674705505

Prediction: dog


'dog'

### A Photo Of A Snake

<img src="https://images.foxtv.com/static.foxla.com/www.foxla.com/content/uploads/2021/12/764/432/snake.jpg?ve=1&tl=1" alt="image" width="200px"/>

In [5]:
snake = CLIPClassifier("https://images.foxtv.com/static.foxla.com/www.foxla.com/content/uploads/2021/12/764/432/snake.jpg?ve=1&tl=1", ["snake", "hose"])
snake.predict()

label: snake, prob: 0.9983910918235779
label: hose, prob: 0.001608935184776783

Prediction: snake


'snake'

### A Photo Of Oranges

<img src="http://images.cocodataset.org/val2017/000000050896.jpg" alt="image" width="200px"/>

In [6]:
oranges = CLIPClassifier("http://images.cocodataset.org/val2017/000000050896.jpg", ["orange", "tangerine", "lemon"])
oranges.predict()

label: orange, prob: 0.7922490835189819
label: tangerine, prob: 0.1991378664970398
label: lemon, prob: 0.008613087236881256

Prediction: orange


'orange'