# Imports

In [3]:
import io
import gc
import os
import sys
import cv2
import torch
import requests
import tempfile
import contextlib
import numpy as np
from PIL import Image
from typing import List
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from transformers import pipeline
from torchvision import transforms
from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoProcessor, AutoModelForCausalLM 
from detector_utils import adapt_mmdet_pipeline, init_detector, process_images_detector
from classes_and_palettes import (
    COCO_KPTS_COLORS,
    COCO_WHOLEBODY_KPTS_COLORS,
    GOLIATH_KPTS_COLORS,
    GOLIATH_SKELETON_INFO,
    GOLIATH_KEYPOINTS,
    GOLIATH_PALETTE, 
    GOLIATH_CLASSES
)

In [4]:
# np.set_printoptions(threshold=sys.maxsize)

# Helpers

In [5]:
"""
    SAM2 Plotting 
"""
np.random.seed(3)

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()
        
"""
    Sapiense Foundation Models
"""
class ConfigSapiens:
    ASSETS_DIR = "/home/ilyass/workspace/tamp_warm_start/notebooks/sapiens/assets/"
    CHECKPOINTS_DIR = "/home/ilyass/workspace/tamp_warm_start/notebooks/sapiens/assets/checkpoints/"
    CHECKPOINTS = {
        "0.3bp": "pose/sapiens_0.3b_goliath_best_goliath_AP_575_torchscript.pt2",
        "1bp": "pose/sapiens_1b_goliath_best_goliath_AP_640_torchscript.pt2",

        "0.3s": "segmentation/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2",
        "0.6s": "segmentation/sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2",
        "1bs": "segmentation/sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2"
    }
    DETECTION_CHECKPOINT = "/home/ilyass/workspace/tamp_warm_start/notebooks/sapiens/assets/checkpoints/pose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth"
    DETECTION_CONFIG = "/home/ilyass/workspace/tamp_warm_start/notebooks/sapiens/assets/assets_rtmdet_m_640-8xb32_coco-person_no_nms.py"
    
class ModelManager:
    @staticmethod
    def load_model(checkpoint_name: str):
        if checkpoint_name is None:
            return None

        checkpoint_path = os.path.join(ConfigSapiens.CHECKPOINTS_DIR, checkpoint_name)
        model = torch.jit.load(checkpoint_path)
        model.eval()
        model.to(device)
        
        return model

    @staticmethod
    @torch.inference_mode()
    def run_model_keypoints(model, input_tensor):
        return model(input_tensor)

    @staticmethod
    @torch.inference_mode()
    def run_model_segmentation(model, input_tensor, height, width):
        output = model(input_tensor)
        output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
        _, preds = torch.max(output, 1)
        return preds

class ImageProcessor:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.Resize((1024, 768)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], 
                                 std=[58.5/255, 57.0/255, 57.5/255])
        ])
        self.detector = init_detector(
            ConfigSapiens.DETECTION_CONFIG, ConfigSapiens.DETECTION_CHECKPOINT, device='cpu'
        )
        self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg)

    def detect_persons(self, image: Image.Image):
        # Convert PIL Image to tensor
        image = np.array(image)
        image = np.expand_dims(image, axis=0)

        # Perform person detection
        bboxes_batch = process_images_detector(
            image, 
            self.detector
        )
        bboxes = self.get_person_bboxes(bboxes_batch[0])  # Get bboxes for the first (and only) image
        
        return bboxes
    
    def get_person_bboxes(self, bboxes_batch, score_thr=0.3):
        person_bboxes = []
        for bbox in bboxes_batch:
            if len(bbox) == 5:  # [x1, y1, x2, y2, score]
                if bbox[4] > score_thr:
                    person_bboxes.append(bbox)
            elif len(bbox) == 4:  # [x1, y1, x2, y2]
                person_bboxes.append(bbox + [1.0])  # Add a default score of 1.0
        return person_bboxes

    @torch.inference_mode()
    def estimate_pose(self, image: Image.Image, bboxes: List[List[float]], model_name: str, kpt_threshold: float):
        pose_model = ModelManager.load_model(ConfigSapiens.CHECKPOINTS[model_name])
        
        result_image = image.copy()
        all_keypoints = []  # List to store keypoints for all persons

        for bbox in bboxes:
            cropped_img = self.crop_image(result_image, bbox)
            input_tensor = self.transform(cropped_img).unsqueeze(0).to(device)
            heatmaps = ModelManager.run_model_keypoints(pose_model, input_tensor)
            keypoints = self.heatmaps_to_keypoints(heatmaps[0].cpu().numpy())
            all_keypoints.append(keypoints)  # Collect keypoints
            result_image = self.draw_keypoints(result_image, keypoints, bbox, kpt_threshold)
        
        return result_image, all_keypoints

    def process_image_keypoints(self, image: Image.Image, model_name: str, kpt_threshold: str):
        bboxes = self.detect_persons(image)
        result_image, keypoints = self.estimate_pose(image, bboxes, model_name, float(kpt_threshold))
        return result_image, keypoints

    def process_image_segmentation(self, image: Image.Image, model_name: str):
        model = ModelManager.load_model(ConfigSapiens.CHECKPOINTS[model_name])
        input_tensor = self.transform(image).unsqueeze(0).to("cuda")
        
        preds = ModelManager.run_model_segmentation(model, input_tensor, image.height, image.width)
        mask = preds.squeeze(0).cpu().numpy()

        # Visualize the segmentation
        blended_image = self.visualize_pred_with_overlay(image, mask)

        # Create downloadable .npy file
        npy_path = tempfile.mktemp(suffix='.npy')
        np.save(npy_path, mask)

        return blended_image, npy_path

    def crop_image(self, image, bbox):
        if len(bbox) == 4:
            x1, y1, x2, y2 = map(int, bbox)
        elif len(bbox) >= 5:
            x1, y1, x2, y2, _ = map(int, bbox[:5])
        else:
            raise ValueError(f"Unexpected bbox format: {bbox}")
        
        crop = image.crop((x1, y1, x2, y2))
        return crop

    @staticmethod
    def heatmaps_to_keypoints(heatmaps):
        num_joints = heatmaps.shape[0]  # Should be 308
        keypoints = {}
        for i, name in enumerate(GOLIATH_KEYPOINTS):
            if i < num_joints:
                heatmap = heatmaps[i]
                y, x = np.unravel_index(np.argmax(heatmap), heatmap.shape)
                conf = heatmap[y, x]
                keypoints[name] = (float(x), float(y), float(conf))
        return keypoints

    @staticmethod
    def draw_keypoints(image, keypoints, bbox, kpt_threshold):
        image = np.array(image)

        # Handle both 4 and 5-element bounding boxes
        if len(bbox) == 4:
            x1, y1, x2, y2 = map(int, bbox)
        elif len(bbox) >= 5:
            x1, y1, x2, y2, _ = map(int, bbox[:5])
        else:
            raise ValueError(f"Unexpected bbox format: {bbox}")
                
        # Calculate adaptive radius and thickness based on bounding box size
        bbox_width = x2 - x1
        bbox_height = y2 - y1
        bbox_size = np.sqrt(bbox_width * bbox_height)
        
        radius = max(1, int(bbox_size * 0.006))  # minimum 1 pixel
        thickness = max(1, int(bbox_size * 0.006))  # minimum 1 pixel
        bbox_thickness = max(1, thickness//4)

        cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), bbox_thickness)
        
        # Draw keypoints
        for i, (name, (x, y, conf)) in enumerate(keypoints.items()):
            if conf > kpt_threshold and i < len(GOLIATH_KPTS_COLORS):
                x_coord = int(x * bbox_width / 192) + x1
                y_coord = int(y * bbox_height / 256) + y1
                color = GOLIATH_KPTS_COLORS[i]
                cv2.circle(image, (x_coord, y_coord), radius, color, -1)

        # Draw skeleton
        for _, link_info in GOLIATH_SKELETON_INFO.items():
            pt1_name, pt2_name = link_info['link']
            color = link_info['color']
            
            if pt1_name in keypoints and pt2_name in keypoints:
                pt1 = keypoints[pt1_name]
                pt2 = keypoints[pt2_name]
                if pt1[2] > kpt_threshold and pt2[2] > kpt_threshold:
                    x1_coord = int(pt1[0] * bbox_width / 192) + x1
                    y1_coord = int(pt1[1] * bbox_height / 256) + y1
                    x2_coord = int(pt2[0] * bbox_width / 192) + x1
                    y2_coord = int(pt2[1] * bbox_height / 256) + y1
                    cv2.line(image, (x1_coord, y1_coord), (x2_coord, y2_coord), color, thickness=thickness)

        return Image.fromarray(image)

    @staticmethod
    def visualize_pred_with_overlay(img, sem_seg, alpha=0.5):
        img_np = np.array(img.convert("RGB"))
        sem_seg = np.array(sem_seg)

        num_classes = len(GOLIATH_CLASSES)
        ids = np.unique(sem_seg)[::-1]
        legal_indices = ids < num_classes
        ids = ids[legal_indices]
        labels = np.array(ids, dtype=np.int64)

        colors = [GOLIATH_PALETTE[label] for label in labels]

        overlay = np.zeros((*sem_seg.shape, 3), dtype=np.uint8)

        for label, color in zip(labels, colors):
            overlay[sem_seg == label, :] = color

        blended = np.uint8(img_np * (1 - alpha) + overlay * alpha)
        return Image.fromarray(blended)

In [6]:
"""
    Florence VLM
"""
def run_example(task_prompt, image, processor, model, text_input=None):
    if text_input is None:
        prompt = task_prompt
    else:
        prompt = task_prompt + text_input
    inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.float16)
    generated_ids = model.generate(
      input_ids=inputs["input_ids"].cuda(),
      pixel_values=inputs["pixel_values"].cuda(),
      max_new_tokens=1024,
      early_stopping=False,
      do_sample=False,
      num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text, 
        task=task_prompt, 
        image_size=(image.width, image.height)
    )

    return parsed_answer

def convert_to_od_format(data):  
    """  
    Converts a dictionary with 'bboxes' and 'bboxes_labels' into a dictionary with separate 'bboxes' and 'labels' keys.  
  
    Parameters:  
    - data: The input dictionary with 'bboxes', 'bboxes_labels', 'polygons', and 'polygons_labels' keys.  
  
    Returns:  
    - A dictionary with 'bboxes' and 'labels' keys formatted for object detection results.  
    """  
    # Extract bounding boxes and labels  
    bboxes = data.get('bboxes', [])  
    labels = data.get('bboxes_labels', [])  
      
    # Construct the output format  
    od_results = {  
        'bboxes': bboxes,  
        'labels': labels  
    }  
      
    return od_results  

def plot_bbox(image, data):
   # Create a figure and axes  
    fig, ax = plt.subplots()  
      
    # Display the image  
    ax.imshow(image)  
      
    # Plot each bounding box  
    for bbox, label in zip(data['bboxes'], data['labels']):  
        # Unpack the bounding box coordinates  
        x1, y1, x2, y2 = bbox  
        # Create a Rectangle patch  
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')  
        # Add the rectangle to the Axes  
        ax.add_patch(rect)  
        # Annotate the label  
        plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))  
      
    # Remove the axis ticks and labels  
    ax.axis('off')  
      
    # Show the plot  
    plt.show()

# Device Definition

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Sapiens Foundation Models

In [7]:
def clean_gpu():
    with torch.no_grad():
        torch.cuda.empty_cache()

In [8]:
def sapiens_segmentation(img):
    image_processor = ImageProcessor()
    segmentation_image, npy_path = image_processor.process_image_segmentation(img, "1bs")
    clean_gpu()
    return segmentation_image, npy_path

### Segmentation

In [9]:
# # Plot RGB and depth side by side
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

# # Plot the first image
# ax1.imshow(rgb_image)
# ax1.set_title('RGB')
# ax1.axis('off')  # Hide axes

# # Plot the second image
# ax2.imshow(segmentation_image)
# ax2.set_title('Segmentation')
# ax2.axis('off')  # Hide axes

# # Adjust the layout and display the plot
# plt.tight_layout()
# plt.show()

### Depth + Removed Segmented Person

In [10]:
# # Read mask information
# mask = np.load(npy_path)

In [11]:
# cv_mask = np.where(mask > 0, 255, 0)
# cv_mask = np.asarray(cv_mask, dtype=np.uint8)

In [12]:
# depth_without_person = cv2.inpaint(np.asarray(depth_image), cv_mask, 40, cv2.INPAINT_TELEA)

In [13]:
# rgb_without_person = cv2.inpaint(np.asarray(rgb_image), cv_mask, 40, cv2.INPAINT_TELEA)

In [14]:
# # Plot RGB and depth side by side
# fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 6))

# # Plot the first image
# ax1.imshow(depth_image)
# ax1.set_title('Original Depth Image')
# ax1.axis('off')  # Hide axes

# # Plot the second image
# ax2.imshow(depth_without_person)
# ax2.set_title('Depth Image with Person Removal')
# ax2.axis('off')  # Hide axes

# # Plot the second image
# ax3.imshow(rgb_without_person)
# ax3.set_title('RGB Image with Person Removal')
# ax3.axis('off')  # Hide axes

# # Adjust the layout and display the plot
# plt.tight_layout()
# plt.show()

### Keypoints

In [15]:
def sapiens_pose(img):
    image_processor = ImageProcessor()
    keypoints_image, keypoints = image_processor.process_image_keypoints(img, "1bp", 0.3)
    clean_gpu()
    return keypoints_image, keypoints

In [16]:
# # Plot RGB and depth side by side
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

# # Plot the first image
# ax1.imshow(rgb_image)
# ax1.set_title('RGB')
# ax1.axis('off')  # Hide axes

# # Plot the second image
# ax2.imshow(keypoints_image)
# ax2.set_title('Keypoints')
# ax2.axis('off')  # Hide axes

# # Adjust the layout and display the plot
# plt.tight_layout()
# plt.show()

# Retrieval System

#### Matching Photos

In [17]:
# import numpy as np
# from PIL import Image
# from tensorflow.keras.applications import VGG16
# from tensorflow.keras.applications.vgg16 import preprocess_input
# from tensorflow.keras.preprocessing.image import img_to_array
# from sklearn.metrics.pairwise import cosine_similarity
# import os

# # Load pre-trained VGG16 model
# model = VGG16(weights='imagenet', include_top=False, pooling='avg')

# def extract_features(img_path):
#     img = Image.open(img_path).resize((224, 224))
#     img_array = img_to_array(img)
#     img_array = img_array[:,:,:3]
#     img_array = np.expand_dims(img_array, axis=0)
#     img_array = preprocess_input(img_array)
#     features = model.predict(img_array)
#     return features.flatten()

# def find_similar_images(reference_img_path, image_folder, top_n=3):
#     # Extract features for reference image
#     reference_features = extract_features(reference_img_path)
    
#     # Extract features for all images in the folder
#     image_features = []
#     image_paths = []
#     for img_name in os.listdir(image_folder):
#         if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
#             img_path = os.path.join(image_folder, img_name)
#             features = extract_features(img_path)
#             image_features.append(features)
#             image_paths.append(img_path)
    
#     # Calculate similarities
#     similarities = cosine_similarity([reference_features], image_features)[0]
    
#     # Sort and get top N similar images
#     top_indices = similarities.argsort()[-top_n:][::-1]
#     top_similar = [(image_paths[i], similarities[i]) for i in top_indices]
    
#     return top_similar

# # Usage example
# reference_image = '/home/ilyass/reference.png'
# # reference_image = './generated_images/grasp/floor/image0.png'
# image_folder = './generated_images/grasp/floor'
# # image_folder = '/home/ilyass/Pictures/Screenshots'
# similar_images = find_similar_images(reference_image, image_folder)

# for img_path, similarity in similar_images:
#     print(f"Image: {img_path}, Similarity: {similarity}")

# clean_gpu()

# del model
# import gc

# gc.collect()
# torch.cuda.empty_cache()

#### Mask Extraction

In [21]:
def florence_sam2_segmentation(img):
    # Load Florence model
    model_id = 'microsoft/Florence-2-large'
    model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype='auto').eval().to("cuda")
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

    # Load SAM2 model
    predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

    task_prompt = '<OPEN_VOCABULARY_DETECTION>'
    results = run_example(task_prompt, img, processor, model, text_input="a box")

    bbox_results  = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])

    with torch.inference_mode(), torch.autocast("cpu", dtype=torch.bfloat16):
        
        predictor.set_image(img)
    
        input_point_x = (bbox_results['bboxes'][0][0] + bbox_results['bboxes'][0][2])/2
        input_point_y = (bbox_results['bboxes'][0][1] + bbox_results['bboxes'][0][3])/2
        input_point = np.array([[input_point_x, input_point_y]])
        input_label = np.array([1])
    
        masks, scores, logits = predictor.predict(point_coords=input_point,
                                                  point_labels=input_label,
                                                  multimask_output=False)

    clean_gpu()
    del model, processor, predictor

    gc.collect()
    torch.cuda.empty_cache()

    return masks, scores, logits, bbox_results

In [19]:
# # Reference Image
# input_point = None
# input_label = None

# ref_reg_image = Image.open("/home/ilyass/reference.png")
# masks_ref, scores_ref, logits_ref, bbox_results_ref = extract_mask(ref_reg_image)
# show_masks(ref_reg_image, masks_ref, scores_ref, box_coords=bbox_results_ref['bboxes'][0], point_coords=input_point, input_labels=input_label)

In [20]:
# # Reference Image
# input_point = None
# input_label = None

# match_rgb_image = Image.open("./generated_images/grasp/floor/image13.png")
# masks_match, scores_match, logits_match, bbox_results_match = extract_mask(match_rgb_image)
# show_masks(match_rgb_image, masks_match, scores_match, box_coords=bbox_results_match['bboxes'][0], point_coords=input_point, input_labels=input_label)