# Local PySpark on SageMaker Studio

This notebook shows how to run local PySpark code within a SageMaker Studio notebook. For this example we use the **Data Science - Python3** image and kernel, but this methodology should work for any kernel within SM Studio, including BYO custom images.

## Setup
We need an available Java installation to run pyspark. The easiest way to do this is to install JDK and set the proper paths using conda

In [2]:
# Setup - Run only once per Kernel App
%conda update -n base -c defaults conda -q
%conda install openjdk -q -y

Retrieving notices: ...working... done
Channels:
 - defaults
 - conda-forge
Platform: linux-64
Collecting package metadata (repodata.json): ...working... done
Solving environment: ...working... done

## Package Plan ##

  environment location: /opt/conda

  added / updated specs:
    - conda


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2024.7.2   |       h06a4308_0         127 KB
    certifi-2024.7.4           |  py310h06a4308_0         158 KB
    conda-23.11.0              |  py310h06a4308_0         997 KB
    ------------------------------------------------------------
                                           Total:         1.3 MB

The following packages will be UPDATED:

  ca-certificates    conda-forge::ca-certificates-2024.2.2~ --> pkgs/main::ca-certificates-2024.7.2-h06a4308_0 
  certifi            conda-forge/noarch::certifi-2024.2.2-~ --> pkgs/main/linux-64

## Install PySpark

In [2]:
%pip install pyspark==3.5.0 
%pip install bokeh==2.4.0
%pip install pandas==1.5.1 matplotlib==3.5.2 
%pip install psycopg2-binary==2.9.9 
%pip install scikit-learn==1.0.2 statsmodels==0.13.2 scipy==1.7.3 
%pip install symbulate==0.5.7 seaborn==0.11.2 
%pip install tenacity==8.0.1 SQLAlchemy==2.0.23
%pip install xgboost==2.0.2 pyarrow==14.0.1 
%pip install asyncio==3.4.3 nest-asyncio==1.5.8 aiohttp==3.9.1 
%pip install boto3==1.21.33 botocore
%pip install fsspec==2023.12.1 fastparquet==2023.10.1
%pip install watchtower
%pip install polars==1.3.0

[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
Collecting polars==1.3.0
  Using cached polars-1.3.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (14 kB)
Using cached polars-1.3.0-cp38-ab

## Utilize S3 Data within local PySpark
* By specifying the `hadoop-aws` jar in our Spark config we're able to access S3 datasets using the s3a file prefix. 
* Since we've already authenticated ourself to SageMaker Studio , we can use our assumed SageMaker ExecutionRole for any S3 reads/writes by setting the credential provider as `ContainerCredentialsProvider`

In [2]:
# Import pyspark and build Spark session
from pyspark.sql import SparkSession
from libs.config.config_vars import CONFIG, ENVIRONMENT, DME_PROJECT, S3_PREFIX, ENVIRONMENT_VARS, S3_DATA_PREFIX, \
    S3_BUCKET

env = ENVIRONMENT_VARS

spark = (
    SparkSession.builder.appName("PySparkApp")
    .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.4.0")
    .config(
        "fs.s3a.aws.credentials.provider",
        "com.amazonaws.auth.ContainerCredentialsProvider",
    )
    .config('spark.driver.extraJavaOptions',
          '-Dio.netty.tryReflectionSetAccessible=true')
    .config('spark.executor.extraJavaOptions',
          '-Dio.netty.tryReflectionSetAccessible=true')
    .getOrCreate()
)

spark.sparkContext.setLogLevel("ERROR")

print(spark.version)



:: loading settings :: url = jar:file:/opt/conda/lib/python3.10/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /root/.ivy2/cache
The jars for the packages stored in: /root/.ivy2/jars
org.apache.hadoop#hadoop-aws added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-9ee0ff5b-1e5b-41d8-89f9-e7fdae6b40f2;1.0
	confs: [default]
	found org.apache.hadoop#hadoop-aws;3.4.0 in central
	found software.amazon.awssdk#bundle;2.23.19 in central
	found org.wildfly.openssl#wildfly-openssl;1.1.3.Final in central
:: resolution report :: resolve 458ms :: artifacts dl 34ms
	:: modules in use:
	org.apache.hadoop#hadoop-aws;3.4.0 from central in [default]
	org.wildfly.openssl#wildfly-openssl;1.1.3.Final from central in [default]
	software.amazon.awssdk#bundle;2.23.19 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	--------------------------------------------------------

3.5.0


In [3]:
import sys
import shutil
import os
import time
import boto3
import pandas as pd
import polars as pl
from __future__ import annotations
from libs.denodo.denodo_connection import DenodoConnection
from libs.dme_sql_queries import *
from libs.dme_pyspark_sql_queries import *
from libs.metric_utils import *
from libs.regression_utils import reg_adjust_parallel_rm_pyspark

### testing new breakout process

ap_data_sector = "CORN_NA_SUMMER"
analysis_run_group = "genomic_prediction"
analysis_year = 2023
target_pipeline_runid = "20240723_00_00_00"
force_refresh = "True"
breakout_level = "market_segment"
#current_source_ids = "('224WNMTYG501')"
# current_source_ids = "('2024_TPP11_PLC4_140_3_SYNR')"

### init- get current source ID's
t0 = time.time()
current_source_ids = get_source_ids(ap_data_sector, analysis_year, analysis_run_group, target_pipeline_runid, force_refresh)
print("source_id call time: {0}".format((time.time() - t0)))
print(current_source_ids)

# spark.sparkContext.getConf().getAll()





            SELECT DISTINCT pvs.source_id, pvs.pipeline_runid
                FROM(
                    SELECT source_id, trait, MAX(pipeline_runid) AS pipeline_runid
                    FROM "managed"."rv_ap_all_pvs"
                    WHERE ap_data_sector = 'CORN_NA_SUMMER'
                    AND CAST(source_year as integer) IN (2023)
                    AND analysis_type in ('GenoPred')
                    AND LOWER(loc) = 'all'
                    GROUP BY source_id, trait
                ) pvs
            INNER JOIN(
                SELECT decision_group
                    FROM "managed"."rv_ap_sector_experiment_config"
                WHERE ap_data_sector_name = 'CORN_NA_SUMMER' 
                    AND CAST(analysis_year as integer) IN (2023)
                    AND adapt_display = 1
            ) asec
                ON asec.decision_group = pvs.source_id
            INNER JOIN(
                SELECT trait 
                    FROM "managed"."rv_ap_sector_trait_config"
   

  df = pd.read_sql_query(sql, self.__denodo_con)


source_id call time: 2.6266045570373535
('23SUZYYG461', '2023_W_ALL_STG3_105_2_SNYR', '2023_N_ALL_STG2_100_1_SNYR', '2023_DB_ALL_STG4_95_1_SNYR', '23SUWEYG36Z', '23SUNEYG30A', '2023_N_ALL_STG3_95_1_SNYR', '2023_DB_ALL_STG2_90_1_SNYR', '2023_N_ALL_STG3_100_2_SNYR', '2023_W_ALL_STG2_105_3_SNYR', '23SUWEYG35Z', '2023_W_ZY_STG3_115_2_SNYR', '2023_N_ZY_STG3_95_2_SNYR', '2023_DB_ALL_STG1_105_1_SNYR', '23SUDBYG42A', '2023_DB_ALL_STG4_105_1_SNYR', '2023_DB_ALL_STG2_95_2_SNYR', '2023_W_ALL_STG4_115_1_SNYR', '23SUZYYG430', '23SUDBYG41A', '23SUWEYG34Z', '23SUNEYG30B', '2023_W_ALL_STG2_105_1_SNYR', '2023_DB_ALL_STG2_115_2_SNYR', '2023_E_ALL_STG3_115_1_SNYR', '2023_E_ALL_STG2_115_2_SNYR', '2023_DB_ALL_STG2_90_2_SNYR', '2023_DB_ALL_STG2_85_1_SNYR', '2023_W_ALL_STG4_110_1_SNYR', '2023_E_ALL_STG2_120_2_SNYR', '2023_E_ALL_STG3_120_1_SNYR', '2023_E_ALL_STG3_110_2_SNYR', '2023_E_ALL_STG3_105_1_SNYR', '2023_E_ALL_STG2_105_1_SNYR', '2023_DB_ALL_STG2_100_2_SNYR', '23SUZYYG311', '2023_N_ALL_STG4_80_1_SNYR', 

In [4]:
### get checks
data_sector_config = get_data_sector_config(ap_data_sector)

checks_df = query_check_entries(ap_data_sector,
                analysis_year)
print("initial number of rows in checks df: {0}".format(checks_df.shape[0]))
print("number of unique source id's in checks df: {0}".format(checks_df.get_column("decision_group").unique().shape[0]))
print("number of unique entry id's in checks df: {0}".format(checks_df.get_column("be_bid").unique().shape[0]))
print(checks_df.head(5))

checks_df = create_check_df(analysis_run_group, checks_df)

print("number of rows in final checks df: {0}".format(checks_df.shape[0]))
print("number of uniques in final checks df: {0}".format(checks_df.select("ap_data_sector", "analysis_year", "decision_group", "be_bid", "material_type").unique().shape[0]))
print("number of unique source id's in final checks df: {0}".format(checks_df.get_column("decision_group").unique().shape[0]))
print("number of unique entry id's in final checks df: {0}".format(checks_df.get_column("be_bid").unique().shape[0]))

# checks_df["par1_be_bid"] = checks_df.par1_be_bid.fillna('')
# checks_df["par2_be_bid"] = checks_df.par2_be_bid.fillna('')

# for mat_type in checks_df["material_type"].drop_duplicates():
#     print(mat_type)
#     print(checks_df.loc[checks_df.material_type == mat_type, ["cpifl", "cperf", "cagrf", "cmatf"]].describe())
    
# print(checks_df.loc[(checks_df.decision_group == '24WNMTYG501') & (checks_df.cperf == 1), :].head(5))

checks_df.write_csv("notebook_output/checks_output.csv")


                        SELECT 
                            ap_data_sector_name,
                            spirit_crop_guid
                          FROM "managed"."rv_ap_data_sector_config"
                        WHERE "ap_data_sector_name" = 'CORN_NA_SUMMER'


  df = pd.read_sql_query(sql, self.__denodo_con)



    SELECT DISTINCT
        "checks"."ap_data_sector" AS "ap_data_sector",
        "checks"."analysis_year" AS "analysis_year",
        "checks"."decision_group" AS "decision_group",
        "checks"."untested_entry_display" AS "untested_entry_display",
        "checks"."be_bid" AS "be_bid",
        CAST(MAX(checks."cpifl",
             "checks"."cperf", 
             "checks"."cagrf", 
             "checks"."cmatf",
             "checks"."cregf",
             "checks"."crtnf") AS boolean) AS "cpifl",
        "checks"."cperf" AS "cperf",
        "checks"."cagrf" AS "cagrf",
        "checks"."cmatf" AS "cmatf",
        "checks"."cregf" AS "cregf",
        "checks"."crtnf" AS "crtnf",
        CASE 
            WHEN "cmt"."fp_het_pool" = 'pool1' 
                THEN checks.fp_ltb
            WHEN cmt.mp_het_pool = 'pool1'
                THEN checks.mp_ltb
            WHEN  "cmt"."tester_role" = 'M' 
                THEN CAST(MAX("checks"."cpifl",
                 "checks"."cperf", 
   

KeyboardInterrupt: 

In [12]:
#pvs pipeline
analysis_type = get_analysis_types(analysis_run_group)
n_partitions = current_source_ids.count(",")+1

pvs_input_df = query_pvs_input(ap_data_sector, analysis_year, analysis_run_group, current_source_ids, breakout_level)
print("number of rows in pvs: {0}".format(pvs_input_df.shape[0]))
print("number of unique source id's in pvs: {0}".format(pvs_input_df.get_column("decision_group").unique().shape[0]))
print("number of unique entry id's in pvs: {0}".format(pvs_input_df.get_column("be_bid").unique().shape[0]))
print(pvs_input_df.select(["prediction", "stderr", "count"]).describe())
pvs_input_df.write_csv('notebook_output/pvs_input.csv')

metric_config_sp_df = spark.read.csv('dme_core_pipeline/data/metric_config.csv', inferSchema=True, header=True)
regression_config_sp_df = spark.read.csv('dme_core_pipeline/data/regression_cfg.csv', inferSchema=True, header=True)
regression_config_sp_df = regression_config_sp_df.filter("analysis_year = {0} and ap_data_sector = '{1}' and analysis_type IN {2}".format(analysis_year, ap_data_sector,analysis_type))

pvs_input_sp_df = spark.createDataFrame(pvs_input_df.to_pandas()).repartition(n_partitions, 'decision_group', 'trait')
checks_sp_df = spark.createDataFrame(checks_df.to_pandas())

pvs_input_sp_df.createOrReplaceTempView('pvs_input')
checks_sp_df.createOrReplaceTempView('cpifl_table')
metric_config_sp_df.createOrReplaceTempView('metric_cfg')

missing_checks_df = pvs_input_sp_df.join(checks_sp_df, ['ap_data_sector', 'analysis_year', 'decision_group', 'be_bid'], 'left')
print("number of entries missing check info: {0}".format(missing_checks_df.filter('cpifl IS NULL').count()))

print("pvs_input: number of rows: {0}".format(pvs_input_sp_df.count()))
print("number of unique source id's in pvs: {0}".format(pvs_input_sp_df.select("decision_group").distinct().count()))
print("number of unique entry id's in pvs: {0}".format(pvs_input_sp_df.select("be_bid").distinct().count()))
pvs_input_sp_df.select("breakout_level", "breakout_level_value").distinct().show()

# Set recipe variables
alpha = 0.3

gr_cols = ['ap_data_sector', 'analysis_year', 'analysis_type', 'decision_group', 'material_type',
           'breakout_level', 'breakout_level_value', 'x', 'y']

gr_cols2 = ['ap_data_sector', 'analysis_year', 'analysis_type', 'decision_group', 
            'material_type', 'breakout_level', 'breakout_level_value', 'trait']

id_cols = ['be_bid', 'count', 'prediction', 'stddev', 'chkfl']

pvs_input_sp_df.createOrReplaceTempView('pvs_input')

if regression_config_sp_df.count() > 0:
    regression_config_sp_df.createOrReplaceTempView('regression_cfg')
    regression_input = merge_pvs_regression_input(spark)
    regression_input = regression_input.pandas_api()

    pvs_regression_output_df = regression_input.groupby(gr_cols).apply(reg_adjust_parallel_rm_pyspark,
                                                                       alpha=alpha)
    if pvs_regression_output_df.shape[0] > 0:
        pvs_regression_output_df = pvs_regression_output_df.loc[
            pvs_regression_output_df.adjusted == 'Yes'].to_spark()
        pvs_regression_output_df.createOrReplaceTempView('pvs_reg_output')
        pvs_metric_raw_df = merge_pvs_cpifl_regression(spark)
        spark.catalog.dropTempView('pvs_reg_output')
        
        pvs_regression_output_df.printSchema()
    else:
        pvs_metric_raw_df = merge_pvs_cpifl(spark)
else:
    pvs_metric_raw_df = merge_pvs_cpifl(spark)


print("number of rows in pvs after merging in checks and reg: {0}".format(pvs_metric_raw_df.count()))
print("number of unique source id's in pvs x check: {0}".format(pvs_metric_raw_df.select("decision_group").distinct().count()))
print("number of unique entry id's in pvs x check: {0}".format(pvs_metric_raw_df.select("be_bid").distinct().count()))
pvs_metric_raw_df.select("count", "prediction", "stddev", "cpifl", "chkfl").summary().show()

    
# Compute recipe outputs
pvs_metric_raw_df.createOrReplaceTempView('pvs_metric_raw')
spark.catalog.dropTempView('regression_cfg')
spark.catalog.dropTempView('pvs_input')

pvs_df = merge_pvs_config(spark, pvs_metric_raw_df, gr_cols2)
spark.catalog.dropTempView('pvs_metric_raw')

print("pvs shape after merging in metric_config: {0}".format(pvs_df.count()))
print("number of unique source id's in pvs x metric: {0}".format(pvs_df.select("decision_group").distinct().count()))
print("number of unique entry id's in pvs x metric: {0}".format(pvs_df.select("be_bid").distinct().count()))
pvs_df.select("breakout_level", "breakout_level_value").distinct().show()
pvs_df.select("count", "prediction", "stddev", "cpifl", "chkfl").summary().show()
pvs_df.repartition(1).write.csv('notebook_output/pvs_intermediate_output.csv', mode='overwrite', header = True)

if pvs_df.count() == 0:
    pvs_df = create_empty_out()
else:
    # pvs_df = pvs_df.pandas_api().pandas_on_spark.apply_batch(run_pvs_metrics, id_cols = id_cols, gr_cols2 = gr_cols2)
    pvs_df = pvs_df.toPandas()
    pvs_df = run_pvs_metrics(pvs_df, id_cols, gr_cols2)
    
#print("pvs shape after running metrics: {0}".format(pvs_df.shape))
#pvs_df.spark.repartition(1).to_csv('notebook_output/pvs_output.csv')
pvs_df.to_csv('notebook_output/pvs_output.csv')

#pvs_df.info()
#missing_pvs_df = pvs_input_sp_df.join(pvs_df.to_spark(), ['ap_data_sector', 'analysis_type', 'analysis_year', 'decision_group', 'be_bid', 'trait'], 'left')
# print("number of entries missing metrics info: {0}".format(missing_pvs_df.filter('abs_mean_prob IS NULL').count()))
# missing_pvs_df.filter('abs_mean_prob IS NULL').show()



    SELECT 
        pvs.ap_data_sector,
        pvs.analysis_type,
        pvs.source_year AS analysis_year,
        pvs.source_id AS decision_group,
        dg_rm_list.stage,
        dg_rm_list.decision_group_rm,
        pvs.market_seg,
        COALESCE( pvs.loc, 'all') AS pvs_loc,
        pvs.trait,
        pvs.entry_identifier AS be_bid,
        COALESCE(LOWER(pvs.material_type), 'entry') AS material_type,
        COALESCE(CAST(pvs.count AS integer), 0) AS count,
        CASE WHEN LOWER(pvs.material_type) IN ('entry', 'hybrid')
            THEN MAX(0.0, pvs.prediction)
            ELSE pvs.prediction
        END AS prediction,
        CASE 
            WHEN pvs.stderr IS NULL 
                AND LOWER(pvs.material_type) = 'entry' 
                AND LOWER(trait_cfg.distribution_type) LIKE 'norm%'
                THEN pvs.prediction*10
            WHEN pvs.stderr > 1000 
                AND pvs.prediction < 500
                THEN MIN(pvs.stderr, pvs.prediction*10) 
            E

                                                                                

number of entries missing check info: 0


                                                                                

pvs_input: number of rows: 40265


                                                                                

number of unique source id's in pvs: 1
number of unique entry id's in pvs: 722


                                                                                

+--------------+--------------------+
|breakout_level|breakout_level_value|
+--------------+--------------------+
|          meso|                sumt|
|          meso|                cnms|
|            na|                 all|
|          meso|                sogo|
|          meso|                padf|
|          meso|                trmg|
+--------------+--------------------+



                                                                                

number of rows in pvs after merging in checks and reg: 40265
number of unique source id's in pvs x check: 1
number of unique entry id's in pvs x check: 722


                                                                                

+-------+-----------------+------------------+-------------------+--------------------+
|summary|            count|        prediction|              cpifl|               chkfl|
+-------+-----------------+------------------+-------------------+--------------------+
|  count|            40265|             40265|              40265|               40265|
|   mean|6.774071774493978| 708.4817157792082|0.05051533589966472|0.003849497081832...|
| stddev|17.69752545935651|2595.6569111996187|0.21900851117007597| 0.06192554959334477|
|    min|                1|               0.0|                  0|                   0|
|    25%|                2| 0.423253951134534|                  0|                   0|
|    50%|                3|  2.76522379226292|                  0|                   0|
|    75%|                7|           18.1702|                  0|                   0|
|    max|              842|        13751.7173|                  1|                   1|
+-------+-----------------+-----

                                                                                

pvs shape after merging in metric_config: 40265


                                                                                

number of unique source id's in pvs x metric: 1
number of unique entry id's in pvs x metric: 722


                                                                                

+--------------+--------------------+
|breakout_level|breakout_level_value|
+--------------+--------------------+
|          meso|                sumt|
|          meso|                cnms|
|            na|                 all|
|          meso|                sogo|
|          meso|                padf|
|          meso|                trmg|
+--------------+--------------------+



                                                                                

+-------+------------------+------------------+------------------+-------------------+--------------------+
|summary|             count|        prediction|            stddev|              cpifl|               chkfl|
+-------+------------------+------------------+------------------+-------------------+--------------------+
|  count|             40265|             40265|             38895|              40265|               40265|
|   mean| 6.774071774493978| 708.4817157791889|               NaN|0.05051533589966472|0.003849497081832...|
| stddev|17.697525459356626|2595.6569111996178|               NaN| 0.2190085111700764| 0.06192554959334482|
|    min|                 1|               0.0|            1.0E-4|                  0|                   0|
|    25%|                 2| 0.423253951134534|0.3107563312900829|                  0|                   0|
|    50%|                 3|  2.76522379226292|1.1879762203007265|                  0|                   0|
|    75%|                 7|

                                                                                

In [6]:
## stability (trial data) pipeline
from pyspark import StorageLevel

alpha = 0.3

gr_cols = ['ap_data_sector', 'analysis_year', 'trial_id', 'x', 'y']
cols = ['ap_data_sector', 'analysis_year', 'trial_id', 'be_bid', 'year', 'experiment_id', 'x', 'y',
        'function', 'plot_barcode',
        'trait', 'prediction_x', 'prediction', 'analysis_target_y', 'trial_pts', 'analysis_pts',
        'adjusted_prediction', 'adj_model',
        'adj_outlier', 'p_value', 'slope1', 'slope2', 'intercept', 'residual', 'adjusted']

col_partitions = ['ap_data_sector', 'analysis_year', 'decision_group', 'breakout_level',
                  'breakout_level_value']

metric_input_cols = ['ap_data_sector', 'analysis_year', 'analysis_type', 'decision_group', 'be_bid',
                     'material_type', 'breakout_level', 'breakout_level_value',
                     'trial_id', 'trait', 'result_numeric_value', 'metric_name',
                     'cpifl', 'chkfl', 'distribution_type', 'direction', 'threshold_factor',
                     'spread_factor', 'weight', 'adv_weight']

gr_cols2 = ['ap_data_sector', 'analysis_year', 'analysis_type', 'decision_group', 'be_bid',
            'material_type', 'breakout_level', 'breakout_level_value',
            'trait', 'metric_name', 'distribution_type', 'direction']

t0 = time.time()
trial_numeric_input_df = query_trial_input(ap_data_sector,
                            analysis_year,
                            analysis_run_group,
                            current_source_ids,
                            breakout_level,
                            'numeric')

trial_numeric_input_df.write_csv("notebook_output/trial_numeric.csv")
print("trial data call time: {0}".format((time.time() - t0)))

print('compute_trial_comparison_metric_input: trial_numeric_input_df data count={0}'.format(trial_numeric_input_df.shape[0]))
print("compute_trial_comparison_metric_input: trial_numeric_input_df data unique source_id's: {0}".format(trial_numeric_input_df.get_column("experiment_id").unique().shape[0]))
print("compute_trial_comparison_metric_input: trial_numeric_input_df data unique entry_id's: {0}".format(trial_numeric_input_df.get_column("be_bid").unique().shape[0]))
print(trial_numeric_input_df.head())

n_partitions = (current_source_ids.count(",")+1)*10

trial_data_sp_df = spark.createDataFrame(trial_numeric_input_df.to_pandas()).repartition(n_partitions, 'decision_group', 'trait', 'breakout_level', 'breakout_level_value')

## Bypass regression

print('compute_trial_comparison_metric_input: trial_data_sp_df data count={0}'.format(trial_data_sp_df.count()))
trial_data_sp_df.show(n=5)

trial_data_sp_df.createOrReplaceTempView('tr_data1')

# Use cpifl table to get parentage, and then create trial pheno data for parents and append to entry-level data
trial_check_sp_df = merge_trial_cpifl(spark, 'numeric')
print('compute_trial_comparison_metric_input: trial_check_sp_df count={0}'.format(trial_check_sp_df.count()))

trial_window = Window.partitionBy('trial_id')
trial_check_sp_df = trial_check_sp_df.withColumn('mincpi',
                                                 F.min('cpifl').over(trial_window)) \
    .where(0 == F.col('mincpi')) \
    .drop('mincpi')

trial_check_sp_df.repartition(n_partitions, col_partitions).createOrReplaceTempView(
    'trc_data3')

print('compute_trial_comparison_metric_input_df: After partitionBy trial_check_sp_df count={0}'.format(trial_check_sp_df.count()))


# Apply metric config
trial_check_met_sp_df = merge_trial_config(spark, 'numeric')
print('compute_trial_comparison_metric_input_df: After merge_trial_pheno_config count={0}'.format(trial_check_met_sp_df.count()))

trial_check_met_sp_df.createOrReplaceTempView('trial_pheno_metric_input')

trial_check_met_sp_df = trial_check_met_sp_df.filter(
    "distribution_type != 'rating'").persist(StorageLevel.MEMORY_AND_DISK)

# Create metric_type = threshold/rating input
# Rating metric input
trial_rating_metric_input_df2 = prepare_rating_metric_input(trial_check_met_sp_df,
                                                            metric_input_cols, gr_cols2)
print('compute_trial_comparison_metric_input_df: After prepare_rating_metric_input count={0}'.format(trial_rating_metric_input_df2.count()))
# Create metric_type = pct_check output
# create h2h structure
h2h_input = merge_trial_h2h(spark, trial_check_met_sp_df)
print('compute_trial_comparison_metric_input_df: After merge_h2h_placement count={0}'.format(h2h_input.count()))

h2h_input = h2h_input.unionByName(trial_rating_metric_input_df2, allowMissingColumns=True)
print('compute_trial_comparison_metric_input_df: After unionByName, count={0}'.format(h2h_input.count()))
h2h_input.show(n=5)

# spark.catalog.dropTempView('trial_pheno_metric_input')
# trial_check_met_sp_df.unpersist()
# h2h_output = h2h_input.pandas_api().pandas_on_spark.apply_batch(run_trial_metrics)
# print('compute_trial_comparison_metric_input_df: Apply run_metrics count={0}'.format(h2h_output.count()))

# h2h_output.spark.repartition(1).to_csv('notebook_output/trial_output.csv')

h2h_input = h2h_input.toPandas()
h2h_output = run_trial_metrics(h2h_input)
h2h_output.to_csv('notebook_output/trial_output.csv')


    SELECT DISTINCT
        asec.ap_data_sector,
        CAST(asec.analysis_year as integer) as analysis_year,
        asec.experiment_id,
        asec.decision_group,
        asec.decision_group_rm,
        asec.stage,
        astc.trait,
        astc.distribution_type,
        astc.direction,
        astc.conv_operator,
        CAST(astc.conv_factor AS float) as conv_factor,
        CAST(astc.yield_trait AS integer) AS yield_trait,
        astc.level,
        astc.metric_name,
        astc.dme_chkfl,
        astc.dme_reg_x,
        astc.dme_reg_y,
        astc.dme_rm_est,
        astc.dme_weighted_trait
      FROM (
        SELECT
            CAST(analysis_year AS integer) AS analysis_year,
            trait,
            FIRST(distribution_type) AS distribution_type,
            FIRST(direction) AS direction,
            FIRST(conv_operator) AS conv_operator,
            FIRST(conv_factor) AS conv_factor,
            FIRST(yield_trait) AS yield_trait,
            FIRST(level) AS lev

NameError: name 'spark' is not defined

In [21]:
## text output

t0 = time.time()
trial_alpha_input_df = query_trial_input(ap_data_sector,
                            analysis_year,
                            analysis_run_group,
                            current_source_ids,
                            breakout_level,
                            'alpha')
print("trial alpha data call time: {0}".format((time.time() - t0)))

### create empty text metric output
#text_df = merge_trial_text_input(ap_data_sector, analysis_year, current_source_ids)
#text_sp_df = spark.createDataFrame(text_df)
text_metric_output_df = create_empty_out(spark).toPandas()

# Merges pvs_output and trial_output and then appends text output.

# Generate H2H metrics
def compute_mti(df, group_cols, weight_col="weight", h2h_mode=False):
    base_group_cols = [
        "ap_data_sector",
        "analysis_year",
        "analysis_type",
        "decision_group",
        "material_type",
        "breakout_level",
        "breakout_level_value",
        "be_bid",
    ]

    sum_cols = ["count"]
    mean_cols = ["prediction", "stddev"]
    max_cols = ["cpifl", "chkfl"]

    if h2h_mode:
        sum_cols = sum_cols + ["check_count"]
        mean_cols = mean_cols + ["check_prediction", "check_stddev"]
        max_cols = max_cols + ["check_chkfl"]

    if weight_col == None:
        mean_cols = mean_cols + ["weight", "adv_weight", "pctchk", "statistic"]
    elif weight_col == "weight":
        mean_cols = mean_cols + ["adv_weight"]

    if weight_col == None:
        df = df.group_by(base_group_cols + group_cols).agg(
            [pl.col(f"{c}").sum() for c in sum_cols]
            + [pl.col(f"{c}").mean() for c in mean_cols]
            + [pl.col(f"{c}").max() for c in max_cols]
            + [np.exp(np.log(pl.col("metric_value")).mean())]
        )
    else:
        df = df.group_by(base_group_cols + group_cols).agg(
            [pl.col(f"{c}").sum() for c in sum_cols]
            + [pl.col(f"{c}").mean() for c in mean_cols]
            + [pl.col(f"{c}").max() for c in max_cols]
            + [
                pl.lit("aggregate").alias("metric_method"),
                ((pl.col("pctchk") * pl.col(weight_col)).sum())
                / (pl.col(weight_col).sum()).alias("pctchk"),
                ((pl.col("statistic") * pl.col(weight_col)).sum())
                / (pl.col(weight_col).sum()).alias("statistic"),
                (
                    np.exp(
                        ((np.log(pl.col("metric_value")) * pl.col(weight_col)).sum())
                        / (pl.col(weight_col).sum())
                    )
                ),
            ]
        )

    if h2h_mode:
        df = df.with_columns(pl.lit("aggregate").alias("trait"))

    if weight_col == "adv_weight":
        df = df.with_columns(pl.lit("advancement").alias("metric_name"))

    return df

base_group_cols = [
    "ap_data_sector",
    "analysis_year",
    "analysis_type",
    "decision_group",
    "material_type",
    "breakout_level",
    "breakout_level_value",
    "be_bid",
]

bebid_data_cols = [
    "count",
    "prediction",
    "stddev",
    "cpifl",
    "chkfl",
]

check_bebid_data_cols = [
    "check_be_bid",
    "check_count",
    "check_prediction",
    "check_stddev",
    "check_chkfl",
]

base_metric_cols = [
    "metric_name",
    "pctchk",
    "statistic",
    "metric_value",
    "metric_method",
]

stacked_df_cols = (
    base_group_cols
    + ["trait", "weight", "adv_weight"]
    + bebid_data_cols
    + check_bebid_data_cols
    + base_metric_cols
)

stacked_df = pl.concat(
    [
        pl.from_pandas(pvs_df).select(stacked_df_cols),
        pl.from_pandas(h2h_output).select(stacked_df_cols),
        pl.from_pandas(text_metric_output_df).select(stacked_df_cols),
    ],
    how="vertical_relaxed",
)

h2h_metric_df = compute_mti(
    stacked_df.filter(
        (pl.col("metric_name") != "h2h")
        & (pl.col("weight") > 0)
        & (pl.col("check_be_bid").is_not_null())
    ),
    group_cols=["check_be_bid", "metric_name"],
    weight_col="weight",
    h2h_mode=True,
)

h2h_adv_df = compute_mti(
    h2h_metric_df.filter(pl.col("adv_weight") > 0),
    group_cols=["check_be_bid"],
    weight_col="adv_weight",
    h2h_mode=True,
)

h2h_metric_cols = (
    base_group_cols
    + ["trait"]
    + bebid_data_cols
    + check_bebid_data_cols
    + base_metric_cols
)

h2h_metric_df = pl.concat(
    [
        stacked_df.filter(
            (pl.col("check_be_bid").is_not_null())
        ).select(h2h_metric_cols),
        h2h_metric_df.select(h2h_metric_cols),
        h2h_adv_df.select(h2h_metric_cols),
    ],
    how="vertical",
)

h2h_metric_df.write_csv('notebook_output/h2h_metric_df.csv')

agg_trait_df = compute_mti(
    stacked_df.filter(
        (pl.col("metric_name") != "h2h") & (pl.col("weight") > 0)
    ),
    group_cols=["trait", "metric_name"],
    weight_col=None,
    h2h_mode=False,
)

agg_metric_df = compute_mti(
    agg_trait_df,
    group_cols=["metric_name"],
    weight_col="weight",
    h2h_mode=False,
)

agg_adv_df = compute_mti(
    agg_metric_df,
    group_cols=[],
    weight_col="adv_weight",
    h2h_mode=False,
)

agg_metric_cols = base_group_cols + bebid_data_cols + base_metric_cols

agg_metric_df = pl.concat(
    [
        agg_metric_df.select(agg_metric_cols),
        agg_adv_df.select(agg_metric_cols),
    ],
    how="vertical",
)

agg_metric_df = agg_metric_df.with_columns(
    pl.when((pl.col("pctchk") >= 0) & (pl.col("pctchk") < 1000))
    .then(pl.col("pctchk"))
    .otherwise(pl.lit(-1)),
    pl.when((pl.col("statistic") > -99) & (pl.col("statistic") < 99))
    .then(pl.col("statistic"))
    .otherwise(pl.lit(-99)),
    pl.when(
        (pl.col("metric_value") > 1) & (pl.col("metric_value") < 99)
    )
    .then(pl.col("metric_value"))
    .otherwise(pl.lit(0)),
)

agg_metric_df.write_csv('notebook_output/agg_metric_df.csv')
# Generate old agg format


    SELECT DISTINCT
        asec.ap_data_sector,
        CAST(asec.analysis_year as integer) as analysis_year,
        asec.experiment_id,
        asec.decision_group,
        asec.decision_group_rm,
        asec.stage,
        astc.trait,
        astc.distribution_type,
        astc.direction,
        astc.conv_operator,
        CAST(astc.conv_factor AS float) as conv_factor,
        CAST(astc.yield_trait AS integer) AS yield_trait,
        astc.level,
        astc.metric_name,
        astc.dme_chkfl,
        astc.dme_reg_x,
        astc.dme_reg_y,
        astc.dme_rm_est,
        astc.dme_weighted_trait
      FROM (
        SELECT
            CAST(analysis_year AS integer) AS analysis_year,
            trait,
            FIRST(distribution_type) AS distribution_type,
            FIRST(direction) AS direction,
            FIRST(conv_operator) AS conv_operator,
            FIRST(conv_factor) AS conv_factor,
            FIRST(yield_trait) AS yield_trait,
            FIRST(level) AS lev