In [None]:
# 1. Imports
import os
import random
import torch
import numpy as np
from PIL import Image
from pycocotools import mask as coco_mask
from groundingdino.util.inference import load_model, load_image, predict, annotate
from groundingdino.util import box_ops
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from segment_anything.utils.amg import mask_to_rle_pytorch, rle_to_mask, area_from_rle, remove_small_regions


# 2. Constants and Global Variables
SAM_CHECKPOINT = "/scratch/gpfs/eh0560/segment-anything/sam_models/sam_vit_h_4b8939.pth"
MODEL_TYPE = "vit_h"
DINO_MODEL_PATH = "/scratch/gpfs/eh0560/GroundingDINO/models/groundingdino_swinb_cogcoor.pth"
DINO_CONFIG_PATH = "/scratch/gpfs/eh0560/GroundingDINO/groundingdino/config/GroundingDINO_SwinB_cfg.py"
DEVICE = "cuda"
IMAGES_DIR = "/scratch/gpfs/RUSTOW/deskewing_datasets/images/cudl_images"
TEXT_PROMPT = "scanned document"
BOX_THRESHOLD = 0.45
TEXT_THRESHOLD = 0.25


# 3. Utility Functions
def load_sam_model(checkpoint, model_type, device):
    sam = sam_model_registry[model_type](checkpoint=checkpoint)
    return sam.to(device), SamPredictor(sam)

def get_images(directory):
    return [os.path.join(dirpath, filename) for dirpath, _, filenames in os.walk(directory) for filename in filenames if filename.lower().endswith('.jpg')]

def visualize_mask(mask, image, random_color=True):
    color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0) if random_color else np.array([30/255, 144/255, 255/255, 0.6])
    mask_image = (mask.reshape(*mask.shape, 1) * color.reshape(1, 1, -1))
    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGBA")
    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

def remove_background(image, mask):
    alpha_channel = (mask > 0).astype(np.uint8) * 255
    return np.concatenate([image, alpha_channel], axis=-1)


# 4. Main Function
def main():
    # Load models
    sam_model, sam_predictor = load_sam_model(SAM_CHECKPOINT, MODEL_TYPE, DEVICE)
    dino_model = load_model(DINO_CONFIG_PATH, DINO_MODEL_PATH)
    
    cudl_images = get_images(IMAGES_DIR)
    image_path = random.choice(cudl_images)
    image_source, image = load_image(image_path)
    
    boxes, logits, phrases = predict(dino_model, image, TEXT_PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD)
    
    sam_predictor.set_image(image_source)
    H, W, _ = image_source.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
    transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).to(DEVICE)
    masks, _, _ = sam_predictor.predict_torch(point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False)
    
    mask = (masks.sum(dim=1) == True).cpu().numpy()[0]
    mask, _ = remove_small_regions(mask, 10000, "holes")
    mask, _ = remove_small_regions(mask, 10000, "islands")
    
    annotated_frame_with_mask = visualize_mask(mask, image_source)
    image_without_background = remove_background(image_source, mask)
    
    # Display images (for demonstration, the next steps depend on how the visualization is handled)
    Image.fromarray(image_without_background).show()
    Image.fromarray(annotated_frame_with_mask).show()


# 5. Execution Point
if __name__ == "__main__":
    main()
