In [None]:
!pip install mlflow

Collecting mlflow
  Downloading mlflow-3.1.4-py3-none-any.whl.metadata (29 kB)
Collecting mlflow-skinny==3.1.4 (from mlflow)
  Downloading mlflow_skinny-3.1.4-py3-none-any.whl.metadata (30 kB)
Collecting alembic!=1.10.0,<2 (from mlflow)
  Downloading alembic-1.16.4-py3-none-any.whl.metadata (7.3 kB)
Collecting docker<8,>=4.0.0 (from mlflow)
  Downloading docker-7.1.0-py3-none-any.whl.metadata (3.8 kB)
Collecting graphene<4 (from mlflow)
  Downloading graphene-3.4.3-py2.py3-none-any.whl.metadata (6.9 kB)
Collecting gunicorn<24 (from mlflow)
  Downloading gunicorn-23.0.0-py3-none-any.whl.metadata (4.4 kB)
Collecting databricks-sdk<1,>=0.20.0 (from mlflow-skinny==3.1.4->mlflow)
  Downloading databricks_sdk-0.59.0-py3-none-any.whl.metadata (39 kB)
Collecting opentelemetry-api<3,>=1.9.0 (from mlflow-skinny==3.1.4->mlflow)
  Downloading opentelemetry_api-1.35.0-py3-none-any.whl.metadata (1.5 kB)
Collecting opentelemetry-sdk<3,>=1.9.0 (from mlflow-skinny==3.1.4->mlflow)
  Downloading opentele

In [None]:
import mlflow
import mlflow.spark
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.functions import rand, randn
import random


In [None]:
def log_synthetic_data():
    # Set MLflow tracking URI
    mlflow.set_tracking_uri("http://mlflow-server:5000")

    # Create or get experiment
    experiment_name = "synthetic_data_experiment" # Use your roll number
    try:
        experiment_id = mlflow.create_experiment(experiment_name)
    except mlflow.exceptions.MlflowException:
        experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id

    mlflow.set_experiment(experiment_name)

    # Initialize Spark session
    spark = SparkSession.builder \
        .appName("MLflow Synthetic Data") \
        .getOrCreate()
        #    .config("spark.executor.memory", "512m") \
        #    .config("spark.executor.core", "2") \
        #    .config("spark.driver.memory", "2g") \
        #    .config("spark.executor.instance", "2")\
        #    .config("spark.driver.core", "1") \


    with mlflow.start_run():
        # Generate synthetic parameters
        learning_rate = random.uniform(0.001, 0.1)
        batch_size = random.choice([32, 64, 128, 256])
        epochs = random.randint(10, 100)

        # Log parameters
        mlflow.log_param("learning_rate", learning_rate)
        mlflow.log_param("batch_size", batch_size)
        mlflow.log_param("epochs", epochs)

        # Generate synthetic metrics
        accuracy = random.uniform(0.7, 0.95)
        loss = random.uniform(0.1, 0.8)
        f1_score = random.uniform(0.65, 0.92)

        # Log metrics
        mlflow.log_metric("accuracy", accuracy)
        mlflow.log_metric("loss", loss)
        mlflow.log_metric("f1_score", f1_score)

        # Generate synthetic Spark DataFrame
        df = spark.range(1000) \
            .withColumn("feature_1", rand(seed=42)) \
            .withColumn("feature_2", randn(seed=123)) \
            .withColumn("target", (rand() > 0.5).cast("int"))

        # Log dataset info
        mlflow.log_param("dataset_size", df.count())
        mlflow.log_param("num_features", len(df.columns) - 1)

        # Save DataFrame as artifact (parquet format)
        temp_path = "/tmp/synthetic_data"
        df.write.mode("overwrite").parquet(temp_path)
        mlflow.log_artifacts(temp_path, "synthetic_dataset")

        # Log some synthetic model artifacts
        model_summary = f"""
        Model Type: Random Forest
        Learning Rate: {learning_rate}
        Batch Size: {batch_size}
        Epochs: {epochs}
        Final Accuracy: {accuracy:.4f}
        Final Loss: {loss:.4f}
        """

        with open("model_summary.txt", "w") as f:
            f.write(model_summary)
        mlflow.log_artifact("model_summary.txt")

        print(f"Run completed successfully!")
        print(f"Experiment: {experiment_name}")
        print(f"Run ID: {mlflow.active_run().info.run_id}")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Loss: {loss:.4f}")

    spark.stop()



In [None]:
if __name__ == "__main__":
    log_synthetic_data()
