# Train a model with PyTorch

> This notebook is a simplified version of our workflow. It exposes the basic details of the traning and evaluation loop more explicitly, but does not offer advanced features like early stopping, mini-batches or validation. Use the `*-lightning` version for those.

## How to use

Run `python run_notebook.py --help` for more information.

In [1]:
# If this is the template file (and not a copy) and you are introducing changes,
# update VERSION with the current date (YYYY.MM.DD)
VERSION = "2021.04.09"

## ✏ Define hyper parameters

In [2]:
# TEMPLATE VALUES -- these are overriden (see below if executed) by papermill using a YAML or Python file as input

# DATA -- Glob paths must be relative to the root of the repository: REPO / features
NPZ_FILES = [
    "path/to/*.npz",
]

# Model -- specified with the full import path to the class object
MODEL_CLS = "kinoml.ml.torch_models.NeuralNetworkRegression"
MODEL_KWARGS = {"hidden_size": 350}  # input_shape is defined dynamically during training
WITH_OBSERVATION_MODEL = True

# Adam
LEARNING_RATE = 0.001
EPSILON = 1e-7
BETAS = 0.9, 0.999

# Trainer
MAX_EPOCHS = 50
N_SPLITS = 5
SHUFFLE_FOLDS = False
VALIDATION = False  # TODO: VALIDATION=True is not implemented yet!
MIN_ITEMS_PER_DATASET = 50  # skip datasets if len(data) < N

# Bootstrapping
N_BOOTSTRAPS = 1
BOOTSTRAP_SAMPLE_RATIO = 1

# Output
VERBOSE = False

## IGNORE THIS ONE
HERE = _dh[-1]

In [3]:
# Parameters
NPZ_FILES = [
    "example-ligand-only-chembl28-morgan512-1k-subsample/_output/ligand__SmilesToLigandFeaturizer__MorganFingerprintFeaturizer_nbits=512_radius=2/ChEMBLDatasetProvider/*.npz"
]
MODEL_CLS = "kinoml.ml.torch_models.NeuralNetworkRegression"
MODEL_KWARGS = {"hidden_shape": 350}
WITH_OBSERVATION_MODEL = True
LEARNING_RATE = 0.001
EPSILON = 1e-07
BETAS = [0.9, 0.999]
MAX_EPOCHS = 50
N_SPLITS = 5
SHUFFLE_FOLDS = False
VALIDATION = False
MIN_ITEMS_PER_DATASET = 10
N_BOOTSTRAPS = 1
BOOTSTRAP_SAMPLE_RATIO = 1
VERBOSE = False
HERE = "/home/jaime/devel/py/openkinome/experiments-binding-affinity/experiments/000_example-ligand-only-chembl28-subset"


⚠ From here on, you should _not_ need to modify anything else 🤞

---

Define key paths for data and outputs:

In [4]:
from pathlib import Path
from datetime import datetime

HERE = Path(HERE)

for parent in HERE.parents:
    if next(parent.glob(".github/"), None):
        REPO = parent
        break

FEATURES_STORE = REPO / "features"
        
OUT = HERE / "_output" / datetime.now().strftime("%Y%m%d-%H%M%S")
OUT.mkdir(parents=True, exist_ok=True)

print(f"This notebook:           HERE = ~/{HERE.relative_to(Path.home())}")
print(f"This repo:               REPO = ~/{REPO.relative_to(Path.home())}")
print(f"Features:      FEATURES_STORE = ~/{FEATURES_STORE.relative_to(Path.home())}")
print(f"Outputs in:               OUT = ~/{OUT.relative_to(Path.home())}")

This notebook:           HERE = ~/devel/py/openkinome/experiments-binding-affinity/experiments/000_example-ligand-only-chembl28-subset
This repo:               REPO = ~/devel/py/openkinome/experiments-binding-affinity
Features:      FEATURES_STORE = ~/devel/py/openkinome/experiments-binding-affinity/features
Outputs in:               OUT = ~/devel/py/openkinome/experiments-binding-affinity/experiments/000_example-ligand-only-chembl28-subset/_output/20210409-163512


In [5]:
# Nasty trick: save all-caps local variables (CONSTANTS working as hyperparameters) so far in a dict to save it later
_hparams = {key: value for key, value in locals().items() if key.upper() == key and not key.startswith(("_", "OE_"))}

In [6]:
# TODO: Make all datasets use the same kinase identifiers
ONE_KINASE = {
    "ChEMBLDatasetProvider": "P35968",
    "PKIS2DatasetProvider": "ABL2",
}

In [7]:
from collections import defaultdict
from warnings import warn
import sys
import shutil

from IPython.display import Markdown
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
import pytorch_lightning as pl

from kinoml.utils import seed_everything, import_object
from kinoml.core import measurements as measurement_types
from kinoml.datasets.torch_datasets import XyNpzTorchDataset
from kinoml.core.measurements import null_observation_model

# Fix the seed for reproducible random splits -- otherwise we get mixed train/test groups every time, biasing the model evaluation
seed_everything();
print("Run started at", datetime.now())



Run started at 2021-04-09 16:35:13.808918


## Load featurized data and create observation models

We assume this path structure: `$REPO/features/_output/<FEATURIZATION>/<DATASET>/<GROUP>.npz`

In [8]:
DATASETS = []
MEASUREMENT_TYPES = set()
KINASES = set()
FEATURIZATIONS = set()
for glob in NPZ_FILES:
    npzs = list(FEATURES_STORE.glob(glob))
    if not npzs:
        warn(f"⚠ NPZ glob `{glob}` did not match any files!")
        continue
        
    for npz in npzs:
        kinase, measurement_type = npz.stem.split("__")
        dataset = npz.parent.name
        featurization = npz.parents[1].name
        
        MEASUREMENT_TYPES.add(measurement_type)
        KINASES.add(kinase)
        FEATURIZATIONS.add(featurization)
        
        ds = XyNpzTorchDataset(npz)
        ds.metadata = {
            "kinase": kinase,
            "measurement_type": measurement_type,
            "dataset": dataset,
            "featurization": featurization
        }
        DATASETS.append(ds)
        if not VALIDATION:
            ds.indices["test"] = np.concatenate([ds.indices["test"], ds.indices["val"]])
            ds.indices["val"] = np.array([])

if not DATASETS:
    raise ValueError("Provided `NPZ_FILES` did not result in any valid datasets!")

In [9]:
print("Observed...")
print(" - Measurement types:", len(MEASUREMENT_TYPES), "-->", *MEASUREMENT_TYPES)
print(" - Kinases:", len(KINASES), "-->", *KINASES)

Observed...
 - Measurement types: 3 --> pKiMeasurement pKdMeasurement pIC50Measurement
 - Kinases: 195 --> Q9H4B4 Q16644 P30530 P41743 Q96GD4 Q12866 Q15303 Q99558 Q08881 P07947 Q15759 O00311 O60674 Q56UN5 O43293 Q5S007 Q8TD19 O15111 Q16832 P53778 Q16513 Q96KB5 Q9HAZ1 Q9HC98 P22455 P49674 P51813 P53779 O43353 P19784 Q9NRP7 P54760 O14965 P45983 O96017 O60331 Q96RR4 P37173 P08581 P04629 P51955 Q9UPN9 P51617 O75385 P43405 P22607 Q12852 Q9UHD2 P78527 P41279 P09619 Q16539 P23443 P31751 Q08345 P21802 Q86YV6 P50750 Q15118 P06241 P06213 P29317 Q04759 P35968 P28482 O43318 Q5VT25 P16234 P48736 Q9NWZ3 P49840 P29376 Q58F21 Q02156 P07949 P49137 O00329 O60885 P51812 Q9Y616 P27361 Q8TBX8 Q04771 Q13315 O75716 Q00535 P53667 P06493 P30291 Q15835 P33981 Q14680 P53350 P52333 Q13627 Q15746 Q8N4C8 P31749 P42338 Q05513 Q14164 P08631 P27448 P17948 Q9UM73 Q05397 P68400 P00533 Q9UEE5 P25098 P29597 P11309 P42681 Q8NI60 O95835 P08069 P50613 P20794 P21675 O75460 P04049 P24941 P07333 Q9P1W9 P49760 Q8WTQ7 O14920 Q134

## Check X duplication

There's a chance we have several measurements per ligand, or, depending on the featurization scheme, even hash collisions... Let's quantify the amount of input tensor duplication we are facing.

In [10]:
for mtype in MEASUREMENT_TYPES:
    display(Markdown(f"#### {mtype}"))
    unique = {}
    for ds in DATASETS:
        if ds.metadata["measurement_type"] == mtype:
            all_ = ds.data_X.shape[0]
            unique_ = np.unique(ds.data_X, axis=0).shape[0]
            unique[ds.metadata["kinase"]] = {"all": all_, "unique": unique_}
    df = pd.DataFrame.from_dict(unique).T
    df["uniqueness"] = df["unique"] / df["all"]
    # This is how you highlight rows in pandas!
    df = df.describe().style.apply(lambda x: ['font-weight: bold' for v in x], subset=pd.IndexSlice[["mean", "std"], :])
    display(df)

#### pKiMeasurement

Unnamed: 0,all,unique,uniqueness
count,37.0,37.0,37.0
mean,2.27027,2.27027,1.0
std,2.063555,2.063555,0.0
min,1.0,1.0,1.0
25%,1.0,1.0,1.0
50%,1.0,1.0,1.0
75%,2.0,2.0,1.0
max,10.0,10.0,1.0


#### pKdMeasurement

Unnamed: 0,all,unique,uniqueness
count,58.0,58.0,58.0
mean,1.155172,1.155172,1.0
std,0.410465,0.410465,0.0
min,1.0,1.0,1.0
25%,1.0,1.0,1.0
50%,1.0,1.0,1.0
75%,1.0,1.0,1.0
max,3.0,3.0,1.0


#### pIC50Measurement

Unnamed: 0,all,unique,uniqueness
count,159.0,159.0,159.0
mean,5.339623,5.339623,1.0
std,7.259067,7.259067,0.0
min,1.0,1.0,1.0
25%,1.0,1.0,1.0
50%,3.0,3.0,1.0
75%,6.0,6.0,1.0
max,51.0,51.0,1.0


Now that we have all the data-dependent objects, we can start with the model-specific definitions.

### Training loop

In [11]:
from kinoml.ml.lightning_modules import KFold3Way, KFold
from IPython.display import Markdown
from tqdm.auto import trange, tqdm
from kinoml.ml.torch_models import NeuralNetworkRegression
from ipywidgets import HBox, VBox, Output, HTML
from kinoml.analysis.plots import predicted_vs_observed, performance
from kinoml.utils import fill_until_next_multiple
import pandas as pd
import torch.nn as nn

if VALIDATION:
    kfold = KFold3Way(n_splits=N_SPLITS, shuffle=SHUFFLE_FOLDS)
    ttypes = ["train", "val", "test"]
else:
    kfold = KFold(n_splits=N_SPLITS, shuffle=SHUFFLE_FOLDS)
    ttypes = ["train", "test"]

ModelCls = import_object(MODEL_CLS)
    
kinase_metrics = defaultdict(dict)
for dataset in tqdm(DATASETS):
    kinase = dataset.metadata["kinase"]
    mtype = dataset.metadata["measurement_type"]
    if dataset.data_X.shape[0] < MIN_ITEMS_PER_DATASET:
        warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
        continue
            
    if VERBOSE:
        display(Markdown(f"#### {mtype}"))

    mtype_class = getattr(measurement_types, mtype)
    obs_model = mtype_class.observation_model(backend="pytorch")
    metrics = defaultdict(list)

    for fold_index, splits in enumerate(kfold.split(dataset.data_X, dataset.data_y)):
        if VALIDATION:
            train_indices, val_indices, test_indices = splits
        else:
            train_indices, test_indices = splits

        if VERBOSE:
            display(Markdown(f"##### Fold {fold_index}"))

        ####
        # TRAIN
        ####
        x_train = dataset.data_X[train_indices].float()
        x_test = dataset.data_X[test_indices].float()
        y_train = dataset.data_y[train_indices]
        y_test = dataset.data_y[test_indices]

        if VALIDATION:
            x_val = dataset.data_X[val_indices].float()
            y_val = dataset.data_y[val_indices]
        
        input_shape = ModelCls.estimate_input_shape(x_train)
        nn_model = ModelCls(input_shape=input_shape, **MODEL_KWARGS)
        nn_model.train(True)

        optimizer = torch.optim.Adam(nn_model.parameters(), lr=LEARNING_RATE, eps=EPSILON, betas=BETAS)
        loss_function = torch.nn.MSELoss()

        if VERBOSE:
            range_epochs = trange(MAX_EPOCHS, desc="Epochs (+ featurization...)")
        else:
            range_epochs = range(MAX_EPOCHS)
        for epoch in range_epochs:
            optimizer.zero_grad()

            prediction = nn_model(x_train)
            if WITH_OBSERVATION_MODEL:
                prediction = obs_model(prediction)

            prediction = prediction.view_as(y_train)

            loss = loss_function(prediction, y_train)
            if VERBOSE:
                range_epochs.set_description(f"Epochs (loss={loss.item():.2e})")

            if VALIDATION:
                warn("Validation step not implemented yet")


            # Gradients w.r.t. parameters
            loss.backward()

            # Optimizer
            optimizer.step()
        
        ###
        # Save model's state -- you will still need to instantiate the model class!
        # Possibly using something like:
        # model = import_object(MODEL_CLS)(**MODEL_KWARGS)
        # model.load_state_dict(torch.load("state_dict.pt"))
        ###
        torch.save(nn_model.state_dict(), OUT / f"state_dict_{kinase}_{mtype}_fold{fold_index}.pt")
        
        ####
        # EVAL
        ####
        nn_model.eval()
        outputs = []
        for ttype in ttypes:
            output = Output()
            with output:
                title = f"fold={fold_index}, {ttype}={locals()[f'{ttype}_indices'].shape[0]}"
                print(title)
                print("-"*(len(title)))

                observed = locals()[f"y_{ttype}"]

                with torch.no_grad():
                    predicted = nn_model(locals()[f"x_{ttype}"])
                    if WITH_OBSERVATION_MODEL:
                        predicted = obs_model(predicted)

                predicted = predicted.view_as(observed).detach().numpy()
                observed = observed.detach().numpy()
                these_metrics = performance(predicted, observed, n_boot=N_BOOTSTRAPS, sample_ratio=BOOTSTRAP_SAMPLE_RATIO)
                metrics[ttype].append(these_metrics)
                if VERBOSE:
                    display(predicted_vs_observed(predicted, observed, mtype_class, with_metrics=False))

            outputs.append(output)
        if VERBOSE:
            display(HBox(outputs))

    # Average performances

    average = defaultdict(dict)
    for key in metrics["test"][0]:
        for label in ttypes:
            # this zero here ---v is super important! we only want the mean of the means!
            values =  [fold[key][0] for fold in metrics[label]]
            average[label][key] = {
                "mean": np.mean(values),
                "std": np.std(values)
            }
    if VERBOSE:
        for label in ttypes:    
            display(HTML(f"Bootstrapped average across folds ({label}):"))
            display(pd.DataFrame.from_dict(average[label]))
    kinase_metrics[kinase][mtype] = average

  0%|          | 0/254 [00:00<?, ?it/s]

  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASE

  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASE

  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASE

  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASE

  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASE

  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASE

  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASE

  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASE

  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASE

  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")
  warn(f"Ignoring {kinase} because it has less than {MIN_ITEMS_PER_DATASET} entries for type {mtype}")


### Summary

`kinase_metrics` is a nested dictionary with these dimensions:

- kinase name
- measurement type
- metric
- mean & standard deviation

In [12]:
import json

display(Markdown(f"""
### Configuration 

```json
{json.dumps(_hparams, default=str, indent=2)}
```
"""))

if VERBOSE:
    display(Markdown(f"""

    ### Kinase metrics

    ```json
    {json.dumps(kinase_metrics, default=str, indent=2)}
    ```
    """))


### Configuration 

```json
{
  "VERSION": "2021.04.09",
  "NPZ_FILES": [
    "example-ligand-only-chembl28-morgan512-1k-subsample/_output/ligand__SmilesToLigandFeaturizer__MorganFingerprintFeaturizer_nbits=512_radius=2/ChEMBLDatasetProvider/*.npz"
  ],
  "MODEL_CLS": "kinoml.ml.torch_models.NeuralNetworkRegression",
  "MODEL_KWARGS": {
    "hidden_shape": 350
  },
  "WITH_OBSERVATION_MODEL": true,
  "LEARNING_RATE": 0.001,
  "EPSILON": 1e-07,
  "BETAS": [
    0.9,
    0.999
  ],
  "MAX_EPOCHS": 50,
  "N_SPLITS": 5,
  "SHUFFLE_FOLDS": false,
  "VALIDATION": false,
  "MIN_ITEMS_PER_DATASET": 10,
  "N_BOOTSTRAPS": 1,
  "BOOTSTRAP_SAMPLE_RATIO": 1,
  "VERBOSE": false,
  "HERE": "/home/jaime/devel/py/openkinome/experiments-binding-affinity/experiments/000_example-ligand-only-chembl28-subset",
  "REPO": "/home/jaime/devel/py/openkinome/experiments-binding-affinity",
  "FEATURES_STORE": "/home/jaime/devel/py/openkinome/experiments-binding-affinity/features",
  "OUT": "/home/jaime/devel/py/openkinome/experiments-binding-affinity/experiments/000_example-ligand-only-chembl28-subset/_output/20210409-163512"
}
```


In [13]:
for mtype in MEASUREMENT_TYPES:
    display(Markdown(f"#### {mtype}"))

    dict_of_flattened_metrics = {}
    for kinase_name, measurement_type_dict in sorted(kinase_metrics.items(), key=lambda kv: kv[0].lower()):
        flattened_metrics = {}
        for train_test_key, train_test_dict in measurement_type_dict.get(mtype, {}).items():
            for metric_key, mean_std_dict in train_test_dict.items():
                for mean_std_key, value in mean_std_dict.items():
                    flattened_metrics[f"{train_test_key}_{metric_key}_{mean_std_key}"] = (value,)
        if flattened_metrics:
            dict_of_flattened_metrics[kinase_name] = pd.DataFrame.from_dict(flattened_metrics)
    
    if not dict_of_flattened_metrics:
        continue
    
    df = pd.concat(dict_of_flattened_metrics)
    df.index = [index[0] for index in df.index]
    with pd.option_context("display.float_format", "{:.3f}".format, "display.max_rows", len(df)):
        display(df.style.background_gradient(subset=["train_r2_mean", "test_r2_mean"], low=0, high=1, vmin=0, vmax=1))
        display(df.describe()[["train_r2_mean", "train_r2_std", "test_r2_mean", "test_r2_std"]].describe().style.apply(lambda x: ['font-weight: bold' for v in x], subset=pd.IndexSlice[["mean", "std"], :]))

#### pKiMeasurement

Unnamed: 0,train_mae_mean,train_mae_std,train_mse_mean,train_mse_std,train_r2_mean,train_r2_std,train_rmse_mean,train_rmse_std,test_mae_mean,test_mae_std,test_mse_mean,test_mse_std,test_r2_mean,test_r2_std,test_rmse_mean,test_rmse_std
P11309,1.085747,0.101512,1.398925,0.200672,0.431285,0.197629,1.179673,0.085417,0.828195,0.61018,1.058226,1.182908,0.0,0.0,0.828195,0.61018


Unnamed: 0,train_r2_mean,train_r2_std,test_r2_mean,test_r2_std
count,7.0,7.0,7.0,7.0
mean,0.51253,0.312253,0.142857,0.142857
std,0.214954,0.303268,0.377964,0.377964
min,0.431285,0.197629,0.0,0.0
25%,0.431285,0.197629,0.0,0.0
50%,0.431285,0.197629,0.0,0.0
75%,0.431285,0.197629,0.0,0.0
max,1.0,1.0,1.0,1.0


#### pKdMeasurement

#### pIC50Measurement

Unnamed: 0,train_mae_mean,train_mae_std,train_mse_mean,train_mse_std,train_r2_mean,train_r2_std,train_rmse_mean,train_rmse_std,test_mae_mean,test_mae_std,test_mse_mean,test_mse_std,test_r2_mean,test_r2_std,test_rmse_mean,test_rmse_std
O00329,0.525646,0.043479,0.39319,0.071145,0.699709,0.082722,0.624597,0.055393,1.45805,0.531108,2.677714,1.710035,-4.760219,5.477612,1.539789,0.553862
O60674,0.446959,0.101456,0.301975,0.139246,0.328279,0.251472,0.531679,0.138897,1.08738,0.477022,1.822168,1.324814,-2.298713,1.544693,1.238906,0.535986
O60885,0.435203,0.071846,0.279574,0.075745,0.724047,0.069094,0.523652,0.07323,1.754681,0.544482,3.943568,2.2168,-5.924624,6.804049,1.902226,0.570178
P00533,0.762394,0.160354,1.095256,0.447002,0.164854,0.30365,1.026564,0.203522,1.239931,0.271924,2.132873,0.841865,-2.021537,2.474959,1.427306,0.309308
P04629,0.448477,0.091394,0.305146,0.095114,0.715593,0.090705,0.546416,0.08109,1.540383,1.129902,4.28852,4.909982,-1.041084,1.683812,1.676578,1.215569
P08069,0.286614,0.023407,0.132554,0.021286,0.919971,0.022884,0.362915,0.029092,0.929444,0.254186,1.40589,0.696593,-0.413803,1.272483,1.139183,0.328864
P08581,0.464253,0.099122,0.288348,0.118349,0.367026,0.2022,0.524824,0.113612,1.392757,0.283145,2.770075,1.104212,-13.472117,16.952626,1.622786,0.369651
P11309,0.535987,0.14711,0.357249,0.184009,0.77935,0.090054,0.578114,0.151766,1.341001,1.568042,4.257041,7.645704,0.0,0.0,1.341001,1.568042
P11362,0.484812,0.167248,0.34929,0.181865,0.829609,0.098085,0.561656,0.183937,1.129823,0.519262,1.778863,1.702729,-0.511379,1.928131,1.221592,0.535327
P12931,0.383771,0.064298,0.228303,0.057068,0.87725,0.012091,0.474251,0.058211,1.249128,0.723984,2.497248,2.118746,-48.633258,95.345515,1.395922,0.74071


Unnamed: 0,train_r2_mean,train_r2_std,test_r2_mean,test_r2_std
count,8.0,8.0,8.0,8.0
mean,3.440361,3.13562,-13.606561,50.605678
std,8.312912,8.431346,49.907399,83.061527
min,-0.044062,0.012091,-128.639628,0.0
25%,0.344357,0.102446,-17.098206,1.716792
50%,0.621378,0.129586,-1.488227,26.258118
75%,0.798736,0.261692,6.0,42.246385
max,24.0,24.0,34.075038,248.583672


In [14]:
print("Run finished at", datetime.now())

Run finished at 2021-04-09 16:35:35.991999


### Save reports to disk

In [15]:
from kinoml.utils import watermark
w = watermark()

Watermark
---------
Last updated: 2021-04-09T16:35:36.040873+02:00

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.22.0

Compiler    : GCC 9.3.0
OS          : Linux
Release     : 4.19.128-microsoft-standard
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit

Hostname: jrodriguez

Git hash: d5cd0b3e0d894abb37c63da95d9a510eb6b1997d

sys              : 3.8.8 | packaged by conda-forge | (default, Feb 20 2021, 16:22:27) 
[GCC 9.3.0]
pandas           : 1.2.3
pytorch_lightning: 1.2.7
json             : 2.0.9
torch            : 1.7.1.post2
kinoml           : 0+untagged.409.gc193ba7.dirty
numpy            : 1.20.2

Watermark: 2.2.0


nvidia-smi
----------
stdout:
NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.

conda info
----------


sys.version: 3.7.6 | packaged by conda-forge | (defau...
sys.prefix: /opt/miniconda
sys.executable: /opt/miniconda/bin/python
conda location: /opt/miniconda/lib/python3.7/site-packages/conda
conda-build: /opt/miniconda/bin/conda-build
conda-convert: /opt/miniconda/bin/conda-convert
conda-debug: /opt/miniconda/bin/conda-debug
conda-develop: /opt/miniconda/bin/conda-develop
conda-env: /opt/miniconda/bin/conda-env
conda-index: /opt/miniconda/bin/conda-index
conda-inspect: /opt/miniconda/bin/conda-inspect
conda-metapackage: /opt/miniconda/bin/conda-metapackage
conda-render: /opt/miniconda/bin/conda-render
conda-server: /opt/miniconda/bin/conda-server
conda-skeleton: /opt/miniconda/bin/conda-skeleton
conda-smithy: /opt/miniconda/bin/conda-smithy
user site dirs: ~/.local/lib/python3.8
                ~/.local/lib/python3.7
                ~/.local/lib/python3.6

CIO_TEST: <not set>
CONDA_DEFAULT_ENV: experiments-binding-affinity
CONDA_EXE: /opt/miniconda/bin/conda
CONDA_PREFIX: /home/jaime/.

# packages in environment at /home/jaime/.conda/envs/experiments-binding-affinity:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                      1_llvm    conda-forge
_py-xgboost-mutex         2.0                       cpu_0    conda-forge
absl-py                   0.12.0             pyhd8ed1ab_0    conda-forge
aiohttp                   3.7.4            py38h497a2fe_0    conda-forge
alabaster                 0.7.12                   pypi_0    pypi
amberlite                 16.0                     pypi_0    pypi
ambertools                20.15                    pypi_0    pypi
ansiwrap                  0.8.4                      py_0    conda-forge
anyio                     2.2.0            py38h578d9bd_0    conda-forge
appdirs                   1.4.4              pyh9f0ad1d_0    conda-forge
argon2-cffi               20.1.0           py38h497a2fe_2    conda

In [16]:
%%capture cap --no-stderr
w = watermark()

In [17]:
import json

df.to_csv(OUT / "performance.csv")

with open(OUT / "performance.json", "w") as f:
    json.dump(kinase_metrics, f, default=str, indent=2)
    
with open(OUT/ "watermark.txt", "w") as f:
    f.write(cap.stdout)

with open(OUT / "hparams.json", "w") as f:
    json.dump(_hparams, f, default=str, indent=2)