<a href="https://colab.research.google.com/github/RaduW/volume-rebalance/blob/main/volume_rebalancing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



# Volume Rebalancing Algorithm


The volume rebalancing algorithm is based on the assumption that while given a global sample rate classes we want to adjust the individual class sampling rate in a way that equalises the number of sampled elements in each class while maintaining the overall sample rate.

In [153]:
url = "https://raw.githubusercontent.com/RaduW/volume-rebalance/main/transaction_adjustment_model.py"
!wget --no-cache --backups=1 {url}

from  transaction_adjustment_model import adjust_sample_rate

--2023-03-22 13:55:26--  https://raw.githubusercontent.com/RaduW/volume-rebalance/main/transaction_adjustment_model.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5943 (5.8K) [text/plain]
Saving to: ‘transaction_adjustment_model.py’


2023-03-22 13:55:26 (31.2 MB/s) - ‘transaction_adjustment_model.py’ saved [5943/5943]



In [154]:
from operator import itemgetter

from ipywidgets import interact, widgets
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")


In [155]:
# import transaction data
projects = "https://raw.githubusercontent.com/RaduW/volume-rebalance/main/projects.json"
trans_data = pd.read_json(projects)

PROJECTS = {
    "sentry": 1,
    "javascript": 11276,
    "snuba": 300688,
    "gibpotato-backend": 4504044639748096,
    #"gibpotato-frontend": 4504044640927744,
}

trans_data.sort_values(by=["proj_id","freq"],ignore_index=True, ascending=True, inplace=True)
trans_data

Unnamed: 0,name,freq,proj_id,proj_name
0,/hs/manage-preferences/unsubscribe,1,1,sentry
1,/repo/com/google/errorprone/error_prone_annota...,1,1,sentry
2,/api/0/customers/,1,1,sentry
3,sentry.tasks.email.process_inbound_email,1,1,sentry
4,getsentry.tasks.quotas.send_ondemand_notification,1,1,sentry
...,...,...,...,...
1345,GET /t4,2,4504044639748096,gibpotato-backend
1346,HEAD /login,4,4504044639748096,gibpotato-backend
1347,HEAD /,46,4504044639748096,gibpotato-backend
1348,GET /login,102,4504044639748096,gibpotato-backend


In [156]:
from typing import List, Mapping, MutableMapping, Tuple, Union

    
def counts_to_labeled_counts( counts: Union[float,int,Tuple[str,float]]):
  counts = sorted(counts)
  return [(f"t-{idx}",count) for idx,count in enumerate(counts)]

# Model params

The model has the following input parameters:

*   a list of initial counts representing the number of elements for each class `counts`
*   an overall desired sample rate: `global_rate` (input via slider)



In [157]:
def process_data(global_rate:float, items_high:int, items_low:int, project:str, trans_data: pd.DataFrame):

    data = trans_data.sort_values(by=["freq"],ignore_index=True, ascending=True)
    data = data[data["proj_name"] == project]

    counts = data["freq"]
    ideal_rate = counts.mean() * global_rate
    cnts_min = counts.min()
    cnts_max = counts.max()
    total = counts.sum()
    num_classes = len(counts)

    data = data.set_index("name", drop=False)
    data["original"] = data["freq"]* global_rate
    if items_low + items_high < num_classes:
      if items_high == 0:
        explicit_transactions = data[:items_low]
      else:
        explicit_transactions = pd.concat([data[:items_low], data[-items_high:]])
    else:
      explicit_transactions = data # we resize everything explicitly

    explicit_transactions_tuple = list(explicit_transactions[["name", "freq"]].itertuples(name=None, index=False))
    adjusted, implicit_rate = adjust_sample_rate(classes=explicit_transactions_tuple, rate=global_rate, total_num_classes=num_classes, total=total)
    adjusted_df = pd.DataFrame.from_dict(data = adjusted, orient="index", columns=["adjusted_rate"] )

    # now calculate the adjusted count
    df = data.join(adjusted_df, how="outer")

    # # keep count of the explicit rates
    df["explicit"]= ~df['adjusted_rate'].isna()
    # # set the adjusted rates for the implicit values 
    df["adjusted_rate"] = df["adjusted_rate"].fillna(implicit_rate)

    df["adjusted"] = df["freq"]* df["adjusted_rate"]

    print( f"Num classes:{num_classes} implicit_rate:{implicit_rate} original_rate:{global_rate}")
    return df  



def plot_rates(ax, data, last_low, first_high, global_rate, x_limit=None, log=False):
  if x_limit is not None:
    ax.set_xlim(*x_limit)
    data = data[data["freq"]<x_limit[1]]

  if log:
    plt.xscale("log")
    plt.yscale("log")
  else:
    plt.xscale("linear")
    plt.yscale("linear")


  counts_series = data["freq"]

  cnts_min = counts_series.min()
  cnts_max =  counts_series.max()
  ideal_rate = counts_series.mean() * global_rate
  rate_max = data.rate.max()

  if not log:
    #ideal level
    ax.text((cnts_min+cnts_max)/2, ideal_rate, "ideal rate", horizontalalignment='center', verticalalignment="bottom", size='medium', color='black')
    sns.lineplot( x=[cnts_min, cnts_max], y=[ideal_rate,ideal_rate], ax = ax);
    # border lower values
    sns.lineplot( x=[last_low,last_low+0.001], y=[0, rate_max], color="blue", ax=ax);
    # border higher values
    sns.lineplot( x=[first_high,first_high+0.001], y=[0, rate_max], color="red", ax=ax, );

  sns.scatterplot(data=data, x="freq", y="rate", hue ="series", style="explicit", ax=ax)

def draw_rate_change(ax, df):
  def minor_tick_format(tick_val, tick_pos):
    if tick_val < 0:
      tick_val *= -1
    while tick_val < 1:
      tick_val *=10
    while tick_val > 10:
      tick_val /=10
    return f"{int(tick_val)%10}"
    

  ax.set_yscale("log")
  ax.set_xscale("log")
  ax.grid(which='minor')
  sns.scatterplot(data=df, x="freq", y="adjusted_rate", hue ="explicit", style="explicit", ax=ax)
  ax.xaxis.set_minor_formatter(minor_tick_format)
  ax.set_title("Sample Rate")


def draw_rebalance_graphs(global_rate, items_high, items_low, project, trans_data):

    df = process_data(global_rate, items_high, items_low, project, trans_data)

    counts_series = df["freq"]

    last_low = counts_series.iloc[items_low]
    first_high = counts_series.iloc[-items_high]

    df2 = df.melt(id_vars=["freq","explicit"],value_vars=["adjusted","original"], var_name="series", value_name = "rate")
    fig, ax = plt.subplots(nrows = 4 ,figsize=(20,20))
    
    draw_rate_change(ax[0], df)
    plot_rates(ax[1], df2, last_low, first_high, global_rate)
    plot_rates(ax[2], df2, last_low, first_high, global_rate, x_limit=(-20, 220))
    plot_rates(ax[3], df2, last_low, first_high, global_rate, log=True)





### Counts

In [158]:
trans_data.head()

Unnamed: 0,name,freq,proj_id,proj_name
0,/hs/manage-preferences/unsubscribe,1,1,sentry
1,/repo/com/google/errorprone/error_prone_annota...,1,1,sentry
2,/api/0/customers/,1,1,sentry
3,sentry.tasks.email.process_inbound_email,1,1,sentry
4,getsentry.tasks.quotas.send_ondemand_notification,1,1,sentry


### Sample rate

In [159]:
global_rate=widgets.FloatSlider(min=0,max=1,value=0.1, step=0.001 , description= "sample rate");
items_high=widgets.BoundedIntText(min=0,max=num_classes,value=20,
                                  description="items high");
items_low =widgets.BoundedIntText(min=0,max=num_classes,value=20,step=1,
    description='items low')

project_names = list(PROJECTS.keys())

project= widgets.Dropdown(
    options=project_names,
    value='sentry',
    description='Project:',
)
#global_rate = widgets.FloatText(min=0,max=1,value=0.1, step=0.001 )
def rebalance_generator(trans_data):
  def inner( global_rate, items_high, items_low, project):
    draw_rebalance_graphs(global_rate, items_high,items_low, project, trans_data)
  return inner

widgets.interact(rebalance_generator(trans_data), global_rate = global_rate, items_high=items_high, items_low=items_low, project=project);

interactive(children=(FloatSlider(value=0.1, description='sample rate', max=1.0, step=0.001), BoundedIntText(v…

In [160]:
items_high = 10
items_low = 20
rate = 0.1

process_data(rate, items_high, items_low, "sentry", trans_data)


Num classes:744 implicit_rate:0.29144039410254186 original_rate:0.1


Unnamed: 0,name,freq,proj_id,proj_name,original,adjusted_rate,explicit,adjusted
,,18465,1,sentry,1846.5,0.29144,False,5381.446877
/,/,9630,1,sentry,963.0,0.29144,False,2806.570995
/*,/*,1130,1,sentry,113.0,0.29144,False,329.327645
/*/*,/*/*,1,1,sentry,0.1,0.29144,False,0.291440
/*/config/config.json,/*/config/config.json,1,1,sentry,0.1,0.29144,False,0.291440
...,...,...,...,...,...,...,...,...
sentry.tasks.update_code_owners_schema,sentry.tasks.update_code_owners_schema,121,1,sentry,12.1,0.29144,False,35.264288
sentry.tasks.user_report,sentry.tasks.user_report,11,1,sentry,1.1,0.29144,False,3.205844
tasks.invoices.create_invoices,tasks.invoices.create_invoices,4,1,sentry,0.4,0.29144,False,1.165762
tasks.invoices.retry_failed_tax_transactions,tasks.invoices.retry_failed_tax_transactions,4,1,sentry,0.4,0.29144,False,1.165762


# Scratch pad below

Ignore....