In [1]:
# импортируем polars, pandas и numpy
import polars as pl
import pandas as pd
import numpy as np

# импортируем Parallel, delayed
from joblib import Parallel, delayed

# увеличиваем максмальное количество 
# отображаемых столбцов
pd.options.display.max_columns = 200

## In-place операции

In-place операции полезны для подсчета гистограмм. Допустим, есть массив целочисленных значений от 0 до 99, и мы хотим посчитать, сколько раз встречается каждое.

In [2]:
# случайный массив
arr = np.random.randint(0, 100, size=1000000)
# инициализируем гистограмму нулями
hist = np.zeros(100, dtype=np.int32)
# теперь inplace заполним ее
np.add.at(hist, arr, 1)
hist

array([ 9904, 10195, 10073,  9969,  9985,  9912, 10082, 10000,  9973,
       10031, 10106,  9931, 10010, 10100, 10089,  9982, 10086, 10078,
        9972, 10135, 10093, 10041, 10018,  9832,  9666,  9917, 10082,
        9824,  9961, 10080,  9920, 10006,  9888,  9831,  9908,  9889,
        9873, 10212, 10032,  9908, 10160, 10038, 10075,  9973,  9902,
       10001, 10003,  9809, 10008,  9994, 10154, 10127,  9908, 10059,
       10037, 10003,  9878, 10038,  9981,  9941,  9960, 10011, 10009,
       10049, 10130,  9979, 10018,  9818,  9882,  9933, 10058, 10141,
        9978,  9918, 10079, 10063,  9896, 10028,  9780,  9863, 10077,
       10229, 10106,  9975, 10007, 10079, 10056,  9890, 10031, 10181,
       10011, 10054,  9988,  9897, 10117, 10064,  9845,  9932, 10207,
        9978], dtype=int32)

можно с несколькими размерностями

In [3]:
arr0 = np.random.randint(0, 10, size=1000000)
arr1 = np.random.randint(0, 5, size=1000000)
# инициализируем гистограмму нулями
hist = np.zeros((10, 5), dtype=np.int32)
# теперь inplace заполним ее
np.add.at(hist, (arr0, arr1), 1)
hist

array([[20218, 20056, 19972, 19916, 19922],
       [20031, 20175, 19930, 19787, 19967],
       [20000, 20140, 20110, 19978, 20181],
       [19883, 19967, 20009, 20125, 20090],
       [20054, 19950, 20203, 20035, 20093],
       [20032, 20078, 19992, 20088, 20047],
       [20065, 20172, 19847, 19717, 19788],
       [19869, 19925, 20231, 20016, 19671],
       [19849, 19934, 20046, 19957, 20025],
       [19868, 20042, 20038, 19942, 19969]], dtype=int32)

можно не только считать, но и делать, например, groupby sum

In [4]:
arr0 = np.random.randint(0, 10, size=1000000)
arr1 = np.random.randint(0, 5, size=1000000)

x = np.random.rand(1000000)

hist = np.zeros((10, 5), dtype=np.float64)
# типа аналог x.groupby(arr0, arr1).sum()
np.add.at(hist, (arr0, arr1), x)
hist

array([[10116.1283585 , 10095.93284622,  9863.3618597 ,  9968.13282714,
        10073.84465794],
       [10032.81650577,  9900.46834809, 10000.79824479,  9949.830191  ,
         9949.12923472],
       [ 9871.97989648,  9897.95408647,  9892.07388805, 10017.37791632,
        10062.57709799],
       [10170.38737255,  9879.49308446,  9989.53851869,  9852.26282826,
         9991.87234051],
       [10015.61809185,  9932.71842797, 10003.54851251, 10026.63651438,
        10008.1755565 ],
       [10083.74114312,  9980.37116598, 10029.33230941, 10069.00008345,
         9841.82017977],
       [10143.06823263, 10053.14544029, 10133.85421869,  9999.41121828,
        10009.71240832],
       [ 9995.23653393, 10021.01648191,  9882.69919671, 10105.71580229,
        10082.87196282],
       [ 9872.75638489, 10087.57081407,  9970.10369584, 10067.38061677,
         9789.91259869],
       [10011.67038218, 10035.69420194,  9987.26771433,  9984.60515479,
         9951.10540039]])

## Fast target encoder

Попробуем сделать быстрый target encoder по кросс валидации с подбором регуляризации. Делать будем без циклов, используя бродкастинг и in-place операции

In [5]:
arr0 = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0])
x = np.array([-5, 1, 10, 1, 3, 9, -2, -1, 7, 2])
fold = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
print('Закодируем категорию {0} средним по {1}, используя кросс-валидацию {2}'.format(arr0, x, fold))

Закодируем категорию [0 1 2 0 1 2 0 1 2 0] средним по [-5  1 10  1  3  9 -2 -1  7  2], используя кросс-валидацию [0 0 0 0 0 1 1 1 1 1]


In [6]:
# посчитаем, сколько у нас категорий и фолдов
n_cats = arr0.max() + 1
n_folds = fold.max() + 1
n_cats, n_folds

(3, 2)

In [7]:
# для начала посчитаем x.groupby(fold, arr0,).sum()
f_sum = np.zeros((n_folds, n_cats), dtype=np.float64)
np.add.at(f_sum, (fold, arr0), x)
f_sum

array([[-4.,  4., 10.],
       [ 0., -1., 16.]])

In [8]:
# теперь x.groupby(fold, arr0).count()
f_count = np.zeros((n_folds, n_cats), dtype=np.float64)
np.add.at(f_count, (fold, arr0), 1)
f_count

array([[2., 2., 1.],
       [2., 1., 2.]])

In [9]:
# далее получим из этого x.groupby(arr0).sum() и x.groupby(arr0).count()
tot_sum = f_sum.sum(axis=0)
tot_count = f_count.sum(axis=0)
tot_sum, tot_count

(array([-4.,  3., 26.]), array([4., 3., 3.]))

In [10]:
# теперь посчитаем сумму и счетчик во всех фолдах, кроме текущего
# для этого из total суммы вычтем сумму groupby fold
# не забываем про бродкастинг. Тотал имеет 2 размерности (n_cats,), 
# а groupby fold - 3 (n_folds, n_cats)
oof_sum = tot_sum[np.newaxis, ...] - f_sum
print(oof_sum)
oof_count = tot_count[np.newaxis, ...] - f_count
print(oof_count)

[[ 0. -1. 16.]
 [-4.  4. 10.]]
[[2. 1. 2.]
 [2. 2. 1.]]


In [11]:
# к ним нужно добавить регуляризацию, по формуле:
# (y_sum + a * y_prior) / (y_count + a), a подберем так, 
# чтобы иметь среднеквадратичную наименьшую ошибку
# посчитаем prior
prior = tot_sum.sum() / tot_count.sum()
prior

2.5

In [12]:
# сетка регуляризации
grid = np.linspace(0.1, 10, 3)
grid

array([ 0.1 ,  5.05, 10.  ])

In [13]:
# применим наше кодирование с разными a по сетке grid, 
# для этого вспомним бродкастинг
# выходная размерность (x.shape[0], grid.shape[0])
enc_candidates = (oof_sum[fold, arr0, np.newaxis] + grid[np.newaxis, :] * prior) /\
    (oof_count[fold, arr0, np.newaxis] + grid[np.newaxis, :])
enc_candidates

array([[ 0.11904762,  1.79078014,  2.08333333],
       [-0.68181818,  1.9214876 ,  2.18181818],
       [ 7.73809524,  4.06028369,  3.41666667],
       [ 0.11904762,  1.79078014,  2.08333333],
       [-0.68181818,  1.9214876 ,  2.18181818],
       [ 9.31818182,  3.73966942,  3.18181818],
       [-1.78571429,  1.22340426,  1.75      ],
       [ 2.02380952,  2.35815603,  2.41666667],
       [ 9.31818182,  3.73966942,  3.18181818],
       [-1.78571429,  1.22340426,  1.75      ]])

In [14]:
# найдем лучшую a по сетке grid, тоже помним про бродкастинг
best_a = ((enc_candidates - x[:, np.newaxis]) ** 2).mean(axis=0).argmin()
best_a

0

In [15]:
# элемент с индексом best_a - лучший регуляризатор
best_oof_enc = enc_candidates[:, best_a]
best_oof_enc

array([ 0.11904762, -0.68181818,  7.73809524,  0.11904762, -0.68181818,
        9.31818182, -1.78571429,  2.02380952,  9.31818182, -1.78571429])

In [16]:
# теперь сохраним кодирование на тест
test_encoder = (tot_sum + prior * grid[best_a]) / (tot_count + grid[best_a])
test_encoder

array([-0.91463415,  1.0483871 ,  8.46774194])

In [17]:
# применим на тест
best_test_enc = test_encoder[arr0]
best_test_enc

array([-0.91463415,  1.0483871 ,  8.46774194, -0.91463415,  1.0483871 ,
        8.46774194, -0.91463415,  1.0483871 ,  8.46774194, -0.91463415])

## Parallel groupby

In [18]:
%%time

# для ускорения считываем файл в датафрейм polars,
# потом в датафрейм pandas, файл можно скачать по адресу
# https://www.kaggle.com/competitions/amex-default-prediction/data?select=train_data.csv
pandas_df = pl.read_csv('Data/train_data.csv', rechunk=False).to_pandas()

CPU times: user 1min 20s, sys: 1min 20s, total: 2min 41s
Wall time: 30.7 s


In [19]:
# взглянем на первые 5 наблюдений
pandas_df.head()

Unnamed: 0,customer_ID,S_2,P_2,D_39,B_1,B_2,R_1,S_3,D_41,B_3,D_42,D_43,D_44,B_4,D_45,B_5,R_2,D_46,D_47,D_48,D_49,B_6,B_7,B_8,D_50,D_51,B_9,R_3,D_52,P_3,B_10,D_53,S_5,B_11,S_6,D_54,R_4,S_7,B_12,S_8,D_55,D_56,B_13,R_5,D_58,S_9,B_14,D_59,D_60,D_61,B_15,S_11,D_62,D_63,D_64,D_65,B_16,B_17,B_18,B_19,D_66,B_20,D_68,S_12,R_6,S_13,B_21,D_69,B_22,D_70,D_71,D_72,S_15,B_23,D_73,P_4,D_74,D_75,D_76,B_24,R_7,D_77,B_25,B_26,D_78,D_79,R_8,R_9,S_16,D_80,R_10,R_11,B_27,D_81,D_82,S_17,R_12,B_28,R_13,D_83,R_14,R_15,D_84,R_16,B_29,B_30,S_18,D_86,D_87,R_17,R_18,D_88,B_31,S_19,R_19,B_32,S_20,R_20,R_21,B_33,D_89,R_22,R_23,D_91,D_92,D_93,D_94,R_24,R_25,D_96,S_22,S_23,S_24,S_25,S_26,D_102,D_103,D_104,D_105,D_106,D_107,B_36,B_37,R_26,R_27,B_38,D_108,D_109,D_110,D_111,B_39,D_112,B_40,S_27,D_113,D_114,D_115,D_116,D_117,D_118,D_119,D_120,D_121,D_122,D_123,D_124,D_125,D_126,D_127,D_128,D_129,B_41,B_42,D_130,D_131,D_132,D_133,R_28,D_134,D_135,D_136,D_137,D_138,D_139,D_140,D_141,D_142,D_143,D_144,D_145
0,0000099d6bd597052cdcda90ffabf56573fe9d7c79be5f...,2017-03-09,0.938469,0.001733,0.008724,1.006838,0.009228,0.124035,0.008771,0.004709,,,0.00063,0.080986,0.708906,0.1706,0.006204,0.358587,0.525351,0.255736,,0.063902,0.059416,0.006466,0.148698,1.335856,0.008207,0.001423,0.207334,0.736463,0.096219,,0.023381,0.002768,0.008322,1.001519,0.008298,0.161345,0.148266,0.922998,0.354596,0.152025,0.118075,0.001882,0.158612,0.065728,0.018385,0.063646,0.199617,0.308233,0.016361,0.401619,0.091071,CR,O,0.007126,0.007665,,0.652984,0.00852,,0.00473,6.0,0.272008,0.008363,0.515222,0.002644,0.009013,0.004808,0.008342,0.119403,0.004802,0.108271,0.050882,,0.007554,0.080422,0.069067,,0.004327,0.007562,,0.007729,0.000272,0.001576,0.004239,0.001434,,0.002271,0.004061,0.007121,0.002456,0.00231,0.003532,0.506612,0.008033,1.009825,0.084683,0.00382,0.007043,0.000438,0.006452,0.00083,0.005055,,0.0,0.00572,0.007084,,0.000198,0.008907,,1,0.002537,0.005177,0.006626,0.009705,0.007782,0.00245,1.001101,0.002665,0.007479,0.006893,1.503673,1.006133,0.003569,0.008871,0.00395,0.003647,0.00495,0.89409,0.135561,0.911191,0.974539,0.001243,0.766688,1.008691,1.004587,0.893734,,0.670041,0.009968,0.004572,,1.008949,2.0,,0.004326,,,,1.007336,0.21006,0.676922,0.007871,1.0,0.23825,0.0,4.0,0.23212,0.236266,0.0,0.70228,0.434345,0.003057,0.686516,0.00874,1.0,1.003319,1.007819,1.00008,0.006805,,0.002052,0.005972,,0.004345,0.001535,,,,,,0.002427,0.003706,0.003818,,0.000569,0.00061,0.002674
1,0000099d6bd597052cdcda90ffabf56573fe9d7c79be5f...,2017-04-07,0.936665,0.005775,0.004923,1.000653,0.006151,0.12675,0.000798,0.002714,,,0.002526,0.069419,0.712795,0.113239,0.006206,0.35363,0.521311,0.223329,,0.065261,0.057744,0.001614,0.149723,1.339794,0.008373,0.001984,0.202778,0.720886,0.099804,,0.030599,0.002749,0.002482,1.009033,0.005136,0.140951,0.14353,0.919414,0.326757,0.156201,0.118737,0.00161,0.148459,0.093935,0.013035,0.065501,0.151387,0.265026,0.017688,0.406326,0.086805,CR,O,0.002413,0.007148,,0.647093,0.002238,,0.003879,6.0,0.18897,0.00403,0.509048,0.004193,0.007842,0.001283,0.006524,0.140611,9.4e-05,0.101018,0.040469,,0.004832,0.081413,0.074166,,0.004203,0.005304,,0.001864,0.000979,0.009896,0.007597,0.000509,,0.00981,0.000127,0.005966,0.000395,0.001327,0.007773,0.500855,0.00076,1.009461,0.081843,0.000347,0.007789,0.004311,0.002332,0.009469,0.003753,,0.0,0.007584,0.006677,,0.001142,0.005907,,1,0.008427,0.008979,0.001854,0.009924,0.005987,0.002247,1.006779,0.002508,0.006827,0.002837,1.503577,1.005791,0.000571,0.000391,0.008351,0.00885,0.00318,0.902135,0.136333,0.919876,0.975624,0.004561,0.786007,1.000084,1.004118,0.906841,,0.668647,0.003921,0.004654,,1.003205,2.0,,0.008707,,,,1.007653,0.184093,0.822281,0.003444,1.0,0.247217,0.0,4.0,0.243532,0.241885,0.0,0.707017,0.430501,0.001306,0.686414,0.000755,1.0,1.008394,1.004333,1.008344,0.004407,,0.001034,0.004838,,0.007495,0.004931,,,,,,0.003954,0.003167,0.005032,,0.009576,0.005492,0.009217
2,0000099d6bd597052cdcda90ffabf56573fe9d7c79be5f...,2017-05-28,0.95418,0.091505,0.021655,1.009672,0.006815,0.123977,0.007598,0.009423,,,0.007605,0.068839,0.720884,0.060492,0.003259,0.33465,0.524568,0.189424,,0.066982,0.056647,0.005126,0.151955,1.337179,0.009355,0.007426,0.206629,0.738044,0.134073,,0.048367,0.010077,0.00053,1.009184,0.006961,0.112229,0.137014,1.001977,0.304124,0.153795,0.114534,0.006328,0.139504,0.084757,0.056653,0.070607,0.305883,0.212165,0.063955,0.406768,0.094001,CR,O,0.001878,0.003636,,0.645819,0.000408,,0.004578,6.0,0.495308,0.006838,0.679257,0.001337,0.006025,0.009393,0.002615,0.075868,0.007152,0.103239,0.047454,,0.006561,0.078891,0.07651,,0.001782,0.001422,,0.005419,0.006149,0.009629,0.003094,0.008295,,0.009362,0.000954,0.005447,0.007345,0.007624,0.008811,0.504606,0.004056,1.004291,0.081954,0.002709,0.004093,0.007139,0.008358,0.002325,0.007381,,0.0,0.005901,0.001185,,0.008013,0.008882,,1,0.007327,0.002016,0.008686,0.008446,0.007291,0.007794,1.001014,0.009634,0.00982,0.00508,1.503359,1.005801,0.007425,0.009234,0.002471,0.009769,0.005433,0.939654,0.134938,0.958699,0.974067,0.011736,0.80684,1.003014,1.009285,0.928719,,0.670901,0.001264,0.019176,,1.000754,2.0,,0.004092,,,,1.004312,0.154837,0.853498,0.003269,1.0,0.239867,0.0,4.0,0.240768,0.23971,0.0,0.704843,0.434409,0.003954,0.690101,0.009617,1.0,1.009307,1.007831,1.006878,0.003221,,0.005681,0.005497,,0.009227,0.009123,,,,,,0.003269,0.007329,0.000427,,0.003429,0.006986,0.002603
3,0000099d6bd597052cdcda90ffabf56573fe9d7c79be5f...,2017-06-13,0.960384,0.002455,0.013683,1.0027,0.001373,0.117169,0.000685,0.005531,,,0.006406,0.05563,0.723997,0.166782,0.009918,0.323271,0.530929,0.135586,,0.08372,0.049253,0.001418,0.151219,1.339909,0.006782,0.003515,0.208214,0.741813,0.134437,,0.030063,0.009667,0.000783,1.007456,0.008706,0.102838,0.129017,0.704016,0.275055,0.155772,0.12074,0.00498,0.1381,0.048382,0.012498,0.065926,0.273553,0.2043,0.022732,0.405175,0.094854,CR,O,0.005899,0.005896,,0.654358,0.005897,,0.005207,6.0,0.50867,0.008183,0.515282,0.008716,0.005271,0.004554,0.002052,0.150209,0.005364,0.206394,0.031705,,0.009559,0.07749,0.071547,,0.005595,0.006363,,0.000646,0.009193,0.008568,0.003895,0.005153,,0.004876,0.005665,0.001888,0.004961,3.4e-05,0.004652,0.508998,0.006969,1.004728,0.060634,0.009982,0.008817,0.00869,0.007364,0.005924,0.008802,,0.0,0.00252,0.003324,,0.009455,0.008348,,1,0.007053,0.003909,0.002478,0.006614,0.009977,0.007686,1.002775,0.007791,0.000458,0.00732,1.503701,1.007036,0.000664,0.0032,0.008507,0.004858,6.3e-05,0.913205,0.140058,0.926341,0.975499,0.007571,0.808214,1.001517,1.004514,0.935383,,0.67262,0.002729,0.01172,,1.005338,2.0,,0.009703,,,,1.002538,0.153939,0.844667,5.3e-05,1.0,0.24091,0.0,4.0,0.2394,0.240727,0.0,0.711546,0.436903,0.005135,0.687779,0.004649,1.0,1.001671,1.00346,1.007573,0.007703,,0.007108,0.008261,,0.007206,0.002409,,,,,,0.006117,0.004516,0.0032,,0.008419,0.006527,0.0096
4,0000099d6bd597052cdcda90ffabf56573fe9d7c79be5f...,2017-07-16,0.947248,0.002483,0.015193,1.000727,0.007605,0.117325,0.004653,0.009312,,,0.007731,0.038862,0.720619,0.14363,0.006667,0.231009,0.529305,,,0.0759,0.048918,0.001199,0.154026,1.341735,0.000519,0.001362,0.205468,0.691986,0.121518,,0.054221,0.009484,0.006698,1.003738,0.003846,0.094311,0.129539,0.917133,0.23111,0.154914,0.095178,0.001653,0.126443,0.039259,0.027897,0.063697,0.233103,0.175655,0.031171,0.48746,0.093915,CR,O,0.009479,0.001714,,0.650112,0.007773,,0.005851,6.0,0.216507,0.008605,0.507712,0.006821,0.000152,0.000104,0.001419,0.096441,0.007972,0.10602,0.032733,,0.008156,0.076561,0.074432,,0.004933,0.004831,,0.001833,0.005738,0.003289,0.002608,0.007338,,0.007447,0.004465,0.006111,0.002246,0.002109,0.001141,0.506213,0.00177,1.000904,0.062492,0.00586,0.001845,0.007816,0.00247,0.005516,0.007166,,0.0,0.000155,0.001504,,0.002019,0.002678,,1,0.007728,0.003432,0.002199,0.005511,0.004105,0.009656,1.006536,0.005158,0.003341,0.000264,1.509905,1.002915,0.003079,0.003845,0.00719,0.002983,0.000535,0.921026,0.13162,0.933479,0.978027,0.0182,0.822281,1.006125,1.005735,0.953363,,0.673869,0.009998,0.017598,,1.003175,2.0,,0.00912,,,,1.00013,0.120717,0.811199,0.008724,1.0,0.247939,0.0,4.0,0.244199,0.242325,0.0,0.705343,0.437433,0.002849,0.688774,9.7e-05,1.0,1.009886,1.005053,1.008132,0.009823,,0.00968,0.004848,,0.006312,0.004462,,,,,,0.003671,0.004946,0.008889,,0.00167,0.008126,0.009827


In [20]:
# пишем обычную агрегирующую функцию
def slow_agg_fn(x, n=300):
    """
    Медленная агрегирующая функция
    
    Параметры
    ---------
    x: pandas.Series
        Серия значений.
   
    n: int
        Количество итераций 
        выполнения операции.
        
    Возвращает
    ----------
    results: pandas.Series
        
    """
    ser = x['P_2']
    for i in range(n):
        np.random.seed(0)
        rnd = np.random.rand(*ser.shape)
        ser = ser + rnd  
        
    return ser.mean()

def wrapper_fn(fn, grp, *args, **kwargs):   
    return grp[0], fn(grp[1], *args, **kwargs)


# пишем агрегирующую функцию с распараллеливанием
def parallel_aggregate(fn, df, by, cols, n_jobs=4, *args, **kwargs):
    """
    Агрегирующая функция с распараллеливанием.
    
    Параметры
    ----------
    fn: callable
        Вызываемая функция.
    df: pandas.DataFrame
        Датафрейм.
    by: list
        Список имен группирующих столбцов.
    cols, str
        Имя агрегируемого столбца.
    n_jobs: int
        Количество процессов.
    args: list
        Неименованные аргументы для 
        передачи в функцию.
    kwargs: dict
        Именованные аргументы для
        для передачи в функцию.
        
    Возвращает
    ----------
    results: pandas.Series
    """   
    res = {}
    groupby = df[by + cols].groupby(by)[cols]
    
    with Parallel(n_jobs) as p:
        res = p(delayed(wrapper_fn)(fn, grp, *args, **kwargs) for grp in groupby)
    
    res = pd.Series({x:y for (x, y) in res})
    
    return res

In [21]:
%%time

temp = pandas_df.groupby(['D_68', 'D_63'])[['P_2']].apply(slow_agg_fn)

CPU times: user 19.3 s, sys: 782 ms, total: 20.1 s
Wall time: 15.9 s


In [22]:
%%time

temp2 = parallel_aggregate(slow_agg_fn, pandas_df, ['D_68', 'D_63'], ['P_2'], n_jobs=8)

CPU times: user 845 ms, sys: 413 ms, total: 1.26 s
Wall time: 8.34 s
