In [None]:
from azure.identity import DefaultAzureCredential
from azure.ai.ml import MLClient
from azure.ai.ml.entities import CommandComponent, Input, Output
from azure.ai.ml.entities import CommandJob
from pathlib import Path


# NOTE: set these env vars or fill here
SUBSCRIPTION_ID = ''
RESOURCE_GROUP = ''
WORKSPACE = ''


credential = DefaultAzureCredential()
mlc = MLClient(credential, SUBSCRIPTION_ID, RESOURCE_GROUP, WORKSPACE)


# Register components from YAML
components_dir = Path('./components')
for comp in ['preprocessing','training','validation']:
    yaml_path = components_dir / comp / 'component.yml'
    print('Registering', yaml_path)
    mlc.components.create_or_update(yaml_path)


# Example run orchestrating components
from azure.ai.ml import Input as AMLInput


raw_data = AMLInput(type='uri_file', path='./data/diabetes.csv')
preprocess_out = './outputs/preprocessed/diabetes_clean.csv'
train_model_out = './outputs/model/model.pkl'
metrics_out = './outputs/model/metrics.json'


# Run preprocessing job
pre_cmd = CommandJob(
display_name='preprocess-job',
command='python preprocess.py --input ${{inputs.input_data}} --output ${{outputs.output_data}}',
inputs={'input_data': raw_data},
outputs={'output_data': Output(type='uri_file', path=preprocess_out)},
code='./components/preprocessing'
)


pre_resp = mlc.jobs.create_or_update(pre_cmd)
pre_resp = mlc.jobs.stream(pre_resp.name)


# Run training job
train_cmd = CommandJob(
display_name='train-job',
command='python train.py --input ${{inputs.train_data}} --model_out ${{outputs.model}} --metrics_out ${{outputs.metrics}} --n_estimators 100',
inputs={'train_data': AMLInput(type='uri_file', path=preprocess_out)},
outputs={'model': Output(type='uri_file', path=train_model_out), 'metrics': Output(type='uri_file', path=metrics_out)},
code='./components/training'
)
train_resp = mlc.jobs.create_or_update(train_cmd)
train_resp = mlc.jobs.stream(train_resp.name)



from azure.ai.ml.entities import Model
model = mlc.models.create_or_update(Model(path=train_model_out, name='diabetes_rf_model'))
print('Registered model:', model.name)