In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit

from pyspark.ml.feature import VectorAssembler

In [2]:
import findspark

In [3]:
from functools import reduce

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchsummary import summary

from torchvision import datasets, models, transforms

In [5]:
from copy import deepcopy

import numpy as np

from termcolor import cprint

In [6]:
findspark.find()

'/Users/haozhang/GitHub/openfl_projet/venv-python3-11/lib/python3.11/site-packages/pyspark'

In [7]:
data_path = "/tmp/files/"

tensor_mnist = datasets.MNIST(
    data_path, train=True, download=True, transform=transforms.ToTensor()
)

tensor_images = torch.stack([tensor_image for tensor_image, _ in tensor_mnist], dim=3)

tensor_images.shape

torch.Size([1, 28, 28, 60000])

In [8]:
tensor_mean = tensor_images.view(1, -1).mean(dim=1)
tensor_mean

tensor([0.1307])

In [9]:
tensor_std = tensor_images.view(1, -1).std(dim=1)
tensor_std

tensor([0.3081])

In [10]:
"""
01. torchvision.transforms.Compose(transforms)
    - Composes several transforms together.

02. torchvision.transforms.Normalize(mean, std, inplace=False)
    - Normalize a tensor image with mean and standard deviation.
    - output[channel] = (input[channel] - mean[channel]) / std[channel]
"""

mnist_train = datasets.MNIST(
    "/tmp/files/",
    train=True,
    download=True,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(tensor_mean, tensor_std),
        ]
    ),
)

mnist_test = datasets.MNIST(
    "/tmp/files/",
    train=False,
    download=True,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(tensor_mean, tensor_std),
        ]
    ),
)

In [11]:
spark = SparkSession.builder.master('local[*]') \
    .appName("model_training") \
    .getOrCreate()

25/04/10 21:22:12 WARN Utils: Your hostname, MacBookPro-2022.local resolves to a loopback address: 127.0.0.1; using 192.168.28.167 instead (on interface en0)
25/04/10 21:22:12 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/10 21:22:12 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [12]:
zero_train = (
    spark.read.format("image")
    .load("/tmp/files/MNIST/raw/train-images-idx3-ubyte")
    .withColumn("label", lit(0))
)
one_train = (
    spark.read.format("image")
    .load("/tmp/files/MNIST/raw/train-images-idx3-ubyte")
    .withColumn("label", lit(1))
)
two_train = (
    spark.read.format("image")
    .load("/tmp/files/MNIST/raw/train-images-idx3-ubyte")
    .withColumn("label", lit(2))
)
three_train = (
    spark.read.format("image")
    .load("/tmp/files/MNIST/raw/train-images-idx3-ubyte")
    .withColumn("label", lit(3))
)
four_train = (
    spark.read.format("image")
    .load("/tmp/files/MNIST/raw/train-images-idx3-ubyte")
    .withColumn("label", lit(4))
)
five_train = (
    spark.read.format("image")
    .load("/tmp/files/MNIST/raw/train-images-idx3-ubyte")
    .withColumn("label", lit(5))
)
six_train = (
    spark.read.format("image")
    .load("/tmp/files/MNIST/raw/train-images-idx3-ubyte")
    .withColumn("label", lit(6))
)
seven_train = (
    spark.read.format("image")
    .load("/tmp/files/MNIST/raw/train-images-idx3-ubyte")
    .withColumn("label", lit(7))
)
eight_train = (
    spark.read.format("image")
    .load("/tmp/files/MNIST/raw/train-images-idx3-ubyte")
    .withColumn("label", lit(8))
)
nine_train = (
    spark.read.format("image")
    .load("/tmp/files/MNIST/raw/train-images-idx3-ubyte")
    .withColumn("label", lit(9))
)

In [13]:
zero_train.select("image.origin", "image.width", "image.height").show(truncate=False)

[Stage 0:>                                                          (0 + 1) / 1]

+---------------------------------------------------+-----+------+
|origin                                             |width|height|
+---------------------------------------------------+-----+------+
|file:///tmp/files/MNIST/raw/train-images-idx3-ubyte|8    |3     |
+---------------------------------------------------+-----+------+



                                                                                

In [14]:
zero_train.show()

+--------------------+-----+
|               image|label|
+--------------------+-----+
|{file:///tmp/file...|    0|
+--------------------+-----+



In [15]:
print((zero_train.count(), len(zero_train.columns)))

(1, 2)


In [21]:
df_list = [
    zero_train,
    one_train,
    two_train,
    three_train,
    four_train,
    five_train,
    six_train,
    seven_train,
    eight_train,
    nine_train,
]

# merge data frame
df_train = reduce(lambda first, second: first.union(second), df_list)

df_train.show()

+--------------------+-----+
|               image|label|
+--------------------+-----+
|{file:///tmp/file...|    0|
|{file:///tmp/file...|    1|
|{file:///tmp/file...|    2|
|{file:///tmp/file...|    3|
|{file:///tmp/file...|    4|
|{file:///tmp/file...|    5|
|{file:///tmp/file...|    6|
|{file:///tmp/file...|    7|
|{file:///tmp/file...|    8|
|{file:///tmp/file...|    9|
+--------------------+-----+



In [17]:
print((df_train.count(), len(df_train.columns)))

(10, 2)


In [18]:
df_train = df_train.repartition(200)
print(df_train.rdd.getNumPartitions())
print((df_train.count(), len(df_train.columns)))
df_train.show()

200
(10, 2)
+--------------------+-----+
|               image|label|
+--------------------+-----+
|{file:///tmp/file...|    5|
|{file:///tmp/file...|    0|
|{file:///tmp/file...|    2|
|{file:///tmp/file...|    3|
|{file:///tmp/file...|    4|
|{file:///tmp/file...|    8|
|{file:///tmp/file...|    6|
|{file:///tmp/file...|    1|
|{file:///tmp/file...|    7|
|{file:///tmp/file...|    9|
+--------------------+-----+



In [19]:
print((train.count(), len(train.columns)))

NameError: name 'train' is not defined

25/04/10 21:22:29 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors


In [None]:
def get_model_for_eval():
    """Gets the broadcasted model."""
    model = models.resnet50(weights=None)
    model.load_state_dict(bc_model_state.value)
    model.eval()
    return model

sc = spark.sparkContext
model_state = models.resnet50(weights=None).state_dict()
bc_model_state = sc.broadcast(model_state)

In [None]:
if torch.backends.mps.is_available():
    cprint("MPS is available", "green")
    device = torch.device("mps:0")
elif torch.backends.cuda.is_available():
    cprint("CUDA is available", "green")
    device = torch.device("cuda:0")
elif torch.backends.cudnn.is_built():
    cprint("CUDNN is available", "green")
    device = torch.device("cuda:0")
else:
    cprint("CUDA and MPS are not available", "red")
    cprint("Using CPU", "red")
    device = torch.device("cpu")

In [None]:
predictions_df = train.select(col("image"), predict_batch_udf(col("label")).alias("prediction"))
predictions_df \
    .write \
    .mode("overwrite") \
    .parquet("hdfs://xxx/output/")

spark.stop()

In [None]:
def predict_batch_udf(paths: pd.Series) -> pd.Series:
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    images = ImageDataset(paths, transform=transform)
    loader = torch.utils.data.DataLoader(images, batch_size=500, num_workers=8)
    model = get_model_for_eval()
    model.to(device)
    all_predictions = []
    with torch.no_grad():
        for batch in loader:
            predictions = list(model(batch.to(device)).cpu().numpy())
            for prediction in predictions:
                all_predictions.append(prediction)
    return pd.Series(all_predictions)