In [None]:
# Links
# https://github.com/v-loves-avocados/chicago-taxi
# Pipeline based NYC taxi: https://github.com/LilHelmer/NYC-Taxi-Fare-Prediction/blob/master/ny_taxi_fare_prediction.ipynb
# https://www.kaggle.com/margulisshahar/taxi-fare-prediction-and-tip-classification
# NYC Taxi: https://www.kaggle.com/aiswaryaramachandran/eda-and-feature-engineering
# Training idea: https://medium.com/analytics-vidhya/machine-learning-to-predict-taxi-fare-part-two-predictive-modelling-f80461a8072e
# GCP Scikit guide: https://cloud.google.com/ai-platform/training/docs/training-scikit-learn

import os
from IPython.core.display import display, HTML
from datetime import datetime
import mlflow
import pymysql

In [None]:
# Jupyter magic jinja template to create Python file with variable substitution.
# Dictonaries for substituted variables: env[] for OS environment vars and var[] for global variables
from IPython.core.magic import register_line_cell_magic
from jinja2 import Template

@register_line_cell_magic
def writetemplate(line, cell):
    dirname = os.path.dirname(line)
    if len(dirname)>0 and not os.path.exists(dirname):
        os.makedirs(dirname)
    with open(line, 'w') as f:
        f.write(Template(cell).render({'env' : os.environ, 'var' : globals()}))

#### Global parameters

In [None]:
# Name of the experiment in MLFlow tracking and name in model registry
experiment_name = "chicago-taxi-m1"
number_of_parallel_trainings = 2

In [None]:
# Experiment name in MLflow 
mlflow.set_experiment(experiment_name)

mlflow_tracking_uri = mlflow.get_tracking_uri()
# MLflow public URI
MLFLOW_TRACKING_EXTERNAL_URI = os.environ["MLFLOW_TRACKING_EXTERNAL_URI"]

REGION=os.environ["MLOPS_REGION"]
ML_IMAGE_URI = os.environ["ML_IMAGE_URI"]
COMPOSER_NAME = os.environ["MLOPS_COMPOSER_NAME"]
MLFLOW_GCS_ROOT_URI = os.environ["MLFLOW_GCS_ROOT_URI"]

print(f"Cloud Composer instance name: {COMPOSER_NAME}")
print(f"Cloud Composer region: {REGION}")
print(f"MLflow tracking server URI: {mlflow_tracking_uri}")
print(f"MLflow GCS root: {MLFLOW_GCS_ROOT_URI}")

experiment_path = MLFLOW_GCS_ROOT_URI.replace("gs://","")
display(HTML('<hr>You can check results of this test in MLflow and GCS folder:'))
display(HTML(f'<h4><a href="{MLFLOW_TRACKING_EXTERNAL_URI}" rel="noopener noreferrer" target="_blank">Click to open MLflow UI</a></h4>'))
display(HTML(f'<h4><a href="https://console.cloud.google.com/storage/browser/{experiment_path}/experiments" rel="noopener noreferrer" target="_blank">Click to open MLFlow GCS folder</a></h4>'))

!mkdir -p ./package/training
!touch ./package/training/__init__.py

In [None]:
%%writefile ./package/setup.py
from setuptools import find_packages
from setuptools import setup

REQUIRED_PACKAGES = ['mlflow==1.11.0','PyMySQL==0.9.3']

setup(
    name='trainer',
    version='0.1',
    install_requires=REQUIRED_PACKAGES,
    packages=find_packages(),
    include_package_data=True,
    description='Customer training setup.'
)

In [None]:
%%writetemplate ./package/training/task.py

import sys, stat
import argparse
import os
import numpy as np
import pandas as pd
import glob
from scipy import stats

from sklearn.linear_model import LogisticRegression # Only for train_test
from sklearn.ensemble import RandomForestRegressor
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler

import mlflow
import mlflow.sklearn
from mlflow.models.signature import infer_signature

from joblib import dump, load
from google.cloud import storage

csv_delimiter = '|'

def copy_local_directory_to_gcs(local_path, gcs_uri):
    assert os.path.isdir(local_path)
    job_dir =  gcs_uri.replace('gs://', '')
    bucket_id = job_dir.split('/')[0]
    bucket_path = job_dir.lstrip('{}/'.format(bucket_id))
    bucket = storage.Client().bucket(bucket_id)
    blob = bucket.blob('{}/{}'.format(bucket_path, local_path))
    _upload_local_to_gcs(local_path, bucket, bucket_path)
        
def _upload_local_to_gcs(local_path, bucket, bucket_path):
    for local_file in glob.glob(local_path + '/**'):
        if not os.path.isfile(local_file):
           _upload_local_to_gcs(local_file, bucket, bucket_path + "/" + os.path.basename(local_file))
        else:
           remote_path = os.path.join(bucket_path, local_file[1 + len(local_path):])
           blob = bucket.blob(remote_path)
           blob.upload_from_filename(local_file)

def feature_engineering(data):
    # Add 'N/A' for missing 'Company'
    data.fillna(value={'company':'N/A','tolls':0}, inplace=True)
    # Drop rows contains null data.
    data.dropna(how='any', axis='rows', inplace=True)
    # Pickup and dropoff locations distance
    data["abs_distance"] = (np.hypot(data["dropoff_latitude"]-data["pickup_latitude"], data["dropoff_longitude"]-data["pickup_longitude"]))*100

    # Remove extremes, outliers
    possible_outliers_cols = ['trip_seconds', 'trip_miles', 'fare', 'abs_distance']
    data=data[(np.abs(stats.zscore(data[possible_outliers_cols])) < 3).all(axis=1)].copy()
    # Reduce location accuracy
    data=data.round({'pickup_latitude': 3, 'pickup_longitude': 3, 'dropoff_latitude':3, 'dropoff_longitude':3})
    #X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=123)
    #X_train = X_train.drop('fare', axis=1)

    # Returns training only features (X) and fare (y)  
    return (
        data.drop(['fare', 'trip_start_timestamp'], axis=1),
        data['fare']
    )

def build_pipeline(number_of_estimators = 20, max_features = 'auto'):
    ct_pipe = ColumnTransformer(transformers=[
    ('hourly_cat', OneHotEncoder(categories=[range(0,24)], sparse = False), ["trip_start_hour"]),
    ('dow', OneHotEncoder(categories=[['Mon', 'Tue', 'Sun', 'Wed', 'Sat', 'Fri', 'Thu']], sparse = False), ["trip_start_day_of_week"]),
    ('std_scaler', StandardScaler(), [
        'trip_start_year',
        'abs_distance',
        'pickup_longitude',
        'pickup_latitude',
        'dropoff_longitude',
        'dropoff_latitude',
        'trip_miles',
        'trip_seconds'])
    ])
    rfr_pipe = Pipeline([
        ('ct', ct_pipe),
        ('forest_reg', RandomForestRegressor(n_estimators = number_of_estimators, max_features = max_features, n_jobs = -1, random_state = 3))
    ])
    return rfr_pipe

def train_model(args):
    print("Taxi fare estimation model training step started...")
    mlflow.set_experiment(args.experiment_name)
    #mlflow.sklearn.autolog()
    with mlflow.start_run(nested=True) as mlflow_run:
        mlflow.log_param("number_of_estimators", args.number_of_estimators)
        mlflow.set_tag("version", args.version_tag)
        mlflow.set_tag("job_name", args.job_name)
        mlflow.log_param("gcs_train_source", args.gcs_train_source)
        if not args.gcs_train_source:
            print("Missing GCS training source URI")
            return
        mlflow.log_param("gcs_eval_source", args.gcs_eval_source)
        if not args.gcs_eval_source:
            print("Missing GCS evaluation source URI")
            return

        df = pd.read_csv(args.gcs_train_source, sep=csv_delimiter)
        mlflow.log_param('training_shape', f'{df.shape}')
        
        X_train, y_train = feature_engineering(df)
        rfr_pipe = build_pipeline(number_of_estimators=args.number_of_estimators)
        
        rfr_score = cross_val_score(rfr_pipe, X_train, y_train, scoring = "neg_mean_squared_error", cv=5)
        mlflow.log_metric("train_cross_valid_score_rmse_mean", np.sqrt(-rfr_score).mean())
        final_model = rfr_pipe.fit(X_train, y_train)
#        signature = infer_signature(X_train, rfr_pipe.predict(X_train)) , signature=signature
        mlflow.sklearn.log_model(final_model, "chicago_rnd_forest")

        # Evaluate model to eval set
        df = pd.read_csv(args.gcs_eval_source, sep=csv_delimiter)
        mlflow.log_param('eval_shape',f'{df.shape}')
        X_eval, y_eval = feature_engineering(df)
        X_eval['fare_pred'] = final_model.predict(X_eval)
        rfr_score = cross_val_score(final_model, X_eval, y_eval, scoring='neg_mean_squared_error', cv=5)
        mlflow.log_metric("eval_cross_valid_score_rmse_mean", np.sqrt(-rfr_score).mean())
        
        # Save model
        model_file_name = f'{args.version_tag}.joblib'
        mlflow.sklearn.save_model(final_model, model_file_name)
        copy_local_directory_to_gcs(model_file_name, args.job_dir)
        mlflow.log_param('model_file', args.job_dir+'/'+model_file_name)

    print("Training finished.")

def main():
    print("Training arguments: " + " ".join(sys.argv[1:]))
    parser = argparse.ArgumentParser()
    parser.add_argument("--number_of_estimators", type=int)
    parser.add_argument("--job-dir", type=str)
    parser.add_argument("--local_data", type=str)
    parser.add_argument("--gcs_train_source", type=str)
    parser.add_argument("--gcs_eval_source", type=str)
    parser.add_argument("--experiment_name", type=str)
    parser.add_argument("--version_tag", type=str)
    parser.add_argument("--job_name", type=str)
    
    args, unknown_args = parser.parse_known_args()

    # CLOUD_ML_JOB conatains other CAIP Training runtime parameters in JSON object
    # job = os.environ["CLOUD_ML_JOB"]
    
    # MLflow locally available
    mlflow.set_tracking_uri("http://127.0.0.1:80")

    train_model(args)

if __name__ == "__main__":
    main()

In [None]:
# Create trainer packege
!cd package && python ./setup.py sdist

# Copy to Composer data folder
!gcloud composer environments storage data import \
    --environment {COMPOSER_NAME} \
    --location {REGION} \
    --source ./package/dist \
    --destination multi_model_trainer_dag

In [None]:
# Copy package files to composer 'data' folder
!gcloud composer environments storage data import \
    --environment {COMPOSER_NAME} \
    --location {REGION} \
    --source ./package \
    --destination multi_model_trainer_dag


#### Create model trainer Airflow DAG
Notice: The entire cell is a template will be written to 'multi_model_trainer_dag.py' file.
        'writetemplate' magic uses Jinja templating while Airflow also provides Jinja templating for runtime parameters.
        Airflow parameters should be wrapped like this: {{ "{{ ts_nodash }}" }} because the template in the template mechanizm.

In [None]:
%%writetemplate multi_model_trainer_dag.py

# Train multiple models in separate AI Platform Training Jobs (PythonOperator)
#  Input: data in GCS
#  Output: model1.joblib model2.joblib
#  Note: eval metric (one eval split) is stored in MLflow

# Evaluate the previous model on the current  eval split
#  Input: experiment Id (fetch the last (registered) model)
#  Output: eval stored in MLflow for the previous model

# Validate the model (PythonOperator)
#  Input: Mflow metric
#  Output: which model (path) to register

# Register the model (PythonOperator) 
#  Input: Path of the winning model
#  Output: Model in specific GCS location
#  Registering model to MLFlow

import os
import logging
from datetime import (datetime, timedelta)
import random
import uuid

import mlflow
import mlflow.sklearn

#import tensorflow_data_validation as tfdv

import airflow
from airflow import DAG
from airflow.operators.bash_operator import BashOperator
from airflow.operators.python_operator import PythonOperator
# TODO: Change to airflow.providers
from airflow.contrib.operators.bigquery_operator import BigQueryOperator
from airflow.contrib.operators.bigquery_table_delete_operator import BigQueryTableDeleteOperator
from airflow.contrib.operators.bigquery_to_gcs import BigQueryToCloudStorageOperator
from airflow.providers.google.cloud.operators.mlengine import MLEngineStartTrainingJobOperator

csv_delimiter = '|'
experiment_name = "{{ var['experiment_name'] }}"
ML_IMAGE_URI = "{{ var['ML_IMAGE_URI'] }}"
job_experiment_root = f"{{ env['MLFLOW_GCS_ROOT_URI'] }}/experiments/{experiment_name}"

PROJECT_ID = os.getenv("GCP_PROJECT", default="edgeml-demo")
REGION = os.getenv("COMPOSER_LOCATION", default="us-central1")

# Postfixes for temporary BQ tables and output CSV files
TRAINING_POSTFIX = "_training"
EVAL_POSTFIX = "_eval"
VALIDATION_POSTFIX = "_validation"

BQ_DATASET = "chicago_taxi_trips"
BQ_TABLE = "taxi_trips"
BQ_QUERY = """
with tmp_table as (
SELECT trip_seconds, trip_miles, fare, tolls, 
    company, pickup_latitude, pickup_longitude, dropoff_latitude, dropoff_longitude,
    DATETIME(trip_start_timestamp, 'America/Chicago') trip_start_timestamp,
    DATETIME(trip_end_timestamp, 'America/Chicago') trip_end_timestamp,
    CASE WHEN (pickup_community_area IN (56, 64, 76)) OR (dropoff_community_area IN (56, 64, 76)) THEN 1 else 0 END is_airport,
FROM `bigquery-public-data.chicago_taxi_trips.taxi_trips`
WHERE
    dropoff_latitude IS NOT NULL and
    dropoff_longitude IS NOT NULL and
    pickup_latitude IS NOT NULL and
    pickup_longitude IS NOT NULL and
    fare > 0 and 
    trip_miles > 0 and
    MOD(ABS(FARM_FINGERPRINT(unique_key)), 100) {}
ORDER BY RAND()
LIMIT {})
SELECT *,
    EXTRACT(YEAR FROM trip_start_timestamp) trip_start_year,
    EXTRACT(MONTH FROM trip_start_timestamp) trip_start_month,
    EXTRACT(DAY FROM trip_start_timestamp) trip_start_day,
    EXTRACT(HOUR FROM trip_start_timestamp) trip_start_hour,
    FORMAT_DATE('%a', DATE(trip_start_timestamp)) trip_start_day_of_week
FROM tmp_table
"""

BQ_QUERY_FOR_TFDV = """
SELECT unique_key, taxi_id, trip_start_timestamp, trip_end_timestamp, trip_seconds, trip_miles, pickup_census_tract, 
    dropoff_census_tract, pickup_community_area, dropoff_community_area, fare, tips, tolls, extras, trip_total, 
    payment_type, company, pickup_latitude, pickup_longitude, pickup_location, dropoff_latitude, dropoff_longitude, dropoff_location
FROM `bigquery-public-data.chicago_taxi_trips.taxi_trips` 
"""

def generate_tfdv_statistics(gcs_file_name, **kwargs):
    logging.info("Processing %s", gcs_file_name)
    # Currently skipped because of pip module versions in Airflow
    #train_stats = tfdv.generate_statistics_from_csv(gcs_file_name)
    #tfdv.WriteStatisticsToTFRecord(output_path = gcs_file_name + ".tfrecord")
    return None

def joiner_func(training_gcs_file_name, eval_gcs_file_name, **kwargs):
    logging.info("Joining %s, eval GCS files %s", training_gcs_file_name, eval_gcs_file_name)
    return None

def model_trainer(training_gcs_file_name, eval_gcs_file_name, model_file, **kwargs):
    logging.info("Training %s, eval GCS files %s", training_gcs_file_name, eval_gcs_file_name)
    return None

def fake_model_tracking(**kwargs):
    job_name = kwargs.get('templates_dict').get('job_name')
    print(f"Fake model tracking: '{job_name}'")
    mlflow.set_experiment(experiment_name)
    with mlflow.start_run(nested=True) as run:
        mlflow.log_param("number_of_estimators", 0)
        mlflow.set_tag("version", "fake")
        mlflow.set_tag("job_name", job_name)
        mlflow.log_metric("train_cross_valid_score_rmse_mean", 1+random.random())
        mlflow.log_metric("eval_cross_valid_score_rmse_mean", 1+random.random())
    return None

def register_model(run_id, model_name):
    model_uri = f'runs:/{run_id}/{model_name}'
    registered_model = mlflow.register_model(model_uri, model_name)
    print(registered_model)

def compare_to_registered_model(model_name, best_run, metric_to_compare):
    # Compare the best run with latest registered model
    mlflow_client = mlflow.tracking.MlflowClient()
    registered_models=mlflow_client.search_registered_models(filter_string=f"name='{model_name}'", max_results=1, order_by=['timestamp DESC'])
    if len(registered_models)==0:
        register_model(best_run.run_id, model_name)
    else:
        last_version = registered_models[0].latest_versions[0]
        run = mlflow_client.get_run(last_version.run_id)
        if not run:
            print(f'Registered version run missing!')            
            return None
            
        last_registered_metric=run.data.metrics[metric_to_compare]
        best_run_metric=best_run['metrics.'+metric_to_compare]
        # Smaller value is better
        if last_registered_metric>best_run_metric:
            print(f'Register better version with metric: {best_run_metric}')
            register_model(best_run.run_id, experiment_name)
        else:
            print(f'Registered version still better. Metric: {last_registered_metric}')    

def model_blessing(**kwargs):
    job_name = kwargs.get('templates_dict').get('job_name')
    print(f"Model blessing: '{job_name}'")

    # Select the best from current training jobs
    experiment = mlflow.get_experiment_by_name(experiment_name)
    filter_string = f"tags.job_name ILIKE '{job_name}_%'"
    df = mlflow.search_runs([experiment.experiment_id], filter_string=filter_string)

    # Compare new trained model and select the best.
    eval_max = df.loc[df['metrics.eval_cross_valid_score_rmse_mean'].idxmax()]
    train_max= df.loc[df['metrics.train_cross_valid_score_rmse_mean'].idxmax()]
    
    compare_to_registered_model(experiment_name, eval_max, 'eval_cross_valid_score_rmse_mean')

tasks = {
    "training" : {
        "dataset_range" : "between 0 and 80",
        "limit" : 4000
        },
    "eval":{
        "dataset_range" : "between 80 and 100",
        "limit" : 1000
        }}

with DAG("multi_model_trainer",
         description = "Train evaluate and validate multi models on taxi fare dataset. Select the best one and register it to Mlflow v0.86",
         schedule_interval = None, #'*/15 * * * *', #None, -> manual trigger
         start_date = datetime(2021, 1, 1),
         max_active_runs = 3,
         catchup = False,
         default_args = { 'provide_context': True}
         ) as dag:

    # Define task list for preparation
    for task_key in tasks.keys():
        # Note: fix table names causes race condition in case when DAG triggered before the previous finished.
        table_name = f"{PROJECT_ID}.{BQ_DATASET}.{BQ_TABLE}_{task_key}"
        task = tasks[task_key]
        task["gcs_file_name"] = f"{job_experiment_root}/data/ds_{task_key}.csv"
        
        # Deletes previous training temporary tables
        task["delete_table"] = BigQueryTableDeleteOperator(
            task_id = "delete_table_" + task_key,
            deletion_dataset_table = table_name,
            ignore_if_missing = True)

        # Splits and copy source BQ table to 'dataset_range' sized segments
        task["split_table"] = BigQueryOperator(
            task_id = "split_table_" + task_key,
            use_legacy_sql=False,
            destination_dataset_table = table_name,
            sql = BQ_QUERY.format(task["dataset_range"],task["limit"]),
            location = REGION)
        
        # Extract split tables to CSV files in GCS
        task["extract_to_gcs"] = BigQueryToCloudStorageOperator(
            task_id = "extract_to_gcs_" + task_key,
            source_project_dataset_table = table_name,
            destination_cloud_storage_uris = [task["gcs_file_name"]],
            field_delimiter = csv_delimiter)
        
        # Generates statisctics by TFDV
        task["tfdv_statisctics"] = PythonOperator(
            task_id = "tfdv_statistics_for_" + task_key,
            python_callable = generate_tfdv_statistics,
            op_kwargs={'gcs_file_name': task["gcs_file_name"]})

    joiner_1 = PythonOperator(
        task_id = "joiner_1",
        python_callable = joiner_func,
        op_kwargs={ 'training_gcs_file_name': tasks["training"]["gcs_file_name"],
                    'eval_gcs_file_name': tasks["eval"]["gcs_file_name"]})

    # Model trainers
#    trainer_1 = PythonOperator(
#        task_id = "trainer_1",
#        python_callable = model_trainer,
#        op_kwargs={ 'training_gcs_file_name': tasks["training"]["gcs_file_name"],
#                    'eval_gcs_file_name': tasks["eval"]["gcs_file_name"],
#                    'model_file': f"{job_experiment_root}/data/model1.joblib"}
#    )

    submit_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    job_name = f"training_job_{submit_time}"
    job_dir = f"{job_experiment_root}/dmt_{submit_time}"

    # Template for string format ({variable}) and jinja template ({{variable}})
    training_command_tmpl="""gcloud ai-platform jobs submit training {job_name} \
        --region {region} \
        --scale-tier BASIC \
        --job-dir {job_dir} \
        --package-path /home/airflow/gcs/data/multi_model_trainer_dag/package/training/ \
        --module-name training.task \
        --master-image-uri {ml_image_uri} \
        --stream-logs \
        -- \
        --experiment_name {experiment_name} \
        --gcs_train_source {gcs_train_source} \
        --gcs_eval_source {gcs_eval_source} \
        --version_tag {version_tag} \
        --number_of_estimators {number_of_estimators} \
        --job_name {job_name}"""

    training_tasks = []
    for training_id in range(0, {{var['number_of_parallel_trainings']}}):
        # Simulated training
#        trainer = PythonOperator(
#            task_id = f'trainer_{training_id}',
#            python_callable = fake_model_tracking,
#            templates_dict={'job_name': 'training_job_{{ "{{ ts_nodash }}" }}'+f'_{training_id}'})
        
        trainer = BashOperator(
            task_id=f'trainer_{training_id}',
            bash_command=training_command_tmpl.format(
                 region = REGION,
                 job_name = 'training_job_{{ "{{ ts_nodash }}" }}'+f'_{training_id}',
                 job_dir = job_dir+f'_{training_id}',
                 ml_image_uri = ML_IMAGE_URI,
                 gcs_train_source = tasks["training"]["gcs_file_name"],
                 gcs_eval_source = tasks["eval"]["gcs_file_name"],
                 experiment_name = experiment_name,
                 version_tag = f'trainer_{training_id}',
                 # The only difference in trainings:
                 number_of_estimators = random.randrange(60,200))
        )

        NATIVE_AIRFLOW="""
        trainer_1 = MLEngineStartTrainingJobOperator(
            task_id=f'trainer_{training_id}',
            project_id=PROJECT_ID,
            job_id=f'trainer_{training_id}',
            package_uris='gs://us-central1-mlops-50-af-e9fe149d-bucket/data/multi_model_trainer_dag/dist/trainer-0.1.tar.gz',
            training_python_module='training.task',
            master_type='CUSTOM',
            master_config={"imageUri" : ML_IMAGE_URI},
            training_args=[f'--jobDir={job_dir}',
                           f'--experiment_name={experiment_name}',
                           f'--gcs_train_source {tasks["training"]["gcs_file_name"]}',
                           f'--gcs_eval_source {tasks["eval"]["gcs_file_name"]}',
                           f'--version_tag trainer_{training_id}',
                           f'--number_of_estimators {random.randrange(60,200)}',
                           f'--job_name training_job_{{ "{{ ts_nodash }}" }}'+f'_{training_id}'
                          ],
            region=REGION,
            scale_tier='BASIC',
            runtime_version='2.3',
            python_version='3.7',
            mode="DRY_RUN" # "CLOUD" or "DRY_RUN"
        )
        """
        training_tasks.append(trainer)
    
    # Select the best model of this run
    model_blessing = PythonOperator(
        task_id = "model_blessing",
        python_callable = model_blessing,
        templates_dict={'job_name': 'training_job_{{ "{{ ts_nodash }}" }}'})

    # Exectute tasks
    for task_key, task in tasks.items():
        task["delete_table"] >> task["split_table"] >> task["extract_to_gcs"] >> task["tfdv_statisctics"]
    [tasks["training"]["tfdv_statisctics"], tasks["eval"]["tfdv_statisctics"]] >> joiner_1

    # Brancing and merging training tasks
    for trainer in training_tasks:
        trainer.set_upstream(joiner_1)
        model_blessing.set_upstream(trainer)

#    [tasks["training"]["tfdv_statisctics"], tasks["eval"]["tfdv_statisctics"]] >> trainer_1
#    [tasks["training"]["tfdv_statisctics"], tasks["eval"]["tfdv_statisctics"]] >> trainer_2


In [None]:
# Copy DAG file to Cloud Composer
!gcloud composer environments storage dags import \
  --environment {COMPOSER_NAME}  \
  --location {REGION} \
  --source multi_model_trainer_dag.py