In [None]:
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score
import pickle
import os

# Ruta para guardar el modelo
MODEL_PATH = os.path.join(os.path.dirname(__file__), "../../models/trained_model.pkl")

def load_data():
    iris = load_iris()
    X, y = iris.data, iris.target
    return X, y

def split_data(**kwargs):
    X, y = load_data()
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    kwargs['ti'].xcom_push(key='X_train', value=X_train)
    kwargs['ti'].xcom_push(key='X_test', value=X_test)
    kwargs['ti'].xcom_push(key='y_train', value=y_train)
    kwargs['ti'].xcom_push(key='y_test', value=y_test)

def train_models(**kwargs):
    ti = kwargs['ti']
    X_train = ti.xcom_pull(key='X_train')
    X_test = ti.xcom_pull(key='X_test')
    y_train = ti.xcom_pull(key='y_train')
    y_test = ti.xcom_pull(key='y_test')

    models = {
        "LogisticRegression": LogisticRegression(),
        "DecisionTree": DecisionTreeClassifier(max_depth=3),
        "XGBoost": XGBClassifier(n_estimators=50)
    }

    scores = {}
    for name, model in models.items():
        model.fit(X_train, y_train)
        preds = model.predict(X_test)
        acc = accuracy_score(y_test, preds)
        scores[name] = (acc, model)

    best_model_name = max(scores, key=lambda k: scores[k][0])
    best_model = scores[best_model_name][1]

    with open(MODEL_PATH, "wb") as f:
        pickle.dump(best_model, f)

    print(f"✅ Mejor modelo guardado: {best_model_name}")

default_args = {
    'start_date': datetime(2023, 1, 1),
    'catchup': False
}

with DAG(
    dag_id='iris_pipeline_dag',
    schedule_interval=None,
    default_args=default_args,
    description='Entrenamiento y versionado del modelo Iris',
    tags=['mlops', 'iris', 'didactico']
) as dag:

    split = PythonOperator(
        task_id='split_data',
        python_callable=split_data
    )

    train = PythonOperator(
        task_id='train_models',
        python_callable=train_models
    )

    split >> train
