In [18]:
import math
import pymc as pm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import OneHotEncoder

%matplotlib inline

In [25]:
df = pd.read_csv("../../Data.csv")
df.dropna(inplace=True)

In [26]:
def transform_df(df):
    purchases = df[df['Quantity'] > 0].copy()
    returns = df[df['Quantity'] < 0].copy()
    
    returns['Quantity'] = returns['Quantity'].abs()
    
    returns_grouped = returns.groupby(['Customer ID', 'StockCode']).agg({'Quantity': 'sum'}).reset_index()
    returns_grouped['is_returned'] = 1
    
    result = pd.merge(purchases, returns_grouped[['Customer ID', 'StockCode', 'is_returned']], on=['Customer ID', 'StockCode'], how='left')
    
    # Заполнение NaN значений в is_returned как 0 (нет возврата)
    result['is_returned'] = result['is_returned'].fillna(0).astype(int)

    return result

transformed_df = transform_df(df)

In [37]:
with pm.Model() as model:
    # Априорные распределения
    price = pm.Normal('price', mu=transformed_df['Price'].mean(), sigma=transformed_df['Price'].std())
    quantity = pm.Normal('quantity', mu=transformed_df['Quantity'].mean(), sigma=transformed_df['Quantity'].std())

    # Условное распределение для is_returned
    theta = pm.Deterministic('theta', pm.math.sigmoid(price + quantity))
    is_returned = pm.Bernoulli('is_returned', p=theta, observed=transformed_df['is_returned'])

    # Обучение модели
    trace = pm.sample(1000)

# Анализ результатов
pm.summary(trace)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [price, quantity]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1391 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Chain 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
price,19.874,34.386,-44.007,78.582,15.817,11.958,5.0,12.0,2.31
quantity,-22.671,34.386,-81.377,41.22,15.817,11.958,5.0,12.0,2.31
theta,0.058,0.0,0.057,0.058,0.0,0.0,3678.0,3145.0,1.0
