<a href="https://colab.research.google.com/github/anshulsawant/WhatDoesThisReallyDo/blob/main/What_does_this_function_really_do_Numpy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
## My attempt to address my biggest numpy fear: I assume that this function is processing stuff in some way, but what if it isn't. I hope this is not going to come crashing down.

import numpy as np

In [15]:
## What does reshape REALLY do
x = np.arange(24)
print(x)
print(x.reshape((6, 4)))
print(x.reshape((6,4), order='F'))
print(x.reshape((2,3,4), order='C')) ## last index changing fastest (default)
print(x.reshape((2,3,4), order='F')) ## first index changing fastest
print(x.reshape((2,3,4)).reshape((2,3,2,2)))
print(x.reshape((2,3,4)).reshape((2,3,2,2), order='F'))
## From numpy docs
## The order keyword gives the index ordering both for fetching the values from a, and then placing the values into the output array. For example, let’s say you have an array:
## You can think of reshaping as first raveling the array (using the given index order), then inserting the elements from the raveled array into the new array using the same kind of index ordering as was used for the raveling.
## Last index changing fastest => 0 -> (0, 0, 0); 1 -> (0, 0, 1); 2 -> (0, 0, 2) ...
## First index changing fastest => 0 -> (0, 0, 0); 1 -> (1, 0, 0); 2 -> (0, 1, 0); 3 -> (3, 0, 0); 4 -> () ...
## With first index changing faster, last index will change every 6th element. (0, 0, 0) -> 0; (0, 0, 1) -> 6;...
## The middle inded will change every 2nd element
## The first index will change every element.

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]
 [16 17 18 19]
 [20 21 22 23]]
[[ 0  6 12 18]
 [ 1  7 13 19]
 [ 2  8 14 20]
 [ 3  9 15 21]
 [ 4 10 16 22]
 [ 5 11 17 23]]
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
[[[ 0  6 12 18]
  [ 2  8 14 20]
  [ 4 10 16 22]]

 [[ 1  7 13 19]
  [ 3  9 15 21]
  [ 5 11 17 23]]]
[[[[ 0  1]
   [ 2  3]]

  [[ 4  5]
   [ 6  7]]

  [[ 8  9]
   [10 11]]]


 [[[12 13]
   [14 15]]

  [[16 17]
   [18 19]]

  [[20 21]
   [22 23]]]]
[[[[ 0  2]
   [ 1  3]]

  [[ 4  6]
   [ 5  7]]

  [[ 8 10]
   [ 9 11]]]


 [[[12 14]
   [13 15]]

  [[16 18]
   [17 19]]

  [[20 22]
   [21 23]]]]


In [19]:
## What does multi dimensional transpose really do
x = np.arange(24)
y = x.reshape((2,3,4)) ## Batch size x sequence length x embedding dim
print(y)
z = y.reshape((2,3,2,2)) ## Each embedding is split into 2 (n_heads) embeddings of size 2 (embedding dim//n_heads)
print(z)
## Each split is assigned to n_heads (2) arrays of size sequence length x (embedding_dim//n_heads)
print(z.transpose((0, 2, 1, 3))) ## Batch size x n_heads x sequence_length x (embedding dim//n_heads)
z[0,:,:,:].reshape((-1)) == y[0,:,:].reshape((-1)) ## Each batch will contain exactly the same elements as it is untouched


[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
[[[[ 0  1]
   [ 2  3]]

  [[ 4  5]
   [ 6  7]]

  [[ 8  9]
   [10 11]]]


 [[[12 13]
   [14 15]]

  [[16 17]
   [18 19]]

  [[20 21]
   [22 23]]]]
[[[[ 0  1]
   [ 4  5]
   [ 8  9]]

  [[ 2  3]
   [ 6  7]
   [10 11]]]


 [[[12 13]
   [16 17]
   [20 21]]

  [[14 15]
   [18 19]
   [22 23]]]]


array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True])

In [4]:
## Multidimensional matrix multiplication
x = np.arange(24).reshape((2,3,4))
y = np.arange(101, 125).reshape((2,4,3))
print(x)
print(y)
print(x @ y)

## Hadamard product
print('Hadamard')
x = np.arange(6).reshape((2,3))
y = np.arange(12).reshape((2,2,3))
print(x)
print(y)
print (x*y)

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
[[[101 102 103]
  [104 105 106]
  [107 108 109]
  [110 111 112]]

 [[113 114 115]
  [116 117 118]
  [119 120 121]
  [122 123 124]]]
[[[  648   654   660]
  [ 2336  2358  2380]
  [ 4024  4062  4100]]

 [[ 6360  6414  6468]
  [ 8240  8310  8380]
  [10120 10206 10292]]]
Hadamard
[[0 1 2]
 [3 4 5]]
[[[ 0  1  2]
  [ 3  4  5]]

 [[ 6  7  8]
  [ 9 10 11]]]
[[[ 0  1  4]
  [ 9 16 25]]

 [[ 0  7 16]
  [27 40 55]]]
