<a href="https://colab.research.google.com/github/aboubacardiallo050/ODC/blob/main/Classification_dag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import pandas as pd
import logging
import tempfile
from airflow.exceptions import AirflowException

default_args = {
    'start_date': datetime(2023, 1, 1),
    'catchup': False,
    'params': {
        'test_size': 0.2,
        'random_state': 42
    }
}

def load_and_split(**context):
    params = context['params']
    iris = load_iris(as_frame=True)
    X_train, X_test, y_train, y_test = train_test_split(
        iris.data, iris.target,
        test_size=params['test_size'],
        random_state=params['random_state']
    )

    with tempfile.TemporaryDirectory() as tmp_dir:
        X_train.to_csv(f'{tmp_dir}/X_train.csv', index=False)
        X_test.to_csv(f'{tmp_dir}/X_test.csv', index=False)
        pd.DataFrame(y_train).to_csv(f'{tmp_dir}/y_train.csv', index=False)
        pd.DataFrame(y_test).to_csv(f'{tmp_dir}/y_test.csv', index=False)
        return tmp_dir

def train_and_evaluate(tmp_dir, **context):
    try:
        X_train = pd.read_csv(f'{tmp_dir}/X_train.csv')
        X_test = pd.read_csv(f'{tmp_dir}/X_test.csv')
        y_train = pd.read_csv(f'{tmp_dir}/y_train.csv').values.ravel()
        y_test = pd.read_csv(f'{tmp_dir}/y_test.csv').values.ravel()

        model = RandomForestClassifier()
        model.fit(X_train, y_train)
        predictions = model.predict(X_test)
        acc = accuracy_score(y_test, predictions)

        logging.info(f'Accuracy: {acc:.4f}')
        context['ti'].xcom_push(key='accuracy', value=acc)
    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        raise AirflowException("Model training error")

with DAG('iris_classification',
         schedule_interval=None,
         default_args=default_args,
         description='Classification avec Iris et scikit-learn',
         tags=['ml'],
         doc_md="""## DAG de Classification Iris""") as dag:

    load_task = PythonOperator(
        task_id='load_and_split_data',
        python_callable=load_and_split,
        provide_context=True
    )

    train_task = PythonOperator(
        task_id='train_model',
        python_callable=train_and_evaluate,
        op_kwargs={'tmp_dir': "{{ ti.xcom_pull(task_ids='load_and_split_data') }}"},
        provide_context=True
    )

    load_task >> train_task