In [1]:
# image_search_test.py
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from pathlib import Path
import numpy as np
import faiss
import os

# ==== CONFIG ====
DATA_ROOT = Path("Scraping_part\goat_data")  # each subfolder = class, contains slugs/images
MODEL_PATH = "best_resnet50_sneakers.pt"

# ==== MODEL ====
NUM_CLASSES = 50
base = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
base.fc = nn.Sequential(
    nn.Linear(base.fc.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(512, NUM_CLASSES)
)
state = torch.load(MODEL_PATH, map_location="cpu")
base.load_state_dict(state, strict=True)
feature_extractor = nn.Sequential(*list(base.children())[:-1]).eval()  # 2048-d after avgpool

device = "cuda" if torch.cuda.is_available() else "cpu"
feature_extractor.to(device)

preprocess = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

def embed_image(p: Path) -> np.ndarray:
    img = Image.open(p).convert("RGB")
    x = preprocess(img).unsqueeze(0).to(device)
    with torch.no_grad():
        v = feature_extractor(x).squeeze().cpu().numpy()  # [2048]
    v = v / (np.linalg.norm(v) + 1e-8)
    return v.astype("float32")


  state = torch.load(MODEL_PATH, map_location="cpu")


In [6]:
from PIL import Image
import torch
import numpy as np
from transformers import CLIPModel, CLIPProcessor
from pathlib import Path

# ==== LOAD MODEL ====
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", use_safetensors=True).to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)

# ==== EMBEDDING FUNCTION ====
def embed_image(path: str):
    img = Image.open(path).convert("RGB")
    inputs = processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        emb = model.get_image_features(**inputs)
    vec = emb[0].cpu().numpy()
    return vec / np.linalg.norm(vec)

# ==== TEST ====
img1 = Path("s-l1200.jpg")
img2 = Path(r"Scraping_part\goat_data\adidas_forum_high\cuts-slices-x-forum-84-high-pull-up-beloved-fz6567\1045923_01.jpg.jpeg")


v1 = embed_image(img1)
v2 = embed_image(img2)

similarity = np.dot(v1, v2)  # cosine similarity since both normalized
print(f"Similarity: {similarity:.3f}")


Similarity: 0.769


In [2]:

# --- build: one index per class ---
class2index = {}   # class -> faiss index
class2paths = {}   # class -> [image paths]

for class_dir in sorted(p for p in DATA_ROOT.iterdir() if p.is_dir()):
    cls = class_dir.name
    vecs, paths = [], []

    for slug_dir in (d for d in class_dir.iterdir() if d.is_dir()):
        for img_path in slug_dir.glob("*.[jp][pn]g"):
            vecs.append(embed_image(img_path))
            paths.append(str(img_path))

    if not vecs:
        continue
    V = np.stack(vecs)
    faiss.normalize_L2(V)
    idx = faiss.IndexFlatIP(V.shape[1])   # cosine via dot product
    idx.add(V)

    class2index[cls] = idx
    class2paths[cls]  = paths
    print(f"{cls}: indexed {len(paths)} images.")

print(f"Built {len(class2index)} class indexes.")


adidas_forum_high: indexed 30 images.
adidas_forum_low: indexed 20 images.
adidas_gazelle: indexed 56 images.
adidas_nmd_r1: indexed 38 images.
adidas_samba: indexed 46 images.
adidas_stan_smith: indexed 79 images.
adidas_superstar: indexed 88 images.
adidas_ultraboost: indexed 59 images.
asics_gel-lyte_iii: indexed 85 images.
converse_chuck_70_high: indexed 77 images.
converse_chuck_70_low: indexed 70 images.
converse_chuck_taylor_all-star_high: indexed 83 images.
converse_chuck_taylor_all-star_low: indexed 102 images.
converse_one_star: indexed 76 images.
new_balance_327: indexed 48 images.
new_balance_550: indexed 31 images.
new_balance_574: indexed 61 images.
new_balance_992: indexed 18 images.
nike_air_force_1_high: indexed 76 images.
nike_air_force_1_low: indexed 44 images.
nike_air_force_1_mid: indexed 85 images.
nike_air_jordan_11: indexed 56 images.
nike_air_jordan_1_high: indexed 47 images.
nike_air_jordan_1_low: indexed 44 images.
nike_air_jordan_3: indexed 37 images.
nike_a

In [3]:
def search_in_class(cls: str, query_img: str, k=5):
    if cls not in class2index:
        raise ValueError(f"class '{cls}' not found")
    q = embed_image(Path(query_img)).reshape(1,-1)
    faiss.normalize_L2(q)
    D, I = class2index[cls].search(q, k)
    out = [(class2paths[cls][int(i)], float(s)) for i, s in zip(I[0], D[0])]
    return out

# --- example usage ---
if __name__ == "__main__":
    # choose the class you want to restrict to
    target_class = "adidas_forum_high"
    query_path = "s-l1200.jpg"  # your test image

    results = search_in_class(target_class, query_path, k=5)
    for r, (p, score) in enumerate(results, 1):
        print(f"{r}. {p}  score={score:.3f}")


1. Scraping_part\goat_data\adidas_forum_high\forum-84-high-off-white-team-dark-green-gw4328\GW4328.png.png  score=0.765
2. Scraping_part\goat_data\adidas_forum_high\girls-are-awesome-x-wmns-forum-high-cloud-white-purple-gy2632\GY2632.png.png  score=0.762
3. Scraping_part\goat_data\adidas_forum_high\packer-shoes-x-forum-84-high-college-pack-collegiate-maroon-gx1520\GX1520.png.png  score=0.758
4. Scraping_part\goat_data\adidas_forum_high\snipes-x-forum-high-313-day-gz8376\GZ8376.png.png  score=0.749
5. Scraping_part\goat_data\adidas_forum_high\packer-shoes-x-forum-84-high-college-pack-collegiate-green-gx1519\GX1519.png.png  score=0.746


In [3]:
# image_search_class_scoped.py
from pathlib import Path
from PIL import Image
import numpy as np
import torch
import faiss
from transformers import CLIPModel, CLIPProcessor

# ==== CONFIG ====
DATA_ROOT = Path("Scraping_part/goat_data")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ==== MODEL ====
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", use_safetensors=True).to(DEVICE)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)

# ==== FUNCTIONS ====
def embed_image(path: Path):
    img = Image.open(path).convert("RGB")
    inputs = processor(images=img, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        emb = model.get_image_features(**inputs)
    v = emb[0].cpu().numpy()
    return v / np.linalg.norm(v)

def build_class_index(class_dir: Path):
    """Embed every image under goat_data/<class>/<slug> folders and build a FAISS index."""
    vectors, paths = [], []

    for slug_dir in class_dir.iterdir():
        if not slug_dir.is_dir():
            continue
        for ext in ("*.jpg", "*.jpeg", "*.png"):
            for img_path in slug_dir.glob(ext):
                try:
                    v = embed_image(img_path)
                    vectors.append(v.astype("float32"))
                    paths.append(str(img_path))
                except Exception as e:
                    print(f"Skip {img_path}: {e}")

    if not vectors:
        raise RuntimeError(f"No images found under {class_dir.resolve()}")

    arr = np.stack(vectors)
    faiss.normalize_L2(arr)
    index = faiss.IndexFlatIP(arr.shape[1])
    index.add(arr)
    print(f"Indexed {len(paths)} images for {class_dir.name}")
    return index, paths


def search_in_class(query_img, class_name, top_k=5):
    """Search only within one clean_class folder."""
    class_dir = DATA_ROOT / class_name
    index, paths = build_class_index(class_dir)
    qvec = embed_image(Path(query_img)).astype("float32").reshape(1, -1)
    faiss.normalize_L2(qvec)
    sims, idxs = index.search(qvec, top_k)
    print(f"\nTop {top_k} matches in {class_name}:")
    for rank, (i, score) in enumerate(zip(idxs[0], sims[0]), start=1):
        print(f"{rank:>2}. {paths[i]} (similarity {score:.3f})")




In [6]:
# ==== RUN EXAMPLE ====
query = "testing images/images (11).jpeg"  # any sneaker photo
search_in_class(query, "converse_one_star", top_k=5)

Indexed 270 images for converse_one_star

Top 5 matches in converse_one_star:
 1. Scraping_part\goat_data\converse_one_star\one-star-low-triple-black-162950c\162950C.png.png (similarity 0.765)
 2. Scraping_part\goat_data\converse_one_star\awake-ny-x-one-star-pro-low-black-a07143c\A07143C.png.png (similarity 0.763)
 3. Scraping_part\goat_data\converse_one_star\undefeated-x-one-star-academy-pro-black-a12131c\1505868_03.jpg.jpeg (similarity 0.755)
 4. Scraping_part\goat_data\converse_one_star\golf-le-fleur-x-one-star-suede-mono-black-162129c\356438_08.jpg.jpeg (similarity 0.754)
 5. Scraping_part\goat_data\converse_one_star\hello-kitty-x-one-star-low-top-velcro-td-black-763908c\763908C.png.png (similarity 0.754)
