Matrix Multiplication in PyTorch


In [1]:
##torch.mm()
import torch

A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])

C = torch.mm(A, B)  # Matrix multiplication
print(C)


tensor([[19, 22],
        [43, 50]])


In [2]:
##torch.matmul()
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])

C = torch.matmul(A, B)  # Recommended for general use
print(C)


tensor([[19, 22],
        [43, 50]])


In [3]:
##It works for batch-wise multiplication of higher-dimensional tensors:

A = torch.rand(2, 3, 4)  # 3D tensor
B = torch.rand(2, 4, 5)  # 3D tensor

C = torch.matmul(A, B)  # Automatically performs batch multiplication
print(C.shape)  # Output: torch.Size([2, 3, 5])


torch.Size([2, 3, 5])


In [4]:
##@ Operator (Pythonic Way)
C = A @ B  # Equivalent to torch.matmul(A, B)


In [5]:
##torch.bmm() (Batch Matrix Multiplication)
A = torch.rand(10, 3, 4)  # Batch of 10 matrices (3x4)
B = torch.rand(10, 4, 5)  # Batch of 10 matrices (4x5)

C = torch.bmm(A, B)  # Batch matrix multiplication
print(C.shape)  # Output: torch.Size([10, 3, 5])


torch.Size([10, 3, 5])


In [6]:
##Element-Wise Multiplication (torch.mul())
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])

C = A * B  # Element-wise multiplication
print(C)


tensor([[ 5, 12],
        [21, 32]])


In [7]:
##Transposing a Matrix Before Multiplication
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
B = torch.tensor([[1, 4], [2, 5], [3, 6]])

C = torch.mm(A, B)  # Shape (2,3) x (3,2) → (2,2)
print(C)


tensor([[14, 32],
        [32, 77]])


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

A = torch.rand(1000, 1000).to(device)  # Move tensors to GPU
B = torch.rand(1000, 1000).to(device)

C = torch.matmul(A, B)  # GPU-accelerated matrix multiplication


In [10]:
##Implementation of Strassen’s Algorithm in Python
import torch

def strassen_multiply(A, B):
    """
    Function to multiply two matrices using Strassen's algorithm.
    Assumes that A and B are square matrices of size 2^k x 2^k.
    """
    n = A.shape[0]

    # Base case: if the matrix size is 1x1, perform scalar multiplication
    if n == 1:
        return A * B
    
    # Divide matrices into four submatrices
    mid = n // 2

    A11, A12 = A[:mid, :mid], A[:mid, mid:]
    A21, A22 = A[mid:, :mid], A[mid:, mid:]
    B11, B12 = B[:mid, :mid], B[:mid, mid:]
    B21, B22 = B[mid:, :mid], B[mid:, mid:]

    # Compute the seven Strassen multiplications
    M1 = strassen_multiply(A11 + A22, B11 + B22)
    M2 = strassen_multiply(A21 + A22, B11)
    M3 = strassen_multiply(A11, B12 - B22)
    M4 = strassen_multiply(A22, B21 - B11)
    M5 = strassen_multiply(A11 + A12, B22)
    M6 = strassen_multiply(A21 - A11, B11 + B12)
    M7 = strassen_multiply(A12 - A22, B21 + B22)

    # Compute the four quadrants of the result matrix
    C11 = M1 + M4 - M5 + M7
    C12 = M3 + M5
    C21 = M2 + M4
    C22 = M1 - M2 + M3 + M6

    # Combine the quadrants into the final matrix
    C = torch.zeros((n, n), dtype=A.dtype)
    C[:mid, :mid] = C11
    C[:mid, mid:] = C12
    C[mid:, :mid] = C21
    C[mid:, mid:] = C22

    return C

# Example usage
A = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
B = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32)

C = strassen_multiply(A, B)
print(C)


tensor([[19., 22.],
        [43., 50.]])
