# Random Forest Self Distillation

In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.cuda.is_available()

True

## Import section

In [1]:
from selfdest_toolkit.data_tools import preprocessing, loading, cleaning, sd_data_utils

In [5]:
import pandas as pd
import numpy as np
import sklearn
from sklearn.ensemble import RandomForestClassifier
import os
from tqdm import tqdm
import json
from rdkit.Chem.Descriptors import descList
from rdkit.Chem import MolFromSmiles, RDKFingerprint
from sklearn.model_selection import cross_validate

## Data Preprocessing

In [3]:
# path to dataset
PATH_DATA = "data/"
PATH_MAIN_DATASET = PATH_DATA + "df_assay_entries.csv"

In [6]:
aids = preprocessing.experiment_whole_preprocess(PATH_MAIN_DATASET, PATH_DATA)

Data file already present, no need for download.


100%|████████████████████████████████████████████████████████████████████████████| 2481/2481 [00:00<00:00, 5353.54it/s]


Chemical descriptor data already generated
Fingerprints already generated


## Individual Data Loading

In [7]:
c_sampledata, c_samplelabel = loading.load_chem_desc_data(411)

In [8]:
f_sampledata, f_samplelabel = loading.load_fingerprint_data(411)

In [9]:
c_sampledata

array([[12.0147147 , -0.2333767 , 12.0147147 , ...,  0.        ,
         1.        ,  0.        ],
       [13.68844293, -1.04921672, 13.68844293, ...,  0.        ,
         0.        ,  0.        ],
       [13.75921567, -0.69660748, 13.75921567, ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [12.81165769, -1.13973262, 12.81165769, ...,  0.        ,
         0.        ,  0.        ],
       [13.06424981, -1.12117352, 13.06424981, ...,  0.        ,
         0.        ,  0.        ],
       [12.39603269, -1.12400132, 12.39603269, ...,  0.        ,
         0.        ,  0.        ]])

In [10]:
c_samplelabel

array([0, 0, 0, ..., 0, 0, 0])

In [11]:
f_sampledata

array([[0, 1, 1, ..., 0, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [0, 1, 1, ..., 0, 0, 1],
       ...,
       [1, 1, 0, ..., 0, 0, 1],
       [1, 1, 0, ..., 0, 0, 1],
       [1, 1, 0, ..., 0, 0, 1]])

In [12]:
f_samplelabel

array([0, 0, 0, ..., 0, 0, 0])

In [13]:
print(c_sampledata.shape)
print(f_sampledata.shape)

(68285, 208)
(68285, 2048)


In [14]:
f_samplelabel.sum()

1536

## Preloading all data

In [4]:
# loading.preload_fingerprint_data_all(aids)

In [5]:
# loading.preload_chem_data_all(aids)

## Preparing for self distillation

In [4]:
# define experiment id
aid = 411

In [5]:
# get the prediction data
data, labels = loading.load_chem_desc_data(411)

In [6]:
# determine number of elements to fetch for self destillation
number_sd = int(data.shape[0]*0.2+0.5) # set to percentage perhaps

In [7]:
# get self distillation elements
sd_data = sd_data_utils.generate_self_distillation_elements(
    aid=aid,
    number_to_generate=number_sd,
    data_gen_method="chem-desc",
    path_data=PATH_DATA
)

## Normal Random Forest

In [1]:
from selfdest_toolkit.randomforest_tools import creation, normal
import json
import typing

In [2]:
# define experiment id
aid = 411

In [3]:
# create random forest
rf = creation.generate_default_rf()

In [4]:
# execute normal random forest testing
accuracy_dict = normal.execute_normal_rf_test(
    rf=rf,
    aid=aid,
    mode="chem-desc"
)

5it [03:29, 41.84s/it]


In [7]:
print(json.dumps(accuracy_dict, indent=4))

{
    "accuracy": [
        0.978692245734788,
        0.9787654682580361,
        0.9786190232115398,
        0.9789119133045324,
        0.9781796880720509
    ],
    "balanced_accuracy": [
        0.5435619563493516,
        0.5404169869096853,
        0.5419332918542376,
        0.5516303724578804,
        0.5336788932095696
    ],
    "roc": [
        0.5435619563493517,
        0.5404169869096854,
        0.5419332918542376,
        0.5516303724578804,
        0.5336788932095696
    ],
    "precision": [
        0.9733954812766568,
        0.9743169963902334,
        0.9731493164173118,
        0.9737573577524699,
        0.9716583967760595
    ],
    "recall": [
        0.978692245734788,
        0.9787654682580361,
        0.9786190232115398,
        0.9789119133045324,
        0.9781796880720509
    ]
}


In [3]:
def convert_acc_dict(
    acc_dict: typing.Dict[str, typing.List[float]]
) -> typing.Dict[str, float]:
    
    # create placeholder dict
    output = {}
    
    # iterate over old dict
    for key in acc_dict:
        
        # calculate average
        output[key] = sum(acc_dict[key])/len(acc_dict[key])
    
    # return the dict
    return output

In [11]:
normal_mean_dict = convert_acc_dict(accuracy_dict)
print(json.dumps(normal_mean_dict, indent=4))

{
    "accuracy": 0.9786336677161895,
    "balanced_accuracy": 0.5422443001561448,
    "roc": 0.5422443001561449,
    "precision": 0.9732555097225463,
    "recall": 0.9786336677161895
}


## Self distillation Random Forest

In [4]:
from selfdest_toolkit.randomforest_tools import self_distillation

In [7]:
# generating random forests
rf_teacher = creation.generate_default_rf()
rf_student = creation.generate_default_rf()

In [8]:
normal_accuracy_dict, sd_accuracy_dict = self_distillation.execute_sd_rf_test(
    rf_teacher=rf_teacher,
    rf_student=rf_student,
    aid=aid,
    mode="chem-desc"
)

5it [08:48, 105.72s/it]


In [9]:
normal_mean_dict = convert_acc_dict(normal_accuracy_dict)
sd_mean_dict = convert_acc_dict(sd_accuracy_dict)

In [10]:
print(json.dumps(normal_mean_dict, indent=4))

{
    "accuracy": 0.9785018671743428,
    "balanced_accuracy": 0.5402642575677149,
    "roc": 0.5402642575677149,
    "precision": 0.9727097179111954,
    "recall": 0.9785018671743428
}


In [11]:
print(json.dumps(sd_mean_dict, indent=4))

{
    "accuracy": 0.9786043787068902,
    "balanced_accuracy": 0.5377760446035238,
    "roc": 0.5377760446035238,
    "precision": 0.9737700557750006,
    "recall": 0.9786043787068902
}


In [12]:
def compare_accuracy_dict(normal, sd):
    # generate comparison dict
    comparison_dict = {}
    
    # go over individual dicts
    for key in normal:
        comparison_dict[key] = (sd[key] - normal[key])
    
    return comparison_dict

In [14]:
comparison_dict = compare_accuracy_dict(normal_mean_dict, sd_mean_dict)

In [15]:
print(json.dumps(comparison_dict, indent=4))

{
    "accuracy": 0.00010251153254736689,
    "balanced_accuracy": -0.0024882129641911055,
    "roc": -0.0024882129641911055,
    "precision": 0.0010603378638052163,
    "recall": 0.00010251153254736689
}


## Testing for self distillation success for a number of experiments

In [1]:
from selfdest_toolkit.data_tools import analysis

In [2]:
# get the experiment list to test
exp_to_test = analysis.get_good_experiment_ids(
    number_to_sample=10
)

100%|██████████████████████████████████████████████████████████████████████████████| 2481/2481 [01:52<00:00, 22.02it/s]


In [3]:
exp_to_test

array([  1688, 624297,    902, 485314, 651965,   1461,   1458, 485313,
       652104,   2551], dtype=int64)