From e3414d654d40b46549eda95245b264532ef2093c Mon Sep 17 00:00:00 2001 From: aymericdamien Date: Sun, 26 Jul 2020 12:29:43 -0700 Subject: [PATCH 1/4] fix ml intro --- .../0_Prerequisite/mnist_dataset_intro.ipynb | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tensorflow_v2/notebooks/0_Prerequisite/mnist_dataset_intro.ipynb b/tensorflow_v2/notebooks/0_Prerequisite/mnist_dataset_intro.ipynb index f1813c85..74f8a91f 100644 --- a/tensorflow_v2/notebooks/0_Prerequisite/mnist_dataset_intro.ipynb +++ b/tensorflow_v2/notebooks/0_Prerequisite/mnist_dataset_intro.ipynb @@ -1,10 +1,8 @@ { "cells": [ { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ "\n", "# MNIST Dataset Introduction\n", @@ -27,12 +25,10 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": { "collapsed": true }, - "outputs": [], "source": [ "# Import MNIST\n", "from tensorflow.examples.tutorials.mnist import input_data\n", @@ -53,12 +49,10 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": { "collapsed": true }, - "outputs": [], "source": [ "# Get the next 64 images array and labels\n", "batch_X, batch_Y = mnist.train.next_batch(64)" @@ -88,9 +82,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", - "version": "2.7.13" + "version": "2.7.18" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 1 } From a1516d2303f31942b4cff615e7c39f1b548157b4 Mon Sep 17 00:00:00 2001 From: aymericdamien Date: Sun, 26 Jul 2020 12:30:48 -0700 Subject: [PATCH 2/4] fix ml intro --- .../0_Prerequisite/mnist_dataset_intro.ipynb | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow_v2/notebooks/0_Prerequisite/mnist_dataset_intro.ipynb b/tensorflow_v2/notebooks/0_Prerequisite/mnist_dataset_intro.ipynb index 74f8a91f..93c9e79e 100644 --- a/tensorflow_v2/notebooks/0_Prerequisite/mnist_dataset_intro.ipynb +++ b/tensorflow_v2/notebooks/0_Prerequisite/mnist_dataset_intro.ipynb @@ -25,10 +25,10 @@ ] }, { - "cell_type": "markdown", - "metadata": { - "collapsed": true - }, + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# Import MNIST\n", "from tensorflow.examples.tutorials.mnist import input_data\n", @@ -49,10 +49,10 @@ ] }, { - "cell_type": "markdown", - "metadata": { - "collapsed": true - }, + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# Get the next 64 images array and labels\n", "batch_X, batch_Y = mnist.train.next_batch(64)" From a48127035f3ae05b1be66c817ac6bcce7073df79 Mon Sep 17 00:00:00 2001 From: aymericdamien Date: Sat, 19 Sep 2020 00:40:12 -0700 Subject: [PATCH 3/4] add multi gpu example --- .../6_Hardware/multigpu_training.ipynb | 371 ++++++++++++++++++ 1 file changed, 371 insertions(+) create mode 100644 tensorflow_v2/notebooks/6_Hardware/multigpu_training.ipynb diff --git a/tensorflow_v2/notebooks/6_Hardware/multigpu_training.ipynb b/tensorflow_v2/notebooks/6_Hardware/multigpu_training.ipynb new file mode 100644 index 00000000..46b07000 --- /dev/null +++ b/tensorflow_v2/notebooks/6_Hardware/multigpu_training.ipynb @@ -0,0 +1,371 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multi-GPU Training Example\n", + "\n", + "Train a convolutional neural network on multiple GPU with TensorFlow 2.0+.\n", + "\n", + "- Author: Aymeric Damien\n", + "- Project: https://github.com/aymericdamien/TensorFlow-Examples/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training with multiple GPU cards\n", + "\n", + "In this example, we are using data parallelism to split the training accross multiple GPUs. Each GPU has a full replica of the neural network model, and the weights (i.e. variables) are updated synchronously by waiting that each GPU process its batch of data.\n", + "\n", + "First, each GPU process a distinct batch of data and compute the corresponding gradients, then, all gradients are accumulated in the CPU and averaged. The model weights are finally updated with the gradients averaged, and the new model weights are sent back to each GPU, to repeat the training process.\n", + "\n", + "\"Parallelism\"\n", + "\n", + "## CIFAR10 Dataset Overview\n", + "\n", + "The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.\n", + "\n", + "![CIFAR10 Dataset](https://storage.googleapis.com/kaggle-competitions/kaggle/3649/media/cifar-10.png)\n", + "\n", + "More info: https://www.cs.toronto.edu/~kriz/cifar.html" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import absolute_import, division, print_function\n", + "\n", + "import tensorflow as tf\n", + "from tensorflow.keras import Model, layers\n", + "import time\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# MNIST dataset parameters.\n", + "num_classes = 10 # total classes (0-9 digits).\n", + "num_gpus = 4\n", + "\n", + "# Training parameters.\n", + "learning_rate = 0.001\n", + "training_steps = 1000\n", + "# Split batch size equally between GPUs.\n", + "# Note: Reduce batch size if you encounter OOM Errors.\n", + "batch_size = 1024 * num_gpus\n", + "display_step = 20\n", + "\n", + "# Network parameters.\n", + "conv1_filters = 64 # number of filters for 1st conv layer.\n", + "conv2_filters = 128 # number of filters for 2nd conv layer.\n", + "conv3_filters = 256 # number of filters for 2nd conv layer.\n", + "fc1_units = 2048 # number of neurons for 1st fully-connected layer." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare MNIST data.\n", + "from tensorflow.keras.datasets import cifar10\n", + "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n", + "# Convert to float32.\n", + "x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)\n", + "# Normalize images value from [0, 255] to [0, 1].\n", + "x_train, x_test = x_train / 255., x_test / 255.\n", + "y_train, y_test = np.reshape(y_train, (-1)), np.reshape(y_test, (-1))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Use tf.data API to shuffle and batch data.\n", + "train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", + "train_data = train_data.repeat().shuffle(batch_size * 10).batch(batch_size).prefetch(num_gpus)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class ConvNet(Model):\n", + " # Set layers.\n", + " def __init__(self):\n", + " super(ConvNet, self).__init__()\n", + " \n", + " # Convolution Layer with 64 filters and a kernel size of 3.\n", + " self.conv1_1 = layers.Conv2D(conv1_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n", + " self.conv1_2 = layers.Conv2D(conv1_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n", + " # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. \n", + " self.maxpool1 = layers.MaxPool2D(2, strides=2)\n", + "\n", + " # Convolution Layer with 128 filters and a kernel size of 3.\n", + " self.conv2_1 = layers.Conv2D(conv2_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n", + " self.conv2_2 = layers.Conv2D(conv2_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n", + " self.conv2_3 = layers.Conv2D(conv2_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n", + " # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. \n", + " self.maxpool2 = layers.MaxPool2D(2, strides=2)\n", + "\n", + " # Convolution Layer with 256 filters and a kernel size of 3.\n", + " self.conv3_1 = layers.Conv2D(conv3_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n", + " self.conv3_2 = layers.Conv2D(conv3_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n", + " self.conv3_3 = layers.Conv2D(conv3_filters, kernel_size=3, padding='SAME', activation=tf.nn.relu)\n", + "\n", + " # Flatten the data to a 1-D vector for the fully connected layer.\n", + " self.flatten = layers.Flatten()\n", + "\n", + " # Fully connected layer.\n", + " self.fc1 = layers.Dense(1024, activation=tf.nn.relu)\n", + " # Apply Dropout (if is_training is False, dropout is not applied).\n", + " self.dropout = layers.Dropout(rate=0.5)\n", + "\n", + " # Output layer, class prediction.\n", + " self.out = layers.Dense(num_classes)\n", + "\n", + " # Set forward pass.\n", + " @tf.function\n", + " def call(self, x, is_training=False):\n", + " x = self.conv1_1(x)\n", + " x = self.conv1_2(x)\n", + " x = self.maxpool1(x)\n", + " x = self.conv2_1(x)\n", + " x = self.conv2_2(x)\n", + " x = self.conv2_3(x)\n", + " x = self.maxpool2(x)\n", + " x = self.conv3_1(x)\n", + " x = self.conv3_2(x)\n", + " x = self.conv3_3(x)\n", + " x = self.flatten(x)\n", + " x = self.fc1(x)\n", + " x = self.dropout(x, training=is_training)\n", + " x = self.out(x)\n", + " if not is_training:\n", + " # tf cross entropy expect logits without softmax, so only\n", + " # apply softmax when not training.\n", + " x = tf.nn.softmax(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Cross-Entropy Loss.\n", + "# Note that this will apply 'softmax' to the logits.\n", + "@tf.function\n", + "def cross_entropy_loss(x, y):\n", + " # Convert labels to int 64 for tf cross-entropy function.\n", + " y = tf.cast(y, tf.int64)\n", + " # Apply softmax to logits and compute cross-entropy.\n", + " loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)\n", + " # Average loss across the batch.\n", + " return tf.reduce_mean(loss)\n", + "\n", + "# Accuracy metric.\n", + "@tf.function\n", + "def accuracy(y_pred, y_true):\n", + " # Predicted class is the index of highest score in prediction vector (i.e. argmax).\n", + " correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))\n", + " return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1)\n", + " \n", + "\n", + "@tf.function\n", + "def backprop(batch_x, batch_y, trainable_variables):\n", + " # Wrap computation inside a GradientTape for automatic differentiation.\n", + " with tf.GradientTape() as g:\n", + " # Forward pass.\n", + " pred = conv_net(batch_x, is_training=True)\n", + " # Compute loss.\n", + " loss = cross_entropy_loss(pred, batch_y)\n", + " # Compute gradients.\n", + " gradients = g.gradient(loss, trainable_variables)\n", + " return gradients\n", + "\n", + "# Build the function to average the gradients.\n", + "@tf.function\n", + "def average_gradients(tower_grads):\n", + " avg_grads = []\n", + " for tgrads in zip(*tower_grads):\n", + " grads = []\n", + " for g in tgrads:\n", + " expanded_g = tf.expand_dims(g, 0)\n", + " grads.append(expanded_g)\n", + " \n", + " grad = tf.concat(axis=0, values=grads)\n", + " grad = tf.reduce_mean(grad, 0)\n", + " \n", + " avg_grads.append(grad)\n", + " \n", + " return avg_grads" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "with tf.device('/cpu:0'):\n", + " # Build convnet.\n", + " conv_net = ConvNet()\n", + " # Stochastic gradient descent optimizer.\n", + " optimizer = tf.optimizers.Adam(learning_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Optimization process.\n", + "def run_optimization(x, y):\n", + " # Save gradients for all GPUs.\n", + " tower_grads = []\n", + " # Variables to update, i.e. trainable variables.\n", + " trainable_variables = conv_net.trainable_variables\n", + "\n", + " with tf.device('/cpu:0'):\n", + " for i in range(num_gpus):\n", + " # Split data between GPUs.\n", + " gpu_batch_size = int(batch_size/num_gpus)\n", + " batch_x = x[i * gpu_batch_size: (i+1) * gpu_batch_size]\n", + " batch_y = y[i * gpu_batch_size: (i+1) * gpu_batch_size]\n", + " \n", + " # Build the neural net on each GPU.\n", + " with tf.device('/gpu:%i' % i):\n", + " grad = backprop(batch_x, batch_y, trainable_variables)\n", + " tower_grads.append(grad)\n", + " \n", + " # Last GPU Average gradients from all GPUs.\n", + " if i == num_gpus - 1:\n", + " gradients = average_gradients(tower_grads)\n", + "\n", + " # Update vars following gradients.\n", + " optimizer.apply_gradients(zip(gradients, trainable_variables))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 1, loss: 2.302630, accuracy: 0.101318, speed: 16342.138481 examples/sec\n", + "step: 20, loss: 2.296755, accuracy: 0.108398, speed: 5355.197204 examples/sec\n", + "step: 40, loss: 2.216037, accuracy: 0.299072, speed: 12388.080848 examples/sec\n", + "step: 60, loss: 2.189814, accuracy: 0.362305, speed: 12033.404638 examples/sec\n", + "step: 80, loss: 2.137831, accuracy: 0.410156, speed: 12189.852065 examples/sec\n", + "step: 100, loss: 2.102876, accuracy: 0.437744, speed: 12212.349483 examples/sec\n", + "step: 120, loss: 2.077521, accuracy: 0.460693, speed: 12160.290400 examples/sec\n", + "step: 140, loss: 2.006775, accuracy: 0.545166, speed: 12202.175380 examples/sec\n", + "step: 160, loss: 1.994143, accuracy: 0.554443, speed: 12168.070368 examples/sec\n", + "step: 180, loss: 1.964281, accuracy: 0.597412, speed: 12244.148312 examples/sec\n", + "step: 200, loss: 1.893395, accuracy: 0.658203, speed: 12197.382402 examples/sec\n", + "step: 220, loss: 1.880256, accuracy: 0.672363, speed: 12178.323620 examples/sec\n", + "step: 240, loss: 1.868853, accuracy: 0.676025, speed: 12224.851444 examples/sec\n", + "step: 260, loss: 1.837151, accuracy: 0.705322, speed: 12101.154436 examples/sec\n", + "step: 280, loss: 1.799418, accuracy: 0.736816, speed: 12185.701420 examples/sec\n", + "step: 300, loss: 1.790719, accuracy: 0.755615, speed: 12126.826668 examples/sec\n", + "step: 320, loss: 1.732242, accuracy: 0.807861, speed: 12229.926783 examples/sec\n", + "step: 340, loss: 1.732089, accuracy: 0.806885, speed: 12167.651100 examples/sec\n", + "step: 360, loss: 1.693968, accuracy: 0.835693, speed: 12060.687471 examples/sec\n", + "step: 380, loss: 1.665804, accuracy: 0.862305, speed: 12130.389108 examples/sec\n", + "step: 400, loss: 1.627162, accuracy: 0.890381, speed: 12152.946766 examples/sec\n", + "step: 420, loss: 1.594189, accuracy: 0.920654, speed: 12057.401941 examples/sec\n", + "step: 440, loss: 1.575212, accuracy: 0.929688, speed: 12196.589206 examples/sec\n", + "step: 460, loss: 1.569351, accuracy: 0.942383, speed: 12147.345871 examples/sec\n", + "step: 480, loss: 1.520648, accuracy: 0.974609, speed: 11998.473978 examples/sec\n", + "step: 500, loss: 1.507439, accuracy: 0.982666, speed: 12152.490287 examples/sec\n", + "step: 520, loss: 1.495090, accuracy: 0.989746, speed: 12071.718912 examples/sec\n", + "step: 540, loss: 1.490940, accuracy: 0.989502, speed: 12049.224039 examples/sec\n", + "step: 560, loss: 1.476727, accuracy: 0.996338, speed: 12134.827424 examples/sec\n", + "step: 580, loss: 1.475038, accuracy: 0.995850, speed: 12128.228532 examples/sec\n", + "step: 600, loss: 1.469776, accuracy: 0.997559, speed: 12113.386949 examples/sec\n", + "step: 620, loss: 1.466832, accuracy: 0.999756, speed: 11939.016031 examples/sec\n", + "step: 640, loss: 1.466991, accuracy: 0.999023, speed: 12095.815773 examples/sec\n", + "step: 660, loss: 1.466177, accuracy: 0.999023, speed: 12035.037908 examples/sec\n", + "step: 680, loss: 1.465074, accuracy: 0.999512, speed: 11789.118097 examples/sec\n", + "step: 700, loss: 1.464655, accuracy: 0.999512, speed: 11965.087437 examples/sec\n", + "step: 720, loss: 1.465109, accuracy: 0.999512, speed: 11855.853520 examples/sec\n", + "step: 740, loss: 1.465021, accuracy: 0.999023, speed: 11774.901096 examples/sec\n", + "step: 760, loss: 1.463057, accuracy: 1.000000, speed: 11930.138289 examples/sec\n", + "step: 780, loss: 1.462609, accuracy: 1.000000, speed: 11766.752011 examples/sec\n", + "step: 800, loss: 1.462320, accuracy: 0.999756, speed: 11744.213314 examples/sec\n", + "step: 820, loss: 1.462975, accuracy: 1.000000, speed: 11700.815885 examples/sec\n", + "step: 840, loss: 1.462328, accuracy: 1.000000, speed: 11759.141371 examples/sec\n", + "step: 860, loss: 1.462561, accuracy: 1.000000, speed: 11650.397252 examples/sec\n", + "step: 880, loss: 1.462608, accuracy: 0.999512, speed: 11581.170575 examples/sec\n", + "step: 900, loss: 1.462178, accuracy: 0.999756, speed: 11562.545711 examples/sec\n", + "step: 920, loss: 1.461582, accuracy: 1.000000, speed: 11616.172231 examples/sec\n", + "step: 940, loss: 1.462402, accuracy: 1.000000, speed: 11709.561795 examples/sec\n", + "step: 960, loss: 1.462436, accuracy: 1.000000, speed: 11629.547741 examples/sec\n", + "step: 980, loss: 1.462415, accuracy: 1.000000, speed: 11623.658645 examples/sec\n", + "step: 1000, loss: 1.461925, accuracy: 1.000000, speed: 11579.716701 examples/sec\n" + ] + } + ], + "source": [ + "# Run training for the given number of steps.\n", + "ts = time.time()\n", + "for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):\n", + " # Run the optimization to update W and b values.\n", + " run_optimization(batch_x, batch_y)\n", + " \n", + " if step % display_step == 0 or step == 1:\n", + " dt = time.time() - ts\n", + " speed = batch_size * display_step / dt\n", + " pred = conv_net(batch_x)\n", + " loss = cross_entropy_loss(pred, batch_y)\n", + " acc = accuracy(pred, batch_y)\n", + " print(\"step: %i, loss: %f, accuracy: %f, speed: %f examples/sec\" % (step, loss, acc, speed))\n", + " ts = time.time()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 00ecf083ecac27fd57afdfc82798d1fef365769b Mon Sep 17 00:00:00 2001 From: aymericdamien Date: Sat, 19 Sep 2020 00:51:18 -0700 Subject: [PATCH 4/4] add multi gpu example --- README.md | 26 ++++++++++++++------------ tensorflow_v2/README.md | 3 +++ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 1bba79c4..e7de7049 100644 --- a/README.md +++ b/README.md @@ -13,13 +13,13 @@ It is suitable for beginners who want to find clear and concise examples about T - [Introduction to MNIST Dataset](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/0_Prerequisite/mnist_dataset_intro.ipynb). #### 1 - Introduction -- **Hello World** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/1_Introduction/helloworld.ipynb)). Very simple example to learn how to print "hello world" using TensorFlow 2.0. -- **Basic Operations** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/1_Introduction/basic_operations.ipynb)). A simple example that cover TensorFlow 2.0 basic operations. +- **Hello World** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/1_Introduction/helloworld.ipynb)). Very simple example to learn how to print "hello world" using TensorFlow 2.0+. +- **Basic Operations** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/1_Introduction/basic_operations.ipynb)). A simple example that cover TensorFlow 2.0+ basic operations. #### 2 - Basic Models -- **Linear Regression** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/2_BasicModels/linear_regression.ipynb)). Implement a Linear Regression with TensorFlow 2.0. -- **Logistic Regression** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/2_BasicModels/logistic_regression.ipynb)). Implement a Logistic Regression with TensorFlow 2.0. -- **Word2Vec (Word Embedding)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/2_BasicModels/word2vec.ipynb)). Build a Word Embedding Model (Word2Vec) from Wikipedia data, with TensorFlow 2.0. +- **Linear Regression** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/2_BasicModels/linear_regression.ipynb)). Implement a Linear Regression with TensorFlow 2.0+. +- **Logistic Regression** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/2_BasicModels/logistic_regression.ipynb)). Implement a Logistic Regression with TensorFlow 2.0+. +- **Word2Vec (Word Embedding)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/2_BasicModels/word2vec.ipynb)). Build a Word Embedding Model (Word2Vec) from Wikipedia data, with TensorFlow 2.0+. - **GBDT (Gradient Boosted Decision Trees)** ([notebooks](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/2_BasicModels/gradient_boosted_trees.ipynb)). Implement a Gradient Boosted Decision Trees with TensorFlow 2.0+ to predict house value using Boston Housing dataset. #### 3 - Neural Networks @@ -27,26 +27,28 @@ It is suitable for beginners who want to find clear and concise examples about T - **Simple Neural Network** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/3_NeuralNetworks/neural_network.ipynb)). Use TensorFlow 2.0 'layers' and 'model' API to build a simple neural network to classify MNIST digits dataset. - **Simple Neural Network (low-level)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/3_NeuralNetworks/neural_network_raw.ipynb)). Raw implementation of a simple neural network to classify MNIST digits dataset. -- **Convolutional Neural Network** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/3_NeuralNetworks/convolutional_network.ipynb)). Use TensorFlow 2.0 'layers' and 'model' API to build a convolutional neural network to classify MNIST digits dataset. +- **Convolutional Neural Network** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/3_NeuralNetworks/convolutional_network.ipynb)). Use TensorFlow 2.0+ 'layers' and 'model' API to build a convolutional neural network to classify MNIST digits dataset. - **Convolutional Neural Network (low-level)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/3_NeuralNetworks/convolutional_network_raw.ipynb)). Raw implementation of a convolutional neural network to classify MNIST digits dataset. - **Recurrent Neural Network (LSTM)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/3_NeuralNetworks/recurrent_network.ipynb)). Build a recurrent neural network (LSTM) to classify MNIST digits dataset, using TensorFlow 2.0 'layers' and 'model' API. -- **Bi-directional Recurrent Neural Network (LSTM)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/3_NeuralNetworks/bidirectional_rnn.ipynb)). Build a bi-directional recurrent neural network (LSTM) to classify MNIST digits dataset, using TensorFlow 2.0 'layers' and 'model' API. -- **Dynamic Recurrent Neural Network (LSTM)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/3_NeuralNetworks/dynamic_rnn.ipynb)). Build a recurrent neural network (LSTM) that performs dynamic calculation to classify sequences of variable length, using TensorFlow 2.0 'layers' and 'model' API. +- **Bi-directional Recurrent Neural Network (LSTM)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/3_NeuralNetworks/bidirectional_rnn.ipynb)). Build a bi-directional recurrent neural network (LSTM) to classify MNIST digits dataset, using TensorFlow 2.0+ 'layers' and 'model' API. +- **Dynamic Recurrent Neural Network (LSTM)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/3_NeuralNetworks/dynamic_rnn.ipynb)). Build a recurrent neural network (LSTM) that performs dynamic calculation to classify sequences of variable length, using TensorFlow 2.0+ 'layers' and 'model' API. ##### Unsupervised - **Auto-Encoder** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/3_NeuralNetworks/autoencoder.ipynb)). Build an auto-encoder to encode an image to a lower dimension and re-construct it. - **DCGAN (Deep Convolutional Generative Adversarial Networks)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/3_NeuralNetworks/dcgan.ipynb)). Build a Deep Convolutional Generative Adversarial Network (DCGAN) to generate images from noise. #### 4 - Utilities -- **Save and Restore a model** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/4_Utils/save_restore_model.ipynb)). Save and Restore a model with TensorFlow 2.0. -- **Build Custom Layers & Modules** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/4_Utils/build_custom_layers.ipynb)). Learn how to build your own layers / modules and integrate them into TensorFlow 2.0 Models. +- **Save and Restore a model** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/4_Utils/save_restore_model.ipynb)). Save and Restore a model with TensorFlow 2.0+. +- **Build Custom Layers & Modules** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/4_Utils/build_custom_layers.ipynb)). Learn how to build your own layers / modules and integrate them into TensorFlow 2.0+ Models. - **Tensorboard** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/4_Utils/tensorboard.ipynb)). Track and visualize neural network computation graph, metrics, weights and more using TensorFlow 2.0+ tensorboard. #### 5 - Data Management - **Load and Parse data** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/5_DataManagement/load_data.ipynb)). Build efficient data pipeline with TensorFlow 2.0 (Numpy arrays, Images, CSV files, custom data, ...). -- **Build and Load TFRecords** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/5_DataManagement/tfrecords.ipynb)). Convert data into TFRecords format, and load them with TensorFlow 2.0. -- **Image Transformation (i.e. Image Augmentation)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/5_DataManagement/image_transformation.ipynb)). Apply various image augmentation techniques with TensorFlow 2.0, to generate distorted images for training. +- **Build and Load TFRecords** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/5_DataManagement/tfrecords.ipynb)). Convert data into TFRecords format, and load them with TensorFlow 2.0+. +- **Image Transformation (i.e. Image Augmentation)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/5_DataManagement/image_transformation.ipynb)). Apply various image augmentation techniques with TensorFlow 2.0+, to generate distorted images for training. +#### 6 - Hardware + **Multi-GPU Training** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/6_Hardware/multigpu_training.ipynb)). Train a convolutional neural network with multiple GPUs on CIFAR-10 dataset. ## TensorFlow v1 diff --git a/tensorflow_v2/README.md b/tensorflow_v2/README.md index 83d9ad0e..ffccd7e5 100644 --- a/tensorflow_v2/README.md +++ b/tensorflow_v2/README.md @@ -41,6 +41,9 @@ - **Build and Load TFRecords** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/5_DataManagement/tfrecords.ipynb)). Convert data into TFRecords format, and load them with TensorFlow 2.0. - **Image Transformation (i.e. Image Augmentation)** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/5_DataManagement/image_transformation.ipynb)). Apply various image augmentation techniques with TensorFlow 2.0, to generate distorted images for training. +#### 6 - Hardware + **Multi-GPU Training** ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/tensorflow_v2/notebooks/6_Hardware/multigpu_training.ipynb)). Train a convolutional neural network with multiple GPUs on CIFAR-10 dataset. + ## Installation To install TensorFlow 2.0, simply run: