In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import col
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [5]:
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import col

def product_category_pairs(
    products: DataFrame,
    categories: DataFrame,
    product_categories: DataFrame
) -> DataFrame:
    """Метод который соединяет продукты с категориями
        через промежуточную таблицу многие ко многим
    """
    
    pro_cat_with_cat = product_categories.join(
        categories, on="category_id", how="left"
    )

    # Соединяем все продукты с их категориями (или None, если категорий нет)
    result = products.join(
        pro_cat_with_cat, on="product_id", how="left"
    ).select(
        col("product_name"),
        col("category_name")
    )

    return result


In [6]:
spark = SparkSession.builder.master("local[1]").appName("Test").getOrCreate()

products = spark.createDataFrame(
    [(1, "Apple"), (2, "Banana"), (3, "Orange"), (4, "car magazine"), (5, "cake")],
    ["product_id", "product_name"]
)

categories = spark.createDataFrame(
    [(10, "Fruit"), (20, "Food"), (30, "Electronics")],
    ["category_id", "category_name"]
)

product_categories = spark.createDataFrame(
    [(1, 10), (2, 10), (1, 20), (3, 10), (3, 20), (5, 20)], 
    ["product_id", "category_id"]
)

df = product_category_pairs(products, categories, product_categories)
df.show()


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/20 18:58:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
                                                                                

+------------+-------------+
|product_name|category_name|
+------------+-------------+
|        cake|         Food|
|       Apple|         Food|
|       Apple|        Fruit|
|      Orange|         Food|
|      Orange|        Fruit|
|      Banana|        Fruit|
|car magazine|         NULL|
+------------+-------------+



In [8]:
#test_product_categories.py
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

@pytest.fixture(scope="module")
def spark():
    return SparkSession.builder.master("local[1]").appName("test").getOrCreate()

def test_product_category_pairs(spark):

    products = spark.createDataFrame(
        [(1, "Apple"), (2, "Banana"), (3, "Orange")],
        ["product_id", "product_name"]
    )

    categories = spark.createDataFrame(
        [(10, "Fruit"), (20, "Food")],
        ["category_id", "category_name"]
    )

    product_categories = spark.createDataFrame(
        [(1, 10), (2, 10), (1, 20)],
        ["product_id", "category_id"]
    )

    result = product_category_pairs(products, categories, product_categories)
    rows = result.collect()

    expected = set([
        ("Apple", "Fruit"),
        ("Apple", "Food"),
        ("Banana", "Fruit"),
        ("Orange", None)
    ])
    actual = set((r.product_name, r.category_name) for r in rows)
    assert actual == expected
