In [0]:
import numpy as np
import torch

In [0]:
M = 128
N = 128
K = 128
C = 32

In [0]:
matrix_a = np.random.randn(C, M, K)
matrix_b = np.random.randn(C, K, N)
matrix_c = np.random.randn(C, M, N)

In [0]:
def matmul(m_a, m_b, m_c):
  for c in range(C):
    for m in range(M):
      for n in range(N):
        val = 0.
        for k in range(K):
          val += m_a[c,m,k] * m_b[c,k,n]
        m_c[c,m,n] = val

In [6]:
%time matmul(matrix_a, matrix_b, matrix_c)

CPU times: user 35.1 s, sys: 9.6 ms, total: 35.1 s
Wall time: 35.1 s


In [7]:
np.allclose(matrix_c, (matrix_a @ matrix_b))

True

In [0]:
def matmul_boardcast(m_a, m_b, m_c):
  for c in range(C):
    for m in range(M):
      for n in range(N):
        m_c[c,m,n] = (m_a[c,m,:] * m_b[c,:,n]).sum()

In [9]:
%time matmul_boardcast(matrix_a, matrix_b, matrix_c)

CPU times: user 2.14 s, sys: 78 ms, total: 2.22 s
Wall time: 2.14 s


In [10]:
np.allclose(matrix_c, (matrix_a @ matrix_b))

True

In [0]:
!pip install pycuda

In [0]:
import pycuda.autoinit
import pycuda.driver as drv
from pycuda import gpuarray
from pycuda.compiler import SourceModule

In [0]:
kernel = SourceModule("""
__global__ void matmul(double *mat_a, double *mat_b, double *mat_c, int C, int M, int N, int K)
{
  int height = blockIdx.y * blockDim.y + threadIdx.y;
  int weight = blockIdx.x * blockDim.x + threadIdx.x;
  int channel = blockIdx.z * blockDim.z + threadIdx.z;
  int thread_idx = channel * M * N + height * N + weight;

  if (channel < C && height < M && weight < N) {
    double val = 0;
    for (int k = 0; k < K; k++)
      val += mat_a[channel * M * N + height * N + k] * mat_b[channel * M * N + k * N + weight];
    mat_c[thread_idx] = val;
  }
}
""")

def matmul_gpu(m_a, m_b, m_c):
  dev_a = gpuarray.to_gpu(m_a.reshape(-1))
  dev_b = gpuarray.to_gpu(m_b.reshape(-1))
  dev_c = gpuarray.to_gpu(m_c.reshape(-1))
  matmul_cuda = kernel.get_function("matmul")
  matmul_cuda(dev_a, dev_b, dev_c, np.int32(C), np.int32(M), np.int32(N), np.int32(K), block=(32,32,1), grid=(N//32,M//32,C))
  return dev_c.get().reshape(C,M,N)

In [14]:
%time c = matmul_gpu(matrix_a, matrix_b, matrix_c)

CPU times: user 7.18 ms, sys: 2 ms, total: 9.18 ms
Wall time: 10.3 ms


In [15]:
np.allclose(c, (matrix_a @ matrix_b))

True

In [0]:
matrix_a = torch.randn(C, M, K)
matrix_b = torch.randn(C, K, N)

In [24]:
%time matrix_c = matrix_a.matmul(matrix_b)

CPU times: user 3.38 ms, sys: 0 ns, total: 3.38 ms
Wall time: 2.62 ms


In [25]:
np.allclose(matrix_c, (matrix_a @ matrix_b))

True