In [None]:
from PIL import Image
import os
import numpy as np
import torch
from tqdm import tqdm, trange

In [None]:
from patched_clip.patched_clip import CLIP_args, get_clip_embeddings, load_clip

In [None]:
model, preprocess = load_clip()
device = CLIP_args.device

In [None]:
def is_valid_image(filename):
    ext_test_flag = any(filename.lower().endswith(extension) for extension in ['.png', '.jpg', '.jpeg'])
    is_file_flag = os.path.isfile(filename)
    return ext_test_flag and is_file_flag

In [None]:
image_dir = '/home/roger/gaussian_feature/feat_data/bulldozer_sample/images'
image_paths = [os.path.join(image_dir, fn) for fn in os.listdir(image_dir)]
image_paths = [fn for fn in image_paths if is_valid_image(fn)]

target_feat_dir = '/home/roger/gaussian_feature/feat_data/bulldozer_sample/clip_features'
os.makedirs(target_feat_dir, exist_ok=True)

output_paths = []
for image_path in image_paths:
    feat_fn = os.path.splitext(os.path.basename(image_path))[0] + '.npy'
    feat_path = os.path.join(target_feat_dir, feat_fn)
    output_paths.append(feat_path)


In [None]:
high_res_feature = True

# Get CLIP embeddings
for i in trange(len(image_paths)):
    with torch.no_grad():
        image_pil = Image.open(image_paths[i])
        descriptors = get_clip_embeddings([image_pil],
                                          to_cpu=False,
                                          model=model,
                                          preprocess=preprocess,
                                          skip_center_crop=high_res_feature)
        descriptors = descriptors.to(device)  # BCHW
        descriptors = descriptors.cpu().squeeze().numpy()
        np.save(output_paths[i], descriptors)