### Necessary imports

In [None]:
import os
import PIL
import clip
import torch
import shutil
import requests
from glob import glob
from pathlib import Path
from tqdm.auto import tqdm
from urllib.parse import urlparse, unquote

### Enable GPU support

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### Select and load the vision transformer model

In [None]:
model_name = 'ViT-L/14'
# model_name = 'ViT-B/32'  # Need more models? Have a look at: https://huggingface.co/openai

model, preprocess = clip.load(model_name, device)

### Categorizing function

In [None]:
def categorize_images(labels, src_dir, img_extension='*.*', dest_folder=None, pred_threshold=0.6, verbose=False):
    # Load images   
    filepaths = Path(src_dir).glob(img_extension)
    images = [(f, PIL.Image.open(f)) for f in filepaths]    
        
    # Create folders (= predicted classes) if they're not already exist.
    dest_path = src_dir if dest_folder is None else dest_folder
    
    input_tokens = []
    for label in labels:        
        Path(dest_path, label).mkdir(parents=True, exist_ok=True)     
        input_tokens.append(clip.tokenize(f'a photo of a {label}'))
        
    text_inputs = torch.cat(input_tokens).to(device)    
    
    # Generate text features
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)   
        
    # Process and classify each image according to the given threshold
    for f, image in tqdm(images):
        image_input = preprocess(image).unsqueeze(0).to(device)

        # Generate image features
        with torch.no_grad():
            image_features = model.encode_image(image_input)
            image_features /= image_features.norm(dim=-1, keepdim=True)
        
        # Pick top-k most similar labels for the image
        similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
        values, indices = similarity[0].topk(len(labels))

        for value, index in zip(values, indices):
            pred_label = labels[index]
            prob = value.item()
            
            if verbose:
                adjusted_prob = 1 - prob if prob < 0.5 else prob
                print(f'Predicted as [{labels[index]}] Confidence: {100 * adjusted_prob:.2f}%')

            if float(prob) > pred_threshold:
                destination = Path(dest_path, pred_label, Path(f).name)  
                shutil.copy(f, destination) 
    if verbose:
        print(f'{"-"*40}\nDone.')     

# Demo

### Define a source for uncategorized images (e.g., Wikipedia)

In [None]:
cat_urls = ['https://upload.wikimedia.org/wikipedia/commons/7/76/TapetumLucidum.JPG',
            'https://upload.wikimedia.org/wikipedia/commons/1/12/Tabby_cat_with_visible_nictitating_membrane.jpg',
            'https://upload.wikimedia.org/wikipedia/commons/4/4f/Felis_silvestris_catus_lying_on_rice_straw.jpg',
            'https://upload.wikimedia.org/wikipedia/commons/b/bb/Kittyply_edit1.jpg',
            'https://upload.wikimedia.org/wikipedia/commons/3/3b/Gato_enervado_pola_presencia_dun_can.jpg',
            'https://upload.wikimedia.org/wikipedia/commons/b/b6/Felis_catus-cat_on_snow.jpg',
            'https://upload.wikimedia.org/wikipedia/commons/5/5e/Domestic_Cat_Face_Shot.jpg',
            'https://upload.wikimedia.org/wikipedia/commons/0/0c/Black_Cat_%287983739954%29.jpg']

frog_urls = ['https://upload.wikimedia.org/wikipedia/commons/c/c1/Variegated_golden_frog_%28Mantella_baroni%29_Ranomafana.jpg',
             'https://upload.wikimedia.org/wikipedia/commons/6/6e/R._imitator_Chazuta.jpg',
             'https://upload.wikimedia.org/wikipedia/commons/5/55/Atelopus_zeteki1.jpg',             
             'https://upload.wikimedia.org/wikipedia/commons/6/68/Wood_Frog_%28Rana_sylvatica%29_%2825234151669%29.jpg',
             'https://upload.wikimedia.org/wikipedia/commons/4/4f/Bombina_bombina_1_%28Marek_Szczepanek%29_tight_crop.jpg',
             'https://upload.wikimedia.org/wikipedia/commons/a/ab/Dendrobates_pumilio.jpg',
             'https://upload.wikimedia.org/wikipedia/commons/5/5b/Bufo_periglenes2.jpg']

### Get the directory of this jupyter notebook

In [None]:
# Credits to: "patricksilva" @ https://github.com/ipython/ipython/issues/10123
if os.name == 'posix': 
    current_directory = !pwd
elif os.name == 'nt':
    current_directory = !echo %cd%

current_directory = current_directory[0] 

### Create the source directory where the downloaded images will be saved

In [None]:
source_directory = Path(current_directory, 'uncategorized')
Path(source_directory).mkdir(parents=True, exist_ok=True)

### Create the destination directory where the categorized images will be copied to

In [None]:
destination_directory = Path(current_directory, 'categorized')
Path(destination_directory).mkdir(parents=True, exist_ok=True)

### Download the images

In [None]:
def download_file(url, directory):    
    headers = {'User-Agent': 'CoolBot/0.0 (https://example.org/coolbot/; coolbot@example.org)'}
    response = requests.get(url, headers=headers)
    filename = unquote(Path(urlparse(url).path).name)    
    filename = Path(directory, filename)  
    
    if response.status_code == 200:
        with open(filename, 'wb') as file:
            file.write(response.content)
    else:
        print(response.raise_for_status())

In [None]:
for url in tqdm(cat_urls + frog_urls):
    download_file(url, source_directory)    

### Finally, categorize the downloaded images

In [None]:
# Instead of single words, you can also define phrases that describe the content of the respective images
labels = ['cat', 'frog']

categorize_images(labels=labels, src_dir=source_directory, dest_folder=destination_directory, verbose=True)