In [None]:
from typing import Any, List, Dict, Optional, Union, Tuple
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
from dataclasses import dataclass
import numpy as np
import cv2
import torch
from PIL import Image

@dataclass
class BoundingBox:
    xmin: int
    ymin: int
    xmax: int
    ymax: int

    @property
    def xyxy(self) -> List[float]:
        return [self.xmin, self.ymin, self.xmax, self.ymax]

@dataclass
class DetectionResult:
    score: float
    label: str
    box: BoundingBox
    mask: Optional[np.array] = None

    @classmethod
    def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
        return cls(score=detection_dict['score'],
                   label=detection_dict['label'],
                   box=BoundingBox(xmin=detection_dict['box']['xmin'],
                                   ymin=detection_dict['box']['ymin'],
                                   xmax=detection_dict['box']['xmax'],
                                   ymax=detection_dict['box']['ymax']))

def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
    # Find contours in the binary mask
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Find the contour with the largest area
    largest_contour = max(contours, key=cv2.contourArea)

    # Extract the vertices of the contour
    polygon = largest_contour.reshape(-1, 2).tolist()

    return polygon

def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
    """
    Convert a polygon to a segmentation mask.

    Args:
    - polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
    - image_shape (tuple): Shape of the image (height, width) for the mask.

    Returns:
    - np.ndarray: Segmentation mask with the polygon filled.
    """
    # Create an empty mask
    mask = np.zeros(image_shape, dtype=np.uint8)

    # Convert polygon to an array of points
    pts = np.array(polygon, dtype=np.int32)

    # Fill the polygon with white color (255)
    cv2.fillPoly(mask, [pts], color=(255,))

    return mask

def get_boxes(results: DetectionResult) -> List[List[List[float]]]:
    boxes = []
    for result in results:
        xyxy = result.box.xyxy
        boxes.append(xyxy)

    return [boxes]

def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
    masks = masks.cpu().float()
    masks = masks.permute(0, 2, 3, 1)
    masks = masks.mean(axis=-1)
    masks = (masks > 0).int()
    masks = masks.numpy().astype(np.uint8)
    masks = list(masks)

    if polygon_refinement:
        for idx, mask in enumerate(masks):
            shape = mask.shape
            polygon = mask_to_polygon(mask)
            mask = polygon_to_mask(polygon, shape)
            masks[idx] = mask

    return masks

class SegmentationEngine:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.detector_id = "IDEA-Research/grounding-dino-tiny"
        self.object_detector = pipeline(model=self.detector_id, 
                                        task="zero-shot-object-detection", 
                                        device=self.device)

        self.segmenter_id = "facebook/sam-vit-huge"

        self.segmentator = AutoModelForMaskGeneration.from_pretrained(self.segmenter_id).to(self.device)
        self.seg_processor = AutoProcessor.from_pretrained(self.segmenter_id)


    def detect(
            self,
            image: Image.Image,
            labels: List[str],
            threshold: float = 0.5,
    ):
        
        labels = [label if label.endswith(".") else label+"." for label in labels]

        results = self.object_detector(image,  candidate_labels=labels, threshold=threshold)
        results = [DetectionResult.from_dict(result) for result in results]

        return results

    def segment(
            self,
            image: Image.Image,
            detection_results: List[Dict[str, Any]],
            polygon_refinement: bool = False,
            segmenter_id: Optional[str] = None,
    ):
        boxes = get_boxes(detection_results)
        inputs = self.seg_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self.device)

        outputs = self.segmentator(**inputs)
        masks = self.seg_processor.post_process_masks(
            masks=outputs.pred_masks,
            original_sizes=inputs.original_sizes,
            reshaped_input_sizes=inputs.reshaped_input_sizes
        )[0]

        masks = refine_masks(masks, polygon_refinement)

        for detection_result, mask in zip(detection_results, masks):
            detection_result.mask = mask

        return detection_results


    def grounded_segmentation(
        self,
        image: Union[Image.Image],
        labels: List[str],
        threshold: float = 0.3,
        polygon_refinement: bool = False,
    ) -> Tuple[np.ndarray, List[DetectionResult]]:

        detections = self.detect(image, labels, threshold)
        detections = self.segment(image, detections, polygon_refinement)

        return np.array(image), detections
     

In [None]:
segmentator = SegmentationEngine()

In [None]:
from transformers import pipeline
from glob import glob
from tqdm import tqdm
import os
import json
import matplotlib.pyplot as plt

model_id = "llava-hf/llava-1.5-7b-hf"

labels = ["cat", "dog", "fox"]
idx = 2


image_list = sorted(glob(f"/job/dataset/{labels[idx]}/*.jpg"))

json_labeled = {}

pipe = pipeline("image-text-to-text", model=model_id, device_map="auto")
prompt = "USER: <image>\nDescribe what you see on this picture?\nASSISTANT:"

for path in tqdm(image_list):
    try:
        image, pipe_res = segmentator.grounded_segmentation(Image.open(path).convert("RGB"), labels=[labels[idx]])
        plt.imsave(f"/job/processed_data/{labels[idx]}/{os.path.basename(path)}", image)
        plt.imsave(f"/job/processed_data/{labels[idx]}/{os.path.basename(path)[:-4]}_mask.jpg", pipe_res[0].mask)

        outputs = pipe(Image.fromarray(image), text=prompt, max_new_tokens=200)
        json_labeled[os.path.basename(path)[:-4]] = outputs[0]["generated_text"].split("ASSISTANT:")[-1]
    except Exception as e:
        print(str(e))
        print(path)
    
with open(f"/job/processed_data/{labels[idx]}/annotations.json", "w") as json_file:
    json.dump(json_labeled, json_file, indent=4)