In [None]:
%pip install open_clip_torch

In [None]:
import torch
from PIL import Image
from tkinter import filedialog
import open_clip
import os

open_clip.list_pretrained()

In [None]:
model, _, preprocess = open_clip.create_model_and_transforms('convnext_base', pretrained='laion400m_s13b_b51k')
model.eval()  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
tokenizer = open_clip.get_tokenizer('ViT-B-32');

In [22]:
location_labels = [];

final_label_file_path = '../label_data/components/place365_and_sun397_with_wordnet.txt'


with open(final_label_file_path, 'r') as file:
    for line in file:
        label = line.strip()  
        location_labels.append(label)


location_description = tokenizer([
    f"{location}"
    for location in location_labels
])

In [23]:
def save_text_features(text_features, file_path):
    torch.save(text_features, file_path)

text_features_file = './encode/text_features.pt'

if not os.path.exists(text_features_file):
    with torch.no_grad():
        text_features = model.encode_text(location_description)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        save_text_features(text_features, text_features_file)
else:
    text_features = torch.load(text_features_file)


In [30]:
image_path = filedialog.askopenfilename(title="Choose a image", filetypes=[("Image files", "*.png *.jpg *.jpeg")])

image = preprocess(Image.open(image_path)).unsqueeze(0)


In [33]:
# NOTE: toi uu thoi gian gen anh 0.5 -> DONE
# NOTE: toi uu viec encode text
# TODO: toi uu label
# TODO: toi uu viec encode image

import time

start_time = time.time()

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    image_features /= image_features.norm(dim=-1, keepdim=True)

    location_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

sorted_probs_and_labels = sorted(zip(location_labels, location_probs[0]), key=lambda x: x[1], reverse=True)


end_time = time.time()
elapsed_time = end_time - start_time

print(f"Running time: {elapsed_time:.4f} second")
print("-----------------------------------------")

for location, prob in sorted_probs_and_labels[:10]:
    print(f"{location}: {100*prob.item():.4f}%")



Running time: 0.1275 second
-----------------------------------------
downtown: 31.9988%
river: 28.4951%
valley: 18.1065%
dam: 6.1124%
archive: 5.0314%
sky: 3.1605%
outdoor oil refinery: 2.1395%
street: 1.1304%
bayou: 0.6068%
park: 0.5022%
