In [None]:
import torch
import mat_mat_ops

# Sum grad over broadcasted dims

In [None]:

def sum_over_broadcasted_dims(grad_output, original_shape, broadcasted_shape):
    """
    Compute the sum over broadcasted dimensions to match the original shape.
    
    Args:
        grad_output (Tensor): The gradient output tensor with broadcasted dimensions.
        original_shape (tuple): The shape of the original tensor before broadcasting.
        broadcasted_shape (tuple): The broadcasted shape after broadcasting.
    
    Returns:
        Tensor: The gradient tensor summed over broadcasted dimensions.
    """
    # Find the broadcasted dimensions by comparing original_shape and broadcasted_shape
    original_dims = len(original_shape)
    broadcast_dims = len(broadcasted_shape)

    # Expand original_shape to match the length of broadcasted_shape (prepend 1s)
    if original_dims < broadcast_dims:
        original_shape_prepend = (1,) * (broadcast_dims - original_dims) + original_shape
    
    # Identify which dimensions were broadcasted (i.e., size 1 in the original_shape_prepend)
    broadcasted_dims = [i for i in range(broadcast_dims) if original_shape_prepend[i] == 1 and broadcasted_shape[i] != 1]
    
    # Sum over the broadcasted dimensions
    if broadcasted_dims:
        grad_output = torch.sum(grad_output, dim=tuple(broadcasted_dims), keepdim=False)
    
    # Squeeze the summed dimensions if necessary to match the original_shape
    return grad_output.view(original_shape)

In [None]:
# Example usage:
broadcasted_shape = (2, 2, 1000, 100)
original_shape = (1, 1000, 100)
grad_output = torch.randn(broadcasted_shape)

# Compute the gradient by summing over broadcasted dimensions
grad_input = sum_over_broadcasted_dims(grad_output, original_shape, broadcasted_shape)
print(f"Original Shape: {original_shape}")
print(f"Broadcasted Shape: {broadcasted_shape}")
print(f"Grad Output Shape: {grad_output.shape}")
print(f"Summed Grad Input Shape: {grad_input.shape}")

# Dimension Check

In [None]:
def mat_mat_ops_dimension_check(tensor1, tensor2):
    # Get dimensions of the input tensors
    dim1 = tensor1.dim()
    dim2 = tensor2.dim()

    # Case 2: Both tensors are 2-dimensional (matrix-matrix product)
    if dim1 == 2 and dim2 == 2:
        if tensor1.size(1) != tensor2.size(0):
            raise ValueError(f"Size mismatch: {tensor1.size()} vs {tensor2.size()}")
        tensor1_broadcasted_shape = (1,) + tensor1.shape
        tensor2_broadcasted_shape = (1,) + tensor2.shape

        out_shape = ( tensor1.size(-2), tensor2.size(-1) )

        return tensor1_broadcasted_shape, tensor2_broadcasted_shape, out_shape

    # Case 5: Either tensor1 or tensor2 is N-dimensional (N > 2), batched matrix multiplication
    elif (dim1 >= 2 and dim2 >= 2):
        # Ensure broadcastability of the batch dimensions
        # Broadcast the batch dimensions
        try:
            broadcasted_shape = torch.broadcast_shapes(tensor1.shape[:-2], tensor2.shape[:-2])
        except RuntimeError as e:
            raise ValueError(f"Batch dimensions are not broadcastable: {tensor1.shape[:-2]} vs {tensor2.shape[:-2]}") from e

        # Ensure matrix dimensions match for multiplication
        if tensor1.size(-1) != tensor2.size(-2):
            raise ValueError(f"Size mismatch in matrix dimensions: {tensor1.size()} vs {tensor2.size()}")
        
        tensor1_broadcasted_shape = broadcasted_shape + tensor1.shape[-2:]
        tensor2_broadcasted_shape = broadcasted_shape + tensor2.shape[-2:]

        batch_dims = tensor1_broadcasted_shape[:-2]
        out_shape = (*batch_dims, tensor1.size(-2), tensor2.size(-1) )

        return tensor1_broadcasted_shape, tensor2_broadcasted_shape, out_shape

    else:
        raise ValueError("Invalid dimensions for batch_mat_mat_ops.")


In [None]:
# Test cases

t1 = torch.empty(3, 4)          # 2D tensor
t2 = torch.empty(4, 5)          # 2D tensor
print(mat_mat_ops_dimension_check(t1, t2))  # Expected: matrix_matrix

t1 = torch.empty(5,4,1, 3, 4)       # 3D tensor
t2 = torch.empty(4, 6)       # 3D tensor
print(mat_mat_ops_dimension_check(t1, t2))  # Expected: batched_matrix_multiplication

t1 = torch.empty(5,4,1, 3, 4)       # 3D tensor
t2 = torch.empty(5, 4, 6)       # 3D tensor
print(mat_mat_ops_dimension_check(t1, t2))  # Expected: batched_matrix_multiplication

# Batch dimension check will fail
t1 = torch.empty(5, 3, 4)
t2 = torch.empty(2, 4, 6)
try:
    print(mat_mat_ops_dimension_check(t1, t2))  # Expected to raise an error
except ValueError as e:
    print(e)

# Test

In [None]:
device = 'cpu'
requires_grad = True

In [None]:
a = torch.rand((8, 64), device=device, requires_grad=requires_grad)
b = torch.rand((64, 512), device=device, requires_grad=requires_grad)

In [None]:
# a[0,0,:], b[0,0,:]
a.max(), b.max() , a.min(), b.min()
# ! b.max = a.max = 1
# ! b.min = a.min = 0

# MatMul

In [None]:
def reference_mat_mat_mul(a, b):
    # batch_dims = a.shape[:-2]
    # out_shape = (*batch_dims, a.shape[-2], b.shape[-1])

    # mat1 = a.flatten(start_dim=0, end_dim=-3)
    # mat2 = b.flatten(start_dim=0, end_dim=-3)

    return torch.matmul(input=a, other=b)

In [None]:
def mat_mat_mul(a, b):
    # # condition = a.dim() == b.dim()
    # # if not condition:
    # #     raise ValueError(f'tensors a and b must have same number of dimensions')
    # torch._check(a.dim() == b.dim() , message=f'tensors a and b must have same number of dimensions')
    # torch._check(a.dim() >= 3 , message=f'a.dim() must be at least 3 , but {a.dim()=}')
    # for dim in range(a.dim() - 2):
    #     torch._check(a.shape[dim] == b.shape[dim] , message=f'tensors "a" and "b" must have same size at batch dimension: {dim=}')
    
    # torch._check(a.shape[-1] == b.shape[-2] , message=f'For Matrix-Matrix operation satisfy : a.shape[-1]==b.shape[-2], but {a.shape[-1]=} and {b.shape[-2]=}')
    # torch._check(a.dtype == torch.float , message=f'tensor "a" must be float, but {a.dtype=}')
    # torch._check(b.dtype == torch.float , message=f'tensor "b" must be float, but {b.dtype=}')
    # torch._check(a.device == b.device , message=f'tensors "a" and "b" must be on same device, but {a.device=} and {b.device=}')

    # batch_dims = a.shape[:-2]
    # out_shape = (*batch_dims, a.shape[-2], b.shape[-1])

    # a = a.flatten(start_dim=0, end_dim=-3)
    # b = b.flatten(start_dim=0, end_dim=-3)
    
    out = mat_mat_ops.ops.mat_mat_mul(a,b)

    # out = out.reshape(out_shape)
    return out

In [None]:
out = mat_mat_mul(a, b)
expected_out = reference_mat_mat_mul(a, b)
torch.testing.assert_close(out, expected_out)

In [None]:
out.shape, expected_out.shape

In [None]:
out[0,0], expected_out[0,0]

In [None]:
grad_out = torch.rand_like(out)

my_a_b_grad = torch.autograd.grad(out, [a, b], grad_out)
expected_a_b_grad = torch.autograd.grad(expected_out, [a, b], grad_out)

torch.testing.assert_close(my_a_b_grad, expected_a_b_grad, atol=5e-5 , rtol=2e-4)

In [None]:
my_a_b_grad[0].shape, expected_a_b_grad[0].shape

In [None]:
my_a_b_grad[1].shape, expected_a_b_grad[1].shape

In [None]:
my_a_b_grad[0][0,0], expected_a_b_grad[0][0,0]

# Mat MAt L1

In [None]:
def reference_mat_mat_l1(a, b):
    a_broadcast_shape , b_broadcast_shape, out_shape = mat_mat_ops_dimension_check(a, b)

    torch._check(a.shape[-1] == b.shape[-2] , message=f'For Matrix-Matrix operation satisfy : a.shape[-1]==b.shape[-2], but {a.shape[-1]=} and {b.shape[-2]=}')
    torch._check(a.dtype == torch.float , message=f'tensor "a" must be float, but {a.dtype=}')
    torch._check(b.dtype == torch.float , message=f'tensor "b" must be float, but {b.dtype=}')
    torch._check(a.device == b.device , message=f'tensors "a" and "b" must be on same device, but {a.device=} and {b.device=}')

    mat1 = a.expand(a_broadcast_shape).flatten(start_dim=0, end_dim=-3)
    mat2 = b.expand(b_broadcast_shape).flatten(start_dim=0, end_dim=-3)
    print(mat1.shape, mat2.shape)
    
    B, M, K = mat1.shape
    B, K, N = mat2.shape
    mat2_tr = torch.transpose(input=mat2, dim0=1, dim1=2) # N,K
    my_out = torch.empty(size=(B,M,N), device=device, dtype=torch.float)
    for bs in range(B):
        for m in range(M):
            my_out[bs, m, :] = (torch.abs(mat1[bs,m,:] - mat2_tr[bs,:,:]).sum(dim=-1) ) / K # (N)
    
    return my_out.reshape(out_shape)

In [None]:
def mat_mat_l1(a, b):
    return mat_mat_ops.ops.mat_mat_l1( a , b )

In [None]:
out = mat_mat_l1(a, b) 
print(out.max() , out.min())
# ! out.min = b.min = a.min = 0
# ! out.max = K * (b.max or a.max) = K*1 = K
expected_out = reference_mat_mat_l1(a, b)
print(out.shape , expected_out.shape)
torch.testing.assert_close(out, expected_out, atol=5e-5 , rtol=2e-4)

In [None]:
out[0,0], expected_out[0,0]

In [None]:
M, K = a.shape[-2:]
N = b.shape[-1]

grad_out = torch.rand_like(out)

my_a_b_grad = torch.autograd.grad(out, [a, b], grad_out)
expected_a_b_grad = torch.autograd.grad(expected_out, [a, b], grad_out)

torch.testing.assert_close(my_a_b_grad[0], expected_a_b_grad[0]*K/N , atol=5e-5 , rtol=2e-4)
torch.testing.assert_close(my_a_b_grad[1], expected_a_b_grad[1]*K/M , atol=5e-5 , rtol=2e-4)

In [None]:
my_a_b_grad[0][0,0], expected_a_b_grad[0][0,0]

### Distribution

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style="dark")
%matplotlib inline

In [None]:
plt.figure(figsize=(20,10))
plt.subplot(231)
plt.hist(a.view(-1).tolist(), bins=50, density=True);
plt.subplot(232)
plt.hist(b.view(-1).tolist(), bins=50, density=True);
plt.subplot(233)
plt.hist((out).view(-1).tolist(), bins=50, density=True);
plt.subplot(234)
plt.hist((grad_out).view(-1).tolist(), bins=50, density=True);
plt.subplot(235)
plt.hist((my_a_b_grad[0]).view(-1).tolist(), bins=50, density=True);
plt.subplot(236)
plt.hist((my_a_b_grad[1]).view(-1).tolist(), bins=50, density=True);