In [None]:
# mlflow_remote_run.py
import os
import sys
import time
import requests

# Set this to True to attempt using the remote MLflow server first.
USE_REMOTE = True
REMOTE_MLFLOW_URI = "http://104.197.55.178:5000"   # <- change if needed
LOCAL_MLRUNS_DIR = os.path.join(os.getcwd(), "mlruns")  # fallback local store

# Path to your GCP service account key (if needed for artifacts / GCS)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = r"my-test-project-466009-b10d519c6fd7.json"

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
import mlflow
import mlflow.keras
from mlflow.tracking import MlflowClient

# Training/config
EPOCHS = 5
BATCH_SIZE = 32
INPUT_SHAPE = (64, 64, 3)
EXPERIMENT_NAME = "cat-vs-dog-hyperparam-tuning"
RUN_NAME = "mobilenetv2-cat-vs-dog"

# Dataset paths - update these to match your environment
TRAIN_DIR = r"C:\Users\pares\OneDrive\Documents\vs code\are-u-cat\dataset\training_set"
TEST_DIR  = r"C:\Users\pares\OneDrive\Documents\vs code\are-u-cat\dataset\test_set"
SAMPLE_TEST_IMAGE = r"C:\Users\pares\OneDrive\Documents\vs code\are-u-cat\dataset\test_set\dogs\dog.4025.jpg"

# Helper to check remote MLflow server reachable
def is_server_reachable(url, timeout=4):
    try:
        r = requests.get(url, timeout=timeout)
        return r.status_code == 200 or r.status_code == 401 or r.status_code == 403
        # UI may return 200 or require auth (401/403) — treat those as reachable.
    except Exception:
        return False

# Choose tracking URI (remote preferred)
if USE_REMOTE and is_server_reachable(REMOTE_MLFLOW_URI):
    mlflow_tracking_uri = REMOTE_MLFLOW_URI
    print(f"Using remote MLflow tracking server: {mlflow_tracking_uri}")
else:
    # Fall back to local file-based mlruns if remote not reachable
    os.makedirs(LOCAL_MLRUNS_DIR, exist_ok=True)
    local_fixed = LOCAL_MLRUNS_DIR.replace("\\", "/")
    mlflow_tracking_uri = f"file:///{local_fixed}"
    print(f"Remote MLflow not reachable — falling back to local tracking URI: {mlflow_tracking_uri}")

# IMPORTANT: set tracking URI BEFORE autolog and BEFORE starting any runs
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_experiment(EXPERIMENT_NAME)

# Enable autolog AFTER setting the remote tracking uri so autolog writes to correct server
mlflow.keras.autolog()

# Start run
with mlflow.start_run(run_name=RUN_NAME):
    # Build model (MobileNetV2 backbone)
    base_model = tf.keras.applications.MobileNetV2(
        input_shape=INPUT_SHAPE,
        include_top=False,
        weights='imagenet'
    )
    base_model.trainable = False

    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(1, activation='sigmoid')
    ])

    model.compile(
        optimizer=optimizers.Adam(),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )

    # Data generators
    train_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
    training_set = train_datagen.flow_from_directory(
        TRAIN_DIR,
        target_size=INPUT_SHAPE[:2],
        batch_size=BATCH_SIZE,
        class_mode='binary'
    )

    test_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
    test_set = test_datagen.flow_from_directory(
        TEST_DIR,
        target_size=INPUT_SHAPE[:2],
        batch_size=BATCH_SIZE,
        class_mode='binary'
    )

    # Train
    history = model.fit(training_set, validation_data=test_set, epochs=EPOCHS)

    # Avoid overwriting autolog params: save additional info as tags or manual_ keys
    mlflow.set_tag("note", "Using autolog; manual tags used to avoid param collisions")
    mlflow.set_tag("manual_input_shape", f"{INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}x{INPUT_SHAPE[2]}")
    mlflow.set_tag("manual_optimizer", "Adam")
    mlflow.set_tag("manual_epochs", str(EPOCHS))

    # If you want params, give them unique keys (prefix manual_)
    mlflow.log_param("manual_batch_size", BATCH_SIZE)

    # Log final metrics (autolog usually logs epoch metrics already)
    if "accuracy" in history.history and "val_accuracy" in history.history:
        mlflow.log_metric("final_train_accuracy", float(history.history['accuracy'][-1]))
        mlflow.log_metric("final_val_accuracy", float(history.history['val_accuracy'][-1]))

    # Explicit model log (autolog may already save the model; it's safe to call again)
    try:
        mlflow.keras.log_model(model, artifact_path="model")
    except Exception as e:
        print("Warning: mlflow.keras.log_model raised an exception:", e)

    print(f"✅ Final Validation Accuracy: {history.history['val_accuracy'][-1]:.4f}")
    print("🎯 MLflow run completed successfully!")

    # Quick sample prediction
    try:
        test_image = load_img(SAMPLE_TEST_IMAGE, target_size=INPUT_SHAPE[:2])
        test_image = img_to_array(test_image) / 255.0
        test_image = np.expand_dims(test_image, axis=0)
        result = model.predict(test_image)
        prediction = 'dog' if result[0][0] >= 0.5 else 'cat'
        print("Prediction (sample):", prediction)
    except Exception as e:
        print("Could not run sample prediction:", e)

# Inspect the tracking server / experiment to confirm the run
try:
    client = MlflowClient(tracking_uri=mlflow.get_tracking_uri())
    exp = client.get_experiment_by_name(EXPERIMENT_NAME)
    if exp is None:
        print("Experiment not found on tracking server (unexpected).")
    else:
        print(f"Experiment '{EXPERIMENT_NAME}' found. id={exp.experiment_id}")
        runs = client.search_runs(exp.experiment_id, order_by=["attributes.start_time DESC"], max_results=5)
        print(f"Latest runs (up to 5) for experiment '{EXPERIMENT_NAME}':")
        for r in runs:
            rid = r.info.run_id
            print(f"- run_id: {rid} | status: {r.info.status} | start_time: {time.ctime(r.info.start_time/1000)}")
            # print a couple of params/metrics
            print("  params:", r.data.params)
            print("  metrics (sample):", {k: v for k, v in list(r.data.metrics.items())[:5]})
except Exception as e:
    print("Could not inspect remote tracking server:", e)

print("Active MLflow tracking URI:", mlflow.get_tracking_uri())
