In [131]:
import torch
from torch import tensor

# Check is broadcastable

In [132]:
def is_broadcastable(A:tensor, B:tensor):
	"""
	A tensor can't be broadcasted if:

	(dimA[i] != dimB[i]) and (dimA[i] != 1 and dimB[i] != 1)

	The dimension of A can be longer than B, or vice versa.
	To deal with this, we extend the shorter tensor with 1's.
	"""

	minDim = min(A.dim(), B.dim())
	i = -1
	# Check latest size, after loop, everything is expanded as '1' so we dont need to check
	for _ in range(minDim):
		sA = A.size()[i]
		sB = B.size()[i]
		i = i - 1
		if sA != sB and sA != 1 and sB != 1:
			print(sA, sB, " mismatch")
			return False
		else:
			print(sA, sB, " match")

	return True

In [133]:
def my_expand_as(A:tensor, B:tensor) -> tensor:
	assert is_broadcastable(A, B)
	pass

# Test is broadcastable

In [134]:
# A = 2x1x2
# B = 2x1x1x3
A = torch.arange(4).reshape(2,1,2)
B = torch.arange(6).reshape(2,1,1,3)
print("A = ", A.size())
print("B = ", B.size())
try:	
	C = my_expand_as(A,B)
	print(C)
	raise "ERROR: A,B are not broadcastable, yet my_expand_as() returned a tensor"
except AssertionError:
	print("A and B are not broadcastable")

A =  torch.Size([2, 1, 2])
B =  torch.Size([2, 1, 1, 3])
2 3  mismatch
A and B are not broadcastable


In [135]:
# A = 2x1x2
# B = 5x1
A = torch.arange(4).reshape(2,1,2)
B = torch.arange(5).reshape(5,1)
print("A = ", A.size())
print("B = ", B.size())
C = my_expand_as(A,B)
print(C)
print("A and B are broadcastable")

A =  torch.Size([2, 1, 2])
B =  torch.Size([5, 1])
2 1  match
1 5  match
None
A and B are broadcastable


In [136]:
# A = 2x1x1x3
# B = 5x1
A = torch.arange(6).reshape(2,1,1,3)
B = torch.arange(5).reshape(5,1)
print("A = ", A.size())
print("B = ", B.size())
C = my_expand_as(A,B)
print(C)
print("A and B are broadcastable")

A =  torch.Size([2, 1, 1, 3])
B =  torch.Size([5, 1])
3 1  match
1 5  match
None
A and B are broadcastable


In [137]:
# A = 2x1x2
# B = 1x3x1x2
A = torch.arange(4).reshape(2,1,2)
B = torch.arange(6).reshape(2,1,1,3)
print("A = ", A.size())
print("B = ", B.size())
try:	
	C = my_expand_as(A,B)
	print(C)
	raise "ERROR: A,B are not broadcastable, yet my_expand_as() returned a tensor"
except AssertionError:
	print("A and B are not broadcastable")

A =  torch.Size([2, 1, 2])
B =  torch.Size([2, 1, 1, 3])
2 3  mismatch
A and B are not broadcastable
