In [1]:
import warnings

import numpy as np
import pandas as pd
from darts.models.forecasting.nhits import NHiTSModel
from darts import TimeSeries
import torch
import pickle
from typing import Callable
from tqdm import tqdm

from helpers import predict, load_agent, quality, clip
from preprocess import preprocess_stats
from rl.sim_enviroment import SimulatedQualEnv

from evidently.report import Report
from evidently.metric_preset import DataDriftPreset
from evidently.options import DataDriftOptions

The `LightGBM` module could not be imported. To enable LightGBM support in Darts, follow the detailed instructions in the installation guide: https://github.com/unit8co/darts/blob/master/INSTALL.md
The `Prophet` module could not be imported. To enable Prophet support in Darts, follow the detailed instructions in the installation guide: https://github.com/unit8co/darts/blob/master/INSTALL.md
The `CatBoost` module could not be imported. To enable CatBoost support in Darts, follow the detailed instructions in the installation guide: https://github.com/unit8co/darts/blob/master/INSTALL.md


In [2]:
import pickle

with open("/home/justbadcarma/Bachelor/data/cov_scaler.pkl", "rb") as f:
    cov_scaler = pickle.load(f)

with open("/home/justbadcarma/Bachelor/data/series_scaler.pkl", "rb") as f:
    series_scaler = pickle.load(f)

In [3]:

from helpers import reward_func


def optimize_params(data: pd.DataFrame, preprocess: Callable = preprocess_stats, device='cpu') -> pd.DataFrame:
    """
    Run and evaluate agent.

    :param data:        raw observations in pandas DataFrame
    :return:            result saves to the same path as input

    Args:
        preprocess: function to preprocess data
    """
    columns = ['Cell ID', 'LAC', 'HR Usage Rate', 'TCH Blocking Rate, BH', 'Number of Available\nTCH',
               'TCH Traffic (Erl), BH', 'Lower_limit', 'Upper_limit']

    df = preprocess(data, columns)
    obs_array = df.drop(columns=['Cell ID', 'LAC'], errors='ignore')
    obs_array.rename_axis(None, axis=1, inplace=True)
    obs_array.reset_index(drop=True, inplace=True)

    agent = load_agent('sac_best_enough_qual.pt', 'pt')
    state_predictor = NHiTSModel.load_from_checkpoint("/home/justbadcarma/Bachelor/rl/nhits_35lw_2l_1b_3s_35_lr4", "state_predictor", best=True, map_location=device)

    # # 'HR Usage Rate', 'TCH Blocking Rate, BH'
    # self.current_state = series[randint(0, len(series))].head(n_past)
    # # 'Number of Available\nTCH', 'TCH Traffic (Erl), BH', 'Param 1',  'Param 2'
    # self.cov = covariates[0].head(n_past)
    qualities = []
    new_states = []

    # 'HR Usage Rate', 'TCH Blocking Rate, BH'


    # print(TimeSeries.from_dataframe(obs_array.iloc[:, :2]))
    # print(len(TimeSeries.from_dataframe(obs_array.iloc[:, :2])))

    # setting env for reward calculation
    environment = SimulatedQualEnv(
        quality_function=reward_func,
        env=state_predictor,
        action_range=np.array([1,1]),
        series=series_scaler.transform(TimeSeries.from_dataframe(obs_array.iloc[:, :2])),
        covariates=cov_scaler.transform(TimeSeries.from_dataframe(obs_array.iloc[:, -4:])),
        n_past=7,
        scaler=series_scaler,
    )
    
    mom_reward = []

    for _ in range(len(obs_array) - 7):
        # print('Curr_state=', current_state.shape)
        obs = environment.reset()
        a1, a2 = predict(obs, agent)
        # compute reward
        new_state, reward, done, info = environment.step(np.array([a1, a2]))
        mom_reward.append(reward)

        # Compute quality
        qualities.append(environment.quality_after)

        # print(cov)
        # n for number of states to predict
        # current_state.rename_axis(None, axis=1, inplace=True)
        # current_state.reset_index(drop=True, inplace=True)
        new_states.append(new_state)

    # df['Lower_limit_Gen'], df['Upper_limit_Gen'], df['Limit_quality_Gen'] = lower_limits, upper_limits, qualities
    # df["Quality Rate"] = 1 - (2*df['HR Usage Rate']/100 + np.log(df['TCH Blocking Rate, BH'] + 1))/(1 + np.log(101))

    states_df = pd.DataFrame(new_states, columns=columns[2:])
    states_df["Quality Rate"] = 1 - (2*states_df['HR Usage Rate']/100 + np.log(states_df['TCH Blocking Rate, BH'] + 1))/(1 + np.log(101))
    states_df['cum_reward'] = np.cumsum(mom_reward)
    states_df['mom_reward'] = mom_reward

    return states_df


In [4]:
from typing import List


def preprocess_full(data: pd.DataFrame, cols: List[str]=None):
    df = data.copy()
    cols = ['HR Usage Rate', 'TCH Blocking Rate, BH', 'Number of Available\nTCH',
               'TCH Traffic (Erl), BH', 'Lower_limit', 'Upper_limit']
    df.drop(columns='DATA', inplace=True)
    df.rename(columns={'Param 1': cols[-2], 'Param 2': cols[-1]}, inplace=True)
    return df[cols]

In [5]:
train_df = pd.read_csv('data/train_data.csv', index_col=0) # train data
test_df = pd.read_csv('data/test_data.csv', index_col=0) 

In [6]:
train_df

Unnamed: 0,Cell ID,DATA,Number of Available\nTCH,HR Usage Rate,"TCH Blocking Rate, BH","TCH Traffic (Erl), BH",Param 1,Param 2
23,12086,2020-10-02,2.0,85.03,0.00,2.19,36,39
36,12092,2020-10-02,3.0,50.82,0.00,2.50,54,78
40,26303,2020-10-02,4.0,99.60,0.00,10.99,85,97
53,782,2020-10-02,2.0,19.23,0.00,0.97,32,50
55,783,2020-10-02,2.0,4.83,0.00,0.48,46,62
...,...,...,...,...,...,...,...,...
193989,882,2023-03-28,13.0,88.00,0.00,2.50,26,43
193990,887,2023-03-28,13.0,99.00,0.28,15.60,12,21
193991,883,2023-03-28,12.0,100.00,19.96,22.57,18,29
193992,888,2023-03-28,12.0,86.00,0.00,5.53,19,31


In [7]:
test_df

Unnamed: 0,Cell ID,DATA,Number of Available\nTCH,HR Usage Rate,"TCH Blocking Rate, BH","TCH Traffic (Erl), BH",Param 1,Param 2
0,25771,2023-01-02,13,96.0,0.00,5.47,46,54
1,25772,2023-01-02,13,75.0,0.00,5.34,25,43
2,25773,2023-01-02,21,72.0,0.00,7.71,27,47
3,3361,2023-01-02,21,49.0,0.00,9.02,20,30
4,3362,2023-01-02,6,92.0,0.00,1.38,11,21
...,...,...,...,...,...,...,...,...
193066,12746,2023-03-28,12,52.0,0.00,1.26,8,16
193067,12747,2023-03-28,12,50.0,0.00,1.37,41,49
193068,12781,2023-03-28,12,100.0,0.00,3.70,40,60
193069,12782,2023-03-28,12,99.0,0.00,4.92,40,60


In [20]:
for cell in tqdm(test_df[['Cell ID']].value_counts().keys()):
    print(test_df[test_df['Cell ID'] == cell].shape)
    break

  0%|          | 0/100 [00:00<?, ?it/s]

(191, 8)





In [22]:
test_df[['Cell ID']].value_counts()

Cell ID
12083      191
12086      191
12472      191
12471      191
12097      191
          ... 
782        185
12752       15
5683         7
5682         7
5681         7
Name: count, Length: 100, dtype: int64

In [6]:
# cell_list = list(map(lambda x: x[0], df[['Cell ID']].value_counts().index[:10].tolist()))
# curr = df[df['Cell ID'].isin(cell_list)]
# reff = df[~df['Cell ID'].isin(cell_list)]

## Train-test

In [8]:
%%time

scores = []
full_rewards = []
reff = train_df
cols = train_df.columns

import logging
# logging.getLogger("pytorch_lightning.utilities.rank_zero").setLevel(logging.WARNING)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

    for cell in tqdm(test_df[['Cell ID']].value_counts().keys()[:]):
        cell_data = test_df[test_df['Cell ID'] == cell]
        
        if len(cell_data) <= 7:
            continue
        
        data_drift_report = Report(metrics=[
            DataDriftPreset(),
        ])
        data_drift_report.run(reference_data=reff, current_data=cell_data[cols],)
        drift = data_drift_report.as_dict()['metrics'][0]['result']['share_of_drifted_columns']

        states = optimize_params(cell_data, preprocess=preprocess_full)

        scores.append({
            'cell_id': cell[0],
            'drift_score': drift,
            'quality_avg': states['Quality Rate'].mean(),
            'quality_min': states['Quality Rate'].min(),
            'quality_max': states['Quality Rate'].max(),
            'quality_std': states['Quality Rate'].std(),
            'cum_reward_avg': states['cum_reward'].mean(),
            'cum_reward_max': states['cum_reward'].max(),
            'cum_reward_std': states['cum_reward'].std(),
            'mom_reward_avg': states['mom_reward'].mean(),
            'mom_reward_min': states['mom_reward'].min(),
            'mom_reward_max': states['mom_reward'].max(),
            'mom_reward_std': states['mom_reward'].std(),
        })
        full_rewards.append({
            'cell_id': cell[0],
            'drift_score': drift,
            'quality': states['Quality Rate'],
            'cum_reward': states['cum_reward'],
            'mom_reward': states['mom_reward'],
        })

scores_df = pd.DataFrame(scores)
# scores_df.to_csv('drift_scores_rewards_new_agent_train-test_no_sample.csv')

full_rewards_df = pd.DataFrame(full_rewards)
# full_rewards_df.to_pickle('data/cell/full_rewards_new_agent_train-test_no_sample.pkl')
scores_df

100%|██████████| 100/100 [09:22<00:00,  5.63s/it]

CPU times: user 8min 38s, sys: 41.7 s, total: 9min 20s
Wall time: 9min 22s





Unnamed: 0,cell_id,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
0,12083,1.0,0.938036,0.788178,0.968796,0.028284,518.406946,1047.914420,259.952838,4.985153,-129.222951,10.0,19.731482
1,12086,1.0,0.959679,0.664859,0.994878,0.050413,-264.679536,350.000000,394.226976,-5.510784,-217.747710,10.0,44.921209
2,12472,1.0,0.920484,0.495793,0.992449,0.097400,-1305.646924,-20.000000,962.719535,-17.577757,-269.207217,10.0,63.259931
3,12471,1.0,0.902566,0.321151,0.996606,0.143719,-2046.224926,50.000000,1451.462522,-27.004052,-398.833829,10.0,83.854511
4,12097,1.0,0.992145,0.885358,0.997306,0.012103,857.826087,1720.000000,497.177445,9.347826,-20.000000,10.0,4.386853
...,...,...,...,...,...,...,...,...,...,...,...,...,...
92,41798,1.0,0.992323,0.982768,0.996593,0.003523,915.000000,1820.000000,526.830143,10.000000,10.000000,10.0,0.000000
93,24461,1.0,0.912180,0.408314,0.995568,0.127461,-1154.655853,102.076908,929.694255,-16.223734,-273.822333,10.0,69.987364
94,752,1.0,0.989428,0.962648,0.993906,0.003364,871.381215,1730.000000,500.035303,9.502762,-20.000000,10.0,3.840753
95,782,1.0,0.982470,0.842098,0.996267,0.027049,415.731187,746.377686,178.602903,3.435516,-199.673491,10.0,32.084778


In [9]:
scores_df[scores_df.drift_score < 1]

Unnamed: 0,cell_id,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
20,12473,0.875,0.95154,0.370663,0.987925,0.086598,307.234492,727.674105,211.706205,1.262364,-339.879131,10.0,45.30457
23,12747,0.875,0.963015,0.606044,0.996799,0.065175,189.456727,530.0,177.183295,-0.670465,-227.95826,10.0,47.189357
27,24042,0.875,0.967672,0.521078,0.994007,0.059037,66.45553,409.312652,131.243955,1.246807,-247.518479,10.0,40.973754
37,12781,0.875,0.950312,0.491363,0.996952,0.084563,182.245653,597.00535,317.050025,-3.845445,-271.954213,10.0,50.768228
41,10782,0.875,0.980698,0.546848,0.994508,0.050483,554.140994,939.007071,257.690867,4.813616,-251.757999,10.0,29.395749
51,781,0.875,0.962758,0.429241,0.995583,0.08577,503.297881,1030.0,256.591081,2.027805,-306.242026,10.0,41.989559
54,754,0.875,0.986341,0.38299,0.994713,0.048361,839.771481,1660.0,477.866658,7.168349,-331.023768,10.0,25.631036
56,41797,0.875,0.971308,0.62629,0.995286,0.040962,199.396197,392.65282,90.543816,2.133983,-224.159521,10.0,27.811624
65,4821,0.875,0.97441,0.4096,0.995096,0.062623,622.07971,1259.447015,314.399759,4.104267,-297.939318,10.0,33.990477
68,8506,0.875,0.977595,0.51912,0.994102,0.050201,306.904392,773.463507,210.868784,4.203606,-254.779897,10.0,32.136263


In [35]:
scores_df

Unnamed: 0,cell_id,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
0,12083,1.0,0.938013,0.788178,0.966074,0.028063,608.948202,1207.645230,316.178354,5.853255,-129.222951,10.0,18.248177
1,12086,1.0,0.959537,0.664859,0.994064,0.050416,-264.134989,350.000000,393.319951,-5.468896,-217.747710,10.0,44.819021
2,12472,1.0,0.920239,0.495793,0.991855,0.097394,-1305.232005,-20.000000,962.661219,-17.577496,-269.207217,10.0,63.255099
3,12471,1.0,0.902759,0.321151,0.996716,0.143830,-1963.455994,50.000000,1389.198569,-26.061547,-398.833829,10.0,83.780603
4,12097,1.0,0.992244,0.888126,0.996904,0.011857,857.826087,1720.000000,497.177445,9.347826,-20.000000,10.0,4.386853
...,...,...,...,...,...,...,...,...,...,...,...,...,...
92,41798,1.0,0.992314,0.982041,0.997012,0.003670,915.000000,1820.000000,526.830143,10.000000,10.000000,10.0,0.000000
93,24461,1.0,0.912222,0.408314,0.995005,0.127392,-1132.487031,105.930617,903.092628,-15.175110,-273.822333,10.0,69.428609
94,752,1.0,0.989371,0.962648,0.994787,0.003382,871.381215,1730.000000,500.035303,9.502762,-20.000000,10.0,3.840753
95,782,1.0,0.982556,0.842098,0.995773,0.027042,413.955080,746.376656,178.660681,3.435510,-199.673491,10.0,32.090957


In [36]:
scores_df.describe()

Unnamed: 0,cell_id,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
count,97.0,97.0,97.0,97.0,97.0,97.0,97.0,97.0,97.0,97.0,97.0,97.0,97.0
mean,12245.371134,0.974227,0.946243,0.681222,0.992024,0.058301,-367.937847,829.286493,737.95376,-6.197041,-180.101472,9.690722,36.265108
std,10224.8214,0.060573,0.057266,0.239723,0.013812,0.059298,1752.617846,725.729413,842.644393,23.035985,145.2073,3.046038,34.514361
min,721.0,0.833333,0.722807,0.266383,0.870611,0.002085,-6725.652643,-218.856386,75.625585,-116.493282,-491.849328,-20.0,0.0
25%,3371.0,1.0,0.933964,0.462082,0.991906,0.010453,-1049.433023,40.0,317.050025,-14.94257,-273.822333,10.0,7.131916
50%,12084.0,1.0,0.971406,0.741239,0.994964,0.028063,413.95508,773.463507,477.866658,3.767686,-217.782818,10.0,23.564447
75%,13315.0,1.0,0.983661,0.915992,0.996341,0.097272,836.630435,1600.0,683.52915,8.695652,-20.0,10.0,62.160016
max,41798.0,1.0,0.992974,0.984201,0.997855,0.22484,925.0,1840.0,4200.012443,10.0,10.0,10.0,126.791858


In [37]:
scores_df[scores_df.columns[1:]].corr()

Unnamed: 0,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
drift_score,1.0,-0.167015,0.311934,-0.085447,-0.015961,-0.159057,0.07025,0.248586,-0.141625,0.271059,-0.043652,-0.002695
quality_avg,-0.167015,1.0,0.703766,0.433769,-0.894567,0.885869,0.70763,-0.80771,0.967619,0.689232,0.400284,-0.908476
quality_min,0.311934,0.703766,1.0,0.050639,-0.905704,0.71316,0.807121,-0.525844,0.675935,0.960497,0.078483,-0.889497
quality_max,-0.085447,0.433769,0.050639,1.0,-0.141585,0.03187,0.125896,0.002635,0.466241,0.023087,0.901825,-0.208966
quality_std,-0.015961,-0.894567,-0.905704,-0.141585,1.0,-0.884414,-0.796498,0.760906,-0.856382,-0.877951,-0.154389,0.979108
cum_reward_avg,-0.159057,0.885869,0.71316,0.03187,-0.884414,1.0,0.736423,-0.928909,0.871856,0.724348,0.007354,-0.89237
cum_reward_max,0.07025,0.70763,0.807121,0.125896,-0.796498,0.736423,1.0,-0.444502,0.71307,0.780141,0.14817,-0.840019
cum_reward_std,0.248586,-0.80771,-0.525844,0.002635,0.760906,-0.928909,-0.444502,1.0,-0.785478,-0.55658,0.052773,0.74513
mom_reward_avg,-0.141625,0.967619,0.675935,0.466241,-0.856382,0.871856,0.71307,-0.785478,1.0,0.674518,0.491212,-0.888613
mom_reward_min,0.271059,0.689232,0.960497,0.023087,-0.877951,0.724348,0.780141,-0.55658,0.674518,1.0,0.053859,-0.87977


## Train-train

In [44]:
%%time

scores = []
full_rewards = []
reff = train_df
cols = train_df.columns

import logging
# logging.getLogger("pytorch_lightning.utilities.rank_zero").setLevel(logging.WARNING)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

    for cell in tqdm(train_df[['Cell ID']].value_counts().keys()[:]):
        cell_data = train_df[train_df['Cell ID'] == cell]

        if len(cell_data) <= 7:
            continue

        data_drift_report = Report(metrics=[
            DataDriftPreset(),
        ])
        data_drift_report.run(reference_data=reff, current_data=cell_data[cols],)
        drift = data_drift_report.as_dict()['metrics'][0]['result']['share_of_drifted_columns']

        states = optimize_params(cell_data, preprocess=preprocess_full)

        scores.append({
            'cell_id': cell[0],
            'drift_score': drift,
            'quality_avg': states['Quality Rate'].mean(),
            'quality_min': states['Quality Rate'].min(),
            'quality_max': states['Quality Rate'].max(),
            'quality_std': states['Quality Rate'].std(),
            'cum_reward_avg': states['cum_reward'].mean(),
            'cum_reward_max': states['cum_reward'].max(),
            'cum_reward_std': states['cum_reward'].std(),
            'mom_reward_avg': states['mom_reward'].mean(),
            'mom_reward_min': states['mom_reward'].min(),
            'mom_reward_max': states['mom_reward'].max(),
            'mom_reward_std': states['mom_reward'].std(),
        })
        full_rewards.append({
            'cell_id': cell[0],
            'drift_score': drift,
            'quality': states['Quality Rate'],
            'cum_reward': states['cum_reward'],
            'mom_reward': states['mom_reward'],
        })

scores_df = pd.DataFrame(scores)
scores_df.to_csv('drift_scores_rewards_new_agent_train-train_no_sample.csv')

full_rewards_df = pd.DataFrame(full_rewards)
full_rewards_df.to_pickle('data/cell/full_rewards_new_agent_train-train_no_sample.pkl')

100%|██████████| 936/936 [1:46:55<00:00,  6.85s/it]

CPU times: user 1h 40min 32s, sys: 6min 11s, total: 1h 46min 43s
Wall time: 1h 46min 55s





In [39]:
scores_df

Unnamed: 0,cell_id,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
0,12083,1.0,0.938013,0.788178,0.966074,0.028063,608.948202,1207.645230,316.178354,5.853255,-129.222951,10.0,18.248177
1,12086,1.0,0.959537,0.664859,0.994064,0.050416,-264.134989,350.000000,393.319951,-5.468896,-217.747710,10.0,44.819021
2,12472,1.0,0.920239,0.495793,0.991855,0.097394,-1305.232005,-20.000000,962.661219,-17.577496,-269.207217,10.0,63.255099
3,12471,1.0,0.902759,0.321151,0.996716,0.143830,-1963.455994,50.000000,1389.198569,-26.061547,-398.833829,10.0,83.780603
4,12097,1.0,0.992244,0.888126,0.996904,0.011857,857.826087,1720.000000,497.177445,9.347826,-20.000000,10.0,4.386853
...,...,...,...,...,...,...,...,...,...,...,...,...,...
92,41798,1.0,0.992314,0.982041,0.997012,0.003670,915.000000,1820.000000,526.830143,10.000000,10.000000,10.0,0.000000
93,24461,1.0,0.912222,0.408314,0.995005,0.127392,-1132.487031,105.930617,903.092628,-15.175110,-273.822333,10.0,69.428609
94,752,1.0,0.989371,0.962648,0.994787,0.003382,871.381215,1730.000000,500.035303,9.502762,-20.000000,10.0,3.840753
95,782,1.0,0.982556,0.842098,0.995773,0.027042,413.955080,746.376656,178.660681,3.435510,-199.673491,10.0,32.090957


In [40]:
scores_df.describe()

Unnamed: 0,cell_id,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
count,97.0,97.0,97.0,97.0,97.0,97.0,97.0,97.0,97.0,97.0,97.0,97.0,97.0
mean,12245.371134,0.974227,0.946243,0.681222,0.992024,0.058301,-367.937847,829.286493,737.95376,-6.197041,-180.101472,9.690722,36.265108
std,10224.8214,0.060573,0.057266,0.239723,0.013812,0.059298,1752.617846,725.729413,842.644393,23.035985,145.2073,3.046038,34.514361
min,721.0,0.833333,0.722807,0.266383,0.870611,0.002085,-6725.652643,-218.856386,75.625585,-116.493282,-491.849328,-20.0,0.0
25%,3371.0,1.0,0.933964,0.462082,0.991906,0.010453,-1049.433023,40.0,317.050025,-14.94257,-273.822333,10.0,7.131916
50%,12084.0,1.0,0.971406,0.741239,0.994964,0.028063,413.95508,773.463507,477.866658,3.767686,-217.782818,10.0,23.564447
75%,13315.0,1.0,0.983661,0.915992,0.996341,0.097272,836.630435,1600.0,683.52915,8.695652,-20.0,10.0,62.160016
max,41798.0,1.0,0.992974,0.984201,0.997855,0.22484,925.0,1840.0,4200.012443,10.0,10.0,10.0,126.791858


In [41]:
scores_df[scores_df.columns[1:]].corr()

Unnamed: 0,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
drift_score,1.0,-0.167015,0.311934,-0.085447,-0.015961,-0.159057,0.07025,0.248586,-0.141625,0.271059,-0.043652,-0.002695
quality_avg,-0.167015,1.0,0.703766,0.433769,-0.894567,0.885869,0.70763,-0.80771,0.967619,0.689232,0.400284,-0.908476
quality_min,0.311934,0.703766,1.0,0.050639,-0.905704,0.71316,0.807121,-0.525844,0.675935,0.960497,0.078483,-0.889497
quality_max,-0.085447,0.433769,0.050639,1.0,-0.141585,0.03187,0.125896,0.002635,0.466241,0.023087,0.901825,-0.208966
quality_std,-0.015961,-0.894567,-0.905704,-0.141585,1.0,-0.884414,-0.796498,0.760906,-0.856382,-0.877951,-0.154389,0.979108
cum_reward_avg,-0.159057,0.885869,0.71316,0.03187,-0.884414,1.0,0.736423,-0.928909,0.871856,0.724348,0.007354,-0.89237
cum_reward_max,0.07025,0.70763,0.807121,0.125896,-0.796498,0.736423,1.0,-0.444502,0.71307,0.780141,0.14817,-0.840019
cum_reward_std,0.248586,-0.80771,-0.525844,0.002635,0.760906,-0.928909,-0.444502,1.0,-0.785478,-0.55658,0.052773,0.74513
mom_reward_avg,-0.141625,0.967619,0.675935,0.466241,-0.856382,0.871856,0.71307,-0.785478,1.0,0.674518,0.491212,-0.888613
mom_reward_min,0.271059,0.689232,0.960497,0.023087,-0.877951,0.724348,0.780141,-0.55658,0.674518,1.0,0.053859,-0.87977


In [23]:
len(df['Cell ID'].unique())

1043

In [11]:
300 / len(df['Cell ID'].unique())

In [10]:
scores_df

In [7]:
import logging

loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]

In [13]:
with open('loggers.txt', 'w') as f:
    for item in loggers:
        # write each item on a new line
        f.write("%s\n" % item)

In [20]:
class LessThanFilter(logging.Filter):
    def __init__(self, exclusive_maximum, name=""):
        super(LessThanFilter, self).__init__(name)
        self.max_level = exclusive_maximum

    def filter(self, record):
        #non-zero return means we log this message
        return 1 if record.levelno < self.max_level else 0

logging.getLogger("pytorch_lightning.utilities.rank_zero").addFilter(LessThanFilter(logging.ERROR))

In [23]:
logging.getLogger("pytorch_lightning.utilities.rank_zero").error('asd')
logging.getLogger("pytorch_lightning.utilities.rank_zero").error('asd')
logging.getLogger("pytorch_lightning.utilities.rank_zero").error('asd')
logging.getLogger("pytorch_lightning.utilities.rank_zero").error('asd')
logging.getLogger("pytorch_lightning.utilities.rank_zero").error('asd')

In [27]:
logging.getLogger("pytorch_lightning.utilities.rank_zero").findCaller()

('/home/rid/Soft/anaconda3/envs/sm_bachelor/lib/python3.9/site-packages/IPython/core/interactiveshell.py',
 3448,
 'run_ast_nodes',
 None)

In [9]:
%%time

data_drift_report = Report(metrics=[
    DataDriftPreset(),
])
data_drift_report.run(reference_data=reff, current_data=cell_data,)
drift = data_drift_report.as_dict()['metrics'][0]['result']['share_of_drifted_columns']

CPU times: user 6.43 s, sys: 212 ms, total: 6.65 s
Wall time: 6.8 s


In [10]:
%%time

data_drift_report.run(reference_data=reff, current_data=cell_data,)
drift = data_drift_report.as_dict()['metrics'][0]['result']['share_of_drifted_columns']

CPU times: user 6.22 s, sys: 168 ms, total: 6.39 s
Wall time: 6.49 s


In [11]:
%%time

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    states = optimize_params(cell_data, preprocess=preprocess_full)

CPU times: user 42.4 s, sys: 1.72 s, total: 44.1 s
Wall time: 44.1 s


In [2]:
torch.cuda.is_available()

True

In [10]:
%%time

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    states = optimize_params(cell_data, preprocess=preprocess_full)

CPU times: user 1min 27s, sys: 2.28 s, total: 1min 30s
Wall time: 54.5 s


In [11]:
%%time

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    states = optimize_params(cell_data, preprocess=preprocess_full, device='cuda')

CPU times: user 1min 23s, sys: 1.81 s, total: 1min 25s
Wall time: 48.2 s


# Analysis

In [7]:
data = pd.read_csv('drift_scores_rewards_all.csv', index_col=0)

In [8]:
data

Unnamed: 0,cell_id,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
0,"(1946,)",0.875,0.874160,0.734750,1.688068,0.056874,-3.587066e+06,-50,3.191874e+06,-17734.276206,-34940,-50,10101.580818
1,"(1945,)",0.875,0.881240,0.712539,1.069511,0.036131,-3.584984e+06,-45,3.210989e+06,-17866.888519,-35685,-45,10343.056816
2,"(1947,)",1.000,0.842714,0.679234,0.999486,0.047754,-1.135826e+06,-10,1.033637e+06,-5546.775000,-9820,-10,3137.497803
3,"(1941,)",0.875,0.852673,0.678700,1.311538,0.068895,-3.447562e+06,-50,3.098796e+06,-17337.090301,-34600,-50,10117.079323
4,"(1943,)",1.000,0.806085,0.669411,1.095661,0.088054,-3.405338e+06,-45,2.981054e+06,-16600.342809,-32240,-45,9092.147583
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1038,"(12482,)",0.875,0.868466,0.725107,1.062574,0.031370,-3.961857e+05,-45,3.373165e+05,-5342.690476,-10220,-45,2780.442553
1039,"(12483,)",0.875,0.883805,0.840871,0.929947,0.020493,-9.898051e+04,5,9.285494e+04,-2863.457944,-5975,5,1811.103647
1040,"(12481,)",0.875,0.965920,0.886776,1.042712,0.034750,-1.163850e+05,-50,1.039313e+05,-3215.140187,-6395,-50,1861.693343
1041,"(13323,)",0.875,0.981018,0.911628,1.053775,0.032089,-5.948250e+04,-45,5.296255e+04,-2295.000000,-4545,-45,1324.990566


In [9]:
data[data.columns[1:]].corr()

Unnamed: 0,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
drift_score,1.0,-0.252896,-0.098583,-0.039432,-0.12755,0.006945,0.068895,-0.021761,0.066041,0.083738,0.066402,-0.093436
quality_avg,-0.252896,1.0,0.502643,0.33352,0.159484,0.045978,-0.038443,0.009329,-0.147332,-0.238948,-0.042868,0.259073
quality_min,-0.098583,0.502643,1.0,0.119505,-0.134206,0.07053,-0.045746,-0.026577,-0.08806,-0.17231,-0.047839,0.182302
quality_max,-0.039432,0.33352,0.119505,1.0,0.468174,-0.152249,-0.073236,0.173904,-0.215231,-0.239303,-0.114423,0.240701
quality_std,-0.12755,0.159484,-0.134206,0.468174,1.0,-0.085979,-0.064057,0.063141,-0.068361,-0.014295,-0.175647,0.007254
cum_reward_avg,0.006945,0.045978,0.07053,-0.152249,-0.085979,1.0,0.171343,-0.990619,0.932597,0.820249,0.390343,-0.802359
cum_reward_max,0.068895,-0.038443,-0.045746,-0.073236,-0.064057,0.171343,1.0,-0.136974,0.164816,0.088091,0.643997,-0.06382
cum_reward_std,-0.021761,0.009329,-0.026577,0.173904,0.063141,-0.990619,-0.136974,1.0,-0.960482,-0.882851,-0.335615,0.870986
mom_reward_avg,0.066041,-0.147332,-0.08806,-0.215231,-0.068361,0.932597,0.164816,-0.960482,1.0,0.957747,0.382207,-0.94869
mom_reward_min,0.083738,-0.238948,-0.17231,-0.239303,-0.014295,0.820249,0.088091,-0.882851,0.957747,1.0,0.236172,-0.995768


In [10]:
data[data.columns[1:]].corr(method='kendall')

Unnamed: 0,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
drift_score,1.0,-0.222758,-0.091771,-0.029052,-0.115285,0.072701,0.045894,-0.073499,0.08375,0.090964,0.051647,-0.100316
quality_avg,-0.222758,1.0,0.369532,0.226132,0.029151,-0.063192,-0.003026,0.080686,-0.144342,-0.19744,-0.006468,0.20341
quality_min,-0.091771,0.369532,1.0,0.108647,-0.098588,-0.09175,0.017899,0.099998,-0.152579,-0.193956,0.013766,0.190701
quality_max,-0.029052,0.226132,0.108647,1.0,0.245643,-0.148184,-0.051916,0.153613,-0.176647,-0.184298,-0.048123,0.182747
quality_std,-0.115285,0.029151,-0.098588,0.245643,1.0,-0.014785,-0.141789,0.008767,-0.001689,0.01791,-0.141768,-0.020737
cum_reward_avg,0.072701,-0.063192,-0.09175,-0.148184,-0.014785,1.0,0.018777,-0.951919,0.858744,0.737975,0.023881,-0.743357
cum_reward_max,0.045894,-0.003026,0.017899,-0.051916,-0.141789,0.018777,1.0,-0.003903,-0.004915,-0.033414,0.991089,0.042257
cum_reward_std,-0.073499,0.080686,0.099998,0.153613,0.008767,-0.951919,-0.003903,1.0,-0.893549,-0.776462,-0.008326,0.787545
mom_reward_avg,0.08375,-0.144342,-0.152579,-0.176647,-0.001689,0.858744,-0.004915,-0.893549,1.0,0.864664,0.000553,-0.873894
mom_reward_min,0.090964,-0.19744,-0.193956,-0.184298,0.01791,0.737975,-0.033414,-0.776462,0.864664,1.0,-0.02954,-0.941438


In [11]:
data[data.columns[1:]].corr(method='spearman')

Unnamed: 0,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
drift_score,1.0,-0.273376,-0.112515,-0.035694,-0.141639,0.089374,0.048655,-0.090309,0.102738,0.111457,0.054876,-0.122876
quality_avg,-0.273376,1.0,0.527301,0.333326,0.03546,-0.087325,-0.006718,0.112539,-0.210416,-0.295772,-0.011162,0.305211
quality_min,-0.112515,0.527301,1.0,0.159504,-0.143087,-0.132457,0.021839,0.143934,-0.220827,-0.283947,0.016713,0.280873
quality_max,-0.035694,0.333326,0.159504,1.0,0.356822,-0.218206,-0.066568,0.227353,-0.261694,-0.273858,-0.061734,0.271981
quality_std,-0.141639,0.03546,-0.143087,0.356822,1.0,-0.023141,-0.182319,0.014738,-0.004675,0.026537,-0.182486,-0.030574
cum_reward_avg,0.089374,-0.087325,-0.132457,-0.218206,-0.023141,1.0,0.026761,-0.995112,0.967544,0.908117,0.033117,-0.908042
cum_reward_max,0.048655,-0.006718,0.021839,-0.066568,-0.182319,0.026761,1.0,-0.00733,-0.002111,-0.04,0.993497,0.051562
cum_reward_std,-0.090309,0.112539,0.143934,0.227353,0.014738,-0.995112,-0.00733,1.0,-0.979258,-0.929143,-0.012881,0.930875
mom_reward_avg,0.102738,-0.210416,-0.220827,-0.261694,-0.004675,0.967544,-0.002111,-0.979258,1.0,0.972041,0.004592,-0.971656
mom_reward_min,0.111457,-0.295772,-0.283947,-0.273858,0.026537,0.908117,-0.04,-0.929143,0.972041,1.0,-0.035342,-0.994058


In [12]:
data.describe()

Unnamed: 0,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
count,1043.0,1043.0,1043.0,1043.0,1043.0,1043.0,1043.0,1043.0,1043.0,1043.0,1043.0,1043.0
mean,0.898011,0.860945,0.711063,1.258105,0.057515,-2911674.0,-31.299137,2548311.0,-14916.982646,-28383.767977,-46.188878,8182.10846
std,0.04969,0.055507,0.035201,0.225152,0.016506,793813.8,228.456211,713671.0,3429.972718,7613.847676,17.025041,2259.314345
min,0.75,0.718305,0.607482,0.867746,0.015185,-3587066.0,-60.0,52641.49,-17866.888519,-35685.0,-60.0,934.315267
25%,0.875,0.821165,0.688939,1.09438,0.045872,-3436793.0,-50.0,2317062.0,-17510.304231,-34815.0,-50.0,6326.187334
50%,0.875,0.854715,0.710457,1.209379,0.055981,-3302644.0,-45.0,2922087.0,-16623.737288,-32055.0,-45.0,9286.716257
75%,0.875,0.898882,0.731056,1.373682,0.067125,-2855031.0,-45.0,3069585.0,-12960.904458,-23415.0,-45.0,10079.644416
max,1.0,1.037022,0.911628,3.08018,0.178629,-58920.0,5680.0,3210989.0,-1449.232143,-3060.0,155.0,10343.056816


In [13]:
data[data.mom_reward_max > 0]

Unnamed: 0,cell_id,drift_score,quality_avg,quality_min,quality_max,quality_std,cum_reward_avg,cum_reward_max,cum_reward_std,mom_reward_avg,mom_reward_min,mom_reward_max,mom_reward_std
7,"(13312,)",1.0,0.827806,0.683533,1.014011,0.051522,-1488446.0,5680,1681490.0,-9521.92953,-24235,155,7988.393527
8,"(13311,)",1.0,0.766681,0.637207,1.072541,0.050251,-523652.9,15,449387.8,-2489.77349,-4475,10,1283.988759
15,"(22975,)",0.875,0.85106,0.708378,0.965337,0.028175,-560112.9,760,516105.7,-2929.211864,-6100,65,1796.577395
43,"(24233,)",1.0,0.779225,0.680422,1.3894,0.045879,-2962463.0,95,2776694.0,-15662.177966,-32200,15,9789.926311
102,"(9737,)",0.875,0.836526,0.688055,1.260287,0.039131,-472372.7,310,458104.5,-2602.398305,-5740,35,1746.029977
192,"(42857,)",0.875,0.901327,0.631919,1.170445,0.036514,-1261592.0,15,1248601.0,-7061.771186,-15505,10,4867.344108
221,"(8916,)",0.875,0.86885,0.690507,0.966925,0.03412,-1782831.0,30,1786634.0,-10402.563667,-24855,15,7679.013739
248,"(42856,)",0.875,0.925625,0.667989,1.060651,0.031276,-2937212.0,5,2752871.0,-15626.655348,-32880,5,9852.050424
262,"(41872,)",0.875,0.802087,0.709027,1.287416,0.05567,-3419002.0,5,3045949.0,-17262.640068,-33845,5,9851.629123
294,"(22971,)",1.0,0.791741,0.703038,1.269978,0.038279,-3302312.0,5,2894209.0,-16208.149406,-29210,5,8736.894528
