# Model Training

This notebook will train the models: 

1) A model to recommend product categories based on customer clusters' history

## Imports

In [14]:
import os

# ETL and Data Manipulation
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexerModel, StringIndexer

# Matrix Factorization
from pyspark.ml.recommendation import ALS

In [15]:
# Create local Spark session
spark = SparkSession.builder \
    .appName("LocalSparkForTesting") \
    .master("local[*]") \
    .getOrCreate()

## Load data

In [16]:
DATA_PATH = os.path.join('/sparkdata/wholesale-recommender','processed')

# Customers with features
customers = spark.read.parquet(os.path.join(DATA_PATH, "customers_features"))

# Customer clusters
customer_clusters = spark.read.parquet(os.path.join(DATA_PATH, "customer_cluster"))

# Interactions
interactions = spark.read.parquet(os.path.join(DATA_PATH, "interactions"))

## Prepare data

### Add the clusters to the customers

In [12]:
# Load customer indexing model
MODELS_PATH = os.path.join('/sparkdata/wholesale-recommender','models')

customer_model_path = os.path.join(MODELS_PATH, "customer_indexer_model")

customer_model = StringIndexerModel.load(customer_model_path)

# Reindex the customer id's to match the interactions dataframe
customer_clusters_indexed = customer_model.transform(customer_clusters)

### One Hot Encode product category

In [13]:
MODEL_PATH = os.path.abspath(os.path.join('/sparkdata/wholesale-recommender', 'models'))

# Make encoder
category_indexer = StringIndexer(inputCol="Product Category", outputCol="category_index")

# Perform encoding
category_indexer_model = category_indexer.fit(interactions)
interactions = category_indexer_model.transform(interactions)

# Save for reverse transform later
category_indexer_model.write().overwrite().save(os.path.join(MODELS_PATH, "category_indexer_model"))

## Add customer clusters

In [29]:
# Perform the join
interactions_with_customer_clusters = interactions.join(
    customer_clusters.withColumnRenamed("Customer ID", "customer_id"),
    on="customer_id",
    how="left"
)

## Train ALS Model

In [31]:
als = ALS(
    userCol="cluster",
    itemCol="category_index",
    ratingCol="rating",
    implicitPrefs=True,         
    coldStartStrategy="drop",   # Avoid NaNs in output
    rank=10,
    maxIter=10,
    regParam=0.1
)

als_model = als.fit(interactions_with_customer_clusters)

## Save

In [32]:
MODEL_PATH = os.path.join('/sparkdata/wholesale-recommender','models','cluster_cat_rec')

# Save
als_model.save(MODEL_PATH)