Submitted by: **Maxim Gorohovski**

# 📌Implementation of 'Broadcasting' Functionality in PyTorch

## 📖 **Background**

Broadcasting is a mechanism that allows operations between tensors of different shapes. The mechanism works by automatically expanding one or both tensors to have compatible dimensions.  

In pure mathematics, such as linear algebra, operations like **addition** or **multiplication** require that vectors or matrices will strict follow dimensional rules. But in the case of PyTorch, broadcasting is extremely useful for multiple reasons:

- 🚀 **No need for 'for' loops**: Before broadcasting, we would've generally approached the element-wise operations by running loops. Since we are working with multiple multiplications and additions, we would have to repeat these operations many times, and that of course would mean that the program's execution would be pretty slow. Using broadcasting would mean that no loops are actually needed and therefore, our program will be a lot faster and a lot more efficient.

- 🚀 **Saves Memory**: Broadcasting saves memory, because it avoids unnecessary memory duplication. Without broadcasting, we would've had to explictly expand and store vectors and matrices before applying any operations. But what broadcasting does is, it uses logical expansion to expand smaller tensors to fit the right dimensions, so essentially we don't allocate any extra memory - which is highly efficient!



## 🔨 Implementation

We will implement broadcasting using the following set of rules:

- When comparing dimensions from **right** to **left**
  - Dimensions must be either:
    1. **equal**
    2. One of the dimensions is 1

- The resulting dimensions after broadcasting is the **maximum** of the two compared dimensions.

- If these conditions fail, broadcasting will raise an error


In [14]:
import torch

1️⃣ **expand_as** function (a)

In [13]:
def expand_as(tensorA,tensorB):
  # Creating a clone of tensorA to not modify the original tensor
  tensorC = tensorA.clone()

  # Smaller amount of dimensions in A, we must add singleton dimensions to the leftmost dimension of tensorC
  while len(tensorC.shape) < len(tensorB.shape):
      tensorC = tensorC.unsqueeze(0)

  repeat_dims = [] # How many times each dimensions needs to be repeated
  for dimA,dimB in zip(reversed(tensorC.shape),reversed(tensorB.shape)):
    if dimA == dimB:
      repeat_dims.insert(0,1)
    elif dimA == 1: # We want to expand tensorC dimension to be the tensorB's dimension
      repeat_dims.insert(0, dimB)
    else:
      raise ValueError(f"dimension mismatch - tensorC.shape[{i}] = {tensorC.shape[i]} can't be expanded to tensorB.shape[{i}] = {tensorB.shape[i]}")

  for i,repeats in enumerate(repeat_dims): #for every index and number of repeats, we duplicate the tensor on the i-th dimension
    if repeats != 1:
      tensorC = torch.cat([tensorC] * repeats, dim=i)

  return tensorC

2️⃣ **mutually_expandable** function (b)

In [None]:
def mutually_expandable(tensorA,tensorB):
  # No need for clones, simply work with the shapes themselves
  shapeA = list(tensorA.shape)
  shapeB = list(tensorB.shape)
  broadcast_shape = [] # Here we keep the final shape

  # Making sure both tensors have the same number of dimensions
  while len(shapeA) < len(shapeB):
    shapeA.insert(0,1) # add singleton dimensions to the leftmost dimension of tensorA
  while len(shapeB) < len(shapeA):
    shapeB.insert(0,1)

  for dimA, dimB in zip(reversed(shapeA),reversed(shapeB)):
    if dimA == dimB:
      broadcast_shape.append(dimA)

    elif dimA == 1:
      broadcast_shape.append(dimB) # We want to increase the dimension of tensorC

    elif dimB == 1:
      broadcast_shape.append(dimA) # We want to increase the dimension of tensorD

    else:
      return False, None

  broadcast_shape.reverse() # Because we went from right to left in the 'for' loop, but the shape should be from left to right
  return True, tuple(broadcast_shape)

3️⃣ **mutually_broadcast** function (c)

In [None]:
def mutually_broadcast(tensorA,tensorB):
  is_expandable, broadcast_shape = mutually_expandable(tensorA,tensorB)
  if is_expandable:
    dummy_tensor = torch.empty(broadcast_shape) # dummy-target tensor for A and B
    tensorC = expand_as(tensorA,dummy_tensor)
    tensorD = expand_as(tensorB,dummy_tensor)
    return tensorC, tensorD

  else:
    raise ValueError(f"Can't mutually broadcast because {tensorA.shape} and {tensorB.shape} are not mutually-expandable")


# 🧪Tester

A tester generated by AI, it covers all possible test cases including failure cases, and tests all those cases on all of the functions we've implemented. We then compare the results with the original PyTorch functions to make sure they function the same way.

In [15]:
def broadcast_test():
    test_cases = [
        (torch.tensor(5), torch.tensor([[1, 2], [3, 4]])),
        (torch.tensor(1), torch.zeros(1)),
        (torch.tensor(42), torch.zeros(3, 1, 5)),
        (torch.tensor([1, 2, 3]), torch.zeros(3, 3)),
        (torch.tensor([1]), torch.zeros(2, 2)),
        (torch.ones(3, 4), torch.zeros(3, 4)),
        (torch.ones(1, 4), torch.zeros(3, 4)),
        (torch.ones(3, 1), torch.zeros(3, 5)),
        (torch.ones(1, 3, 1, 5), torch.zeros(2, 3, 4, 5)),
        (torch.ones(2, 1, 4), torch.zeros(2, 3, 4)),
        (torch.ones(1, 1, 4, 1), torch.zeros(2, 3, 4, 5)),
        # Broadcasting failure expected:
        (torch.zeros(2, 3), torch.zeros(4, 2)),
        (torch.zeros(2, 1, 3), torch.zeros(1, 3, 4)),
    ]

    print("Running full broadcast tests...\n")
    for idx, (A, B) in enumerate(test_cases):
        print(f"--- Test Case #{idx + 1} ---")
        print(f"A.shape: {A.shape}, B.shape: {B.shape}")

        # 1️⃣ Test mutually_expandable
        try:
            expected_shape = torch.broadcast_shapes(A.shape, B.shape)
            expandable, shape = mutually_expandable(A, B)
            assert expandable, "mutually_expandable returned False unexpectedly"
            assert shape == expected_shape, f"Shape mismatch: expected {expected_shape}, got {shape}"
            print("✅ mutually_expandable: PASSED")
        except RuntimeError:
            expandable, shape = mutually_expandable(A, B)
            assert not expandable and shape is None, "Expected failure, but got expandable"
            print("✅ mutually_expandable: Correctly failed")

        # 2️⃣ Test expand_as (only if broadcasting possible)
        try:
            torch_result = A.expand_as(B)
            custom_result = expand_as(A, B)
            assert torch.equal(custom_result, torch_result), "expand_as result mismatch"
            print("✅ expand_as: PASSED")
        except RuntimeError:
            try:
                custom_result = expand_as(A, B)
                assert False, "Custom expand_as should have failed but didn't"
            except Exception:
                print("✅ expand_as: Correctly failed")

        # 3️⃣ Test mutually_broadcast
        try:
            my_C, my_D = mutually_broadcast(A, B)
            torch_C, torch_D = torch.broadcast_tensors(A, B)
            assert torch.equal(my_C, torch_C), "broadcast_tensors result mismatch (C)"
            assert torch.equal(my_D, torch_D), "broadcast_tensors result mismatch (D)"
            print("✅ mutually_broadcast: PASSED")
        except ValueError:
            try:
                torch.broadcast_tensors(A, B)
                assert False, "Expected PyTorch broadcast_tensors to fail, but it didn't"
            except RuntimeError:
                print("✅ mutually_broadcast: Correctly failed")

        print()

    print("All tests passed successfully!")

In [16]:
broadcast_test()

Running full broadcast tests...

--- Test Case #1 ---
A.shape: torch.Size([]), B.shape: torch.Size([2, 2])
✅ mutually_expandable: PASSED
✅ expand_as: PASSED
✅ mutually_broadcast: PASSED

--- Test Case #2 ---
A.shape: torch.Size([]), B.shape: torch.Size([1])
✅ mutually_expandable: PASSED
✅ expand_as: PASSED
✅ mutually_broadcast: PASSED

--- Test Case #3 ---
A.shape: torch.Size([]), B.shape: torch.Size([3, 1, 5])
✅ mutually_expandable: PASSED
✅ expand_as: PASSED
✅ mutually_broadcast: PASSED

--- Test Case #4 ---
A.shape: torch.Size([3]), B.shape: torch.Size([3, 3])
✅ mutually_expandable: PASSED
✅ expand_as: PASSED
✅ mutually_broadcast: PASSED

--- Test Case #5 ---
A.shape: torch.Size([1]), B.shape: torch.Size([2, 2])
✅ mutually_expandable: PASSED
✅ expand_as: PASSED
✅ mutually_broadcast: PASSED

--- Test Case #6 ---
A.shape: torch.Size([3, 4]), B.shape: torch.Size([3, 4])
✅ mutually_expandable: PASSED
✅ expand_as: PASSED
✅ mutually_broadcast: PASSED

--- Test Case #7 ---
A.shape: torch.S