In [15]:
import json
import os
import random
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry
from diffusers import StableDiffusionInpaintPipeline
from groundingdino.util.inference import load_model, load_image, predict, annotate
from GroundingDINO.groundingdino.util import box_ops
import argparse
from tqdm import tqdm

In [16]:
with open('/content/drive/Shareddrives/FashionXchange/captions.json') as json_file:
    # Load the JSON data
    cloth_captions = json.load(json_file)

In [35]:
def find_occurrences(sentence, word_list):
    occurrences = []
    for word in word_list:
        if word.lower() in sentence.lower():
            occurrences.append(word)
    return occurrences


word_list = ["top", "shirt", "upper clothing", "T-shirt", "sweater", "trousers", "pants", "shorts", "lower clothing"]

image_masks_list = {}
for key,value in cloth_captions.items():
  occurences = find_occurrences(value, word_list)
  occurences.append("clothes")
  for i in range(len(occurences)):
    if occurences[i] in ["top", "shirt", "upper clothing", "T-shirt", "sweater" ]:
      occurences[i] = occurences[i] + ", arms"
    elif occurences[i] in ["trousers", "pants", "shorts", "lower clothing"]:
      occurences[i] = occurences[i] + ", legs"
    else:
      occurences[i] = occurences[i] + ", arms and legs"

  image_masks_list[key] = occurences

with open("/content/drive/Shareddrives/FashionXchange/image_mask_list.json", "w") as json_file:
  json.dump(image_masks_list, json_file)

In [18]:
device = "cuda"

# Paths
sam_checkpoint_path = "/content/sam_vit_h_4b8939.pth"
groundingdino_model_path = "/content/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
groundingdino_weights_path = "/content/groundingdino_swint_ogc.pth"

# SAM Parameters
model_type = "vit_h"
sam_model = sam_model_registry[model_type](checkpoint=sam_checkpoint_path).to(device=device)
sam_predictor = SamPredictor(sam_model)

# Grounding DINO
groundingdino_model = load_model(groundingdino_model_path, groundingdino_weights_path)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


final text_encoder_type: bert-base-uncased


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [23]:
def show_mask(mask, image, random_color=True):
    """
    Overlay a mask on an image and return the composited result.

    Args:
        mask (torch.Tensor): Mask to overlay.
        image (np.ndarray): Image to overlay the mask on.
        random_color (bool, optional): If True, overlay with random color.
                                      If False, use a fixed color. Default is True.

    Returns:
        np.ndarray: Image with the mask overlaid.
    """
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])

    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

def transform_boxes(predictor,boxes, src,device):
    """
    Transform boxes to adjust to the source image dimensions.

    Args:
        boxes (torch.Tensor): Bounding boxes in the format [x_center, y_center, width, height].
        src (np.ndarray): Source image.

    Returns:
        torch.Tensor: Transformed boxes.
    """
    H, W, _ = src.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
    return predictor.transform.apply_boxes_torch(boxes_xyxy, src.shape[:2]).to(device)

def save_image(image, file_path):
    """
    Save an image to the specified file path.

    Args:
        image (PIL.Image.Image): Image to be saved.
        file_path (str): Path where the image will be saved.
    """
    try:
        image.save(file_path)
        print(f"Image saved: {file_path}")
    except Exception as e:
        print(f"Error saving image to {file_path}: {e}")

In [24]:
def edit_image(path, item, prompt, box_threshold, text_threshold):
    """
    Edit an image by replacing objects using segmentation and inpainting.

    Args:
        path (str): Path to the image file.
        item (str): Object to be recognized in the image.
        prompt (str): Object to replace the selected object in the image.
        box_threshold (float): Threshold for bounding box predictions.
        text_threshold (float): Threshold for text predictions.

    Returns:
        np.ndarray: Edited image.
    """
    src, img = load_image(path)

    # Predict object bounding boxes, logits, and phrases
    boxes, logits, phrases = predict(
        model=groundingdino_model,
        image=img,
        caption=item,
        box_threshold=box_threshold,
        text_threshold=text_threshold
    )

    # Set up predictor
    sam_predictor.set_image(src)
    new_boxes = transform_boxes(sam_predictor,boxes, src,device)

    # Predict masks and annotations
    masks, _, _ = sam_predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=new_boxes,
        multimask_output=False,
    )

    # Overlay mask on annotated image
    img_annotated_mask = show_mask(
        masks[0][0].cpu(),
        annotate(image_source=src, boxes=boxes, logits=logits, phrases=phrases)[...,::-1]
    )

    # Apply inpainting pipeline
    edited_image = pipeline(prompt=prompt,
                        image=Image.fromarray(src).resize((512, 512)),
                        mask_image=Image.fromarray(masks[0][0].cpu().numpy()).resize((512, 512))
    ).images[0]

    return edited_image

In [49]:
def get_mask(img_path, looking_for, box_threshold=0.3, text_threshold=0.25):
  fig, axs = plt.subplots(1, 3, figsize=(12, 4))
  src, img = load_image(image_path)
  axs[0].imshow(src, cmap='gray')
  axs[0].set_title('Source Image')
  axs[0].axis('off')
  boxes, logits, phrases = predict(
        model=groundingdino_model,
        image=img,
        caption=looking_for,
        box_threshold=box_threshold,
        text_threshold=text_threshold
    )
  annotated_frame = annotate(image_source=src, boxes=boxes, logits=logits, phrases=phrases)
  axs[1].imshow(np.flip(annotated_frame, 2), cmap='gray')
  axs[1].set_title('DINO o/p')
  axs[1].axis('off')

  sam_predictor.set_image(src)
  new_boxes = transform_boxes(sam_predictor,boxes, src, device)
  masks, _, _ = sam_predictor.predict_torch(
      point_coords=None,
      point_labels=None,
      boxes=new_boxes,
      multimask_output=False,
  )
  masks = torch.any(masks, dim=0, keepdim=True)
  # img_annotated_mask = show_mask(
  #       masks[0][0].cpu(),
  #       annotate(image_source=src, boxes=boxes, logits=logits, phrases=phrases)[...,::-1]
  #   )
  # axs[2].imshow(img_annotated_mask, cmap='gray')
  # axs[2].set_title('SAM o/p')
  # axs[2].axis('off')
  return masks, src

In [41]:
import zipfile
import os

# Path to the zip file
zip_file_path = "/content/drive/Shareddrives/FashionXchange/images.zip"

# Directory to extract the contents to
extract_to_directory = "/content/drive/Shareddrives/FashionXchange"


# Open the zip file
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    # Extract all the contents to the specified directory
    zip_ref.extractall(extract_to_directory)

In [None]:
image_dir = "/content/drive/Shareddrives/FashionXchange/images"
masks_dir = "/content/drive/Shareddrives/FashionXchange/Masks"
with open('/content/drive/Shareddrives/FashionXchange/image_mask_list.json') as json_file:
    # Load the JSON data
    image_masks = json.load(json_file)
for key, value in tqdm(image_masks.items()):
  image_path = os.path.join(image_dir, key)
  cnt = 0
  for mask in value:
    looking_for = mask
    gen_mask, _ = get_mask(image_path, looking_for, box_threshold=0.3, text_threshold=0.25)
    cnt+=1
    gen_mask = np.squeeze(gen_mask.detach().cpu().numpy())
    gen_mask = Image.fromarray(gen_mask)
    img_mask_dir = os.path.join(masks_dir,key[:-4])
    os.makedirs(img_mask_dir, exist_ok=True)
    mask_file_path = os.path.join(img_mask_dir,key[:-4]+f"{cnt}"+".jpg")
    gen_mask.save(mask_file_path)




