In [None]:
import mlflow
import os
import importlib
from transformers import pipeline

summarization_model = pipeline("summarization")
sentiment_model = pipeline("sentiment-analysis")
translation_model = pipeline('translation_en_to_fr')
image_classification = pipeline("image-classification")

In [None]:
import download_file

def predict_to_save(sentiment_model,summarization_model, translation_model,image_classification,input_type, data = None,min_length = 0, max_length = 150):
    """
    make and combine the prediction of all the differents models on the differents scoring tables

    params_scoring dictionnary: dictionnary that contain all the scoring tables on which models will...
    ... make predictions. It have to be a dictionnary as the predict function of mlflow only take a single argument.

    Return Pandas dataframe with all the fraud risk score (and few others informations) on the remise batch
    Return interpretation_remises pandas dataframe with the interpretation for the remises model
    Return interpretation_client pandas dataframe with the interpretation for the clients models
    """
    
    if input_type == "sentiment":
        return sentiment_model(data)
    elif input_type == "translation":
        return translation_model(data)
    elif input_type == "image":
        image = download_file.download_image(data)
        return image_classification(image)
    elif input_type == "summarization":

        if data is None:
            data = download_file.download_story()

        dict_result = summarization_model(data, min_length, max_length)[0]
        dict_result["input_text"] = data
        return dict_result

    return "mauvais type selectionné"


In [None]:
import mlflow
import cloudflow

tracking_uri ="/Users/simonlemouellic/Documents/test_package_mlflow"
experiment_id = "cloudflow_demo" 

cloudflow.prepare_env(tracking_uri,experiment_id)

with mlflow.start_run(experiment_id = experiment_id) as run:
    
    print("RUN ID : ", run.info.run_id)
    
    mlflow.log_metric('test_metrics', 0.99)

    model = cloudflow.cloudflow_model("DEBUG")        
    model.save(tracking_uri    = tracking_uri,
               experiment_id = experiment_id,
               run_id = run.info.run_id, 
               predict_function = predict_to_save, 
               models = {"summarization_model"  : summarization_model, 
                         "image_classification" : image_classification,
                         "translation_model"    : translation_model,
                         "sentiment_model"      : sentiment_model})