In [1]:
import torch
import torch.nn as nn
from torchvision import models, transforms
import json
from pathlib import Path
from PIL import Image, ImageEnhance
import numpy as np
import faiss
from transformers import CLIPModel, CLIPProcessor
import pickle
from typing import Dict, Tuple, List
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:


# ===== CONFIG =====
MODEL_PATH = "best_resnet50_sneakers.pt"
CLASS_INDICES_PATH = "class_indices.json"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLASSES = 50
CONFIDENCE_THRESHOLD = 0.20

if Path(CLASS_INDICES_PATH).exists():
    with open(CLASS_INDICES_PATH, "r") as f:
        raw_map = json.load(f)

    # raw_map is:  {"adidas_forum_high": 0, "nike_air_force_1": 1, ...}
    class_map = {v: k for k, v in raw_map.items()}
else:
    raise FileNotFoundError("class_indices.json not found.")



classifier_model = models.resnet50(weights=None)
classifier_model.fc = nn.Sequential(
    nn.Linear(classifier_model.fc.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(512, NUM_CLASSES)
)

state = torch.load(MODEL_PATH, map_location=DEVICE)
classifier_model.load_state_dict(state, strict=True)
classifier_model.eval().to(DEVICE)

print("ResNet50 model loaded.")

preprocess = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225]),
])


def predict_sneaker_class(img_path, threshold=CONFIDENCE_THRESHOLD):
    img = Image.open(img_path).convert("RGB")
    x = preprocess(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits = classifier_model(x)
        probs = torch.softmax(logits, dim=1)[0]

    top_idx = int(torch.argmax(probs))
    conf = float(probs[top_idx])
    class_name = class_map[top_idx]

    print(f"Predicted class: {class_name}")
    print(f"Confidence: {conf:.3f}")

    if conf < threshold:
        print("→ This image is probably NOT a known sneaker (below threshold).")
    else:
        print("→ This image is a valid sneaker match.")

    return class_name, conf

predict_sneaker_class("testing images/jordan1 low.jpeg")




  state = torch.load(MODEL_PATH, map_location=DEVICE)


ResNet50 model loaded.
Predicted class: nike_air_jordan_1_low
Confidence: 0.976
→ This image is a valid sneaker match.


('nike_air_jordan_1_low', 0.9757837653160095)

In [3]:
# image_search_class_scoped_v2.py


# ==== CONFIG ====
DATA_ROOT = Path("Scraping_part/goat_data")
INDEX_CACHE_DIR = Path("faiss_cache")
INDEX_CACHE_DIR.mkdir(exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ==== MODEL ====
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", use_safetensors=True).to(DEVICE)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)

# Cache for indices
_index_cache: Dict[str, Tuple[faiss.Index, List[str]]] = {}

# ==== AUGMENTATION ====
def augment_image(img: Image.Image, strength='light'):
    """Apply random augmentations to image"""
    augmented = []
    
    # Original
    augmented.append(img)
    
    if strength in ['light', 'medium', 'heavy']:
        # Brightness variations
        enhancer = ImageEnhance.Brightness(img)
        augmented.append(enhancer.enhance(random.uniform(0.85, 1.15)))
        
        # Contrast variations
        enhancer = ImageEnhance.Contrast(img)
        augmented.append(enhancer.enhance(random.uniform(0.9, 1.1)))
    
    if strength in ['medium', 'heavy']:
        # Color saturation
        enhancer = ImageEnhance.Color(img)
        augmented.append(enhancer.enhance(random.uniform(0.9, 1.1)))
        
        # Slight rotation
        augmented.append(img.rotate(random.uniform(-10, 10), fillcolor=(255, 255, 255)))
        
        # Zoom crop (simulates different distances)
        w, h = img.size
        crop_size = int(min(w, h) * random.uniform(0.85, 0.95))
        left = random.randint(0, w - crop_size)
        top = random.randint(0, h - crop_size)
        cropped = img.crop((left, top, left + crop_size, top + crop_size))
        augmented.append(cropped.resize((w, h), Image.LANCZOS))
    
    if strength == 'heavy':
        # Horizontal flip (only if makes sense for your use case)
        augmented.append(img.transpose(Image.FLIP_LEFT_RIGHT))
        
        # Sharpness
        enhancer = ImageEnhance.Sharpness(img)
        augmented.append(enhancer.enhance(random.uniform(0.8, 1.2)))
    
    return augmented

def embed_image(path: Path, augment=False, aug_strength='light'):
    """Embed image with optional augmentation"""
    img = Image.open(path).convert("RGB")
    
    if not augment:
        inputs = processor(images=img, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            emb = model.get_image_features(**inputs)
        v = emb[0].cpu().numpy()
        return v / np.linalg.norm(v)
    
    # Multi-augmentation embedding
    imgs = augment_image(img, strength=aug_strength)
    embeddings = []
    
    for aug_img in imgs:
        inputs = processor(images=aug_img, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            emb = model.get_image_features(**inputs)
        v = emb[0].cpu().numpy()
        embeddings.append(v / np.linalg.norm(v))
    
    # Return mean of augmented embeddings
    return np.mean(embeddings, axis=0)

def build_class_index(class_dir: Path, augment_index=False, aug_per_image=5):
    """
    Build FAISS index for a class. Optionally augment each image multiple times.
    
    Args:
        class_dir: Path to class directory
        augment_index: If True, create multiple augmented versions per image
        aug_per_image: Number of augmented versions per original image
    """
    vectors, paths = [], []

    for slug_dir in class_dir.iterdir():
        if not slug_dir.is_dir():
            continue
        for ext in ("*.jpg", "*.jpeg", "*.png"):
            for img_path in slug_dir.glob(ext):
                try:
                    if not augment_index:
                        v = embed_image(img_path)
                        vectors.append(v.astype("float32"))
                        paths.append(str(img_path))
                    else:
                        # Create multiple augmented embeddings per image
                        img = Image.open(img_path).convert("RGB")
                        augmented_imgs = augment_image(img, strength='medium')
                        
                        # Limit augmentations
                        for aug_img in augmented_imgs[:aug_per_image]:
                            inputs = processor(images=aug_img, return_tensors="pt").to(DEVICE)
                            with torch.no_grad():
                                emb = model.get_image_features(**inputs)
                            v = emb[0].cpu().numpy()
                            v_norm = v / np.linalg.norm(v)
                            vectors.append(v_norm.astype("float32"))
                            paths.append(str(img_path))  # Same path for all augmentations
                            
                except Exception as e:
                    print(f"Skip {img_path}: {e}")

    if not vectors:
        raise RuntimeError(f"No images found under {class_dir.resolve()}")

    arr = np.stack(vectors)
    faiss.normalize_L2(arr)
    
    # Use IndexFlatIP for exact cosine similarity
    index = faiss.IndexFlatIP(arr.shape[1])
    index.add(arr)
    
    unique_images = len(set(paths))
    print(f"Indexed {len(paths)} embeddings ({unique_images} unique images) for {class_dir.name}")
    return index, paths

def get_or_build_index(class_name: str, rebuild=False, augment_index=False):
    """Get cached index or build new one"""
    if class_name in _index_cache and not rebuild:
        return _index_cache[class_name]
    
    cache_file = INDEX_CACHE_DIR / f"{class_name}.pkl"
    
    # Try loading from disk cache
    if cache_file.exists() and not rebuild:
        print(f"Loading cached index for {class_name}...")
        with open(cache_file, 'rb') as f:
            index_data = pickle.load(f)
        _index_cache[class_name] = (
            faiss.deserialize_index(index_data['index']),
            index_data['paths']
        )
        return _index_cache[class_name]
    
    # Build new index
    class_dir = DATA_ROOT / class_name
    index, paths = build_class_index(class_dir, augment_index=augment_index)
    
    # Cache to disk
    with open(cache_file, 'wb') as f:
        pickle.dump({
            'index': faiss.serialize_index(index),
            'paths': paths
        }, f)
    
    _index_cache[class_name] = (index, paths)
    return index, paths

def search_in_class(query_img, class_name, top_k=5, use_query_augmentation=True, 
                   augment_index=False, rebuild_index=False):
    """
    Search within one class with improved accuracy.
    
    Args:
        query_img: Path to query image
        class_name: Class directory name
        top_k: Number of results to return
        use_query_augmentation: Average multiple augmented query embeddings
        augment_index: Build index with augmented images (slower, better recall)
        rebuild_index: Force rebuild cached index
    """
    # Get or build index
    index, paths = get_or_build_index(class_name, rebuild=rebuild_index, 
                                     augment_index=augment_index)
    
    # Embed query with optional augmentation
    if use_query_augmentation:
        print("Using query-time augmentation...")
        qvec = embed_image(Path(query_img), augment=True, aug_strength='medium')
    else:
        qvec = embed_image(Path(query_img))
    
    qvec = qvec.astype("float32").reshape(1, -1)
    faiss.normalize_L2(qvec)
    
    # Search with more candidates for deduplication
    search_k = top_k * 10 if augment_index else top_k
    sims, idxs = index.search(qvec, search_k)
    
    # Deduplicate results (same path might appear multiple times due to augmentation)
    seen_paths = {}
    for i, score in zip(idxs[0], sims[0]):
        path = paths[i]
        if path not in seen_paths or score > seen_paths[path]:
            seen_paths[path] = score
    
    # Sort and take top_k unique results
    results = sorted(seen_paths.items(), key=lambda x: x[1], reverse=True)[:top_k]
    
    print(f"\nTop {top_k} matches in {class_name}:")
    for rank, (path, score) in enumerate(results, start=1):
        # Extract model name from path
        model_name = Path(path).parent.name
        print(f"{rank:>2}. {model_name} | {Path(path).name} (sim: {score:.3f})")
    
    return results

# ==== USAGE EXAMPLES ====
if __name__ == "__main__":
    #use model to predict image class
    
    # Example 1: Basic search with query augmentation (RECOMMENDED)
    search_in_class(
        query_img="testing images/jordan 1 low 2).jpeg",
        class_name="nike_air_jordan_1_low",
        top_k=5,
        use_query_augmentation=True  # Fast, significant accuracy boost
    )
    
    # Example 2: Build augmented index (slower but better for small datasets)
    # Run this once to rebuild your indices with augmentation
    # search_in_class(
    #     query_img="path/to/query.jpg",
    #     class_name="",
    #     top_k=5,
    #     use_query_augmentation=True,
    #     augment_index=True,  # 5x more embeddings in index
    #     rebuild_index=True   # Force rebuild
    # )
    
    # Example 3: Rebuild all class indices with augmentation
    # for class_dir in DATA_ROOT.iterdir():
    #     if class_dir.is_dir():
    #         get_or_build_index(class_dir.name, rebuild=True, augment_index=True)


Loading cached index for nike_air_jordan_1_low...
Using query-time augmentation...

Top 5 matches in nike_air_jordan_1_low:
 1. air-jordan-1-low-alt-ps-cobalt-bliss-fn7376-400 | FN7376_400.png.png (sim: 0.850)
 2. air-jordan-1-low-black-white-dark-gum-hv5968-001 | 1526203_02.jpg.jpeg (sim: 0.849)
 3. air-jordan-1-low-spruce-aura-cw1381-003 | 601728_08.jpg.jpeg (sim: 0.847)
 4. air-jordan-1-low-iron-grey-553558-152 | 1452846_02.jpg.jpeg (sim: 0.846)
 5. air-jordan-1-low-alt-td-black-medium-olive-dr9747-092 | DR9747_092.png.png (sim: 0.845)


In [41]:
#example full usage:
# ===== BLOCK 3: Example end-to-end usage =====

QUERY_IMAGE = "testing images/jordan1 low.jpeg"   # change as needed

# 1) Predict sneaker class using ResNet50
predicted_class, confidence = predict_sneaker_class(QUERY_IMAGE)

# 2) Only search if confidence is above your threshold
if confidence >= CONFIDENCE_THRESHOLD:
    print(f"\nSearching inside predicted class: {predicted_class}\n")
    
    # 3) Run FAISS search (uses your full implementation)
    results = search_in_class(
        query_img=QUERY_IMAGE,
        class_name=predicted_class,
        top_k=5,
        use_query_augmentation=True
    )
else:
    print("\nConfidence too low — skipping FAISS similarity search.")


Predicted class: nike_air_jordan_1_low
Confidence: 0.976
→ This image is a valid sneaker match.

Searching inside predicted class: nike_air_jordan_1_low

Using query-time augmentation...

Top 5 matches in nike_air_jordan_1_low:
 1. wmns-air-jordan-1-low-french-blue-dc0774-402 | 1274656_08.jpg.jpeg (sim: 0.856)
 2. air-jordan-1-low-golf-midnight-navy-dd9315-104 | 1114577_03.jpg.jpeg (sim: 0.856)
 3. air-jordan-1-low-gs-usa-cv9844-400 | 672670_08.jpg.jpeg (sim: 0.853)
 4. wmns-air-jordan-1-low-marina-blue-dc0774-114 | 913992_08.jpg.jpeg (sim: 0.850)
 5. air-jordan-1-low-gs-game-royal-553560-124 | 664458_08.jpg.jpeg (sim: 0.847)


In [4]:
import pandas as pd
from catboost import CatBoostRegressor

# ===============================
# Load trained model
# ===============================
model_path = "sneaker_price_model.cbm"

model = CatBoostRegressor()
model.load_model(model_path)

# ===============================
# Example input
# Replace these values with real inputs
# ===============================
input_data = {
    "class_name": "adidas-forum-high",
    "brand": "adidas",
    "silhouette": "forum",
    "retail_price_usd": 230, # jeeba min l excel ba3d ma tle2e l sneaker wa 3mella input
    "release_age": 3 # jeeba min l excel ba3d ma tle2e l sneaker aw 3mella input
}

# Convert to DataFrame (CatBoost expects tabular format)
df_input = pd.DataFrame([input_data])

# ===============================
# Predict
# ===============================
predicted_price = model.predict(df_input)[0]

print(f"Predicted Resale Price: ${predicted_price:.2f}")


Predicted Resale Price: $91.93


In [45]:
import pandas as pd

# Load CSV once
df = pd.read_csv("Scraping_part/scraper/data/products_nodup.csv")
df["slug"] = df["slug"].astype(str).str.lower().str.strip()

# release_age
df["release_date"] = pd.to_datetime(df["release_date"], errors="coerce")
df["release_age"] = 2025 - df["release_date"].dt.year

def get_input_data(slug: str):
    slug = slug.lower().strip()
    row = df[df["slug"] == slug]

    if row.empty:
        raise ValueError(f"Slug not found: {slug}")

    r = row.iloc[0]

    return {
        "class_name": r["class_name"],
        "brand": r["brand"],
        "silhouette": r["silhouette"],
        "retail_price_usd": float(r["retail_price_usd"]),
        "release_age": int(r["release_age"]) if pd.notna(r["release_age"]) else None,
    }

# Example
if __name__ == "__main__":
    print(get_input_data("forum-84-high-white-black-gy5847"))


{'class_name': 'adidas-forum-high', 'brand': 'adidas', 'silhouette': 'Forum', 'retail_price_usd': 120.0, 'release_age': 3}
