# MAL With Annotation Types
* Image MAL with subclasses.
* This is the same task as the image mal tutorial but we are going to add a subclass for whether or not the person in the image is holding a bag.

In [None]:
try:
    import labelbox
except: 
    !git clone https://github.com/Labelbox/labelbox-python.git
    !cd labelbox-python && git checkout ms/annotation-examples && pip install .[data]

In [2]:
# Run these if running in a colab notebook
COLAB = "google.colab" in str(get_ipython())
if COLAB:
    !git clone https://github.com/Labelbox/labelbox-python.git
    !mv labelbox-python/examples/model_assisted_labeling/image_model.py .
else:
    import sys
    sys.path.append('../model_assisted_labeling')

In [3]:
#Used this as a reference for the model
#https://colab.research.google.com/github/tensorflow/tpu/blob/master/models/official/mask_rcnn/mask_rcnn_demo.ipynb#scrollTo=6lCL-ZcwaJbA
from labelbox.schema.ontology import OntologyBuilder, Tool, Classification, Option
from labelbox import Client, LabelingFrontend
from labelbox.data.annotation_types import (
    LabelList,
    RasterData,
    Rectangle,
    ObjectAnnotation,
    ClassificationAnnotation,
    Point,
    ClassificationAnswer,
    Radio,
    Mask,
    Label
)
from labelbox.data.serialization import NDJsonConverter
from image_model import predict, class_mappings, load_model
from typing import Dict, Any, Tuple, List
import numpy as np
from PIL import Image
import requests
import ndjson
import uuid
from io import BytesIO
import os
from getpass import getpass

In [4]:
# If you don't want to give google access to drive you can skip this cell
# and manually set `API_KEY` below.
if COLAB:
    !pip install colab-env -qU
    from colab_env import envvar_handler
    envvar_handler.envload()

API_KEY = os.environ.get("LABELBOX_API_KEY")
if not os.environ.get("LABELBOX_API_KEY"):
    API_KEY = getpass("Please enter your labelbox api key")
    if COLAB:
        envvar_handler.add_env("LABELBOX_API_KEY", API_KEY)

In [5]:
# Set this if running in colab. Otherwise it should work if you have the LABELBOX_API_KEY set.
API_KEY = os.environ["LABELBOX_API_KEY"]
# Only update this if you have an on-prem deployment
ENDPOINT = "https://api.labelbox.com/graphql"

In [6]:
client = Client(api_key=API_KEY, endpoint=ENDPOINT)

In [7]:
#Downloads weights and loads the model.
load_model()

Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
INFO:tensorflow:Restoring parameters from gs://cloud-tpu-checkpoints/mask-rcnn/1555659850/variables/variables


## Local first
* No keys, no references. Just get started making inferences

In [8]:
def has_bag(person, bags):
    for bag in bags:
        if person.value.shapely.contains(bag.value.shapely.centroid):
            return True
    return False    

def get_annotations(boxes, classes, seg_masks):
    annotations = []
        for box, class_idx, seg in zip(boxes, classes, seg_masks):
        name = class_mappings[class_idx]
        value = None
        classifications = []
        if name in ['person', 'handbag']:
            value = Rectangle(
                start = Point(x = box[1], y = box[0]), end = Point(x = box[3], y = box[2])
            )
        elif name == 'car':
            value = Mask(mask = RasterData.from_2D_arr(arr = seg), color = (1,1,1))
        if value is not None:
            annotations.append(
                ObjectAnnotation(
                    name = name,
                    value = value
                )
            ) 
    return annotations

def update_bag_classifications(annotations):
    bags = [annot for annot in annotations if annot.name == 'handbag']
    people = [annot for annot in annotations if annot.name == 'person']
    for person in people:
        person.classifications = [ClassificationAnnotation(
            name = 'has_bag',
            value = Radio(answer = ClassificationAnswer(name = str(has_bag(person, bags))))
        )]


In [9]:
### We can just start creating predictions whether or not we have a 
image_paths = ['/Users/matthewsokoloff/Downloads/kitano_st.jpeg']

labellist = LabelList([])

for image_url in image_paths:
    image_data = RasterData(file_path = image_url)
    height, width = image_data.data.shape[:2]
    prediction = predict(np.array([image_data.im_bytes]), min_score=0.5, height=height, width = width)
    annotations = get_annotations(prediction['boxes'], prediction['class_indices'], prediction['seg_masks'])
    update_bag_classifications(annotations)
    labellist.append(Label(
        data = image_data,
        annotations = annotations
    ))

In [10]:
# At this time it is a bit verbose.
# We will be adding helper functions to make this easier.
# Ie. geometry.from_prediction which will automatically figure out what you have..
# Or even just like bbox.from_points
# idk.

### Project Setup
* Same as before. Except we can add the data and ontology directly from the labellist
* See `labellist.get_ontology()` and then the adding id section below.

In [11]:
# Lets setup a project to label
# Note see Ontology, Project, and Project_setup notebooks for more information on this section.
project = client.create_project(name="subclass_mal_project")
dataset = client.create_dataset(name="subclass_mal_dataset")
editor = next(
    client.get_labeling_frontends(where=LabelingFrontend.name == 'editor'))
# Use the label collection to build the ontology
project.setup(editor, labellist.get_ontology().asdict())
project.datasets.connect(dataset)
project.enable_model_assisted_labeling()

True

### Add ids required for MAL

In [12]:
signer = lambda _bytes: client.upload_data(content=_bytes, sign=True)
labellist.add_url_to_masks(signer) \
         .add_url_to_data(signer) \
         .assign_schema_ids(OntologyBuilder.from_project(project)) \
         .add_to_dataset(dataset, signer)

1it [00:04,  4.40s/it]
1it [00:05,  5.17s/it]
1it [00:00, 9686.61it/s]


<labelbox.data.annotation_types.collection.LabelList at 0x181106940>

### Convert to Prediction import format (NDJson)
* We want to create a json payload that matches this: https://docs.labelbox.com/data-model/en/index-en#annotations
* Here we will run inferences on all of our data (only one image this time)

In [13]:
ndjsons = list(NDJsonConverter.serialize(labellist))
print(ndjsons[0])

{'uuid': '2be62464-6f40-4022-800a-60e28ed21ad6', 'dataRow': {'id': 'ckrjmvvkcu9hb0ypk3o255gm8'}, 'schemaId': 'ckrjmvkw9lned0y8u0uz6cpmd', 'classifications': [{'schemaId': 'ckrjmvkwwlnej0y8uaxy0btap', 'answer': {'schemaId': 'ckrjmvkxdlnel0y8ugxoh9qw9'}}], 'bbox': {'top': 2166.56884765625, 'left': 3658.38427734375, 'height': 562.581787109375, 'width': 207.585205078125}}


### Upload the annotations

In [14]:
upload_task = project.upload_annotations(name=f"upload-job-{uuid.uuid4()}",
                                         annotations=ndjsons,
                                         validate=True)
# Wait for upload to finish
upload_task.wait_until_done()

In [15]:
# Review the upload status
for status in upload_task.statuses:
    print(status)

{'uuid': '2be62464-6f40-4022-800a-60e28ed21ad6', 'dataRow': {'id': 'ckrjmvvkcu9hb0ypk3o255gm8'}, 'status': 'SUCCESS'}
{'uuid': '4272fb66-8b8f-4d38-99f7-1fda2ff7fb3d', 'dataRow': {'id': 'ckrjmvvkcu9hb0ypk3o255gm8'}, 'status': 'SUCCESS'}
{'uuid': '2ff416c4-8cf7-4f07-81c9-009ae2b550b7', 'dataRow': {'id': 'ckrjmvvkcu9hb0ypk3o255gm8'}, 'status': 'SUCCESS'}
{'uuid': '9709fb30-dc6a-4c8f-a2ca-b2741bd4f676', 'dataRow': {'id': 'ckrjmvvkcu9hb0ypk3o255gm8'}, 'status': 'SUCCESS'}
{'uuid': 'ed8b7e18-8f9e-4f3b-b24f-a0d274a52f8c', 'dataRow': {'id': 'ckrjmvvkcu9hb0ypk3o255gm8'}, 'status': 'SUCCESS'}
{'uuid': '4735ddad-6b2d-4652-9c8d-fa07b63a7b1e', 'dataRow': {'id': 'ckrjmvvkcu9hb0ypk3o255gm8'}, 'status': 'SUCCESS'}
{'uuid': '3fc59e3f-2388-41ca-9999-583efa6830f6', 'dataRow': {'id': 'ckrjmvvkcu9hb0ypk3o255gm8'}, 'status': 'SUCCESS'}
{'uuid': '7bfc2697-7467-40b9-8d6f-3fc9e097c51a', 'dataRow': {'id': 'ckrjmvvkcu9hb0ypk3o255gm8'}, 'status': 'SUCCESS'}
{'uuid': 'c08fbe2e-39b8-420b-992f-1a610f045cb7', 'dataRo