# Model Training (ViT Anomaly Detection)

This notebook trains a Vision Transformer (ViT) model for anomaly detection. The process involves:

1.  **Environment Setup**: Installing system-level dependencies (`libaio`) and specific Python libraries (`transformers==4.49.0`) to ensure compatibility and fix bugs.
2.  **Authentication Setup**: Configuring the Spark session with Azure Key Vault secrets to allow MLflow to access storage for data lineage.
3.  **Data Preparation**: Loading data from a Delta table, converting it to a Hugging Face `Dataset`, and applying preprocessing.
4.  **Model Definition**: Loading a pre-trained `google/vit-base-patch16-224` and adapting its classification head for our binary (normal/anomaly) task.
5.  **Training Setup**: Defining `TrainingArguments` and a custom `compute_metrics` function for our imbalanced dataset (using F1-score).
6.  **Training & Manual Logging**: Running the `Trainer` and, due to `autolog()` unreliability in this environment, manually logging the final model, processor, and data lineage to MLflow.

### 1. Environment Setup: System Dependencies

We must first install `libaio-dev` (Linux Asynchronous I/O development library).

**Reason**: The `transformers` and `accelerate` libraries, when running on certain Linux environments like this (Ubuntu Noble), have a deep dependency on this system library. Without it, the `Trainer` will fail to initialize and crash with a cryptic `ld: cannot find -laio` error. This is a system-level fix, not a Python one.

In [0]:
%sh
# Update package lists and install the required AIO library
sudo apt-get update && sudo apt-get install -y libaio-dev

### 2. Environment Setup: Python Libraries

Next, we install specific Python libraries:

* `azure-identity` & `azure-keyvault-secrets`: Required to authenticate with Azure Key Vault and retrieve the secrets (e.g., Service Principal credentials) needed to access our storage.
* `transformers==4.49.0`: **This version pin is critical.** Our `mlflow==2.21.3` version is only officially compatible with `transformers <= 4.49.0`. Using a newer version (like 4.50.x) caused `mlflow.transformers.autolog()` to **silently fail**â€”it would log parameters and metrics, but **fail to save the final model artifact.** Downgrading ensures compatibility.

In [0]:
%pip install azure-identity azure-keyvault-secrets

# Pin transformers to a version compatible with our MLflow version (2.21.3)
%pip install transformers==4.49.0

In [0]:
# After installing new libraries, we must restart the Python kernel 
# for the changes to take effect in the notebook's environment.
dbutils.library.restartPython()

In [0]:
import mlflow, transformers
# Verify the MLflow and transformers version.
print(mlflow.__version__)
print(transformers.__version__)

### 3. Environment Setup: Spark Authentication

**Reason**: This step is required for **MLflow Data Lineage** (`mlflow.log_input`). When we call `mlflow.data.from_spark(train_df)`, MLflow tries to launch a *new* Spark job to profile the data. This new job needs permission to read the original Delta table files from `abfss://` storage.

The credentials used to mount `/mnt/datamount` are not automatically inherited by this new MLflow job. Therefore, we must *explicitly* set the Service Principal credentials (fetched from Key Vault) on the **global Spark session configuration**. This ensures any Spark job spawned by this session (including MLflow's) has the necessary permissions.

In [0]:
# --- 1. Fetch secrets from Azure Key Vault ---
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient

key_vault_url = "https://anomalykeyvault10222025.vault.azure.net/"
key_name_tenant_id = "tenant-id"
key_name_client_id = "client-id"
key_name_client_secret = "client-secret"
storge_account_name = "stlabelingdevwestus001" # The storage account we need to access

credential = DefaultAzureCredential()
client = SecretClient(vault_url=key_vault_url, credential=credential)

tenant_id = client.get_secret(key_name_tenant_id).value
client_id = client.get_secret(key_name_client_id).value
client_secret = client.get_secret(key_name_client_secret).value

# --- 2. Set Service Principal credentials on the global Spark session ---
# This allows MLflow's background jobs to authenticate to Azure Data Lake Storage (ADLS).

spark.conf.set(f"fs.azure.account.auth.type.{storge_account_name}.dfs.core.windows.net", "OAuth")
spark.conf.set(f"fs.azure.account.oauth.provider.type.{storge_account_name}.dfs.core.windows.net", "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider")
spark.conf.set(f"fs.azure.account.oauth2.client.id.{storge_account_name}.dfs.core.windows.net", client_id)
spark.conf.set(f"fs.azure.account.oauth2.client.secret.{storge_account_name}.dfs.core.windows.net", client_secret)
spark.conf.set(f"fs.azure.account.oauth2.client.endpoint.{storge_account_name}.dfs.core.windows.net", f"https://login.microsoftonline.com/{tenant_id}/oauth2/token")

print(f"Spark session configured for Service Principal access to: {storge_account_name}")

In [0]:
# Define and verify the mount path for our data
mount_path = "/mnt/datamount"
dbutils.fs.ls(mount_path)

In [0]:
# Load the augmented image dataset from the Parquet files
print("Loading original DataFrame...")
df = spark.read.format('parquet').load(f'{mount_path}/images_augmented')

# Define the name for Delta table
dataset_name = "train_dataset"

In [0]:
# Save the DataFrame as a Delta table for reliability and versioning.
# This is the source table MLflow will track for lineage.
df.write.mode("overwrite").format("delta").saveAsTable(f"{dataset_name}")

In [0]:
# Import Hugging Face datasets library
from datasets import Dataset, ClassLabel

# Load the Delta table back as a Spark DataFrame
train_df = spark.read.table(dataset_name)

# Convert the Spark DataFrame to a Hugging Face Dataset object
# This loads the data into memory/cache for use by the Trainer.
train_dataset = Dataset.from_spark(train_df)

# Define the class labels. 'surfing' will be 0, 'noise' will be 1.
class_label_feature = ClassLabel(names=['surfing', 'noise'])
print("Casting 'label' column to ClassLabel...")

# Apply the ClassLabel to the 'label' column.
# This converts the string labels ("surfing", "noise") into integer indices (0, 1)
# which are required by the model for training.
train_dataset = train_dataset.cast_column('label', class_label_feature)

display(train_dataset)
display(train_dataset.features)

# Verify that the label has been converted to an integer
print(f"\nSample label (as integer): {train_dataset[0]['label']}")

In [0]:
# Verify the 'image' column contains raw bytes, as expected by our preprocess function
type(train_dataset[0]['image'])

In [0]:
import io
import time
import os
import numpy as np
import pandas as pd
from PIL import Image

import torch
from sklearn.metrics import accuracy_score
import mlflow.pytorch
import mlflow.data
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer, EarlyStoppingCallback

# set device
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")
    
device = get_device()
print(f"Using device: {device}") # Will be 'cpu' in this environment
    

In [0]:
# Get the number of labels from our ClassLabel feature
num_labels = train_dataset.features['label'].num_classes
print(f"\n{num_labels} classes found in features.")

# Load the pre-trained model's image processor
# This processor knows how to resize (224x224), normalize, 
# and convert images to the exact format the ViT model expects.
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224', use_fast = True, do_resize = True, size = 224)

# Load the pre-trained ViT model
# 'ignore_mismatched_sizes = True' is essential.
# It discards the original 1000-class ImageNet classifier head 
# and replaces it with a new, randomly initialized classifier 
# that matches our 'num_labels' (2).
model = AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels = num_labels, ignore_mismatched_sizes = True)

# Set label-to-ID mappings in the model's config for clarity
model.config.id2label = {i: label for i, label in enumerate(class_label_feature.names)}
model.config.label2id = {label: i for i, label in enumerate(class_label_feature.names)}

In [0]:
def preprocess(example):
    '''
    define a preprocessing function to handle a single data sample (raw bytes image data)
    input: example = {'name': '...', 'path': '...', 'label': 0, 'image': b'...' }
    return: example = {'name': '...', 'path': '...', 'label': 0, 'image': b'...', 'pixel_values': tensor([3, 224, 224])} 
    '''
    # 1. Decode the raw 'bytes' into a PIL Image object
    # 2. Convert to 'RGB' (3 channels). This is critical, as it turns our
    #    binary/grayscale images into the 3-channel format ViT expects.
    image = Image.open(io.BytesIO(example['image'])).convert('RGB')

    # 3. Use the processor to resize, normalize, and convert the PIL image 
    #    into a PyTorch tensor. Output shape is [1, 3, 224, 224].
    processed_img = processor(images = image, return_tensors = 'pt')
    
    # 4. Remove the unnecessary batch dimension (axis 0). 
    #    Shape becomes [3, 224, 224]. The Trainer's data loader will re-add 
    #    the batch dimension later.
    example['pixel_values'] = processed_img['pixel_values'].squeeze()
    return example

print("Starting preprocessing with .map() on Databricks...")
# Apply the preprocessing function to all samples in the dataset.
# .map() caches the results for fast training.
train_dataset = train_dataset.map(preprocess)
print("Preprocessing finished and cached.")

# Check a sample to verify the new 'pixel_values' column
sample = train_dataset[0]
print(sample.keys())
print(f"Pixel values shape: {torch.tensor(sample['pixel_values']).shape}")
print(f"Label: {sample['label']}")

# Set the dataset format to return PyTorch tensors directly
train_dataset.set_format(type='torch', columns=['pixel_values', 'label'])

# Split the dataset into train and test sets
# 'stratify_by_column' is crucial for imbalanced data.
# It ensures both train_set and test_set have the same 80/20 ratio
# of normal/noise labels as the original dataset.
train_dataset = train_dataset.train_test_split(test_size = 0.2, stratify_by_column='label')
train_set = train_dataset['train']
test_set = train_dataset['test']

# Log the class distribution
train_normal_num = len(train_set['label'][train_set['label'] == 0])
train_noise_num = len(train_set['label'][train_set['label'] == 1])
print(f"train_normal_num: {train_normal_num}, train_noise_num: {train_noise_num}")

In [0]:
# Although 'ignore_mismatched_sizes=True' already created a new head,
# we can also explicitly re-initialize it for full control.
model.classifier = torch.nn.Linear(model.config.hidden_size, num_labels)

# Initialize the weights of the new classifier layer
# 'xavier_uniform_' is a common and robust initialization strategy.
torch.nn.init.xavier_uniform_(model.classifier.weight)
model.classifier.bias.data.fill_(0)

# We comment this out because the Trainer will handle it.
# The Trainer automatically moves the model to the correct device ('cpu' or 'cuda')
# when training begins, so a manual .to(device) is not needed.
# model = model.to(device)

In [0]:

model_name = 'vit-base-patch16-224-anomaly'
# Define the output directory on DBFS (Databricks File System) for persistent storage.
# '/dbfs' is the FUSE mount point for DBFS.
output_path = f"/dbfs{mount_path}/{model_name}"
os.makedirs(output_path, exist_ok=True)

train_args = TrainingArguments(
  output_dir=output_path,   # Where to save checkpoints
  num_train_epochs=10,
  per_device_train_batch_size=16, # Num samples per batch (on this device)
  # Simulate a larger batch size for stable gradients. 
  # Effective batch size = 16 * 4 = 64.
  gradient_accumulation_steps=4,      
  per_device_eval_batch_size=16,  # Batch size for evaluation      
  learning_rate=2e-5,           # Standard learning rate for ViT fine-tuning
  weight_decay=0.01,            # L2 regularization to prevent overfitting
  # Use a learning rate warmup for 10% of training steps.
  # This helps stabilize training at the beginning.
  warmup_ratio=0.1, 
  evaluation_strategy="epoch",  # Run evaluation at the end of each epoch
  save_strategy="epoch",        # Save a checkpoint at the end of each epoch
  logging_strategy='steps',
  logging_steps=10,             # Log training loss every 10 steps
  # Our dataset is 80/20, so 'accuracy' is misleading.
  # 'f1' score for the minority class is the correct metric to optimize for.
  metric_for_best_model='f1',          
  # This ensures that 'trainer.model' at the end is the best-performing one.
  load_best_model_at_end=True,
  # Only keep the single best checkpoint to save disk space.
  save_total_limit=1,
  
  # We are training on CPU, so FP16 (mixed precision) must be disabled.
  # FP16 is for NVIDIA GPUs.
  fp16=False,                          
)

In [0]:
import evaluate # Hugging Face's modern evaluation library

# Load the F1 and Accuracy metric calculators
f1_metric= evaluate.load("f1")
acc_metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    """Calculates F1 and Accuracy for the Trainer."""
    # eval_pred is a tuple: (predictions, label_ids)
    # predictions are raw logits (e.g., [1.7, -0.3]), not probabilities
    logits, labels = eval_pred.predictions, eval_pred.label_ids
    
    # Get the predicted class (0 or 1) by finding the index 
    # with the highest logit score for each sample (axis=1).
    predicts = np.argmax(logits, axis=1)
    
    # --- Calculate Accuracy ---
    acc_result = acc_metric.compute(predictions = predicts, references = labels)
    
    # --- Calculate F1 Score (for the minority class) ---
    # 'pos_label=1' specifies that our 'positive' class (the one we 
    # want to detect) is 'noise' (label 1).
    # 'average="binary"' tells the function to *only* calculate the F1 
    # score for this specific positive class (label 1).
    f1_result = f1_metric.compute(predictions = predicts, references = labels, average = "binary", pos_label = 1)

    # Combine both metrics into a single dictionary
    results = {}
    results.update(acc_result)
    results.update(f1_result)

    return results




In [0]:
run_name = f"{model_name}-run-{time.strftime('%Y%m%d%H%M%S')}"
# Define the early stopping callback
early_stopping = EarlyStoppingCallback(early_stopping_patience=3)

with mlflow.start_run(run_name=run_name):

    # 1. Enable autologging, but disable model logging.
    # We do this because autolog() was unreliably failing to save artifacts
    # in this environment. We let it handle params/metrics,
    # but we will manually log the model ourselves.
    mlflow.transformers.autolog(log_models = False)

    # 2. Log custom parameters that autolog cannot know
    mlflow.log_param("train_normal_num", train_normal_num)
    mlflow.log_param("train_noise_num", train_noise_num)
    mlflow.log_param("total_train_samples", len(train_set))
    mlflow.log_param("total_test_samples", len(test_set))

    mlflow.set_tag("task", "anomaly_detection")
    mlflow.set_tag("model_family", "ViT")

    # 3. Log the input dataset for data lineage tracking.
    # This creates a *reference* to the 'train_df' Spark table.
    # This line *required* the Spark Auth setup from Cell 7 to work.
    src_dataset = mlflow.data.from_spark(train_df)
    mlflow.log_input(src_dataset, context="Training-Input-DataFrame")

    # 4. Initialize the Trainer
    trainer = Trainer(
        model=model,
        args=train_args,
        train_dataset=train_set,
        eval_dataset=test_set,
        compute_metrics=compute_metrics,
        callbacks=[early_stopping]
    )

    # 5. Start the model training
    # autolog() will automatically log metrics/params during this process.
    print("Starting model training...")
    train_result = trainer.train()
    print("Training finished.")

    # 6. Manually log the model artifacts
    # This block runs *after* training is complete.
    print("Logging model artifacts manually...")
    try:
        # Get a sample input for signature inference
        sample_input = next(iter(test_set))
        # Add the batch dimension: [3, 224, 224] -> [1, 3, 224, 224]
        input_example = sample_input['pixel_values'].unsqueeze(0).cpu().numpy()
        
        # Log the model using mlflow.transformers.log_model
        mlflow.transformers.log_model(
            # CRITICAL: We must pass a dict containing *both* the model
            # and the processor. 'trainer.model' contains the best model
            # thanks to 'load_best_model_at_end=True'.
            transformers_model = {"model": trainer.model, "image_processor": processor}, 
            artifact_path = "model", # Folder name in MLflow artifacts
            input_example = input_example # This infers the signature automatically
        )
        print("Model artifacts logged successfully.")
    
    except Exception as e:
        print(f"Error: Failed to log model artifacts: {e}")

print("MLflow run completed!!")
