In [5]:
import kfp
import kfp.dsl as dsl
import kfp.onprem as onprem

def echo_op(text):
    return dsl.ContainerOp(
        name='echo',
        image='library/bash:4.4.23',
        command=['sh', '-c'],
        arguments=['echo "$0"', text],
    )

@dsl.pipeline(
    name='FMnistPipeline',
    description='mnist '
)

def fmnist_pipeline(learning_rate, dropout_rate,model_path, model_version, model_version2, model_version3, model_version4, model_version5):
    
    exit_task = echo_op("Done!")
    
    with dsl.ExitHandler(exit_task): 
        
        #vol component
        #fmnist_pvc = dsl.PipelineVolume(pvc="tfjob-data-volume", name="tfjob-data-volume")
        fmnist_vop = dsl.VolumeOp(
            name="fmnist-volume1",
            resource_name="fmnist-pvc",
            modes=dsl.VOLUME_MODE_RWO,
            size="1Gi"
        )
        
        #base component
        mnist = dsl.ContainerOp(
            name='Train',
            image='khw2126/mnist-simple:5DD5D44A',
            command=['python', '/app/Train.py'],
            arguments=[
                "--learning_rate", learning_rate,
                "--dropout_rate", dropout_rate,
                "--model_path", model_path,
                "--model_version", model_version
            ],
            pvolumes={"/result": fmnist_vop.volume}
        )
        
        #prun com
        prun = dsl.ContainerOp(
            name='prun',
            image='khw2126/prun:0.0.11',
            command=['python', '/app/pruning_code.py'],
            arguments=[
                "--model_path", model_path,
                "--model_version", model_version,
                "--model_version2", model_version2,
            ],
            output_artifact_paths={'mlpipeline-metrics': '/mlpipeline-metrics.json'},
            pvolumes={"/result": mnist.pvolume}
        )
        
        #quan aware com
        quan_aware = dsl.ContainerOp(
            name='quan_aware',
            image='khw2126/quan:0.0.9',
            command=['python', '/app/quantization_aware_code.py'],
            arguments=[
                "--model_path", model_path,
                "--model_version", model_version,
                "--model_version3", model_version3,
            ],
            output_artifact_paths={'mlpipeline-metrics': '/mlpipeline-metrics.json'},
            pvolumes={"/result": mnist.pvolume}
        )
        
        #dynamic quan
        dynamic_quan = dsl.ContainerOp(
            name='dynamic_quan',
            image='khw2126/dynamic_quan:0.0.4',
            command=['python', '/app/dynamic_quant.py'],
            arguments=[
                "--model_path", model_path,
                "--model_version", model_version,
                "--model_version4", model_version4,
            ],
            output_artifact_paths={'mlpipeline-metrics': '/mlpipeline-metrics.json'},
            pvolumes={"/result": mnist.pvolume}
        )
        
        #integer_quan
        integer_quant = dsl.ContainerOp(
            name='integer_quant',
            image='khw2126/integer_quant:0.0.4',
            command=['python', '/app/integer_quant.py'],
            arguments=[
                "--model_path", model_path,
                "--model_version", model_version,
                "--model_version5", model_version2,
            ],
            output_artifact_paths={'mlpipeline-metrics': '/mlpipeline-metrics.json'},
            pvolumes={"/result": mnist.pvolume}
        )
        
        #result com
        result = dsl.ContainerOp(
            name='list_list',
            image='library/bash:4.4.23',
            command=['ls', '-R', '/result'],
            #pvolumes={"/result": prun.pvolume}
            pvolumes={"/result": mnist.pvolume}#prun.pvolume
        )
        
        #order
        prun.after(mnist)
        quan_aware.after(mnist)
        dynamic_quan.after(mnist)
        integer_quant.after(mnist)
        
        result.after(prun)
        result.after(quan_aware)
        result.after(dynamic_quan)
        result.after(integer_quant)
        

    
if __name__ == '__main__':
    #argument !![model_path, model_version]
#model_version: name
    arguments = {'learning_rate': '0.001397', 
                 'dropout_rate': '0.18',
                 'model_path':'/result', #'/result/saved_model'
                 'model_version': '/base_model.h5',#/Base_model.h5
                 'model_version2': '/prun_model.h5',
                 'model_version3': '/quan_aware_model.h5',
                 'model_version4': '/dynamic_post_quan_tflite_file.tflite',
                 'model_version5': '/integer_post_quan_tflite_file.tflite'
                }
    kfp.Client().create_run_from_pipeline_func(pipeline_func=fmnist_pipeline, arguments=arguments)