## 1. Setup & Imports


In [None]:
import os
from pathlib import Path
import warnings

import rootutils

rootutils.setup_root(Path.cwd(), indicator=".project-root", pythonpath=True)

ROOT_DIR = Path(os.environ.get("PROJECT_ROOT", Path.cwd()))
print(f"Project root: {ROOT_DIR}")

warnings.filterwarnings("ignore")

## 2. Initialize Spark


In [None]:
from src.amazon_reviews_analysis.utils import build_spark

spark = build_spark()

print("✓ Spark Session created successfully!")
print(f"Spark Version: {spark.version}")

## 3. Load Model


In [None]:
from pyspark.ml import PipelineModel

MODEL_DIR = ROOT_DIR / "models" / "spark_lr_classifier"

model = PipelineModel.load(str(MODEL_DIR))

print(f"✓ Model loaded from {MODEL_DIR}")
print(f"Pipeline stages: {[stage.__class__.__name__ for stage in model.stages]}")

## 4. Load Test Data


In [None]:
from pyspark.sql.functions import col

DATA_DIR = ROOT_DIR / "data/classification/extracted"

TEXT_COL = "text"
TARGET_COL = "label"

df = spark.read.parquet(str(DATA_DIR))
df = df.withColumn("label", col(TARGET_COL).cast("double"))

# Use same split as training (seed=42) to get the same test set
_, test_df = df.randomSplit([0.8, 0.2], seed=42)

print(f"Test set: {test_df.count():,} records")
test_df.groupBy("label").count().orderBy("label").show()

## 5. Make Predictions


In [None]:
predictions = model.transform(test_df)

print("Sample predictions:")
predictions.select(TEXT_COL, "label", "prediction", "probability").show(10, truncate=50)

## 6. Evaluation Metrics


In [None]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

metrics = {}
for metric_name in ["accuracy", "f1", "weightedPrecision", "weightedRecall"]:
    evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName=metric_name)
    metrics[metric_name] = evaluator.evaluate(predictions)

print("RESULTS")
print(f"Accuracy:           {metrics['accuracy']:.4f}")
print(f"F1 Score:           {metrics['f1']:.4f}")
print(f"Weighted Precision: {metrics['weightedPrecision']:.4f}")
print(f"Weighted Recall:    {metrics['weightedRecall']:.4f}")

## 7. Confusion Matrix


In [None]:
print("Confusion Matrix (rows=actual, cols=predicted):")
confusion_matrix = predictions.groupBy("label", "prediction").count().orderBy("label", "prediction")
confusion_matrix.show(25)

pivot_cm = predictions.groupBy("label").pivot("prediction").count().orderBy("label").fillna(0)
print("\nConfusion Matrix (pivoted):")
pivot_cm.show()

## 8. Per-Class Metrics


In [None]:
from pyspark.sql.functions import count, sum as spark_sum, when

LABEL_NAMES = {0.0: "negative", 1.0: "neutral", 2.0: "positive"}

per_class = predictions.groupBy("label").agg(
    count("*").alias("total"), spark_sum(when(col("label") == col("prediction"), 1).otherwise(0)).alias("correct")
)
per_class = per_class.withColumn("accuracy", col("correct") / col("total"))

print("Per-Class Accuracy:")
for row in per_class.orderBy("label").collect():
    label_name = LABEL_NAMES.get(row["label"], str(row["label"]))
    print(f"{label_name:10s}: {row['accuracy']:.4f} ({row['correct']}/{row['total']})")

## 9. Error Analysis


In [None]:
misclassified = predictions.filter(col("label") != col("prediction"))

print(f"Total misclassified: {misclassified.count():,}")
print(f"Error rate: {misclassified.count() / predictions.count() * 100:.2f}%")
print("\nSample misclassified reviews:")
misclassified.select(TEXT_COL, "label", "prediction").show(10, truncate=80)

In [None]:
spark.stop()