In [None]:
import os

# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import PolyCollection
import geopandas as gpd
import pandas as pd
import PIL
from pathlib import Path
import tqdm
import rasterio.features
import shapely.geometry
import glob

# COCO tools
import pycocotools.mask
from pycocotools.coco import COCO

# # Recognize Anything Model & Tag2Text
# from ram.models import ram_plus
# from ram import inference_ram
# import torchvision.transforms as TS

# Grounding Dino
from groundingdino.util.inference import load_model
import groundingdino.datasets.transforms as T
from groundingdino.util.utils import get_phrases_from_posmap

# Segment anything
from segment_anything_hq import (
    SamPredictor as SamPredictor_hq,
    sam_model_registry as sam_model_registry_hq,
)
from segment_anything import SamPredictor, sam_model_registry

In [None]:
from hydra import initialize, compose

with initialize("config", version_base=None):
    cfg = compose("config.yaml")

print(cfg)
# Load grounding dino model
dino_model = load_model(
    cfg.GROUNDING_DINO_CONFIG_PATH,
    cfg.GROUNDING_DINO_CHECKPOINT_PATH,
    device=cfg.DEVICE,
)

# Segment Anything Model (SAM)

if cfg.USE_SAM_HQ:
    print("Initialize SAM-HQ Predictor")
    sam = sam_model_registry_hq[cfg.SAM_HQ_ENCODER_VERSION](
        checkpoint=cfg.SAM_HQ_CHECKPOINT_PATH
    ).to(device=cfg.DEVICE)
    sam_predictor = SamPredictor_hq(sam)
else:
    sam = sam_model_registry[cfg.SAM_ENCODER_VERSION](
        checkpoint=cfg.SAM_CHECKPOINT_PATH
    ).to(device=cfg.DEVICE)
    sam_predictor = SamPredictor(sam)

In [None]:
def find_tile_bounds(root_tilepath, concat=True):
    """
    Recursively search through directories starting from root_tilepath to find and optionally concatenate
    geospatial data files (.gpkg) that include 'cells_intersect' in their filename.

    Parameters:
    - root_tilepath (str or pathlib.Path): The root directory path where the search for tile files begins.
    - concat (bool, optional): If True, concatenates all found geospatial data into a single DataFrame.
                               If False, returns a list of DataFrames. Default is True.

    Returns:
    - pandas.DataFrame or list of geopandas.GeoDataFrame: The concatenated DataFrame of all files if `concat=True`,
      or a list of DataFrames for each file if `concat=False`.

    """
    cell_files = []
    for p in Path(str(root_tilepath)).iterdir():
        if p.is_dir():
            cell_files += find_tile_bounds(p, concat=False)
        elif p.is_file() and p.suffix == ".gpkg" and "tiles_intersect" in p.stem:
            df = gpd.read_file(p)
            df["name"] = p.stem.replace("_tiles_intersects", "")
            cell_files.append(df.copy())

    if concat:
        cell_files = pd.concat(cell_files)

    return cell_files


# tilebounds
df_tilebounds = find_tile_bounds(Path(cfg.disk_path) / "tile_dataset_havenhoofden")
df_tilebounds

In [None]:
out_dir = Path(cfg.disk_path) / "output_havenhoofden"
if not out_dir.exists():
    out_dir.mkdir()
else:
    print(f"Directory {out_dir} already exists!")

In [None]:
Path(cfg.disk_path) / "tile_dataset_havenhoofden"

In [None]:
dataDirs = [
    Path(p)
    for p in glob.glob(
        str(
            Path(cfg.disk_path)
            / "tile_dataset_havenhoofden"
            / "20230714 MUG Medemblik Den Oever orthomosaic deel *"
        )
    )
]

for dataDir in dataDirs:
    print(dataDir)

    projName = dataDir.stem

    df_pred_shapes = dict(
        category=[],
        confidence=[],
        tile_path=[],
        project_name=[],
        tile_fname=[],
        geometry=[],
    )

    imgPaths = glob.glob(str(dataDir / "tiles_havenhoofden" / "*.jpeg"))

    for imgPath in tqdm.tqdm(imgPaths):
        print(imgPath)

        # Load image
        image_pil, image = load_image(imgPath)

        # Tags
        if cfg.fixed_tags:
            tags = ",".join(cfg.fixed_tags)
        else:
            # Find tags with RAM
            ram_model = ram_model.to(cfg.DEVICE)
            raw_image = image_pil.resize((384, 384))
            raw_image = transform(raw_image).unsqueeze(0).to(cfg.DEVICE)
            res = inference_ram(raw_image, ram_model)
            tags = res[0].replace(" |", ",")

        # Find bounding boxes with grounding dino
        boxes_filt, scores, pred_phrases = get_grounding_output(
            dino_model,
            image,
            tags,
            DINO_BOX_THRESHOLD,
            DINO_TEXT_THRESHOLD,
            device=DEVICE,
        )

        # Resize boxes
        size = image_pil.size
        H, W = size[1], size[0]
        for i in range(boxes_filt.size(0)):
            boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
            boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
            boxes_filt[i][2:] += boxes_filt[i][:2]

        # use NMS to handle overlapped boxes
        boxes_filt = boxes_filt.cpu()
        nms_idx = (
            torchvision.ops.nms(boxes_filt, scores, IOU_THRESHOLD).numpy().tolist()
        )
        if DO_IOU_MERGE:
            boxes_filt_clean = boxes_filt[nms_idx]
            pred_phrases_clean = [pred_phrases[idx] for idx in nms_idx]
            # print(f"NMS: before {boxes_filt.shape[0]} boxes, after {boxes_filt_clean.shape[0]} boxes")
        else:
            boxes_filt_clean = boxes_filt
            pred_phrases_clean = pred_phrases

        # Segment objects with SAM
        image_np = np.array(image_pil)
        sam_predictor.set_image(image_np)
        transformed_boxes = sam_predictor.transform.apply_boxes_torch(
            boxes_filt_clean, image_np.shape[:2]
        ).to(DEVICE)
        masks, _, _ = sam_predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes.to(DEVICE),
            multimask_output=False,
        )

        # Find tile bounds (X, Y) based on name
        imgName = Path(imgPath).stem
        tile1a = imgName.split("_")[0]
        tile1b = int(imgName.split("_")[-1])
        cellfile = df_tilebounds[
            (df_tilebounds.index == tile1b) & (df_tilebounds.name == tile1a)
        ].copy()
        assert len(cellfile) == 1
        cellfile = cellfile.iloc[0, :].copy()
        xstep = (cellfile.xmax - cellfile.xmin) / image_np.shape[1]
        ystep = (cellfile.ymax - cellfile.ymin) / image_np.shape[0]

        # eventueel nog geometry van cellfile ipv tabel,
        # eventueel test via inladen tiff

        affine = [xstep, 0, cellfile.xmin, 0, -ystep, cellfile.ymax, 0, 0, 1]

        # SAM masks
        assert len(pred_phrases_clean) == len(masks)
        shapes, titles = [], []
        for cat_title, mask in zip(pred_phrases_clean, masks):
            mask = mask.cpu().numpy()
            cat_shapes = rasterio.features.shapes(
                mask.astype(np.uint8), mask=mask, connectivity=4, transform=affine
            )
            for shape, _ in cat_shapes:
                title, confidence = cat_title.replace(")", "").split("(")
                shape = shapely.geometry.shape(shape).simplify(
                    0.01, preserve_topology=True
                )
                if shape.area > 0.01:
                    df_pred_shapes["category"].append(title)
                    df_pred_shapes["confidence"].append(confidence)
                    df_pred_shapes["geometry"].append(shape)

                    df_pred_shapes["tile_path"].append(str(imgPath))
                    df_pred_shapes["tile_fname"].append(Path(imgPath).stem)
                    df_pred_shapes["project_name"].append(projName)

    df_pred_shapes = gpd.GeoDataFrame(df_pred_shapes, crs="epsg:28992")
    if USE_SAM_HQ:
        df_pred_shapes.to_file(out_dir / f"fix_tags_hq_{projName}_havenhoofden.gpkg")
    else:
        df_pred_shapes.to_file(out_dir / f"fix_tags_{projName}_havenhoofden.gpkg")