Skip to content

Commit

Permalink
cleanup gradient health code
Browse files Browse the repository at this point in the history
  • Loading branch information
AbhinavTuli committed Mar 12, 2021
1 parent 07ee267 commit 83afdc9
Show file tree
Hide file tree
Showing 27 changed files with 174 additions and 886 deletions.
File renamed without changes.
170 changes: 170 additions & 0 deletions gradient_health/explore.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import hub\n",
"import tensorflow as tf\n",
"from time import sleep\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# helper function to visualize images\n",
"def visualize(image):\n",
" image = image.reshape(512, 512)\n",
" plt.figure(figsize=(5, 5))\n",
" plt.axis('off')\n",
" plt.imshow(image, cmap='gray', vmin=0, vmax=1)\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds = hub.Dataset(\"s3://snark-gradient-raw-data/output_single_8_all_samples_max_4_boolean_m5_fixed_final_400/ds3\") \n",
"print(ds.shape) # the number of samples \n",
"print(ds.schema) # the structure of the dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image_sequence = ds[\"image\", 100040].compute() # or access any other sample\n",
"# visualize(img)\n",
"image_sequence.shape\n",
"visualize(image_sequence[0]) # visualize first image in sequence\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for item in ds:\n",
" print(item[\"label_chexpert\"].compute()) # or you can access any other key from schema\n",
" print(item[\"viewPosition\"].compute()) # the ClassLabels are stored as integers\n",
" print(item[\"viewPosition\"].compute(label_name=True)) # strings labels are retrieved in this manner\n",
" break\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"subset = ds[500:1000] # take a subset of the dataset \n",
"print(len(subset))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def only_frontal(sample):\n",
" viewPosition = sample[\"viewPosition\"].compute(True)\n",
" return True if \"PA\" in viewPosition or \"AP\" in viewPosition else False\n",
"\n",
"filtered = subset.filter(only_frontal)\n",
"print(len(filtered))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tds = filtered.to_tensorflow()\n",
"# alternatively we can send a subset of keys to tf that are relevant for training\n",
"# this is faster as otherwise other irrelevant data is fetched too, that can slow things down\n",
"tds = filtered.to_tensorflow(key_list=[\"image\", \"label_chexpert\", \"viewPosition\"])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_image(viewPosition, images):\n",
" for i, vp in enumerate(viewPosition):\n",
" if vp in [5, 12]:\n",
" return np.concatenate((images[i], images[i], images[i]), axis=2)\n",
"\n",
"def to_model_fit(sample):\n",
" viewPosition = sample[\"viewPosition\"]\n",
" images = sample[\"image\"]\n",
" image = tf.py_function(get_image, [viewPosition, images], tf.uint16)\n",
" labels = sample[\"label_chexpert\"]\n",
" return image, labels\n",
"\n",
"# converts the data into X, y format format for training\n",
"tds_train = tds.map(to_model_fit)\n",
"\n",
"# batch and prefetch\n",
"tds_train = tds_train.batch(8).prefetch(tf.data.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"\n",
"for batch in tds_train:\n",
" # do something\n",
" sleep(0.1) # simulate training delay "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
]
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 2 additions & 0 deletions gradient_health/model_training/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Keras==2.4.3
tensorflow==2.4.1
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,15 @@
import shutil
import os
import pickle
from callback import MultipleClassAUROC, MultiGPUModelCheckpoint
from configparser import ConfigParser
from generator import AugmentedImageSequence
from keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.optimizers import Adam

# from keras.utils import multi_gpu_model
from models.keras import ModelFactory
from utility import get_sample_counts
from weights import get_class_weights
from augmenter import augmenter
import tensorflow as tf
import hub
import numpy as np

import wandb
wandb.init()


def train_gen():
for _ in range(10000000):
yield np.random.randint(low=0, high=2, size=(512, 512, 3)), np.random.randint(low=0, high=2, size=(14,))


def val_gen():
for _ in range(0):
yield np.random.randint(low=0, high=2, size=(512, 512, 3)), np.random.randint(low=0, high=2, size=(14,))


dummy_train = tf.data.Dataset.from_generator(
train_gen,
output_signature=(
tf.TensorSpec(shape=(512, 512, 3), dtype=tf.uint16),
tf.TensorSpec(shape=(14,), dtype=tf.int32),
))

dummy_val = tf.data.Dataset.from_generator(
val_gen,
output_signature=(
tf.TensorSpec(shape=(512, 512, 3), dtype=tf.uint16),
tf.TensorSpec(shape=(14,), dtype=tf.int32),
))

dummy_train = dummy_train.batch(8).prefetch(tf.data.AUTOTUNE)
dummy_val = dummy_val.batch(8).prefetch(tf.data.AUTOTUNE)


def only_frontal(sample):
viewPosition = sample["viewPosition"].compute(True)
Expand Down Expand Up @@ -76,7 +39,6 @@ def main():

# default config
output_dir = cp["DEFAULT"].get("output_dir")
image_source_dir = cp["DEFAULT"].get("image_source_dir")
base_model_name = cp["DEFAULT"].get("base_model_name")
print(base_model_name)
class_names = cp["DEFAULT"].get("class_names").split(",")
Expand All @@ -95,8 +57,6 @@ def main():
patience_reduce_lr = cp["TRAIN"].getint("patience_reduce_lr")
min_lr = cp["TRAIN"].getfloat("min_lr")
validation_steps = cp["TRAIN"].get("validation_steps")
positive_weights_multiply = cp["TRAIN"].getfloat("positive_weights_multiply")
dataset_csv_dir = cp["TRAIN"].get("dataset_csv_dir")
# if previously trained weights is used, never re-split
if use_trained_model_weights:
# resuming mode
Expand Down Expand Up @@ -171,16 +131,6 @@ def main():
)
print(f"** validation_steps: {validation_steps} **")

# compute class weights
print("** compute class weights from training data **")
class_weights = get_class_weights(
train_counts,
train_pos_counts,
multiply=positive_weights_multiply,
)
print("** class_weights **")
print(class_weights)

print("** load model **")
if use_trained_model_weights:
if use_best_weights:
Expand Down Expand Up @@ -237,15 +187,7 @@ def main():
model_train = model
model_train.compile(optimizer=optimizer, loss="binary_crossentropy")

auroc = MultipleClassAUROC(
sequence=tds_val,
class_names=class_names,
weights_path=output_weights_path,
stats=training_stats,
workers=generator_workers,
)
callbacks = [
# checkpoint,
TensorBoard(
log_dir=os.path.join(output_dir, "logs"), batch_size=batch_size
),
Expand All @@ -257,7 +199,6 @@ def main():
mode="min",
min_lr=min_lr,
),
# auroc,
]
print("** start training **")
history = model_train.fit(
Expand All @@ -277,7 +218,6 @@ def main():
pickle.dump(
{
"history": history.history,
"auroc": auroc.aurocs,
},
f,
)
Expand Down
File renamed without changes.
8 changes: 0 additions & 8 deletions gradient_health_training copy/augmenter.py

This file was deleted.

Loading

0 comments on commit 83afdc9

Please sign in to comment.