## 1. Setup & Imports


In [None]:
import os
from pathlib import Path
import warnings
import zipfile

import rootutils

rootutils.setup_root(Path.cwd(), indicator=".project-root", pythonpath=True)

ROOT_DIR = Path(os.environ.get("PROJECT_ROOT", Path.cwd()))
print(f"Project root: {ROOT_DIR}")

warnings.filterwarnings("ignore")

## 2. Initialize Spark


In [None]:
from pyspark.sql.functions import *
from pyspark.sql.window import Window
from src.amazon_reviews_analysis.utils import build_spark
import os
import subprocess

# CRITICAL: Set JAVA_HOME BEFORE importing Spark
# This must be done in the notebook, not just in terminal

# Try multiple methods to find Java
java_home = None

# Method 1: Check if already set
if os.environ.get('JAVA_HOME'):
    java_home = os.environ['JAVA_HOME']
    print(f"‚úì JAVA_HOME already set to: {java_home}")
else:
    # Method 2: Try Homebrew
    try:
        brew_prefix = subprocess.check_output(
            ['brew', '--prefix', 'openjdk@17'], 
            text=True,
            stderr=subprocess.DEVNULL
        ).strip()
        if os.path.exists(brew_prefix):
            java_home = brew_prefix
            os.environ['JAVA_HOME'] = java_home
            print(f"‚úì Found Java via Homebrew: {java_home}")
    except:
        pass
    
    # Method 3: Try common locations
    if not java_home:
        common_paths = [
            '/opt/homebrew/opt/openjdk@17',
            '/usr/local/opt/openjdk@17',
            '/Library/Java/JavaVirtualMachines/temurin-17.jdk/Contents/Home',
            '/Library/Java/JavaVirtualMachines/jdk-17.jdk/Contents/Home'
        ]
        for path in common_paths:
            if os.path.exists(path):
                java_home = path
                os.environ['JAVA_HOME'] = java_home
                print(f"‚úì Found Java at: {java_home}")
                break
    
    # Method 4: Use /usr/libexec/java_home (macOS)
    if not java_home:
        try:
            java_home = subprocess.check_output(
                ['/usr/libexec/java_home', '-v', '17'],
                text=True,
                stderr=subprocess.DEVNULL
            ).strip()
            os.environ['JAVA_HOME'] = java_home
            print(f"‚úì Found Java via java_home: {java_home}")
        except:
            pass

# Verify Java is accessible
if java_home:
    java_bin = os.path.join(java_home, 'bin', 'java')
    if os.path.exists(java_bin):
        try:
            result = subprocess.run(
                [java_bin, '-version'],
                capture_output=True,
                text=True,
                stderr=subprocess.STDOUT,
                timeout=5
            )
            print(f"‚úì Java verification successful")
            print(f"  Version: {result.stdout.split(chr(10))[0] if result.stdout else 'N/A'}")
        except Exception as e:
            print(f"‚ö†Ô∏è  Could not verify Java: {e}")
    else:
        print(f"‚ö†Ô∏è  Java binary not found at: {java_bin}")
else:
    print("‚ùå ERROR: Could not find Java installation!")
    print("Please install Java 17: brew install openjdk@17")
    print("Or set JAVA_HOME manually in this cell:")
    print("  os.environ['JAVA_HOME'] = '/path/to/java'")

# Also add to PATH
if java_home:
    java_bin_dir = os.path.join(java_home, 'bin')
    current_path = os.environ.get('PATH', '')
    if java_bin_dir not in current_path:
        os.environ['PATH'] = f"{java_bin_dir}:{current_path}"
        print(f"‚úì Added Java to PATH")
# Initialize Spark
try:
    spark.stop()
except:
    pass

# Recreate Spark session
from src.amazon_reviews_analysis.utils import build_spark
spark = build_spark()
print("‚úì Spark Session created successfully!")



## 3. Load Data


In [None]:
DATA_ZIP = ROOT_DIR / "data/classification/classification_reviews.zip"
EXTRACT_DIR = ROOT_DIR / "data/classification/extracted"

# Check if zip file exists and is valid
if not DATA_ZIP.exists():
    raise FileNotFoundError(f"Zip file not found: {DATA_ZIP}")

# Check file size (empty files are likely corrupted)
if DATA_ZIP.stat().st_size == 0:
    raise ValueError(f"Zip file is empty: {DATA_ZIP}")

# Try to verify it's a valid zip file
try:
    with zipfile.ZipFile(DATA_ZIP, "r") as test_zip:
        test_zip.testzip()  # Test the zip file integrity
except zipfile.BadZipFile:
    raise ValueError(f"File is not a valid zip file: {DATA_ZIP}. It may be corrupted or in a different format.")

if not EXTRACT_DIR.exists():
    print(f"üì¶ Extracting {DATA_ZIP}...")
    with zipfile.ZipFile(DATA_ZIP, "r") as zip_ref:
        zip_ref.extractall(EXTRACT_DIR)
    print("‚úì Extraction complete!")
else:
    print("‚úì Data already extracted")

print(f"\nData location: {EXTRACT_DIR}")

In [None]:
df = spark.read.parquet(str(EXTRACT_DIR))

print(f"Total records: {df.count():,}")
print(f"\nColumns: {df.columns}")
df.printSchema()

In [None]:
df.show(5, truncate=50)

## 4. Data Exploration


In [None]:
# Check target distribution (label: 0=negative, 1=neutral, 2=positive)
df.groupBy("label").count().orderBy("label").show()

In [None]:
from pyspark.sql.functions import col, count, when, isnan

TEXT_COL = "text"
TARGET_COL = "label"  # 0=negative, 1=neutral, 2=positive

## 5. Data Preprocessing


In [None]:
from pyspark.sql.functions import col

# Label is already 0, 1, 2 - just cast to double for MLlib
df_clean = df.withColumn("label", col(TARGET_COL).cast("double"))

print(f"Clean dataset: {df_clean.count():,} records")
print("\nLabel distribution (0=negative, 1=neutral, 2=positive):")
df_clean.groupBy("label").count().orderBy("label").show()

In [None]:
# Train-Test Split
train_df, test_df = df_clean.randomSplit([0.8, 0.2], seed=42)

print(f"Training set: {train_df.count():,} records")
print(f"Test set: {test_df.count():,} records")

# Calculate class weights for imbalanced data
from pyspark.sql.functions import col, when, lit

# Get class counts in training set
class_counts = train_df.groupBy("label").count().collect()
total_samples = train_df.count()

# Calculate weights: total_samples / (num_classes * class_count)
class_weights = {}
for row in class_counts:
    label = row["label"]
    count = row["count"]
    weight = total_samples / (len(class_counts) * count)
    class_weights[label] = weight

print(f"\nClass weights: {class_weights}")

# Add weight column to training data
train_df = train_df.withColumn(
    "weight",
    when(col("label") == 0.0, lit(class_weights.get(0.0, 1.0)))
    .when(col("label") == 1.0, lit(class_weights.get(1.0, 1.0)))
    .when(col("label") == 2.0, lit(class_weights.get(2.0, 1.0)))
    .otherwise(lit(1.0))
)

print("‚úì Class weights added to training data")

## 6. Feature Engineering Pipeline


In [None]:
from pyspark.ml.feature import Tokenizer, StopWordsRemover, HashingTF, IDF
from pyspark.ml import Pipeline

tokenizer = Tokenizer(inputCol=TEXT_COL, outputCol="words")
remover = StopWordsRemover(inputCol="words", outputCol="filtered_words")
hashing_tf = HashingTF(inputCol="filtered_words", outputCol="raw_features", numFeatures=5000)  # Reduced from 10000 to save memory
idf = IDF(inputCol="raw_features", outputCol="features")

print("‚úì Feature transformers defined")

## 7. Model Training


In [None]:
from pyspark.ml.classification import RandomForestClassifier

rf = RandomForestClassifier(
    featuresCol="features",
    labelCol="label",
    weightCol="weight",  # Use class weights to handle imbalanced data
    numTrees=10,  # Reduced from 10 to save memory (fewer trees = less memory)
    maxDepth=3,  # Reduced from 3 to save memory (shallower trees = less memory)
    maxBins=16,  # Reduced from 32 to save memory (fewer bins = less memory)
    minInstancesPerNode=5,  # Increased from 1 to reduce tree size (fewer nodes = less memory)
    minInfoGain=0.0,  # Minimum information gain for a split
    impurity="gini",  # Impurity measure: "gini" or "entropy"
    featureSubsetStrategy="sqrt",  # Use sqrt of features per tree (less memory than "auto")
    subsamplingRate=0.4,  # Use 80% of data per tree (less memory per tree)
    seed=42,
    
)

pipeline = Pipeline(stages=[tokenizer, remover, hashing_tf, idf, rf])

In [None]:
print("üöÄ Training model...")
model = pipeline.fit(train_df)
print("‚úì Training complete!")

## 8. Model Evaluation


In [None]:
predictions = model.transform(test_df)

predictions.select(TEXT_COL, "label", "prediction", "probability").show(10, truncate=50)

In [None]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator_acc = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator_acc.evaluate(predictions)

evaluator_f1 = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="f1")
f1 = evaluator_f1.evaluate(predictions)

evaluator_precision = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="weightedPrecision"
)
precision = evaluator_precision.evaluate(predictions)

evaluator_recall = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="weightedRecall"
)
recall = evaluator_recall.evaluate(predictions)

print("RESULTS")
print(f"Accuracy:           {accuracy:.4f}")
print(f"F1 Score:           {f1:.4f}")
print(f"Weighted Precision: {precision:.4f}")
print(f"Weighted Recall:    {recall:.4f}")

In [None]:
confusion_matrix = predictions.groupBy("label", "prediction").count().orderBy("label", "prediction")
print("Confusion Matrix:")
confusion_matrix.show(25)

In [None]:
from pyspark.sql.functions import sum as spark_sum, when

per_class = predictions.groupBy("label").agg(
    count("*").alias("total"), spark_sum(when(col("label") == col("prediction"), 1).otherwise(0)).alias("correct")
)
per_class = per_class.withColumn("accuracy", col("correct") / col("total"))
per_class.orderBy("label").show()

## 9. Save Model


In [None]:
MODEL_DIR = ROOT_DIR / "models" / "spark_lr_classifier"

model.write().overwrite().save(str(MODEL_DIR))

print(f"‚úì Model saved to {MODEL_DIR}")

## 10. Quick Inference Test


In [None]:
from pyspark.ml import PipelineModel

loaded_model = PipelineModel.load(str(MODEL_DIR))

sample_data = spark.createDataFrame(
    [
        ("This product is amazing! Best purchase I've ever made.",),
        ("Terrible quality, broke after one day. Don't buy!",),
        ("It's okay, nothing special but does the job.",),
    ],
    [TEXT_COL],
)

sample_predictions = loaded_model.transform(sample_data)

print("Sample Predictions:")
sample_predictions.select(TEXT_COL, "prediction").show(truncate=60)

In [None]:
spark.stop()