# Random Forest Self Distillation

In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.cuda.is_available()

True

## Import section

In [3]:
from selfdist_toolkit.data_tools import preprocessing, loading, cleaning, sd_data_utils, analysis
from selfdist_toolkit.randomforest_tools import self_distillation, rf_analysis, normal, creation

import os
from tqdm import tqdm
import json
import pandas as pd
import numpy as np

import sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_validate

from rdkit.Chem.Descriptors import descList
from rdkit.Chem import MolFromSmiles, RDKFingerprint

## Data Preprocessing

In [4]:
# path to dataset
PATH_DATA = "data/"
PATH_MAIN_DATASET = PATH_DATA + "df_assay_entries.csv"

In [5]:
aids = preprocessing.experiment_whole_preprocess(PATH_MAIN_DATASET, PATH_DATA)

Data file already present, no need for download.


100%|████████████████████████████████████████████████████████████████████████████| 2481/2481 [00:00<00:00, 7492.64it/s]


Chemical descriptor data already generated
Fingerprints already generated


## Individual Data Loading

In [6]:
c_sampledata, c_samplelabel = loading.load_chem_desc_data(411)

In [7]:
f_sampledata, f_samplelabel = loading.load_fingerprint_data(411)

In [8]:
c_sampledata

array([[12.0147147 , -0.2333767 , 12.0147147 , ...,  0.        ,
         1.        ,  0.        ],
       [13.68844293, -1.04921672, 13.68844293, ...,  0.        ,
         0.        ,  0.        ],
       [13.75921567, -0.69660748, 13.75921567, ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [12.81165769, -1.13973262, 12.81165769, ...,  0.        ,
         0.        ,  0.        ],
       [13.06424981, -1.12117352, 13.06424981, ...,  0.        ,
         0.        ,  0.        ],
       [12.39603269, -1.12400132, 12.39603269, ...,  0.        ,
         0.        ,  0.        ]])

In [9]:
c_samplelabel

array([0, 0, 0, ..., 0, 0, 0])

In [10]:
f_sampledata

array([[0, 1, 1, ..., 0, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [0, 1, 1, ..., 0, 0, 1],
       ...,
       [1, 1, 0, ..., 0, 0, 1],
       [1, 1, 0, ..., 0, 0, 1],
       [1, 1, 0, ..., 0, 0, 1]])

In [11]:
f_samplelabel

array([0, 0, 0, ..., 0, 0, 0])

In [12]:
print(c_sampledata.shape)
print(f_sampledata.shape)

(68285, 208)
(68285, 2048)


In [13]:
f_samplelabel.sum()

1536

## Preloading all data

In [14]:
# loading.preload_fingerprint_data_all(aids)

In [15]:
# loading.preload_chem_data_all(aids)

## Preparing for self distillation

In [16]:
# define experiment id
aid = 411

In [17]:
# get the prediction data
data, labels = loading.load_chem_desc_data(aid)

In [18]:
# determine number of elements to fetch for self destillation
number_sd = int(data.shape[0]*0.2+0.5) # set to percentage perhaps

In [19]:
# get self distillation elements
sd_data = sd_data_utils.generate_self_distillation_elements(
    aid=aid,
    number_to_generate=number_sd,
    data_gen_method="chem-desc",
    path_data=PATH_DATA
)

## Normal Random Forest

In [20]:
# define experiment id
# aid = 411
aid = 1688

In [21]:
# create random forest
rf = creation.generate_default_rf()

In [22]:
# execute normal random forest testing
accuracy_dict = normal.execute_normal_rf_test(
    rf=rf,
    aid=aid,
    mode="chem-desc"
)

  temp **= 2
  new_unnormalized_variance -= correction**2 / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  upper_bound = n_samples * eps * var + (n_samples * mean * eps) ** 2
5it [02:27, 29.54s/it]


In [23]:
print(json.dumps(accuracy_dict, indent=4))

{
    "accuracy": [
        0.9892221061482406,
        0.9890182689122735,
        0.9890182689122735,
        0.9887634723673147,
        0.9890947078757613
    ],
    "balanced_accuracy": [
        0.5172295768639444,
        0.5080129870021933,
        0.5068737955511842,
        0.5033323323394471,
        0.5103190623589576
    ],
    "roc": [
        0.5172295768639444,
        0.5080129870021933,
        0.5068737955511842,
        0.5033323323394471,
        0.5103190623589576
    ],
    "precision": [
        0.9871089320097642,
        0.9852163291230984,
        0.9855267700857692,
        0.9807987228117808,
        0.9872477544342729
    ],
    "recall": [
        0.9892221061482406,
        0.9890182689122735,
        0.9890182689122735,
        0.9887634723673147,
        0.9890947078757613
    ]
}


In [24]:
normal_mean_dict = rf_analysis.convert_acc_dict(accuracy_dict)
print(json.dumps(normal_mean_dict, indent=4))

{
    "accuracy": 0.9890233648431728,
    "balanced_accuracy": 0.5091535508231453,
    "roc": 0.5091535508231453,
    "precision": 0.9851797016929371,
    "recall": 0.9890233648431728
}


## Self distillation Random Forest

In [25]:
# generating random forests
rf_teacher = creation.generate_default_rf()
rf_student = creation.generate_default_rf()

In [26]:
normal_accuracy_dict, sd_accuracy_dict = self_distillation.execute_sd_rf_test(
    rf_teacher=rf_teacher,
    rf_student=rf_student,
    aid=aid,
    mode="chem-desc"
)

  temp **= 2
  new_unnormalized_variance -= correction**2 / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  upper_bound = n_samples * eps * var + (n_samples * mean * eps) ** 2
5it [05:37, 67.59s/it]


In [27]:
normal_mean_dict = rf_analysis.convert_acc_dict(normal_accuracy_dict)
sd_mean_dict = rf_analysis.convert_acc_dict(sd_accuracy_dict)

In [28]:
print(json.dumps(normal_mean_dict, indent=4))

{
    "accuracy": 0.989002981119576,
    "balanced_accuracy": 0.5086880981064464,
    "roc": 0.5086880981064463,
    "precision": 0.9845927337302527,
    "recall": 0.989002981119576
}


In [29]:
print(json.dumps(sd_mean_dict, indent=4))

{
    "accuracy": 0.9890131729813744,
    "balanced_accuracy": 0.5082381041277991,
    "roc": 0.5082381041277991,
    "precision": 0.9852126380301224,
    "recall": 0.9890131729813744
}


In [30]:
comparison_dict = rf_analysis.compare_accuracy_dict(normal_mean_dict, sd_mean_dict)

In [31]:
print(json.dumps(comparison_dict, indent=4))

{
    "accuracy": 1.0191861798403146e-05,
    "balanced_accuracy": -0.0004499939786473117,
    "roc": -0.0004499939786472007,
    "precision": 0.0006199042998696891,
    "recall": 1.0191861798403146e-05
}


## Testing for self distillation success for a number of experiments

### Chemical descriptor data mode

In [32]:
# get the experiment list to test
exp_to_test = analysis.get_good_experiment_ids(
    number_to_sample=10
)

100%|██████████████████████████████████████████████████████████████████████████████| 2481/2481 [01:17<00:00, 31.96it/s]


In [33]:
exp_to_test

array([   902,   1458,   1461,   1688,   2551, 485313, 485314, 624297,
       651965, 652104], dtype=int64)

In [34]:
# take balanced accuracy and roc into account as it seems to be the most expressive out of all, 
# because it actually is not over 90% from the beginning and from the documentation it seemed 
# to be the best for imbalanced datasets. Although it is astonishing that the average parameter
# to weighted does not work. Maybe a weight needs to be supplied besides the parameter but I
# rather go with the other options

In [35]:
# better worse counter
roc_better = []
roc_worse = []
ba_better = []
ba_worse = []

In [36]:
mode = "chem-desc"
# mode = "fingerprint"

In [37]:
# iterate over experiments to test
for aid in tqdm(exp_to_test):
    
    # generating random forests
    rf_teacher = creation.generate_default_rf()
    rf_student = creation.generate_default_rf()
    
    # execute self destillation test
    normal_accuracy_dict, sd_accuracy_dict = self_distillation.execute_sd_rf_test(
        rf_teacher=rf_teacher,
        rf_student=rf_student,
        aid=aid,
        mode=mode,
        verbose=False
    )
    
    # calculate mean accuracy values
    normal_mean_dict = rf_analysis.convert_acc_dict(normal_accuracy_dict)
    sd_mean_dict = rf_analysis.convert_acc_dict(sd_accuracy_dict)
    
    # get the difference in accuracy scores
    comparison_dict = rf_analysis.compare_accuracy_dict(normal_mean_dict, sd_mean_dict)
    
    # analyze it
    if comparison_dict["roc"] > 0:
        roc_better.append(aid)
    else:
        roc_worse.append(aid)
    if comparison_dict["balanced_accuracy"] > 0:
        ba_better.append(aid)
    else:
        ba_worse.append(aid)

  temp **= 2
  new_unnormalized_variance -= correction**2 / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  upper_bound = n_samples * eps * var + (n_samples * mean * eps) ** 2
  temp **= 2
  new_unnormalized_variance -= correction**2 / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  upper_bound = n_samples * eps * var + (n_samples * mean * eps) ** 2
  temp **= 2
  new_unnormalized_variance -= correction**2 / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  upper_bound = n_samples * eps * var + (n_samples * mean * eps) ** 2
  temp **= 2
  new_unnormalized_variance -= correction**2 / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  upper_bound = n_samples * eps * var + (n_samples * mean * eps) ** 2
  temp **= 2
  new_unnormalized_variance -= correction**2 / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  upper_bound =

In [38]:
print("roc wise {}({}%) experiments worked better with self destillation".format(len(roc_better), len(roc_better)/(len(roc_better)+len(roc_worse))))
print("balanced accuracy wise {}({}%) experiments worked better with self destillation".format(len(ba_better), len(ba_better)/(len(ba_better)+len(ba_worse))))

roc wise 1(0.1%) experiments worked better with self destillation
balanced accuracy wise 1(0.1%) experiments worked better with self destillation


### Fingerprint mode

In [39]:
# get the experiment list to test
exp_to_test = analysis.get_good_experiment_ids(
    number_to_sample=10
)

100%|██████████████████████████████████████████████████████████████████████████████| 2481/2481 [01:12<00:00, 34.42it/s]


In [40]:
exp_to_test

array([   902,   1458,   1461,   1688,   2551, 485313, 485314, 624297,
       651965, 652104], dtype=int64)

In [41]:
# take balanced accuracy and roc into account as it seems to be the most expressive out of all, 
# because it actually is not over 90% from the beginning and from the documentation it seemed 
# to be the best for imbalanced datasets. Although it is astonishing that the average parameter
# to weighted does not work. Maybe a weight needs to be supplied besides the parameter but I
# rather go with the other options

In [42]:
# better worse counter
roc_better = []
roc_worse = []
ba_better = []
ba_worse = []

In [43]:
# mode = "chem-desc"
mode = "fingerprint"

In [44]:
# iterate over experiments to test
for aid in tqdm(exp_to_test):
    
    # generating random forests
    rf_teacher = creation.generate_default_rf()
    rf_student = creation.generate_default_rf()
    
    # execute self destillation test
    normal_accuracy_dict, sd_accuracy_dict = self_distillation.execute_sd_rf_test(
        rf_teacher=rf_teacher,
        rf_student=rf_student,
        aid=aid,
        mode=mode,
        verbose=False
    )
    
    # calculate mean accuracy values
    normal_mean_dict = rf_analysis.convert_acc_dict(normal_accuracy_dict)
    sd_mean_dict = rf_analysis.convert_acc_dict(sd_accuracy_dict)
    
    # get the difference in accuracy scores
    comparison_dict = rf_analysis.compare_accuracy_dict(normal_mean_dict, sd_mean_dict)
    
    # analyze it
    if comparison_dict["roc"] > 0:
        roc_better.append(aid)
    else:
        roc_worse.append(aid)
    if comparison_dict["balanced_accuracy"] > 0:
        ba_better.append(aid)
    else:
        ba_worse.append(aid)

100%|███████████████████████████████████████████████████████████████████████████████| 10/10 [2:44:23<00:00, 986.31s/it]


In [45]:
print("roc wise {}({}%) experiments worked better with self destillation".format(len(roc_better), len(roc_better)/(len(roc_better)+len(roc_worse))))
print("balanced accuracy wise {}({}%) experiments worked better with self destillation".format(len(ba_better), len(ba_better)/(len(ba_better)+len(ba_worse))))

roc wise 0(0.0%) experiments worked better with self destillation
balanced accuracy wise 0(0.0%) experiments worked better with self destillation
