In [1]:
from pyspark.sql import SparkSession
import pyspark.sql.types as T
import numpy as np
import polars as pl

pl.Config(tbl_rows=100)

spark = (SparkSession.builder
        .master("local[6]")
        .appName('PySpark_Prepare')
        .config("spark.driver.memory", "15g")
        .getOrCreate())

dataset_schema = [
    T.StructField('slctn_nmbr', T.IntegerType(), True),
    T.StructField('client_id', T.StringType(), True),
    T.StructField('npo_account_id', T.StringType(), True),
    T.StructField('npo_accnts_nmbr', T.IntegerType(), True),
    T.StructField('pmnts_type', T.IntegerType(), True),
    T.StructField('year', T.IntegerType(), True),
    T.StructField('quarter', T.StringType(), True),
    T.StructField('gender', T.IntegerType(), True),
    T.StructField('age', T.IntegerType(), True),
    T.StructField('clnt_cprtn_time_d', T.IntegerType(), True),
    T.StructField('actv_prd_d', T.IntegerType(), True),
    T.StructField('lst_pmnt_rcnc_d', T.IntegerType(), True),
    T.StructField('balance', T.FloatType(), True),
    T.StructField('oprtn_sum_per_qrtr', T.FloatType(), True),
    T.StructField('oprtn_sum_per_year', T.FloatType(), True),
    T.StructField('frst_pmnt_date', T.StringType(), True),
    T.StructField('lst_pmnt_date_per_qrtr', T.IntegerType(), True),
    T.StructField('frst_pmnt', T.FloatType(), True),
    T.StructField('lst_pmnt', T.FloatType(), True),
    T.StructField('pmnts_sum', T.FloatType(), True),
    T.StructField('pmnts_nmbr', T.IntegerType(), True),
    T.StructField('pmnts_sum_per_qrtr', T.FloatType(), True),
    T.StructField('pmnts_sum_per_year', T.FloatType(), True),
    T.StructField('pmnts_nmbr_per_qrtr', T.IntegerType(), True),
    T.StructField('pmnts_nmbr_per_year', T.IntegerType(), True),
    T.StructField('incm_sum', T.FloatType(), True),
    T.StructField('incm_per_qrtr', T.FloatType(), True),
    T.StructField('incm_per_year', T.FloatType(), True),
    T.StructField('mgd_accum_period', T.FloatType(), True),
    T.StructField('mgd_payment_period', T.FloatType(), True),
    T.StructField('phone_number', T.IntegerType(), True),
    T.StructField('email', T.IntegerType(), True),
    T.StructField('lk', T.IntegerType(), True),
    T.StructField('assignee_npo', T.IntegerType(), True),
    T.StructField('assignee_ops', T.IntegerType(), True),
    T.StructField('postal_code', T.StringType(), True),
    T.StructField('region', T.StringType(), True),
    T.StructField('citizen', T.IntegerType(), True),
    T.StructField('fact_addrss', T.IntegerType(), True),
    T.StructField('appl_mrkr', T.IntegerType(), True),
    T.StructField('evry_qrtr_pmnt', T.IntegerType(), True),
    T.StructField('churn', T.IntegerType(), True)
]

dataset_struct = T.StructType(fields=dataset_schema)

dataset = spark.read.csv('dataset/train.csv', sep=',', header=True, schema=dataset_struct)
# удаляем ненужные параметры
dataset = dataset.drop(*['oprtn_sum_per_year', 'frst_pmnt_date', 'lst_pmnt_date_per_qrtr', 'pmnts_sum_per_year',
                        'pmnts_nmbr_per_year', 'incm_per_year', 'postal_code', 'npo_accnts_nmbr',
                        'slctn_nmbr', 'client_id'])
dataset = dataset.dropDuplicates()
dataset = dataset.dropna()
# сортируем
dataset = dataset.sort(['npo_account_id', 'quarter'])

lag = 8
# данные для обогощения датасета
data_cntrbtrs = pl.read_csv('dataset/cntrbtrs.csv', separator=';')
region_encoder = pl.read_csv('dataset/region_encoder.csv')
gdp = pl.read_csv('dataset/gdp.csv')
mrot = pl.read_csv('dataset/mrot.csv')
usd = pl.read_csv('dataset/usd.csv')

# оставляем только счета с >= 8+1 записями
filter_ids = dataset.groupBy('npo_account_id').count().withColumnRenamed('count', 'lag')
filter_ids = filter_ids.filter(filter_ids.lag >= lag+1).select('npo_account_id')
ids = np.array(filter_ids.collect()).reshape(-1)

dataset_pl = pl.from_pandas(dataset.toPandas())# 80% времени выполняется эта строчка

24/08/27 20:45:52 WARN Utils: Your hostname, MacBook-Pro-Danil.local resolves to a loopback address: 127.0.0.1; using 192.168.0.131 instead (on interface en0)
24/08/27 20:45:52 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/08/27 20:45:52 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/08/27 20:45:54 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

ModuleNotFoundError: No module named 'distutils'

24/08/27 20:46:04 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors


In [None]:
pl.Config(tbl_rows=1000)

column_names = list()
# кодирование региона
def get_region(region):
    region = region.split(' ')[0]
    new_value = region_encoder.filter(pl.col('region') == region)['value'][0]
    return new_value

# значение мрот rub
def get_mrot(year):
    new_value = mrot.filter(pl.col('year') == year)['rubles'][0]
    return int(new_value)

# значение ввп на душу населения usd
def get_gdp(year):
    new_value = gdp.filter(pl.col('year') == year)['usd'][0]
    return int(new_value)

# значение курса доллара
def get_usd(quarter):
    new_value = usd.filter(pl.col('quarter') == quarter)['rubles'][0]
    return new_value

# тип пенсионного вклада
def get_pens_type(id):
    value = data_cntrbtrs.filter(pl.col('npo_accnt_id') == id)['accnt_pnsn_schm'][0]
    return value


new_dataset = []
for i in range(len(ids)):
    id = ids[i]
    df_id = dataset_pl.filter(pl.col('npo_account_id') == id) # берем записи конкретного счета
    df_id = df_id.drop('npo_account_id')

    target = df_id.tail(1)['churn'][0]# сохраняем таргет и тип вклада
    pens = get_pens_type(id)

    df_id = df_id.drop('churn')
    df_id = df_id.slice(len(df_id)-lag-1, lag)# оставляем последние 8 отчетов

    df_id = df_id.with_columns(pl.col('region')
                                        .map_elements(lambda x: get_region(x), return_dtype=pl.Float64)
                                        .alias('region'))# кодируем регион
    
    df_id = df_id.with_columns(pl.col('quarter')
                                        .map_elements(lambda x: get_usd(x), return_dtype=pl.Float64)
                                        .alias('usd'))# добавляем курс доллара rub
    

    df_id = df_id.with_columns(pl.col('year')
                                        .map_elements(lambda x: get_mrot(x), return_dtype=pl.Int64)
                                        .alias('mrot_rub'))# добавляем мрот rub
    
    df_id = df_id.with_columns((pl.col('mrot_rub') / pl.col('usd'))
                                         .alias('mrot_usd'))# добавляем мрот usd
    

    df_id = df_id.with_columns(pl.col('year')
                                        .map_elements(lambda x: get_gdp(x), return_dtype=pl.Int64)
                                        .alias('gdp_usd'))# добавляем ввп usd

    
    df_id = df_id.with_columns((pl.col('gdp_usd') * pl.col('usd'))
                                         .alias('gdp_rub'))# добавляем ввп rub
    
    
    df_id = df_id.with_columns(pl.col('quarter')
                                        .map_elements(lambda x: int(x[5:6]), return_dtype=pl.Int64)
                                        .alias('quarter'))
    df_id = df_id.drop('year')

    if (len(column_names) == 0):
        df_names = df_id.columns
        for i in range(lag):
            for j in range(len(df_names)):
                column_names.append(df_names[j] + '_' + str(i+1))
        column_names.append('pens_type')
        column_names.append('target')

    data_id = np.append(df_id.to_numpy().reshape(-1), (pens, target))
    new_dataset.append(data_id)
    print(len(new_dataset))
np.save('dataset_train_prepared.npy', np.array(new_dataset))

In [None]:
column_names