In [7]:
%%capture
#@markdown # DeepLab training
!python -m pip install abraia[dev]==0.24.1

import os
if not os.getenv('ABRAIA_ID') and not os.getenv('ABRAIA_KEY'):
    abraia_id = ''  #@param {type: "string"}
    abraia_key = ''  #@param {type: "string"}
    %env ABRAIA_ID=$abraia_id
    %env ABRAIA_KEY=$abraia_key

from abraia import Abraia, training, utils

abraia = Abraia()

In [8]:
#@markdown ## Dashboard

import io
import re
import sys
import glob
import math
import random
import itertools
import contextlib
import numpy as np
import pandas as pd
import ipywidgets as widgets
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm
from IPython.display import display, clear_output, HTML
from IPython.display import Image as show_image

def display_images(imgs, labels):
    C = math.ceil(math.sqrt(len(imgs)))
    R = math.ceil(len(imgs) / C)
    print(R, C)
    plt.figure(figsize=(12, 9))
    for k, (img, label) in enumerate(zip(imgs, labels)):
        plt.subplot(R, C, k+1)
        plt.title(label)
        plt.imshow(img)
        plt.axis('off')
    plt.show()

def add_image_cell(src):
    return f'<img src="{src}" width="300">'

def display_df(df):
    display(HTML(df.to_html(escape=False)))

def load_dataset_df(dataset):
    df = pd.DataFrame(dataset)
    df = df.rename(columns={'name': 'filename'})
    df['image'] = df['url'].apply(add_image_cell)
    return df

#def display_dataset(dataset):
#    imgs, labels = [], []
#    for row in dataset:
#        labels.append(row['name'])
#        imgs.append(utils.load_image(utils.load_url(row['url'])))
#    display_images(imgs, labels)

def source_images(dir):
    return list(itertools.chain.from_iterable([glob.glob(f"{dir}*/*.{ext}") for ext in ['png', 'jpg']]))

projects = training.load_projects()
dropdown_project = widgets.Dropdown(description='Project', options=projects, value=None, layout={'width': 'auto'})
dropdown_task = widgets.Dropdown(description='Task', options=[], value=None, layout={'width': 'auto'})
text_classes = widgets.Text(description='Classes', value='', layout={'width': 'auto'})

int_epochs = widgets.IntText(description='Epochs', min=0, step=25, layout={'width': 'auto'})
button_train = widgets.Button(description='Train', disabled=True, layout={'width': 'auto'})
vbox_train = widgets.VBox([int_epochs, button_train])

output = widgets.Output(layout={'border': '1px solid #aaa', 'width': '100%', 'height': '360px', 'overflow': 'auto'})
label_status = widgets.Label(value='')

vbox_dashboard = widgets.VBox([dropdown_project, dropdown_task, text_classes, vbox_train])

def dropdown_project_eventhandler(change):
    output.clear_output()
    with output:
        label_status.value = 'Loading annotations...'
        project = dropdown_project.value
        annotations = training.load_annotations(project)
        classes = training.load_labels(annotations)
        text_classes.value = ','.join(classes)
        tasks = training.load_tasks(annotations)
        tasks = tasks if len(tasks) else ['detect']
        dropdown_task.options = tasks
        dropdown_task.value = tasks[-1]
        dataset = training.load_dataset(project)
        display_df(load_dataset_df(dataset)[['image', 'filename']].sample(9))
        label_status.value = 'Annotations loaded.'

dropdown_project.observe(dropdown_project_eventhandler, names='value')

def dropdown_task_eventhandler(change):
    task = dropdown_task.value
    classes = text_classes.value
    button_train.disabled = not (classes and task)
    int_epochs.value = 25 if task == 'classify' else 250

dropdown_task.observe(dropdown_task_eventhandler, names='value')

def load_dataset():
    output.clear_output()
    with output:
        label_status.value = 'Loading dataset...'
        project = dropdown_project.value
        task = dropdown_task.value
        classes = text_classes.value.split(',')
        if not os.path.exists(project):
            training.create_dataset(project, task, classes)
        label_status.value = 'Dataset loaded.'

def train_model(project, classes, epochs, batch=32, imgsz=640):
    output.clear_output()
    with output:
        label_status.value = 'Training model...'
        task = dropdown_task.value
        if task == 'classify':
            training_session = training.classify.Model()
            dataloaders, classes = training_session.create_dataset(project)
            model = training_session.train(project, epochs=epochs)
            training_session.save(project, classes)
            display(training.classify.visualize_data(dataloaders['train']))
            #training.classify.visualize_model(model, dataloaders['val'])
        else:
            training_session = training.detect.Model(task)
            def print_train_end(trainer):
                print('# End training')
                print('Metrics:', trainer.metrics)
            #training_session.model.add_callback('on_train_start', print_train_start)
            #training_session.model.add_callback('on_train_epoch_start', print_train_epoch)
            training_session.model.add_callback('on_train_end', print_train_end)
            metrics = training_session.train(project, epochs=epochs, batch=batch, imgsz=imgsz)
            #TODO: Save metrics with model
            training_session.save(project, classes, imgsz=imgsz)
        label_status.value = 'Model saved.'
    return training_session

def sorted_folders(dir):
    items = [os.path.join(dir, name) for name in os.listdir(dir)]
    sorted_items = sorted(items, key=os.path.getctime)
    return sorted_items

def calculate_metrics(training_session, split='val'):
    out = io.StringIO()
    with contextlib.redirect_stderr(out):
        metrics = training_session.model.val(split=split)
        confusion_matrix = metrics.confusion_matrix.matrix
        map = metrics.box.map50
        precision = metrics.box.p
        recall = metrics.box.r
    print(f"Mean Average Precision: {map}")
    print(f"Precision by Class: {precision}")
    print(f"Recall by Class: {recall}")
    print(f"Confusion Matrix:\n{confusion_matrix}")
    return map, precision, recall, confusion_matrix

def test_model(training_session, project):
    task = dropdown_task.value
    if task != 'classify':
        output.clear_output()
        with output:
            print("Train metrics")
            calculate_metrics(training_session)
            print("Test metrics")
            calculate_metrics(training_session, split='test')
            srcs = source_images(f"{project}/test/")
            for src in srcs:
                img = utils.load_image(src)
                results = training_session.run(img)
                print(results)
                out = utils.render_results(img, results)
                im = Image.fromarray(out)
                im.thumbnail((300, 300))
                display(im)
            folder = sorted_folders(f"runs/{task}/")[-1]
            #display(show_image(filename=f"{folder}/BoxPR_curve.png"))
            display(show_image(filename=f"{folder}/val_batch0_labels.jpg"))
            display(show_image(filename=f"{folder}/val_batch0_pred.jpg"))

def button_train_eventhandler(obj):
    #clear_output(wait=True)
    button_train.disabled = True
    project = dropdown_project.value
    task = dropdown_task.value
    classes = text_classes.value.split(',')
    epochs = int_epochs.value
    load_dataset()
    training_session = train_model(project, classes, epochs)
    test_model(training_session, project)
    button_train.disabled = False

button_train.on_click(button_train_eventhandler)

dashboard = widgets.AppLayout(header=None, left_sidebar=vbox_dashboard, center=output, right_sidebar=None, footer=label_status, grid_gap="10px")
display(dashboard)

AppLayout(children=(Label(value='', layout=Layout(grid_area='footer')), VBox(children=(Dropdown(description='P…