-
Notifications
You must be signed in to change notification settings - Fork 624
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
07ee267
commit 83afdc9
Showing
27 changed files
with
174 additions
and
886 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
{ | ||
"metadata": { | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.10-final" | ||
}, | ||
"orig_nbformat": 2, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3", | ||
"language": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2, | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import hub\n", | ||
"import tensorflow as tf\n", | ||
"from time import sleep\n", | ||
"import numpy as np\n", | ||
"from matplotlib import pyplot as plt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# helper function to visualize images\n", | ||
"def visualize(image):\n", | ||
" image = image.reshape(512, 512)\n", | ||
" plt.figure(figsize=(5, 5))\n", | ||
" plt.axis('off')\n", | ||
" plt.imshow(image, cmap='gray', vmin=0, vmax=1)\n", | ||
" \n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"ds = hub.Dataset(\"s3://snark-gradient-raw-data/output_single_8_all_samples_max_4_boolean_m5_fixed_final_400/ds3\") \n", | ||
"print(ds.shape) # the number of samples \n", | ||
"print(ds.schema) # the structure of the dataset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"image_sequence = ds[\"image\", 100040].compute() # or access any other sample\n", | ||
"# visualize(img)\n", | ||
"image_sequence.shape\n", | ||
"visualize(image_sequence[0]) # visualize first image in sequence\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for item in ds:\n", | ||
" print(item[\"label_chexpert\"].compute()) # or you can access any other key from schema\n", | ||
" print(item[\"viewPosition\"].compute()) # the ClassLabels are stored as integers\n", | ||
" print(item[\"viewPosition\"].compute(label_name=True)) # strings labels are retrieved in this manner\n", | ||
" break\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"subset = ds[500:1000] # take a subset of the dataset \n", | ||
"print(len(subset))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def only_frontal(sample):\n", | ||
" viewPosition = sample[\"viewPosition\"].compute(True)\n", | ||
" return True if \"PA\" in viewPosition or \"AP\" in viewPosition else False\n", | ||
"\n", | ||
"filtered = subset.filter(only_frontal)\n", | ||
"print(len(filtered))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"tds = filtered.to_tensorflow()\n", | ||
"# alternatively we can send a subset of keys to tf that are relevant for training\n", | ||
"# this is faster as otherwise other irrelevant data is fetched too, that can slow things down\n", | ||
"tds = filtered.to_tensorflow(key_list=[\"image\", \"label_chexpert\", \"viewPosition\"])\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_image(viewPosition, images):\n", | ||
" for i, vp in enumerate(viewPosition):\n", | ||
" if vp in [5, 12]:\n", | ||
" return np.concatenate((images[i], images[i], images[i]), axis=2)\n", | ||
"\n", | ||
"def to_model_fit(sample):\n", | ||
" viewPosition = sample[\"viewPosition\"]\n", | ||
" images = sample[\"image\"]\n", | ||
" image = tf.py_function(get_image, [viewPosition, images], tf.uint16)\n", | ||
" labels = sample[\"label_chexpert\"]\n", | ||
" return image, labels\n", | ||
"\n", | ||
"# converts the data into X, y format format for training\n", | ||
"tds_train = tds.map(to_model_fit)\n", | ||
"\n", | ||
"# batch and prefetch\n", | ||
"tds_train = tds_train.batch(8).prefetch(tf.data.AUTOTUNE)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 47, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"for batch in tds_train:\n", | ||
" # do something\n", | ||
" sleep(0.1) # simulate training delay " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
] | ||
} |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Keras==2.4.3 | ||
tensorflow==2.4.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.