# Libraries



In [1]:
!pip install torch diffusers  transformers pillow matplotlib mediapipe opencv-python opencv-contrib-python gradio blip lpips torchvision requests beautifulsoup4 fake_useragent

Collecting mediapipe
  Downloading mediapipe-0.10.20-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (9.7 kB)
Collecting gradio
  Downloading gradio-5.9.1-py3-none-any.whl.metadata (16 kB)
Collecting blip
  Downloading blip-0.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Collecting fake_useragent
  Downloading fake_useragent-2.0.3-py3-none-any.whl.metadata (17 kB)
Collecting sounddevice>=0.4.4 (from mediapipe)
  Downloading sounddevice-0.5.1-py3-none-any.whl.metadata (1.4 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.6-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.5.2 (from gradio)
  Downloading gradio_client-1.5.2-py3-none-any.whl.metadata (7.1 kB)
Collecting markupsafe

In [2]:
!apt-get update
!apt-get install -y wget unzip xvfb libxi6 libgconf-2-4
!apt-get install -y libappindicator1 fonts-liberation
!wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb
!dpkg -i google-chrome-stable_current_amd64.deb || apt-get -fy install
!rm google-chrome-stable_current_amd64.deb

!pip install selenium requests webdriver-manager

0% [Working]            Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Get:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [1,197 kB]
Get:4 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:5 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:7 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Hit:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:9 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:10 https://r2u.stat.illinois.edu/ubuntu jammy/main all Packages [8,563 kB]
Hit:11 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:12 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 Packages [1,226 kB]
Get:13 

In [3]:
%%capture

!pip install groq langchain_community sentence_transformers
!pip install llama-index-llms-groq
!pip install groq

# Data Scrapping

In [4]:
import os
import logging
import requests
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.by import By
from selenium.webdriver.chrome.options import Options
from webdriver_manager.chrome import ChromeDriverManager
from bs4 import BeautifulSoup
import time
import shutil
import urllib.parse
import concurrent.futures
import gradio as gr
from PIL import Image
import torch
from diffusers import StableDiffusionImg2ImgPipeline
import torchvision.transforms as transforms
from typing import List, Dict
from groq import Groq
from transformers import BlipProcessor, BlipForConditionalGeneration

# Configure logging
logging.basicConfig(level=logging.INFO)

# Websites to scrape
websites = {
    'junaidjamshed': 'https://www.junaidjamshed.com/womens/kurti.html?product_list_dir=desc&product_list_order=top_rated',
    'khaadi': 'https://pk.khaadi.com/ready-to-wear/essentials/kurta/kurta/?prefn1=filter_categories&prefv1=Kurta&srule=most-popular&start=0&sz=96',
}

# Keywords to filter images (specific to shirts)
keywords = ['shirt', 'kurta', 'kurti']

# Folder to save images
output_folder = "scraped_images"

# Clear output folder before scraping
if os.path.exists(output_folder):
    shutil.rmtree(output_folder)  # Delete the folder and its contents
os.makedirs(output_folder, exist_ok=True)  # Recreate the folder

# Selenium setup
options = Options()
options.add_argument("--headless")
options.add_argument("--no-sandbox")
options.add_argument("--disable-dev-shm-usage")
options.binary_location = "/usr/bin/google-chrome"

driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=options)

# Function to fetch image using requests
def fetch_image(img_url):
    try:
        if img_url and img_url.startswith('http'):
            img_data = requests.get(img_url, timeout=10).content
            return img_data
        else:
            logging.warning(f"Invalid image URL: {img_url}")
            return None
    except Exception as e:
        logging.error(f"Failed to fetch image {img_url}: {e}")
        return None

# Function to save images (with concurrency)
def save_images(site_name, images):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        img_data_list = list(executor.map(fetch_image, images[:10]))  # Fetch in parallel

    for i, img_data in enumerate(img_data_list):
        if img_data:
            img_name = f"{site_name}shirt{i + 1}.jpg"
            img_path = os.path.join(output_folder, img_name)
            try:
                with open(img_path, 'wb') as img_file:
                    img_file.write(img_data)
                logging.info(f"Saved {img_name}")
            except Exception as e:
                logging.error(f"Failed to save image {img_name}: {e}")

# Function to scrape images from Junaid Jamshed using Selenium
def scrape_images_junaidjamshed(site_name, url):
    try:
        driver.get(url)

        # Scroll to load all images (limited number of scrolls)
        last_height = driver.execute_script("return document.body.scrollHeight")
        scroll_limit = 5  # Limit the number of scrolls
        scroll_count = 0

        while scroll_count < scroll_limit:
            driver.execute_script("window.scrollBy(0, 1000);")
            time.sleep(2)  # Wait for images to load
            new_height = driver.execute_script("return document.body.scrollHeight")
            if new_height == last_height:
                break
            last_height = new_height
            scroll_count += 1

        time.sleep(3)

        images = []
        img_elements = driver.find_elements(By.TAG_NAME, "img")
        seen_urls = set()  # To track already seen images

        for img in img_elements:
            img_url = img.get_attribute('src') or img.get_attribute('data-src') or img.get_attribute('srcset')
            alt_text = img.get_attribute('alt')

            if img_url and img_url.startswith('data:image') or img_url in seen_urls:
                continue  # Skip base64 images or duplicates

            seen_urls.add(img_url)

            if alt_text and any(keyword.lower() in alt_text.lower() for keyword in keywords):
                images.append(img_url)

        save_images(site_name, images)
    except Exception as e:
        logging.error(f"Error scraping {site_name}: {e}")

# Function to scrape images from Khaadi using BeautifulSoup
def scrape_images_khaadi(site_name, url):
    try:
        response = requests.get(url, timeout=10)
        soup = BeautifulSoup(response.content, 'html.parser')

        images = []
        img_elements = soup.find_all('img')

        seen_urls = set()
        for img in img_elements:
            img_url = img.get('src') or img.get('data-src')
            alt_text = img.get('alt')

            # Skip base64 images
            if img_url and img_url.startswith('data:image'):
                continue

            # Handle relative URLs
            img_url = urllib.parse.urljoin(url, img_url)

            # Skip duplicate URLs
            if img_url in seen_urls:
                continue
            seen_urls.add(img_url)

            # Filter images by keywords in alt text
            if alt_text and any(keyword.lower() in alt_text.lower() for keyword in keywords):
                images.append(img_url)

        save_images(site_name, images)
    except Exception as e:
        logging.error(f"Error scraping {site_name}: {e}")

# Load all images from the 'scraped_images' folder
def load_images_from_folder(folder_path):
    images = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".jpg"):
            image_path = os.path.join(folder_path, filename)
            image = preprocess_image(image_path)
            images.append((image, image_path))  # Store image and path tuple
    return images

# Preprocess input images
def preprocess_image(image_path, size=(512, 512)):
    image = Image.open(image_path).convert("RGB")
    resize_transform = transforms.Resize(size)
    return resize_transform(image)

# Main scraping function
def scrape_data():
    all_images = []
    for site_name, url in websites.items():
        logging.info(f"Scraping {site_name} for shirts...")
        if site_name == 'junaidjamshed':
            scrape_images_junaidjamshed(site_name, url)
        elif site_name == 'khaadi':
            scrape_images_khaadi(site_name, url)

    # Reload the gallery images after scraping
    all_images = load_images_from_folder(output_folder)
    return all_images

# Start scraping
images = scrape_data()
logging.info(f"Scraping complete. Images saved: {[img[1] for img in images]}")

# Close the driver
driver.quit()

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

# Generative Model

In [5]:
import gradio as gr
import os
from PIL import Image
import torch
from diffusers import StableDiffusionImg2ImgPipeline
import torchvision.transforms as transforms
from transformers import BlipProcessor, BlipForConditionalGeneration

# Set up models and processor (Preload for efficiency)
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16
).to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

(…)ature_extractor/preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/308 [00:00<?, ?B/s]

safety_checker/config.json:   0%|          | 0.00/4.72k [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [4]:
import gradio as gr
import os
from PIL import Image
import torch
from diffusers import StableDiffusionImg2ImgPipeline
import torchvision.transforms as transforms
from typing import List, Dict
from groq import Groq
from transformers import BlipProcessor, BlipForConditionalGeneration
from functools import lru_cache  # Using lru_cache for caching
import lpips  # LPIPS similarity model

# Set up Groq API key
os.environ["GROQ_API_KEY"] = "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"
client = Groq()  # Initialize Groq API client
DEFAULT_MODEL = "llama-3.1-70b-versatile"

# Load BLIP model and processor for image captioning
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# LPIPS model for similarity checks
device = "cuda" if torch.cuda.is_available() else "cpu"
lpips_model = lpips.LPIPS(net="alex").to(device)

# Function to create assistant message format
def assistant(content: str):
    return {"role": "assistant", "content": content}

# Function to create user message format
def user(content: str):
    return {"role": "user", "content": content}

# Function for chat completion with Groq
def chat_completion(messages: List[Dict], model=DEFAULT_MODEL, temperature=0.6, top_p=0.9) -> str:
    response = client.chat.completions.create(
        messages=messages,
        model=model,
        temperature=temperature,
        top_p=top_p,
    )
    return response.choices[0].message.content

# Preprocess input images
def preprocess_image(image_path, size=(512, 512)):
    image = Image.open(image_path).convert("RGB")
    resize_transform = transforms.Resize(size)
    return resize_transform(image)

# Load all images from the 'scraped_images' folder
def load_images_from_folder(folder_path):
    images = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".jpg"):
            image_path = os.path.join(folder_path, filename)
            image = preprocess_image(image_path)
            images.append((image, image_path))  # Store image and path tuple
    return images

# Global variables to hold loaded images and descriptions
images = []
descriptions = {}

# Generate a description for an image using BLIP
def generate_description_with_blip(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    out = blip_model.generate(**inputs)
    description = processor.decode(out[0], skip_special_tokens=True)
    return description

# Generate a description for an image using BLIP and save it as a .txt file
def generate_description_with_blip_cached(image_path):
    description = generate_description_with_blip(image_path)
    # Save the description to a .txt file
    text_file_path = os.path.splitext(image_path)[0] + ".txt"
    with open(text_file_path, "w", encoding="utf-8") as text_file:
        text_file.write(description)
    return description

# Updated generate_prompt_from_selected_images function
def generate_prompt_from_selected_images(image_paths):
    """
    Generate a prompt for Stable Diffusion based on selected images, focusing on detailed descriptions of their design.
    """
    selected_descriptions = []
    for image_path in image_paths:
        image_name = os.path.basename(image_path)
        description = descriptions.get(image_name, "")
        if not description:
            description = generate_description_with_blip_cached(image_path)
            descriptions[image_name] = description
        selected_descriptions.append(description)

    if selected_descriptions:
        # Combine the descriptions and generate a prompt
        combined_descriptions = " ".join(selected_descriptions)
        groq_prompt = chat_completion([user(f"Emphasize the design details, patterns, and fabric texture. Ensure the description captures the style, structure, and overall appearance of the clothing while cropping out the figure's face entirely. {combined_descriptions}")]
)
        return groq_prompt
    else:
        return "No descriptions found for the selected images."


# Load the Stable Diffusion pipeline
def load_pipeline():
    return StableDiffusionImg2ImgPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=torch.float16
    ).to(device)

# Blend and generate image using the Stable Diffusion pipeline
def blend_and_generate_image(image_paths, alpha1, alpha2, generated_prompt):
    if not image_paths or len(image_paths) < 1:
        return None

    alpha1 = max(0, min(1, alpha1))
    alpha2 = max(0, min(1, alpha2))
    if alpha1 + alpha2 != 1:
        alpha2 = 1 - alpha1

    image1 = Image.open(image_paths[0]).convert("RGBA")
    blended_image = image1
    pipe = load_pipeline()
    output_image = pipe(
        prompt=generated_prompt,
        image=blended_image.convert("RGB"),
        strength=0.80,
        guidance_scale=7.5,
        num_inference_steps=50,
        generator=torch.manual_seed(42),
    ).images[0]

    return output_image

# Calculate similarity using LPIPS
def calculate_similarity(image1_path, image2_path):
    if not (image1_path and image2_path):
        return "Please select two valid images."

    image1 = preprocess_image(image1_path)
    image2 = preprocess_image(image2_path)

    transform = transforms.ToTensor()
    image1_tensor = transform(image1).unsqueeze(0).to(device)
    image2_tensor = transform(image2).unsqueeze(0).to(device)

    similarity = lpips_model(image1_tensor, image2_tensor).item()
    return similarity

# Modify the compare_generated_with_selected function to return both the similarity score and the most similar image
def compare_generated_with_selected(generated_image, selected_images):
    if not selected_images or not generated_image:
        return "Please select images to compare."

    max_similarity = float('inf')
    most_similar_image = None
    similarities = []

    # Calculate similarity for each selected image
    for img_path in selected_images:
        similarity_score = calculate_similarity(generated_image, img_path)
        similarities.append((img_path, similarity_score))

        # Track the most similar image (lowest similarity score)
        if similarity_score < max_similarity:
            max_similarity = similarity_score
            most_similar_image = img_path

    # Prepare the similarity scores and paths to be displayed
    similarity_info = "\n".join([f"{os.path.basename(img[0])}: Similarity Score: {img[1]:.4f}" for img in similarities])

    # Display the most similar image and score
    return most_similar_image, max_similarity  # Return the image path and similarity score

# Modify the UI component to display the most similar image and similarity score
def create_ui():
    with gr.Blocks() as app:
        with gr.Row():
            gr.Markdown("### Select images and generate a new image based on descriptions and blending")

        with gr.Row():
            fetch_data_button = gr.Button("Fetch New Data")
            upload_image = gr.File(label="Upload Image", type="filepath", file_types=[".jpg"])
            upload_button = gr.Button("Upload")

        with gr.Row():
            gallery = gr.Gallery(label="Loaded Images", value=[], interactive=True, columns=4, height="auto")

        with gr.Row():
            image1_display = gr.Textbox(label="First Selected Image Path", interactive=False)
            image2_display = gr.Textbox(label="Second Selected Image Path", interactive=False)

        with gr.Row():
            generated_prompt_display = gr.Textbox(label="Generated Prompt", interactive=True, lines=3)

        with gr.Row():
            similarity_score_display = gr.Textbox(label="Most Similarity Score", interactive=False)

        with gr.Row():
            alpha1_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Weight of First Image")
            alpha2_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Weight of Second Image")

        with gr.Row():
            output_generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)


        selected_images = []

        def handle_selection(evt: gr.SelectData):
            selected_path = images[evt.index][1] if evt.index < len(images) else None
            if selected_path:
                if selected_path in selected_images:
                    selected_images.remove(selected_path)
                elif len(selected_images) < 2:
                    selected_images.append(selected_path)
                else:
                    selected_images.pop(0)
                    selected_images.append(selected_path)

            generated_prompt = generate_prompt_from_selected_images(selected_images) if selected_images else ""
            image1_path = selected_images[0] if len(selected_images) > 0 else ""
            image2_path = selected_images[1] if len(selected_images) > 1 else ""
            return image1_path, image2_path, generated_prompt

        gallery.select(handle_selection, None, [image1_display, image2_display, generated_prompt_display])

        alpha1_slider.change(lambda x: 1 - x, inputs=alpha1_slider, outputs=alpha2_slider)
        alpha2_slider.change(lambda x: 1 - x, inputs=alpha2_slider, outputs=alpha1_slider)

        def fetch_new_data_and_update_gallery():
            global images, descriptions
            updated_images = load_images_from_folder("scraped_images")
            descriptions = {os.path.basename(img[1]): generate_description_with_blip_cached(img[1]) for img in updated_images}
            images = updated_images
            return gr.update(value=[img[0] for img in updated_images])

        fetch_data_button.click(fetch_new_data_and_update_gallery, None, gallery)

        upload_button.click(lambda img: gr.update(value=img), [upload_image], gallery)

        def check_similarity(image1_path, image2_path):
            similarity_score = calculate_similarity(image1_path, image2_path)
            # Get the most similar image and similarity score
            most_similar_image, max_similarity = compare_generated_with_selected(image2_path, [image1_path])
            return max_similarity, most_similar_image

        similarity_button = gr.Button("Check Similarity")
        similarity_button.click(check_similarity, [image1_display, image2_display], [similarity_score_display, output_generated_image])

        def compare_generated_and_selected(generated_image):
            most_similar_image, max_similarity = compare_generated_with_selected(generated_image, selected_images)
            return most_similar_image, max_similarity  # Return the image and similarity score

        similarity_button.click(compare_generated_and_selected, [output_generated_image], [image2_display, similarity_score_display])

        def blend_and_generate(image1_path, image2_path, alpha1, alpha2, generated_prompt):
            return blend_and_generate_image([image1_path, image2_path], alpha1, alpha2, generated_prompt)

        generate_button = gr.Button("Generate Image")
        generate_button.click(blend_and_generate, [image1_display, image2_display, alpha1_slider, alpha2_slider, generated_prompt_display], output_generated_image)

    return app


if __name__ == "__main__":
    app = create_ui()
    app.launch()


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://49ba1a3b72ef001361.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
