
# PyTorch Broadcasting: A Practical Mini-Notebook

**Goal:** Build intuition for PyTorch broadcasting—what it is, when it works, when it fails—and practice with short exercises.

**You will learn to:**
- Read and reason about tensor shapes
- Apply broadcasting rules correctly
- Use `unsqueeze`and `None` indexing
- Spot & fix common broadcasting bugs (wrong axes, shape mismatch)


## Setup

In [25]:

import torch
torch.__version__, torch.cuda.is_available()


('2.2.2+cu121', False)


## Broadcasting in a Nutshell

**Equal Number of Dimensions:** If the tensors have a different number of dimensions, PyTorch implicitly adds leading (left-side) dimensions of size 1 to the tensor with fewer dimensions until both tensors have the same number of dimensions.

**Compatibility rule (right-aligned):** Compare shapes from **right to left**. For each axis pair:
- If they are **equal**, they're compatible on that axis.
- If **one is 1**, it's broadcast (virtually repeated) to match the other.
- Otherwise, shapes are **incompatible** → runtime error.

**Key tools & idioms**
- `x.unsqueeze(dim)` / `x[:, None]` to create size-1 axes

**Important:** Broadcasting is **virtual**—no memory copies


## Quick Demos

In [26]:
#squeeze and unsqueeze
x=torch.ones(2)
print(x.shape)  #[2]

x=torch.unsqueeze(x,1)  
print(x.shape)  #[2,1]

x=torch.squeeze(x,1)  #gets rid of the dimension if the dimension is 1, nothing otherwise
# x=x.squeeze(1)      #same thing

print(x.shape)  #[2]


torch.Size([2])
torch.Size([2, 1])
torch.Size([2])


In [28]:

# Demo 1:
a = torch.randn(1, 6)
print("a:", a)

c = torch.randn(5, 1)
print("c", c)

out = a + c  # -> [5, 6]
print("out:", out)
print("out.shape:", out.shape)


a: tensor([[-0.7185,  0.5186, -1.3125,  0.1920,  0.5428, -2.2188]])
c tensor([[ 0.2590],
        [-1.0297],
        [-0.5008],
        [ 0.2734],
        [-0.9181]])
out: tensor([[-0.4595,  0.7776, -1.0535,  0.4510,  0.8018, -1.9598],
        [-1.7482, -0.5111, -2.3422, -0.8377, -0.4869, -3.2485],
        [-1.2192,  0.0179, -1.8133, -0.3088,  0.0420, -2.7195],
        [-0.4451,  0.7920, -1.0392,  0.4654,  0.8161, -1.9454],
        [-1.6366, -0.3995, -2.2306, -0.7261, -0.3753, -3.1369]])
out.shape: torch.Size([5, 6])


In [29]:

# Demo 2: Add [B, D] and [D] -> [B, D]
B, D = 4, 3
x = torch.arange(B*D, dtype=torch.float32).reshape(B, D)
print("x.shape:", x.shape)
print(x)
b = torch.tensor([10.0, 20.0, 30.0])   # shape [D]
print("b.shape:", b.shape)
print(b)
y = x + b
print("x.shape:", x.shape, "| b.shape:", b.shape, "| y.shape:", y.shape)
print(y)


x.shape: torch.Size([4, 3])
tensor([[ 0.,  1.,  2.],
        [ 3.,  4.,  5.],
        [ 6.,  7.,  8.],
        [ 9., 10., 11.]])
b.shape: torch.Size([3])
tensor([10., 20., 30.])
x.shape: torch.Size([4, 3]) | b.shape: torch.Size([3]) | y.shape: torch.Size([4, 3])
tensor([[10., 21., 32.],
        [13., 24., 35.],
        [16., 27., 38.],
        [19., 30., 41.]])


In [31]:

# Demo 3: Add [B, D] and [B, 1] -> [B, D]
x = torch.arange(12.0).reshape(4, 3)
print(x)
row_bias = torch.tensor([[100.0],[200.0],[300.0],[400.0]])
y = x + row_bias
print("x.shape:", x.shape, "| row_bias.shape:", row_bias.shape, "| y.shape:", y.shape)
print(y)


tensor([[ 0.,  1.,  2.],
        [ 3.,  4.,  5.],
        [ 6.,  7.,  8.],
        [ 9., 10., 11.]])
x.shape: torch.Size([4, 3]) | row_bias.shape: torch.Size([4, 1]) | y.shape: torch.Size([4, 3])
tensor([[100., 101., 102.],
        [203., 204., 205.],
        [306., 307., 308.],
        [409., 410., 411.]])


In [None]:

# Demo 4: Using unsqueeze to align axes
x = torch.arange(12.0).reshape(4, 3)         # [4, 3]
print(x)
col = torch.tensor([1.0, 2.0, 3.0, 4.0])     # [4]
y_bad = None
try:
    y_bad = x + col          # incompatible: [4,3] + [4]
except Exception as e:
    print("As expected, this fails:", e)

# Fix via unsqueeze at axis 1 (-> [4,1])
y_ok = x + col.unsqueeze(1)  #->[4,1]
print("Fixed:", y_ok.shape)
print(y_ok)


tensor([[ 0.,  1.,  2.],
        [ 3.,  4.,  5.],
        [ 6.,  7.,  8.],
        [ 9., 10., 11.]])
As expected, this fails: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1
Fixed: torch.Size([4, 3])
tensor([[ 1.,  2.,  3.],
        [ 5.,  6.,  7.],
        [ 9., 10., 11.],
        [13., 14., 15.]])



### Example: NCHW Images (Batch, Channels, Height, Width)

Common pattern: apply per-channel bias/scale to an image batch `x` with shape `[N, C, H, W]`.


In [33]:

N, C, H, W = 2, 3, 4, 5
x = torch.randn(N, C, H, W)
scale_c = torch.tensor([0.5, 2.0, 1.5])     
bias_c  = torch.tensor([0.1, -0.2, 0.0])    
print("x:", x.shape, "scale_c:", scale_c.shape, "bias_c:", bias_c.shape)

# Broadcast to [N, C, H, W] by inserting singleton axes: [1, C, 1, 1]
scale_c_s=scale_c[None, :, None, None]
bias_c_s=bias_c[None, :, None, None]
print("scale_c_s:", scale_c_s.shape, "bias_c_s:", bias_c_s.shape)
print(scale_c_s)

y = x * scale_c_s + bias_c_s
print("x:", x.shape, "y:", y.shape)


x: torch.Size([2, 3, 4, 5]) scale_c: torch.Size([3]) bias_c: torch.Size([3])
scale_c_s: torch.Size([1, 3, 1, 1]) bias_c_s: torch.Size([1, 3, 1, 1])
tensor([[[[0.5000]],

         [[2.0000]],

         [[1.5000]]]])
x: torch.Size([2, 3, 4, 5]) y: torch.Size([2, 3, 4, 5])



## Common Pitfalls & Fixes

1) **Wrong axis alignment:** Added dimension on wrong axis so operation fails.
- **Fix:** `unsqueeze` (or `None`) in the correct axis.

2) **Incompatible shapes:** Read from right to left; insert 1-sized axes where needed.


## Exercise


### Row/Column Normalization (5–7 min)

Given `X` with shape `[B, D]`:
1. Compute **column-wise** mean and std → shapes `[D]`.
2. Normalize `X` to `Z = (X - mean) / std` using **broadcasting** (no loops).

**Starter code (fill the TODOs):**


In [None]:

torch.manual_seed(0)
B, D = 5, 4
X = torch.randn(B, D)
# 1) column-wise mean/std
column_mean = X.mean(dim=0)              # [D] dim=0 means calculate along rows
column_std  = X.std(dim=0)              # [D]
print(f'column_mean={column_mean}, column_mean.shape={column_mean.shape}')
print(f'Manual check of 0th column mean X[:,0].mean()={X[:,0].mean()}')
print(f'column_std={column_std}, column_std.shape={column_std.shape}\n')

# 2) Column-wise normalization: Z_feat = (X - column_mean) / column_std
# TODO: broadcast subtraction/division correctly
# Z_feat = ...

# 3) Row-wise (per sample) normalization: Z_row
row_mean = X.mean(dim=1, keepdim=True) # [B, 1]
row_std  = X.std(dim=1, keepdim=True)  # [B, 1]
# # TODO: compute Z_row
# Z_row = ...

# Quick checks (should be ~0 and ~1; small numeric devs OK)
# print("Column-wise mean ~0:", Z_feat.mean(dim=0))
# print("Column-wise std  ~1:", Z_feat.std(dim=0))
# print("Row-wise mean    ~0:", Z_row.mean(dim=1))
# print("Row-wise std     ~1:", Z_row.std(dim=1))


---
## Solution (reveal after attempting)

### Solution:

In [None]:

torch.manual_seed(0)
B, D = 5, 4
X = torch.randn(B, D)
# 1) column-wise mean/std
column_mean = X.mean(dim=0)              # [D]
column_std  = X.std(dim=0)              # [D]
print(f'column_mean={column_mean}, column_mean.shape={column_mean.shape}')
print(f'Manual check of 0th column mean X[:,0].mean()={X[:,0].mean()}')
print(f'column_std={column_std}, column_std.shape={column_std.shape}\n')


# 2) Column-wise normalization: Z_feat = (X - column_mean) / column_std
# TODO: broadcast subtraction/division correctly
Z_feat = (X - column_mean[None, :]) / column_std[None, :]
# or
# Z_feat = (X - column_mean.unsqueeze(0)) / column_std.unsqueeze(0)

# 3) Row-wise (per sample) normalization: Z_row
row_mean = X.mean(dim=1, keepdim=True)       # [B, 1]
row_std  = X.std(dim=1, keepdim=True)  # [B, 1]
# # TODO: compute Z_row
Z_row = (X - row_mean) / row_std

# Quick checks (should be ~0 and ~1; small numeric devs OK)
print("Column-wise mean ~0:", Z_feat.mean(dim=0))
print("Column-wise std  ~1:", Z_feat.std(dim=0))
print("Row-wise mean    ~0:", Z_row.mean(dim=1))
print("Row-wise std     ~1:", Z_row.std(dim=1))
