In [None]:
%%capture
#@markdown # YOLOv8 dataset training
!python -m pip install onnx
!python -m pip install abraia
!python -m pip install ultralytics

import os
if not os.getenv('ABRAIA_ID') and not os.getenv('ABRAIA_KEY'):
    abraia_id = ''  #@param {type: "string"}
    abraia_key = ''  #@param {type: "string"}
    %env ABRAIA_ID=$abraia_id
    %env ABRAIA_KEY=$abraia_key

from abraia import Multiple

multiple = Multiple()

In [None]:
from logging.config import valid_ident
#@markdown ### Dataset

import json
import glob
import shutil
import itertools
from PIL import Image
from tqdm.contrib.concurrent import process_map
from sklearn.model_selection import train_test_split


def load_projects():
    folders = multiple.list_files()[1]
    return [folder['name'] for folder in folders if folder['name'] not in ('export', '.export')]


def load_annotations(dataset):
    annotations = json.loads(multiple.load_file(f"{dataset}/annotations.json"))
    for annotation in annotations:
        annotation['path'] = f"{dataset}/{annotation['filename']}"
    return annotations


def load_labels(annotations):
    labels = []
    for annotation in annotations:
        for object in annotation.get('objects', []):
            label = object.get('label')
            if label and label not in labels:
                labels.append(label)
    return list(set(labels))


def load_task(annotations):
    label, box, polygon = False, False, False
    for annotation in annotations:
        for object in annotation.get('objects', []):
            if 'polygon' in object:
                polygon = True
            elif 'box' in object:
                box = True
            elif 'label' in object:
                label = True
    if polygon:
        return 'segment'
    if box:
        return 'detect'
    if label:
        return 'classify'


def download_file(path, folder):
    dest = os.path.join(folder, os.path.basename(path))
    if not os.path.exists(dest):
        multiple.download_file(path, dest)
    return dest


def save_annotation(annotation, folder, classes, task):
    if task == 'classify':
        for object in annotation.get('objects', []):
            label = object.get('label')
            if label:
                src = os.path.join(folder, annotation['filename'])
                dest = os.path.join(folder, label, annotation['filename'])
                shutil.move(src, dest)
    else:
        im = Image.open(os.path.join(folder, 'images', annotation['filename']))
        label_lines = []
        for object in annotation.get('objects', []):
            label, bbox, cords = object.get('label'), object.get('box'), object.get('polygon')
            # Convert polygon or box to yolo format
            label_line = ''
            if task == 'segment' and cords:
                label_line = f"{classes.index(label)} " + ' '.join([f"{cord[0] / im.width} {cord[1] / im.height}" for cord in cords])
            elif task == 'detect' and bbox:
                label_line = f"{classes.index(label)} {(bbox[0] + bbox[2] / 2) / im.width} {(bbox[1] + bbox[3] / 2) / im.height} {bbox[2] / im.width} {bbox[3] / im.height}"
            elif task == 'classify':
                label_line = f"{classes.index(label)}"
            label_lines.append(label_line)
        label_path = os.path.join(folder, 'labels',  f"{os.path.splitext(annotation['filename'])[0]}.txt")
        with open(label_path, 'w') as f:
            f.write('\n'.join(label_lines))


def save_data(annotations, folder, classes, task):
    if (task == 'classify'):
        os.makedirs(os.path.join(folder), exist_ok=True)
        paths = [annotation['path'] for annotation in annotations]
        process_map(download_file, paths, itertools.repeat(folder), max_workers=5)
        for label in classes:
            os.makedirs(os.path.join(folder, label), exist_ok=True)
        for annotation in annotations:
            save_annotation(annotation, folder, classes, task)
    else:
        os.makedirs(os.path.join(folder, 'images'), exist_ok=True)
        paths = [annotation['path'] for annotation in annotations]
        process_map(download_file, paths, itertools.repeat(os.path.join(folder, 'images')), max_workers=5)
        os.makedirs(os.path.join(folder, 'labels'), exist_ok=True)
        for annotation in annotations:
            save_annotation(annotation, folder, classes, task)


def save_config(dataset, classes):
    yaml_content = f'''
    train: {os.path.join(os.getcwd(), dataset, 'train/images')}
    val: {os.path.join(os.getcwd(), dataset, 'val/images')}
    test: {os.path.join(os.getcwd(), dataset, 'test/images')}
    names: {classes}
    '''
    path = os.path.join(dataset, 'data.yaml')
    with open(path, 'w') as f:
        f.write(yaml_content)


def split_dataset(annotations):
    train, test = train_test_split(annotations, test_size=0.3)
    val, test = train_test_split(test, test_size=0.5)
    return train, val, test


def create_dataset(dataset, task, classes):
    annotations = load_annotations(dataset)
    train, val, test = split_dataset(annotations)
    data_annotations = {'train': train, 'val': val, 'test': test}
    for x in ['train', 'val', 'test']:
        save_data(data_annotations[x], f"{dataset}/{x}", classes, task)
    save_config(dataset, classes)


from ultralytics import YOLO
from PIL import ImageDraw


def build_model_name(model_name, task):
    if task == 'segment':
        model_name = f"{model_name}-seg"
    if task == 'classify':
        model_name = f"{model_name}-cls"
    return model_name


def train_model(dataset, task, batch=8, epochs=7, imgsz=640):
    model_name = build_model_name('yolov8n', task)
    model = YOLO(f"{model_name}.pt")
    data = f"{dataset}" if task == 'classify' else f"{dataset}/data.yaml"
    results = model.train(data=data, batch=batch, epochs=epochs, imgsz=imgsz, device="cpu")
    return model, model_name


def save_model(model, model_name, dataset, task, classes, imgsz=640):
    model_src = model.export(format="onnx")
    multiple.upload_file(model_src, f"{dataset}/{model_name}.onnx")
    multiple.save_json(f"{dataset}/{model_name}.json", {'task': task, 'inputShape': [1, 3, imgsz, imgsz], 'classes': classes})


def plot_results(src, results):
    result = results[0]
    img = Image.open(src).convert('RGB')
    draw = ImageDraw.Draw(img)
    if result.masks:
        for mask in result.masks:
            print('mask', mask)
            polygon = mask.xy[0]
            draw.polygon(polygon, outline="#ffff00", width=2)
    if result.boxes:
        for box in result.boxes.data:
            x1, y1, x2, y2, class_id, prob = box
            draw.rectangle((x1, y1, x2, y2), None, "#00ff00", width=3)
    if result.probs:
        draw.text((0, 0), f"{result.names[result.probs.top1]} {result.probs.data[result.probs.top1]}")
    display(img)


import ipywidgets as widgets
from IPython.display import display
from IPython.display import Image as show_image

projects = load_projects()

dropdown_project = widgets.Dropdown(description='Project', options=projects, value=projects[0])
text_classes = widgets.Text(description='Classes', value='')
text_task = widgets.Text(description='Task', value='')
label_status = widgets.Label(value='')

def dropdown_project_eventhandler(change):
    output_train.clear_output()
    output_labels.clear_output()
    output_pred.clear_output()

    label_status.value = 'Loading annotations...'
    dataset = dropdown_project.value
    annotations = load_annotations(dataset)
    classes = load_labels(annotations)
    task = load_task(annotations)
    text_classes.value = str(classes)
    text_task.value = str(task)
    label_status.value = 'Annotations loaded.'

    button_train.disabled = False

dropdown_project.observe(dropdown_project_eventhandler, names='value')

button_train = widgets.Button(description='Train', disabled=True)
output_train = widgets.Output(layout={'border': '1px solid black', 'width': '50%', 'height': '360px', 'overflow': 'auto'})

output_labels = widgets.Output()
output_pred = widgets.Output()

def button_train_eventhandler(obj):
    output_train.clear_output()
    with output_train:
        button_train.disabled = True
        dataset = dropdown_project.value
        classes = eval(text_classes.value)
        task = text_task.value
        #model_path = f'{dropdown_project.value}/{text_model.value}.hdf5'
        #model_type = dropdown_type.value
        #batch_size = 32
        #epochs = 50

        label_status.value = 'Downloading dataset...'
        create_dataset(dataset, task, classes)

        image_size = 128
        label_status.value = 'Training model...'
        model, model_name = train_model(dataset, task, imgsz=image_size)
        save_model(model, model_name, dataset, task, classes, imgsz=image_size)
        label_status.value = 'Model saved'

        src = glob.glob(f"{dataset}/test/*/*.png")[0]
        results = model.predict(src)
        plot_results(src, results)
    button_train.disabled = False

    with output_labels:
        display(show_image(filename=f"runs/{task}/train/val_batch0_labels.jpg"))
    with output_pred:
        display(show_image(filename=f"runs/{task}/train/val_batch0_pred.jpg"))

button_train.on_click(button_train_eventhandler)

hbox_output = widgets.HBox([output_train, output_labels, output_pred])
dashboard = widgets.VBox([dropdown_project, text_classes, text_task, button_train, label_status, hbox_output])
display(dashboard)


VBox(children=(Dropdown(description='Project', options=('camera', 'circles', 'edge', 'final-demo', 'hymenopter…