# Multiclass Defect Detection with Distributed training using PyTorch Object Detection Models in Snowflake Notebooks


In [None]:
!pip freeze | grep snow
!pip install opencv-python-headless

import torch
import torchvision
print(torch.__version__)
print(torchvision.__version__)

## Install necessary packages:

* torch
* torchvision
* opencv
* matplotlib
* Pillow

In [None]:
!pip install opencv-python
!apt update && apt install -y libsm6 libxext6
!apt-get install -y libxrender-dev

### Import necessary packages

In [None]:
# Import python packages
import streamlit as st
import pandas as pd
from snowflake.snowpark.context import get_active_session

import os
import sys
import time
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN


import numpy as np

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as img
from snowflake.ml.registry import Registry

from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T

from snowflake.ml.modeling.distributors.pytorch import PyTorchDistributor, PyTorchScalingConfig, WorkerResourceConfig
from snowflake.ml.data.sharded_data_connector import DataConnector, ShardedDataConnector


import warnings
warnings.filterwarnings("ignore")

session = get_active_session()

In [None]:
# Get device info
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print("Number of GPU devices available:", num_gpus)
    
    for i in range(num_gpus):
        print("Device", i, ":", torch.cuda.get_device_name(i))
    
    #Set a default device
    torch.cuda.set_device(0)
else:
    print("CUDA is not available. Check your installation or GPU setup.")

### View the training dataset

In [None]:
session.table("training_data").limit(5).collect();

# Training

## Step 1: Define a Training Function for Each Worker

Create a function that defines the training process for an individual worker. This function will be executed independently on each worker during distributed training.

## Step 2: Execute the Training Function Using PyTorchDistributor

Use the PyTorchDistributor to distribute and manage the execution of the training function across multiple workers.

* The **ShardedDataConnector** ensures that the dataset is evenly partitioned (sharded) and distributed across all workers.
* The **PyTorchScalingConfig** specifies the number of workers and necessary resources (e.g., CPUs, GPUs, memory) for each worker.


In [None]:
import base64
import io
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader, IterableDataset
from PIL import Image
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn  
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor  
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights 
import torch.distributed as dist
from snowflake.ml.modeling.distributors.pytorch import get_context
from torch.nn.parallel import DistributedDataParallel as DDP
import tempfile
import cloudpickle as cp

def train_func():
    context = get_context()
    rank = context.get_rank()
    dist.init_process_group(backend="nccl")
    print(f"Worker Rank : {rank}, world_size: {context.get_world_size()}")

    ###
    #  Wrapper to transform the dataset.
    ###
    class FCBData(IterableDataset):
        def __init__(self, source_dataset, transforms=None):  
            self.source_dataset = source_dataset
            self.transforms = transforms if transforms else torchvision.transforms.ToTensor()  # Ensure we apply ToTensor transform
    
        def __iter__(self):
            for row in self.source_dataset:
                base64_image = row['IMAGE_DATA']
                image = Image.open(io.BytesIO(base64.b64decode(base64_image)))
                # Convert the image to a tensor
                image = self.transforms(image)  # Converts PIL image to tensor
    
                # Extract bounding box and labels
                boxes = [[row[k].item() for k in ["XMIN", "YMIN", "XMAX", "YMAX"]] for _ in range(1)]
                labels = [row["CLASS"].item()]
    
                boxes = torch.as_tensor(boxes, dtype=torch.float32)  
                labels = torch.as_tensor(labels, dtype=torch.int64)
                
                # Prepare the target dictionary
                target = {  
                    'boxes': boxes,  
                    'labels': labels,  
                    'image_id': torch.tensor([int(row["FILENAME"])]),
                    'area': (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]),  # Calculate area
                    'iscrowd': torch.zeros((boxes.shape[0],), dtype=torch.uint8)  # Set iscrowd to 0 for all
                }
                yield (image, target)

    with torch.cuda.device(rank):
        # Model initialization
        weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT  
        model = fasterrcnn_resnet50_fpn(weights=weights)
          
        # Modify the model for your number of classes (including background)
        num_classes = 6  
        in_features = model.roi_heads.box_predictor.cls_score.in_features  
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
        model.to(rank)
        model = DDP(model, device_ids=[rank])
        
    
        optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], lr=0.0001, weight_decay=0.0005)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
    
        # Load the data using ShardedDataConnector
        dataset_map = context.get_dataset_map()
        train_shard = dataset_map["train"].get_shard().to_torch_dataset()
        train_dataset = FCBData(train_shard)
    
        # get hyper_params 
        hyper_parms = context.get_hyper_params()
        
        def collate_fn(batch):
            return tuple(zip(*batch))
    
        batch_size = int(hyper_parms['batch_size'])
        train_data_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=collate_fn,
            pin_memory=True,
            pin_memory_device=f"cuda:{rank}"
        )

        # Training loop
        num_epochs = int(hyper_parms['num_epochs'])
        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0
            running_batches = 0
            for images, targets in train_data_loader:
                running_batches = running_batches + 1
                images = [image.float() / 255.0 for image in images]
                images = [image.to(rank) for image in images]
                targets = [{k: v.to(rank) for k, v in t.items()} for t in targets]
                
                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())
                
                optimizer.zero_grad()
                losses.backward()
                optimizer.step()
    
                running_loss += losses.item()
    
            print(f"[Rank {rank}] Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / (running_batches*batch_size):.4f}, Processed {running_batches * (epoch+1) * batch_size} images so far")
            lr_scheduler.step()
    
        MODEL_PATH = "/tmp/models/detectionmodel.pt"
        if rank == 0:
            with open(MODEL_PATH, mode="w+b") as model_file:
                    torch.save(model.module.state_dict(), model_file)
            print(f"Model written to {MODEL_PATH}")
    
        print(f"[Rank {rank}] Training completed.")

### For the purpose of this quickstart, we have considered a smaller volume as the data source. But ideally this can scale million rows

1. Split the dataset (shard) for distributed training across multiple workers.
2. Train a PyTorch model using 4 workers, each utilizing 1 GPU for efficient computation. Control the training with hyperparameters such as batch size and number of epochs.

In [None]:
# Set up PyTorchDistributor
from snowflake.ml.modeling.distributors.pytorch import PyTorchDistributor, PyTorchScalingConfig, WorkerResourceConfig  
from snowflake.ml.data.sharded_data_connector import ShardedDataConnector  

df = session.table("training_data")

# Create sharded data connector.
train_data = ShardedDataConnector.from_dataframe(df)

# Create pytorch distributor.
pytorch_trainer = PyTorchDistributor(  
    train_func=train_func,
    scaling_config=PyTorchScalingConfig(  
        num_nodes=1,  
        num_workers_per_node=4,  
        resource_requirements_per_worker=WorkerResourceConfig(num_cpus=0, num_gpus=1),  
    )  
)  

# Run the trainer.
pytorch_trainer.run(
    dataset_map={"train": train_data},
    hyper_params={"batch_size": "32", "num_epochs": "5"}
)

# MODEL DEPLOYMENT


# Snowflake Model Registry - Securely manage models and their metadata in Snowflake.

The model registry stores machine learning models as first-class schema-level objects in Snowflake.

* Load the model produced by trainer 
* Define custom wrapper for the PyTorch model
* Save it to Model Registry by specifying the model_name,version_name,input dataframe as signature and conda_dependencies

In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN_ResNet50_FPN_Weights
from PIL import Image
import io
import json
import base64
df=session.table("VAL_IMAGES_LABELS").limit(1).to_pandas()

first_row = df.iloc[0]  
base64_image = first_row['IMAGE_DATA'] 
df = pd.DataFrame({'IMAGE_DATA': [base64_image]})  

spdf=session.create_dataframe(df)
# Function to load the model
def load_model(model_path):  
    weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT  
    model = fasterrcnn_resnet50_fpn(weights=weights)  
    
    # Modify the box predictor for your specific dataset
    num_classes = 6  # Background + 5 classes
    in_features = model.roi_heads.box_predictor.cls_score.in_features  
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)  
    model.load_state_dict(torch.load(model_path), strict=False)  
    model.double()
    model.eval()  
    return model  

# Function to decode and transform an image
def decode_and_transform_image(base64_image):  
    image_data = base64.b64decode(base64_image)  
    image = Image.open(io.BytesIO(image_data)).convert('RGB')  
    if image.mode != 'RGB':
        image = image.convert('RGB')
    # Define the necessary transformations
    transform = transforms.Compose([  
        transforms.Resize((224, 224)), 
        transforms.ToTensor(),  # Converts to [C, H, W]
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
    ])  
    image_tensor = transform(image)
    image_tensor = image_tensor.double()
    
    # Debugging: Print the shape after transformation
    print(f"Shape after transformation: {image_tensor.shape}")
    
    return image_tensor


# try:
model_path = '/tmp/models/detectionmodel.pt'
model = load_model(model_path)

from snowflake.ml.model import custom_model

class DefectDetectionModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)

    @custom_model.inference_api
    def predict(self, input_df: pd.DataFrame) -> pd.DataFrame:
        processed_input = torch.stack(input_df['IMAGE_DATA'].apply(decode_and_transform_image).to_list())
        raw_output = self.context.model_ref("rcnn").forward(processed_input)
        final_output = pd.DataFrame({"output": [json.dumps({k: v.detach().cpu().numpy().tolist() for k, v in res.items()}) for res in raw_output]})
        return final_output

ddm = DefectDetectionModel(context = custom_model.ModelContext(models={'rcnn': model}))


ml_reg = Registry(session=session)  
# Log the model with the sample input for Snowflake registry
mv = ml_reg.log_model(  
    ddm,  
    model_name="DefectDetectionModel",  
    version_name='v3',  
    sample_input_data=spdf,
    conda_dependencies=["pytorch", "torchvision"],
    options={"embed_local_ml_library": True,
             
                "relax": True}

)
    

## Fetch the logged Model from Snowflake Registry



In [None]:
# Usage Example
reg = Registry(session=session) 
model_ref = reg.show_models()
model_ref

## Detect Defects on Validation dataset
Lets consider there is a validation table VAL_IMAGES_LABELS which contains the Base64 Encoding information of validation images.

* Get a reference to a specific model from the registry by name using the registry’s get_model method
* Get a reference to a specific version of a model as a ModelVersion instance using the model’s version method.
* Carry inference using the model and output the predictions


In [None]:

m = reg.get_model("DEFECTDETECTIONMODEL")
mv = m.version("GENTLE_DONKEY_4")


df=session.table("VAL_IMAGES_LABELS").limit(1).to_pandas()

first_row = df.iloc[0]
base64_image = first_row['IMAGE_DATA'] 
image_data_df = pd.DataFrame({'IMAGE_DATA': [base64_image]})  
image_data_df.head()



remote_prediction = mv.run(image_data_df, function_name="predict")
remote_prediction.head()

Fetch predictions and use a function display_image_with_boxes() to display Image with Bounding Boxes and Labels


In [None]:
import json
import base64
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import io

# Class mapping dictionary
classes_la = {
    0: "open",
    1: "short",
    2: "mousebite",
    3: "spur",
    4: "copper",
    5: "pin-hole"
}

# Function to display the image with bounding boxes and class labels
def display_image_with_boxes(image, boxes, labels, scores, target_size=(800, 600)):
    # Resize the image to a target size
    img = image.resize(target_size).convert("RGB")  # Resize and convert to RGB
    img_np = np.array(img)

    # Adjust the DPI and figure size
    fig, ax = plt.subplots(figsize=(3, 6), dpi=10)  # Adjust figure size and DPI
    ax.imshow(img_np)

    for label, box, score in zip(labels, boxes, scores):
        xmin, ymin, xmax, ymax = box
        class_label = classes_la[label]

        # Create a Rectangle patch
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='r', facecolor='none')
        ax.text(xmin, ymin, f"{class_label}: {score:.2f}", verticalalignment='top', color='red', fontsize=13, weight='bold')
        ax.add_patch(rect)

    plt.axis('off')
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)  # Ensure no padding/margins around the image
    plt.show()

# Combine the image data and remote prediction DataFrames
combined_df = pd.concat([image_data_df, remote_prediction], axis=1)

# Create a list to store data for the final DataFrame
rows = []

# Iterate through each row in the combined DataFrame
for index, row in combined_df.iterrows():
    output_str = row.get('output', None)  # Get the output column value

    if isinstance(output_str, str):  # Ensure it's a valid string before loading as JSON
        try:
            # Convert the 'output' column JSON string into a dictionary
            output_data = json.loads(output_str)

            # Extract boxes, labels, and scores from JSON data
            if 'boxes' in output_data and 'labels' in output_data and 'scores' in output_data:
                boxes = output_data['boxes']
                labels = output_data['labels']
                scores = output_data['scores']

                # Decode the image data
                image_data = base64.b64decode(row['IMAGE_DATA'])
                image = Image.open(io.BytesIO(image_data)).convert("RGB")

                # Limit to top 5 classes based on scores
                if len(scores) > 0:
                    # Create a DataFrame to manage boxes, labels, and scores
                    data = pd.DataFrame({
                        'box': boxes,
                        'label': labels,
                        'score': scores
                    })

                    # Get the top 5 entries based on scores
                    top_classes = data.nlargest(5, 'score')

                    # Extract corresponding boxes, labels, and scores
                    top_boxes = top_classes['box'].tolist()
                    top_labels = top_classes['label'].tolist()
                    top_scores = top_classes['score'].tolist()

                    # Store each of the top 5 predictions as a separate row
                    for i in range(len(top_boxes)):
                        rows.append({
                            'image_data': row['IMAGE_DATA'],
                            'output': row['output'],
                            'label': top_labels[i],
                            'box': top_boxes[i],
                            'score': top_scores[i]
                        })

                    # Display the image with bounding boxes and labels
                    display_image_with_boxes(image, top_boxes, top_labels, top_scores)
                else:
                    print("No scores available to limit to top 5.")
            else:
                print("Missing keys 'boxes', 'labels', or 'scores' in the output data.")

        except json.JSONDecodeError:
            print(f"Invalid JSON in row {index}, skipping this row.")
    else:
        print(f"Invalid output type (not a string) in row {index}, skipping this row.")

# Create the final DataFrame with the collected rows (one row per label/box/score)
final_df = pd.DataFrame(rows)
session.sql("create TABLE if not exists PCB_DATASET.PUBLIC.DETECTION_OUTPUTS (
	image_data VARCHAR(16777216),
	output VARCHAR(16777216),
	label NUMBER(38,0),
	box VARIANT,
	score FLOAT
)").collect()

# Write the DataFrame to the Snowflake table
combined_spdf = session.create_dataframe(final_df)
combined_spdf.write.save_as_table("DETECTION_OUTPUTS", mode="overwrite")
