<a href="https://colab.research.google.com/github/Jainam051/Multi-Modal-Product-Tagger-CLIP-SAM-/blob/main/SAM_%2B_CLIP_Tagger.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install git+https://github.com/openai/CLIP.git
!pip install opencv-python matplotlib
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
!pip install supervision
!pip install gradio

In [None]:
import torch
import clip
from PIL import Image, ImageDraw, ImageFont
import cv2
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import numpy as np
import matplotlib.pyplot as plt
import supervision as sv
import gradio as gr
import random


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to(device)

In [None]:

def low_quality(image_path, output_path="low_quality.jpg", quality=10):
    image = cv2.imread(image_path)

    if image is None:
        raise ValueError(f"Failed to load image: {image_path}")

    # Save image with low JPEG quality
    cv2.imwrite(output_path, image, [int(cv2.IMWRITE_JPEG_QUALITY), quality])


In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

In [None]:
def image_preprocess(image_path):
  image = cv2.imread(image_path)
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  masks = mask_generator.generate(image_rgb)
  return image ,image_rgb , masks







In [None]:
def new_image_preprocess(image):
  image_np = np.array(image)
  image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
  masks = mask_generator.generate(image_rgb)
  return image , image_rgb , masks

In [None]:
def generate_image(image, masks):
  mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
  detections = sv.Detections.from_sam(sam_result=masks)
  annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
  sv.plot_images_grid(
    images=[image, annotated_image],
    grid_size=(1, 2),
    titles=['source image', 'segmented image']
  )


In [None]:
def generate_result(image):

  mask_generator = SamAutomaticMaskGenerator(sam)
  image, image_rgb , masks = new_image_preprocess(image)
  result = []
  label_to_color = {}
  #generate_image(image,masks)
  with open("imagenet_classes.txt") as f:
    class_names = [f"a photo of a {line.strip()}" for line in f.readlines()]
  text_inputs = clip.tokenize(class_names).to(device)
  with torch.no_grad():
    text_features = clip_model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)
  for idx, mask in enumerate(masks):

    x0, y0, w, h = mask["bbox"]
    x1, y1 = x0 + w, y0 + h

    masked_img = image_rgb[y0:y1, x0:x1]


    if masked_img.shape[0] < 10 or masked_img.shape[1] < 10:
        continue

    pil_crop = Image.fromarray(masked_img)
    image_input = clip_preprocess(pil_crop).unsqueeze(0).to(device)

    with torch.no_grad():
        image_features = clip_model.encode_image(image_input)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        similarity = (100.0 * image_features @ text_features.T)
        probs = similarity.softmax(dim=-1).cpu().numpy()

    pred_idx = probs[0].argmax()
    label = class_names[pred_idx]
    confidence = probs[0][pred_idx]

    if confidence > 0.10 :
      if label not in label_to_color:
        label_to_color[label] = tuple(random.randint(0, 255) for _ in range(3))
      color = label_to_color[label]

      label_text = f"{label} ({confidence:.2f})"
      result.append(((x0,y0,w,h),label_text,color))


  return result


In [None]:
def segment_clip(image):
   result=  generate_result(image)
   annotated_img = image.convert("RGB")
   draw = ImageDraw.Draw(annotated_img)
   for item in result:
    box = item[0]
    title = item[1]
    color = item[2]
    x0, y0, w, h = box
    x1, y1 = x0 + w, y0 + h
    draw.rectangle([x0, y0, x1, y1], outline=color, width=3)
    bbox = draw.textbbox((0, 0), title)

    label_pos = (x1 - bbox[2] - 4, y0 + 4)
    label_width = bbox[2] - bbox[0]
    label_height = bbox[3] - bbox[1]
    draw.rectangle([label_pos, (x1, y0 + label_height + 8)], fill=color)
    draw.text(label_pos, title, fill="white")

   return annotated_img




In [None]:
# Gradio UI
gr.Interface(
    fn=segment_clip,
    inputs=gr.Image(type="pil"),
    outputs=gr.Image(type="pil"),
    title="Multi-Modal Product Tagger (CLIP + SAM)"
).launch(debug=True)
