# What is this?
An AI that will guess your halloween costume! More specifically, a program that will match images of halloween costumes to [5,000 pre-defined](https://raw.githubusercontent.com/janelleshane/halloween-costume-dataset/master/costumes) possible halloween costumes.
# How does this work?
The program uses [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) to calculate a similarity score between the uploaded image and every halloween costume idea in this [dataset](https://raw.githubusercontent.com/janelleshane/halloween-costume-dataset/master/costumes). CLIP is doing [zero-shot](https://en.wikipedia.org/wiki/Zero-shot_learning) image classification over a dataset of 5,000 halloween costume ideas. Unfortunately if your costume is not in the dataset this program will be unable to guess your costume correctly, though it might still find an example in the dataset that is a decent match for your costume.
# Things to note
Uploading a larger image takes longer.

Be sure to use a GPU runtime.

In [None]:
!pip install transformers
!pip install gradio

In [None]:
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel # import the CLIP model from huggingface
import gradio as gr
import requests

In [None]:
# Download the CLIP model from huggingface
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Put the model on the GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

In [None]:

# Remove any whitespace sorrounding the costume idea and append "halloween costume" onto the end of the idea
# Appending halloween costume encourages CLIP to interpret the images as costumes instead of as actual photographs of the ideas
def preprocess_costume(costume):
    return f"{costume.strip()} halloween costume"


# Calculate a similarity score for every costume in "costumes" relative to the provided image
def get_logits(costume_image, costumes):
    with torch.no_grad():
        inputs = processor(text=costumes, images=costume_image, return_tensors="pt", padding=True).to(device)

        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
        return logits_per_image


# Calculate a similarity score for every costume in "costumes"
# Normalize the scores to a probability distribution over all the costume ideas
# Return the top k most probably costume ideas
def get_top_k(costume_image, costumes, k=5):
    batch_size = 512 # calculating the similarity scores in batches to not blow up our memory usage

    all_logits = []
    for i in range(0, len(costumes), batch_size):
        logits = get_logits(costume_image, costumes[i:i+batch_size])
        all_logits.append(logits)


    all_logits = torch.cat(all_logits, dim=1)
    probs = all_logits.softmax(dim=1)[0]  # we can take the softmax to get the label probabilities instead of similarity scores
    sorted_indices = torch.argsort(probs, dim=0, descending=True) # sort the costumes based on their probabilities
    top_k = {costumes[i]: probs[i].item() for i in sorted_indices[:k]} # find the top k most probably costumes

    return top_k


# Given an image and a source of costume ideas, return the top 5 most probable costumes for the given image
def classify(image, costume_list_url="https://raw.githubusercontent.com/janelleshane/halloween-costume-dataset/master/costumes"):

    # If you run this locally you may want to specify your own file of costume ideas
    # with open(costume_list_file, "r") as f:
    #     costumes = list(set([preprocess_costume(line) for line in f.readlines()]))

    costume_list = requests.get(costume_list_url).text
    costumes = list(set([preprocess_costume(line) for line in costume_list.split("\n")]))
    
    results = get_top_k(image, costumes)

    return results

In [None]:
# Host the costume classifier through a gradio app
demo = gr.Interface(
    fn=classify,
    inputs="image",
    outputs="label",
    title="Costume Classifier",
    description="Upload an image and CLIP will guess what costume it is!",
    allow_flagging=False,
)

demo.launch(
    share=True,
    enable_queue=True,
    debug=True,
)