# Set up packages and dataframes

In [1]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

spark = SparkSession.builder.getOrCreate()

In [2]:
import os
from nltk.stem.snowball import SnowballStemmer
import re
import regex

In [3]:
DATA_FOLDER = 'data/'

transactions = spark.read.options(    
        header=True,  
        inferSchema=True
    ).csv(
        os.path.join(DATA_FOLDER, 'sales_train.csv'), 
    )

items = spark.read.options(    
        header=True,  
        inferSchema=True
    ).csv(
        os.path.join(DATA_FOLDER, 'items.csv'), 
    )

item_categories = spark.read.options(    
        header=True,  
        inferSchema=True
    ).csv(
        os.path.join(DATA_FOLDER, 'item_categories.csv'), 
    )

shops = spark.read.options(    
        header=True,  
        inferSchema=True
    ).csv(
        os.path.join(DATA_FOLDER, 'shops.csv'), 
    )

test = spark.read.options(    
        header=True,  
        inferSchema=True
    ).csv(
        os.path.join(DATA_FOLDER, 'test.csv'), 
    )

# EDA

## Look at dataframes

Print the top 10 rows and the total number of rows.

In [4]:
print('Total number of rows: {}'.format(transactions.count()))
transactions.show(10)

print('Total number of rows: {}'.format(items.count()))
items.show(10)

print('Total number of rows: {}'.format(item_categories.count()))
item_categories.show(10)

print('Total number of rows: {}'.format(shops.count()))
shops.show(10)

print('Total number of rows: {}'.format(test.count()))
test.show(10)

Total number of rows: 2935849
+----------+--------------+-------+-------+----------+------------+
|      date|date_block_num|shop_id|item_id|item_price|item_cnt_day|
+----------+--------------+-------+-------+----------+------------+
|02.01.2013|             0|     59|  22154|     999.0|         1.0|
|03.01.2013|             0|     25|   2552|     899.0|         1.0|
|05.01.2013|             0|     25|   2552|     899.0|        -1.0|
|06.01.2013|             0|     25|   2554|   1709.05|         1.0|
|15.01.2013|             0|     25|   2555|    1099.0|         1.0|
|10.01.2013|             0|     25|   2564|     349.0|         1.0|
|02.01.2013|             0|     25|   2565|     549.0|         1.0|
|04.01.2013|             0|     25|   2572|     239.0|         1.0|
|11.01.2013|             0|     25|   2572|     299.0|         1.0|
|03.01.2013|             0|     25|   2573|     299.0|         3.0|
+----------+--------------+-------+-------+----------+------------+
only showing top 1

We see that in the training dataset, there are 2935849 transactions, 22170 items, 84 item categories and 60 shops.

# Feature Extraction

## Join all data onto the transactions dataframe 

Upon inspection, one sees that all features appearing in the training dataframes above can be joined onto the transactions dataframe using the appropriate ID.

In [5]:
transactions.createOrReplaceTempView('transactions')
items.createOrReplaceTempView('items')

transactions = spark.sql(('SELECT transactions.*, items.item_name, items.item_category_id '
                  ' FROM transactions '
                  ' LEFT JOIN items '
                  '  ON transactions.item_id = items.item_id '))

In [6]:
transactions.createOrReplaceTempView('transactions')
item_categories.createOrReplaceTempView('item_categories')

transactions = spark.sql(('SELECT transactions.*, item_categories.item_category_name '
                  ' FROM transactions '
                  ' LEFT JOIN item_categories '
                  '  ON transactions.item_category_id = item_categories.item_category_id '))

In [7]:
transactions.createOrReplaceTempView('transactions')
shops.createOrReplaceTempView('shops')

transactions = spark.sql(('SELECT transactions.*, shops.shop_name '
                  ' FROM transactions '
                  ' LEFT JOIN shops '
                  '  ON transactions.shop_id = shops.shop_id '))

## Extract the day, month and year from the date.

In [8]:
transactions = transactions.withColumn('date', F.to_date(transactions.date, format='dd.MM.yyyy'))

In [9]:
transactions = transactions.withColumn('day', F.dayofyear(transactions.date))
transactions = transactions.withColumn('month', F.month(transactions.date))
transactions = transactions.withColumn('year', F.year(transactions.date))

## Extract text-based features

Define a stemmer that can handle both Russian and English text using nltk's Snowball Stemmer.

In [10]:
en_stemmer = SnowballStemmer('english')
ru_stemmer = SnowballStemmer('russian')

In [11]:
def clean_text(text):
    """ Removes punctuation from string, unwanted unicode characters, and numbers. Returns in lowercase.
    
    Args:
        text (str): The text to clean.
    
    Returns:
        The cleaned text after filtered by the regex expression and made lowercase.
    
    For more information on the unicode categories used in the regex expression see here:
    https://www.regular-expressions.info/unicode.html#category
    
    >>> clean_text("!$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ Can't, - Trademark™ ...「（Punctuation）」42.32 ?")
    cant trademark punctuation
    
    """
    # remove URLs
    text = re.sub(r"http\S+", "", text)
    # remove apostrophes 
    text = text.replace("'", "")
    
    # Define regex unicode Categories and strip from string
    remove = regex.compile('[\p{C}|\p{M}|\p{P}|\p{S}|\p{Z}|\p{N}]+', regex.UNICODE)
    text = remove.sub(" ", text).strip()
    
    # make lowercase
    text = text.lower()
    
    return text

def stemmer(text):
    """Identify the words written in Cyrillic and Latin characters in a string,
    and apply a Russian or English stemmer, respectively.
    
    Args:
        text(str): The string whose Cyrillic and Latin text will be stemmed.
    
    Returns:
        A stemmed version of the text.
    """
    cyr_regex = regex.compile('\p{Cyrillic}+', regex.UNICODE)
    lat_regex = regex.compile('\p{Latin}+', regex.UNICODE)
    
    text = clean_text(text)

    words = re.split('\s', text)
    stemmed_text = ''
    for word in words:
        ru = regex.search(cyr_regex, word)
        en = regex.search(lat_regex, word)
        if ru:
            stemmed_word = ru_stemmer.stem(word)
        elif en:
            stemmed_word = en_stemmer.stem(word)
        else:
            stemmed_word = word
        stemmed_text = stemmed_text + ' ' + stemmed_word
    # remove leading space from stemmed_text
    stemmed_text = stemmed_text[1:]
    
    return stemmed_text

Demonstrate function on sample text from the dataset.

In [12]:
text = '(Кино) - Blu-Ray'

print(stemmer(text))

кин blu ray


Apply stemmer to columns containing text.

In [13]:
from pyspark.sql.types import StringType

udf_stemmer = F.udf(stemmer, StringType())

In [14]:
transactions = transactions.withColumn('stemmed_item_name', udf_stemmer(transactions.item_name))
transactions = transactions.withColumn('stemmed_item_category_name', udf_stemmer(transactions.item_category_name))
transactions = transactions.withColumn('stemmed_shop_name', udf_stemmer(transactions.shop_name))

## Display the resulting dataframe.

In [15]:
transactions.limit(10).toPandas()

PythonException: 
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  File "C:\spark\spark-3.1.1-bin-hadoop2.7\python\lib\pyspark.zip\pyspark\worker.py", line 588, in main
  File "C:\spark\spark-3.1.1-bin-hadoop2.7\python\lib\pyspark.zip\pyspark\worker.py", line 447, in read_udfs
  File "C:\spark\spark-3.1.1-bin-hadoop2.7\python\lib\pyspark.zip\pyspark\worker.py", line 249, in read_single_udf
  File "C:\spark\spark-3.1.1-bin-hadoop2.7\python\lib\pyspark.zip\pyspark\worker.py", line 69, in read_command
  File "C:\spark\spark-3.1.1-bin-hadoop2.7\python\lib\pyspark.zip\pyspark\serializers.py", line 160, in _read_with_length
    return self.loads(obj)
  File "C:\spark\spark-3.1.1-bin-hadoop2.7\python\lib\pyspark.zip\pyspark\serializers.py", line 430, in loads
    return pickle.loads(obj, encoding=encoding)
  File "C:\spark\spark-3.1.1-bin-hadoop2.7\python\lib\pyspark.zip\pyspark\cloudpickle\cloudpickle.py", line 562, in subimport
    __import__(name)
ModuleNotFoundError: No module named 'regex'
