In this notebook we will implement some broadcasting functions from torch library:

* [torch.tensor.expand_as ](https://pytorch.org/docs/stable/generated/torch.Tensor.expand_as.html)

* [torch.broadcast_tensors](https://pytorch.org/docs/stable/generated/torch.broadcast_tensors.html)



In [None]:
import torch
import numpy as np

## Expand_as

Initially, we must verify whether, when provided with two tensors A and B, we can resize tensor A to match the dimensions of tensor B.


A can be expanded to B if the following conditions are met:
1. Both A and B have at least one dimension.
2. B has an equal or greater number of dimensions compared to A.
3. When iterating through the dimensions of A from the last to the first, each dimension of A is either 1 or matches the corresponding dimension (counting from the end) of B.


In [None]:
def has_no_dimensions(tensor):
    """
    Check if the given tensor is None or has no dimensions.

    Args:
    - tensor (torch.Tensor): The tensor to be checked.

    Returns:
    - bool: True if the tensor is None or has no dimensions, False otherwise.
    """

    if tensor is None:
        return True

    if len(tensor.size()) == 0:
        return True

    return False

In [None]:
def is_expandable(A, B):
  """
  Check if tensor A is expandable to match the dimensions of tensor B.

  Args:
  - A (torch.Tensor): The first tensor to be checked for expandability.
  - B (torch.Tensor): The second tensor whose dimensions are being compared for expansion.

  Returns:
  - bool: True if A is expandable to B, False otherwise.
  """

  # Obtain the sizes of the tensors
  size_A = A.size()
  size_B = B.size()


  # Check if either tensor is None or has no dimensions
  if has_no_dimensions(A) or has_no_dimensions(B):
      return False

  # Check if B has fewer dimensions than A
  if len(size_A) > len(size_B):
    return False

  # Iterate through the dimensions of A from the trailing dimension
  for i in range(-1, -1*len(size_A)-1, -1):

    if size_A[i] == 1: # If the dimension of A is 1, it can be expanded
      continue

    # If the dimensions of A and B at the corresponding index match,
    # there is no need to expand
    if size_A[i] == size_B[i]:
      continue

    return False # If one of the conditions isn't met


  return True

Secondly, when presented with a tensor A, we might require augmenting its dimensions by padding until it attains a specific number of dimensions.

In our scenario, we may need to pad the dimensions of one tensor to match those of another tensor.

In [None]:
def dim_pad(A, lenght):
  """
  Pad tensor A to have a specified number of dimensions.

  Args:
  - A (torch.Tensor): The tensor to be padded.
  - length (int): The desired number of dimensions for tensor A.

  Returns:
  - torch.Tensor: The padded tensor A.

  Note:
  If tensor A already has a number of dimensions equal to or greater than 'length',
  this function won't perform any padding.
  """

  # Create a detached clone of tensor A
  A0 = A.clone().detach()

  # Pad A to be in number of dimensions = length
  size_A = A.size()
  if len(size_A) < lenght:
    for i in range(lenght - len(size_A)):
      A0 = A0.unsqueeze(0)

  return A0



If two tensors, A and B, are given, with A being expandable to match B and both having the same number of dimensions, we will extend (separately) each axis of A that does not align with the corresponding axis of B

In [None]:
def expand_along_axis(tensor_in, dim, times):

  """
    Expand tensor along a specified axis by repeating its content without duplicating it in memory (using views of the tensor).

    Args:
    - tensor_in (torch.Tensor): The input tensor to be expanded.
    - dim (int): The axis along which to expand the tensor.
    - times (int): The number of times to repeat the tensor's content along the specified axis.

    Returns:
    - torch.Tensor: The expanded tensor.
    """

  tensor_out = tensor_in.detach().clone()


  # Repeat the tensor's content along the specified axis by viewing the other axes multiple times.
  for i in range(times-1):
      tensor_out = torch.cat([tensor_out, tensor_in.view(tensor_in.size())], dim=dim)

  return tensor_out


If A can be expanded to match B, we'll pad A's dimensions to align with the number of dimensions in B, and we'll expand each axis of A that is equal to 1.

In [None]:
def expand_tensor(input_tensor, *sizes):

    """
    Expand tensor to the specified sizes along the specified dimensions.

    Args:
    - input_tensor (torch.Tensor): The input tensor to be expanded.
    - sizes (int or tuple of ints): The desired sizes along each dimension.

    Returns:
    - torch.Tensor: The expanded tensor.
    """

    # Initialize output tensor
    C = input_tensor.clone()

    # Iterate over the desired sizes from the end
    for i in range(-1, -1* len(sizes) -1, -1):

      # Check if the dimension in the input tensor is 1 and needs expansion
      if input_tensor.size()[i] == 1 and sizes[i]!=1:

        # Expand along the current dimension
        C = expand_along_axis(C, i, sizes[i])


    return C



If A can be broadcasted to B, we will adjust their A's sizes by padding it accordinly. Subsequently, we will expand each axis of A that requires expansion.

In [None]:
def my_expand_as(A, B):

  size_B = B.size()
  # Check if A is expandable to B and perform necessary operations if so
  if is_expandable(A,B):
    A = dim_pad(A, len(size_B)) # Pad tensor A to match the number of dimensions of tensor B
    C = expand_tensor(A, *size_B) # Expand tensor A to match the dimensions of tensor B
    return C

  else:
    raise RuntimeError("Can't expand A to B's dimensions")


In [None]:
def compare_expand_as(A,B):
  """
  Compare the expansion results of two tensors obtained using my_expand_as and Torch's implementation.

  Args:
  - A (torch.Tensor): The first tensor.
  - B (torch.Tensor): The second tensor.

  Returns:
  - None

  Prints:
  - str: Message indicating whether the expansion results obtained using custom and Torch's implementations are equal.
  """

  torch_expansion = A.expand_as(B) # Our ground truth
  C = my_expand_as(A, B)

  if torch.all(C == torch_expansion):
      print("Expansion results for the provided tensors: my_expand_as matches Torch's implementation.")
  else:
    print("Expansion results for the provided tensors: my_expand_as differs from Torch's implementation.")

Usage Example:

In [None]:
# Examples where A is expandable to B's dimensions

D = torch.arange(16).reshape(4,1,4)
E = torch.arange(96).reshape(3,4,2,4)
compare_expand_as(D,E)

F = torch.arange(5).reshape(1, 1, 1, 1, 5)
G = torch.arange(160).reshape(2, 2, 2, 2, 2, 5)
compare_expand_as(F,G)

Expansion results for the provided tensors: my_expand_as matches Torch's implementation.
Expansion results for the provided tensors: my_expand_as matches Torch's implementation.


As observed, there are no discrepancies.

In [None]:
# Examples where A is not expandable to B's dimensions
H = torch.arange(16).reshape(1,4,1,4)
I = torch.arange(16).reshape(4,1,4)
try:
  compare_expand_as(H,I)
except RuntimeError as e:
  print(e)
  print("H is expandable to I: ", is_expandable(H,I))


expand(torch.LongTensor{[1, 4, 1, 4]}, size=[4, 1, 4]): the number of sizes provided (3) must be greater or equal to the number of dimensions in the tensor (4)
H is expandable to I:  False


In [None]:
J = torch.arange(15).reshape(3, 1, 1, 5)
K = torch.arange(40).reshape(2, 2, 2, 5)
try:
  compare_expand_as(J,K)

except RuntimeError as e:
  print(e)
  print("J is expandable to K: ", is_expandable(J,K))

The expanded size of the tensor (2) must match the existing size (3) at non-singleton dimension 0.  Target sizes: [2, 2, 2, 5].  Tensor sizes: [3, 1, 1, 5]
J is expandable to K:  False


## broadcast_tensors


Now, our aim is to determine whether two tensors can be broadcasted together. The procedure shares similarities with checking if A is expandable to B but with some variations. The conditions to ascertain whether two tensors, A and B, can be broadcasted together are as follows:

1. Both A and B must possess at least one dimension.
2. While iterating through the dimensions of one of the tensors from the end to the beginning, each dimension should fulfill one of these three criteria:

  a. It is 1.

  b. It doesn't exist.

  c. It matches the corresponding dimension of the other tensor (counting from the end).


In this implementation, we will iterate over the tensors starting with the one having the least dimension. Once we finish iterating through it, there's no need for further checks.

Furthermore, during the process of checking for broadcastability, we will determine the appropriate dimension to broadcast the two tensors to (if they are indeed broadcastable).


In [None]:
def is_broadcastable(A, B):
  """
  Check if two tensors are broadcastable together.

  Args:
  - A (torch.Tensor): The first tensor to be checked for broadcastability.
  - B (torch.Tensor): The second tensor to be checked for broadcastability.

  Returns:
  - tuple: A tuple containing:
    - bool: True if the tensors are broadcastable, False otherwise.
    - list (optional): A list representing the shape of the output tensor after broadcasting,
      returned only if the tensors are broadcastable.
  """


  # Check if either tensor is None or has no dimensions
  if has_no_dimensions(A) or has_no_dimensions(B):
      return False


  size_A = A.size()
  size_B = B.size()



  # Store the dimensions of the shorter and longer tensors (number of dimensions wise) independently.
  if len(size_A)<=len(size_B):
    shorter_tensor_size = size_A
    longer_tensor_size = size_B

  else:
    shorter_tensor_size  = size_B
    longer_tensor_size = size_A



  out_dim = []
  # Iterate the tensor with least dimensions from the last dimensions to first
  for i in range(-1, -1*len(shorter_tensor_size)-1, -1):

    # Check if at least one of the tensors in dim=i is 1
    if shorter_tensor_size[i] == 1:
      out_dim.insert(0, longer_tensor_size[i])

    elif longer_tensor_size[i] == 1 :
      out_dim.insert(0, shorter_tensor_size[i])

    # Check if dimension i in both tensors is equal
    elif shorter_tensor_size[i] == longer_tensor_size[i]:
      out_dim.insert(0, shorter_tensor_size[i],)

    # If both of the conditions before aren't met, the tensors are not broadcastable
    else:
      return False

  # Save the dimensions of the longer tensor that we didn't iterate yet to determine the desired broadcasting size
  i-=1
  while i>=-1*len(longer_tensor_size):
    out_dim.insert(0, longer_tensor_size[i])
    i-=1

  # There was no violation so the two tensors are broadcastable
  return True, out_dim




Once more, there might be a need to pad the dimensions of the tensors. This time, we will pad them to match the number of dimensions of the tensor with the greater number of dimensions.

In [None]:
def pad_to_same_size(A, B):

  """
  Pad two tensors to have the same number of dimensions.

  Args:
  - A (torch.Tensor): The first tensor to be padded.
  - B (torch.Tensor): The second tensor to be padded.

  Returns:
  - tuple: A tuple containing the padded tensors:
    - torch.Tensor: The padded version of tensor A.
    - torch.Tensor: The padded version of tensor B.
  """

  size_A = A.size()
  size_B = B.size()

  # Pad tensor with less dimensions to match other
  if len(size_A)<len(size_B):
    A0 = dim_pad(A, len(size_B))
    return A0, B

  elif len(size_A)>len(size_B):
    B0 = dim_pad(B, len(size_A))
    return A, B0

  else:
    return A, B


Finally, when tensors are broadcastable together and possess an equal number of dimensions, we can perform the broadcasting operation on them simultaneously.

In [None]:
def broadcast(A, B, target_size):

  """
  Broadcast two tensors to match a target size along each dimension.

  Args:
  - A (torch.Tensor): The first tensor to be broadcasted.
  - B (torch.Tensor): The second tensor to be broadcasted.
  - target_size (tuple of ints): The desired size along each dimension after broadcasting.

  Returns:
  - tuple: A tuple containing the broadcasted tensors:
    - torch.Tensor: The broadcasted version of tensor A.
    - torch.Tensor: The broadcasted version of tensor B.
  """

  A0 = A.clone()
  B0 = B.clone()

  size_A = A.size()
  size_B = B.size()

  # Iterate over the target size dimensions from the last to the first
  for i in range(-1, -1* len(target_size) -1, -1):

    # Broadcast tensor A along the current dimension if needed
    if size_A[i] == 1 and target_size[i]!=1:
      A0 = expand_along_axis(A0, i, target_size[i])

    # Broadcast tensor B along the current dimension if needed
    if size_B[i] == 1 and target_size[i]!=1:
      B0 = expand_along_axis(B0, i, target_size[i])

  return A0, B0




If A and B can be broadcasted together, we will adjust their dimensions' sizes by padding them accordingly. Subsequently, we will expand each axis that requires expansion.

In [None]:
def my_broadcast_tensors(A,B):
  """
  Broadcast two tensors to perform element-wise operations.

  Args:
  - A (torch.Tensor): The first tensor.
  - B (torch.Tensor): The second tensor.

  Returns:
  - tuple: A tuple containing the broadcasted tensors:
    - torch.Tensor: The broadcasted version of tensor A.
    - torch.Tensor: The broadcasted version of tensor B.

  Raises:
  - ValueError: If the tensors cannot be broadcasted together.
  """

  # If broadcastable, pad tensors to the same size and broadcast them
  broadcast_res = is_broadcastable(A,B)
  if broadcast_res:
    A, B = pad_to_same_size(A, B)
    broadcast_size = broadcast_res[1]
    return broadcast(A, B, broadcast_size)

  else:
    # If not broadcastable, raise a ValueError
    raise ValueError('Tried to broadcast two tensors which are not unbroadcastable')


In [None]:
def compare_results(X, Y):
    """
    Compare the broadcasting results of two tensors obtained using custom and Torch's implementations.

    Args:
    - X (torch.Tensor): The first tensor.
    - Y (torch.Tensor): The second tensor.

    Returns:
    - None

    Prints:
    - str: Message indicating whether the broadcasting results of X obtained using custom and Torch's implementations are equal.
    - str: Message indicating whether the broadcasting results of Y obtained using custom and Torch's implementations are equal.

    Raises:
    - ValueError: If broadcasting is not possible.

    Notes:
    The function may raise a ValueError if the custom broadcasting implementation in `my_broadcast_tensors` fails.
    """

    # Broadcast tensors X and Y using this my_broadcast_tensors
    X0, Y0 = my_broadcast_tensors(X, Y)

    # Broadcast tensors X and Y using Torch's implementation
    X_t, Y_t = torch.broadcast_tensors(X, Y)

    # Compare the broadcasting results of X and Y obtained using my_broadcast_tensors and Torch's implementations
    # Print message indicating whether the results are equal or not
    if torch.all(X0 == X_t) and torch.all(Y0 == Y_t) :
        print("Broadcasting results for the provided tensors: my_broadcast_tensors matches Torch's implementation.")
    else:
        print("Broadcasting results for the provided tensors: my_broadcast_tensors differs from Torch's implementation.")




Usage examples:

In [None]:
# Broadcastable tensors examples

C = torch.arange(24).reshape(2,3,4)
D = torch.tensor([0,1,2,3]) # size 4
compare_results(C, D) # Compare broadcasting results of tensors A and B


R = torch.arange(16).reshape(4,1,4)
S = torch.arange(16).reshape(2,1,2,4)
compare_results(R, S) # Compare broadcasting results of tensors R and S


X = torch.arange(3).reshape(3,1)
Y = torch.arange(4).reshape(1,4)
compare_results(X, Y) # Compare broadcasting results of tensors X and Y

Broadcasting results for the provided tensors: my_broadcast_tensors matches Torch's implementation.
Broadcasting results for the provided tensors: my_broadcast_tensors matches Torch's implementation.
Broadcasting results for the provided tensors: my_broadcast_tensors matches Torch's implementation.


In [None]:
# Unbroadcastable tensors example
G = torch.arange(6).reshape(3,2)
H = torch.arange(5).reshape(1,5)
try:
  compare_results(G, H) # Compare broadcasting results of tensors A and B

except ValueError as e:
  print(e)

Tried to broadcast two tensors which are not unbroadcastable


In [None]:
M = torch.arange(1080).reshape(1,2,3,3,2,5,3,2)
N = torch.arange(1620).reshape(3,3,3,2,5,3,2)
try:
  compare_results(M, N) # Compare broadcasting results of tensors A and B

except ValueError as e:
  print(e)

Tried to broadcast two tensors which are not unbroadcastable
