In [0]:
%pip install datasets==2.20.0 transformers==5.0.0 tf-keras==2.17.0 accelerate==1.4.0 mlflow==2.20.2 torchvision==0.20.1 deepspeed==0.14.4
dbutils.library.restartPython()

In [0]:
checkpoint_path = "dbfs:/Volumes/smart_claims_dev/00_landing/training_imgs/"
training_df = spark.readStream.table("smart_claims_dev.02_silver.training_images")
# display(training_df.limit(1), checkpointLocation = checkpoint_path)

In [0]:
import io
from pyspark.sql.functions import pandas_udf, col
IMAGE_RESIZE = 224

landing_catalog = "smart_claims_dev"
landing_schema = "00_landing"
base_path = f"/Volumes/{landing_catalog}/{landing_schema}/claims"
metadata_path = f"{base_path}/autoloader_metadata"

#Resize UDF function
@pandas_udf("binary")
def resize_image_udf(content_series):
  def resize_image(content):
    from PIL import Image
    """resize image and serialize back as jpeg"""
    #Load the PIL image
    image = Image.open(io.BytesIO(content))
    width, height = image.size   # Get dimensions
    new_size = min(width, height)
    # Crop the center of the image
    image = image.crop(((width - new_size)/2, (height - new_size)/2, (width + new_size)/2, (height + new_size)/2))
    #Resize to the new resolution
    image = image.resize((IMAGE_RESIZE, IMAGE_RESIZE), Image.NEAREST)
    #Save back as jpeg
    output = io.BytesIO()
    image.save(output, format='JPEG')
    return output.getvalue()
  return content_series.apply(resize_image)


# add the metadata to enable the image preview
image_meta = {"spark.contentAnnotation" : '{"mimeType": "image/jpeg"}'}

(training_df
      .withColumn("content", resize_image_udf(col("content")).alias("content", metadata=image_meta))
      .writeStream
      .option("checkpointLocation", f"{metadata_path}/_checkpoint2")
      .trigger(availableNow=True)
      .toTable("smart_claims_dev.02_silver.training_images_resized"))

In [0]:
display(spark.table("smart_claims_dev.02_silver.training_images_resized").limit(10), checkpointLocation = checkpoint_path)

In [0]:
from datasets import Dataset
import mlflow

#Setup the training experiment
mlflow.set_experiment("/Users/kryshtopenko@gmail.com/image-claims-classifier")

# Convert Spark DataFrame to pandas first (Serverless doesn't support from_spark)
pandas_df = spark.table("smart_claims_dev.02_silver.training_images_resized").toPandas()
dataset = Dataset.from_pandas(pandas_df).rename_column("content", "image")

splits = dataset.train_test_split(test_size=0.2, seed = 42)
train_ds = splits['train']
val_ds = splits['test']

In [0]:
import torch
from transformers import AutoFeatureExtractor, AutoImageProcessor

# pre-trained model from which to fine-tune
# Check the hugging face repo for more details & models: https://huggingface.co/microsoft/resnet-50
model_checkpoint = "microsoft/resnet-50"

from PIL import Image
import io
from torchvision.transforms import CenterCrop, Compose, Normalize, RandomResizedCrop, Resize, ToTensor, Lambda

#Extract the model feature (contains info on pre-process step required to transform our data, such as resizing & normalization)
#Using the model parameters makes it easy to switch to another model without any change, even if the input size is different.
model_def = AutoFeatureExtractor.from_pretrained(model_checkpoint)

#Transformations on our training dataset. we'll add some crop here
transforms = Compose([Lambda(lambda b: Image.open(io.BytesIO(b)).convert("RGB")), #byte to pil
                        ToTensor(), #convert the PIL img to a tensor
                        Normalize(mean=model_def.image_mean, std=model_def.image_std)
                        ])

# Add some random resiz & transformation to our training dataset
def preprocess(batch):
    """Apply train_transforms across a batch."""
    batch["image"] = [transforms(image) for image in batch["image"]]
    return batch
   
#Set our training / validation transformations
train_ds.set_transform(preprocess)
val_ds.set_transform(preprocess)

In [0]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

#Mapping between class label and value (huggingface use it during inference to output the proper label)
label2id, id2label = dict(), dict()
for i, label in enumerate(set(dataset['label'])):
    label2id[label] = i
    id2label[i] = label
    
#Load the base model from its checkpoint
model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint, 
    label2id=label2id,
    id2label=id2label,
    num_labels=len(label2id),
    ignore_mismatched_sizes = True # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

In [0]:
import os
import gc

# 1. Kill the 'phone home' features that cause hangs
os.environ["DATABRICKS_AUTOLOGGING_ENABLED"] = "false"
os.environ["REPORT_TO"] = "none"
os.environ["WANDB_DISABLED"] = "true"


# model_name = model_checkpoint.split("/")[-1]
from transformers import TrainingArguments

args = TrainingArguments(
    #  f"/tmp/huggingface/pcb/{model_name}-finetuned",
    # no_cuda=True, #Run on CPU for resnet to make it easier
    output_dir="/tmp/checkpoints",
    push_to_hub= False,
    remove_unused_columns=False,
    per_device_train_batch_size=32, 
    per_device_eval_batch_size=32,
    eval_strategy = "epoch",
    save_strategy = "epoch",
    num_train_epochs=1,  #20
    max_steps=10,
    load_best_model_at_end=False,
    learning_rate=5e-5,
    logging_steps=1,
    logging_first_step=True
)
print("Success! Arguments are ready.")

In [0]:

# args = TrainingArguments(
#     # f"/tmp/huggingface/pcb/{model_name}-finetuned",
#     # no_cuda=True, #Run on CPU for resnet to make it easier
#     remove_unused_columns=False,
#     per_device_train_batch_size=32, 
#     per_device_eval_batch_size=32,
#     evaluation_strategy = "epoch",
#     save_strategy = "epoch",
#     num_train_epochs=1,  #20
#     max_steps=10,
#     load_best_model_at_end=True,
#     learning_rate=5e-5,
#     logging_steps=1,
#     logging_first_step=True
# )

In [0]:
import mlflow
# This wrapper adds steps before and after the inference to simplify the model usage
# Before calling the model: apply the same transform as the training, resizing the image
# After callint the model: only keeps the main class with the probability as output
class ModelWrapper(mlflow.pyfunc.PythonModel):
    def __init__(self, pipeline):
        self.pipeline = pipeline
        # instantiate model in evaluation mode
        self.pipeline.model.eval()

    def predict(self, context, images):
        from PIL import Image
        with torch.set_grad_enabled(False):
            #Convert the byte to PIL images
            images = images['content'].apply(lambda b: Image.open(io.BytesIO(b))).to_list()
            #the pipeline returns the probability for all the class
            predictions = self.pipeline.predict(images)
            #Filter & returns only the class with the highest score [{'score': 0.999038815498352, 'label': 'normal'}, ...]
            return pd.DataFrame([max(r, key=lambda x: x['score']) for r in predictions])

In [0]:
from transformers import pipeline, DefaultDataCollator, EarlyStoppingCallback
from mlflow.models import infer_signature

with mlflow.start_run(run_name="hugging_face_new") as run:
    mlflow.log_input(mlflow.data.from_huggingface(train_ds, "training"))

    # use real class count instead of 3
    def collate_fn(examples):
        import torch
        pixel_values = torch.stack([e["image"] for e in examples])
        labels = torch.tensor([label2id[e["label"]] for e in examples], dtype=torch.long)
        labels = torch.nn.functional.one_hot(labels, num_classes=len(label2id)).float()
        return {"pixel_values": pixel_values, "labels": labels}

    trainer = Trainer(model, args, train_dataset=train_ds, eval_dataset=val_ds, tokenizer=model_def, data_collator=collate_fn)
    train_results = trainer.train()

    # Build final HF pipeline
    classifier = pipeline("image-classification", model=trainer.state.best_model_checkpoint, tokenizer=model_def)

    # ---- moved from your Cell B, so it's inside the SAME run ----
    import pandas as pd
    wrapped_model = ModelWrapper(classifier)
    test_df = spark.table("smart_claims_dev.02_silver.training_images_resized").select('content').toPandas()
    predictions = wrapped_model.predict(None, test_df)
    signature = infer_signature(test_df, predictions)

    reqs = mlflow.transformers.get_default_pip_requirements(model)

    # LOG the model and CAPTURE the URI
    logged = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=wrapped_model,
        pip_requirements=reqs,
        signature=signature,
    )

# keep these prints to sanity-check
from mlflow import artifacts
print("logged.model_uri:", logged.model_uri)   # e.g., runs:/<run_id>/model
print("logged.run_id  :", logged.run_id)
print("model files    :", artifacts.list_artifacts(logged.model_uri))


In [0]:
from mlflow.tracking import MlflowClient
import mlflow

mlflow.set_registry_uri("databricks-uc")
model_name = "smart_claims_dev.03_gold.claims_damage_level"

registered = mlflow.register_model(
    model_uri=logged.model_uri,
    name=model_name,
)

MlflowClient().set_registered_model_alias(
    name=model_name,
    alias="prod",
    version=registered.version,
)

print(f"Registered {model_name} v{registered.version} and set alias 'prod'.")

In [0]:
predict_damage_udf = mlflow.pyfunc.spark_udf(spark, model_uri=f"models:/smart_claims_dev.03_gold.claims_damage_level@prod")
columns = predict_damage_udf.metadata.get_input_schema().input_names()
#Run the inferences
spark.table('smart_claims_dev.02_silver.training_images_resized').withColumn("damage_prediction", predict_damage_udf(*columns)).write.mode('overwrite').saveAsTable('smart_claims_dev.03_gold.damage_predictions')

In [0]:
predictions = spark.table('smart_claims_dev.03_gold.damage_predictions')
display(predictions)

In [0]:
results = predictions.selectExpr("path", "label", "damage_prediction.label as predictions", "damage_prediction.score as score").toPandas()

In [0]:
import matplotlib.pyplot as plt
import seaborn as sns

# create confusion matrix
confusion_matrix = pd.crosstab(results['label'], results['predictions'])

# plot confusion matrix
fig = plt.figure()
sns.heatmap(confusion_matrix, annot=True, cmap="Blues", fmt='d')

In [0]:
raw_images = (spark.read.table("smart_claims_dev.02_silver.claim_images")
                   .withColumn("damage_prediction", predict_damage_udf(*columns)))

metadata = spark.table("smart_claims_dev.01_bronze.claim_images_meta")

raw_images.join(metadata, on="image_name").write.mode('overwrite').saveAsTable("smart_claims_dev.03_gold.claim_images_predicted")