In [1]:
FIGURES_PATH = 'out/figures/'
DATASETS_PATH = 'out/datasets/'

In [2]:
import pandas as pd
from datetime import datetime, timedelta
import os
import multiprocessing
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import random
from tqdm.notebook import tqdm
from multiprocesspandas import applyparallel
from pandarallel import pandarallel
import psutil
from sys import getsizeof

import pickle
import gc


tqdm.pandas()
from helper import *

In [3]:
NROWS = 1_000_000

In [4]:
data = pd.read_csv(DATASETS_PATH + 'data_processed.csv', nrows=NROWS).drop(columns=['Unnamed: 0'])
data['datetime'] = pd.to_datetime(data['datetime'])

In [5]:
def process_batch(x):
    ans = dict()
    for i in x['product_id'].values:
        if i in ans:
            ans[i] += 1
        else:
            ans[i] = 1
    return ans


def get_user_purchases(data):
    """
    :param data: receipts - pandas.DataFrame
    :return: ans: ans[i][j] = count of purchases by the user i of the product j - matrix
    """
    ans = dict()
    data = data[['gid', 'product_id']]


    pandarallel.initialize(progress_bar=True, use_memory_fs=True, nb_workers=psutil.cpu_count(logical=False))
    ans = data.groupby(by='gid').parallel_apply(process_batch)

    return ans.to_dict()

In [None]:
def delete_users_with_some(data, some=1, unique=False):

    pandarallel.initialize(progress_bar=True, use_memory_fs=True, nb_workers=psutil.cpu_count(logical=False))
    if unique:
        ans = data.groupby(by='gid')[['gid', 'product_id']].parallel_apply(lambda x: x['product_id'].nunique())
    else:
        ans = data.groupby(by='gid')[['gid', 'product_id']].parallel_apply(lambda x: x.shape[0])

    ans = ans.loc[ans >= some].index.array

    if len(ans) != 0:
        ans = data.loc[~data['gid'].isin(ans)]
    else:
        ans = data

    return ans


In [None]:
data = delete_users_with_some(data, some=1, unique=True)

gc.collect()

In [8]:
up = get_user_purchases(data)

INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=14928), Label(value='0 / 14928')))…

In [9]:
with open(DATASETS_PATH + 'user_purchases.pkl', 'wb') as f:
    pickle.dump(up, f)
    
del up

gc.collect()

0

In [5]:
BATCH_SIZE = 150_000

In [6]:
def get_date_distances_map(data, interval=None):
    """
    Считаем по каждому пользователю ближайшие (по модулю даты) покупки товаров.
    Усредняем значения по каждому пользователю.

    :param data: предобработанные данные
    :return: ans: ans[(product_1, product_2)] = массив средних временных промежутков для каждого пользователя при покупке товаров
    """
    ans = dict()


    def data_splitting(interval):
        nonlocal data
        batches = []
        data = data.sort_values(by='datetime')
        start = data.iloc[0].at['datetime']
        end = data.iloc[-1].at['datetime']
        while start <= end:
            sub_end = start + timedelta(days=interval)
            batch = data.loc[data['datetime'] >= start].loc[data['datetime'] < sub_end]
            batches.append(batch)
            start = sub_end

        return batches

    
    def do_dataframe(temps):
        ans = pd.DataFrame(data=temps, columns=['product_1', 'product_2', 'timedelta'])
        ans['count'] = pd.Series(data=[1 for _ in range(ans.shape[0])])
        return ans


    def fill_ans(x):
        product_date = x[['product_id', 'datetime']]
        res = dict()
        for i1, r1 in product_date.iterrows():
                for i2, r2 in product_date.iterrows():
                    if i1 != i2:
                        p1, p2 = r1['product_id'], r2['product_id']
                        timedelta = (r1['datetime'] - r2['datetime']).days
                        
                        if (p1, p2) in res:
                            if abs(res[(p1, p2)]) > abs(timedelta):
                                res[(p1, p2)] = timedelta
                        
                        else:
                            res[(p1, p2)] = timedelta
        return res
    

    def concat_dicts(res):
        nonlocal ans
        res = res.values
        for r in res:
            for key in r.keys():
                if key in ans:
                    ans[key].append(r[key])
                else:
                    ans[key] = [r[key]]
                
        return ans
        

    data = data[['gid', 'product_id', 'datetime']]
    data.loc[:, 'datetime'] = data['datetime'].dt.date
    if interval is not None:
        batches = data_splitting(interval=interval)
    else:
        batches = np.array_split(data, data.shape[0] // BATCH_SIZE + 1)
    
    pandarallel.initialize(progress_bar=False, use_memory_fs=True, nb_workers=psutil.cpu_count(logical=False))
    for batch in tqdm(batches):
        if psutil.virtual_memory().percent >= 90:
            break
        grouped_by_user = batch.groupby(by='gid')
        temp = grouped_by_user.parallel_apply(fill_ans)
        temp = temp.dropna()
        ans = concat_dicts(temp)
    return ans

In [7]:
dists_map = get_date_distances_map(data)

INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


  0%|          | 0/7 [00:00<?, ?it/s]

In [13]:
# dists_map
gc.collect()

323

In [14]:
psutil.cpu_count()

16

In [9]:
# переписать на групбай
def process_batch(batch):
    res = dict()
    for key in tqdm(batch.keys()):
        arr = np.array(batch[key])
        res[key] = [arr.mean(), arr.shape[0], np.quantile(arr, 0.75) - np.quantile(arr, 0.25)]
    return res


def concat_batches(dist):
    trans_dists.update(dist)
    

def transform_dists(dists):
    
    def chunks(dictionary, size):
        items = list(dictionary.items())
        return [dict(items[i : i + size]) for i in range(0, len(items), size)]

    def custom_error_callback(error):
        print(f'Got an Error: {error}', flush=True)
    
    
    pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
    batches = chunks(dists, len(dists) // psutil.cpu_count() // 2)
    
#     del dists
    
    gc.collect()
    
    for batch in batches:
        pool.apply_async(
            process_batch,
            args=(batch,),
            callback=concat_batches,
            error_callback=custom_error_callback,
        )
    
    pool.close()
    pool.join()

In [10]:
trans_dists = dict()
transform_dists(dists_map)
print('Done')





Done


In [8]:
trans_dists

NameError: name 'trans_dists' is not defined

In [18]:
import sys
def sizeof_fmt(num, suffix='B'):
    ''' by Fred Cirera,  https://stackoverflow.com/a/1094933/1870254, modified'''
    for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']:
        if abs(num) < 1024.0:
            return "%3.1f %s%s" % (num, unit, suffix)
        num /= 1024.0
    return "%.1f %s%s" % (num, 'Yi', suffix)

for name, size in sorted(((name, sys.getsizeof(value)) for name, value in list(
                          globals().items())), key= lambda x: -x[1])[:10]:
    print("{:>30}: {:>8}".format(name, sizeof_fmt(size)))
    
    

                             _: 320.0 MiB
                     dists_map: 320.0 MiB
                   trans_dists: 320.0 MiB
                           _17: 320.0 MiB
                          data: 89.1 MiB
                          _i11:  8.0 KiB
                          _iii:  3.2 KiB
                          _i15:  3.2 KiB
                          tqdm:  2.0 KiB
                   pandarallel:  1.0 KiB


In [19]:
with open(DATASETS_PATH + 'date_distances.pkl', 'wb') as f:
    pickle.dump(trans_dists, f)
    
del trans_dists