#### Final Project: H&M Personalized Fashion Recommendations
#### Model Building: MBA, ALS Implicit, Ensemble
#### Last Updated: May 24, 2022
<br>

#### Overview

This notebook intakes parquet files created from the "preprocessing.ipynb" notebook and utilizes it for creation of various models for personalized fashion recommendations. The models created are listed below: <br>

(1) Seasonal Model <br>
(2) Trending Product Model <br>
(3) ALS Implicit Model <br>
(4) Market Basket Analysis <br>
(5) Ensemble Model

<br>

In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.sql.functions import collect_set, col, count, row_number, lit, when
from pyspark.sql.window import Window
from pyspark.sql.functions import from_unixtime, unix_timestamp, year, month, col, date_format
from pyspark.ml.fpm import FPGrowth
from pyspark.mllib.evaluation import RankingMetrics
from pyspark.ml.feature import Bucketizer, StringIndexer
from pyspark.ml.recommendation import ALS
from pyspark.ml import Pipeline


spark = SparkSession.builder \
    .master('local[*]') \
    .config("spark.driver.memory", "100g") \
    .appName('my-cool-app') \
    .getOrCreate()

<br>

#### Load Data

In [2]:
customers = spark.read.csv('/sfs/qumulo/qhome/sar2jf/Documents/hm_fashion/Data/raw_data/customers.csv',  inferSchema=True, header = True)
articles = spark.read.csv('/sfs/qumulo/qhome/sar2jf/Documents/hm_fashion/Data/raw_data/articles.csv',  inferSchema=True, header = True)
transactions_full = spark.read.csv('/sfs/qumulo/qhome/sar2jf/Documents/hm_fashion/Data/raw_data/transactions_train.csv',  inferSchema=True, header = True)

<br>

#### Additional Processing

##### Create Customer Index, Article Index

In [3]:
# Map customer_id to an index

cust_idx = customers.select(col('customer_id')).distinct()

w = Window().orderBy(lit('A'))
cust_idx = cust_idx.withColumn("cust_idx", row_number().over(w))

cust_idx.show(5)


+--------------------+--------+
|         customer_id|cust_idx|
+--------------------+--------+
|000346516dd355b40...|       1|
|0003e56a4332b2503...|       2|
|0011a72ff27917972...|       3|
|0022058e10f379f15...|       4|
|0028449d82fdf6771...|       5|
+--------------------+--------+
only showing top 5 rows



In [4]:
# Map article_id to an index

article_idx = articles.select(col('article_id')).distinct()

w = Window().orderBy(lit('A'))
article_idx = article_idx.withColumn("art_idx", row_number().over(w))

article_idx.show(5)


+----------+-------+
|article_id|art_idx|
+----------+-------+
| 126589006|      1|
| 201219001|      2|
| 241412052|      3|
| 247072032|      4|
| 266873009|      5|
+----------+-------+
only showing top 5 rows



##### Create month / year column for transactions_train

In [17]:
transactions_full = transactions_full.withColumn("t_year", transactions_full.t_dat.substr(1,4))
transactions_full = transactions_full.withColumn("t_month", transactions_full.t_dat.substr(6,2))

transactions_full.show(5)

+----------+--------------------+----------+--------------------+----------------+------+-------+
|     t_dat|         customer_id|article_id|               price|sales_channel_id|t_year|t_month|
+----------+--------------------+----------+--------------------+----------------+------+-------+
|2018-09-20|000058a12d5b43e67...| 663713001|0.050830508474576264|               2|  2018|     09|
|2018-09-20|000058a12d5b43e67...| 541518023| 0.03049152542372881|               2|  2018|     09|
|2018-09-20|00007d2de826758b6...| 505221004| 0.01523728813559322|               2|  2018|     09|
|2018-09-20|00007d2de826758b6...| 685687003|0.016932203389830508|               2|  2018|     09|
|2018-09-20|00007d2de826758b6...| 685687004|0.016932203389830508|               2|  2018|     09|
+----------+--------------------+----------+--------------------+----------------+------+-------+
only showing top 5 rows



##### Load Validation Set and Format

In [9]:
validation = spark.read.parquet("/sfs/qumulo/qhome/sar2jf/Documents/hm_fashion/Data/processed_data/val_par")

In [10]:
validation = validation.groupBy('customer_id') \
                      .agg(collect_set('article_id') \
                      .alias('labels'))

validation.show(5)

+--------------------+--------------------+
|         customer_id|              labels|
+--------------------+--------------------+
|0038bf2b66fdc1de4...|[864415002, 89475...|
|004432f08708cc499...|[933406001, 85078...|
|005ddabf9bc77f963...|[745232001, 91381...|
|006ae0656ded2215d...|[803083002, 85001...|
|0086f22a4967559a6...|[828295001, 87131...|
+--------------------+--------------------+
only showing top 5 rows



##### Below Table is List of Customers To Make Predictions On

In [35]:
# Get list of customers that require predictions, merge to customers to get demographic data

customer_val = validation.select(col('customer_id'))

temp = customers.select(col('customer_id'), col('cust_sex'), col('age_group'))
customer_val = customer_val.join(temp, on = 'customer_id', how = 'inner')

customer_val.show(5)

+--------------------+--------+---------+
|         customer_id|cust_sex|age_group|
+--------------------+--------+---------+
|0038bf2b66fdc1de4...|       F|      5.0|
|0038bf2b66fdc1de4...|       F|      5.0|
|0038bf2b66fdc1de4...|       F|      5.0|
|0038bf2b66fdc1de4...|       F|      5.0|
|004432f08708cc499...|       F|      2.0|
+--------------------+--------+---------+
only showing top 5 rows



<br>

#### Determine Sex of Customers

Sex of customers must be predicted based on the type of article they purchase. First the index_name of each article is mapped to either M/F/ Unknown. Then based on the proportion of products purchased from each category, a sex for each customer is predicted. 

In [19]:
# Map index names of articles to either M/F/Unknown

articles = articles.withColumn('prod_sex',
    when(articles.index_name == 'Menswear', 'M')\
    .when(articles.index_name == 'Lingeries/Tights', 'F')\
    .when(articles.index_name == 'Ladies Accessories', 'F')\
    .when(articles.index_name == 'Ladieswear', 'F')\
    .otherwise('Unknown')
)

articles.select(col('index_name'), col('prod_sex')).show(5)

+----------------+--------+
|      index_name|prod_sex|
+----------------+--------+
|      Ladieswear|       F|
|      Ladieswear|       F|
|      Ladieswear|       F|
|Lingeries/Tights|       F|
|Lingeries/Tights|       F|
+----------------+--------+
only showing top 5 rows



In [20]:
# Join prod_sex to transactions_full

art_temp = articles.select(col('article_id'), col('prod_sex'))
transactions_full = transactions_full.join(art_temp, on = 'article_id', how = 'inner')

transactions_full.show(5)

+----------+----------+--------------------+--------------------+----------------+------+-------+--------+
|article_id|     t_dat|         customer_id|               price|sales_channel_id|t_year|t_month|prod_sex|
+----------+----------+--------------------+--------------------+----------------+------+-------+--------+
| 663713001|2018-09-20|000058a12d5b43e67...|0.050830508474576264|               2|  2018|     09|       F|
| 541518023|2018-09-20|000058a12d5b43e67...| 0.03049152542372881|               2|  2018|     09|       F|
| 505221004|2018-09-20|00007d2de826758b6...| 0.01523728813559322|               2|  2018|     09| Unknown|
| 685687003|2018-09-20|00007d2de826758b6...|0.016932203389830508|               2|  2018|     09|       F|
| 685687004|2018-09-20|00007d2de826758b6...|0.016932203389830508|               2|  2018|     09|       F|
+----------+----------+--------------------+--------------------+----------------+------+-------+--------+
only showing top 5 rows



In [21]:
# Get counts of clothing by sex for each customer

transactions_full.createOrReplaceTempView("transactions_temp")
sex_ind = spark.sql("SELECT customer_id, SUM(CASE WHEN prod_sex = 'M' then 1 ELSE 0 END) AS m_prod_cnt,\
                    SUM(CASE WHEN prod_sex = 'F' then 1 ELSE 0 END) AS f_prod_cnt,\
                    SUM(CASE WHEN prod_sex = 'Unknown' then 1 ELSE 0 END) AS unknown\
                    FROM transactions_temp GROUP BY customer_id")

sex_ind.show(5)

+--------------------+----------+----------+-------+
|         customer_id|m_prod_cnt|f_prod_cnt|unknown|
+--------------------+----------+----------+-------+
|05f65801b9a2d28a5...|         0|        40|     10|
|05f79b715286a38a8...|         0|        13|      4|
|072d11a8c0a1e6d0f...|         5|         2|      0|
|0d8f74345c5153236...|        11|        10|     19|
|102b78267df2dbc2d...|        36|        23|     17|
+--------------------+----------+----------+-------+
only showing top 5 rows



In [22]:
sex_ind = sex_ind.withColumn('cust_sex',
    when(((sex_ind.f_prod_cnt) >= (sex_ind.m_prod_cnt)) &  ((sex_ind.f_prod_cnt) >= (sex_ind.unknown)), 'F')\
    .when(((sex_ind.m_prod_cnt) >= (sex_ind.f_prod_cnt)) &  ((sex_ind.m_prod_cnt) >= (sex_ind.unknown)), 'M')\
    .otherwise('unknown')
)

sex_ind.show(5)

+--------------------+----------+----------+-------+--------+
|         customer_id|m_prod_cnt|f_prod_cnt|unknown|cust_sex|
+--------------------+----------+----------+-------+--------+
|05f65801b9a2d28a5...|         0|        40|     10|       F|
|05f79b715286a38a8...|         0|        13|      4|       F|
|072d11a8c0a1e6d0f...|         5|         2|      0|       M|
|0d8f74345c5153236...|        11|        10|     19| unknown|
|102b78267df2dbc2d...|        36|        23|     17|       M|
+--------------------+----------+----------+-------+--------+
only showing top 5 rows



In [23]:
# Join with customers table and transactions_train

sex_temp = sex_ind.select(col('customer_id'), col('cust_sex'))

customers = customers.join(sex_temp, on = 'customer_id', how = 'left')

In [24]:
customers.show(5)

+--------------------+----+------+------------------+----------------------+---+--------------------+--------+
|         customer_id|  FN|Active|club_member_status|fashion_news_frequency|age|         postal_code|cust_sex|
+--------------------+----+------+------------------+----------------------+---+--------------------+--------+
|000346516dd355b40...|null|  null|            ACTIVE|                  NONE| 25|5b05b77e09b7c5895...| unknown|
|0003e56a4332b2503...| 1.0|   1.0|            ACTIVE|             Regularly| 70|87a11433a6d5bfbd2...|       F|
|0011a72ff27917972...|null|  null|            ACTIVE|                  NONE| 35|ba653343093af1c4c...|       F|
|0022058e10f379f15...|null|  null|            ACTIVE|                  NONE| 48|b8be2bccfe2c6da84...|       F|
|0028449d82fdf6771...|null|  null|        PRE-CREATE|                  NONE| 49|734bf53cf5df1e4a1...|       F|
+--------------------+----+------+------------------+----------------------+---+--------------------+--------+
o

<br>

#### Create prod_index_id

The product_type_no (included in original dataset) only categorizes based on the type of product (i.e. pants), but remains the same across index_name. For example, product_type_no will be the same for pants in ladieswear and pants in menswear. 

The prod_index_id is created based on distinct combinations of both the product_type_name, and the index_name. 

In [25]:
# Create distinct pairings of product types, and index

art_lookup = articles.select(col('product_type_name'), col('index_name')) \
                       .distinct() \
                       .sort(col('product_type_no'))


# Create ID for unique product, index types

w = Window().orderBy(lit('A'))
art_lookup = art_lookup.withColumn("prod_index_id", row_number().over(w))


# Rejoin with articles, extract relevant columns

art_lookup = articles.join(art_lookup, 
                           on = ['product_type_name', 'index_name'], 
                           how = 'inner')

art_lookup = art_lookup.select(col('article_id'),
                  col('product_type_name'), col('product_code'),
                  col('product_type_no'), 
                  col('index_name'), 
                  col('prod_index_id')).sort(col('prod_index_id'))

art_lookup_short = art_lookup.select(col('article_id'), 
                  col('prod_index_id')).sort(col('prod_index_id'))

In [26]:
art_lookup.show(5)

+----------+-----------------+------------+---------------+--------------------+-------------+
|article_id|product_type_name|product_code|product_type_no|          index_name|prod_index_id|
+----------+-----------------+------------+---------------+--------------------+-------------+
| 519243001|          Unknown|      519243|             -1|Children Accessor...|            1|
| 867969002|          Unknown|      867969|             -1|             Divided|            2|
| 691704002|          Unknown|      691704|             -1|             Divided|            2|
| 724906018|          Unknown|      724906|             -1|             Divided|            2|
| 724906019|          Unknown|      724906|             -1|             Divided|            2|
+----------+-----------------+------------+---------------+--------------------+-------------+
only showing top 5 rows



<br>

#### Create Age Buckets For Customers Table

In [27]:
customers.show(5)

+--------------------+----+------+------------------+----------------------+---+--------------------+--------+
|         customer_id|  FN|Active|club_member_status|fashion_news_frequency|age|         postal_code|cust_sex|
+--------------------+----+------+------------------+----------------------+---+--------------------+--------+
|000346516dd355b40...|null|  null|            ACTIVE|                  NONE| 25|5b05b77e09b7c5895...| unknown|
|0003e56a4332b2503...| 1.0|   1.0|            ACTIVE|             Regularly| 70|87a11433a6d5bfbd2...|       F|
|0011a72ff27917972...|null|  null|            ACTIVE|                  NONE| 35|ba653343093af1c4c...|       F|
|0022058e10f379f15...|null|  null|            ACTIVE|                  NONE| 48|b8be2bccfe2c6da84...|       F|
|0028449d82fdf6771...|null|  null|        PRE-CREATE|                  NONE| 49|734bf53cf5df1e4a1...|       F|
+--------------------+----+------+------------------+----------------------+---+--------------------+--------+
o

In [28]:
# Add in age buckets

bucketizer = Bucketizer(splits=[0, 18, 20, 25, 30, 40, 50, 120],inputCol="age", outputCol="age_group")
customers = bucketizer.setHandleInvalid("keep").transform(customers)

dem_short = customers.select(col('customer_id'), col('age_group'))

dem_short.show(5)

+--------------------+---------+
|         customer_id|age_group|
+--------------------+---------+
|000346516dd355b40...|      3.0|
|0003e56a4332b2503...|      6.0|
|0011a72ff27917972...|      4.0|
|0022058e10f379f15...|      5.0|
|0028449d82fdf6771...|      5.0|
+--------------------+---------+
only showing top 5 rows



In [29]:
# Join with transactions_full

transactions_full = transactions_full.join(dem_short, on = 'customer_id', how = 'inner')

transactions_full.show(5)

+--------------------+----------+----------+--------------------+----------------+------+-------+--------+---------+
|         customer_id|article_id|     t_dat|               price|sales_channel_id|t_year|t_month|prod_sex|age_group|
+--------------------+----------+----------+--------------------+----------------+------+-------+--------+---------+
|000346516dd355b40...| 534210011|2018-11-28| 0.01523728813559322|               2|  2018|     11|       F|      3.0|
|000346516dd355b40...| 666084001|2018-11-28|0.022016949152542376|               2|  2018|     11|       M|      3.0|
|000346516dd355b40...| 557248003|2018-12-06|0.027101694915254236|               2|  2018|     12|       M|      3.0|
|000346516dd355b40...| 507909001|2019-02-13|0.025406779661016947|               1|  2019|     02|       F|      3.0|
|000346516dd355b40...| 642437010|2019-05-30|0.033881355932203386|               1|  2019|     05| Unknown|      3.0|
+--------------------+----------+----------+--------------------

<br>

#### (1) Seasonal Model

This model will recommend products based on seasonality. Note below that the prediction period is the week following 9/22/2020. Thus find most popular products sold in September, October of previous years and use as recommendations. Most recent seasons will be weighted more heavily

Note that only previous seasons are considered i.e. 2018, 2019

In [30]:
# Find prediction period

transactions_full.agg({'t_dat': 'max'}).collect()

[Row(max(t_dat)='2020-09-22')]

In [31]:
# Subset transactions_full to only sales in Nov, Oct

tx_SepOct = transactions_full.filter(col('t_month').isin('09', '10'))
tx_SepOct = tx_SepOct.filter(col('t_year').isin('2018', '2019'))


# Create weights, heavier for more recent year

tx_SepOct = tx_SepOct.withColumn(
    'weight',
    when(tx_SepOct.t_year == '2018', 0.5)\
    .when(tx_SepOct.t_year == '2019', 1)
)

tx_SepOct.show(5)

+--------------------+----------+----------+--------------------+----------------+------+-------+--------+---------+------+
|         customer_id|article_id|     t_dat|               price|sales_channel_id|t_year|t_month|prod_sex|age_group|weight|
+--------------------+----------+----------+--------------------+----------------+------+-------+--------+---------+------+
|0011a72ff27917972...| 779059003|2019-10-14|0.050830508474576264|               2|  2019|     10|       F|      4.0|   1.0|
|0011a72ff27917972...| 624257003|2019-10-14| 0.08472881355932205|               2|  2019|     10|       F|      4.0|   1.0|
|0011a72ff27917972...| 745973002|2019-10-14|0.033881355932203386|               2|  2019|     10| Unknown|      4.0|   1.0|
|0011a72ff27917972...| 796314001|2019-10-14|0.050830508474576264|               2|  2019|     10|       F|      4.0|   1.0|
|0011a72ff27917972...| 807756001|2019-10-14|0.008457627118644067|               2|  2019|     10|       F|      4.0|   1.0|
+-------

##### Below table shows most popular products sold in months of September, October by sex and age-group

In [34]:
# Below gives the final Top-n purchase table

tx_SepOct_small = tx_SepOct.select(col('article_id'), col('age_group'), col('weight'), col('prod_sex'))

top_n = tx_SepOct_small.groupBy('prod_sex', 'age_group', 'article_id').sum('weight')
top_n = top_n.sort(col('prod_sex'), col('age_group'), col('sum(weight)').desc())


# Only keep top 12 most popular products

windowDept = Window.partitionBy(['prod_sex', 'age_group']).orderBy(col("sum(weight)").desc())
top_n = top_n.withColumn("row",row_number().over(windowDept))
top_n = top_n.filter(col('row') <= 12)

top_n = top_n.sort(col('prod_sex'), col('age_group'), col('sum(weight)').desc())

top_n.show()

+--------+---------+----------+-----------+---+
|prod_sex|age_group|article_id|sum(weight)|row|
+--------+---------+----------+-----------+---+
|       F|     null| 673677002|       31.0|  1|
|       F|     null| 778064001|       30.0|  2|
|       F|     null| 464297007|       29.0|  3|
|       F|     null| 767834001|       29.0|  4|
|       F|     null| 399256001|       27.0|  5|
|       F|     null| 372860001|       25.5|  6|
|       F|     null| 568601006|       24.0|  7|
|       F|     null| 562245046|       23.0|  8|
|       F|     null| 562245001|       22.5|  9|
|       F|     null| 608776002|       22.5| 10|
|       F|     null| 752814004|       21.0| 12|
|       F|     null| 470985003|       21.0| 11|
|       F|      0.0| 399256001|       11.0|  1|
|       F|      0.0| 673677002|        6.0|  2|
|       F|      0.0| 738943010|        5.0|  4|
|       F|      0.0| 573085004|        5.0|  7|
|       F|      0.0| 777070002|        5.0|  5|
|       F|      0.0| 693243019|        5

##### Create Predictions By Merging Top-N Table, Format for Evaluation

In [40]:
# Merge top_n to get predictions

predictions = customer_val.join(top_n, (customer_val.cust_sex == top_n.prod_sex) & (customer_val.age_group == top_n.age_group), 
                                how = 'inner')


predictions = predictions.groupBy('customer_id') \
                      .agg(collect_set('article_id') \
                      .alias('preds'))

pred_labels = predictions.join(validation, on = 'customer_id', how = 'inner') 

pred_labels.show(5)

+--------------------+--------------------+--------------------+
|         customer_id|               preds|              labels|
+--------------------+--------------------+--------------------+
|0038bf2b66fdc1de4...|[399256001, 56860...|[864415002, 89475...|
|004432f08708cc499...|[507909001, 67367...|[933406001, 85078...|
|005ddabf9bc77f963...|[562245001, 46429...|[745232001, 91381...|
|006ae0656ded2215d...|[562245001, 67367...|[803083002, 85001...|
|0086f22a4967559a6...|[685813001, 73086...|[828295001, 87131...|
+--------------------+--------------------+--------------------+
only showing top 5 rows



##### Evaluate Predictions

In [45]:
pred_labels = pred_labels.select(col("preds"), col("labels")).rdd
pred_labels = pred_labels.map(lambda x: (x[0], x[1]))

In [46]:
metrics = RankingMetrics(pred_labels)

In [47]:
metrics.meanAveragePrecisionAt(12)

0.0020068105715145322

<br>

#### (2) Trending Product Model

This model will recommend trending products, i.e. most popular products in the months leading up to the prediction period. Specifically, Aug-Sep 2020 most popular products will be recommended. 

In [48]:
tx_AugSep = transactions_full.filter(col('t_month').isin('08', '09'))
tx_AugSep = tx_AugSep.filter(col('t_year').isin('2020'))

tx_AugSep = tx_AugSep.withColumn('weight',
    when(tx_AugSep.t_month == '07', 0.5)\
    .when(tx_AugSep.t_month == '08', 0.5)\
    .when(tx_AugSep.t_month == '09', 1)
)

tx_AugSep.show(5)

+--------------------+----------+----------+--------------------+----------------+------+-------+--------+---------+------+
|         customer_id|article_id|     t_dat|               price|sales_channel_id|t_year|t_month|prod_sex|age_group|weight|
+--------------------+----------+----------+--------------------+----------------+------+-------+--------+---------+------+
|0038bf2b66fdc1de4...| 864415002|2020-08-14|0.016932203389830508|               1|  2020|     08|       F|      5.0|   0.5|
|0038bf2b66fdc1de4...| 894756001|2020-08-14|0.022016949152542376|               1|  2020|     08|       F|      5.0|   0.5|
|0038bf2b66fdc1de4...| 904820001|2020-08-14|0.033881355932203386|               1|  2020|     08|       F|      5.0|   0.5|
|0038bf2b66fdc1de4...| 911870003|2020-09-12|0.033881355932203386|               1|  2020|     09|       F|      5.0|   1.0|
|004432f08708cc499...| 850784003|2020-09-06|0.025406779661016947|               1|  2020|     09| Unknown|      2.0|   1.0|
+-------

##### Below table shows most popular products sold in months of August, and September 2020 by sex and age-group

In [50]:
# Below gives the final Top-n purchase table

tx_AugSep_small = tx_AugSep.select(col('article_id'), col('age_group'), col('weight'), col('prod_sex'))

top_n = tx_AugSep_small.groupBy('prod_sex', 'age_group', 'article_id').sum('weight')
top_n = top_n.sort(col('prod_sex'), col('age_group'), col('sum(weight)').desc())

windowDept = Window.partitionBy(['prod_sex', 'age_group']).orderBy(col("sum(weight)").desc())
top_n = top_n.withColumn("row",row_number().over(windowDept))
top_n = top_n.filter(col('row') <= 12)

top_n = top_n.sort(col('prod_sex'), col('age_group'), col('sum(weight)').desc())

top_n.show()

+--------+---------+----------+-----------+---+
|prod_sex|age_group|article_id|sum(weight)|row|
+--------+---------+----------+-----------+---+
|       F|     null| 863595006|       10.5|  1|
|       F|     null| 778064001|       10.5|  2|
|       F|     null| 158340001|       10.0|  3|
|       F|     null| 160442010|       10.0|  4|
|       F|     null| 751471043|        8.5|  5|
|       F|     null| 896169002|        8.0|  7|
|       F|     null| 610776002|        8.0|  6|
|       F|     null| 767862001|        7.5|  9|
|       F|     null| 554598001|        7.5|  8|
|       F|     null| 736870001|        7.0| 12|
|       F|     null| 896152002|        7.0| 10|
|       F|     null| 778064028|        7.0| 11|
|       F|      0.0| 552716001|       29.0|  1|
|       F|      0.0| 639448001|       24.5|  2|
|       F|      0.0| 456163086|       24.5|  3|
|       F|      0.0| 850917001|       20.5|  4|
|       F|      0.0| 907409001|       17.0|  5|
|       F|      0.0| 894756001|       16

##### Create Predictions By Merging Top-N Table, Format for Evaluation

In [51]:
# Merge top_n to get predictions

predictions = customer_val.join(top_n, (customer_val.cust_sex == top_n.prod_sex) & (customer_val.age_group == top_n.age_group), 
                                how = 'inner')


predictions = predictions.groupBy('customer_id') \
                      .agg(collect_set('article_id') \
                      .alias('preds'))

pred_labels = predictions.join(validation, on = 'customer_id', how = 'inner') 

pred_labels.show(5)


+--------------------+--------------------+--------------------+
|         customer_id|               preds|              labels|
+--------------------+--------------------+--------------------+
|0038bf2b66fdc1de4...|[865929003, 78334...|[864415002, 89475...|
|004432f08708cc499...|[915526001, 78161...|[933406001, 85078...|
|005ddabf9bc77f963...|[610776002, 91552...|[745232001, 91381...|
|006ae0656ded2215d...|[863583001, 91552...|[803083002, 85001...|
|0086f22a4967559a6...|[685814063, 68581...|[828295001, 87131...|
+--------------------+--------------------+--------------------+
only showing top 5 rows



##### Evaluate Predictions

In [52]:
pred_labels = pred_labels.select(col("preds"), col("labels")).rdd
pred_labels = pred_labels.map(lambda x: (x[0], x[1]))

In [53]:
metrics = RankingMetrics(pred_labels)

In [54]:
metrics.meanAveragePrecisionAt(12)

0.006887753792953343

<br>

#### (3) ALS Implicit Model

##### Format TrainTest_par file for training of ALS Model

In [5]:
# Load TrainTest_par

TrainTest = spark.read.parquet("/sfs/qumulo/qhome/sar2jf/Documents/hm_fashion/Data/processed_data/TrainTest_par")
TrainTest_als = TrainTest.groupby('customer_id', 'article_id').count()


# Need art_idx and cust_idx

TrainTest_als = TrainTest_als.join(cust_idx, on = 'customer_id', how = 'inner')
TrainTest_als = TrainTest_als.join(article_idx, on = 'article_id', how = 'inner')

TrainTest_als.show(5)

+----------+--------------------+-----+--------+-------+
|article_id|         customer_id|count|cust_idx|art_idx|
+----------+--------------------+-----+--------+-------+
| 694966001|000346516dd355b40...|    1|       1|  14954|
| 485176004|000346516dd355b40...|    1|       1|   9496|
| 666084001|000346516dd355b40...|    1|       1|  57811|
| 809672001|000346516dd355b40...|    1|       1|  93320|
| 766955002|000346516dd355b40...|    1|       1|  15091|
+----------+--------------------+-----+--------+-------+
only showing top 5 rows



##### Train ALS Model

In [6]:
als = ALS(userCol = "cust_idx", 
        itemCol = "art_idx", 
        ratingCol = "count",
        coldStartStrategy = "drop", 
        nonnegative = True,
        implicitPrefs = True
       )

model_als = als.fit(TrainTest_als)

##### Make Predictions

In [11]:
# Create list to make predictions on, based on cust_idx

user_subset = validation.select(col('customer_id'))
user_subset = user_subset.join(cust_idx, on = 'customer_id', how = 'inner')
user_subset = user_subset.select(col('cust_idx'))

user_subset.show(5)

+--------+
|cust_idx|
+--------+
|       9|
|      11|
|      14|
|      15|
|      18|
+--------+
only showing top 5 rows



In [27]:
# Recommend products 

pred_als = model_als.recommendForUserSubset(user_subset, 12)

pred_als.show(5)

+--------+--------------------+
|cust_idx|     recommendations|
+--------+--------------------+
|     148|[{35801, 0.127422...|
|    4935|[{51857, 0.160507...|
|    6466|[{7857, 0.0731965...|
|    6620|[{23821, 0.094929...|
|    7340|[{87396, 0.186739...|
+--------+--------------------+
only showing top 5 rows



##### Evaluate Predictions

In [16]:
# Format validation data for ALS model evaluation

val_als = spark.read.parquet("/sfs/qumulo/qhome/sar2jf/Documents/hm_fashion/Data/processed_data/val_par")

val_als = val_als.join(article_idx, on = 'article_id', how = 'inner')
val_als = val_als.join(cust_idx, on = 'customer_id', how = 'inner')


val_als = val_als.groupBy('cust_idx') \
                      .agg(collect_set('art_idx') \
                      .alias('labels'))

val_als.show(5)

+--------+--------------------+
|cust_idx|              labels|
+--------+--------------------+
|     148|             [82870]|
|    4935|[99746, 19946, 7819]|
|    6466|     [87682, 101258]|
|    6620|[11863, 97930, 20...|
|    7340|[44166, 36724, 14...|
+--------+--------------------+
only showing top 5 rows



In [32]:
# Create predictions and labels dataframe

pred_label_als = pred_als.select(col('recommendations.art_idx'), col('cust_idx'))
pred_label_als = pred_label_als.join(val_als, on = 'cust_idx', how = 'inner')

pred_label_als.show(5)

+--------+--------------------+--------------------+
|cust_idx|             art_idx|              labels|
+--------+--------------------+--------------------+
|     148|[35801, 85631, 43...|             [82870]|
|    4935|[51857, 68776, 82...|[99746, 19946, 7819]|
|    6466|[7857, 46719, 152...|     [87682, 101258]|
|    6620|[23821, 26974, 22...|[11863, 97930, 20...|
|    7340|[87396, 102680, 5...|[44166, 36724, 14...|
+--------+--------------------+--------------------+
only showing top 5 rows



In [33]:
pred_label_als = pred_label_als.select('art_idx', "labels").rdd
pred_label_als = pred_label_als.map(lambda x: (x[0], x[1]))

In [34]:
metrics = RankingMetrics(pred_label_als)

In [35]:
metrics.meanAveragePrecisionAt(12)

0.004910241737678567

<br>

#### (4) Market Basket Analysis (MBA)

In [36]:
# Read in train data from parquet

parquetFile = spark.read.parquet("/sfs/qumulo/qhome/sar2jf/Documents/hm_fashion/Data/processed_data/train_mba_par")

parquetFile.createOrReplaceTempView("train_mba")
train_mba = spark.sql("SELECT * FROM train_mba")

train_mba.show(5)


+--------------------+--------------------+
|         customer_id|               items|
+--------------------+--------------------+
|0005d926defbe130e...|[780297002, 84985...|
|000fd548002726b75...|[693242006, 47078...|
|00299609d2f3cd490...|[777504004, 53613...|
|003cac311b5886fd1...|         [726768001]|
|003dcdb1164332b8a...|[769620001, 38685...|
+--------------------+--------------------+
only showing top 5 rows



In [37]:
# Read in validation data from parquet

parquetFile = spark.read.parquet("/sfs/qumulo/qhome/sar2jf/Documents/hm_fashion/Data/processed_data/val_mba_par")
parquetFile.createOrReplaceTempView("val_mba")

val_mba = spark.sql("SELECT * FROM val_mba")

val_mba.show(5)

+--------------------+--------------------+--------------------+
|         customer_id|              labels|             history|
+--------------------+--------------------+--------------------+
|002a72483ab192e7c...|[278811009, 27881...|[740943003, 74814...|
|00b3d59d31333ebd0...|[919273002, 87152...|[699867001, 79079...|
|00c0875075599d2b4...|[715624001, 67793...|[785018005, 81659...|
|00c51f82b53ff1b4b...|         [337991002]|[763529002, 85570...|
|00c80932cf9461749...|[803969009, 86441...|[689898003, 78586...|
+--------------------+--------------------+--------------------+
only showing top 5 rows



In [44]:
# Specify Model

fpGrowth = FPGrowth(itemsCol="items", minSupport=0.01, minConfidence=0.01)

In [None]:
mba_model = fpGrowth.fit(train_mba)

In [None]:
# Show top 10 association rules

mba_rules = mba_model.associationRules.sort(col('Confidence').desc())

mba_rules.show(20)

In [25]:
# View items of first rule

print("Antedents:")

art_lookup.filter(col('prod_index_id') == '227').select(col('prod_index_id'), col('product_type_name')).show(1)
art_lookup.filter(col('prod_index_id') == '21').select(col('prod_index_id'), col('product_type_name')).show(1)
art_lookup.filter(col('prod_index_id') == '454').select(col('prod_index_id'), col('product_type_name')).show(1)

print("")
print("Consequent:")

art_lookup.filter(col('prod_index_id') == '415').select(col('prod_index_id'), col('product_type_name')).show(1)

Antedents:
+-------------+-----------------+
|prod_index_id|product_type_name|
+-------------+-----------------+
|          227|         Vest top|
+-------------+-----------------+
only showing top 1 row

+-------------+-----------------+
|prod_index_id|product_type_name|
+-------------+-----------------+
|           21|  Swimwear bottom|
+-------------+-----------------+
only showing top 1 row

+-------------+-----------------+
|prod_index_id|product_type_name|
+-------------+-----------------+
|          454|              Bra|
+-------------+-----------------+
only showing top 1 row


Consequent:
+-------------+-----------------+
|prod_index_id|product_type_name|
+-------------+-----------------+
|          415|       Bikini top|
+-------------+-----------------+
only showing top 1 row

