# **`torch.squeeze()` Summary**



## 1. **Definition**

```python
torch.squeeze(input, dim=None) → Tensor
```

Removes dimensions of size **1** (singleton dimensions) from a tensor’s shape.
It’s the inverse of `unsqueeze()`, which adds singleton dimensions.


<br>


## 2. **Parameters**

| Parameter          | Description                                                                                                            |
| ------------------ | ---------------------------------------------------------------------------------------------------------------------- |
| `input`            | The input tensor.                                                                                                      |
| `dim` *(optional)* | If given, only removes that specific dimension **if its size = 1**. If the size ≠ 1, the tensor is returned unchanged. |
| *(no dim)*         | Removes **all** singleton dimensions.                                                                                  |

<br>

## 3. **Return Value**

* A **view** of the original tensor (no data copied).
* Shape has fewer dimensions (if any size=1 axes were removed).
* Data remains unchanged.

<br>

## 4. **Examples**

### Example 1: Remove all size-1 dims

```python
import torch

x = torch.zeros(1, 3, 1, 5)   # shape (1,3,1,5)
y = torch.squeeze(x)

print(y.shape)   # torch.Size([3, 5])
```

Explanation: Removed **dim=0** and **dim=2** (both had size 1).

<br>

### Example 2: Remove only a specific dim

```python
x = torch.zeros(1, 3, 1, 5)

y = torch.squeeze(x, dim=0)   # remove dim=0
print(y.shape)                # torch.Size([3, 1, 5])

z = torch.squeeze(x, dim=1)   # dim=1 has size 3 → unchanged
print(z.shape)                # torch.Size([1, 3, 1, 5])
```

<br>

### Example 3: No singleton dims → unchanged

```python
x = torch.tensor([[1, 2], [3, 4]])   # shape (2,2)
y = x.squeeze()

print(y.shape)  # torch.Size([2,2]) → nothing removed
```

<br>

### Example 4: Negative dim

```python
x = torch.zeros(2, 1, 4)
y = torch.squeeze(x, dim=-2)  # -2 → second-to-last axis
print(y.shape)                # torch.Size([2, 4])
```

<br>

## 5. **Real-World Use Cases**

### a) Removing batch or channel dimensions

Models often output extra singleton dims:

```python
output = torch.randn(1, 10, 1, 1)  # (batch=1, classes=10, 1, 1)
output = output.squeeze()          # (10,) → clean vector
```

<br>

### b) After slicing/indexing

Indexing sometimes leaves an extra `1`:

```python
x = torch.rand(5, 1, 10)
print(x[:, 0, :].shape)   # (5, 10) → already squeezed by indexing

y = x.squeeze(1)          # removes explicit dim=1
print(y.shape)            # (5, 10)
```

<br>

### c) Preparing data for loss functions

Some loss functions expect `(N,)` instead of `(N,1)`:

```python
labels = torch.tensor([[0], [1], [2]])   # shape (3,1)
labels = labels.squeeze(1)               # shape (3,)
```

<br>

## 6. **Comparison with Related Functions**

| Function         | Purpose                | Example                            |
| ---------------- | ---------------------- | ---------------------------------- |
| `squeeze()`      | Removes singleton dims | `(1, 3, 1, 5)` → `(3, 5)`          |
| `unsqueeze(dim)` | Adds a singleton dim   | `(3, 5)` → `(1, 3, 5)`             |
| `reshape()`      | General reshaping      | `(1, 3, 1, 5)` → `(3, 5)` (manual) |

<br>

## 7. **Important Notes**

* ❗ If `dim` is specified but that axis is not size 1, the tensor is returned unchanged.
* ❗ Since it’s a **view**, modifying one affects the other (unless `.clone()` is used).
* ✅ Works with any dtype (`int`, `float`, `bool`) and on both CPU/GPU.

<br>

## 8. **Quick Reference Table**

| Input Shape    | Code            | Output Shape | Explanation                |
| -------------- | --------------- | ------------ | -------------------------- |
| `(1, 3, 1, 5)` | `x.squeeze()`   | `(3, 5)`     | Remove all singleton dims  |
| `(1, 3, 1, 5)` | `x.squeeze(0)`  | `(3, 1, 5)`  | Remove only dim=0          |
| `(1, 3, 1, 5)` | `x.squeeze(2)`  | `(1, 3, 5)`  | Remove only dim=2          |
| `(2, 3, 5)`    | `x.squeeze()`   | `(2, 3, 5)`  | Nothing removed            |
| `(2, 1, 4)`    | `x.squeeze(-2)` | `(2, 4)`     | Remove second-to-last axis |

<br>

 **In short:**

* `.squeeze()` → gets rid of “empty” axes (`size=1`).
* Useful for **cleaning up shapes** in DL pipelines (batch/channel/loss prep).
* The opposite of `.unsqueeze()`.



In [79]:
import torch

torch.manual_seed(42)

matrix_2 = torch.randint(low=0,high=30,size=(1,3,3))

matrix_2

tensor([[[12, 17, 16],
         [ 4,  6,  5],
         [10, 14, 10]]])

In [80]:
squeeze_matrix = matrix_2.squeeze()

In [81]:
squeeze_matrix

tensor([[12, 17, 16],
        [ 4,  6,  5],
        [10, 14, 10]])

<br><br><br>

If we take the example below where there is no 1 dimension present within the matrix, then squeeze removes nothing.

In [82]:
matrix_2 = torch.randint(low=0,high=30,size=(2,3,3))

matrix_2

tensor([[[23, 28,  4],
         [20, 24, 21],
         [22,  5, 25]],

        [[27, 16, 19],
         [26, 13, 11],
         [29, 13, 11]]])

In [84]:
squeeze_matrix = matrix_2.squeeze()

squeeze_matrix

tensor([[[23, 28,  4],
         [20, 24, 21],
         [22,  5, 25]],

        [[27, 16, 19],
         [26, 13, 11],
         [29, 13, 11]]])