### Notebook description:

This notebooks should: 
1. read images and embed it by CLIP model
1. save embeddings in pair with filename to pickle object

At the end we will have just embedding collection that we can use to find clusters or to make image retrieval.

In [1]:
from pathlib import Path
Path.ls = lambda x: list(x.iterdir())

from tqdm import tqdm
import shutil

In [2]:
DATA_DIR = Path('data')
DIR_IMGS = DATA_DIR / 'drone_imgs'

# Get image embeddings

In [3]:
# !pip install git+https://github.com/openai/CLIP.git -q

In [4]:
import torch
import clip
from PIL import Image

In [5]:
from dataclasses import dataclass

@dataclass
class ImageEmbedding:
    """Class for keeping track of an item in inventory."""
    emb: torch.Tensor # clip image embedding: B, 768
    filename: Path

    def __init__(self, filename: str, emb: torch.Tensor):
        self.filename = filename
        self.emb = emb

In [6]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)
model.eval()
1

1

In [8]:
processed_images = []
for img_file in tqdm(DIR_IMGS.ls()):
    if img_file.suffix != '.png':
        continue
        
    image = preprocess(Image.open(img_file)).unsqueeze(0).to(device)
    with torch.inference_mode():
        image_features = model.encode_image(image)
    
    img_obj = ImageEmbedding(img_file.name, image_features.cpu())
    processed_images.append(img_obj)

100%|███████████████████████████████████████████████████████████████████████████████████| 8201/8201 [03:16<00:00, 41.68it/s]


In [10]:
len(processed_images)

8201

# Save processed images to pickle

In [9]:
import pickle

def write_pickle(filename, obj):
    with open(filename, 'wb') as handle:
        pickle.dump(obj, handle)

def read_pickle(filename):
    with open(filename, 'rb') as handle:
        return pickle.load(handle)

In [11]:
write_pickle('data/processed_images.pickle', processed_images)