# 05: Manual Artifact Logging (Plan B)

**Purpose:** This notebook is a recovery script. Use it **only if** the main `04_model` training notebook successfully trained a model but **failed to log the model artifact** to MLflow (e.g., the "Artifacts" section in the MLflow run is empty).

This script will:
1.  Re-establish the Spark session's authentication to Azure storage.
2.  Re-create the necessary objects (`processor`, `test_set`) that were lost when the cluster terminated.
3.  Load the best model from the saved **checkpoint files** on disk.
4.  Resume the *existing* MLflow run (using its Run ID).
5.  Manually log the model artifact, processor, and signature to that existing run.

### 1. Import Essential Libraries

Import only the libraries needed to load data, re-create objects, and log the model.

In [0]:
%sh
# Update package lists and install the required AIO library
sudo apt-get update && sudo apt-get install -y libaio-dev

In [0]:
%pip install azure-identity azure-keyvault-secrets

# Pin transformers to a version compatible with our MLflow version (2.21.3)
%pip install transformers==4.49.0

In [0]:
# After installing new libraries, we must restart the Python kernel 
# for the changes to take effect in the notebook's environment.
dbutils.library.restartPython()

In [0]:
import mlflow, transformers
# Verify the MLflow and transformers version.
print(mlflow.__version__)
print(transformers.__version__)

In [0]:
import io
import time
import os
import numpy as np
import torch
import mlflow
from mlflow.models import infer_signature
from PIL import Image
from datasets import Dataset, ClassLabel
from transformers import AutoImageProcessor, AutoModelForImageClassification
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient

### 2. Configure Spark Authentication

**Reason:** We must re-configure the Spark session's access to Azure storage. This is required because `Dataset.from_spark(train_df)` (in the next step) will launch a new Spark job to read the Delta table, and this job needs permission to access the underlying storage.

In [0]:
# --- 1. Fetch secrets from Azure Key Vault ---
key_vault_url = "https://anomalykeyvault10222025.vault.azure.net/"
key_name_tenant_id = "tenant-id"
key_name_client_id = "client-id"
key_name_client_secret = "client-secret"
storge_account_name = "stlabelingdevwestus001" 

credential = DefaultAzureCredential()
client = SecretClient(vault_url=key_vault_url, credential=credential)

tenant_id = client.get_secret(key_name_tenant_id).value
client_id = client.get_secret(key_name_client_id).value
client_secret = client.get_secret(key_name_client_secret).value

# --- 2. Set Service Principal credentials on the global Spark session ---
spark.conf.set(f"fs.azure.account.auth.type.{storge_account_name}.dfs.core.windows.net", "OAuth")
spark.conf.set(f"fs.azure.account.oauth.provider.type.{storge_account_name}.dfs.core.windows.net", "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider")
spark.conf.set(f"fs.azure.account.oauth2.client.id.{storge_account_name}.dfs.core.windows.net", client_id)
spark.conf.set(f"fs.azure.account.oauth2.client.secret.{storge_account_name}.dfs.core.windows.net", client_secret)
spark.conf.set(f"fs.azure.account.oauth2.client.endpoint.{storge_account_name}.dfs.core.windows.net", f"https://login.microsoftonline.com/{tenant_id}/oauth2/token")

print(f"Spark session configured for Service Principal access to: {storge_account_name}")

### 3. Re-create `processor` and `test_set`

**Reason:** The original `processor` and `test_set` objects existed only in the memory of the terminated cluster. We must re-create them.
* `processor` is required by `mlflow.transformers.log_model` to save the complete pipeline.
* `test_set` is required to get a sample input for generating the model's signature.

In [0]:
# --- 1. Define dataset name and load processor ---
dataset_name = "train_dataset"
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224', use_fast = True, do_resize = True, size = 224)

# --- 2. Define the preprocessing function (must be identical to training) ---
def preprocess(example):
    image = Image.open(io.BytesIO(example['image'])).convert('RGB')
    processed_img = processor(images = image, return_tensors = 'pt')
    example['pixel_values'] = processed_img['pixel_values'].squeeze()
    return example

# --- 3. Re-load and re-process the data ---
print("Loading Spark DataFrame...")
train_df = spark.read.table(dataset_name)

print("Converting to Hugging Face Dataset...")
train_dataset = Dataset.from_spark(train_df)

print("Casting label column...")
class_label_feature = ClassLabel(names=['surfing', 'noise'])
train_dataset = train_dataset.cast_column('label', class_label_feature)

print("Mapping preprocessing function...")
train_dataset = train_dataset.map(preprocess)

print("Setting torch format...")
train_dataset.set_format(type='torch', columns=['pixel_values', 'label'])

print("Splitting dataset to get 'test_set'...")
train_dataset = train_dataset.train_test_split(test_size = 0.2, stratify_by_column='label')
test_set = train_dataset['test']

print("Processor and test_set are re-created.")

### 4. Load from Checkpoint and Log to Existing MLflow Run

This is the main recovery step. We will:
1.  Define the `run_id` of the failed run (copy this from the MLflow UI).
2.  Define the path to the `best_checkpoint_name` (e.g., `checkpoint-26`) saved by the `Trainer`.
3.  Load this checkpoint from disk into a model object.
4.  Use `mlflow.start_run(run_id=...)` to "re-open" the existing MLflow run.
5.  Manually generate the model signature.
6.  Call `mlflow.transformers.log_model` to finally save the artifact.

In [0]:
# --- 1. User Configuration ---

# PASTE THE RUN ID from the MLflow UI
run_id_to_resume = "5b44ff1a1abb4885a01dacec8f233d0e" 

# PASTE the name of the best checkpoint folder saved during training
best_checkpoint_name = "checkpoint-26" 

# Define paths (ensure these match your training script)
model_name = 'vit-base-patch16-224-anomaly'
mount_path = "/mnt/datamount"
output_path = f"/dbfs{mount_path}/{model_name}"
best_model_path = f"{output_path}/{best_checkpoint_name}"

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device set to use {device}")
# -------------------------------

print(f"Loading best model from checkpoint: {best_model_path}")
try:
    # 2. Load the trained model from the checkpoint file
    model_to_log = AutoModelForImageClassification.from_pretrained(best_model_path)
    model_to_log.to(device) # Ensure model is on the correct device
    model_to_log.eval()     # Set model to evaluation mode
    print("Model loaded successfully.")

    # 3. Resume the previously failed MLflow Run
    with mlflow.start_run(run_id=run_id_to_resume):
        
        print(f"Resumed MLflow run: {run_id_to_resume}")
        print("Logging model artifacts manually...")

        # 4. Manually create the signature
        # Get a sample input tensor from the test_set
        sample_input = next(iter(test_set))
        input_tensor = sample_input['pixel_values'].unsqueeze(0).to(device)
        
        # Run the model to get a sample output
        with torch.no_grad():
            output_tensor = model_to_log(input_tensor)

        # Convert input and output to numpy for the signature
        input_array = input_tensor.cpu().numpy()
        output_array = output_tensor.logits.cpu().numpy()

        # Create the signature object
        signature = infer_signature(input_array, output_array)
        
        # 5. Log the model artifacts
        mlflow.transformers.log_model(
            transformers_model={
                "model": model_to_log,
                "image_processor": processor # Use the 'image_processor' key
            }, 
            artifact_path="model",
            signature=signature,          # Use the manual signature
            task="image-classification"   # Explicitly state the task
        )
        print("Model artifacts logged successfully to existing run.")

except Exception as e:
    print(f"Error: Failed to load model or log artifacts: {e}")

print("Manual logging process finished.")