In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import os

class ImageCaptionDataset(Dataset):
    def __init__(self, df, image_dir, tokenizer, transform):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_dir, row["filename"])
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        caption = row["caption"]
        text_inputs = self.tokenizer(
        caption,
        padding="max_length",
        truncation=True,
        max_length=32,
        return_tensors="pt")

        input_ids = text_inputs.input_ids.squeeze(0).long()
        attention_mask = text_inputs.attention_mask.squeeze(0)

        return image, input_ids, attention_mask



In [None]:
import torch.nn as nn
import torchvision.models as models

class VisionEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        base = models.resnet18(pretrained=True)
        base.fc = nn.Identity()
        self.backbone = base
        self.head = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, embed_dim)
        )

    def forward(self, x):
        x = self.backbone(x)
        return nn.functional.normalize(self.head(x), dim=1)


In [None]:
from transformers import AutoTokenizer

class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=512, hidden_size=768, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
        self.gru = nn.GRU(hidden_size, embed_dim, num_layers=num_layers, batch_first=True)

    def forward(self, input_ids, attention_mask):
        x = self.embedding(input_ids)
        packed_out, _ = self.gru(x)
        x = packed_out[:, -1, :]  # last token
        return nn.functional.normalize(x, dim=1)


In [None]:
class MiniCLIP(nn.Module):
    def __init__(self, vision_encoder, text_encoder):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.text_encoder = text_encoder
        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07)))

    def forward(self, images, input_ids, attention_mask):
        img_embed = self.vision_encoder(images)        # [B, D]
        txt_embed = self.text_encoder(input_ids, attention_mask)  # [B, D]

        # Cosine similarity scaled
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * img_embed @ txt_embed.T
        logits_per_text = logits_per_image.T

        return logits_per_image, logits_per_text


In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [None]:
import torch.nn.functional as F

def clip_loss(logits_per_image, logits_per_text):
    B1, B2 = logits_per_image.shape
    assert B1 == B2, f"Expected square logits: got {logits_per_image.shape}"
    labels = torch.arange(B1).to(logits_per_image.device)
    loss_i = F.cross_entropy(logits_per_image, labels)
    loss_t = F.cross_entropy(logits_per_text, labels)
    return (loss_i + loss_t) / 2


In [None]:
import torch
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

def compute_recall_at_k(model, val_loader, device, k_list=[1, 5]):
    model.eval()
    all_img_embeds = []
    all_txt_embeds = []

    with torch.no_grad():
        for images, input_ids, attention_mask in val_loader:
            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            img_embed = model.vision_encoder(images)       # [B, D]
            txt_embed = model.text_encoder(input_ids, attention_mask)  # [B, D]

            all_img_embeds.append(img_embed)
            all_txt_embeds.append(txt_embed)

    img_embeds = torch.cat(all_img_embeds, dim=0).cpu().numpy()
    txt_embeds = torch.cat(all_txt_embeds, dim=0).cpu().numpy()

    # Cosine similarity
    sim_matrix = cosine_similarity(txt_embeds, img_embeds)  # [N, N]
    ranks = np.argsort(-sim_matrix, axis=1)  # Descending

    recalls = {}
    for k in k_list:
        hits = [(i in ranks[i, :k]) for i in range(len(ranks))]
        recalls[f"Recall@{k}"] = np.mean(hits) * 100

    return recalls


def train_clip_model(model, train_loader, val_loader, epochs=40, lr=1e-4, save_path="/kaggle/working/miniclip_best.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

    best_loss = float("inf")
    train_losses = []
    val_recalls1 = []
    val_recalls5 = []

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        loop = tqdm(train_loader, desc=f"üìö Epoch {epoch+1}/{epochs}")

        for images, input_ids, attention_mask in loop:
            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            optimizer.zero_grad()

            logits_i, logits_t = model(images, input_ids, attention_mask)

            # Contrastive loss
            B1, B2 = logits_i.shape
            assert B1 == B2, f"Expected square logits: got {logits_i.shape}"
            labels = torch.arange(B1).to(device)
            loss_i = F.cross_entropy(logits_i, labels)
            loss_t = F.cross_entropy(logits_t, labels)
            loss = (loss_i + loss_t) / 2

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            avg_loss = total_loss / (loop.n + 1)
            loop.set_postfix(loss=loss.item(), avg=avg_loss)

        avg_epoch_loss = total_loss / len(train_loader)
        train_losses.append(avg_epoch_loss)
        scheduler.step(avg_epoch_loss)
        print(f"\n‚úÖ Epoch {epoch+1}: Avg Loss = {avg_epoch_loss:.4f}")

        # üß™ Validation
        recalls = compute_recall_at_k(model, val_loader, device)
        val_recalls1.append(recalls["Recall@1"])
        val_recalls5.append(recalls["Recall@5"])
        print(f"üéØ Recall@1: {recalls['Recall@1']:.2f}%  |  Recall@5: {recalls['Recall@5']:.2f}%")

        # üíæ Save best
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            torch.save(model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict(), save_path)
            print(f"üíæ Best model saved (loss: {best_loss:.4f}) ‚Üí {save_path}")

    return train_losses, val_recalls1, val_recalls5


In [None]:
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
import os
import pandas as pd
from tqdm import tqdm
import pandas as pd

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()

In [None]:
# Load your caption CSV
df = pd.read_csv("/kaggle/input/blip-captions/blip_captions (1).csv")
image_dir = "/kaggle/input/clothe/clothes_tryon_dataset/train/cloth"

In [None]:
all_embeddings = []
all_filenames = []

batch_size = 32
for i in tqdm(range(0, len(df), batch_size), desc=" Encoding Images"):
    batch_df = df.iloc[i:i+batch_size]
    batch_imgs = [Image.open(os.path.join(image_dir, fname)).convert("RGB") for fname in batch_df["filename"]]
    
    inputs = processor(images=batch_imgs, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        img_feats = model.get_image_features(**inputs)
        img_feats = torch.nn.functional.normalize(img_feats, dim=1)
    
    all_embeddings.append(img_feats.cpu())
    all_filenames.extend(batch_df["filename"].tolist())

# Save
image_embeds = torch.cat(all_embeddings, dim=0)  # [N, 512]
torch.save({"embeds": image_embeds, "files": all_filenames}, "/kaggle/working/clip_image_embeds.pt")
print(" Saved CLIP image embeddings.")

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

def search_clip(query, model, processor, image_embeds, image_filenames, top_k=5):
    inputs = processor(text=[query], return_tensors="pt").to(device)
    with torch.no_grad():
        text_feat = model.get_text_features(**inputs)
        text_feat = torch.nn.functional.normalize(text_feat, dim=1)

    sims = cosine_similarity(text_feat.cpu().numpy(), image_embeds.numpy())[0]
    top_idx = sims.argsort()[::-1][:top_k]
    return [(image_filenames[i], sims[i]) for i in top_idx]


In [None]:
import matplotlib.pyplot as plt

def show_results(results, image_dir):
    plt.figure(figsize=(15, 3))
    for idx, (fname, score) in enumerate(results):
        path = os.path.join(image_dir, fname)
        img = Image.open(path).convert("RGB")
        plt.subplot(1, len(results), idx + 1)
        plt.imshow(img)
        plt.title(f"{score:.2f}")
        plt.axis("off")
    plt.tight_layout()
    plt.show()


In [None]:
data = torch.load("/kaggle/input/clip-embed/pytorch/default/1/clip_image_embeds (1).pt")
image_embeds = data["embeds"]
image_filenames = data["files"]

query = "a red frock"
results = search_clip(query, model, processor, image_embeds, image_filenames, top_k=5)
show_results(results, image_dir)

In [None]:
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np
import matplotlib.pyplot as plt

# Load and preprocess the query image
query_image_path = "/kaggle/input/clothe/clothes_tryon_dataset/test/cloth/00094_00.jpg"
img = Image.open(query_image_path).convert("RGB")
inputs = processor(images=img, return_tensors="pt").to(device)

#  Get normalized CLIP image embedding
query_embed = model.get_image_features(**inputs)
query_embed = torch.nn.functional.normalize(query_embed, dim=1)

#  Detach before converting to NumPy
sims = cosine_similarity(query_embed.detach().cpu().numpy(), image_embeds.numpy())[0]

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os

# Config
top_k = 5  # change as needed
image_dir = "/kaggle/input/clothe/clothes_tryon_dataset/train/cloth"

# Get Top-K similar indices
top_indices = sims.argsort()[::-1][:top_k]

# Top-k filenames and scores
top_results = [(image_filenames[i], sims[i]) for i in top_indices]

# Display the query image
plt.figure(figsize=(3, 3))
plt.imshow(img)
plt.title("üîç Query Image")
plt.axis("off")
plt.show()

# Show results
plt.figure(figsize=(12, 4))
for idx, (fname, score) in enumerate(top_results):
    img_path = os.path.join(image_dir, fname)
    img = Image.open(img_path).convert("RGB")
    plt.subplot(1, top_k, idx + 1)
    plt.imshow(img)
    plt.title(f"{score:.2f}")
    plt.axis("off")
plt.tight_layout()
plt.show()


In [None]:
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn.functional as F

def multi_modal_search(
    query_text, 
    query_image, 
    model, 
    processor, 
    image_embeds, 
    image_filenames, 
    top_k=5,
    image_weight=0.5,
    text_weight=0.5
):
    """
    Performs a multi-modal search by combining text and image queries.
    
    Args:
        query_text (str): The text part of the query.
        query_image (PIL.Image): The image part of the query.
        model: The pre-trained CLIP model.
        processor: The CLIP processor.
        image_embeds (torch.Tensor): Pre-computed embeddings of the image database.
        image_filenames (list): List of filenames corresponding to the embeddings.
        top_k (int): Number of top results to return.
        image_weight (float): The influence of the image query.
        text_weight (float): The influence of the text query.
    """
    device = next(model.parameters()).device

    # 1. Process and encode the text query
    text_inputs = processor(text=[query_text], return_tensors="pt").to(device)
    with torch.no_grad():
        text_feat = model.get_text_features(**text_inputs)
        text_feat = F.normalize(text_feat, dim=1)

    # 2. Process and encode the image query
    image_inputs = processor(images=query_image, return_tensors="pt").to(device)
    with torch.no_grad():
        img_feat = model.get_image_features(**image_inputs)
        img_feat = F.normalize(img_feat, dim=1)

    # 3. Combine the embeddings (the core of multi-modal search)
    # We create the final query vector using a weighted average of the two modalities.
    combined_feat = (image_weight * img_feat) + (text_weight * text_feat)
    
    # 4. Normalize the combined embedding
    # This is crucial to ensure it's a valid vector for cosine similarity.
    combined_feat = F.normalize(combined_feat, dim=1)

    # 5. Perform search against the database
    # We compare our new combined feature vector against all pre-computed image embeddings.
    sims = cosine_similarity(combined_feat.cpu().numpy(), image_embeds.numpy())[0]
    
    # 6. Get top results
    top_idx = sims.argsort()[::-1][:top_k]
    
    return [(image_filenames[i], sims[i]) for i in top_idx]

In [None]:
from PIL import Image
import os
import matplotlib.pyplot as plt

# --- Setup: Make sure these variables from your previous code are loaded ---
# model, processor, device
# data = torch.load("/kaggle/input/embdclip/pytorch/default/1/clip_image_embeds.pt")
# image_embeds = data["embeds"]
# image_filenames = data["files"]
# image_dir = "/kaggle/input/clothe/clothes_tryon_dataset/train/cloth"
# --------------------------------------------------------------------------


# 1. Define your multi-modal query
query_image_path = "/kaggle/input/clothe/clothes_tryon_dataset/test/cloth/00075_00.jpg" # A striped t-shirt
query_text = "Give me similar tshirts but with long sleeves" # We want to find something like the image, but as a blue sweater.

# 2. Load the query image
query_image = Image.open(query_image_path).convert("RGB")

# Display the query image and text
print(f"üîç Image Query: An item with this pattern/shape.")
print(f"‚úçÔ∏è Text Query: '{query_text}'")
plt.figure(figsize=(3, 3))
plt.imshow(query_image)
plt.axis("off")
plt.show()

# 3. Run the multi-modal search
# Let's give slightly more weight to the text to emphasize the "blue sweater" aspect.
results = multi_modal_search(
    query_text, 
    query_image, 
    model, 
    processor, 
    image_embeds, 
    image_filenames, 
    top_k=5,
    image_weight=0.4, # Less weight on the original t-shirt pattern
    text_weight=0.6   # More weight on the "blue sweater" concept
)

# 4. Show the results (using your existing `show_results` function)
print("\nTop 5 Multi-Modal Search Results:")
show_results(results, image_dir)

In [None]:
!pip install diffusers transformers accelerate --quiet

In [None]:
import torch
from diffusers import StableDiffusionPipeline

# Make sure the device is set correctly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained Stable Diffusion model
# We use float16 for faster inference and less memory usage on the GPU
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", 
    torch_dtype=torch.float16
)
pipe = pipe.to(device)

print("Stable Diffusion pipeline loaded successfully.")

In [None]:
from PIL import Image

def generate_clothing_design(prompt, num_inference_steps=50, guidance_scale=7.5):
    """
    Generates an image of a clothing item from a text prompt using Stable Diffusion.
    
    Args:
        prompt (str): The user's description of the clothing item.
        num_inference_steps (int): Number of denoising steps. Higher is better quality but slower.
        guidance_scale (float): How much to adhere to the prompt. Higher is stricter.
        
    Returns:
        PIL.Image: The generated image.
    """
    
    # --- Prompt Engineering ---
    # We add keywords to focus the model on generating a clean product image.
    # This is a critical step for getting good results.
    enhanced_prompt = (
        f"A photorealistic image of {prompt}, fashion design, "
        "product shot, studio lighting, on a mannequin, white background, 8k, whole outfit should be visible"
    )
    
    print(f"Generating image with enhanced prompt: '{enhanced_prompt}'")
    
    # Run the diffusion pipeline
    with torch.no_grad():
        # The 'pipe' returns an object with the generated images
        result = pipe(
            enhanced_prompt, 
            num_inference_steps=num_inference_steps, 
            guidance_scale=guidance_scale
        )
        generated_image = result.images[0]
    
    return generated_image

In [None]:
import matplotlib.pyplot as plt

# 1. Define the user's dream clothing item
user_prompt = "a black tshirt with graphic detailed logo on it"

# 2. Generate the design
# This will take a moment to run on the GPU
new_design_image = generate_clothing_design(user_prompt)

# 3. Visualize the newly created design
print("\n‚ú® Generated Design:")
plt.figure(figsize=(6, 6))
plt.imshow(new_design_image)
plt.axis("off")
plt.show()

In [None]:
# We need a function for image-to-image search. 
# We can adapt your previous `search_clip` or create a specific one.

def search_by_image(query_image, model, processor, image_embeds, image_filenames, top_k=5):
    """Performs an image-to-image search using CLIP."""
    device = next(model.parameters()).device
    
    # Preprocess the query image and get its embedding
    inputs = processor(images=query_image, return_tensors="pt").to(device)
    with torch.no_grad():
        query_embed = model.get_image_features(**inputs)
        query_embed = F.normalize(query_embed, dim=1)

    # Perform cosine similarity search
    sims = cosine_similarity(query_embed.cpu().numpy(), image_embeds.numpy())[0]
    top_indices = sims.argsort()[::-1][:top_k]
    
    return [(image_filenames[i], sims[i]) for i in top_indices]

# --- Let's run the integration ---

print("\nUsing the generated design to find similar items in our database...")

# 1. Use the 'new_design_image' we just created as the query
#    (Ensure your other variables like model, processor, image_embeds are loaded)
search_results = search_by_image(
    new_design_image,
    model,
    processor,
    image_embeds,
    image_filenames,
    top_k=5
)

# 2. Show the results (using your existing show_results function)
print("\nTop 5 Real Items Similar to the Generated Design:")
show_results(search_results, image_dir)

In [None]:
import torch
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
import numpy as np
import cv2 # We'll use OpenCV to create a mask programmatically

# Ensure you are on a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the inpainting pipeline. This is different from the standard StableDiffusionPipeline.
# This model is specifically designed to understand image+mask+prompt inputs.
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    torch_dtype=torch.float16,
).to(device)

print("Stable Diffusion Inpainting pipeline loaded successfully.")

In [None]:
import matplotlib.pyplot as plt

# --- Setup: Make sure these variables are loaded from your previous code ---
# image_dir = "/kaggle/input/clothe/clothes_tryon_dataset/train/cloth"
# --------------------------------------------------------------------------

# 1. Select the base image we want to edit
base_image_path = os.path.join(image_dir, "00010_00.jpg") # A plain black t-shirt
base_image = Image.open(base_image_path).convert("RGB").resize((512, 512))

# 2. Create the mask programmatically
# In a real app, this mask would be drawn by the user with a brush tool.
# Here, we create a black image and draw a white rectangle on it to define the edit area.
mask_image = np.zeros((512, 512), dtype=np.uint8)
# Let's define a rectangle on the chest of the t-shirt
cv2.rectangle(mask_image, (160, 150), (350, 300), (255, 255, 255), -1) # -1 fills the rectangle
mask_image = Image.fromarray(mask_image)

# 3. Visualize the inputs to understand what we're doing
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(base_image)
plt.title("1. Base Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(mask_image, cmap='gray')
plt.title("2. Mask (White area will be edited)")
plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
def create_edited_design(base_image, mask_image, prompt, pipe):
    """
    Generates an edited image using the Stable Diffusion Inpainting pipeline.
    
    Args:
        base_image (PIL.Image): The original image.
        mask_image (PIL.Image): The mask defining the area to change.
        prompt (str): The text description of the desired change.
        pipe: The loaded inpainting pipeline.
        
    Returns:
        PIL.Image: The new, edited image.
    """
    print(f"Applying new design with prompt: '{prompt}'")
    
    # The inpainting pipeline takes the prompt, base image, and mask image as input.
    with torch.no_grad():
        edited_image = pipe(
            prompt=prompt, 
            image=base_image, 
            mask_image=mask_image
        ).images[0]
        
    return edited_image

# Let's define what we want to paint onto the t-shirt's chest
edit_prompt = "A detailed embroidered patch of a roaring tiger's head, hyperrealistic"

# Run the creative process!
edited_design = create_edited_design(base_image, mask_image, edit_prompt, inpaint_pipe)

# Visualize the final result
plt.figure(figsize=(6, 6))
plt.imshow(edited_design)
plt.title("3. Generated Design!")
plt.axis("off")
plt.show()

In [None]:
def enhance_prompt(base_prompt):
    """
    Takes a simple user prompt and expands it into a detailed, high-quality prompt
    for Stable Diffusion, including a negative prompt.
    
    Args:
        base_prompt (str): The user's simple input (e.g., "a wolf").
        
    Returns:
        tuple: A tuple containing the (enhanced_positive_prompt, negative_prompt).
    """
    
    # This is the core of our prompt engineering.
    # We combine descriptive keywords about style, medium, and quality.
    enhanced_positive_prompt = (
        f"A masterpiece, photorealistic, high-quality embroidered patch of a {base_prompt}. "
        "The design features an intricate, hyper-detailed, and sharp focus. "
        "Rendered with cinematic lighting, 8k ultra-high definition, professional product shot on fabric."
    )
    
    # The negative prompt is crucial for avoiding common image generation errors.
    negative_prompt = (
        "lowres, blurry, bad anatomy, error, worst quality, low quality, "
        "jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, "
        "extra fingers, mutated hands, poorly drawn hands, poorly drawn face, "
        "malformed limbs, extra limbs, cloned face, disfigured, gross proportions, "
        "watermark, grain, text, signature"
    )
    
    return enhanced_positive_prompt, negative_prompt

In [None]:
def create_edited_design_enhanced(base_image, mask_image, user_prompt, pipe):
    """
    Generates a high-quality edited image by first enhancing the user's prompt.
    
    Args:
        base_image (PIL.Image): The original image.
        mask_image (PIL.Image): The mask defining the area to change.
        user_prompt (str): The user's SIMPLE text description.
        pipe: The loaded inpainting pipeline.
        
    Returns:
        PIL.Image: The new, high-quality edited image.
    """
    
    # 1. Enhance the user's simple prompt
    enhanced_prompt, negative_prompt = enhance_prompt(user_prompt)
    
    print("--- Prompt Enhancement ---")
    print(f"Original User Prompt: '{user_prompt}'")
    print(f"‚ú® Enhanced Prompt: '{enhanced_prompt}'")
    print(f"üö´ Negative Prompt: '{negative_prompt}'")
    print("--------------------------")
    
    # 2. Run the pipeline with the new, powerful prompts
    with torch.no_grad():
        edited_image = pipe(
            prompt=enhanced_prompt, 
            image=base_image, 
            mask_image=mask_image,
            negative_prompt=negative_prompt, # Add the negative prompt here
            num_inference_steps=50,          # Can use more steps for higher quality
            guidance_scale=8.5               # A slightly higher guidance can help with detailed prompts
        ).images[0]
        
    return edited_image

In [None]:
# The user provides a very simple, "bad" prompt.
user_prompt = "a tshirt with a tiger head on it" 

# Our system will automatically turn this into a masterpiece.
# We use the new function: create_edited_design_enhanced
edited_design_enhanced = create_edited_design_enhanced(
    base_image, 
    mask_image, 
    user_prompt, 
    inpaint_pipe
)

# Visualize the final, high-quality result
plt.figure(figsize=(8, 8))
plt.imshow(edited_design_enhanced)
plt.title(f"Generated from simple prompt: '{user_prompt}'")
plt.axis("off")
plt.show()

# And of course, you can still use this superior image to search your database
# final_results = search_by_image(edited_design_enhanced, ...)
# show_results(final_results, ...)

In [None]:
# --- Setup: Make sure these variables and functions are loaded from your previous code ---
# model, processor
# image_embeds, image_filenames
# image_dir
# def search_by_image(...)
# def show_results(...)
# user_prompt = "a skull"  # The simple prompt from the user
# edited_design_enhanced # The high-quality image we generated
# ----------------------------------------------------------------------------------------

print("\n------------------------------------------------------------------")
print("‚úÖ Creative Step Complete. Now finding similar items in our database...")
print("------------------------------------------------------------------\n")

# Use the 'edited_design_enhanced' (the high-quality image) as the query image.
# This is the key change.
final_results = search_by_image(
    edited_design_enhanced, # <--- Using the image generated with the enhanced prompt
    model,
    processor,
    image_embeds,
    image_filenames,
    top_k=5
)

# Show the final results from your database
# We use the original, simple 'user_prompt' in the title for clarity.
print(f"Top 5 Real Items similar to the generated '{user_prompt}' design:")
show_results(final_results, image_dir)

In [None]:
import pandas as pd
import re
import json

# Load your captions dataset
df = pd.read_csv("/kaggle/input/blip-captions/blip_captions (1).csv")

# This dictionary is our Fashion Knowledge Base.
# We can easily add more keywords or new categories (e.g., 'material') later.
ATTRIBUTE_KEYWORDS = {
    'brand': [
        'levi', 'adidas', 'nike', 'calvin', 'tommy', 'champion', 'guess', 
        'hollister', 'vans', 'puma', 'superdry', 'lacoste', 'boss'
    ],
    'style': [
        't-shirt', 'long sleeve', 'sweatshirt', 'hoodie', 'bodysuit', 
        'polo', 'blouse', 'crop top', 'tank top', 'cardigan', 'jumper',
        'vest', 'cami'
    ],
    'color': [
        'black', 'white', 'red', 'blue', 'green', 'pink', 'yellow', 'grey', 
        'purple', 'orange', 'brown', 'navy', 'silver', 'gold', 'multi',
        'beige', 'khaki', 'maroon', 'burgundy'
    ],
    'pattern': [
        'floral', 'striped', 'polka dot', 'leopard', 'plaid', 'tie dye', 
        'camo', 'checkered', 'paisley', 'zebra', 'houndstooth', 'print'
    ],
    'neckline': [
        'v-neck', 'scoop neck', 'turtle neck', 'round neck', 'crew neck', 
        'square neck', 'halter neck'
    ]
}

print("Fashion attribute keywords defined.")

In [None]:
def generate_tags(caption, keywords_dict):
    """
    Parses a caption and extracts a set of predefined attribute tags.
    
    Args:
        caption (str): The clothing description.
        keywords_dict (dict): The dictionary of attributes and their keywords.
        
    Returns:
        list: A list of unique tags found in the caption (e.g., ['color_black', 'style_t-shirt']).
    """
    if not isinstance(caption, str):
        return []
        
    found_tags = set()
    caption_lower = caption.lower()
    
    for attribute, keywords in keywords_dict.items():
        for keyword in keywords:
            # Use regex to match whole words to avoid partial matches (e.g., 'red' in 'bordeaux')
            # The \b markers stand for word boundaries.
            if re.search(r'\b' + re.escape(keyword) + r'\b', caption_lower):
                # Create a clean tag format, replacing spaces with underscores
                clean_keyword = keyword.replace(' ', '_')
                found_tags.add(f"{attribute}_{clean_keyword}")
                
    return list(found_tags)

# --- Test the function with one example ---
example_caption = df.iloc[15]['caption'] # "adidas 3 stripes tee - t - shirt - grey heather / white"
example_tags = generate_tags(example_caption, ATTRIBUTE_KEYWORDS)
print(f"Caption: '{example_caption}'")
print(f"Generated Tags: {example_tags}")

In [None]:
print("Generating tags for all captions... This may take a moment.")

# Use the .apply() method to efficiently run the function on each row
df['tags'] = df['caption'].apply(lambda c: generate_tags(c, ATTRIBUTE_KEYWORDS))

print("Tag generation complete!")

# --- Let's save our newly tagged data for future use ---
# Saving as a JSON is often better for columns containing lists.
df.to_json('tagged_fashion_data.json', orient='records', lines=True)

print("\nDataFrame with new 'tags' column:")
# Display the filename and the new tags column to see the result
print(df[['filename', 'caption', 'tags']].head())

In [None]:
# For this step, we'll use the DataFrame 'df' we just created.
# In a real app, you would load the 'tagged_fashion_data.json' file.
tagged_df = df.set_index('filename') # Setting filename as index for fast lookups

def filter_results(initial_results, tagged_data, filters):
    """
    Filters a list of search results based on a dictionary of attribute filters.
    
    Args:
        initial_results (list): A list of tuples (filename, score) from a CLIP search.
        tagged_data (pd.DataFrame): The DataFrame containing the pre-generated tags.
        filters (dict): A dictionary of filters, e.g., {'color': 'red', 'style': 't-shirt'}.
        
    Returns:
        list: A new list of tuples that match all the specified filters.
    """
    if not filters:
        return initial_results
        
    # Get the filenames from the initial search results
    initial_filenames = {filename for filename, score in initial_results}
    
    # Build the list of expected tag strings from the filters dictionary
    expected_tags = {f"{attribute}_{value.replace(' ', '_')}" for attribute, value in filters.items()}
    
    matching_filenames = set()
    
    # Iterate through only the relevant items from our search
    for filename in initial_filenames:
        try:
            item_tags = set(tagged_data.loc[filename, 'tags'])
            # Check if all expected tags are a subset of the item's tags
            if expected_tags.issubset(item_tags):
                matching_filenames.add(filename)
        except KeyError:
            # This can happen if a filename from search isn't in our CSV
            continue
            
    # Rebuild the results list, preserving the original order and scores
    final_results = [(filename, score) for filename, score in initial_results if filename in matching_filenames]
    
    return final_results

print("Filter function created.")

In [None]:
import pandas as pd
from collections import defaultdict

# --- Setup: Load the tagged data we created in the last step ---
# In a real app, you would load this from 'tagged_fashion_data.json'
# df = pd.read_json('tagged_fashion_data.json', orient='records', lines=True)
# --------------------------------------------------------------------

def generate_dropdown_options(tagged_df):
    """
    Scans the tagged DataFrame and creates a dictionary of all available
    filter options for a UI.
    
    Args:
        tagged_df (pd.DataFrame): The DataFrame with the 'tags' column.
        
    Returns:
        dict: A dictionary where keys are attributes (e.g., 'color') and
              values are sorted lists of unique options (e.g., ['black', 'blue', ...]).
    """
    # Use defaultdict to automatically handle the creation of new keys
    options = defaultdict(set)
    
    # Iterate through the 'tags' column
    for tag_list in tagged_df['tags']:
        if not tag_list:  # Skip empty lists
            continue
        for tag in tag_list:
            try:
                # Split 'attribute_value' format (e.g., 'color_black')
                attribute, value = tag.split('_', 1)
                # Replace underscores back to spaces for display
                options[attribute].add(value.replace('_', ' '))
            except ValueError:
                # Handle cases where a tag might not have an underscore
                continue
                
    # Convert sets to sorted lists for consistent ordering in the UI
    sorted_options = {attribute: sorted(list(values)) for attribute, values in options.items()}
    
    return sorted_options

# Generate the options for our UI
dropdown_options = generate_dropdown_options(df)

# --- Display the generated options as if they were in a UI ---
print("--- Dynamically Generated Drop-Down Menus ---")
for attribute, values in dropdown_options.items():
    print(f"\n‚ñº {attribute.capitalize()} Menu:")
    # We'll just print the first 10 for brevity
    print(values[:10]) 
    if len(values) > 10:
        print(f"  (...and {len(values) - 10} more)")

In [None]:
# Make sure the tagged_df is indexed by filename for fast lookups
tagged_df = df.set_index('filename')

def flexible_ranked_filter(initial_results, tagged_data, filters):
    """
    Filters and re-ranks results based on how many tags they match.
    Items matching more tags are ranked higher. The original search score is
    used as a tie-breaker.
    
    Args:
        initial_results (list): A list of tuples (filename, score) from CLIP search.
        tagged_data (pd.DataFrame): The DataFrame with pre-generated tags.
        filters (dict): A dictionary of desired filters, e.g., {'color': 'red'}.
        
    Returns:
        list: A new list of tuples, re-ranked based on tag matches and original score.
    """
    if not filters:
        return initial_results
        
    # Build the set of expected tag strings from the filters dictionary
    expected_tags = {f"{attribute}_{value.replace(' ', '_')}" for attribute, value in filters.items()}
    
    scored_results = []
    
    for filename, original_score in initial_results:
        try:
            item_tags = set(tagged_data.loc[filename, 'tags'])
            
            # Calculate how many of the desired tags this item has
            match_count = len(expected_tags.intersection(item_tags))
            
            # We only add items that match at least one tag to the list.
            # You could change this to `if match_count >= 0:` to include everything.
            if match_count > 0:
                # Create a new tuple with the match count for sorting
                scored_results.append((filename, original_score, match_count))
                
        except KeyError:
            continue
            
    # Sort the results. This is the key step.
    # - We sort by 'match_count' in descending order (primary criteria).
    # - Then, we sort by 'original_score' in descending order (tie-breaker).
    # The `lambda` function defines this multi-level sorting.
    sorted_results = sorted(scored_results, key=lambda x: (x[2], x[1]), reverse=True)
    
    # Rebuild the list into the original (filename, score) format
    final_ranked_list = [(filename, score) for filename, score, match_count in sorted_results]
    
    return final_ranked_list

print("Flexible ranked filtering function created.")

In [None]:
# --- Setup: Make sure all necessary variables and functions are loaded ---
# model, processor, image_embeds, image_filenames, image_dir
# def search_clip(...)
# def show_results(...)
# -------------------------------------------------------------------------

# 1. USER ACTION: Perform a broad initial search
initial_query = "a casual top"
print(f"Performing initial search for: '{initial_query}'")
initial_results = search_clip(initial_query, model, processor, image_embeds, image_filenames, top_k=200) # Get a large pool

# 2. USER ACTION: Apply a very specific combination of filters
# Let's imagine the user is looking for a "pink, striped, v-neck polo". 
# This exact item might not exist in our small dataset.
my_filters = {
    'color': 'pink',
    'pattern': 'striped',
    'neckline': 'v-neck',
    'style': 'polo'
}
print("\n---------------------------------------------------------")
print(f"Applying filters (looking for best matches): {my_filters}")
print("---------------------------------------------------------")

# 3. RUN THE FLEXIBLE FILTERING LOGIC
ranked_results = flexible_ranked_filter(initial_results, tagged_df, my_filters)

# 4. SHOW THE RESULTS
# The results will now be ordered by how many tags they matched.
# A pink, striped v-neck (3 matches) will appear before a pink polo (2 matches),
# which will appear before a generic striped shirt (1 match).
if ranked_results:
    print(f"Showing best matches out of {len(ranked_results)} relevant items found:")
    show_results(ranked_results[:5], image_dir) # Show the top 5 best matches
    
    # Let's inspect the tags of the top result to see why it was chosen
    top_result_filename = ranked_results[0][0]
    top_result_tags = tagged_df.loc[top_result_filename, 'tags']
    print(f"\nTags for the top recommended item ('{top_result_filename}'):")
    print(top_result_tags)
else:
    print("No items matched any of your selected filters.")

In [None]:
# ----------------------------------------------------------------------------------
# Cell 1: Install All Dependencies and Write the Streamlit App File
# ----------------------------------------------------------------------------------

# Step 1: Install a Kaggle-compatible PyTorch first.
!pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118

# Step 2: Install all other packages, using specific, compatible versions for the entire ecosystem.
# This is the final fix that pins the 'peft' library to the version required by diffusers.
!pip install streamlit==1.28.2 pyngrok==7.0.0 transformers==4.34.0 diffusers==0.23.1 accelerate==0.24.1 pandas==2.1.3 scikit-learn==1.3.2 opencv-python-headless==4.8.1.78 peft==0.6.2 --quiet

# Step 3: Write the entire Streamlit app to a file named app.py.


In [None]:
# %%writefile app.py
# import streamlit as st
# import torch
# import pandas as pd
# import os
# import re
# from PIL import Image, ImageDraw, ImageFont
# from collections import defaultdict
# from transformers import CLIPProcessor, CLIPModel
# from diffusers import StableDiffusionInpaintPipeline
# from sklearn.metrics.pairwise import cosine_similarity
# import torch.nn.functional as F

# # =============================================================================
# # 1. PAGE CONFIGURATION & INITIALIZATION
# # =============================================================================
# st.set_page_config(
#     page_title="AI Fashion Stylist Pro",
#     page_icon="ü§ñ",
#     layout="wide"
# )

# # Initialize session state for variables that need to persist across reruns
# if 'search_results' not in st.session_state:
#     st.session_state.search_results = []
# if 'generated_design' not in st.session_state:
#     st.session_state.generated_design = None
# if 'similar_items' not in st.session_state:
#     st.session_state.similar_items = []

# # =============================================================================
# # 2. MODEL & DATA LOADING (Cached to run only once)
# # =============================================================================
# @st.cache_resource
# def load_all_models_and_data():
#     """Loads all models and data files into memory, cached for performance."""
#     print("Executing one-time resource loading...")
    
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
#     clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

#     inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
#         "runwayml/stable-diffusion-inpainting",
#         torch_dtype=torch.float16,
#     ).to(device)

#     # --- KAGGLE-SPECIFIC FILE PATHS ---
#     data_dir = '/kaggle/input/clothe/clothes_tryon_dataset/train/cloth'
#     embeddings_path = "/kaggle/input/clip-embed/pytorch/default/1/clip_image_embeds (1).pt"
#     tagged_data_path = "/kaggle/input/tagged-fashion-data/tagged_fashion_data.json"

#     if not os.path.exists(embeddings_path) or not os.path.exists(tagged_data_path):
#         st.error(f"CRITICAL ERROR: Data files not found in '{data_dir}'. Please ensure you have attached the 'fashion-app-data' Kaggle dataset containing 'clip_image_embeds.pt' and 'tagged_fashion_data.json'.")
#         st.stop()
        
#     embedding_data = torch.load(embeddings_path, map_location='cpu')
#     image_embeds = embedding_data["embeds"]
#     image_filenames = embedding_data["files"]

#     tagged_df = pd.read_json(tagged_data_path, orient='records', lines=True)
#     tagged_df_indexed = tagged_df.set_index('filename')

#     print("Resource loading complete.")
#     return device, clip_model, clip_processor, inpaint_pipe, image_embeds, image_filenames, tagged_df, tagged_df_indexed

# DEVICE, clip_model, clip_processor, inpaint_pipe, image_embeds, image_filenames, tagged_df, tagged_df_indexed = load_all_models_and_data()

# # !!! CRITICAL !!! UPDATE THIS PATH to match your clothing image dataset on Kaggle.
# IMAGE_DIR = "/kaggle/input/clothe/clothes_tryon_dataset/train/cloth/"
# if not os.path.exists(IMAGE_DIR):
#     st.error(f"Image directory not found at '{IMAGE_DIR}'. Please update the IMAGE_DIR path in the `app.py` script to match your Kaggle dataset path.")
#     st.stop()

# # =============================================================================
# # 3. BACKEND FUNCTIONS
# # =============================================================================
# @st.cache_data
# def generate_dropdown_options(_df):
#     options = defaultdict(set)
#     for tag_list in _df['tags']:
#         if not isinstance(tag_list, list): continue
#         for tag in tag_list:
#             try:
#                 attribute, value = tag.split('_', 1)
#                 options[attribute].add(value.replace('_', ' '))
#             except ValueError: continue
#     return {attribute: [""] + sorted(list(values)) for attribute, values in options.items()}

# def search_by_text(query, top_k=200):
#     text_inputs = clip_processor(text=[query], return_tensors="pt").to(DEVICE)
#     with torch.no_grad():
#         text_feat = clip_model.get_text_features(**text_inputs)
#         text_feat = F.normalize(text_feat, dim=1)
#     sims = cosine_similarity(text_feat.cpu().numpy(), image_embeds.numpy())[0]
#     top_idx = sims.argsort()[::-1][:top_k]
#     return [(image_filenames[i], sims[i]) for i in top_idx]

# def search_by_image(query_image, top_k=200):
#     if query_image is None: return []
#     inputs = clip_processor(images=query_image, return_tensors="pt").to(DEVICE)
#     with torch.no_grad():
#         query_embed = clip_model.get_image_features(**inputs)
#         query_embed = F.normalize(query_embed, dim=1)
#     sims = cosine_similarity(query_embed.cpu().numpy(), image_embeds.numpy())[0]
#     top_idx = sims.argsort()[::-1][:top_k]
#     return [(image_filenames[i], sims[i]) for i in top_idx]

# def multi_modal_search(query_text, query_image, image_weight=0.5, top_k=200):
#     text_feat = F.normalize(clip_model.get_text_features(**clip_processor(text=[query_text], return_tensors="pt").to(DEVICE)))
#     img_feat = F.normalize(clip_model.get_image_features(**clip_processor(images=query_image, return_tensors="pt").to(DEVICE)))
#     combined_feat = F.normalize((image_weight * img_feat) + ((1.0 - image_weight) * text_feat))
#     sims = cosine_similarity(combined_feat.cpu().detach().numpy(), image_embeds.numpy())[0]
#     top_idx = sims.argsort()[::-1][:top_k]
#     return [(image_filenames[i], sims[i]) for i in top_idx]

# def flexible_ranked_filter(initial_results, filters):
#     if not any(filters.values()): return initial_results
#     expected_tags = {f"{attr}_{val.replace(' ', '_')}" for attr, val in filters.items() if val}
#     scored_results = []
#     for filename, original_score in initial_results:
#         try:
#             item_tags = set(tagged_df_indexed.loc[filename, 'tags'])
#             match_count = len(expected_tags.intersection(item_tags))
#             if match_count > 0:
#                 scored_results.append((filename, original_score, match_count))
#         except KeyError: continue
#     sorted_results = sorted(scored_results, key=lambda x: (x[2], x[1]), reverse=True)
#     return [(filename, score) for filename, score, _ in sorted_results]

# def enhance_prompt(base_prompt):
#     enhanced_positive = f"masterpiece, photorealistic, high-quality professional product shot of a {base_prompt}, intricate, hyper-detailed, sharp focus, cinematic lighting, 8k uhd, on a mannequin, clean white background"
#     negative = "lowres, blurry, bad anatomy, error, worst quality, jpeg artifacts, ugly, duplicate, morbid, out of frame, watermark, text, signature, person, model"
#     return enhanced_positive, negative

# def create_edited_design_enhanced(base_image, mask_image, user_prompt):
#     enhanced_prompt, negative_prompt = enhance_prompt(user_prompt)
#     with torch.no_grad():
#         edited_image = inpaint_pipe(
#             prompt=enhanced_prompt, image=base_image.resize((512, 512)), 
#             mask_image=mask_image.resize((512, 512)),
#             negative_prompt=negative_prompt, num_inference_steps=50, guidance_scale=8.5
#         ).images[0]
#     return edited_image

# def virtual_try_on_placeholder(person_img, cloth_img):
#     placeholder = Image.new('RGB', (768, 1024), color = 'white')
#     draw = ImageDraw.Draw(placeholder)
#     try: font = ImageFont.truetype("LiberationSans-Regular.ttf", 30)
#     except IOError: font = ImageFont.load_default()
#     text = "Virtual Try-On Output\n\nThis is where the magic happens!\n\nA VTON model would:\n1. Parse pose (from OpenPose JSON).\n2. Segment the body (from Human Parsing).\n3. Isolate garment (from Clothing Mask).\n4. Warp clothing onto the person's shape.\n5. Generate a new, photorealistic image."
#     draw.multiline_text((50, 200), text, fill='black', font=font)
#     return placeholder

# # =============================================================================
# # 4. STREAMLIT UI LAYOUT
# # =============================================================================

# st.title("ü§ñ AI Fashion Stylist Pro")
# st.markdown("Discover, create, and virtually try on your next favorite outfit.")

# tab1, tab2, tab3 = st.tabs(["üîé Smart Search & Recommendation", "üé® Creative Director", "üë§ Virtual Try-On Hub"])

# with tab1:
#     with st.sidebar:
#         st.header("Search & Filter Controls")
#         text_query = st.text_input("Text Description", placeholder="e.g., a blue floral blouse")
#         image_query_file = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'])
#         image_weight = 0.5
#         if text_query and image_query_file:
#             image_weight = st.slider("Image vs. Text Influence", 0.0, 1.0, 0.5, 0.1)
#         st.markdown("---")
#         st.subheader("Smart Tags")
#         dropdowns = generate_dropdown_options(tagged_df)
#         filters = {}
#         for attr, options in dropdowns.items():
#             filters[attr] = st.selectbox(attr.capitalize(), options)
#         search_button = st.button("Search & Filter", type="primary")

#     st.header("Your Personalized Recommendations")
#     if search_button:
#         with st.spinner("Finding your style..."):
#             initial_results = []
#             image_query = Image.open(image_query_file) if image_query_file else None
#             if text_query and image_query: initial_results = multi_modal_search(text_query, image_query, image_weight)
#             elif image_query: initial_results = search_by_image(image_query)
#             elif text_query: initial_results = search_by_text(text_query)
#             st.session_state.search_results = flexible_ranked_filter(initial_results, filters)

#     if not st.session_state.search_results and search_button:
#         st.warning("No items matched your query or filters. Please broaden your search.")
#     elif not st.session_state.search_results:
#         st.info("Enter a query in the sidebar and click 'Search' to see recommendations.")
#     else:
#         st.success(f"Found {len(st.session_state.search_results)} matching items. Showing top results.")
#         cols = st.columns(6)
#         for i, (fname, score) in enumerate(st.session_state.search_results[:12]):
#             image_path = os.path.join(IMAGE_DIR, fname)
#             if os.path.exists(image_path):
#                 cols[i % 6].image(image_path, caption=f"Score: {score:.2f}", use_column_width=True)

# with tab2:
#     st.header("Unleash Your Inner Designer")
#     col1, col2 = st.columns(2)
#     with col1:
#         st.subheader("1. Upload Base Item & Mask")
#         base_img_file = st.file_uploader("Base Clothing Image", type=['png', 'jpg', 'jpeg'], key="base")
#         mask_img_file = st.file_uploader("Mask Image (White area is replaced)", type=['png', 'jpg', 'jpeg'], key="mask")
#     with col2:
#         st.subheader("2. Describe Your Idea")
#         prompt = st.text_input("Prompt", placeholder="e.g., a roaring tiger head")
#         generate_button = st.button("Generate & Find Similar", type="primary", key="generate")

#     if generate_button:
#         if base_img_file and mask_img_file and prompt:
#             with st.spinner("Creating your masterpiece... This can take up to a minute."):
#                 base_img = Image.open(base_img_file).convert("RGB")
#                 mask_img = Image.open(mask_img_file).convert("RGB")
#                 st.session_state.generated_design = create_edited_design_enhanced(base_img, mask_img, prompt)
#                 st.session_state.similar_items = search_by_image(st.session_state.generated_design, top_k=6)
#         else:
#             st.error("Please provide a base image, a mask image, and a text prompt.")

#     if st.session_state.generated_design:
#         st.markdown("---")
#         st.subheader("3. Your Unique Creation & Similar Real Items")
#         res_col1, res_col2 = st.columns([1, 2])
#         with res_col1:
#             st.image(st.session_state.generated_design, caption="Your Generated Design", use_column_width=True)
#         with res_col2:
#             if st.session_state.similar_items:
#                 cols = st.columns(3)
#                 for i, (fname, score) in enumerate(st.session_state.similar_items):
#                     image_path = os.path.join(IMAGE_DIR, fname)
#                     if os.path.exists(image_path):
#                         cols[i % 3].image(image_path, use_column_width=True)
#             else:
#                 st.info("No similar items found in the database.")

# with tab3:
#     st.header("Experience the Future of Fitting")
#     st.info("This is a conceptual demonstration of a Virtual Try-On system using your dataset.")
#     col1, col2 = st.columns(2)
#     with col1:
#         person_img_file = st.file_uploader("1. Upload Person Image", type=['png', 'jpg', 'jpeg'], key="vton_p")
#     with col2:
#         cloth_img_file = st.file_uploader("2. Upload Clothing Item", type=['png', 'jpg', 'jpeg'], key="vton_c")
        
#     if st.button("Generate Virtual Try-On", type="primary", key="vton_gen"):
#         if person_img_file and cloth_img_file:
#             person_img = Image.open(person_img_file)
#             cloth_img = Image.open(cloth_img_file)
#             st.image(virtual_try_on_placeholder(person_img, cloth_img), caption="VTON Process Simulation")
#         else:
#             st.error("Please upload both a person and a clothing image.")

In [None]:
# # ----------------------------------------------------------------------------------
# # Cell 2: Run the Streamlit Application with ngrok using Threads
# # ----------------------------------------------------------------------------------
# import os
# import threading
# import time
# from pyngrok import ngrok
# from kaggle_secrets import UserSecretsClient

# def run_streamlit():
#     os.system("streamlit run app.py --server.headless true --server.enableCORS false --server.port 8501")

# print("Setting up ngrok tunnel...")
# try:
#     authtoken = UserSecretsClient().get_secret("NGROK_AUTH_TOKEN")
#     ngrok.set_auth_token(authtoken)
#     print("‚úÖ ngrok authtoken set successfully!")
# except Exception as e:
#     print(f"‚ö†Ô∏è ngrok authtoken not found. The tunnel will be temporary.")

# thread = threading.Thread(target=run_streamlit)
# thread.start()
# time.sleep(5) 

# public_url = ngrok.connect(8501)
# print("üöÄ Your Streamlit App is live!")
# print(f"üîó Public URL: {public_url}")

In [None]:
# Step 1: Install a Kaggle-compatible PyTorch first.
!pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118

# Step 2: Install all other packages, including the new canvas library and compatible versions of the HF ecosystem.
!pip install streamlit==1.28.2 streamlit-drawable-canvas==0.9.3 pyngrok==7.0.0 transformers==4.34.0 diffusers==0.23.1 accelerate==0.24.1 pandas==2.1.3 scikit-learn==1.3.2 opencv-python-headless==4.8.1.78 peft==0.6.2 --quiet


In [None]:
# !git clone https://github.com/levihsu/OOTDiffusion.git

# # Step 2: Install a Kaggle-compatible PyTorch first.
# !pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118

# # Step 3: Install all other packages, including OOTDiffusion's specific requirements.
# !pip install streamlit==1.28.2 streamlit-drawable-canvas==0.9.3 pyngrok==7.0.0 transformers==4.34.0 diffusers==0.23.1 accelerate==0.24.1 pandas==2.1.3 scikit-learn==1.3.2 opencv-python-headless==4.8.1.78 peft==0.6.2 ninja --quiet

# # Step 4: Download the pre-trained models for OOTDiffusion
# # This will take a significant amount of time.
# print("Downloading OOTDiffusion models...")
# # Main VTON Model
# !wget -O /kaggle/working/OOTDiffusion/models/vton/checkpoint.pth https://huggingface.co/levihsu/OOTDiffusion/resolve/main/checkpoints/vton/checkpoint.pth
# # Segmentation Model
# !wget -O /kaggle/working/OOTDiffusion/models/parsing/79999_iter.pth https://huggingface.co/levihsu/OOTDiffusion/resolve/main/checkpoints/parsing/79999_iter.pth
# # OpenPose Model
# !wget -O /kaggle/working/OOTDiffusion/models/openpose/body_pose_model.pth https://huggingface.co/levihsu/OOTDiffusion/resolve/main/checkpoints/openpose/body_pose_model.pth
# print("Model downloads complete.")

In [None]:
# %%writefile app.py

# import streamlit as st
# import torch
# import pandas as pd
# import os
# import re
# from PIL import Image, ImageDraw, ImageFont
# from collections import defaultdict
# from transformers import CLIPProcessor, CLIPModel
# from diffusers import StableDiffusionInpaintPipeline
# from sklearn.metrics.pairwise import cosine_similarity
# import torch.nn.functional as F
# from streamlit_drawable_canvas import st_canvas
# import subprocess # To run the VTON script

# # =============================================================================
# # 1. PAGE CONFIGURATION & INITIALIZATION
# # =============================================================================
# st.set_page_config(page_title="AI Fashion Stylist Pro", page_icon="ü§ñ", layout="wide")

# # Initialize session state for all the variables we need to track
# for key in ['search_results', 'generated_design', 'similar_items', 'creative_search_results', 'selected_creative_image_fname', 'vton_result_image']:
#     if key not in st.session_state:
#         st.session_state[key] = None

# # =============================================================================
# # 2. MODEL & DATA LOADING (Cached to run only once)
# # =============================================================================
# @st.cache_resource
# def load_all_models_and_data():
#     """Loads all models and data files into memory, cached for performance."""
#     print("Executing one-time resource loading...")
    
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
#     clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

#     inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
#         "runwayml/stable-diffusion-inpainting",
#         torch_dtype=torch.float16,
#     ).to(device)

#     data_dir = '/kaggle/input/clothe/clothes_tryon_dataset/train/cloth'
#     embeddings_path = "/kaggle/input/clip-embed/pytorch/default/1/clip_image_embeds (1).pt"
#     tagged_data_path = "/kaggle/input/tagged-fashion-data/tagged_fashion_data.json"

#     if not os.path.exists(embeddings_path) or not os.path.exists(tagged_data_path):
#         st.error(f"Data files not found in '{data_dir}'. Attach the 'fashion-app-data' Kaggle dataset.")
#         st.stop()
        
#     embedding_data = torch.load(embeddings_path, map_location='cpu')
#     image_embeds = embedding_data["embeds"]
#     image_filenames = embedding_data["files"]

#     tagged_df = pd.read_json(tagged_data_path, orient='records', lines=True)
#     tagged_df_indexed = tagged_df.set_index('filename')
#     print("Resource loading complete.")
#     return device, clip_model, clip_processor, inpaint_pipe, image_embeds, image_filenames, tagged_df, tagged_df_indexed

# DEVICE, clip_model, clip_processor, inpaint_pipe, image_embeds, image_filenames, tagged_df, tagged_df_indexed = load_all_models_and_data()


# # --- CRITICAL: Define all required paths for the dataset ---
# BASE_DATA_PATH = "/kaggle/input/clothe/clothes_tryon_dataset/train/"
# IMAGE_DIR = "/kaggle/input/clothe/clothes_tryon_dataset/train/cloth"
# PERSON_IMAGE_DIR = "/kaggle/input/clothe/clothes_tryon_dataset/train/image"
# CLOTH_MASK_DIR = "/kaggle/input/clothe/clothes_tryon_dataset/train/cloth-mask"
# POSE_JSON_DIR = "/kaggle/input/clothe/clothes_tryon_dataset/train/openpose_img"

# for path in [IMAGE_DIR, PERSON_IMAGE_DIR, CLOTH_MASK_DIR, POSE_JSON_DIR]:
#     if not os.path.exists(path):
#         st.error(f"Required directory not found: '{path}'. Please check your dataset paths.")
#         st.stop()

# # =============================================================================
# # 3. BACKEND FUNCTIONS
# # =============================================================================
# @st.cache_data
# def generate_dropdown_options(_df):
#     options = defaultdict(set)
#     for tag_list in _df['tags']:
#         if not isinstance(tag_list, list): continue
#         for tag in tag_list:
#             try:
#                 attribute, value = tag.split('_', 1)
#                 options[attribute].add(value.replace('_', ' '))
#             except ValueError: continue
#     return {attribute: [""] + sorted(list(values)) for attribute, values in options.items()}

# def search_by_text(query, top_k=200):
#     text_inputs = clip_processor(text=[query], return_tensors="pt").to(DEVICE)
#     with torch.no_grad():
#         text_feat = clip_model.get_text_features(**text_inputs)
#         text_feat = F.normalize(text_feat, dim=1)
#     sims = cosine_similarity(text_feat.cpu().numpy(), image_embeds.numpy())[0]
#     top_idx = sims.argsort()[::-1][:top_k]
#     return [(image_filenames[i], sims[i]) for i in top_idx]

# def search_by_image(query_image, top_k=200):
#     if query_image is None: return []
#     inputs = clip_processor(images=query_image, return_tensors="pt").to(DEVICE)
#     with torch.no_grad():
#         query_embed = clip_model.get_image_features(**inputs)
#         query_embed = F.normalize(query_embed, dim=1)
#     sims = cosine_similarity(query_embed.cpu().numpy(), image_embeds.numpy())[0]
#     top_idx = sims.argsort()[::-1][:top_k]
#     return [(image_filenames[i], sims[i]) for i in top_idx]

# def multi_modal_search(query_text, query_image, image_weight=0.5, top_k=200):
#     text_feat = F.normalize(clip_model.get_text_features(**clip_processor(text=[query_text], return_tensors="pt").to(DEVICE)))
#     img_feat = F.normalize(clip_model.get_image_features(**clip_processor(images=query_image, return_tensors="pt").to(DEVICE)))
#     combined_feat = F.normalize((image_weight * img_feat) + ((1.0 - image_weight) * text_feat))
#     sims = cosine_similarity(combined_feat.cpu().detach().numpy(), image_embeds.numpy())[0]
#     top_idx = sims.argsort()[::-1][:top_k]
#     return [(image_filenames[i], sims[i]) for i in top_idx]

# def flexible_ranked_filter(initial_results, filters):
#     if not any(filters.values()): return initial_results
#     expected_tags = {f"{attr}_{val.replace(' ', '_')}" for attr, val in filters.items() if val}
#     scored_results = []
#     for filename, original_score in initial_results:
#         try:
#             item_tags = set(tagged_df_indexed.loc[filename, 'tags'])
#             match_count = len(expected_tags.intersection(item_tags))
#             if match_count > 0:
#                 scored_results.append((filename, original_score, match_count))
#         except KeyError: continue
#     sorted_results = sorted(scored_results, key=lambda x: (x[2], x[1]), reverse=True)
#     return [(filename, score) for filename, score, _ in sorted_results]

# def enhance_prompt(base_prompt):
#     enhanced_positive = f"masterpiece, photorealistic, high-quality professional product shot of a {base_prompt}, intricate, hyper-detailed, sharp focus, cinematic lighting, 8k uhd, on a mannequin, clean white background"
#     negative = "lowres, blurry, bad anatomy, error, worst quality, jpeg artifacts, ugly, duplicate, morbid, out of frame, watermark, text, signature, person, model"
#     return enhanced_positive, negative

# def create_edited_design_enhanced(base_image, mask_image, user_prompt):
#     enhanced_prompt, negative_prompt = enhance_prompt(user_prompt)
#     with torch.no_grad():
#         edited_image = inpaint_pipe(
#             prompt=enhanced_prompt, image=base_image.resize((512, 512)), 
#             mask_image=mask_image.resize((512, 512)),
#             negative_prompt=negative_prompt, num_inference_steps=50, guidance_scale=8.5
#         ).images[0]
#     return edited_image

# # --- The REAL VTON Backend Function ---
# def run_virtual_tryon(person_fname, cloth_fname):
#     """Constructs and runs the OOTDiffusion command with the correct working directory."""
#     person_img_src = os.path.join(PERSON_IMAGE_DIR, person_fname)
#     cloth_img_src = os.path.join(IMAGE_DIR, cloth_fname)
    
#     session_id = f"{os.path.splitext(person_fname)[0]}_{os.path.splitext(cloth_fname)[0]}"
#     output_dir = os.path.join("/kaggle/working/OOTDiffusion/results/", session_id)
#     os.makedirs(output_dir, exist_ok=True)
    
#     # --- KEY CHANGE 1: The script path is now relative ---
#     # We are running 'python' on 'run_oot.py' from *within* its own directory.
#     command = [
#         "python", "run_oot.py",
#         "--model_path", person_img_src,
#         "--cloth_path", cloth_img_src,
#         "--model_type", "hd",
#         "--category", "upperbody",
#         "--output_dir", output_dir
#     ]
    
#     # --- KEY CHANGE 2: Set the Current Working Directory (cwd) ---
#     # This tells the subprocess to execute as if it were in the OOTDiffusion folder.
#     ootd_working_dir = "/kaggle/working/OOTDiffusion/"
    
#     try:
#         process = subprocess.run(command, check=True, capture_output=True, text=True, cwd=ootd_working_dir)
#         print("OOTDiffusion STDOUT:", process.stdout)
#         print("OOTDiffusion STDERR:", process.stderr)

#         expected_output_fname = f"{os.path.splitext(os.path.basename(person_img_src))[0]}_{os.path.splitext(os.path.basename(cloth_img_src))[0]}.png"
#         output_path = os.path.join(output_dir, expected_output_fname)

#         if os.path.exists(output_path):
#             return output_path
#         else:
#             for file in os.listdir(output_dir):
#                 if file.endswith(".png"): return os.path.join(output_dir, file)
#             return None
            
#     except subprocess.CalledProcessError as e:
#         print(f"OOTDiffusion script failed. Error: {e.stderr}")
#         st.error(f"Virtual Try-On failed. Error: {e.stderr}")
#         return None
# # =============================================================================
# # 4. STREAMLIT UI LAYOUT
# # =============================================================================

# st.title("ü§ñ AI Fashion Stylist Pro")
# st.markdown("Discover, create, and virtually try on your next favorite outfit.")

# tab1, tab2, tab3 = st.tabs(["üîé Smart Search & Recommendation", "üé® Creative Director", "üë§ Virtual Try-On Hub"])

# with tab1:
#     with st.sidebar:
#         st.header("Search & Filter Controls")
#         text_query = st.text_input("Text Description", placeholder="e.g., a blue floral blouse")
#         image_query_file = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'])
#         image_weight = 0.5
#         if text_query and image_query_file:
#             image_weight = st.slider("Image vs. Text Influence", 0.0, 1.0, 0.5, 0.1)
#         st.markdown("---")
#         st.subheader("Smart Tags")
#         dropdowns = generate_dropdown_options(tagged_df)
#         filters = {}
#         for attr, options in dropdowns.items():
#             filters[attr] = st.selectbox(attr.capitalize(), options)
#         search_button = st.button("Search & Filter", type="primary")

#     st.header("Your Personalized Recommendations")
#     if search_button:
#         with st.spinner("Finding your style..."):
#             initial_results = []
#             image_query = Image.open(image_query_file) if image_query_file else None
#             if text_query and image_query: initial_results = multi_modal_search(text_query, image_query, image_weight)
#             elif image_query: initial_results = search_by_image(image_query)
#             elif text_query: initial_results = search_by_text(text_query)
#             st.session_state.search_results = flexible_ranked_filter(initial_results, filters)

#     if not st.session_state.search_results and search_button:
#         st.warning("No items matched your query or filters. Please broaden your search.")
#     elif not st.session_state.search_results:
#         st.info("Enter a query in the sidebar and click 'Search' to see recommendations.")
#     else:
#         st.success(f"Found {len(st.session_state.search_results)} matching items. Showing top results.")
#         cols = st.columns(6)
#         for i, (fname, score) in enumerate(st.session_state.search_results[:12]):
#             image_path = os.path.join(IMAGE_DIR, fname)
#             if os.path.exists(image_path):
#                 cols[i % 6].image(image_path, caption=f"Score: {score:.2f}", use_column_width=True)

# with tab2:
#     st.header("Unleash Your Inner Designer")
#     st.subheader("1. Find a Base Item to Edit")
#     creative_search_term = st.text_input("Search for a clothing item", key="creative_search")
#     if st.button("Search", key="creative_search_btn"):
#         if creative_search_term:
#             with st.spinner("Searching..."):
#                 st.session_state.creative_search_results = search_by_text(creative_search_term, top_k=6)
#         else:
#             st.session_state.creative_search_results = None

#     if st.session_state.creative_search_results:
#         st.markdown("---")
#         st.write("Click 'Select' on an item to start editing:")
#         cols = st.columns(6)
#         for i, (fname, score) in enumerate(st.session_state.creative_search_results):
#             image_path = os.path.join(IMAGE_DIR, fname)
#             if os.path.exists(image_path):
#                 cols[i].image(image_path, use_column_width=True)
#                 if cols[i].button("Select", key=f"select_{fname}"):
#                     st.session_state.selected_creative_image_fname = fname
#                     st.session_state.generated_design = None
#                     st.session_state.similar_items = None
#                     st.experimental_rerun()
    
#     st.markdown("---")
#     if st.session_state.selected_creative_image_fname:
#         st.subheader("2. Paint Over the Area to Change & Describe Your Idea")
#         base_image_path = os.path.join(IMAGE_DIR, st.session_state.selected_creative_image_fname)
#         base_img = Image.open(base_image_path).convert("RGB")
#         col1, col2 = st.columns(2)
#         with col1:
#             st.write("Use the thick brush to paint the area you want to change.")
#             canvas_result = st_canvas(
#                 fill_color="rgba(255, 255, 255, 0)", stroke_width=35,
#                 stroke_color="rgba(255, 0, 0, 0.5)", background_image=base_img.resize((512, 512)),
#                 update_streamlit=True, height=512, width=512, drawing_mode="freedraw", key="canvas",
#             )
#         with col2:
#             prompt = st.text_input("What should the new design be?", placeholder="e.g., a roaring tiger head")
#             generate_button = st.button("Generate & Find Similar", type="primary", key="generate")
#             if generate_button:
#                 if canvas_result.image_data is not None and prompt:
#                     mask_array = canvas_result.image_data[:, :, 3] > 0
#                     if mask_array.sum() > 0:
#                         mask_img = Image.fromarray(mask_array.astype('uint8') * 255)
#                         with st.spinner("Creating your masterpiece..."):
#                             st.session_state.generated_design = create_edited_design_enhanced(base_img, mask_img, prompt)
#                             st.session_state.similar_items = search_by_image(st.session_state.generated_design, top_k=6)
#                     else: st.warning("Please paint on the image to indicate the area to change.")
#                 else: st.error("Please paint on the image and provide a text prompt.")

#     if st.session_state.generated_design:
#         st.markdown("---"); st.subheader("3. Your Unique Creation & Similar Real Items")
#         res_col1, res_col2 = st.columns([1, 2])
#         with res_col1: st.image(st.session_state.generated_design, caption="Your Generated Design", use_column_width=True)
#         with res_col2:
#             if st.session_state.similar_items:
#                 cols = st.columns(3)
#                 for i, (fname, score) in enumerate(st.session_state.similar_items):
#                     image_path = os.path.join(IMAGE_DIR, fname)
#                     if os.path.exists(image_path): cols[i % 3].image(image_path, use_column_width=True)
#             else: st.info("No similar items found in the database.")


# # ### CORRECTED VIRTUAL TRY-ON TAB ###
# with tab3:
#     st.header("‚ú® The Magic Mirror: Virtual Try-On")
#     st.info("Select a person and a clothing item from the dataset to see the try-on result. Generation can take 1-2 minutes.")

#     # Get a sample of available person and clothing images for the dropdowns
#     person_files = sorted(os.listdir(PERSON_IMAGE_DIR))[:200]
#     cloth_files = sorted(os.listdir(IMAGE_DIR))[:200]

#     col1, col2 = st.columns(2)
#     with col1:
#         selected_person = st.selectbox("1. Choose a Person Model", person_files)
#         if selected_person:
#             st.image(os.path.join(PERSON_IMAGE_DIR, selected_person), use_column_width=True)
    
#     with col2:
#         selected_cloth = st.selectbox("2. Choose a Clothing Item", cloth_files)
#         if selected_cloth:
#             st.image(os.path.join(IMAGE_DIR, selected_cloth), use_column_width=True)
            
#     if st.button("Generate Virtual Try-On", type="primary", key="vton_gen"):
#         if selected_person and selected_cloth:
#             with st.spinner("Warming up the Magic Mirror... This will take a moment."):
#                 # OOTDiffusion automatically finds the corresponding mask and pose files based on filename conventions
#                 # So we just need to run the main function
#                 result_path = run_virtual_tryon(selected_person, selected_cloth)
                
#                 if result_path and os.path.exists(result_path):
#                     st.session_state.vton_result_image = Image.open(result_path)
#                 else:
#                     st.session_state.vton_result_image = None
#                     st.error("VTON generation failed or output file not found. Check notebook logs.")
#         else:
#             st.error("Please select both a person and a clothing item.")
            
#     if st.session_state.vton_result_image:
#         st.markdown("---")
#         st.subheader("üéâ Your Virtual Try-On Result")
#         st.image(st.session_state.vton_result_image, use_column_width=True)

In [None]:
# # ----------------------------------------------------------------------------------
# # Cell 2: Run the Streamlit Application with ngrok using Threads
# # ----------------------------------------------------------------------------------
# import os
# import threading
# import time
# from pyngrok import ngrok
# from kaggle_secrets import UserSecretsClient

# # --- Function to run Streamlit in a thread ---
# def run_streamlit():
#     """This function runs the streamlit command in the shell."""
#     os.system("streamlit run app.py --server.headless true --server.enableCORS false --server.port 8501")

# # --- Setup ngrok ---
# print("Setting up ngrok tunnel...")
# # Authenticate ngrok using the secret you added
# try:
#     authtoken = UserSecretsClient().get_secret("NGROK_AUTH_TOKEN")
#     ngrok.set_auth_token(authtoken)
#     print("‚úÖ ngrok authtoken set successfully!")
# except Exception as e:
#     print(f"‚ö†Ô∏è ngrok authtoken not found in Kaggle Secrets. The tunnel will be temporary. Error: {e}")
#     print("Create a free ngrok account and add your authtoken as a secret named NGROK_AUTH_TOKEN for longer sessions.")

# # --- Start Streamlit in a separate thread ---
# thread = threading.Thread(target=run_streamlit)
# thread.start()

# # Give the Streamlit server a moment to start up
# time.sleep(5) 

# # --- Open the ngrok tunnel to the Streamlit port (8501) ---
# public_url = ngrok.connect(8501)
# print("üöÄ Your Streamlit App is live!")
# print(f"üîó Public URL: {public_url}")

In [None]:
# ----------------------------------------------------------------------------------
# Cell 1: Install All Dependencies and Write the Final Streamlit App File
# ----------------------------------------------------------------------------------

# Step 1: Install a Kaggle-compatible PyTorch first.
!pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118

# Step 2: Install all other packages with compatible versions.
!pip install streamlit==1.28.2 streamlit-drawable-canvas==0.9.3 pyngrok==7.0.0 transformers==4.34.0 diffusers==0.23.1 accelerate==0.24.1 pandas==2.1.3 scikit-learn==1.3.2 opencv-python-headless==4.8.1.78 peft==0.6.2 ninja --quiet

# # Step 3: Clone the OOTDiffusion repository if it doesn't exist.
# !git clone https://github.com/levihsu/OOTDiffusion.git /kaggle/working/OOTDiffusion

# # Step 4: Download the pre-trained models for OOTDiffusion.
# print("Downloading OOTDiffusion models...")
# !wget -nc -O /kaggle/working/OOTDiffusion/models/vton/checkpoint.pth https://huggingface.co/levihsu/OOTDiffusion/resolve/main/checkpoints/vton/checkpoint.pth
# !wget -nc -O /kaggle/working/OOTDiffusion/models/parsing/79999_iter.pth https://huggingface.co/levihsu/OOTDiffusion/resolve/main/checkpoints/parsing/79999_iter.pth
# !wget -nc -O /kaggle/working/OOTDiffusion/models/openpose/body_pose_model.pth https://huggingface.co/levihsu/OOTDiffusion/resolve/main/checkpoints/openpose/body_pose_model.pth
# print("Model downloads complete.")

In [None]:
%%writefile app.py

import streamlit as st
import torch
import pandas as pd
import os
import re
from PIL import Image, ImageDraw, ImageFont
from collections import defaultdict
from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration
from diffusers import StableDiffusionInpaintPipeline
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn.functional as F
from streamlit_drawable_canvas import st_canvas
import subprocess
import numpy as np

# =============================================================================
# 1. PAGE CONFIGURATION & INITIALIZATION
# =============================================================================
st.set_page_config(
    page_title="AI Fashion Stylist Pro",
    page_icon="ü§ñ",
    layout="wide"
)

# Initialize session state
for key in ['search_results', 'generated_design', 'similar_items', 'creative_search_results', 
            'selected_creative_image_fname', 'vton_result_image', 'sketch_results', 
            'sketch_caption', 'uploaded_sketch', 'cart', 'items_to_show', 'search_active']:
    if key not in st.session_state:
        st.session_state[key] = [] if key == 'cart' else False if key == 'search_active' else 12 if key == 'items_to_show' else None

# =============================================================================
# 2. MODEL & DATA LOADING
# =============================================================================
@st.cache_resource
def load_all_models_and_data():
    print("Executing one-time resource loading...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
    inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16).to(device)
    
    data_dir = '/kaggle/input/clothe/clothes_tryon_dataset/train/cloth'
    embeddings_path = "/kaggle/input/clip-embed/pytorch/default/1/clip_image_embeds (1).pt"
    tagged_data_path = "/kaggle/input/tagged-fashion-data/tagged_fashion_data.json"
    
    embedding_data = torch.load(embeddings_path, map_location='cpu')
    image_embeds = embedding_data["embeds"].numpy()
    image_filenames = embedding_data["files"]
    tagged_df = pd.read_json(tagged_data_path, orient='records', lines=True)
    tagged_df_indexed = tagged_df.set_index('filename')
    
    print("Resource loading complete.")
    return device, clip_model, clip_processor, blip_model, blip_processor, inpaint_pipe, image_embeds, image_filenames, tagged_df, tagged_df_indexed

DEVICE, clip_model, clip_processor, blip_model, blip_processor, inpaint_pipe, image_embeds, image_filenames, tagged_df, tagged_df_indexed = load_all_models_and_data()

BASE_DATA_PATH = "/kaggle/input/clothe/clothes_tryon_dataset/train/"
IMAGE_DIR = os.path.join(BASE_DATA_PATH, "cloth/")
PERSON_IMAGE_DIR = os.path.join(BASE_DATA_PATH, "image/")

# =============================================================================
# 3. BACKEND FUNCTIONS
# =============================================================================
@st.cache_data
def generate_dropdown_options(_df):
    options = defaultdict(set)
    for attr in ['brand', 'color', 'pattern', 'style', 'neckline']:
        options[attr].add("Any")
    for tag_list in _df['tags']:
        if not isinstance(tag_list, list): continue
        for tag in tag_list:
            try:
                attribute, value = tag.split('_', 1)
                options[attribute].add(value.replace('_', ' '))
            except ValueError: continue
    return {attribute: sorted(list(values)) for attribute, values in options.items()}

def search_by_text(query, top_k=500):
    text_inputs = clip_processor(text=[query], return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        text_feat = clip_model.get_text_features(**text_inputs)
        text_feat = F.normalize(text_feat, dim=1)
    sims = cosine_similarity(text_feat.cpu().numpy(), image_embeds)[0]
    top_idx = np.argsort(-sims)[:top_k]
    return [(image_filenames[i], sims[i]) for i in top_idx]

def search_by_image(query_image, top_k=500):
    if query_image is None: return []
    inputs = clip_processor(images=query_image, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        query_embed = clip_model.get_image_features(**inputs)
        query_embed = F.normalize(query_embed, dim=1)
    sims = cosine_similarity(query_embed.cpu().numpy(), image_embeds)[0]
    top_idx = np.argsort(-sims)[:top_k]
    return [(image_filenames[i], sims[i]) for i in top_idx]

def multi_modal_search(query_text, query_image, image_weight=0.5, top_k=500):
    text_feat = F.normalize(clip_model.get_text_features(**clip_processor(text=[query_text], return_tensors="pt").to(DEVICE)))
    img_feat = F.normalize(clip_model.get_image_features(**clip_processor(images=query_image, return_tensors="pt").to(DEVICE)))
    combined_feat = F.normalize((image_weight * img_feat) + ((1.0 - image_weight) * text_feat))
    sims = cosine_similarity(combined_feat.cpu().detach().numpy(), image_embeds)[0]
    top_idx = np.argsort(-sims)[:top_k]
    return [(image_filenames[i], sims[i]) for i in top_idx]

# --- CORRECTED flexible_ranked_filter ---
def flexible_ranked_filter(initial_results, filters):
    # If no filters are selected (all are "Any"), just return the initial results
    if not filters: 
        return initial_results
        
    expected_tags = {f"{attr}_{val.replace(' ', '_')}" for attr, val in filters.items()}
    scored_results = []
    
    for filename, original_score in initial_results:
        try:
            item_tags = set(tagged_df_indexed.loc[filename, 'tags'])
            match_count = len(expected_tags.intersection(item_tags))
            if match_count > 0:
                scored_results.append((filename, original_score, match_count))
        except KeyError: 
            continue
            
    sorted_results = sorted(scored_results, key=lambda x: (x[2], x[1]), reverse=True)
    return [(filename, score) for filename, score, _ in sorted_results]

def add_to_cart(item_filename):
    if item_filename not in st.session_state.cart:
        st.session_state.cart.append(item_filename)

def remove_from_cart(item_filename):
    if item_filename in st.session_state.cart:
        st.session_state.cart.remove(item_filename)

def clear_cart():
    st.session_state.cart = []

def load_more_items():
    st.session_state.items_to_show += 12

def handle_search(text_query, image_query_file, image_weight, filters):
    st.session_state.search_active = True
    initial_results = []
    image_query = Image.open(image_query_file) if image_query_file else None

    # Determine the initial set of candidates based on the search query
    if text_query and image_query:
        initial_results = multi_modal_search(text_query, image_query, image_weight)
    elif image_query:
        initial_results = search_by_image(image_query)
    elif text_query:
        initial_results = search_by_text(text_query)
    else:
        # If no search query, the initial pool is all items, allowing for filtering the entire catalog
        initial_results = [(fname, 1.0) for fname in image_filenames]
    
    # Apply the flexible ranked filter to the initial results
    st.session_state.search_results = flexible_ranked_filter(initial_results, filters)

def clear_search():
    st.session_state.search_active = False
    st.session_state.search_results = []
    st.session_state.items_to_show = 12

# (Other backend functions for other tabs remain the same)
# ...

# =============================================================================
# 4. STREAMLIT UI LAYOUT
# =============================================================================

st.title("ü§ñ AI Fashion Stylist Pro")
st.markdown("Discover, create, and virtually try on your next favorite outfit.")

tab1, tab2, tab3, tab4 = st.tabs(["üõçÔ∏è Browse & Search", "üé® Creative Director", "üë§ Virtual Try-On", "‚úçÔ∏è Sketch to Fashion"])

with tab1:
    with st.sidebar:
        st.header("Search & Filter")
        text_query = st.text_input("Text Description", placeholder="e.g., a blue floral blouse")
        image_query_file = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'], key="main_uploader")
        image_weight = 0.5
        if text_query and image_query_file:
            image_weight = st.slider("Image vs. Text Influence", 0.0, 1.0, 0.5, 0.1)
        
        st.subheader("Smart Tags")
        dropdowns = generate_dropdown_options(tagged_df)
        filters = {}
        for attr, options in dropdowns.items():
            selected = st.selectbox(attr.capitalize(), options, index=0)
            # Only add the filter if a specific option (not "Any") is chosen
            if selected != "Any":
                filters[attr] = selected
        
        col1, col2 = st.columns(2)
        with col1:
            st.button("Search / Filter", on_click=handle_search, args=(text_query, image_query_file, image_weight, filters), type="primary")
        with col2:
            st.button("Browse All", on_click=clear_search)

        st.markdown("---")
        st.header("Shopping Cart")
        if not st.session_state.cart:
            st.info("Your cart is empty.")
        else:
            for item_fname in st.session_state.cart:
                cart_col1, cart_col2 = st.columns([1, 3])
                with cart_col1:
                    st.image(os.path.join(IMAGE_DIR, item_fname), use_column_width=True)
                with cart_col2:
                    st.write(f"**{item_fname}**")
                    st.button("Remove", key=f"remove_{item_fname}", on_click=remove_from_cart, args=(item_fname,))
            st.markdown("---")
            if st.button("Clear Cart"):
                clear_cart()
                st.experimental_rerun()

    # --- Main Display Area Logic ---
    if st.session_state.search_active:
        st.header("Search Results")
        display_items = st.session_state.search_results
        if not display_items:
            st.warning("Your search returned no results. Try broadening your query or click 'Browse All'.")
    else:
        st.header("Browse Our Collection")
        display_items = [(fname, None) for fname in image_filenames]
    
    if display_items:
        items_to_display_now = display_items[:st.session_state.items_to_show]
        num_cols = 6
        cols = st.columns(num_cols)
        for i, (fname, score) in enumerate(items_to_display_now):
            image_path = os.path.join(IMAGE_DIR, fname)
            if os.path.exists(image_path):
                with cols[i % num_cols]:
                    st.image(image_path, use_column_width=True)
                    st.button("Add to Cart", key=f"add_{fname}", on_click=add_to_cart, args=(fname,))
        
        if not st.session_state.search_active and st.session_state.items_to_show < len(display_items):
            st.button("Load More", on_click=load_more_items, use_container_width=True)



with tab2:
    st.header("Unleash Your Inner Designer")
    col1, col2 = st.columns(2)
    with col1:
        st.subheader("1. Upload Base Item & Mask")
        base_img_file = st.file_uploader("Base Clothing Image", type=['png', 'jpg', 'jpeg'], key="base")
        mask_img_file = st.file_uploader("Mask Image (White area is replaced)", type=['png', 'jpg', 'jpeg'], key="mask")
    with col2:
        st.subheader("2. Describe Your Idea")
        prompt = st.text_input("Prompt", placeholder="e.g., a roaring tiger head")
        generate_button = st.button("Generate & Find Similar", type="primary", key="generate")

    if generate_button:
        if base_img_file and mask_img_file and prompt:
            with st.spinner("Creating your masterpiece... This can take up to a minute."):
                base_img = Image.open(base_img_file).convert("RGB")
                mask_img = Image.open(mask_img_file).convert("RGB")
                st.session_state.generated_design = create_edited_design_enhanced(base_img, mask_img, prompt)
                st.session_state.similar_items = search_by_image(st.session_state.generated_design, top_k=6)
        else:
            st.error("Please provide a base image, a mask image, and a text prompt.")

    if st.session_state.generated_design:
        st.markdown("---")
        st.subheader("3. Your Unique Creation & Similar Real Items")
        res_col1, res_col2 = st.columns([1, 2])
        with res_col1:
            st.image(st.session_state.generated_design, caption="Your Generated Design", use_column_width=True)
        with res_col2:
            if st.session_state.similar_items:
                cols = st.columns(3)
                for i, (fname, score) in enumerate(st.session_state.similar_items):
                    image_path = os.path.join(IMAGE_DIR, fname)
                    if os.path.exists(image_path):
                        cols[i % 3].image(image_path, use_column_width=True)
            else:
                st.info("No similar items found in the database.")

with tab3:
    st.header("Experience the Future of Fitting")
    st.info("This is a conceptual demonstration of a Virtual Try-On system using your dataset.")
    col1, col2 = st.columns(2)
    with col1:
        person_img_file = st.file_uploader("1. Upload Person Image", type=['png', 'jpg', 'jpeg'], key="vton_p")
    with col2:
        cloth_img_file = st.file_uploader("2. Upload Clothing Item", type=['png', 'jpg', 'jpeg'], key="vton_c")
        
    if st.button("Generate Virtual Try-On", type="primary", key="vton_gen"):
        if person_img_file and cloth_img_file:
            person_img = Image.open(person_img_file)
            cloth_img = Image.open(cloth_img_file)
            st.image(virtual_try_on_placeholder(person_img, cloth_img), caption="VTON Process Simulation")
        else:
            st.error("Please upload both a person and a clothing image.")

with tab4:
    st.header("Sketch-Based Discovery üñäÔ∏è")
    sketch_file = st.file_uploader("Upload Sketch Image", type=['png', 'jpg', 'jpeg'], key="sketch_uploader")
    
    if sketch_file:
        sketch = Image.open(sketch_file).convert("RGB").resize((224, 224))
        st.image(sketch, caption="Uploaded Sketch",use_column_width=False)
        
        if st.button("Show Recommendations", type="primary"):
            with st.spinner("üß† Describing your sketch and searching..."):
                caption = describe_sketch(sketch)
                sketch_embed = hybrid_clip_embedding(sketch, caption)
                sims = image_embeds.numpy() @ sketch_embed
                topk = np.argsort(-sims)[:6]

                cols = st.columns(6)
                for i, idx in enumerate(topk):
                    img_path = os.path.join(IMAGE_DIR, image_filenames[idx])
                    if os.path.exists(img_path):
                        cols[i].image(img_path, caption=f"Score: {sims[idx]:.2f}", use_column_width=True)

In [4]:
# ----------------------------------------------------------------------------------
# Cell 2: Run the Streamlit Application with ngrok using Threads
# ----------------------------------------------------------------------------------
import os
import threading
import time
from pyngrok import ngrok
from kaggle_secrets import UserSecretsClient

def run_streamlit():
    os.system("streamlit run app.py --server.headless true --server.enableCORS false --server.port 8501")

print("Setting up ngrok tunnel...")
try:
    authtoken = UserSecretsClient().get_secret("NGROK_AUTH_TOKEN")
    ngrok.set_auth_token(authtoken)
    print("‚úÖ ngrok authtoken set successfully!")
except Exception as e:
    print(f"‚ö†Ô∏è ngrok authtoken not found. The tunnel will be temporary.")

for tunnel in ngrok.get_tunnels():
    ngrok.disconnect(tunnel.public_url)
    print(f"üîå Disconnected old tunnel: {tunnel.public_url}")

thread = threading.Thread(target=run_streamlit)
thread.start()
time.sleep(5) 

public_url = ngrok.connect(8501)
print("üöÄ Your Streamlit App is live!")
print(f"üîó Public URL: {public_url}")

Setting up ngrok tunnel...
‚úÖ ngrok authtoken set successfully!
üîå Disconnected old tunnel: https://3d240b3dcfb2.ngrok-free.app


2025-07-21 16:59:39.073 
As a result, 'server.enableCORS' is being overridden to 'true'.

More information:
In order to protect against CSRF attacks, we send a cookie with each request.
To do so, we must specify allowable origins, which places a restriction on
cross-origin resource sharing.

If cross origin resource sharing is required, please disable server.enableXsrfProtection.
            
2025-07-21 16:59:39.176 Port 8501 is already in use



Collecting usage statistics. To deactivate, set browser.gatherUsageStats to False.

üöÄ Your Streamlit App is live!
üîó Public URL: NgrokTunnel: "https://08207d102df3.ngrok-free.app" -> "http://localhost:8501"


2025-07-21 17:01:47.974 Please replace `st.experimental_rerun` with `st.rerun`.

`st.experimental_rerun` will be removed after 2024-04-01.
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [00:10<00:00,  4.97it/s]
2025-07-21 17:04:37.814 Please replace `st.experimental_rerun` with `st.rerun`.

`st.experimental_rerun` will be removed after 2024-04-01.
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [00:10<00:00,  4.96it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [00:09<00:00,  5.03it/s]


In [3]:
%%writefile app.py

import streamlit as st
import torch
import pandas as pd
import os
import re
from PIL import Image, ImageDraw, ImageFont
from collections import defaultdict
from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration
from diffusers import StableDiffusionInpaintPipeline
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn.functional as F
from streamlit_drawable_canvas import st_canvas
import subprocess
import numpy as np

# =============================================================================
# 1. PAGE CONFIGURATION & INITIALIZATION
# =============================================================================
st.set_page_config(
    page_title="AI Fashion Stylist Pro",
    page_icon="ü§ñ",
    layout="wide"
)

# Initialize session state for all the variables we need to track
for key in ['search_results', 'generated_design', 'similar_items', 'creative_search_results', 
            'selected_creative_image_fname', 'vton_result_image', 'sketch_results', 
            'sketch_caption', 'uploaded_sketch', 'cart', 'items_to_show', 'search_active']:
    if key not in st.session_state:
        st.session_state[key] = [] if key == 'cart' else False if key == 'search_active' else 12 if key == 'items_to_show' else None

# =============================================================================
# 2. MODEL & DATA LOADING (Cached to run only once)
# =============================================================================
@st.cache_resource
def load_all_models_and_data():
    """Loads all models and data files into memory, cached for performance."""
    print("Executing one-time resource loading...")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

    inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
        "runwayml/stable-diffusion-inpainting",
        torch_dtype=torch.float16,
    ).to(device)

    # Use the correct, verified paths from our previous debugging sessions
    data_dir = '/kaggle/input/clothe/clothes_tryon_dataset/train/cloth'
    embeddings_path = "/kaggle/input/clip-embed/pytorch/default/1/clip_image_embeds (1).pt"
    tagged_data_path = "/kaggle/input/tagged-fashion-data/tagged_fashion_data.json"

    if not os.path.exists(embeddings_path) or not os.path.exists(tagged_data_path):
        st.error(f"CRITICAL ERROR: Data files not found. Ensure you have attached the 'fashion-app-data' Kaggle dataset.")
        st.stop()
        
    embedding_data = torch.load(embeddings_path, map_location='cpu')
    image_embeds = embedding_data["embeds"].numpy()
    image_filenames = embedding_data["files"]

    tagged_df = pd.read_json(tagged_data_path, orient='records', lines=True)
    tagged_df_indexed = tagged_df.set_index('filename')

    print("Resource loading complete.")
    return device, clip_model, clip_processor, inpaint_pipe, image_embeds, image_filenames, tagged_df, tagged_df_indexed, blip_model, blip_processor

DEVICE, clip_model, clip_processor, inpaint_pipe, image_embeds, image_filenames, tagged_df, tagged_df_indexed, blip_model, blip_processor = load_all_models_and_data()

# CRITICAL PATHS - Ensure these match your Kaggle dataset structure
BASE_DATA_PATH = "/kaggle/input/clothe/clothes_tryon_dataset/train/cloth"
IMAGE_DIR = "/kaggle/input/clothe/clothes_tryon_dataset/train/cloth"
PERSON_IMAGE_DIR = "/kaggle/input/clothe/clothes_tryon_dataset/train/image"

# =============================================================================
# 3. BACKEND FUNCTIONS
# =============================================================================
@st.cache_data
def generate_dropdown_options(_df):
    options = defaultdict(set)
    for attr in ['brand', 'color', 'pattern', 'style', 'neckline']:
        options[attr].add("Any")
    for tag_list in _df['tags']:
        if not isinstance(tag_list, list): continue
        for tag in tag_list:
            try:
                attribute, value = tag.split('_', 1)
                options[attribute].add(value.replace('_', ' '))
            except ValueError: continue
    return {attribute: sorted(list(values)) for attribute, values in options.items()}

def search_by_text(query, top_k=500):
    text_inputs = clip_processor(text=[query], return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        text_feat = clip_model.get_text_features(**text_inputs)
        text_feat = F.normalize(text_feat, dim=1)
    sims = cosine_similarity(text_feat.cpu().numpy(), image_embeds)[0]
    top_idx = np.argsort(-sims)[:top_k]
    return [(image_filenames[i], sims[i]) for i in top_idx]

def search_by_image(query_image, top_k=500):
    if query_image is None: return []
    inputs = clip_processor(images=query_image, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        query_embed = clip_model.get_image_features(**inputs)
        query_embed = F.normalize(query_embed, dim=1)
    sims = cosine_similarity(query_embed.cpu().numpy(), image_embeds)[0]
    top_idx = np.argsort(-sims)[:top_k]
    return [(image_filenames[i], sims[i]) for i in top_idx]

def multi_modal_search(query_text, query_image, image_weight=0.5, top_k=500):
    text_feat = F.normalize(clip_model.get_text_features(**clip_processor(text=[query_text], return_tensors="pt").to(DEVICE)))
    img_feat = F.normalize(clip_model.get_image_features(**clip_processor(images=query_image, return_tensors="pt").to(DEVICE)))
    combined_feat = F.normalize((image_weight * img_feat) + ((1.0 - image_weight) * text_feat))
    sims = cosine_similarity(combined_feat.cpu().detach().numpy(), image_embeds)[0]
    top_idx = np.argsort(-sims)[:top_k]
    return [(image_filenames[i], sims[i]) for i in top_idx]

def flexible_ranked_filter(initial_results, filters):
    if not filters: return initial_results
    expected_tags = {f"{attr}_{val.replace(' ', '_')}" for attr, val in filters.items()}
    scored_results = []
    for filename, original_score in initial_results:
        try:
            item_tags = set(tagged_df_indexed.loc[filename, 'tags'])
            match_count = len(expected_tags.intersection(item_tags))
            if match_count > 0:
                scored_results.append((filename, original_score, match_count))
        except KeyError: continue
    sorted_results = sorted(scored_results, key=lambda x: (x[2], x[1]), reverse=True)
    return [(filename, score) for filename, score, _ in sorted_results]

def add_to_cart(item_filename):
    if item_filename not in st.session_state.cart:
        st.session_state.cart.append(item_filename)

def remove_from_cart(item_filename):
    if item_filename in st.session_state.cart:
        st.session_state.cart.remove(item_filename)

def clear_cart():
    st.session_state.cart = []

def load_more_items():
    st.session_state.items_to_show += 12

def handle_search(text_query, image_query_file, image_weight, filters):
    st.session_state.search_active = True
    initial_results = []
    image_query = Image.open(image_query_file) if image_query_file else None
    if text_query and image_query: initial_results = multi_modal_search(text_query, image_query, image_weight)
    elif image_query: initial_results = search_by_image(image_query)
    elif text_query: initial_results = search_by_text(text_query)
    else: initial_results = [(fname, 1.0) for fname in image_filenames]
    st.session_state.search_results = flexible_ranked_filter(initial_results, filters)

def clear_search():
    st.session_state.search_active = False
    st.session_state.search_results = []
    st.session_state.items_to_show = 12

def describe_sketch(img: Image.Image) -> str:
    inputs = blip_processor(images=img, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        out = blip_model.generate(**inputs, max_new_tokens=50)
    return blip_processor.decode(out[0], skip_special_tokens=True)

def hybrid_clip_embedding(image: Image.Image, text: str):
    inputs = clip_processor(text=[text], images=image, return_tensors="pt", padding=True).to(DEVICE)
    with torch.no_grad():
        img_feat  = clip_model.get_image_features(pixel_values=inputs["pixel_values"])
        text_feat = clip_model.get_text_features(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
    img_feat  = F.normalize(img_feat, dim=-1); text_feat = F.normalize(text_feat, dim=-1)
    hybrid_feat = (img_feat + text_feat) / 2
    return hybrid_feat.cpu().numpy()[0]

def find_top_k(query_emb, dataset_embs, k=5):
    sims = dataset_embs @ query_emb
    idx = np.argsort(-sims)[:k]
    return idx, sims[idx]

def enhance_prompt(base_prompt):
    enhanced_positive = f"masterpiece, photorealistic, high-quality professional product shot of a {base_prompt}, intricate, hyper-detailed, sharp focus, cinematic lighting, 8k uhd, on a mannequin, clean white background"
    negative = "lowres, blurry, bad anatomy, error, worst quality, jpeg artifacts, ugly, duplicate, morbid, out of frame, watermark, text, signature, person, model"
    return enhanced_positive, negative

def create_edited_design_enhanced(base_image, mask_image, user_prompt):
    enhanced_prompt, negative_prompt = enhance_prompt(user_prompt)
    with torch.no_grad():
        edited_image = inpaint_pipe(
            prompt=enhanced_prompt, image=base_image.resize((512, 512)), 
            mask_image=mask_image.resize((512, 512)),
            negative_prompt=negative_prompt, num_inference_steps=50, guidance_scale=8.5
        ).images[0]
    return edited_image

def virtual_try_on_placeholder():
    # Placeholder function, VTON is not implemented in this version
    pass

# =============================================================================
# 4. STREAMLIT UI LAYOUT
# =============================================================================

st.title("ü§ñ AI Fashion Stylist Pro")
st.markdown("Discover, create, and virtually try on your next favorite outfit.")

tab1, tab2, tab3, tab4 = st.tabs(["üõçÔ∏è Browse & Search", "üé® Creative Director", "üë§ Virtual Try-On", "‚úçÔ∏è Sketch to Fashion"])

with tab1:
    with st.sidebar:
        st.header("Search & Filter")
        text_query = st.text_input("Text Description", placeholder="e.g., a blue floral blouse")
        image_query_file = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'], key="main_uploader")
        image_weight = 0.5
        if text_query and image_query_file:
            image_weight = st.slider("Image vs. Text Influence", 0.0, 1.0, 0.5, 0.1)
        st.subheader("Smart Tags")
        dropdowns = generate_dropdown_options(tagged_df)
        filters = {}
        for attr, options in dropdowns.items():
            selected = st.selectbox(attr.capitalize(), options, index=0)
            if selected != "Any":
                filters[attr] = selected
        col1, col2 = st.columns(2)
        with col1:
            st.button("Search / Filter", on_click=handle_search, args=(text_query, image_query_file, image_weight, filters), type="primary")
        with col2:
            st.button("Browse All", on_click=clear_search)
        st.markdown("---")
        st.header("Shopping Cart")
        if not st.session_state.cart:
            st.info("Your cart is empty.")
        else:
            for item_fname in st.session_state.cart:
                cart_col1, cart_col2 = st.columns([1, 3])
                with cart_col1:
                    st.image(os.path.join(IMAGE_DIR, item_fname), use_column_width=True)
                with cart_col2:
                    st.write(f"**{item_fname}**")
                    st.button("Remove", key=f"remove_{item_fname}", on_click=remove_from_cart, args=(item_fname,))
            st.markdown("---")
            if st.button("Clear Cart"):
                clear_cart()
                st.experimental_rerun()

    if st.session_state.search_active:
        st.header("Search Results")
        display_items = st.session_state.search_results
        if not display_items:
            st.warning("Your search returned no results. Try broadening your query or click 'Browse All'.")
    else:
        st.header("Browse Our Collection")
        display_items = [(fname, None) for fname in image_filenames]
    
    if display_items:
        items_to_display_now = display_items[:st.session_state.items_to_show]
        num_cols = 6
        cols = st.columns(num_cols)
        for i, (fname, score) in enumerate(items_to_display_now):
            image_path = os.path.join(IMAGE_DIR, fname)
            if os.path.exists(image_path):
                with cols[i % num_cols]:
                    st.image(image_path, use_column_width=True)
                    st.button("Add to Cart", key=f"add_{fname}", on_click=add_to_cart, args=(fname,))
        if not st.session_state.search_active and st.session_state.items_to_show < len(display_items):
            st.button("Load More", on_click=load_more_items, use_container_width=True)

with tab2:
    st.header("Unleash Your Inner Designer")
    st.subheader("1. Find a Base Item to Edit")
    creative_search_term = st.text_input("Search for a clothing item", key="creative_search")
    if st.button("Search", key="creative_search_btn"):
        if creative_search_term:
            with st.spinner("Searching..."):
                st.session_state.creative_search_results = search_by_text(creative_search_term, top_k=6)
        else:
            st.session_state.creative_search_results = None
    if st.session_state.creative_search_results:
        st.markdown("---")
        st.write("Click 'Select' on an item to start editing:")
        cols = st.columns(6)
        for i, (fname, score) in enumerate(st.session_state.creative_search_results):
            image_path = os.path.join(IMAGE_DIR, fname)
            if os.path.exists(image_path):
                cols[i].image(image_path, use_column_width=True)
                if cols[i].button("Select", key=f"select_{fname}"):
                    st.session_state.selected_creative_image_fname = fname
                    st.session_state.generated_design = None
                    st.session_state.similar_items = None
                    st.experimental_rerun()
    st.markdown("---")
    if st.session_state.selected_creative_image_fname:
        st.subheader("2. Paint Over the Area to Change & Describe Your Idea")
        base_image_path = os.path.join(IMAGE_DIR, st.session_state.selected_creative_image_fname)
        base_img = Image.open(base_image_path).convert("RGB")
        col1, col2 = st.columns(2)
        with col1:
            st.write("Use the thick brush to paint the area you want to change.")
            canvas_result = st_canvas(
                fill_color="rgba(255, 255, 255, 0)", stroke_width=35,
                stroke_color="rgba(255, 0, 0, 0.5)", background_image=base_img.resize((512, 512)),
                update_streamlit=True, height=512, width=512, drawing_mode="freedraw", key="canvas",
            )
        with col2:
            prompt = st.text_input("What should the new design be?", placeholder="e.g., a roaring tiger head")
            generate_button = st.button("Generate & Find Similar", type="primary", key="generate")
            if generate_button:
                if canvas_result.image_data is not None and prompt:
                    mask_array = canvas_result.image_data[:, :, 3] > 0
                    if mask_array.sum() > 0:
                        mask_img = Image.fromarray(mask_array.astype('uint8') * 255)
                        with st.spinner("Creating your masterpiece..."):
                            st.session_state.generated_design = create_edited_design_enhanced(base_img, mask_img, prompt)
                            st.session_state.similar_items = search_by_image(st.session_state.generated_design, top_k=6)
                    else: st.warning("Please paint on the image to indicate the area to change.")
                else: st.error("Please paint on the image and provide a text prompt.")
    if st.session_state.generated_design:
        st.markdown("---"); st.subheader("3. Your Unique Creation & Similar Real Items")
        res_col1, res_col2 = st.columns([1, 2])
        with res_col1: st.image(st.session_state.generated_design, caption="Your Generated Design", use_column_width=True)
        with res_col2:
            if st.session_state.similar_items:
                cols = st.columns(3)
                for i, (fname, score) in enumerate(st.session_state.similar_items):
                    image_path = os.path.join(IMAGE_DIR, fname)
                    if os.path.exists(image_path): cols[i % 3].image(image_path, use_column_width=True)
            else: st.info("No similar items found in the database.")

with tab3:
    # (Placeholder VTON code)
    st.header("Experience the Future of Fitting")
    st.info("This is a conceptual demonstration of a Virtual Try-On system using your dataset.")
    col1, col2 = st.columns(2)
    with col1:
        person_img_file = st.file_uploader("1. Upload Person Image", type=['png', 'jpg', 'jpeg'], key="vton_p")
    with col2:
        cloth_img_file = st.file_uploader("2. Upload Clothing Item", type=['png', 'jpg', 'jpeg'], key="vton_c")
    if st.button("Generate Virtual Try-On", type="primary", key="vton_gen"):
        if person_img_file and cloth_img_file:
            person_img = Image.open(person_img_file)
            cloth_img = Image.open(cloth_img_file)
            st.image(virtual_try_on_placeholder(person_img, cloth_img), caption="VTON Process Simulation")
        else:
            st.error("Please upload both a person and a clothing image.")

with tab4:
    st.header("Sketch-Based Discovery üñäÔ∏è")
    sketch_file = st.file_uploader("Upload Sketch Image", type=['png', 'jpg', 'jpeg'], key="sketch_uploader")
    if sketch_file:
        sketch = Image.open(sketch_file).convert("RGB").resize((224, 224))
        st.image(sketch, caption="Uploaded Sketch", use_column_width=False)
        if st.button("Show Recommendations", type="primary", key="sketch_btn"):
            with st.spinner("üß† Describing your sketch and searching..."):
                caption = describe_sketch(sketch)
                st.session_state.sketch_caption = caption
                sketch_embed = hybrid_clip_embedding(sketch, caption)
                sims = image_embeds @ sketch_embed
                topk = np.argsort(-sims)[:6]
                st.session_state.sketch_results = []
                for i, idx in enumerate(topk):
                    st.session_state.sketch_results.append((image_filenames[idx], sims[idx]))

    if st.session_state.sketch_results:
        st.markdown("---")
        #st.subheader(f"AI Caption: *\"{st.session_state.sketch_caption}\"*")
        st.subheader("Top 6 Matches Found:")
        cols = st.columns(6)
        for i, (fname, score) in enumerate(st.session_state.sketch_results):
            image_path = os.path.join(IMAGE_DIR, fname)
            if os.path.exists(image_path):
                cols[i].image(image_path, caption=f"Score: {score:.2f}", use_column_width=True)

Overwriting app.py


In [5]:
# File: evaluate_similarity.py

import torch
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import os

# =============================================================================
# CONFIGURATION
# =============================================================================
# --- KAGGLE-SPECIFIC FILE PATHS ---
# Ensure these paths are correct for your Kaggle environment or local setup.
EMBEDDINGS_PATH = "/kaggle/input/clip-embed/pytorch/default/1/clip_image_embeds (1).pt"
TAGGED_DATA_PATH = "/kaggle/input/tagged-fashion-data/tagged_fashion_data.json"

# --- EVALUATION PARAMETERS ---
# K defines how many top results we should check for a match.
K = 5 

# =============================================================================
# MAIN EVALUATION SCRIPT
# =============================================================================

def main():
    """
    Main function to load data and run the similarity evaluation.
    """
    print("--- Starting Recommendation Dataset Similarity Evaluation ---")

    # --- 1. Load Data ---
    print(f"Loading embeddings from: {EMBEDDINGS_PATH}")
    if not os.path.exists(EMBEDDINGS_PATH):
        print(f"ERROR: Embeddings file not found at {EMBEDDINGS_PATH}")
        return
        
    embedding_data = torch.load(EMBEDDINGS_PATH, map_location='cpu')
    image_embeds = embedding_data["embeds"].numpy()
    image_filenames = embedding_data["files"]

    print(f"Loading tagged data from: {TAGGED_DATA_PATH}")
    if not os.path.exists(TAGGED_DATA_PATH):
        print(f"ERROR: Tagged data file not found at {TAGGED_DATA_PATH}")
        return
        
    tagged_df = pd.read_json(TAGGED_DATA_PATH, orient='records', lines=True)
    tagged_df_indexed = tagged_df.set_index('filename')

    print(f"Loaded {len(image_filenames)} items to evaluate.")
    print(f"Will check top {K} recommendations for each item.")

    # --- 2. Initialize Counters ---
    total_items_processed = 0
    hits_at_k = 0

    # --- 3. Calculate Similarity Matrix ---
    # This is a one-time, memory-intensive operation but much faster than per-item calculation.
    print("\nCalculating similarity matrix for all items... (This may take a moment)")
    similarity_matrix = cosine_similarity(image_embeds)
    
    # --- 4. Iterate and Evaluate Each Item ---
    print("Evaluating each item as a query...")
    # Use tqdm for a nice progress bar
    for i, query_filename in enumerate(tqdm(image_filenames, desc="Processing Items")):
        try:
            # Get the ground truth tags for our query item
            query_tags = set(tagged_df_indexed.loc[query_filename, 'tags'])
            if not query_tags:
                continue # Skip items that have no tags to compare against
        except KeyError:
            # Skip if the item from embeddings isn't in our tagged data
            continue
            
        total_items_processed += 1
        
        # Get the similarity scores for the current item against all others
        sim_scores = similarity_matrix[i]
        
        # Get the indices of the top K+1 most similar items (because the top one is itself)
        # We use np.argsort to sort and get indices
        top_indices = np.argsort(-sim_scores)[1:K+1] # Exclude the first one (itself)

        # --- 5. Check for a "Hit" ---
        found_hit = False
        for result_idx in top_indices:
            try:
                result_filename = image_filenames[result_idx]
                result_tags = set(tagged_df_indexed.loc[result_filename, 'tags'])
                
                # A "hit" is defined as having at least one tag in common
                if query_tags.intersection(result_tags):
                    found_hit = True
                    break # We found a match, no need to check other results for this query
            except KeyError:
                continue
        
        if found_hit:
            hits_at_k += 1

    # --- 6. Calculate and Display Final Score ---
    if total_items_processed == 0:
        print("\nEvaluation could not be completed. No items with tags were found to process.")
        return
        
    percentage_similarity = (hits_at_k / total_items_processed) * 100
    
    print("\n--- üìà Evaluation Complete ---")
    print(f"Total Items with Tags Evaluated: {total_items_processed}")
    print(f"Successful Recommendations (Hits @ {K}): {hits_at_k}")
    print(f"Percentage Similarity (Recall@{K}): {percentage_similarity:.2f}%")
    print("\nThis means that for a given clothing item, there is a "
          f"{percentage_similarity:.2f}% chance that at least one of the top {K} "
          "visually similar recommendations also shares a semantic tag (like color, brand, or style) with it.")

if __name__ == "__main__":
    main()

--- Starting Recommendation Dataset Similarity Evaluation ---
Loading embeddings from: /kaggle/input/clip-embed/pytorch/default/1/clip_image_embeds (1).pt
Loading tagged data from: /kaggle/input/tagged-fashion-data/tagged_fashion_data.json
Loaded 11647 items to evaluate.
Will check top 5 recommendations for each item.

Calculating similarity matrix for all items... (This may take a moment)
Evaluating each item as a query...


Processing Items: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11647/11647 [00:03<00:00, 3087.74it/s]


--- üìà Evaluation Complete ---
Total Items with Tags Evaluated: 11076
Successful Recommendations (Hits @ 5): 10682
Percentage Similarity (Recall@5): 96.44%

This means that for a given clothing item, there is a 96.44% chance that at least one of the top 5 visually similar recommendations also shares a semantic tag (like color, brand, or style) with it.



