In [0]:
CATALOG = dbutils.widgets.get("CATALOG")
SCHEMA = dbutils.widgets.get("SCHEMA")
VOLUME = dbutils.widgets.get("VOLUME")

In [0]:
root_path = f'/Volumes/{CATALOG}/{SCHEMA}/{VOLUME}/x-ray-kaggle/chest-xray-pneumonia/chest_xray/'

train_path = root_path + 'train/'
test_path = root_path + 'test/'
val_path = root_path + 'val/'

In [0]:
from pyspark.sql import Row

def list_jpeg_files_recursive(path):
    files = []
    for f in dbutils.fs.ls(path):
        if f.isDir():
            files.extend(list_jpeg_files_recursive(f.path))
        elif f.name.lower().endswith('.jpeg') or f.name.lower().endswith('.jpg'):
            files.append(f)
    return files

files = list_jpeg_files_recursive(root_path)
file_rows = [Row(path=f.path, name=f.name, size=f.size, modificationTime=f.modificationTime) for f in files]
raw_df = spark.createDataFrame(file_rows)
display(raw_df)

In [0]:
from pyspark.sql.functions import regexp_extract, when, regexp_replace

# Remove 'dbfs:' from path
path_col = regexp_replace("path", r"^dbfs:", "")

# Extract stage from path
stage_expr = regexp_extract(path_col, r"/chest_xray/(train|test|val)/", 1)
# Map stage values to desired output
stage_col = when(stage_expr == "train", "train") \
    .when(stage_expr == "test", "test") \
    .when(stage_expr == "val", "validate") \
    .otherwise(None)

# Extract label from path
label_col = regexp_extract(path_col, r"/(NORMAL|PNEUMONIA)/", 1)

cleaned_df = raw_df.select(
    path_col.alias("path"),
    "name",
    label_col.alias("label"),
    stage_col.alias("stage")
)

display(cleaned_df)
from databricks.feature_engineering import FeatureEngineeringClient

fe = FeatureEngineeringClient()
table_name = f"{CATALOG}.{SCHEMA}.chest_xray_cleaned"

feature_table = fe.create_table(
    name=table_name,
    primary_keys='name',
    df=cleaned_df,
    description='Cleaned chest x-ray image metadata with name as primary key'
)

In [0]:
import matplotlib.pyplot as plt
import seaborn as sns

# Aggregate counts by stage and label
counts_df = cleaned_df.groupBy("stage", "label").count().toPandas()

# Pivot for heatmap
heatmap_data = counts_df.pivot(index="label", columns="stage", values="count").fillna(0)

plt.figure(figsize=(6, 4))
sns.heatmap(heatmap_data, annot=True, fmt=".0f", cmap="YlOrRd")
plt.xlabel("Stage")
plt.ylabel("Label")
plt.title("Number of Images by Stage and Label")
display(plt.gcf())
plt.close()

In [0]:
from PIL import Image
import matplotlib.pyplot as plt

# Set a fixed aspect ratio (e.g., 4:3)
ASPECT_RATIO = 4 / 3
FIXED_WIDTH = 6  # inches
FIXED_HEIGHT = FIXED_WIDTH / ASPECT_RATIO

def display_image(path, dpi=300):
    img = Image.open(path)
    plt.figure(figsize=(FIXED_WIDTH, FIXED_HEIGHT))
    plt.imshow(img, interpolation="nearest", aspect="auto")
    plt.axis('on')
    width, height = img.size
    plt.xlabel(f'Width: {width} px')
    plt.ylabel(f'Height: {height} px')
    display(plt.gcf())
    plt.close()

# Display 1 sample from NORMAL and PNEUMONIA
for label in ['NORMAL', 'PNEUMONIA']:
    sample_path = spark.sql(f"""
        SELECT path FROM {CATALOG}.{SCHEMA}.chest_xray_cleaned
        WHERE label = '{label}'
        LIMIT 1
    """).toPandas()['path']
    for path in sample_path:
        display_image(path)