In [1]:
import numpy as np

In [2]:
a = np.arange(1, 25, dtype='int8').reshape(6, 4)
print(a)

[[ 1  2  3  4]
 [ 5  6  7  8]
 [ 9 10 11 12]
 [13 14 15 16]
 [17 18 19 20]
 [21 22 23 24]]


In [3]:
a.shape

(6, 4)

In [4]:
a.strides

(4, 1)

In [5]:
# b = np.lib.stride_tricks.as_strided(a, shape=(2, 6, 2), strides=(2, 4, 1))
b = np.lib.stride_tricks.as_strided(a, shape=(3, 2, 2, 2), strides=(8, 2, 4, 1))
print(b)

[[[[ 1  2]
   [ 5  6]]

  [[ 3  4]
   [ 7  8]]]


 [[[ 9 10]
   [13 14]]

  [[11 12]
   [15 16]]]


 [[[17 18]
   [21 22]]

  [[19 20]
   [23 24]]]]


In [6]:
b.shape

(3, 2, 2, 2)

In [7]:
b.strides

(8, 2, 4, 1)

In [8]:
c = np.lib.stride_tricks.as_strided(b, shape=(3, 2, 2, 2), strides=(8, 2, 4, 1))
print(c)

[[[[ 1  2]
   [ 5  6]]

  [[ 3  4]
   [ 7  8]]]


 [[[ 9 10]
   [13 14]]

  [[11 12]
   [15 16]]]


 [[[17 18]
   [21 22]]

  [[19 20]
   [23 24]]]]


In [9]:
c.shape

(3, 2, 2, 2)

In [10]:
c.strides

(8, 2, 4, 1)

# Creating Generalized Method

In [1]:
def split_image_in_fragments(_img, fragment_height, fragment_width):

    channel, image_height, image_width = _img.shape
    byte_size = _img.dtype.alignment

    shape = (
        channel,
        int(image_height / fragment_height),
        int(image_width / fragment_width),
        fragment_height,
        fragment_width
    )

    strides = (
        image_width * fragment_height * fragment_width * byte_size,
        image_width * fragment_height * byte_size,
        fragment_width * byte_size,
        image_width * byte_size,
        byte_size
    )

    return np.lib.stride_tricks.as_strided(_img, shape=shape, strides=strides)

## 2dim

### Create 2dim Image

In [23]:
fake_image = np.arange(0, 24, dtype='int8').reshape(1, 4, 6)
print(fake_image)

[[[ 0  1  2  3  4  5]
  [ 6  7  8  9 10 11]
  [12 13 14 15 16 17]
  [18 19 20 21 22 23]]]


In [14]:
print(split_image_in_fragments(fake_image, 2, 2))

[[[[[ 0  1]
    [ 6  7]]

   [[ 2  3]
    [ 8  9]]

   [[ 4  5]
    [10 11]]]


  [[[12 13]
    [18 19]]

   [[14 15]
    [20 21]]

   [[16 17]
    [22 23]]]]]


In [15]:
np.lib.stride_tricks.as_strided(fake_image, shape=(1, 2, 3, 2, 2), strides=(24, 12, 2, 6, 1))

array([[[[[ 0,  1],
          [ 6,  7]],

         [[ 2,  3],
          [ 8,  9]],

         [[ 4,  5],
          [10, 11]]],


        [[[12, 13],
          [18, 19]],

         [[14, 15],
          [20, 21]],

         [[16, 17],
          [22, 23]]]]], dtype=int8)

## 3dim

### Create 3dim Image

In [16]:
channeled_fake_image = np.arange(0, 72, dtype='int8').reshape(3, 4, 6)
print(channeled_fake_image)

[[[ 0  1  2  3  4  5]
  [ 6  7  8  9 10 11]
  [12 13 14 15 16 17]
  [18 19 20 21 22 23]]

 [[24 25 26 27 28 29]
  [30 31 32 33 34 35]
  [36 37 38 39 40 41]
  [42 43 44 45 46 47]]

 [[48 49 50 51 52 53]
  [54 55 56 57 58 59]
  [60 61 62 63 64 65]
  [66 67 68 69 70 71]]]


### Test

In [17]:
print(split_image_in_fragments(channeled_fake_image, 2, 2))

[[[[[ 0  1]
    [ 6  7]]

   [[ 2  3]
    [ 8  9]]

   [[ 4  5]
    [10 11]]]


  [[[12 13]
    [18 19]]

   [[14 15]
    [20 21]]

   [[16 17]
    [22 23]]]]



 [[[[24 25]
    [30 31]]

   [[26 27]
    [32 33]]

   [[28 29]
    [34 35]]]


  [[[36 37]
    [42 43]]

   [[38 39]
    [44 45]]

   [[40 41]
    [46 47]]]]



 [[[[48 49]
    [54 55]]

   [[50 51]
    [56 57]]

   [[52 53]
    [58 59]]]


  [[[60 61]
    [66 67]]

   [[62 63]
    [68 69]]

   [[64 65]
    [70 71]]]]]


### Check

In [18]:
channeled_fake_image = np.lib.stride_tricks.as_strided(channeled_fake_image, shape=(3, 2, 3, 2, 2), strides=(24, 12, 2, 6, 1))
print(channeled_fake_image)

[[[[[ 0  1]
    [ 6  7]]

   [[ 2  3]
    [ 8  9]]

   [[ 4  5]
    [10 11]]]


  [[[12 13]
    [18 19]]

   [[14 15]
    [20 21]]

   [[16 17]
    [22 23]]]]



 [[[[24 25]
    [30 31]]

   [[26 27]
    [32 33]]

   [[28 29]
    [34 35]]]


  [[[36 37]
    [42 43]]

   [[38 39]
    [44 45]]

   [[40 41]
    [46 47]]]]



 [[[[48 49]
    [54 55]]

   [[50 51]
    [56 57]]

   [[52 53]
    [58 59]]]


  [[[60 61]
    [66 67]]

   [[62 63]
    [68 69]]

   [[64 65]
    [70 71]]]]]
