In [1]:
# --- SEGMENT 0: INSTALL DEPENDENCIES ---
print("Installing required libraries for AI Tagger...")
# Install a compatible version of NumPy first
!pip install -q "numpy<2.0"
# Install CLIP, scikit-learn for clustering, and webcolors for color names
!pip install -q git+https://github.com/openai/CLIP.git
!pip install -q scikit-learn webcolors
print("✅ Dependencies installed.")


Installing required libraries for AI Tagger...
  Preparing metadata (setup.py) ... [?25l[?25hdone
✅ Dependencies installed.


In [3]:
import os
import torch
import clip
from PIL import Image
from google.colab import drive
from tqdm.notebook import tqdm
import numpy as np
import json
from sklearn.cluster import KMeans
import webcolors

# ==============================================================================
#                                   PART 1: SETUP
# ==============================================================================
print("--- Setting up the environment ---")
drive.mount('/content/drive')

# --- Define Paths ---
CLOTH_DIR = '/content/drive/MyDrive/VirtualFIT_Models/viton_hd_zips/test/cloth'
OUTPUT_JSON_PATH = '/content/drive/MyDrive/VirtualFIT_Models/metadata.json'

if not os.path.exists(CLOTH_DIR):
    raise FileNotFoundError(f"The specified folder was not found: {CLOTH_DIR}")

os.makedirs(os.path.dirname(OUTPUT_JSON_PATH), exist_ok=True)

# --- Load the CLIP Model ---
print("\nLoading the pre-trained CLIP model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
print("✅ CLIP model loaded successfully.")


# ==============================================================================
#                  PART 2: DEFINE THE CLASSIFICATION FUNCTIONS
# ==============================================================================

def classify_image_attribute(image_path, text_prompts):
    """Uses CLIP to find the best text description for an image."""
    try:
        image = Image.open(image_path)
        image_input = preprocess(image).unsqueeze(0).to(device)
        text_inputs = clip.tokenize(text_prompts).to(device)
        with torch.no_grad():
            logits_per_image, _ = model(image_input, text_inputs)
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()
        best_match_index = probs.argmax()
        return text_prompts[best_match_index]
    except Exception as e:
        print(f"Could not process {os.path.basename(image_path)}: {e}")
        return None

def get_dominant_color_rgb(image_path, num_clusters=3):
    """Finds the most dominant color in the center of an image."""
    try:
        img = Image.open(image_path).convert("RGB")

        # Crop the image to the central 50% to focus on the garment
        width, height = img.size
        left = width * 0.25
        top = height * 0.25
        right = width * 0.75
        bottom = height * 0.75
        img_cropped = img.crop((left, top, right, bottom))

        img_array = np.array(img_cropped)
        pixels = img_array.reshape(-1, 3)

        # Filter out any remaining white/light gray background pixels
        non_background_pixels = pixels[pixels.mean(axis=1) < 245]
        if len(non_background_pixels) < 1:
            return (255, 255, 255) # Default to white

        kmeans = KMeans(n_clusters=num_clusters, n_init='auto', random_state=0)
        kmeans.fit(non_background_pixels)

        unique, counts = np.unique(kmeans.labels_, return_counts=True)
        dominant_cluster_index = unique[counts.argmax()]
        dominant_color = kmeans.cluster_centers_[dominant_cluster_index].astype(int)

        return tuple(dominant_color)
    except Exception:
        return (0, 0, 0)

def categorize_rgb_to_name(rgb_tuple):
    """Categorizes an RGB tuple into a simple color name."""
    r, g, b = rgb_tuple

    if r > 220 and g > 220 and b > 220: return "white"
    if r < 40 and g < 40 and b < 40: return "black"
    if abs(r - g) < 20 and abs(r - b) < 20 and abs(g - b) < 20: return "gray"

    if r > g and r > b: # Red-dominant
        if g > 100: return "orange" if r > 200 else "brown"
        if r > 150 and b > 100: return "pink"
        return "red"
    if g > r and g > b: return "green"
    if b > r and b > g:
        if r > 100: return "purple"
        return "blue"
    if r > 180 and g > 180 and b < 100: return "yellow"

    return "unknown"


# ==============================================================================
#                 PART 3: PROCESS ALL IMAGES AND GENERATE METADATA
# ==============================================================================
print(f"\n--- Analyzing and organizing images from: {CLOTH_DIR} ---")

style_prompts = ["a casual style item", "a formal style item", "an elegant evening top", "a sportswear item", "a professional business item"]
sleeve_prompts = ["a full-sleeve item", "a half-sleeve item", "a sleeveless item"]
season_prompts = ["a summer item", "a winter item", "an autumn item", "a spring item"]

all_metadata = []
image_files = [f for f in os.listdir(CLOTH_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

if not image_files:
    print("No images found in the specified directory.")
else:
    for filename in tqdm(image_files, desc="Processing Images"):
        source_path = os.path.join(CLOTH_DIR, filename)

        style_result = classify_image_attribute(source_path, style_prompts)
        sleeve_result = classify_image_attribute(source_path, sleeve_prompts)
        season_result = classify_image_attribute(source_path, season_prompts)
        dominant_rgb = get_dominant_color_rgb(source_path)
        color_name = categorize_rgb_to_name(dominant_rgb)

        # Robust logic to extract the specific attribute
        def get_attribute(result_string):
            if not result_string: return "unknown"
            words = result_string.split(" ")
            if "evening" in words: return "elegant"
            if "business" in words: return "business"
            return words[1]

        style = get_attribute(style_result)
        sleeve = get_attribute(sleeve_result).replace("-", " ")
        season = get_attribute(season_result)
        cloth_type = "top"

        all_metadata.append({
            "filename": filename,
            "type": cloth_type,
            "color": color_name,
            "style": style,
            "sleeve_length": sleeve,
            "season": season
        })

    with open(OUTPUT_JSON_PATH, 'w') as f:
        json.dump(all_metadata, f, indent=2)

    print("\n--- ✅ Metadata Generation Complete ---")
    print(f"The file 'metadata.json' has been saved to: {OUTPUT_JSON_PATH}")
    print("\n--- Sample of Generated Metadata ---")
    print(json.dumps(all_metadata[:5], indent=2))


--- Setting up the environment ---
Mounted at /content/drive

Loading the pre-trained CLIP model...


100%|███████████████████████████████████████| 338M/338M [00:17<00:00, 20.5MiB/s]


✅ CLIP model loaded successfully.

--- Analyzing and organizing images from: /content/drive/MyDrive/VirtualFIT_Models/viton_hd_zips/test/cloth ---


Processing Images:   0%|          | 0/2032 [00:00<?, ?it/s]


--- ✅ Metadata Generation Complete ---
The file 'metadata.json' has been saved to: /content/drive/MyDrive/VirtualFIT_Models/metadata.json

--- Sample of Generated Metadata ---
[
  {
    "filename": "07782_00.jpg",
    "type": "top",
    "color": "red",
    "style": "elegant",
    "sleeve_length": "full sleeve",
    "season": "autumn"
  },
  {
    "filename": "07118_00.jpg",
    "type": "top",
    "color": "gray",
    "style": "sportswear",
    "sleeve_length": "full sleeve",
    "season": "winter"
  },
  {
    "filename": "07581_00.jpg",
    "type": "top",
    "color": "white",
    "style": "sportswear",
    "sleeve_length": "sleeveless",
    "season": "summer"
  },
  {
    "filename": "09018_00.jpg",
    "type": "top",
    "color": "black",
    "style": "sportswear",
    "sleeve_length": "sleeveless",
    "season": "spring"
  },
  {
    "filename": "08263_00.jpg",
    "type": "top",
    "color": "orange",
    "style": "elegant",
    "sleeve_length": "sleeveless",
    "season": "sprin