<a href="https://colab.research.google.com/github/OfekYa/Deep-Learning/blob/main/implement_pytorch_broadcast.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implement the broadcasting functionality of PyTorch

In [None]:
import torch

# **Question 1_a**

In [None]:

def expand_as(tensor_A, tensor_B):

    tensor_B[tensor_B != 0] = 0 # Redefine the B values as 0's. Keeping the original B size.
    tensor_C = tensor_A.clone()  # Duplication (שכפול) of A in order not to change the size of A.
                                 # (according to the requirement of the exercise)

    A_shape = tensor_A.shape  # Save the dimensions of A
    B_shape = tensor_B.shape  # Save the dimensions of B

    A_current_dim = len(A_shape) - 1  # The rightmost dimension of A
    B_current_dim = len(B_shape) - 1  # The rightmost dimension of B

    if A_current_dim > B_current_dim:  #  If tensor B has fewer dimensions than tensor A, then A cannot be broadcast to B shape.
        raise Exception('tensor ' + str(tensor_A) + ' cannot broadcast to tensor shape: ' + str(tensor_B.shape))

    while B_current_dim >= 0:

        ''' If tensor A has fewer dimensions than tensor B, then we add an extra dimension at
            the beginning of A until the number of dimensions in both tensors is equal. (p.21 [2.a] learning guide) '''
        if A_current_dim < 0:
            tensor_C.unsqueeze_(0) # Turn an N dim tensor into an (N+1) dim tensor by adding an extra dimension.
            A_current_dim += 1
            A_shape = tensor_C.shape # C has increased by one dimension, so we need to update the shape.

        '''  If the current dimension size of A and B are different, 
              and A is also different from 1, then A cannot be broadcast to B shape. (p.21 [1.a] learning guide)'''
        if (B_shape[B_current_dim] != A_shape[A_current_dim] and A_shape[A_current_dim] != 1):
            raise Exception('tensor ' + str(tensor_A) + ' cannot broadcast to tensor shape: ' + str(tensor_B.shape))

        A_current_dim -= 1  # Move one dimension to the left in each of the tensors
        B_current_dim -= 1

    tensor_C = tensor_A + tensor_B  # We will receive the broadcasting of A to the shape of B.
    return tensor_C

# **Question 1_b**

In [None]:

def broadcastible_together(tensor_A, tensor_B):

    A_shape = tensor_A.shape   # Save the dimensions of A
    B_shape = tensor_B.shape   # Save the dimensions of B

    A_current_dim = len(A_shape) - 1  # The rightmost dimension of A
    B_current_dim = len(B_shape) - 1  # The rightmost dimension of B

    """ Go through all the dimensions of at least one of the tensors and check the conditions for broadcasting.
        (p.21 [1->a,b,c] learning guide) """
    while A_current_dim >= 0 and B_current_dim >= 0:
        if A_shape[A_current_dim] != B_shape[B_current_dim]:
            if A_shape[A_current_dim] != 1 and B_shape[B_current_dim] != 1:
                return False  # Unable to broadcasting.

        A_current_dim -= 1  # Move one dimension to the left in each of the tensors.
        B_current_dim -= 1

    return True, (tensor_A + tensor_B).size()  # The size the tensors will be broadcasting to.


# **Question 1_c**

In [None]:

def broadcast_tensors(tensor_A, tensor_B):

    broadcastible, size = broadcastible_together(tensor_A, tensor_B) # check if A and B can broadcast together.
    if broadcastible:
        tmp_tensor = torch.zeros(size) # Tensor of 0's in size which the tensors A and B will be broadcasting to.
        tensor_A = expand_as(tensor_A, tmp_tensor)  # broadcast A
        tensor_B = expand_as(tensor_B, tmp_tensor)  # broadcast B
        return tensor_A, tensor_B



# **Question 1_d**

In [None]:

def compare_my_expand_as_to_pytorch(test_cases):

    for i in range(len(test_cases)):

        A = test_cases[i][0]
        B = test_cases[i][1]
        expected_error, threw_error = False, False

        try:
            expected_tensor = A.expand_as(B)
        except:
            expected_error = True
        try:
            actual_tensor = expand_as(A, B)
        except:
            threw_error = True
        if expected_error != threw_error:
            print("FAILED: EXCEPTION: expected_error != threw_error")

        elif threw_error:
            print("SUCCESS: both return exception")

        elif not torch.equal(expected_tensor, actual_tensor):
            print("FAILED: EXCEPTION: expected tensor!= my tensor")

        else:
            print("SUCCESS: both return the same tensor")
##############################################################################################################

def compare_my_broadcastible_together_to_pytorch(test_cases):

    for i in range(len(test_cases)):

        A = test_cases[i][0]
        B = test_cases[i][1]


        expected_result =  False

        try:
            expected_result = True, torch.broadcast_tensors(A, B)[0].shape

        except:
            pass

        try:
            actual_result = broadcastible_together(A, B)
        except:
            pass

        if type(expected_result) and type(actual_result) is bool: # התוצאות יהיו או ערכי TUPLE או ערכי BOOL:FALSE
            print("SUCCESS: both return the same bool (FALSE)")

        elif expected_result[1] != actual_result[1]:
            print("FAILED: EXCEPTION: expected_result[1] != actual_result[1]")

        else:
            print("SUCCESS: both return the same result")


##############################################################################################################


def compare_my_broadcast_tensors_to_pytorch(test_cases):
    for i in range(len(test_cases)):

        A = test_cases[i][0]
        B = test_cases[i][1]

        expected_error, threw_error = False, False

        try:
            expected_result = torch.broadcast_tensors(A, B)
        except:
            expected_error = True

        try:
            actual_result = broadcast_tensors(A, B)
        except:
            threw_error = True

        if expected_error != threw_error:
            print("FAILED: EXCEPTION: expected_error != threw_error")
        elif threw_error:
            print("SUCCESS: both return exception")

        expected_a, expected_b = expected_result
        actual_a, actual_b = actual_result
        if not torch.equal(expected_a, actual_a):
            print("FAILED: bad result for A")

        if not torch.equal(expected_b, actual_b):
            print("FAILED: bad result for B")

        print("SUCCESS: both return the same result")



In [None]:

test_cases = [
    [torch.tensor([4, 4]), torch.tensor([2])],
    [torch.tensor([[1, 6, 7], [1, 6, 7]]), torch.tensor([2])],
    [torch.tensor([1, 2]), torch.tensor([[2, 3, 4], [5, 6, 7]])],
    [torch.tensor([1, 2, 3]), torch.tensor([[2, 3, 4], [5, 6, 7]])],
    [torch.tensor([[1, 2, 3]]), torch.tensor([[2, 3, 4], [5, 6, 7]])],

    [torch.tensor([[[1, 2, 3]]]), torch.tensor([[2, 3, 4], [5, 6, 7]])],
    [torch.arange(10 ** 4).reshape(10, 10, 10, 1, 10), torch.arange(10 ** 5).view(10, 10, 10, 10, 10)],
    [torch.arange(10 ** 3).reshape(10, 1, 10, 1, 10), torch.arange(10 ** 4).view(10, 10, 10, 10)],
    [torch.arange(10 ** 3).reshape(10, 10, 1, 10), torch.arange(10 ** 3).view(10, 10, 10)],
    [torch.arange(10 ** 2).reshape(10, 1, 1, 10), torch.arange(10 ** 5).view(10, 10, 10, 10, 10)],

    [torch.arange(10 ** 2).reshape(10, 10), torch.arange(10 ** 5).view(10, 10, 10, 10, 10)]
]
print("TEST: expand_as\n")
compare_my_expand_as_to_pytorch(test_cases)
print("\n\nTEST: broadcastible_together\n")
compare_my_broadcastible_together_to_pytorch(test_cases)
print("\n\nTEST: broadcast_together\n")
compare_my_broadcast_tensors_to_pytorch(test_cases)

TEST: expand_as

SUCCESS: both return exception
SUCCESS: both return exception
SUCCESS: both return exception
SUCCESS: both return the same tensor
SUCCESS: both return the same tensor
SUCCESS: both return exception
SUCCESS: both return the same tensor
SUCCESS: both return exception
SUCCESS: both return exception
SUCCESS: both return the same tensor
SUCCESS: both return the same tensor


TEST: broadcastible_together

SUCCESS: both return the same result
SUCCESS: both return the same result
SUCCESS: both return the same bool (FALSE)
SUCCESS: both return the same result
SUCCESS: both return the same result
SUCCESS: both return the same result
SUCCESS: both return the same result
SUCCESS: both return the same result
SUCCESS: both return the same result
SUCCESS: both return the same result
SUCCESS: both return the same result


TEST: broadcast_together

SUCCESS: both return the same result
SUCCESS: both return the same result
SUCCESS: both return exception
SUCCESS: both return the same resu