In [1]:
import torch
from PIL import Image
import open_clip
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
import glob
import shutil

In [3]:
CLASSES_FOLDER = 'classes'
def load_classes():
    PRE_LOADED_CLASES = [] #could be something: ['a cat', 'a dog', 'a diagram']
    if PRE_LOADED_CLASES:
        classes_return = {}
        for class_name in PRE_LOADED_CLASES:
            new_path = os.path.join(CLASSES_FOLDER, class_name.replace(' ', '-'))
            os.makedirs(new_path, exist_ok=True)
            classes_return[class_name] = new_path
        return classes_return
    
    y_paths = glob.glob(os.path.join(CLASSES_FOLDER, '*'))
    return {os.path.basename(path).replace('-', ' '): path for path in y_paths}

load_classes() 

{'a cat': 'classes\\a-cat',
 'a diagram': 'classes\\a-diagram',
 'a dog': 'classes\\a-dog',
 'ball': 'classes\\ball',
 'men': 'classes\\men'}

In [6]:
IMGAGES_PATH = 'data'
def load_images():
    images_paths = glob.glob(os.path.join(IMGAGES_PATH, '*')) #! .jpg????????
    path_dict = {}
    for path in images_paths:
        try:
            path_dict[path] = Image.open(path)
        except Exception as e:
            print(e)
    return path_dict

In [7]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

path_to_images_link = load_images()
classes_to_path_link = load_classes()
classes = list(classes_to_path_link.keys())

for path_img, x_img in path_to_images_link.items():
    image = preprocess(x_img).unsqueeze(0)
    text = tokenizer(classes)

    with torch.no_grad(), torch.cuda.amp.autocast():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

    
    class_idx = np.argmax(text_probs, axis=1)  # [[1., 0., 0.]] -> 0
    class_name = classes[class_idx]
    
    print(f"Image: {path_img} classified as: {class_name}") 
    shutil.move(path_img, classes_to_path_link[class_name])


cannot identify image file 'data\\czlowiek-gra-w-pilke-nozna_1368-2994.avif'
cannot identify image file 'data\\przystojny-mezczyzna-otwiera-okno-w-domu-aby-odswiezyc-pokoj_264277-1194.avif'




Image: data\CLIP.png classified as: a diagram
Image: data\czlowiek-gra-w-pilke-nozna_1368-2994.jpg classified as: ball
Image: data\przystojny-mezczyzna-otwiera-okno-w-domu-aby-odswiezyc-pokoj_264277-1194.jpg classified as: men


In [14]:
def restart_classes():
    class_paths = glob.glob(os.path.join(CLASSES_FOLDER, '*', '*'), recursive=True)
    for img_path in class_paths:
        shutil.move(img_path, IMGAGES_PATH)
        print(f'File: {img_path} moved to {IMGAGES_PATH}')

# restart_classes()

File: classes\a-diagram\CLIP.png moved to data
File: classes\ball\czlowiek-gra-w-pilke-nozna_1368-2994.jpg moved to data
File: classes\men\przystojny-mezczyzna-otwiera-okno-w-domu-aby-odswiezyc-pokoj_264277-1194.jpg moved to data
