In [None]:
from random import sample
import time
import random
import base64
import gzip
from io import BytesIO
import pandas as pd
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
from sqlalchemy import select, func
from eyened_orm import (
    ImageInstance,
    Modality,
    Feature,
    Annotation,
    AnnotationData,
    AnnotationType,
    Segmentation,
    Creator
)
from eyened_orm.Segmentation import Datatype, DataRepresentation
from eyened_orm.db import Database

In [None]:
database = Database('../dev/eyened_dev.env')
session = database.create_session()

In [None]:
def get_annotations_with_annotation_type(annotation_type_ids, where=None):
    #
    query = (
        select(Annotation, ImageInstance)
        # .join_from(Annotation, AnnotationData, isouter=True)
        .join_from(Annotation, ImageInstance, isouter=True)
        .join_from(Annotation, Creator)
        .where(
            ~Annotation.Inactive & 
            (Annotation.AnnotationTypeID.in_(annotation_type_ids)) &
            (Annotation.CreatorID != 1) &
            # (Annotation.CreatorID != 21) &
            (Creator.IsHuman)
        )
    )
    
    if where is not None:
        query = query.where(where)
    
    all_annots = session.execute(
        query
        .order_by(func.rand())
    ).all()
    return all_annots
    return random.sample(all_annots, 100)

In [None]:
# BASIC ANNOTATIONS
# 13 binary mask annotations
# 14 probability annotations
def open_data(dpath, db_res=None):
    im = Image.open(dpath)
    im = np.array(im)
    if len(im.shape) == 3:
        im = im[...,0]
    
    im = (im > 0).astype(np.uint8)
        
    if len(im.shape) == 2:
        im = im[None,...]
    
    if len(im.shape) != 3:
        raise RuntimeError(f'Found shape {im.shape} for {dpath}')

    return im # DHW


def convert_one_annotation_basic(annot, image_instance):

    res_db = (image_instance.Rows_y, image_instance.Columns_x, image_instance.NrOfFrames)

    
    depth, height, width = image_instance.shape
    segmentation = Segmentation(
        Depth=depth,
        Height=height,
        Width=width,
        SparseAxis=0,
        ScanIndices=None,
        ImageProjectionMatrix=None,
        DataRepresentation=DataRepresentation.Binary,
        DataType=Datatype.R8UI,
        ImageInstanceID=image_instance.ImageInstanceID,
        CreatorID=annot.CreatorID,
        FeatureID = annot.FeatureID
    )

    

    if len(annot.AnnotationData) == 0:
        session.add(segmentation)
        session.flush([segmentation])
        segmentation.write_empty()
        return segmentation
    elif len(annot.AnnotationData) == 1:
        annot_data = annot.AnnotationData[0]
        try:
            im = open_data(annot_data.path, res_db)
            assert im.shape == image_instance.shape, f'Shape mismatch for Annotation ID {annot.AnnotationID}, found {im.shape} != {image_instance.shape}'
        except Exception as e:
            raise RuntimeError(f'Error opening {annot_data.path}: {e}') from e
        
        session.add(segmentation)
        session.flush([segmentation])
        segmentation.write_data(im)
        return segmentation
    else:
        raise RuntimeError(f'Found {len(annot.AnnotationData)} annotation data for {annot.AnnotationID}')


def convert_annotations_basic(annotation_type_id):
    elems = get_annotations_with_annotation_type([annotation_type_id])
    annotations = []
    segmentations = []

    for annot, image_instance in tqdm(elems):
        try:
            segmentation = convert_one_annotation_basic(annot, image_instance)
            segmentations.append(segmentation)
            annotations.append(annot)
        except Exception as e:
            print(f'Error converting {annot.AnnotationID}: {e}')
            continue
    

    session.commit()
    return annotations, segmentations

In [None]:
annotations, segmentations = convert_annotations_basic(13)
# for annot, seg in zip(annotations, segmentations):
#     print(annot.AnnotationID, seg.SegmentationID, seg.ImageInstanceID)

In [None]:
annotations, segmentations = convert_annotations_basic(14)
# for annot, seg in zip(annotations, segmentations):
#     print(annot.AnnotationID, seg.SegmentationID, seg.ImageInstanceID)

In [None]:
# R/G masks
# 2	Segmentation 2D	R/G mask	19292
# 5	Segmentation OCT Enface	R/G mask	113
def convert_annotations_rgmasks(annotation_type_id, where=None):
    elems = get_annotations_with_annotation_type([annotation_type_id], where=where)
    annotations = []
    segmentations = []
    # ignore Vessel masks here. They will be inserted with the Artery/Vein annotations
    for annot, image_instance in tqdm(elems):

        if image_instance is None:
            print(
                f"Found image_instance is None for  annot_id: {annot.AnnotationID}"
            )
            continue

        d, h, w = image_instance.shape

        if annotation_type_id == 5:
            res_db = (image_instance.Columns_x, image_instance.NrOfFrames)
        else:
            res_db = (image_instance.Columns_x, image_instance.Rows_y)

        if annotation_type_id == 5:
            # enface
            depth = d
            height = 1
            width = w
            sparse_axis = 1
        else:
            depth = 1
            height = h
            width = w
            sparse_axis = 0

        segmentation = Segmentation(
            Depth=depth,
            Height=height,
            Width=width,
            SparseAxis=sparse_axis,
            ScanIndices=None,
            ImageProjectionMatrix=None,
            DataRepresentation=DataRepresentation.DualBitMask,
            DataType=Datatype.R8UI, 
            ImageInstanceID=image_instance.ImageInstanceID,
            CreatorID=annot.CreatorID,
            FeatureID = annot.FeatureID
        )

        

        if len(annot.AnnotationData) == 0:
            session.add(segmentation)
            session.flush([segmentation])
            segmentation.write_empty()

        elif len(annot.AnnotationData) == 1:
            annot_data = annot.AnnotationData[0]

            try:
                im = Image.open(annot_data.path)
            except Exception as e:
                print(f"Error opening {annot_data.path} for annot_id: {annot.AnnotationID}, image_instance_id: {image_instance.ImageInstanceID}")
                continue

            rs_im = im.size

            if res_db != rs_im:
                raise RuntimeError(f"Found shape {rs_im} != {res_db} for {annot_data.path}")
            w, h = im.size
                
            if im.mode == 'RGBA':
                im = im.convert("RGB")

            im = np.array(im)
            new_im = np.zeros((h, w), np.uint8)
            if len(im.shape) == 3:
                # both red and green channels
                new_im[im[...,0] > 0] = 1
                new_im[im[...,1] > 0] = 2
                new_im[(im[...,0] > 0) & (im[...,1] > 0)] = 3
            else:
                # only R channel
                new_im[im > 0] = 1

            if annotation_type_id == 5:
                new_im = new_im[:,None,:]
            else:
                new_im = new_im[None,:,:]

            session.add(segmentation)
            session.flush([segmentation])
            segmentation.write_data(new_im)

        else:
            raise RuntimeError(f'Found {len(annot.AnnotationData)} annotation data for {annot.AnnotationID}')

        segmentations.append(segmentation)
        annotations.append(annot)

    session.commit()
    return segmentations, annotations


In [None]:
segmentations, annotations = convert_annotations_rgmasks(5)

In [None]:
for annot, seg in zip(annotations, segmentations):
    print(annot.AnnotationID, seg.SegmentationID, seg.ImageInstanceID)

In [None]:
segmentations, annotations = convert_annotations_rgmasks(2)

In [None]:
for annot, seg in zip(annotations, segmentations):
    print(annot.AnnotationID, seg.SegmentationID, seg.ImageInstanceID)