In [702]:
import torch
from torch import tensor
from typing import List, Optional

# Check is broadcastable

In [703]:
def is_broadcastable(A:tensor, B:tensor) -> bool:
	"""
	A tensor can't be broadcasted if:

	(dimA[i] != dimB[i]) and (dimA[i] != 1 and dimB[i] != 1)

	The dimension of A can be longer than B, or vice versa.
	To deal with this, we extend the shorter tensor with 1's.

	We also check if A.size()[i] < B.size()[i] because we can't expand from higher dimension to lower dimension.
	"""

	minDim = min(A.dim(), B.dim())
	i = -1
	# Check latest size, after loop, everything is expanded as '1' so we dont need to check
	for _ in range(minDim):
		sA = A.size()[i]
		sB = B.size()[i]
		i = i - 1
		if sA != sB and sA != 1 and sB != 1:
			print(sA, sB, " mismatch")
			return False
		elif sA > sB:
			print(f"Dim {sA} in A is longer than {sB} in B, mismatch")
			return False
		else:
			print(sA, sB, " match")
	print("Broadcastable")
	return True

# Broadcast function

In [704]:
def my_expand_as(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
	assert is_broadcastable(A, B)

	"""
	Algorithm to broadcast

	dimensions:
	A = |2|1|1|
	B = |3|2|2|3|9|

	broadcast to:
	A = |1|1|2|1|1|
	B = |3|2|2|3|9|

	expands to:
	C = |3|2|2|3|9|
	
	"""
	# We want to save B to the 'skeleton' of A
	res = torch.clone(A)
	diff = B.dim() - A.dim()

	for i in range(A.dim()):
		if A.shape[i] != B.shape[diff+i]: 
			lst = [res] * B.shape[diff+i]  # like repeat
			res = torch.stack(lst, i+1)
			res.squeeze_(i)

	for i in range(diff, 0, -1):
		lst = [res] * B.shape[i]
		res = torch.stack(lst, 0)

	return res


# Test is broadcastable

In [719]:
# A = 2x1x2
# B = 2x1x1x3
A = torch.arange(4).reshape(2,1,2)
B = torch.arange(6).reshape(2,1,1,3)
print("A = ", A.size())
print("B = ", B.size())
try:	
	C = my_expand_as(A,B)
	print(C)
	raise "ERROR: A,B are not broadcastable, yet my_expand_as() returned a tensor"
except AssertionError:
	print("A and B are not broadcastable")

A =  torch.Size([2, 1, 2])
B =  torch.Size([2, 1, 1, 3])
2 3  mismatch
A and B are not broadcastable


In [720]:
# A = 2x1x2
# B = 5x1
A = torch.arange(4).reshape(2,1,2)
B = torch.arange(5).reshape(5,1)
print("A = ", A.size())
print("B = ", B.size())
try:
	C = my_expand_as(A,B)
	print(C)
	raise "ERROR: A,B are not broadcastable, yet my_expand_as() returned a tensor"
except AssertionError:
	print("A and B are not broadcastable")

A =  torch.Size([2, 1, 2])
B =  torch.Size([5, 1])
Dim 2 in A is longer than 1 in B, mismatch
A and B are not broadcastable


In [721]:
# A = 2x1x1x3
# B = 5x1
A = torch.arange(6).reshape(2,1,1,3)
B = torch.arange(5).reshape(5,1)
print("A = ", A.size())
print("B = ", B.size())
try:
	C = my_expand_as(A,B)
	print(C)
	print("A and B are broadcastable")
	raise "ERROR: A,B are not broadcastable, yet my_expand_as() returned a tensor"
except AssertionError:
	print("A and B are not broadcastable")

A =  torch.Size([2, 1, 1, 3])
B =  torch.Size([5, 1])
Dim 3 in A is longer than 1 in B, mismatch
A and B are not broadcastable


In [708]:
# A = 2x1x2
# B = 1x3x1x2
A = torch.arange(4).reshape(2,1,2)
B = torch.arange(6).reshape(2,1,1,3)
print("A = ", A.size())
print("B = ", B.size())
try:	
	C = my_expand_as(A,B)
	print(C)
	raise "ERROR: A,B are not broadcastable, yet my_expand_as() returned a tensor"
except AssertionError:
	print("A and B are not broadcastable")

A =  torch.Size([2, 1, 2])
B =  torch.Size([2, 1, 1, 3])
2 3  mismatch
A and B are not broadcastable


In [722]:
# A = 1x1x1x1
# B = 5x5x5x1
A = torch.arange(1).reshape(1,1,1,1)
B = torch.arange(125).reshape(5,5,5,1)
print("A = ", A.size())
print("B = ", B.size())
C = my_expand_as(A,B)
print(C)
print("A and B are broadcastable")

A =  torch.Size([1, 1, 1, 1])
B =  torch.Size([5, 5, 5, 1])
1 1  match
1 5  match
1 5  match
1 5  match
Broadcastable
tensor([[[[0],
          [0],
          [0],
          [0],
          [0]],

         [[0],
          [0],
          [0],
          [0],
          [0]],

         [[0],
          [0],
          [0],
          [0],
          [0]],

         [[0],
          [0],
          [0],
          [0],
          [0]],

         [[0],
          [0],
          [0],
          [0],
          [0]]],


        [[[0],
          [0],
          [0],
          [0],
          [0]],

         [[0],
          [0],
          [0],
          [0],
          [0]],

         [[0],
          [0],
          [0],
          [0],
          [0]],

         [[0],
          [0],
          [0],
          [0],
          [0]],

         [[0],
          [0],
          [0],
          [0],
          [0]]],


        [[[0],
          [0],
          [0],
          [0],
          [0]],

         [[0],
          [0],
 

In [710]:
# A = 6
# B = 6x6
A = torch.arange(6)
B = torch.arange(36).reshape(6,6)
print("A = ", A.size())
print("B = ", B.size())
C = my_expand_as(A,B)
print(C)
print("A and B are broadcastable")

A =  torch.Size([6])
B =  torch.Size([6, 6])
6 6  match
Broadcastable
tensor([[0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5]])
A and B are broadcastable


# Test my_expand_as function correctness

In [711]:
A=torch.arange(9).reshape(3,3)  # 3x3
B=torch.arange(3)  # 1x3
print("A = ", A)
print("B = ",B)
print("A + B = ", A+B)
print(A.size(), B.size())

print(B.expand_as(A))
print(my_expand_as(B,A))
assert torch.all(B.expand_as(A) == my_expand_as(B,A))

A =  tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
B =  tensor([0, 1, 2])
A + B =  tensor([[ 0,  2,  4],
        [ 3,  5,  7],
        [ 6,  8, 10]])
torch.Size([3, 3]) torch.Size([3])
tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2]])
3 3  match
Broadcastable
tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2]])
3 3  match
Broadcastable


In [712]:
A = torch.rand(6).reshape(2,3)
B = torch.rand(6).reshape(1,2,3)

my_res = my_expand_as(A,B)
torch_res = A.expand_as(B)

assert torch.all(my_res == torch_res).item()

3 3  match
2 2  match
Broadcastable


In [713]:
# def broadcast_dim(src: torch.Tensor, dst: torch.Tensor) -> tuple:
# 	"""
# 	Return list of new size of broadcasted dimensions
# 	"""
# 	dimSrc = list(src.size())
# 	dimDst = list(dst.size())
# 	print(dimSrc, dimDst)

# 	if len(dimSrc) != len(dimDst):
# 		diff = abs(len(dimSrc) - len(dimDst))
# 		# add ones to the smaller dimension
# 		if len(dimSrc) < len(dimDst):
# 			for _ in range(diff):
# 				dimSrc.insert(0, 1)
# 		else:
# 			for _ in range(diff):
# 				dimDst.insert(0, 1)

# 	print(dimSrc, dimDst)
# 	new_size = []
# 	for dimA, dimB in zip(dimSrc, dimDst):
# 		new_size.append(max(dimA, dimB))

# 	return new_size

# def my_expand_as_two(A: tensor, B: tensor) -> torch.Tensor:
# 	assert is_broadcastable(A, B)

# 	new_size = broadcast_dim(A,B)
# 	print("new_size: ", new_size)
# 	return torch.zeros(new_size).squeeze()

# a = torch.rand(3).reshape(3)
# b = torch.rand(9).reshape(3,3) 
# my_expand_as_two(a,b).size()

# a = torch.rand(9).reshape(1,1,1,3,3)
# b = torch.rand(9).reshape(3,3) 
# my_expand_as_two(a,b).size()

# Check parallel broadcasting

In [714]:
def is_broadcastable_parallel_tensors(A:torch.Tensor, B:torch.Tensor) -> tuple[bool, List[int]]:
	# For semantics purpose, we order the tensors by dim
	(a,b) = (A,B)
	if A.dim() > B.dim():
		(a,b) = (B,A)
		
	# We squeeze 'a' so that it has the same number of dimensions as 'b'
	for i in range(B.dim()-A.dim()):
		a = torch.unsqueeze(a, 0)
	print("shapeA: ", a.dim())
	print("shapeB: ", b.dim())

	new_size = []
	for i in range(a.dim()):
		# The rules in PDF 5
		if a.shape[i] != 1 and b.shape[i] != 1 and a.shape[i] != b.shape[i]:
			return False, None
		
		# Keep storing the dimensions
		if a.shape[i]>=b.shape[i]:
			new_size.append(a.shape[i])
		else:
			new_size.append(b.shape[i])

	return True, new_size

# Broadcast two tensors in parallel

In [715]:
def my_broadcast_tensors(A:torch.Tensor, B:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
	is_parallel_broadcastable, desired_sizes = is_broadcastable_parallel_tensors(A, B)
	assert is_parallel_broadcastable
	assert desired_sizes is not None

	container = torch.empty(desired_sizes)  # Sekeleton tensor
	return my_expand_as(A, container), my_expand_as(B, container)

# Test my_broadcast_tensors

In [723]:
a=torch.arange(3).reshape(3)
b=torch.arange(9).reshape(1,3,3)

torch_resA, torch_resB = torch.broadcast_tensors(a,b)
my_resA, my_resB = my_broadcast_tensors(a,b)

print("torch_res = ", torch_resA)
print("my_res = ", my_resA)

assert (torch.all(torch_resA == my_resA) and torch.all(torch_resB == my_resB)).item()

shapeA:  3
shapeB:  3
3 3  match
Broadcastable
3 3  match
3 3  match
1 1  match
Broadcastable
torch_res =  tensor([[[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]]])
my_res =  tensor([[[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]],

        [[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]],

        [[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]]])


In [717]:
a=torch.arange(9).reshape(1,9)
b=torch.arange(3).reshape(1,3)

try:
	torch_resA, torch_resB = torch.broadcast_tensors(a,b)
	raise "ERROR: Something went wrong"
except:
	pass

try:
	my_resA, my_resB = my_broadcast_tensors(a,b)
	raise "ERROR: Something went wrong, shouldn;t be able to broadcast"
except:
	pass

shapeA:  2
shapeB:  2
