In [7]:
# https://stackoverflow.com/questions/40731433/understanding-tf-extract-image-patches-for-extracting-patches-from-an-image
import tensorflow as tf
import numpy as np

In [3]:
'''
API:
    tf.extract_image_patches(
        images,
        ksizes,
        strides,
        rates,
        padding,
        name=None
    )

ksizes: decide the dimensions of each patch, that is how many pixels each patch should contain
strides: denotes the length of the gap between the start of one patch and the start of the next consecutive patch 
             within the original image
rates: a number that essentially means our patch should jump by rates pixels in the original image 
        for each consecutive pixel that ends up in our patch
padding: "VALID", which means every patch must be fully contained in the image
         "SAME", which means patches are allowed to be incomplete (the remaining pixels will be filled in with zeroes)  
'''

'\nAPI:\n    tf.extract_image_patches(\n        images,\n        ksizes,\n        strides,\n        rates,\n        padding,\n        name=None\n    )\n\nksizes: decide the dimensions of each patch, that is how many pixels each patch should contain\nstrides: denotes the length of the gap between the start of one patch and the start of the next consecutive patch \n             within the original image\nrates: a number that essentially means our patch should jump by rates pixels in the original image \n        for each consecutive pixel that ends up in our patch\npadding: "VALID", which means every patch must be fully contained in the image\n         "SAME", which means patches are allowed to be incomplete (the remaining pixels will be filled in with zeroes)  \n'

In [54]:
# example
n = 10
# images is a 1 x 10 x 10 x 1 array that contains the numbers 1 through 100 in order
images = [[[[x * n + y + 1] for y in range(n)] for x in range(n)]]
print(np.asarray(images).shape)
# print(images)

(1, 10, 10, 1)


In [10]:
# We generate four outputs as follows:
# 1. 3x3 patches with stride length 5
# 2. Same as above, but the rate is increased to 2
# 3. 4x4 patches with stride length 7; only one patch should be generated
# 4. Same as above, but with padding set to 'SAME'
with tf.Session() as sess:
    print(tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), '\n\n')
    print(tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 2, 2, 1], padding='VALID').eval(), '\n\n')
    print(tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), '\n\n')
    print(tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='SAME').eval())

[[[[ 1  2  3 11 12 13 21 22 23]
   [ 6  7  8 16 17 18 26 27 28]]

  [[51 52 53 61 62 63 71 72 73]
   [56 57 58 66 67 68 76 77 78]]]] 


[[[[  1   3   5  21  23  25  41  43  45]
   [  6   8  10  26  28  30  46  48  50]]

  [[ 51  53  55  71  73  75  91  93  95]
   [ 56  58  60  76  78  80  96  98 100]]]] 


[[[[ 1  2  3  4 11 12 13 14 21 22 23 24 31 32 33 34]]]] 


[[[[  1   2   3   4  11  12  13  14  21  22  23  24  31  32  33  34]
   [  8   9  10   0  18  19  20   0  28  29  30   0  38  39  40   0]]

  [[ 71  72  73  74  81  82  83  84  91  92  93  94   0   0   0   0]
   [ 78  79  80   0  88  89  90   0  98  99 100   0   0   0   0   0]]]]


In [35]:
with tf.Session() as sess:
    patch1 = tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 1, 1, 1], padding='VALID').eval()
    patch2 = tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 2, 2, 1], padding='VALID').eval()
    patch3 = tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='VALID').eval()
    patch4 = tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='SAME').eval()

In [32]:
count = 0
print('original_image\n')
for pixel in np.asarray(images).flatten():
    if count > 9: 
        print('\n')
        count = 0
    count += 1
    print(pixel, end=' ')

original_image

1 2 3 4 5 6 7 8 9 10 

11 12 13 14 15 16 17 18 19 20 

21 22 23 24 25 26 27 28 29 30 

31 32 33 34 35 36 37 38 39 40 

41 42 43 44 45 46 47 48 49 50 

51 52 53 54 55 56 57 58 59 60 

61 62 63 64 65 66 67 68 69 70 

71 72 73 74 75 76 77 78 79 80 

81 82 83 84 85 86 87 88 89 90 

91 92 93 94 95 96 97 98 99 100 

In [47]:
count = 0
print('patch1\n')
print('3x3 patches with stride length 5\n')
print(patch1.shape, end= '\n\n')
for pixel in np.asarray(images).flatten():
    if count > 9: 
        print('\n')
        count = 0
    count += 1
    if pixel in patch1.flatten():
        print('*', end=' ')
        continue
    print(pixel, end=' ')

patch1

3x3 patches with stride length 5

(1, 2, 2, 9)

* * * 4 5 * * * 9 10 

* * * 14 15 * * * 19 20 

* * * 24 25 * * * 29 30 

31 32 33 34 35 36 37 38 39 40 

41 42 43 44 45 46 47 48 49 50 

* * * 54 55 * * * 59 60 

* * * 64 65 * * * 69 70 

* * * 74 75 * * * 79 80 

81 82 83 84 85 86 87 88 89 90 

91 92 93 94 95 96 97 98 99 100 

In [49]:
count = 0
print('patch2\n')
print('Same as above, but the rate is increased to 2\n')
print(patch2.shape, end= '\n\n')
for pixel in np.asarray(images).flatten():
    if count > 9: 
        print('\n')
        count = 0
    count += 1
    if pixel in patch2.flatten():
        print('*', end=' ')
        continue
    print(pixel, end=' ')

patch2

Same as above, but the rate is increased to 2

(1, 2, 2, 9)

* 2 * 4 * * 7 * 9 * 

11 12 13 14 15 16 17 18 19 20 

* 22 * 24 * * 27 * 29 * 

31 32 33 34 35 36 37 38 39 40 

* 42 * 44 * * 47 * 49 * 

* 52 * 54 * * 57 * 59 * 

61 62 63 64 65 66 67 68 69 70 

* 72 * 74 * * 77 * 79 * 

81 82 83 84 85 86 87 88 89 90 

* 92 * 94 * * 97 * 99 * 

In [50]:
count = 0
print('patch3\n')
print('4x4 patches with stride length 7; only one patch should be generated\n')
print(patch3.shape, end= '\n\n')
for pixel in np.asarray(images).flatten():
    if count > 9: 
        print('\n')
        count = 0
    count += 1
    if pixel in patch3.flatten():
        print('*', end=' ')
        continue
    print(pixel, end=' ')

patch3

4x4 patches with stride length 7; only one patch should be generated

(1, 1, 1, 16)

* * * * 5 6 7 8 9 10 

* * * * 15 16 17 18 19 20 

* * * * 25 26 27 28 29 30 

* * * * 35 36 37 38 39 40 

41 42 43 44 45 46 47 48 49 50 

51 52 53 54 55 56 57 58 59 60 

61 62 63 64 65 66 67 68 69 70 

71 72 73 74 75 76 77 78 79 80 

81 82 83 84 85 86 87 88 89 90 

91 92 93 94 95 96 97 98 99 100 

In [53]:
count = 0
print('patch4\n')
print('Same as above, but with padding set to SAME\n')
print(patch4.shape, end= '\n\n')
for pixel in np.asarray(images).flatten():
    if count > 9: 
        print('\n')
        count = 0
    count += 1
    if pixel in patch4.flatten():
        print('*', end=' ')
        continue
    print(pixel, end=' ')

patch4

Same as above, but with padding set to SAME

(1, 2, 2, 16)

* * * * 5 6 7 * * * 

* * * * 15 16 17 * * * 

* * * * 25 26 27 * * * 

* * * * 35 36 37 * * * 

41 42 43 44 45 46 47 48 49 50 

51 52 53 54 55 56 57 58 59 60 

61 62 63 64 65 66 67 68 69 70 

* * * * 75 76 77 * * * 

* * * * 85 86 87 * * * 

* * * * 95 96 97 * * * 