In [1]:
import kfp
import kfp.dsl as dsl
import kfp.onprem as onprem

def echo_op(text):
    return dsl.ContainerOp(
        name='echo',
        image='library/bash:4.4.23',
        command=['sh', '-c'],
        arguments=['echo "$0"', text],
    )

@dsl.pipeline(
    name='FMnistPipeline',
    description='mnist '
)

def fmnist_pipeline(learning_rate, dropout_rate,model_path, model_version, model_version3, model_version8):
    
    exit_task = echo_op("Done!")
    
    with dsl.ExitHandler(exit_task): 
        
        #vol component
        #fmnist_pvc = dsl.PipelineVolume(pvc="tfjob-data-volume", name="tfjob-data-volume")
        fmnist_vop = dsl.VolumeOp(
            name="fmnist-volume1",
            resource_name="fmnist-pvc",
            modes=dsl.VOLUME_MODE_RWO,
            size="1Gi"
        )
        
        #base component
        mnist = dsl.ContainerOp(
            name='Train',
            image='khw2126/mnist-simple:C628263',
            command=['python', '/app/Train.py'],
            arguments=[
                "--learning_rate", learning_rate,
                "--dropout_rate", dropout_rate,
                "--model_path", model_path,
                "--model_version", model_version
            ],
            pvolumes={"/result": fmnist_vop.volume}
        )
        
        #quant_aware_tflite com
        quant_aware_tflite = dsl.ContainerOp(
            name='test',
            image='khw2126/test:0.0.2',
            command=['python', '/app/quantization_aware_tflite.py'],
            arguments=[
                "--model_path", model_path,
                "--model_version", model_version,
                "--model_version3", model_version3,
                "--model_version8", model_version8,
            ],
            output_artifact_paths={'mlpipeline-metrics': '/mlpipeline-metrics.json'},
            pvolumes={"/result": mnist.pvolume}
        )
        
        #result com
        result = dsl.ContainerOp(
            name='list_list',
            image='library/bash:4.4.23',
            command=['ls', '-R', '/result'],
            #pvolumes={"/result": prun.pvolume}
            pvolumes={"/result": mnist.pvolume}#prun.pvolume
        )
        
        #order
        quant_aware_tflite.after(mnist)
        result.after(quant_aware_tflite)

    
if __name__ == '__main__':
    #argument !![model_path, model_version]
#model_version: name
    arguments = {'learning_rate': '0.001397', 
                 'dropout_rate': '0.18',
                 'model_path':'/result', #'/result/saved_model'
                 'model_version': '/base_model.h5',#/Base_model.h5
                 'model_version3': '/quan_model.h5',
                 'model_version8': '/qaware_dynamic_post_quan_tflite_file.tflite'
                }
    kfp.Client().create_run_from_pipeline_func(pipeline_func=fmnist_pipeline, arguments=arguments)