In [None]:
import sys
import os

# Add the path to the project directory
utils_path = os.path.abspath('..')
if utils_path not in sys.path:
    sys.path.append(utils_path)
# Add the path to the directory containing utils to sys.path
utils_path = os.path.abspath('../utils')
if utils_path not in sys.path:
    sys.path.append(utils_path)
print(sys.path)

### Authenticate the client

Instantiate a training and prediction client with your endpoint and keys. 

In [2]:
from azure.cognitiveservices.vision.customvision.training import CustomVisionTrainingClient
from azure.cognitiveservices.vision.customvision.prediction import CustomVisionPredictionClient
from msrest.authentication import ApiKeyCredentials
from dotenv import load_dotenv
import os

# Load environment variables
load_dotenv()
print(os.getenv("TRAINING_KEY"), os.getenv("TRAINING_ENDPOINT"))

# Authentication
credentials = ApiKeyCredentials(in_headers={"Training-key": os.getenv("TRAINING_KEY")})
trainer = CustomVisionTrainingClient(endpoint=os.getenv("TRAINING_ENDPOINT"), credentials=credentials)

# Authentication for prediction
prediction_credentials = ApiKeyCredentials(in_headers={"Prediction-key": os.getenv("PREDICTION_KEY")})
predictor = CustomVisionPredictionClient(endpoint=os.getenv("PREDICTION_ENDPOINT"), credentials=prediction_credentials)


90dad624b6664556accbcfd69e2e170d https://crackdetection.cognitiveservices.azure.com/


### Creat or get the project

In [3]:
# Find the object detection domain
obj_detection_domain = next(domain for domain in trainer.get_domains() if domain.type == "ObjectDetection" and domain.name == "General")

# Project name setup
project_name = "WRB-Bad-Detection"

# Find project by name
print(f"Searching for project '{project_name}'...")
projects = trainer.get_projects()

project_id = None
for project in projects:
    if project.name == project_name:
        project_id = project.id
        break

if project_id:
    print(f"Project '{project_name}' found with ID: {project_id}")
else:
    print(f"No project found with the name '{project_name}'")
    # Create a new project
    print ("Creating project...")
    project = trainer.create_project(project_name, domain_id=obj_detection_domain.id)


Searching for project 'WRB-Bad-Detection'...
Project 'WRB-Bad-Detection' found with ID: 1ee0bb48-3b3f-419d-a575-e12c98f91578


### Train Model

Train the model.

In [4]:
import time

# Force training
try:
    iteration = trainer.train_project(
        project_id, 
        training_type="Advanced",
        reserved_budget_in_hours=2,
        force_train=True)

    print(f"Training started for iteration {iteration.id}")
    while (iteration.status != "Completed"):
        iteration = trainer.get_iteration(project.id, iteration.id)
        print ("Training status: " + iteration.status)
        time.sleep(1)

    iteration_id = iteration.id

except Exception as e:
    print(f"Error occurred during training: {e}")

Training started for iteration 19e31227-65c6-4d8a-bafd-aa01bd35c600


In [4]:
# Get iterations
print(f"Fetching iterations for project '{project_name}'(id:'{project_id}')...")
iterations = trainer.get_iterations(project_id)

# List iteration IDs
print("Iteration:")
for iteration in iterations:
    # print(iteration)
    print(f"{iteration.name}, Created at {iteration.created}, Last modified at {iteration.last_modified}")

Fetching iterations for project 'WRB-Bad-Detection'(id:'1ee0bb48-3b3f-419d-a575-e12c98f91578')...
Iteration:
Iteration 3, Created at 2024-06-29 07:41:11.493000+00:00, Last modified at 2024-06-29 09:13:19.213000+00:00
Iteration 2, Created at 2024-06-29 07:19:07.863000+00:00, Last modified at 2024-06-29 07:48:31.821000+00:00
Iteration 1, Created at 2024-06-28 20:03:03.686000+00:00, Last modified at 2024-06-29 07:27:01.080000+00:00


### Publish the current iteration

An iteration is not available in the prediction endpoint until it is published. The following code makes the current iteration of the model available for querying.

In [None]:
publish_iteration_name = "detectModel"

# The iteration is now trained. Publish it to the project endpoint
trainer.publish_iteration(
    project.id, 
    iteration_id, 
    publish_iteration_name, 
    os.getenv("PREDICTION_RESOURCE_ID")
    )
print ("Done!")