<a href="https://colab.research.google.com/github/Pathairush/pytest_data_unittest/blob/main/pytest_for_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://downloads.apache.org/spark/spark-3.2.1/spark-3.2.1-bin-hadoop2.7.tgz
!tar xzf spark-3.2.1-bin-hadoop2.7.tgz
!pip install -q findspark 
!pip install -q great-expectations==0.14.13
!pip install -q ruamel-yaml
!pip install pytest -q
!pip install ipytest -q

In [2]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.2.1-bin-hadoop2.7"

import findspark
findspark.init()

In [3]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import Window
from pyspark.sql import types as T
import pyspark.pandas as ps

# load data into a memory
spark = (
    SparkSession.builder
    .master("local")
    .appName("colab")
    .config('spark.ui.port', '4050')
    .getOrCreate()
)



In [4]:
import numpy as np
import pandas as pd
import pytest
import ipytest
ipytest.autoconfig()

In [13]:
@pytest.fixture()
def df():
    return ps.DataFrame(
        {
            'a' : [1,2,3],
            'b' : [1,2,3],
            'c' : [1,2,3],
         }
    ).to_pandas()

@pytest.fixture()
def expect_df_add_col():
    return ps.DataFrame(
        {
            'a' : [1,2,3],
            'b' : [1,2,3],
            'c' : [1,2,3],
            'd' : [4,4,4]
         }
    ).to_pandas()
    

def add_col_to_df(df):
    df['d'] = 4
    return df

def test_add_col_to_df(df, expect_df_add_col):
    actual = add_col_to_df(df)
    pd.testing.assert_frame_equal(actual, expect_df_add_col)


# Reference
- [Writing tests for result api in python](https://www.ontestautomation.com/writing-tests-for-restful-apis-in-python-using-requests-part-2-data-driven-tests/#:~:text=Pytest%20supports%20data%20driven%20testing,the%20test%20data%20object%20itself.)
- [Writing Better Code Through Testing](https://towardsdatascience.com/writing-better-code-through-testing-f3150abec6ca)
- [Testing Best Practices for Machine Learning Libraries](https://towardsdatascience.com/testing-best-practices-for-machine-learning-libraries-41b7d0362c95)
- [pandas.testing.assert_frame_equal](https://pandas.pydata.org/docs/reference/api/pandas.testing.assert_frame_equal.html)

In [35]:
@pytest.fixture()
def fund_portfolio_psdf():
    return ps.DataFrame(
        {
            'customer_number' : ['110XXXXXXXXXX', '110XXXXXXXXXX', '110XXXXXXXXXX', '110XXXXXXXXXX'],
            'fund_code' : ['A', 'B', 'C', 'D'],
            'available_amount' : [300, 200, 100, 150]
        }
    )

@pytest.fixture()
def fact_customer_products():
    return ps.DataFrame(
        {
            'customer_number' : ['110XXXXXXXXXX'],
            'product_fund_list' : [['A', 'B', 'D']]
        }
    ).to_pandas()

def get_fact_customer_products(psdf):
    return (
        psdf
        .to_spark()
        .withColumn('rank', F.rank().over(Window.partitionBy('customer_number').orderBy(F.desc('available_amount'))))
        .filter('rank <= 3')
        .groupby('customer_number')
        .agg(F.array_sort(F.collect_set('fund_code')).alias('product_fund_list'))
        .toPandas()
    )

def test_get_fact_customer_products(fund_portfolio_psdf, fact_customer_products):
    actual = get_fact_customer_products(fund_portfolio_psdf)
    pd.testing.assert_frame_equal(actual, fact_customer_products)

In [36]:
ipytest.run()

[32m.[0m[32m.[0m[32m                                                                                           [100%][0m
[32m[32m[1m2 passed[0m[32m in 0.88s[0m[0m


<ExitCode.OK: 0>