In [1]:
import torch
from PIL import Image, ImageOps
from tkinter import filedialog
import open_clip
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(torch.cuda.get_device_name()) 
# open_clip.list_pretrained()

model, _, preprocess = open_clip.create_model_and_transforms('convnext_base', pretrained='laion400m_s13b_b51k')
model.eval() 
tokenizer = open_clip.get_tokenizer('ViT-B-32');

Using device: cuda
NVIDIA GeForce RTX 4050 Laptop GPU


In [5]:
holiday_labels = [];

final_label_file_path = '../label_data/events/holidays_list.txt'


with open(final_label_file_path, 'r') as file:
    for line in file:
        label = line.strip()  
        holiday_labels.append(label)


holiday_description = tokenizer([
     f"This image shows events, activities, and things typically seen during {holiday}. Can you identify what holiday it represents based on these elements?"
    for holiday in holiday_labels
])

def save_text_features(text_features, file_path):
    torch.save(text_features, file_path)

text_features_file = './encode/text_features_event.pt'

if not os.path.exists(text_features_file):
    with torch.no_grad():
        text_features = model.encode_text(holiday_description)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        save_text_features(text_features, text_features_file)
else:
    text_features = torch.load(text_features_file)

  text_features = torch.load(text_features_file)


In [6]:
import time

image_path = filedialog.askopenfilename(title="Choose a image", filetypes=[("Image files", "*.png *.jpg *.jpeg")])

image = preprocess(Image.open(image_path)).unsqueeze(0)

start_time = time.time()

with torch.no_grad(), torch.amp.autocast('cuda'):
    image_features = model.encode_image(image)
    image_features /= image_features.norm(dim=-1, keepdim=True)

    holiday_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

sorted_probs_and_labels = sorted(zip(holiday_labels, holiday_probs[0]), key=lambda x: x[1], reverse=True)


end_time = time.time()
elapsed_time = end_time - start_time

print(f"Running time: {elapsed_time:.4f} second")
print("-----------------------------------------")

for holiday, prob in sorted_probs_and_labels[:10]:
    print(f"{holiday}: {100*prob.item():.4f}%")


Running time: 0.1314 second
-----------------------------------------
Hung Kings commemoration day - Vietnam: 99.9759%
Lunar New Year - Vietnam: 0.0216%
Chuseok - South Korea: 0.0009%
Tomb-Sweeping Day - China: 0.0008%
Dragon Boat Festival - China: 0.0006%
Korean New Year: 0.0001%
Mid-Autumn Festival: 0.0000%
Culture Day - Japan: 0.0000%
Showa Day - Japan: 0.0000%
Children's Day: 0.0000%
