# Data Processing Pipeline

This notebook demonstrates the data processing pipeline for computer vision tasks on Databricks.

## Setup

First, let's install required dependencies and import necessary modules.

In [None]:
# Install dependencies
%pip install pycocotools albumentations

In [None]:
# Import required modules
from pyspark.sql import SparkSession
import mlflow
from data.unity_catalog.catalog_manager import CatalogManager
import mlflow
from data.processing.coco_processor import COCOProcessor
from data.processing.data_loader import COCODataset, get_transforms
import matplotlib.pyplot as plt
import numpy as np

## Initialize Spark Session

Create a Spark session for distributed data processing.

In [None]:
# Initialize Spark session
spark = SparkSession.builder \
    .appName("CV Data Processing") \
    .config("spark.executor.memory", "4g") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

In [None]:
# Get current user's email for catalog/schema naming
current_user = spark.sql("SELECT current_user()").collect()[0][0]
user_prefix = current_user.split('@')[0]

# Define catalog and schema names
catalog_name = f"{user_prefix}_cv_catalog"
schema_name = "coco_dataset"

# Initialize catalog manager
catalog_manager = CatalogManager(spark)

# Create catalog if it doesn't exist
catalog_manager.create_catalog_if_not_exists(
    catalog_name=catalog_name,
    comment="Catalog for computer vision datasets"
)

# Create schema if it doesn't exist
catalog_manager.create_schema_if_not_exists(
    catalog_name=catalog_name,
    schema_name=schema_name,
    comment="Schema for COCO format datasets"
)

## Initialize COCO Processor

Create a COCO processor instance to handle MS COCO format datasets.

In [None]:
# Initialize processor with catalog manager
processor = COCOProcessor(spark, catalog_manager=catalog_manager)

# Load annotations
annotation_file = "/dbfs/path/to/annotations.json"
processor.load_coco_annotations(annotation_file)

## Process Images

Process images and create a DataFrame with image metadata.

In [None]:
# Process images
image_dir = "/dbfs/path/to/images"
df = processor.process_images(image_dir)

# Display sample data
display(df.limit(5))

## Validate Data

Perform data validation to ensure quality and consistency.

In [None]:
# Validate data
validation_results = processor.validate_data(df)
print("Validation results:")
for category, issues in validation_results.items():
    print(f"\n{category}:")
    for issue in issues:
        print(f"- {issue}")

## Create DataLoader

Set up data loading for training.

In [None]:
# Create dataset
dataset = COCODataset(
    image_paths=df.select("image_path").rdd.flatMap(lambda x: x).collect(),
    annotations=df.select("annotations").rdd.flatMap(lambda x: x).collect(),
    transform=get_transforms(mode='train')
)

# Create dataloader
dataloader = create_dataloader(
    dataset,
    batch_size=32,
    num_workers=4,
    shuffle=True
)

## Save to Delta Lake

Save processed data to Delta Lake format for efficient storage and querying.

In [None]:
# Save to Delta Lake with Unity Catalog
processor.save_to_delta(
    df=df,
    catalog_name=catalog_name,
    schema_name=schema_name,
    table_name="coco_dataset",
    comment="Processed COCO dataset with annotations"
)

# Verify saved data
saved_df = spark.sql(f"SELECT * FROM {catalog_name}.{schema_name}.coco_dataset")
print(f"Total records: {saved_df.count()}")

## Visualize Sample Data

Visualize sample images and annotations to verify data processing.

In [None]:
def visualize_sample(image, annotations):
    plt.figure(figsize=(10, 10))
    # Convert tensor to numpy array if needed
    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()
        # Denormalize if the image was normalized
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = std * image + mean
        image = np.clip(image, 0, 1)
    
    plt.imshow(image)
    
    # Get boxes from annotations
    boxes = annotations['boxes']
    if torch.is_tensor(boxes):
        boxes = boxes.numpy()
    
    for box in boxes:
        x, y, w, h = box
        rect = plt.Rectangle(
            (x, y), w, h,
            fill=False, edgecolor='red', linewidth=2
        )
        plt.gca().add_patch(rect)
    plt.axis('off')
    plt.show()

# Visualize a few samples
for i in range(3):
    image, annotations = dataset[i]
    visualize_sample(image, annotations)