# Tutorial of reshape to find index
When a feature map passed through nodes such as `ReshapeNode` or `PermuteNode`, the pruning index is changing. We need to calculate the new dim of the pruning index. 


## A cheating solution:

### 1. create a new tensor `input_index` from `input`: 

In [28]:
import torch
# init input, prune_dim, and prune index
input = torch.randn(2, 3, 4, 5)
prune_dim = 2
prune_idx = torch.tensor([1, 2])

input_index = torch.zeros_like(input)

### 2. assign `1` to pruning index in dim:

In [29]:
index_tuple = (  # index_tuple = (:, :, prune_idx, :)
    (slice(None),) * (prune_dim)
    + (prune_idx,)
    + (slice(None),) * (len(input.shape) - prune_dim - 1)
) 
print(index_tuple)
input_index[index_tuple] = 1

(slice(None, None, None), slice(None, None, None), tensor([1, 2]), slice(None, None, None))


### 3. transform the `input_index`:

In [30]:
output_index = input_index.reshape(6, 20)
# index of output dim 0 (not the prune_dim)
print(set(torch.nonzero(output_index)[:, 0].tolist()))
# index of output dim 1 (where the prune_dim)
print(set(torch.nonzero(output_index)[:, 1].tolist()))

{0, 1, 2, 3, 4, 5}
{5, 6, 7, 8, 9, 10, 11, 12, 13, 14}


Now, we could find the prune index is changed from `[1, 2]` to `[5, 6, 7, 8, 9, 10, 11, 12, 13, 14]`.