In [1]:
import findspark
findspark.init()
import pyspark #only run after findspark.init()

In [2]:
from pyspark.ml.clustering import KMeans
from pyspark.ml.image import ImageSchema
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.functions import array_to_vector
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql import Row
from pyspark.sql.functions import udf, input_file_name, col
from pyspark.sql.types import ArrayType, FloatType
import numpy as np
from PIL import Image
import io
import os
import shutil

print("Imports done!")

Imports done!


In [3]:
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder
        .appName("chest-xray-bigdata")
        .config("spark.python.worker.faulthandler.enabled", "true")
        .config("spark.sql.execution.pyspark.udf.faulthandler.enabled", "true")
        .config("spark.driver.memory", "8g")
        .config("spark.sql.debug.maxToStringFields", "10000")
        .config("spark.driver.maxResultSize", "2g")
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
        .config("spark.kryoserializer.buffer.max", "512m")
        .getOrCreate()
)

print("Spark Session Start!")

Spark Session Start!


In [4]:
# For testing; 1000 images
#image_dir = "../data/images_001/training/*.png"
# Real use; 4999 images
image_dir = "../data/images_001/images/*.png"

In [5]:
# Create DataFrame by loading files in "image" format. "read" creates the DataFrame
xray_df = spark.read.format("binaryFile") \
    .option("pathGlobFilter", "*.png") \
    .option("recursiveFileLookup", "true") \
    .load(image_dir)
# Print the structure/schema of the DataFrame.
xray_df.printSchema()

root
 |-- path: string (nullable = true)
 |-- modificationTime: timestamp (nullable = true)
 |-- length: long (nullable = true)
 |-- content: binary (nullable = true)



In [6]:
# Add filenames (for assignment)
xray_df = xray_df.withColumn("filename", input_file_name())
xray_df.show(2)

+--------------------+-------------------+------+--------------------+--------------------+
|                path|   modificationTime|length|             content|            filename|
+--------------------+-------------------+------+--------------------+--------------------+
|file:/C:/Users/Bu...|2017-08-01 13:35:36|726419|[89 50 4E 47 0D 0...|file:///C:/Users/...|
|file:/C:/Users/Bu...|2017-08-01 13:40:34|699102|[89 50 4E 47 0D 0...|file:///C:/Users/...|
+--------------------+-------------------+------+--------------------+--------------------+
only showing top 2 rows



In [7]:
# Define UDF using Pillow (PIL)
@udf(returnType=ArrayType(FloatType()))
def extract_features_pil(file_content):
    if file_content is None:
        return None

    try:
        # 1. Open Image from Bytes
        image = Image.open(io.BytesIO(file_content))
        
        # 2. Convert to Grayscale ('L') to match channels=1
        image = image.convert('L')
        
        # 3. Resize to 256x256
        image = image.resize((256, 256))
        
        # 4. Convert to Numpy Array & Normalize
        img_array = np.array(image).astype(np.float32) / 255.0
        
        # 5. Flatten to list
        return img_array.flatten().tolist()
        
    except Exception as e:
        # This catch block will actually work now for Pillow errors
        return None

# Apply the stable UDF
xray_feature_df = xray_df.withColumn("features", extract_features_pil(col("content")))

# Filter out failures
clean_xray_df = xray_feature_df.filter(col("features").isNotNull())

# Test - This should finally work
clean_xray_df.show(5, truncate=True)

+--------------------+-------------------+------+--------------------+--------------------+--------------------+
|                path|   modificationTime|length|             content|            filename|            features|
+--------------------+-------------------+------+--------------------+--------------------+--------------------+
|file:/C:/Users/Bu...|2017-08-01 13:35:36|726419|[89 50 4E 47 0D 0...|file:///C:/Users/...|[0.5882353, 0.623...|
|file:/C:/Users/Bu...|2017-08-01 13:40:34|699102|[89 50 4E 47 0D 0...|file:///C:/Users/...|[0.08627451, 0.08...|
|file:/C:/Users/Bu...|2017-07-19 09:17:10|693527|[89 50 4E 47 0D 0...|file:///C:/Users/...|[0.38039216, 0.26...|
|file:/C:/Users/Bu...|2017-08-01 13:38:04|690140|[89 50 4E 47 0D 0...|file:///C:/Users/...|[0.28235295, 0.21...|
|file:/C:/Users/Bu...|2017-07-19 09:27:12|687283|[89 50 4E 47 0D 0...|file:///C:/Users/...|[0.09411765, 0.08...|
+--------------------+-------------------+------+--------------------+--------------------+-----

In [8]:
# Convert Array<Float> to DenseVector. Carry over the filenames.
final_df = clean_xray_df.select(
    col("filename"),
    array_to_vector(col("features")).alias("features")
)

final_df.show(5, truncate=True)

+--------------------+--------------------+
|            filename|            features|
+--------------------+--------------------+
|file:///C:/Users/...|[0.58823531866073...|
|file:///C:/Users/...|[0.08627451211214...|
|file:///C:/Users/...|[0.38039216399192...|
|file:///C:/Users/...|[0.28235295414924...|
|file:///C:/Users/...|[0.09411764889955...|
+--------------------+--------------------+
only showing top 5 rows



In [9]:
# Now run KMeans
kmeans = KMeans(k=15, seed=42, featuresCol="features")
model = kmeans.fit(final_df)

In [10]:
# Calculate to which cluster the vectors belong
xray_clustered_df = model.transform(final_df)

In [11]:
# Show how many images belong to each cluster
xray_clustered_df.groupBy("prediction").count().show()

+----------+-----+
|prediction|count|
+----------+-----+
|         1|  276|
|        13|  374|
|         5|  449|
|         8|  375|
|         7|  324|
|        14|  308|
|        12|  445|
|         3|  300|
|        11|  407|
|         2|  358|
|         9|  325|
|         4|  299|
|         0|  266|
|        10|  128|
|         6|  365|
+----------+-----+



In [12]:
output_dir = "../docs/cluster_outputs"

In [13]:
# Delete old cluster outputs
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)

In [14]:
# Export 20 filenames for each cluster
for c in range(15):
    output_path = f"{output_dir}/cluster_{c}"
    (
        xray_clustered_df
        .filter(col("prediction") == c)
        .select("filename")
        .limit(20)
        .write
        .mode("overwrite")
        .text(output_path)
    )

In [15]:
# For testing; 100 images
#new_image_dir = "../data/images_001/test/*.png"
# Real use; 10000 images
new_image_dir = "../data/images_002/images/*.png"

In [16]:
# Create new DataFrame
new_df = spark.read.format("binaryFile") \
    .option("pathGlobFilter", "*.png") \
    .option("recursiveFileLookup", "true") \
    .load(new_image_dir)

In [17]:
# Feature extraction like during training
new_features = new_df.withColumn(
    "features",
    extract_features_pil(col("content"))
).select(array_to_vector(col("features")).alias("features"))

In [18]:
# Cluster calculation for new images
prediction = model.transform(new_features)
prediction.show()

+--------------------+----------+
|            features|prediction|
+--------------------+----------+
|[0.21960784494876...|         5|
|[0.16470588743686...|         5|
|[0.40784314274787...|         8|
|[0.16470588743686...|         5|
|[0.10196078568696...|         5|
|[0.09019608050584...|         8|
|[0.20000000298023...|         5|
|[0.87450981140136...|         8|
|[0.63137257099151...|         5|
|[0.29803922772407...|        12|
|[0.26666668057441...|         5|
|[0.37647059559822...|         1|
|[0.87843137979507...|         7|
|[0.12941177189350...|         5|
|[0.83137255907058...|        14|
|[0.0,0.0,0.0,0.0,...|        14|
|[0.10588235408067...|         5|
|[0.85098040103912...|        14|
|[0.56862747669219...|         5|
|[0.92156863212585...|         5|
+--------------------+----------+
only showing top 20 rows



In [19]:
spark.stop()

print("Spark Session Stop!")

Spark Session Stop!
