## Flatten a tensor 

In [1]:
import torch


In [2]:
import numpy as np

In [4]:
d = [[1,2,3],[4,5,6],[7,8,9],[10,11,12]]

In [5]:
t = torch.tensor(d)
t

tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

### A tensor flatten operation is a common operation inside convolutional neural networks. This is because convolutional layer outputs that are passed to fully connected layers must be flatted out before the fully connected layer will accept the input.

### what if we want to only flatten specific axes within the tensor? This is typically required when working with CNNs

##### tensor inputs to a convolutional neural network typically have 4 axes, one for batch size, one for color channels, and one each for height and width.

In [16]:
t1 = torch.tensor([[1,1,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]])
t1

tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])

In [17]:
t2 = torch.tensor([[2,2,2,2],[2,2,2,2],[2,2,2,2],[2,2,2,2]])
t2

tensor([[2, 2, 2, 2],
        [2, 2, 2, 2],
        [2, 2, 2, 2],
        [2, 2, 2, 2]])

In [18]:
t3 = torch.tensor([[3,3,3,3],[3,3,3,3],[3,3,3,3],[3,3,3,3]])
t3

tensor([[3, 3, 3, 3],
        [3, 3, 3, 3],
        [3, 3, 3, 3],
        [3, 3, 3, 3]])

### So now we have 3 images of each 4* 4 andtensor rank is 2 for each images .  No lets create a batch to pass these images to a CNN . batches are represented using a single tensor, so we’ll need to combine these three tensors into a single larger tensor that has three axes instead of 2.

In [34]:
t = torch.stack((t1,t2,t3))
#we used the stack() method to concatenate our sequence of three tensors along a new axis.
t

tensor([[[1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]],

        [[2, 2, 2, 2],
         [2, 2, 2, 2],
         [2, 2, 2, 2],
         [2, 2, 2, 2]],

        [[3, 3, 3, 3],
         [3, 3, 3, 3],
         [3, 3, 3, 3],
         [3, 3, 3, 3]]])

In [35]:
t.shape
# now it is a 3 dimensional tensor ..where the 1st dimension shows the batch size i.e 3  last 2 shows the height and
#width of the iamge . Now all we have to add is the color channels , lets assume it is a gray scale image . 


torch.Size([3, 4, 4])

In [37]:
#color channel is 1 . so lets add 1 dimension to our image using unsqueeze function or we can use reshape function
t = t.reshape(3,1,4,4)
t


tensor([[[[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]]],


        [[[2, 2, 2, 2],
          [2, 2, 2, 2],
          [2, 2, 2, 2],
          [2, 2, 2, 2]]],


        [[[3, 3, 3, 3],
          [3, 3, 3, 3],
          [3, 3, 3, 3],
          [3, 3, 3, 3]]]])

In [36]:
t.unsqueeze(dim =1 )


tensor([[[[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]]],


        [[[2, 2, 2, 2],
          [2, 2, 2, 2],
          [2, 2, 2, 2],
          [2, 2, 2, 2]]],


        [[[3, 3, 3, 3],
          [3, 3, 3, 3],
          [3, 3, 3, 3],
          [3, 3, 3, 3]]]])

In [38]:
t.shape

torch.Size([3, 1, 4, 4])

In [39]:
len(t.shape)

4

### The first axis has 3 elements. Each element of the first axis represents an image. For each image, we have a single color channel on the channel axis. Each of these channels contain 4 arrays that contain 4 numbers or scalar components.

In [40]:
t[0]
#We have the first image.

tensor([[[1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]])

In [41]:
t[0][0]
#We have the first color channel in the first image.

tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])

In [42]:
t[0][0][0]
#We have the first first row of pixels in the first color channel of the first image.

tensor([1, 1, 1, 1])

In [43]:
t[0][0][0][0]
#We have the first pixel value in the first row of the first color channel of the first image.

tensor(1)

## Flattening the tensor batch
#### the whole batch is a single tensor that will be passed to the CNN, so we don’t want to flatten the whole thing. We only want to flatten the image tensors within the batch tensor.

#### Let’s flatten the whole thing first just to see what it will look like using flatten() in built funtion in python

In [44]:
t.flatten()

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

In [45]:
t.flatten().shape
#all images are mixed now ..which isnt good at all 

torch.Size([48])

#### This flattened batch won’t work well inside our CNN because we need individual predictions for each image within our batch tensor, and now we have a flattened mess.

#### The solution here, is to flatten each image while still maintaining the batch axis. This means we want to flatten only part of the tensor. We want to flatten the, color channel axis with the height and width axes.We skip over the batch axis so to speak, leaving it intact.

In [46]:
t.flatten(start_dim = 1 )

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]])

In [47]:
t.flatten(start_dim = 1 ).shape
#now we have 3 images which are flattened into 16 pixels

torch.Size([3, 16])

### If we flatten an RGB image, what happens to the color ?
#### Each color channel will be flattened first. Then, the flattened channels will be lined up side by side on a single axis of the tensor. Let's look at an example in code.

In [55]:
r1 = torch.ones(1,2,2)
g1 = torch.ones(1,2,2) + 1
b1 = torch.ones(1,2,2) + 2
img1 = torch.cat(
    (r1,g1,b1)
    ,dim=0
)

In [56]:
r2 = torch.ones(1,2,2)+3
g2 = torch.ones(1,2,2) + 4
b2 = torch.ones(1,2,2) + 5

img2 = torch.cat(
    (r2,g2,b2)
    ,dim=0
)

In [57]:
r3 = torch.ones(1,2,2)+6
g3 = torch.ones(1,2,2) + 7
b3 = torch.ones(1,2,2) + 8

img3 = torch.cat(
    (r3,g3,b3)
    ,dim=0
)

In [58]:
img = torch.stack((img1,img2,img3))

In [59]:
img

tensor([[[[1., 1.],
          [1., 1.]],

         [[2., 2.],
          [2., 2.]],

         [[3., 3.],
          [3., 3.]]],


        [[[4., 4.],
          [4., 4.]],

         [[5., 5.],
          [5., 5.]],

         [[6., 6.],
          [6., 6.]]],


        [[[7., 7.],
          [7., 7.]],

         [[8., 8.],
          [8., 8.]],

         [[9., 9.],
          [9., 9.]]]])

In [60]:
img.shape

torch.Size([3, 3, 2, 2])

In [63]:
imgflat = img.flatten(start_dim =2 )
imgflat


tensor([[[1., 1., 1., 1.],
         [2., 2., 2., 2.],
         [3., 3., 3., 3.]],

        [[4., 4., 4., 4.],
         [5., 5., 5., 5.],
         [6., 6., 6., 6.]],

        [[7., 7., 7., 7.],
         [8., 8., 8., 8.],
         [9., 9., 9., 9.]]])

In [64]:
imgflat.shape
#now we have 3 images of 3 color channel which are flattened in 4 pixels since the size of image was 2*2 

torch.Size([3, 3, 4])