In [1]:
import torch
import torch.nn as nn

device = "mps" if torch.backends.mps.is_available() else "cpu"

In [2]:
'''
Basic arithmetic with broadcasting.
Create a tensor A of shape (3, 1) and a tensor B of shape (1, 4). 
Perform element-wise addition and element-wise multiplication on A and B. 
Explain how broadcasting allows this operation to be performed.
'''

A = torch.rand(size=(3,1)).to(device)
B = torch.rand(size=(1,4)).to(device)

n = A.shape[0]
m = B.shape[1]

C = torch.zeros(size=(n,m)).to(device)
for i in range(n):
    for j in range(m):
        C[i][j] = A[i][0] + B[0][j]

print(C)


tensor([[1.2390, 0.8470, 0.6567, 0.6765],
        [1.7383, 1.3463, 1.1559, 1.1758],
        [0.8531, 0.4611, 0.2707, 0.2906]], device='mps:0')


In [5]:
C_= torch.broadcast_shapes((3,1),(1,4)) # C gets broadcasted shape of (3,4)
C_ = A+B
assert(torch.sum(C_ == C) == C.shape[0]*C.shape[1]) 
# Number of equal elements in both the tensors should be equal to the tot number of tensors

### Broadcasting:
- In linear algebra, addition, subtraction, multiplication, and division of matrices (or tensors) require them to have the exact same shape (element-wise operations).
- Broadcasting is a mechanism that relaxes this constraint. When two tensors have different shapes, PyTorch (and NumPy) attempts to align them by "stretching" the smaller tensor along its dimension(s) so that the resulting shapes are compatible. 



### Rule of Broadcasting:
- Each tensor has at least one dimension.
- When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.
- If a dimension is missing, always pad in the leading dimensions. Eg. (4,) -> (1,4) 

Eg. A = (3,1,2), B = (2,)
These 2 are broadcastable because:

```
A = (3,1,2)
B = (1,1,2)
```

And hence the result of the broadcasted tensor is (3,1,2)

Eg. 

```
A = (4, 1, 6, 1)
B = (1, 5, 1, 8)
```

C = A + B <br>
C = (4,5,6,8)

In [12]:
# Write a function `is_broadcastable(shape1, shape2)` 
# that returns True/False if the two shapes can be broadcast together
# according to PyTorch rules.
#
# Example:
# is_broadcastable((4,1,3), (1,5,1)) ➜ True
# is_broadcastable((4,2), (3,)) ➜ False

def is_broadcastable(A: torch.Tensor, B:torch.tensor)-> bool:
    shape1 = A.shape
    shape2 = B.shape

    if len(shape1) > len(shape2):
        while(len(shape1) != len(shape2)):
            B = torch.unsqueeze(B, dim=0)
            shape2 = B.shape
    else:
        while(len(shape1) != len(shape2)):
            A = torch.unsqueeze(A, dim=0)
            shape1 = A.shape
    
    assert(len(shape1) == len(shape2))
    print(f"Shape of A : {A.shape} \nShape of B : {B.shape}")
    
    n = len(shape1)
    
    for i in range(n):
        if ((shape1[n-i-1] != shape2[n-i-2]) and (shape1[n-i-1] != 1) and (shape2[n-i-1] != 1)):
            return False
    
    return True


In [13]:
A = torch.rand(size=(4,1,3))
B = torch.rand(size=(1,5,1))
is_broadcastable(A,B)

Shape of A : torch.Size([4, 1, 3]) 
Shape of B : torch.Size([1, 5, 1])


True

In [17]:
# Create a tensor of shape (3,1)
# Expand it to (3,5) using .expand()
# Verify that .expand() doesn’t allocate new memory (use a.is_shared_storage(b))

a = torch.rand(size=(3,1))
b = torch.expand_copy(a,size=(3,5))
print(a.data_ptr())
print(b.data_ptr())

5485646976
5485955584
