In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tlc

In [None]:
PROJECT_NAME = "3LC Tutorials"
DATASET_NAME = "LIACI"
INSTANCE_SEGMENTATION_TABLE_NAME = "instance-segmentation"

In [None]:
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor

# load Mask2Former fine-tuned on COCO instance segmentation
processor: Mask2FormerImageProcessor = AutoImageProcessor.from_pretrained(
    "facebook/mask2former-swin-tiny-coco-instance"
)
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-tiny-coco-instance")

In [None]:
from tlc_tools.common import infer_torch_device

device = infer_torch_device()
model.to(device)

In [None]:
def print_gpu_memory():
    if torch.cuda.is_available():
        print(f"Allocated: {torch.cuda.memory_allocated()/1024**2:.2f}MB")
        print(f"Cached: {torch.cuda.memory_reserved()/1024**2:.2f}MB")

In [None]:
from pathlib import Path

# image_folder = Path(r"C:\Users\gudbrand\OneDrive\Bilder\Bryllup")
image_folder = Path(r"C:\Data\balloon\train")

table = tlc.Table.from_image_folder(
    image_folder,
    include_label_column=False,
    table_name="Balloons",
    dataset_name="Test",
    project_name="Test",
    # if_exists="overwrite",
)

In [None]:
value_map = tlc.MapElement._construct_value_map(model.config.id2label)

In [None]:
import numpy as np


def table_map(sample):
    image = sample  # sample is a PIL image

    def get_correct_dimensions(image):
        orientation = image.getexif().get(274, 1)  # 274 is the EXIF orientation tag
        w, h = image.size

        if orientation in [5, 6, 7, 8]:
            return h, w
        return w, h

    # Store original image size before any processing
    width, height = get_correct_dimensions(image)
    img_array = np.array(image)
    img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)
    # Check tensor dimensions - should be (C,H,W)
    if len(img_tensor.shape) != 3:
        raise ValueError(f"Expected tensor with 3 dimensions (C,H,W), got shape {img_tensor.shape}")
    if img_tensor.shape[0] != 3:
        raise ValueError(f"Expected 3 channels, got {img_tensor.shape[0]}")
    return img_tensor, (width, height)


table.clear_maps()
table.map(table_map)

In [None]:
def preprocessor(batch):
    images = batch[0]
    inputs = processor(images=images.squeeze(0), return_tensors="pt")

    return dict(inputs)

In [None]:
predictor = tlc.Predictor(model, unpack_dicts=True, preprocess_fn=preprocessor, device=device)

In [None]:
def collect_fn(batch, predictor_output):
    _, (w, h) = batch
    w = int(w)
    h = int(h)

    result = processor.post_process_instance_segmentation(
        predictor_output.forward, target_sizes=[(h, w)], return_binary_maps=True
    )[0]

    masks = result["segmentation"]
    infos = result["segments_info"]

    labels = [i["label_id"] for i in infos]
    scores = [i["score"] for i in infos]

    # transposed = masks.cpu().numpy()
    # if len(transposed.shape) == 2:  # Single mask case
    #     transposed = np.expand_dims(transposed, axis=2)
    # else:
    #     transposed = transposed.transpose(1, 2, 0)
    # transposed = transposed.astype(np.uint8)
    def process_masks_efficiently(masks):
        processed = []
        # Process each mask individually
        for i in range(masks.shape[-1]):  # Iterate over the last dimension
            single_mask = masks[..., i]
            single_mask = single_mask.astype(np.uint8)
            processed.append(single_mask)

        # Stack only at the end
        return np.stack(processed, axis=-1)

    # Use it like this:
    if len(masks.shape) == 2:  # Single mask case
        transposed = np.expand_dims(masks.cpu().numpy(), axis=2)
    else:
        transposed = masks.cpu().numpy().transpose(1, 2, 0)

    transposed = process_masks_efficiently(transposed)

    # Allow final dimension to vary based on number of masks
    # assert transposed.shape[0] == w and transposed.shape[1] == h,
    # f"Expected shape ({h}, {w}, N) but got {transposed.shape}"
    instances = {
        "image_height": h,
        "image_width": w,
        "masks": transposed,
        "instance_properties": {"label": labels, "scores": scores},
    }

    torch.cuda.empty_cache()
    return {"predicted_masks": [instances]}

In [None]:
metrics_collector = tlc.FunctionalMetricsCollector(
    collect_fn,
    column_schemas={
        "predicted_masks": tlc.InstanceSegmentationMasks(
            "predicted_masks",
            instance_properties_structure={
                "label": tlc.CategoricalLabel("label", value_map),
                "scores": tlc.IoU("scores"),
            },
            is_prediction=True,
        ),
    },
    compute_aggregates=False,
)

In [None]:
tlc.collect_metrics(table, metrics_collector, predictor=predictor, collect_aggregates=False)

In [None]:
run = tlc.active_run()

In [None]:
run.set_status_completed()

In [None]:
run.name

In [None]:
run.metrics_tables[-1].name

In [None]:
sample_type = tlc.InstanceSegmentationMasks(
    "segmentations", instance_properties_structure={"label": tlc.CategoricalLabel("label", value_map)}
)

column_added_table = table.add_column(
    column_name="segmentations",
    values={"image_height": 0, "image_width": 0, "instance_properties": {"label": []}, "rles": []},
    schema=sample_type.schema,
)

In [None]:
column_added_table