In [32]:
import imageio
import torch

img_arr = imageio.imread('data/bobby.jpg')

img_arr.shape

# img_array is a numpy array like object with three dimensions: two spatial dimensions, width and height and a third dimension corresponding to the
# red, green and blue channels. 

(720, 1280, 3)

In [33]:
# use the tensors 'permute' method with the old dimensions for each new dimension to get an appropiate layout

# given an input tensor h * w * c we get a proper layout by having channel 2 first and then channels 0 and 1

img = torch.from_numpy(img_arr)
out = img.permute(2,0,1)

# important to note that this operation does not make a copy of the tensor data

In [34]:
# to create a dataset of multiple images to use as an input for neural networks, store the images in a batch along the first dimension
# to obtain a n*c*h*w tensor.

# create a tensor of an appropiate size and fill it with images loaded from a directory

batch_size = 3 
batch = torch.zeros(batch_size, 3, 256, 256, dtype=torch.uint8)

# this indicates our batch will consist of three RGB images 256 pixels in height and 256 pixels in length 

# each colour will be represented as an 8 bit integer

In [36]:
import os 
data_dir = 'data/cats/'

filenames = [name for name in os.listdir(data_dir) if os.path.splitext(name)[-1] =='.png']

for i, filename in enumerate(filenames):
    img_arr = imageio.imread(os.path.join(data_dir,filename))
    img_t = torch.from_numpy(img_arr)
    img_t = img_t.permute(2,0,1)
    img_t = img_t[:3]   # only keep the first 3 channels
    batch[i] = img_t
    

tensor([[[156, 152, 124,  ..., 150, 149, 158],
         [174, 134, 165,  ..., 120, 136, 138],
         [127, 156, 107,  ..., 131, 143, 164],
         ...,
         [116, 130, 129,  ..., 127, 118, 112],
         [129, 130, 123,  ..., 115, 121, 114],
         [129, 123, 118,  ..., 113, 121, 120]],

        [[139, 135, 109,  ..., 135, 135, 147],
         [160, 119, 149,  ..., 105, 122, 124],
         [113, 140,  90,  ..., 118, 129, 152],
         ...,
         [ 99, 110, 111,  ..., 117, 108, 103],
         [111, 111, 106,  ..., 106, 112, 105],
         [111, 104, 102,  ..., 103, 110, 111]],

        [[129, 123,  98,  ..., 131, 132, 145],
         [155, 110, 137,  ..., 102, 119, 121],
         [104, 132,  80,  ..., 112, 125, 146],
         ...,
         [ 93, 108, 105,  ..., 125, 115, 108],
         [108, 108,  98,  ..., 110, 117, 110],
         [107,  98,  95,  ..., 108, 115, 116]]], dtype=torch.uint8)
tensor([[[202, 193, 190,  ...,  13,  13,  12],
         [199, 192, 189,  ...,  14,  14,