## A more accurate Human to Anime Feature Matcher

In [33]:
import torch
import timm
from PIL import Image
import numpy as np
from torchvision import transforms
from dotenv import load_dotenv
import os
import clip
import glob

# Load environment variables from .env file
load_dotenv()

True

In [21]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
model = timm.create_model("vit_base_patch14_dinov2.lvd142m", pretrained=True)
model = model.eval().to(device)
clip_model, preprocess_clip = clip.load("ViT-B/32", device=device)

In [10]:
transform = transforms.Compose([
    transforms.Resize((518, 518)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

In [25]:
def get_dino_embedding(img_path):
    img = Image.open(img_path).convert("RGB")
    x = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        emb = model.forward_features(x)  # feature extraction
    return emb.cpu().numpy().flatten()

def get_clip_embedding(img_path):
    img = Image.open(img_path).convert("RGB")
    img_pre = preprocess_clip(img).unsqueeze(0).to(device)
    with torch.no_grad():
        return clip_model.encode_image(img_pre).cpu().numpy().flatten()

In [34]:
# Get all PNG and JPG files from GenshinCharacters directory
avatar_files = glob.glob("../GenshinCharacters/*.png") + glob.glob("../GenshinCharacters/*.jpg")
dino_embeddings = [get_dino_embedding(img) for img in avatar_files]
clip_embeddings = [get_clip_embedding(img) for img in avatar_files]


In [35]:
test_path = os.environ["TEST_IMG_PATH"]
query_dino_emb = get_dino_embedding(test_path)
query_clip_emb = get_clip_embedding(test_path)

In [36]:
def combined_similarity(q_dino, q_clip, a_dino, a_clip, alpha=0.67):
    # normalize
    q_dino /= np.linalg.norm(q_dino)
    q_clip /= np.linalg.norm(q_clip)
    a_dino /= np.linalg.norm(a_dino)
    a_clip /= np.linalg.norm(a_clip)
    
    sim_dino = np.dot(q_dino, a_dino)
    sim_clip = np.dot(q_clip, a_clip)
    return alpha*sim_clip + (1-alpha)*sim_dino

In [37]:
similarities = [combined_similarity(query_dino_emb, query_clip_emb, emb[0], emb[1]) for emb in zip(dino_embeddings, clip_embeddings)]

In [38]:
best_idx = int(np.argmax(similarities))
print("Best match:", avatar_files[best_idx], "\nScore:", similarities[best_idx])

Best match: ../GenshinCharacters/Fischl.png 
Score: 0.43704224
