In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType
from pyspark.ml.pipeline import PipelineModel

# Initialize SparkSession
spark = SparkSession.builder \
    .appName("Load Local Model and Predict") \
    .getOrCreate()

# Local path where the model is saved
model_save_path = "text_classification_model2"

# Load the model from local
loaded_model = PipelineModel.load(model_save_path)

# Create a sample dataframe with new data to test
new_data = [
    ("Indonesia-Filipina Sepakat Pulangkan Terpidana Mati Kasus Narkoba Mary Jane",), 
    ("Gara-gara Judi, Jenderal Bobby Razia Semua Prajurit TNI Kodam Hasanuddin",),
    ("Bank Mandiri Dukung Program Makan Bergizi Gratis, Dorong Pertumbuhan Ekonomi Kerakyatan yang Berkelanjutan",)
]

columns = ["Content"]

new_df = spark.createDataFrame(new_data, columns)

# Make predictions
predictions = loaded_model.transform(new_df)

# Get the labels from the StringIndexer used during training
labels = loaded_model.stages[-2].labels  # Assuming StringIndexer is the second-last stage

# Map predictions to their original labels
map_label_udf = udf(lambda x: labels[int(x)], StringType())
predictions = predictions.withColumn("predicted_label", map_label_udf(col("prediction")))

# Show predictions with labels
predictions.select("Content", "predicted_label").show(truncate=False)

# Stop the SparkSession
spark.stop()

24/12/07 08:55:08 WARN Utils: Your hostname, codespaces-be24d3 resolves to a loopback address: 127.0.0.1; using 10.0.7.5 instead (on interface eth0)
24/12/07 08:55:08 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/07 08:55:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/12/07 08:55:21 WARN StringIndexerModel: Input column Label does not exist during transformation. Skip StringIndexerModel for this column.
                                                                                

+----------------------------------------------------------------------------------------------------------+---------------+
|Content                                                                                                   |predicted_label|
+----------------------------------------------------------------------------------------------------------+---------------+
|Indonesia-Filipina Sepakat Pulangkan Terpidana Mati Kasus Narkoba Mary Jane                               |bisnis         |
|Gara-gara Judi, Jenderal Bobby Razia Semua Prajurit TNI Kodam Hasanuddin                                  |berita         |
|Bank Mandiri Dukung Program Makan Bergizi Gratis, Dorong Pertumbuhan Ekonomi Kerakyatan yang Berkelanjutan|bisnis         |
+----------------------------------------------------------------------------------------------------------+---------------+



In [None]:
spark.stop()


In [13]:
print(loaded_model.stages[4].getInputCol())  # Expected input column for the Tokenizer
print(loaded_model.stages[4].getOutputCol())  # Output column from the Tokenizer


Label
label
