Transfer learning with PDX

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt
import PIL

import os
import sys
from pathlib import Path
fdir = Path.cwd()
print(fdir)
sys.path.append(str(fdir/'..'))

/vol/ml/apartin/projects/pdx-histo/nbs


In [2]:
# https://stackoverflow.com/questions/37893755/tensorflow-set-cuda-visible-devices-within-jupyter
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [3]:
# %env CUDA_DEVICE_ORDER=PCI_BUS_ID
# %env CUDA_VISIBLE_DEVICES=1

In [4]:
# https://www.codegrepper.com/code-examples/python/suppres+tensorflow+warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

from collections import OrderedDict
import glob
from pathlib import Path
from pprint import pprint, pformat
import shutil
from time import time

# import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix

import warnings
warnings.filterwarnings("ignore")

import tensorflow as tf
assert tf.__version__ >= "2.0"

from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Dense, Dropout, Activation, BatchNormalization
from tensorflow.keras import losses
from tensorflow.keras import optimizers
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping

import src
from src.config import cfg
from src.models import build_model_rsp, build_model_rsp_baseline, keras_callbacks, load_best_model
from src.ml.scale import get_scaler
from src.ml.evals import calc_scores, save_confusion_matrix
from src.ml.keras_utils import plot_prfrm_metrics
from src.utils.classlogger import Logger
from src.utils.utils import (cast_list, create_outdir, create_outdir_2, dump_dict, fea_types_to_str_name,
                             get_print_func, read_lines, Params, Timer)
from src.datasets.tidy import split_data_and_extract_fea, extract_fea, TidyData
from src.tf_utils import get_tfr_files
from src.sf_utils import (create_manifest, create_tf_data, calc_class_weights,
                          parse_tfrec_fn_rsp, parse_tfrec_fn_rna)
from src.sf_utils import bold, green, blue, yellow, cyan, red

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [5]:
dataname="tidy_drug_pairs_all_samples"
prjname="bin_rsp_drug_pairs_all_samples"
id_name="smp"
target="Response"
split_on="Group"
n_samples=-1
tfr_dir_name="PDX_FIXED_RSP_DRUG_PAIR_0.1_of_tiles"
trn_phase="train"
use_tile=True
use_ge=True
use_dd1=True
use_dd2=True
scale_fea=False

print_fn = print

In [6]:
# Create project dir (if it doesn't exist)
prjdir = cfg.MAIN_PRJDIR/prjname

# fea_names = "tile_dd1_dd2"
fea_names = "tile_ge_dd1_dd2"
prm_file_path = fdir/f"../default_params/default_params_{fea_names}.json"
params = Params(prm_file_path)

# Load dataframe (annotations)
annotations_file = cfg.DATA_PROCESSED_DIR/dataname/cfg.SF_ANNOTATIONS_FILENAME
data = pd.read_csv(annotations_file)
data = data.astype({"image_id": str, "slide": str})
print(data.shape)

# Determine tfr_dir (where TFRecords are stored)
# if args.target[0] == "Response":
#     if params.single_drug:
#         tfr_dir = cfg.SF_TFR_DIR_RSP
#     else:
#         tfr_dir = (cfg.DATADIR/args.tfr_dir_name).resolve()
# elif args.target[0] == "ctype":
#     tfr_dir = cfg.SF_TFR_DIR_RNA_NEW

tfr_dir = (cfg.DATADIR/tfr_dir_name).resolve()
label = f"{params.tile_px}px_{params.tile_um}um"
tfr_dir = tfr_dir/label

# Scalers for each feature set
ge_scaler, dd1_scaler, dd2_scaler = None, None, None

ge_cols  = [c for c in data.columns if c.startswith("ge_")]
dd1_cols = [c for c in data.columns if c.startswith("dd1_")]
dd2_cols = [c for c in data.columns if c.startswith("dd2_")]

if scale_fea:
    if use_ge and len(ge_cols) > 0:
        ge_scaler = get_scaler(data[ge_cols])
    if use_dd1 and len(dd1_cols) > 0:
        dd1_scaler = get_scaler(data[dd1_cols])
    if use_dd2 and len(dd2_cols) > 0:
        dd2_scaler = get_scaler(data[dd2_cols])

(6962, 4950)


In [7]:
# -----------------------------------------------
# Data splits
# -----------------------------------------------
# if params.drug_specific is None:
if use_dd1 is False and use_dd2 is False:
    splitdir = cfg.DATADIR/"PDX_Transfer_Learning_Classification/Processed_Data/Data_For_MultiModal_Learning/Data_Partition_Drug_Specific"
    splitdir = splitdir/params.drug_specific
    split_id = 0
else:
    splitdir = cfg.DATADIR/"PDX_Transfer_Learning_Classification/Processed_Data/Data_For_MultiModal_Learning/Data_Partition"
    split_id = 81
    # split_id = 0

tr_id = cast_list(read_lines(str(splitdir/f"cv_{split_id}"/"TrainList.txt")), int)
vl_id = cast_list(read_lines(str(splitdir/f"cv_{split_id}"/"ValList.txt")), int)
te_id = cast_list(read_lines(str(splitdir/f"cv_{split_id}"/"TestList.txt")), int)

# Update ids
index_col_name = "index"
tr_id = sorted(set(data[index_col_name]).intersection(set(tr_id)))
vl_id = sorted(set(data[index_col_name]).intersection(set(vl_id)))
te_id = sorted(set(data[index_col_name]).intersection(set(te_id)))

# Subsample train samples
if n_samples > 0:
    if n_samples < len(tr_id):
        tr_id = tr_id[:n_samples]
    if n_samples < len(vl_id):
        vl_id = vl_id[:n_samples]
    if n_samples < len(te_id):
        te_id = te_id[:n_samples]

In [8]:
# --------------
# w/o TidyData
# --------------
kwargs = {"ge_cols": ge_cols,
          "dd1_cols": dd1_cols,
          "dd2_cols": dd2_cols,
          "ge_scaler": ge_scaler,
          "dd1_scaler": dd1_scaler,
          "dd2_scaler": dd2_scaler,
          "ge_dtype": cfg.GE_DTYPE,
          "dd_dtype": cfg.DD_DTYPE,
          "index_col_name": index_col_name,
          "split_on": split_on
          }
tr_ge, tr_dd1, tr_dd2, tr_meta = split_data_and_extract_fea(data, ids=tr_id, **kwargs)
vl_ge, vl_dd1, vl_dd2, vl_meta = split_data_and_extract_fea(data, ids=vl_id, **kwargs)
te_ge, te_dd1, te_dd2, te_meta = split_data_and_extract_fea(data, ids=te_id, **kwargs)

ge_shape = (tr_ge.shape[1],)
dd_shape = (tr_dd1.shape[1],)

# Make sure indices do not overlap
assert len( set(tr_id).intersection(set(vl_id)) ) == 0, "Overlapping indices btw tr and vl"
assert len( set(tr_id).intersection(set(te_id)) ) == 0, "Overlapping indices btw tr and te"
assert len( set(vl_id).intersection(set(te_id)) ) == 0, "Overlapping indices btw vl and te"

# Print split ratios
print_fn("")
print_fn("Train samples {} ({:.2f}%)".format( len(tr_id), 100*len(tr_id)/data.shape[0] ))
print_fn("Val   samples {} ({:.2f}%)".format( len(vl_id), 100*len(vl_id)/data.shape[0] ))
print_fn("Test  samples {} ({:.2f}%)".format( len(te_id), 100*len(te_id)/data.shape[0] ))

tr_grp_unq = set(tr_meta[split_on].values)
vl_grp_unq = set(vl_meta[split_on].values)
te_grp_unq = set(te_meta[split_on].values)
print_fn("")
print_fn(f"Total intersects on {split_on} btw tr and vl: {len(tr_grp_unq.intersection(vl_grp_unq))}")
print_fn(f"Total intersects on {split_on} btw tr and te: {len(tr_grp_unq.intersection(te_grp_unq))}")
print_fn(f"Total intersects on {split_on} btw vl and te: {len(vl_grp_unq.intersection(te_grp_unq))}")
print_fn(f"Unique {split_on} in tr: {len(tr_grp_unq)}")
print_fn(f"Unique {split_on} in vl: {len(vl_grp_unq)}")
print_fn(f"Unique {split_on} in te: {len(te_grp_unq)}")


# --------------------------
# Obtain T/V/E tfr filenames
# --------------------------
# List of sample names for T/V/E
tr_smp_names = list(tr_meta[id_name].values)
vl_smp_names = list(vl_meta[id_name].values)
te_smp_names = list(te_meta[id_name].values)

# TFRecords filenames
train_tfr_files = get_tfr_files(tfr_dir, tr_smp_names)
val_tfr_files = get_tfr_files(tfr_dir, vl_smp_names)
test_tfr_files = get_tfr_files(tfr_dir, te_smp_names)
print("Total samples {}".format(len(train_tfr_files) + len(val_tfr_files) + len(test_tfr_files)))

# Missing tfrecords
print("\nThese samples miss a tfrecord:")
df_miss = data.loc[~data[id_name].isin(tr_smp_names + vl_smp_names + te_smp_names), ["smp", "image_id"]]
print(df_miss)

assert sorted(tr_smp_names) == sorted(tr_meta[id_name].values.tolist()), "Sample names in the tr_smp_names and tr_meta don't match."
assert sorted(vl_smp_names) == sorted(vl_meta[id_name].values.tolist()), "Sample names in the vl_smp_names and vl_meta don't match."
assert sorted(te_smp_names) == sorted(te_meta[id_name].values.tolist()), "Sample names in the te_smp_names and te_meta don't match."


Train samples 5574 (80.06%)
Val   samples 690 (9.91%)
Test  samples 698 (10.03%)

Total intersects on Group btw tr and vl: 0
Total intersects on Group btw tr and te: 0
Total intersects on Group btw vl and te: 0
Unique Group in tr: 760
Unique Group in vl: 97
Unique Group in te: 102
Total samples 6962

These samples miss a tfrecord:
Empty DataFrame
Columns: [smp, image_id]
Index: []


In [9]:
# -------------------------------
# Class weight
# -------------------------------
tile_cnts = pd.read_csv(tfr_dir/"tile_counts_per_slide.csv")
tile_cnts.insert(loc=0, column="tfr_abs_fname", value=tile_cnts["tfr_fname"].map(lambda s: str(tfr_dir/s)))
cat = tile_cnts[tile_cnts["tfr_abs_fname"].isin(train_tfr_files)]
cat = cat.groupby(target).agg({"smp": "nunique", "max_tiles": "sum", "n_tiles": "sum", "slide": "nunique"}).reset_index()
categories = {}
for i, row_data in cat.iterrows():
    dct = {"num_samples": row_data["smp"], "num_tiles": row_data["n_tiles"]}
    categories[row_data[target]] = dct

class_weight = calc_class_weights(train_tfr_files,
                                  class_weights_method=params.class_weights_method,
                                  categories=categories)
print(categories)
print(class_weight)

{0: {'num_samples': 5309, 'num_tiles': 201000}, 1: {'num_samples': 265, 'num_tiles': 11422}}
{0: 0.5284129353233831, 1: 9.29880931535633}


In [10]:
# -------------------------------
# Parsing funcs
# -------------------------------
if target == "Response":
    # Response
    parse_fn = parse_tfrec_fn_rsp
    parse_fn_train_kwargs = {
        "use_tile": use_tile,
        "use_ge": use_ge,
        "use_dd1": use_dd1,
        "use_dd2": use_dd2,
        "ge_scaler": ge_scaler,
        "dd1_scaler": dd1_scaler,
        "dd2_scaler": dd2_scaler,
        "id_name": id_name,
        "augment": params.augment,
        "application": params.base_image_model,
    }
else:
    # Ctype
    parse_fn = parse_tfrec_fn_rna
    parse_fn_train_kwargs = {
        'use_tile': use_tile,
        'use_ge': use_ge,
        'ge_scaler': ge_scaler,
        'id_name': id_name,
        'MODEL_TYPE': params.model_type,
        'AUGMENT': params.augment,
    }

parse_fn_non_train_kwargs = parse_fn_train_kwargs.copy()
parse_fn_non_train_kwargs["augment"] = False


# ----------------------------------------
# Number of tiles/examples in each dataset
# ----------------------------------------
# import ipdb; ipdb.set_trace()
tr_tiles = tile_cnts[tile_cnts[id_name].isin(tr_smp_names)]["n_tiles"].sum()
vl_tiles = tile_cnts[tile_cnts[id_name].isin(vl_smp_names)]["n_tiles"].sum()
te_tiles = tile_cnts[tile_cnts[id_name].isin(te_smp_names)]["n_tiles"].sum()

eval_batch_size = 8 * params.batch_size
tr_steps = tr_tiles // params.batch_size
vl_steps = vl_tiles // eval_batch_size
te_steps = te_tiles // eval_batch_size


# -------------------------------
# Create TF datasets
# -------------------------------
print("\nCreating TF datasets.")

# Training
# import ipdb; ipdb.set_trace()
train_data = create_tf_data(
    batch_size=params.batch_size,
    deterministic=False,
    include_meta=False,
    interleave=True,
    n_concurrent_shards=params.n_concurrent_shards,  # 32, 64
    parse_fn=parse_fn,
    prefetch=1,  # 2
    repeat=True,
    seed=None,  # cfg.seed,
    shuffle_files=True,
    shuffle_size=params.shuffle_size,  # 8192
    tfrecords=train_tfr_files,
    **parse_fn_train_kwargs)

# Determine feature shapes from data
bb = next(train_data.__iter__())

# Infer dims of features from the data
# import ipdb; ipdb.set_trace()
if use_ge:
    ge_shape = bb[0]["ge_data"].numpy().shape[1:]
else:
    ge_shape = None

if use_dd1:
    dd_shape = bb[0]["dd1_data"].numpy().shape[1:]
else:
    dd_shape = None

# Print keys and dims
for i, item in enumerate(bb):
    print(f"\nItem {i}")
    if isinstance(item, dict):
        for k in item.keys():
            print(f"\t{k}: {item[k].numpy().shape}")
    elif isinstance(item.numpy(), np.ndarray):
        print(item)

# for i, rec in enumerate(train_data.take(2)):
#     tf.print(rec[1])

# Evaluation (val, test, train)
create_tf_data_eval_kwargs = {
    "batch_size": eval_batch_size,
    "include_meta": False,
    "interleave": False,
    "parse_fn": parse_fn,
    "prefetch": None,  # 2
    "repeat": False,
    "seed": None,
    "shuffle_files": False,
    "shuffle_size": None,
}

create_tf_data_eval_kwargs.update({"tfrecords": val_tfr_files, "include_meta": False})
val_data = create_tf_data(
    **create_tf_data_eval_kwargs,
    **parse_fn_non_train_kwargs
)

create_tf_data_eval_kwargs.update({"tfrecords": test_tfr_files, "include_meta": True})
test_data = create_tf_data(
    **create_tf_data_eval_kwargs,
    **parse_fn_non_train_kwargs
)

create_tf_data_eval_kwargs.update({"tfrecords": val_tfr_files, "include_meta": True})
eval_val_data = create_tf_data(
    **create_tf_data_eval_kwargs,
    **parse_fn_non_train_kwargs
)

create_tf_data_eval_kwargs.update({"tfrecords": train_tfr_files, "include_meta": True})
eval_train_data = create_tf_data(
    **create_tf_data_eval_kwargs,
    **parse_fn_non_train_kwargs
)


Creating TF datasets.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'

Item 0
	tile_image: (64, 299, 299, 3)
	ge_data: (64, 942)
	dd1_data: (64, 1993)
	dd2_data: (64, 1993)

Item 1
tf.Tensor(
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], shape=(64,), dtype=int64)


In [11]:
# Mixed precision
if params.use_fp16:
    if int(tf.keras.__version__.split(".")[1]) == 4:  # TF 2.4
        from tensorflow.keras import mixed_precision
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
    elif int(tf.keras.__version__.split(".")[1]) == 3:  # TF 2.3
        from tensorflow.keras.mixed_precision import experimental as mixed_precision
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_policy(policy)
        
# Target
loss = losses.BinaryCrossentropy()

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: GeForce RTX 2080 Ti, compute capability 7.5


In [12]:
# Calc output bias
if use_tile:
    # from sf_utils import get_categories_from_manifest
    # categories = get_categories_from_manifest(train_tfr_files, manifest, outcomes)
    neg = categories[0]["num_tiles"]
    pos = categories[1]["num_tiles"]
else:
    neg, pos = np.bincount(tr_meta[args.target[0]].values)

total = neg + pos
print_fn("Samples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n".format(total, pos, 100 * pos / total))
output_bias = np.log([pos/neg])
print_fn(f"Output bias: {output_bias}")
# output_bias = None

Samples:
    Total: 212422
    Positive: 11422 (5.38% of total)

Output bias: [-2.86776359]


In [13]:
# # bb = train_data.take(1)
# bb = next(eval_train_data.__iter__())
# print(len(bb))
# print(bb[0].keys())
# print(len(bb[1]))
# print(bb[2].keys())

## Show images

In [14]:
plt.figure(figsize=(10, 10))
for xdata, labels, meta in eval_train_data.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        # plt.imshow(xdata["tile_image"][i].numpy().astype("uint8"))
        plt.imshow(xdata["tile_image"][i].numpy())
        # plt.title(f"{labels.numpy()[i]}; {meta['tile_id'].numpy()[i]}; {meta['image_id'].numpy()[i]}; {meta['ctype'].numpy()[i]}")
        plt.title(f"{labels.numpy()[i]}")
        plt.axis("off")

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


In [15]:
# plt.figure(figsize=(20, 20))
# for xdata, labels, meta in eval_train_data.take(1):
#     for i in range(9):
#         ax = plt.subplot(3, 3, i + 1)
#         plt.imshow(xdata["tile_image"][i].numpy()) # .astype("uint8"))
#         plt.title("{}; slide: {}; tile: {}; trt; {}".format(labels.numpy()[i],
#                                                             meta['image_id'].numpy()[i].decode("utf-8"),
#                                                             meta['tile_id'].numpy()[i].decode("utf-8"),
#                                                             meta['trt'].numpy()[i].decode("utf-8")))
#         plt.axis("off")

In [16]:
x = xdata["tile_image"][i].numpy()
print(x.mean())
print(x.std())
print(x.min())
print(x.max())

0.33830264
0.35412648
-1.0
1.0


## Model

In [17]:
callbacks = []
monitor = "val_loss"
patience = 7

reduce_lr = ReduceLROnPlateau(monitor=monitor,
                              factor=0.5,
                              patience=5,
                              verbose=1,
                              mode="auto",
                              min_delta=0.0001,
                              cooldown=0,
                              min_lr=0)
callbacks.append(reduce_lr)

early_stop = EarlyStopping(monitor=monitor,
                           patience=patience,
                           mode="auto",
                           restore_best_weights=True,
                           verbose=1)
callbacks.append(early_stop)

In [18]:
dense1_dd1 = params.dense1_dd1
dense1_dd2 = params.dense1_dd2
dense1_ge = params.dense1_ge
dense1_img = params.dense1_img
dense2_img = params.dense2_img
dense1_top = params.dense1_top
learning_rate = params.learning_rate
dropout1_top = params.dropout1_top

pretrain = "imagenet"
pooling = "avg"

In [19]:
# if output_bias is not None:
#     output_bias = tf.keras.initializers.Constant(output_bias)

# model_inputs = []
# merge_inputs = []

# if use_tile:
#     image_shape = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE, 3)
#     tile_input_tensor = tf.keras.Input(shape=image_shape, name="tile_image")
#     base_img_model = tf.keras.applications.Xception(
#         include_top=False,
#         weights=pretrain,
#         input_shape=None,
#         input_tensor=None,
#         pooling=pooling)

#     # import ipdb; ipdb.set_trace()
#     # print(len(base_img_model.trainable_weights))
#     # print(len(base_img_model.non_trainable_weights))
#     # print(len(base_img_model.layers))
#     # base_img_model.trainable = False
#     # x_tile = keras.layers.GlobalAveragePooling2D()(tile_input_tensor)

#     x_tile = base_img_model(tile_input_tensor)
#     model_inputs.append(tile_input_tensor)

#     if dense1_img > 0:
#         x_tile = Dense(dense1_img, activation=tf.nn.relu, name="dense1_img")(x_tile)
#         # x_tile = BatchNormalization(name="batchnorm_im")(x_tile)
#     if dense2_img > 0:
#         x_tile = Dense(dense2_img, activation=tf.nn.relu, name="dense2_img")(x_tile)
#     if (dense1_img > 0) or (dense2_img > 0):
#         x_tile = BatchNormalization(name="batchnorm_im")(x_tile)
#     merge_inputs.append(x_tile)
#     del tile_input_tensor, x_tile

# if use_ge:
#     ge_input_tensor = tf.keras.Input(shape=ge_shape, name="ge_data")
#     x_ge = Dense(dense1_ge, activation=tf.nn.relu, name="dense1_ge")(ge_input_tensor)
#     x_ge = BatchNormalization(name="batchnorm_ge")(x_ge)
#     # x_ge = Dropout(0.4)(x_ge)
#     model_inputs.append(ge_input_tensor)
#     merge_inputs.append(x_ge)
#     del ge_input_tensor, x_ge

# if use_dd1:
#     dd1_input_tensor = tf.keras.Input(shape=dd_shape, name="dd1_data")
#     x_dd1 = Dense(dense1_dd1, activation=tf.nn.relu, name="dense1_dd1")(dd1_input_tensor)
#     x_dd1 = BatchNormalization(name="batchnorm_dd1")(x_dd1)
#     # x_dd1 = Dropout(0.4)(x_dd1)
#     model_inputs.append(dd1_input_tensor)
#     merge_inputs.append(x_dd1)
#     del dd1_input_tensor, x_dd1

# if use_dd2:
#     dd2_input_tensor = tf.keras.Input(shape=dd_shape, name="dd2_data")
#     x_dd2 = Dense(dense1_dd2, activation=tf.nn.relu, name="dense1_dd2")(dd2_input_tensor)
#     x_dd2 = BatchNormalization(name="batchnorm_dd2")(x_dd2)
#     # x_dd2 = Dropout(0.4)(x_dd2)
#     model_inputs.append(dd2_input_tensor)
#     merge_inputs.append(x_dd2)
#     del dd2_input_tensor, x_dd2

# # Merge towers
# merged_model = layers.Concatenate(axis=1, name="merger")(merge_inputs)

# merged_model = tf.keras.layers.Dense(dense1_top, activation=tf.nn.relu,
#                                      name="dense1_top", kernel_regularizer=None)(merged_model)
# merged_model = BatchNormalization(name="batchnorm_top")(merged_model)
# if dropout1_top > 0:
#     merged_model = Dropout(dropout1_top)(merged_model)

# softmax_output = tf.keras.layers.Dense(
#     1, activation="sigmoid", bias_initializer=output_bias, name="Response")(merged_model)

# # Assemble final model
# model = tf.keras.Model(inputs=model_inputs, outputs=softmax_output)

# metrics = [
#       keras.metrics.AUC(name="roc-auc", curve="ROC"),
#       keras.metrics.AUC(name="pr-auc", curve="PR"),
# ]
# if optimizer == "SGD":
#     optimizer = optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True)
# elif optimizer == "Adam":
#     optimizer = optimizers.Adam(learning_rate=learning_rate)

# model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

In [20]:
# fit_verbose = 1

# print_fn(f"Train steps:      {tr_steps}")
# print_fn(f"Validation steps: {vl_steps}")

# history = model.fit(x=train_data,
#                     validation_data=val_data,
#                     steps_per_epoch=tr_steps,
#                     validation_steps=vl_steps,
#                     class_weight=class_weight,
#                     epochs=params.epochs,
#                     verbose=fit_verbose,
#                     callbacks=callbacks)

In [21]:
if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)

model_inputs = []
merge_inputs = []

# Image
image_shape = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE, 3)
tile_input_tensor = tf.keras.Input(shape=image_shape, name="tile_image")
base_img_model = tf.keras.applications.Xception(
    include_top=False,
    weights="imagenet",
    input_shape=None,
    input_tensor=None,
    pooling="avg")
base_img_model.trainable = False  # freeze weights

x_tile = base_img_model(tile_input_tensor, training=False)
model_inputs.append(tile_input_tensor)

x_tile = Dense(dense1_img, activation=tf.nn.relu, name="dense1_img")(x_tile)
x_tile = Dense(dense2_img, activation=tf.nn.relu, name="dense2_img")(x_tile)
merge_inputs.append(x_tile)
del tile_input_tensor, x_tile

# GE
ge_input_tensor = tf.keras.Input(shape=ge_shape, name="ge_data")
x_ge = Dense(dense1_ge, activation=tf.nn.relu, name="dense1_ge")(ge_input_tensor)
x_ge = BatchNormalization(name="batchnorm_ge")(x_ge)
model_inputs.append(ge_input_tensor)
merge_inputs.append(x_ge)
del ge_input_tensor, x_ge

# DD1
dd1_input_tensor = tf.keras.Input(shape=dd_shape, name="dd1_data")
x_dd1 = Dense(dense1_dd1, activation=tf.nn.relu, name="dense1_dd1")(dd1_input_tensor)
x_dd1 = BatchNormalization(name="batchnorm_dd1")(x_dd1)
model_inputs.append(dd1_input_tensor)
merge_inputs.append(x_dd1)
del dd1_input_tensor, x_dd1

# DD2
dd2_input_tensor = tf.keras.Input(shape=dd_shape, name="dd2_data")
x_dd2 = Dense(dense1_dd2, activation=tf.nn.relu, name="dense1_dd2")(dd2_input_tensor)
x_dd2 = BatchNormalization(name="batchnorm_dd2")(x_dd2)
model_inputs.append(dd2_input_tensor)
merge_inputs.append(x_dd2)
del dd2_input_tensor, x_dd2

# Merge towers
merged_model = layers.Concatenate(axis=1, name="merger")(merge_inputs)

# Dense layers of the top classfier
merged_model = tf.keras.layers.Dense(dense1_top, activation=tf.nn.relu,
                                     name="dense1_top", kernel_regularizer=None)(merged_model)
merged_model = Dropout(0.2)(merged_model)

# Output
softmax_output = tf.keras.layers.Dense(
    1, activation="sigmoid", bias_initializer=output_bias, name="Response")(merged_model)

# Assemble final model
model = tf.keras.Model(inputs=model_inputs, outputs=softmax_output)

metrics = [keras.metrics.AUC(name="roc-auc", curve="ROC"),
           keras.metrics.AUC(name="pr-auc", curve="PR")
]

# optimizer = optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True)
optimizer = optimizers.Adam(learning_rate=learning_rate)

model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

In [22]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
tile_image (InputLayer)         [(None, 299, 299, 3) 0                                            
__________________________________________________________________________________________________
xception (Functional)           (None, 2048)         20861480    tile_image[0][0]                 
__________________________________________________________________________________________________
ge_data (InputLayer)            [(None, 942)]        0                                            
__________________________________________________________________________________________________
dd1_data (InputLayer)           [(None, 1993)]       0                                            
______________________________________________________________________________________________

In [23]:
# Layers & models have three weight attributes:
# weights: list of all weights variables of the layer.
# trainable_weights: list of those that are meant to be updated (via gradient descent) to minimize the loss during training.
# non_trainable_weights: list of those that aren't meant to be trained. Typically they are updated by the model during the forward pass.
# In general, all weights are trainable weights. The only built-in layer that has non-trainable weights is the BatchNormalization layer. It uses non-trainable weights to keep track of the mean and variance of its inputs during training.
print("\nBase model.")
print("trainable_weights:", len(base_img_model.trainable_weights))
print("non_trainable_weights:", len(base_img_model.non_trainable_weights))
print("layers:", len(base_img_model.layers))
print("output_shape:", base_img_model.output_shape)
print("trainable_variables:", len(base_img_model.trainable_variables))

print("\nFull model.")
print("trainable_weights:", len(model.trainable_weights))
print("non_trainable_weights:", len(model.non_trainable_weights))
print("layers:", len(model.layers))
print("output_shape:", model.output_shape)
print("trainable_variables:", len(model.trainable_variables))


Base model.
trainable_weights: 0
non_trainable_weights: 234
layers: 133
output_shape: (None, 2048)
trainable_variables: 0

Full model.
trainable_weights: 20
non_trainable_weights: 240
layers: 17
output_shape: (None, 1)
trainable_variables: 20


In [24]:
# initial_epochs = 10
initial_epochs = 4
# initial_epochs = 50

res = model.evaluate(val_data, steps=vl_steps, verbose=1)
print("Loss: {:0.4f}".format(res[0]))

Loss: 0.2432


In [25]:
print_fn(f"Train steps:      {tr_steps}")
print_fn(f"Validation steps: {vl_steps}")

history = model.fit(x=train_data,
                    validation_data=val_data,
                    steps_per_epoch=tr_steps,
                    validation_steps=vl_steps,
                    class_weight=class_weight,
                    epochs=initial_epochs,
                    verbose=1,
                    callbacks=callbacks)

Train steps:      3319
Validation steps: 47
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


In [33]:
# acc = history.history['pr-auc']
# val_acc = history.history['val_pr-auc']

# loss = history.history['loss']
# val_loss = history.history['val_loss']

# plt.figure(figsize=(8, 8))
# plt.subplot(2, 1, 1)
# plt.plot(acc, label='Training Accuracy')
# plt.plot(val_acc, label='Validation Accuracy')
# plt.legend(loc='lower right')
# plt.ylabel('Accuracy')
# plt.ylim([min(plt.ylim()),1])
# plt.title('Training and Validation Accuracy')

# plt.subplot(2, 1, 2)
# plt.plot(loss, label='Training Loss')
# plt.plot(val_loss, label='Validation Loss')
# plt.legend(loc='upper right')
# plt.ylabel('Cross Entropy')
# plt.ylim([0,1.0])
# plt.title('Training and Validation Loss')
# plt.xlabel('epoch')
# plt.show()

### Finetune

In [27]:
base_img_model.trainable = True

In [28]:
print("\nBase model.")
print("trainable_weights:", len(base_img_model.trainable_weights))
print("non_trainable_weights:", len(base_img_model.non_trainable_weights))
print("layers:", len(base_img_model.layers))
print("output_shape:", base_img_model.output_shape)
print("trainable_variables:", len(base_img_model.trainable_variables))

print("\nFull model.")
print("trainable_weights:", len(model.trainable_weights))
print("non_trainable_weights:", len(model.non_trainable_weights))
print("layers:", len(model.layers))
print("output_shape:", model.output_shape)
print("trainable_variables:", len(model.trainable_variables))


Base model.
trainable_weights: 154
non_trainable_weights: 80
layers: 133
output_shape: (None, 2048)
trainable_variables: 154

Full model.
trainable_weights: 174
non_trainable_weights: 86
layers: 17
output_shape: (None, 1)
trainable_variables: 174


In [29]:
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_img_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_img_model.layers[:fine_tune_at]:
    layer.trainable =  False

Number of layers in the base model:  133


In [30]:
print("\nBase model.")
print("trainable_weights:", len(base_img_model.trainable_weights))
print("non_trainable_weights:", len(base_img_model.non_trainable_weights))
print("layers:", len(base_img_model.layers))
print("output_shape:", base_img_model.output_shape)
print("trainable_variables:", len(base_img_model.trainable_variables))

print("\nFull model.")
print("trainable_weights:", len(model.trainable_weights))
print("non_trainable_weights:", len(model.non_trainable_weights))
print("layers:", len(model.layers))
print("output_shape:", model.output_shape)
print("trainable_variables:", len(model.trainable_variables))


Base model.
trainable_weights: 39
non_trainable_weights: 195
layers: 133
output_shape: (None, 2048)
trainable_variables: 39

Full model.
trainable_weights: 59
non_trainable_weights: 201
layers: 17
output_shape: (None, 1)
trainable_variables: 59


In [31]:
# model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
#               optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
#               metrics=['accuracy'])

loss = losses.BinaryCrossentropy()
# optimizer = optimizers.SGD(learning_rate=learning_rate/10, momentum=0.9, nesterov=True)
optimizer = optimizers.Adam(learning_rate=learning_rate/10)
model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

In [32]:
# fine_tune_epochs = 10
fine_tune_epochs = 4
total_epochs =  initial_epochs + fine_tune_epochs

# history_fine = model.fit(train_dataset,
#                          epochs=total_epochs,
#                          initial_epoch=history.epoch[-1],
#                          validation_data=validation_dataset)

history_fine = model.fit(x=train_data,
                    validation_data=val_data,
                    steps_per_epoch=tr_steps,
                    validation_steps=vl_steps,
                    class_weight=class_weight,
                    epochs=total_epochs,
                    initial_epoch=history.epoch[-1],
                    verbose=1,
                    callbacks=callbacks)

Epoch 4/8
Epoch 5/8
Epoch 6/8
Epoch 7/8
Epoch 8/8


In [None]:
# acc += history_fine.history['accuracy']
# val_acc += history_fine.history['val_accuracy']

# loss += history_fine.history['loss']
# val_loss += history_fine.history['val_loss']

In [None]:
# plt.figure(figsize=(8, 8))
# plt.subplot(2, 1, 1)
# plt.plot(acc, label='Training Accuracy')
# plt.plot(val_acc, label='Validation Accuracy')
# plt.ylim([0.8, 1])
# plt.plot([initial_epochs-1, initial_epochs-1], plt.ylim(), label='Start Fine Tuning')
# plt.legend(loc='lower right')
# plt.title('Training and Validation Accuracy')

# plt.subplot(2, 1, 2)
# plt.plot(loss, label='Training Loss')
# plt.plot(val_loss, label='Validation Loss')
# plt.ylim([0, 1.0])
# plt.plot([initial_epochs-1, initial_epochs-1], plt.ylim(), label='Start Fine Tuning')
# plt.legend(loc='upper right')
# plt.title('Training and Validation Loss')
# plt.xlabel('epoch')
# plt.show()

In [None]:
pretrain = "imagenet"
pooling = "avg"
image_shape = (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE, 3)
tile_input_tensor = tf.keras.Input(shape=image_shape, name="tile_image")
mm = tf.keras.applications.Xception(
    include_top=False,
    weights=pretrain,
    input_shape=None,
    input_tensor=None,
    pooling=pooling)

x = mm(tile_input_tensor, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(tile_input_tensor, outputs)

import ipdb; ipdb.set_trace()
# Layers & models have three weight attributes:
# weights: list of all weights variables of the layer.
# trainable_weights: list of those that are meant to be updated (via gradient descent) to minimize the loss during training.
# non_trainable_weights: list of those that aren't meant to be trained. Typically they are updated by the model during the forward pass.
# In general, all weights are trainable weights. The only built-in layer that has non-trainable weights is the BatchNormalization layer. It uses non-trainable weights to keep track of the mean and variance of its inputs during training.
print("trainable_weights:", len(mm.trainable_weights))
print("non_trainable_weights:", len(mm.non_trainable_weights))
print("layers:", len(mm.layers))
print("output shape:", mm.output_shape)
mm.trainable = False

In [None]:
# Utils for transfer learing with Keras
def print_trainable_layers(model, print_all=False):
    """ Print the trainable state of layers. """
    print('Trainable layers:')
    for layer in model.layers:
        if layer.trainable:
            print(layer.name, layer.trainable)
        if not layer.trainable and print_all:
            print(layer.name, layer.trainable)

            
def freeze_layers(model, freeze_up_to='all'):
    """ Freeze up to layer freeze_up_to, including! """
    # freeze_layers = ['1', '2', '3', '4']
    if freeze_up_to=='all':
        for layer in model.layers:
            layer.trainable = False

    #for layer in model.layers:
    #    if any([True for i in layers_ids if i in layer.name]):
    #        layer.trainable = False
    for layer in model.layers:
        # if freeze_up_to.lower() != layer.name.lower():
        if freeze_up_to.lower() not in layer.name.lower():
            layer.trainable = False
        else:
            layer.trainable = False
            break

def pop_layers(model, keep_up_to):
    # pop_layers = ['4', '5', 'outputs']
    model_layers = model.layers
    #for layer in model_layers[::-1]:
    #    if any([True for i in layers_ids if i in layer.name]):
    #        model.layers.pop()  
    
    for layer in model_layers[::-1]:
        # if keep_up_to.lower() != layer.name.lower():
        if keep_up_to.lower() not in layer.name.lower():
            model.layers.pop()
        else:
           break