In [None]:
!nvidia-smi
!nvcc --version

In [None]:
!git clone --depth=1 https://github.com/NVIDIA/cutlass.git

In [None]:
%%bash
set -euo pipefail

gpu_name="$(nvidia-smi --query-gpu=name --format=csv,noheader | head -n1)"

# Default to a safe-ish arch; override below when we recognize the GPU.
sm=70

case "$gpu_name" in
  *"A100"*) sm=80 ;;
  *"H100"*) sm=90 ;;   # Hopper
  *"L4"*)   sm=89 ;;
  *"A10"*)  sm=86 ;;
  *"RTX 30"*"A40"*|"*A40*") sm=86 ;;
  *"V100"*) sm=70 ;;
  *"T4"*)   sm=75 ;;
  *"P100"*) sm=60 ;;
esac

echo "Detected GPU: $gpu_name"
echo "Using -arch=sm_${sm}"

# Save for later cells
echo "SM=${sm}" > sm.env


In [4]:
%%bash
set -euo pipefail

cat > cutlass_gemm.cu <<'CU'
#include <cstdio>
#include <vector>
#include <random>
#include <cuda_runtime.h>

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/layout/matrix.h"

#define CHECK_CUDA(call) do { \
  cudaError_t status = (call); \
  if (status != cudaSuccess) { \
    printf("CUDA Error %s at %s:%d\n", cudaGetErrorString(status), __FILE__, __LINE__); \
    return -1; \
  } \
} while(0)

int main() {
  using Element = float;
  using Layout = cutlass::layout::RowMajor;

  // GEMM: C[M,N] = alpha * A[M,K] * B[K,N] + beta * C[M,N]
  int M = 1024, N = 1024, K = 1024;
  Element alpha = 1.0f, beta = 0.0f;

  // Host buffers
  std::vector<Element> hA(M*K), hB(K*N), hC(M*N);

  // Fill A and B with random values
  std::mt19937 rng(123);
  std::uniform_real_distribution<float> dist(-1.f, 1.f);
  for (auto &x : hA) x = dist(rng);
  for (auto &x : hB) x = dist(rng);
  std::fill(hC.begin(), hC.end(), 0.0f);

  // Device buffers
  Element *dA=nullptr, *dB=nullptr, *dC=nullptr;
  CHECK_CUDA(cudaMalloc((void**)&dA, sizeof(Element)*hA.size()));
  CHECK_CUDA(cudaMalloc((void**)&dB, sizeof(Element)*hB.size()));
  CHECK_CUDA(cudaMalloc((void**)&dC, sizeof(Element)*hC.size()));

  CHECK_CUDA(cudaMemcpy(dA, hA.data(), sizeof(Element)*hA.size(), cudaMemcpyHostToDevice));
  CHECK_CUDA(cudaMemcpy(dB, hB.data(), sizeof(Element)*hB.size(), cudaMemcpyHostToDevice));
  CHECK_CUDA(cudaMemcpy(dC, hC.data(), sizeof(Element)*hC.size(), cudaMemcpyHostToDevice));

  // CUTLASS GEMM type (FP32, row-major)
  using Gemm = cutlass::gemm::device::Gemm<
      Element, Layout,   // A
      Element, Layout,   // B
      Element, Layout    // C/D
  >;

  Gemm gemm_op;

  Gemm::Arguments args(
      {M, N, K},                 // Problem size
      {dA, K},                   // A ptr, lda
      {dB, N},                   // B ptr, ldb
      {dC, N},                   // C ptr, ldc
      {dC, N},                   // D ptr, ldd (output)
      {alpha, beta}
  );

  // Optional: workspace (some kernels require it)
  size_t workspace_size = Gemm::get_workspace_size(args);
  void* workspace = nullptr;
  if (workspace_size) CHECK_CUDA(cudaMalloc(&workspace, workspace_size));

  cutlass::Status status = gemm_op.can_implement(args);
  if (status != cutlass::Status::kSuccess) {
    printf("GEMM configuration not supported.\n");
    return -1;
  }

  status = gemm_op.initialize(args, workspace);
  if (status != cutlass::Status::kSuccess) {
    printf("GEMM initialize failed.\n");
    return -1;
  }

  status = gemm_op();
  if (status != cutlass::Status::kSuccess) {
    printf("GEMM run failed.\n");
    return -1;
  }

  CHECK_CUDA(cudaMemcpy(hC.data(), dC, sizeof(Element)*hC.size(), cudaMemcpyDeviceToHost));

  // Simple checksum so we know it worked
  double checksum = 0.0;
  for (int i = 0; i < 10; ++i) checksum += hC[i];
  printf("Done. C[0..9] sum = %.6f\n", checksum);

  if (workspace) cudaFree(workspace);
  cudaFree(dA); cudaFree(dB); cudaFree(dC);
  return 0;
}
CU

# Read saved SM from earlier step
source sm.env

# Common NVCC flags for CUTLASS on Colab; allow-unsupported handles frequent host-compiler versions
NVCCFLAGS="-O3 -std=c++17 --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math -Wno-deprecated-declarations --allow-unsupported-compiler"

# Build: include CUTLASS headers
nvcc cutlass_gemm.cu -I cutlass/include -arch=sm_${SM} ${NVCCFLAGS} -o cutlass_gemm


In [None]:
!./cutlass_gemm
