In [1]:
import os

In [2]:
import glob

In [3]:
import random

In [7]:
pip install pyspark delta-spark


Collecting pysparkNote: you may need to restart the kernel to use updated packages.

  Downloading pyspark-4.0.0.tar.gz (434.1 MB)
     ---------------------------------------- 0.0/434.1 MB ? eta -:--:--
     ---------------------------------------- 0.0/434.1 MB ? eta -:--:--
     ---------------------------------------- 0.5/434.1 MB 3.2 MB/s eta 0:02:16
     ---------------------------------------- 1.6/434.1 MB 4.0 MB/s eta 0:01:48
     ---------------------------------------- 2.1/434.1 MB 3.7 MB/s eta 0:01:58
     ---------------------------------------- 3.4/434.1 MB 4.2 MB/s eta 0:01:43
     ---------------------------------------- 4.7/434.1 MB 4.6 MB/s eta 0:01:33
      --------------------------------------- 6.0/434.1 MB 4.9 MB/s eta 0:01:28
      --------------------------------------- 6.8/434.1 MB 5.0 MB/s eta 0:01:26
      --------------------------------------- 7.1/434.1 MB 4.8 MB/s eta 0:01:30
      --------------------------------------- 7.1/434.1 MB 4.8 MB/s eta 0:01:30
   


[notice] A new release of pip is available: 25.0.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [8]:
from pyspark.sql import SparkSession, Row

In [9]:
from pyspark.sql.types import StructType, BinaryType, IntegerType, StringType

In [None]:
spark = SparkSession.builder \
    .appName("MNIST Delta Table") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .getOrCreate()


In [None]:
base_path = "flat files/mnist_png/training"


In [None]:
all_rows = []

for digit in range(10):
    digit_folder = os.path.join(base_path, str(digit))
    
    # Get all PNG files in this folder using glob
    image_paths = glob.glob(os.path.join(digit_folder, "*.png"))
    
    # Pick 5 random ones
    selected_images = random.sample(image_paths, 5)
    
    for path in selected_images:
        with open(path, "rb") as f:
            img_bytes = f.read()
        
        # Create a Row with label, filename, and image bytes
        row = Row(label=digit, filename=os.path.basename(path), image=img_bytes)
        all_rows.append(row)


In [None]:
schema = StructType() \
    .add("label", IntegerType()) \
    .add("filename", StringType()) \
    .add("image", BinaryType())

df = spark.createDataFrame(all_rows, schema=schema)
df.show(5)


In [None]:
df.write.format("delta").mode("overwrite").saveAsTable("mnist_images_table")


In [None]:
spark.sql("SELECT * FROM mnist_images_table LIMIT 10").show()
