In [None]:
# Copyright 2022 Sony Semiconductor Solutions Corp. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Import dataset into CVAT

## Imports

In [None]:
import errno
import glob
import json
import jsonschema
import pathlib
import re
import os
from pprint import pprint

from cvat_sdk import make_client
from cvat_sdk.api_client import ApiClient, Configuration, models
from cvat_sdk.core.proxies.tasks import ResourceType

## Load Configurations

Load the configuration file and set the variables.

In [None]:
def validate_symlink(path: pathlib.Path):
    if path.is_symlink():
        msg = "Symbolic link is not supported. Please use real folder or file"
        raise OSError(errno.ELOOP, f"{msg}", f"{path}")


configuration_path = pathlib.Path("./configuration.json")
validate_symlink(configuration_path)

with open(configuration_path, "r") as f:
    app_configuration = json.load(f)

configuration_schema_path = pathlib.Path("./configuration_schema_import.json")
validate_symlink(configuration_schema_path)

with open(configuration_schema_path, "r") as f:
    json_schema = json.load(f)

# Validate configuration.
jsonschema.validate(app_configuration, json_schema)

cvat_username = app_configuration["cvat_username"]
cvat_password = app_configuration.get("cvat_password", "")
cvat_project_id = app_configuration.get("cvat_project_id", "")

import_dir = app_configuration["import_dir"].replace(os.path.sep, "/")
validate_symlink(pathlib.Path(import_dir))
import_image_extension = app_configuration["import_image_extension"]
import_task_name = app_configuration["import_task_name"]

## Set Authentication

In [None]:
# Set up an API client
configuration = Configuration(
    host="http://localhost:8080/",
    username=cvat_username,
    password=cvat_password,
)

## Load Dataset Files

In [None]:
def atoi(text):
    return int(text) if text.isdigit() else text


def natural_keys(text):
    return [atoi(c) for c in re.split(r"(\d+)", text)]


# Reads Image files from a dir and gets a sorted list
files_img = sorted(
    glob.glob(
        str(pathlib.Path(import_dir) / ("**/*." + import_image_extension)),
        recursive=True,
    ),
    key=natural_keys,
)

if len(files_img) == 0:
    raise FileNotFoundError(
        f"Image files for dataset not found in the import_dir: {import_dir}"
    )

## Upload Dataset and Link to Project
Import dataset to CVAT by creating CVAT task with dataset. And link the created CVAT task to CVAT project.

In [None]:
task_id = 1
# Create a Client instance bound to a local server and authenticate using basic auth
with make_client(
    host="http://localhost", port="8080", credentials=(cvat_username, cvat_password)
) as client:
    # To create a task, fill in task parameters first.
    task_spec = {"project_id": cvat_project_id, "name": import_task_name}
    # Now we can create a task using a task repository method.
    # Repositories can be accessed as the Client class members.
    task = client.tasks.create_from_data(
        spec=task_spec, resource_type=ResourceType.LOCAL, resources=files_img
    )
    # If an object is modified on the server,
    # the local object is not updated automatically.
    # To reflect the latest changes, the local object needs to be fetch()-ed.
    task.fetch()
    task_id = task.id

print("Import process is completed.")

print("Project and Task linking.")
with ApiClient(configuration) as api_client:
    # int | A unique integer value identifying this task.
    id = task_id
    patched_task_write_request = models.PatchedTaskWriteRequest(
        project_id=cvat_project_id,
    )
    # PatchedTaskWriteRequest |  (optional)
    (data, response) = api_client.tasks_api.partial_update(
        id,
        patched_task_write_request=patched_task_write_request,
    )
    pprint("project id: " + str(cvat_project_id) + " linked task id: " + str(task_id))

print("Project and Task linking completed.")