# constexpr编译期计算 - 实践篇

本notebook通过实际代码帮助你理解C++ constexpr的使用方法。

**学习目标：**
- 理解constexpr变量与函数的定义
- 掌握编译期数学计算
- 学会使用static_assert验证编译期结果
- 了解constexpr在模板中的应用


## 1. constexpr基础

constexpr变量和函数的基本用法。


In [None]:
%%writefile constexpr_basics.cpp
#include <iostream>
#include <array>

// ==================== constexpr变量 ====================

constexpr int MAX_SIZE = 1024;
constexpr double PI = 3.14159265358979;
constexpr int PRIMES[] = {2, 3, 5, 7, 11, 13, 17, 19};

// ==================== constexpr函数 ====================

// 简单的constexpr函数
constexpr int square(int x) {
    return x * x;
}

// C++14风格：可以有局部变量和循环
constexpr int factorial(int n) {
    int result = 1;
    for (int i = 2; i <= n; ++i) {
        result *= i;
    }
    return result;
}

// 递归constexpr函数
constexpr int fibonacci(int n) {
    if (n <= 1) return n;
    return fibonacci(n - 1) + fibonacci(n - 2);
}

// ==================== 编译期 vs 运行期 ====================

int main() {
    std::cout << "==========================================" << std::endl;
    std::cout << "        constexpr 基础演示" << std::endl;
    std::cout << "==========================================\n" << std::endl;
    
    // 编译期常量
    std::cout << "【constexpr变量】" << std::endl;
    std::cout << "MAX_SIZE = " << MAX_SIZE << std::endl;
    std::cout << "PI = " << PI << std::endl;
    std::cout << "PRIMES[3] = " << PRIMES[3] << std::endl;
    
    // 编译期计算
    std::cout << "\n【constexpr函数 - 编译期调用】" << std::endl;
    constexpr int sq10 = square(10);      // 编译期计算
    constexpr int fact5 = factorial(5);   // 编译期计算
    constexpr int fib10 = fibonacci(10);  // 编译期计算
    
    std::cout << "square(10) = " << sq10 << std::endl;
    std::cout << "factorial(5) = " << fact5 << std::endl;
    std::cout << "fibonacci(10) = " << fib10 << std::endl;
    
    // 用于数组大小
    std::cout << "\n【用于数组大小】" << std::endl;
    constexpr int arrSize = square(5);  // 25
    int arr[arrSize];  // 编译期确定大小
    std::cout << "Array size: " << sizeof(arr)/sizeof(arr[0]) << std::endl;
    
    // std::array也可以用
    std::array<int, factorial(4)> arr2;  // 大小为24
    std::cout << "std::array size: " << arr2.size() << std::endl;
    
    // 运行期调用
    std::cout << "\n【constexpr函数 - 运行期调用】" << std::endl;
    int x;
    std::cout << "请输入一个数字: ";
    x = 7;  // 模拟输入
    std::cout << x << std::endl;
    int result = square(x);  // 运行期计算
    std::cout << "square(" << x << ") = " << result << std::endl;
    
    // static_assert验证
    static_assert(square(10) == 100, "square(10) should be 100");
    static_assert(factorial(5) == 120, "factorial(5) should be 120");
    static_assert(fibonacci(10) == 55, "fibonacci(10) should be 55");
    std::cout << "\n所有static_assert通过！" << std::endl;
    
    return 0;
}


In [None]:
!g++ -std=c++17 -o constexpr_basics constexpr_basics.cpp && ./constexpr_basics


## 2. constexpr类

创建可在编译期使用的字面量类型。


In [None]:
%%writefile constexpr_class.cpp
#include <iostream>

// 字面量类型：所有成员和方法都可以是constexpr
class Point {
    int x_, y_;
public:
    constexpr Point(int x = 0, int y = 0) : x_(x), y_(y) {}
    
    constexpr int x() const { return x_; }
    constexpr int y() const { return y_; }
    
    constexpr int distanceSquared() const {
        return x_ * x_ + y_ * y_;
    }
    
    constexpr Point operator+(const Point& other) const {
        return Point(x_ + other.x_, y_ + other.y_);
    }
    
    constexpr bool operator==(const Point& other) const {
        return x_ == other.x_ && y_ == other.y_;
    }
};

// 更复杂的例子：有理数
class Rational {
    int num_, den_;
    
    static constexpr int gcd(int a, int b) {
        return b == 0 ? a : gcd(b, a % b);
    }
    
public:
    constexpr Rational(int num = 0, int den = 1) 
        : num_(num / gcd(num, den)), den_(den / gcd(num, den)) {}
    
    constexpr int numerator() const { return num_; }
    constexpr int denominator() const { return den_; }
    
    constexpr Rational operator+(const Rational& other) const {
        return Rational(num_ * other.den_ + other.num_ * den_,
                       den_ * other.den_);
    }
    
    constexpr Rational operator*(const Rational& other) const {
        return Rational(num_ * other.num_, den_ * other.den_);
    }
    
    constexpr bool operator==(const Rational& other) const {
        return num_ == other.num_ && den_ == other.den_;
    }
};

// 模拟FlashAttention的KernelConfig
template <int HeadDim>
struct KernelConfig {
    static constexpr int kHeadDim = HeadDim;
    static constexpr int kBlockM = HeadDim <= 64 ? 128 : 64;
    static constexpr int kBlockN = 64;
    static constexpr int kNWarps = HeadDim <= 64 ? 4 : 8;
    
    // 派生常量
    static constexpr int kNThreads = kNWarps * 32;
    static constexpr int kBlockElements = kBlockM * kBlockN;
    
    // 共享内存大小（字节）
    static constexpr int kSmemQSize = kBlockM * kHeadDim * 2;  // fp16
    static constexpr int kSmemKVSize = kBlockN * kHeadDim * 2 * 2;
    static constexpr int kSmemTotal = kSmemQSize + kSmemKVSize;
};

int main() {
    std::cout << "==========================================" << std::endl;
    std::cout << "        constexpr 类演示" << std::endl;
    std::cout << "==========================================\n" << std::endl;
    
    // Point类
    std::cout << "【Point 类】" << std::endl;
    constexpr Point p1(3, 4);
    constexpr Point p2(1, 2);
    constexpr Point p3 = p1 + p2;
    constexpr int dist = p1.distanceSquared();
    
    std::cout << "p1 = (" << p1.x() << ", " << p1.y() << ")" << std::endl;
    std::cout << "p2 = (" << p2.x() << ", " << p2.y() << ")" << std::endl;
    std::cout << "p1 + p2 = (" << p3.x() << ", " << p3.y() << ")" << std::endl;
    std::cout << "p1.distanceSquared() = " << dist << std::endl;
    
    static_assert(p1.distanceSquared() == 25, "Distance should be 25");
    static_assert(p3 == Point(4, 6), "Sum should be (4,6)");
    
    // Rational类
    std::cout << "\n【Rational 类】" << std::endl;
    constexpr Rational r1(1, 2);  // 1/2
    constexpr Rational r2(1, 3);  // 1/3
    constexpr Rational sum = r1 + r2;  // 5/6
    constexpr Rational prod = r1 * r2; // 1/6
    
    std::cout << "r1 = " << r1.numerator() << "/" << r1.denominator() << std::endl;
    std::cout << "r2 = " << r2.numerator() << "/" << r2.denominator() << std::endl;
    std::cout << "r1 + r2 = " << sum.numerator() << "/" << sum.denominator() << std::endl;
    std::cout << "r1 * r2 = " << prod.numerator() << "/" << prod.denominator() << std::endl;
    
    // KernelConfig（类似FlashAttention）
    std::cout << "\n【KernelConfig - 类似FlashAttention】" << std::endl;
    
    using Config64 = KernelConfig<64>;
    std::cout << "HeadDim=64:" << std::endl;
    std::cout << "  BlockM=" << Config64::kBlockM << std::endl;
    std::cout << "  BlockN=" << Config64::kBlockN << std::endl;
    std::cout << "  NWarps=" << Config64::kNWarps << std::endl;
    std::cout << "  NThreads=" << Config64::kNThreads << std::endl;
    std::cout << "  SmemTotal=" << Config64::kSmemTotal << " bytes" << std::endl;
    
    using Config128 = KernelConfig<128>;
    std::cout << "\nHeadDim=128:" << std::endl;
    std::cout << "  BlockM=" << Config128::kBlockM << std::endl;
    std::cout << "  BlockN=" << Config128::kBlockN << std::endl;
    std::cout << "  NWarps=" << Config128::kNWarps << std::endl;
    std::cout << "  NThreads=" << Config128::kNThreads << std::endl;
    std::cout << "  SmemTotal=" << Config128::kSmemTotal << " bytes" << std::endl;
    
    return 0;
}


In [None]:
!g++ -std=c++17 -o constexpr_class constexpr_class.cpp && ./constexpr_class


## 3. 编译期数学运算

使用constexpr实现各种数学计算。


In [None]:
%%writefile constexpr_math.cpp
#include <iostream>
#include <cmath>

// ==================== 编译期数学函数 ====================

// 编译期绝对值
constexpr int abs_val(int x) {
    return x >= 0 ? x : -x;
}

// 编译期最大值
constexpr int max_val(int a, int b) {
    return a > b ? a : b;
}

// 编译期最小值
constexpr int min_val(int a, int b) {
    return a < b ? a : b;
}

// 编译期幂运算
constexpr long long power(int base, int exp) {
    long long result = 1;
    for (int i = 0; i < exp; ++i) {
        result *= base;
    }
    return result;
}

// 编译期判断是否为2的幂
constexpr bool is_power_of_two(int n) {
    return n > 0 && (n & (n - 1)) == 0;
}

// 编译期向上取整到2的幂
constexpr int next_power_of_two(int n) {
    if (n <= 1) return 1;
    int p = 1;
    while (p < n) {
        p *= 2;
    }
    return p;
}

// 编译期整数平方根（牛顿法）
constexpr int isqrt(int n) {
    if (n < 2) return n;
    int x = n;
    int y = (x + 1) / 2;
    while (y < x) {
        x = y;
        y = (x + n / x) / 2;
    }
    return x;
}

// 编译期GCD
constexpr int gcd(int a, int b) {
    return b == 0 ? a : gcd(b, a % b);
}

// 编译期LCM
constexpr int lcm(int a, int b) {
    return a / gcd(a, b) * b;
}

// 编译期判断素数
constexpr bool is_prime(int n) {
    if (n < 2) return false;
    if (n == 2) return true;
    if (n % 2 == 0) return false;
    for (int i = 3; i * i <= n; i += 2) {
        if (n % i == 0) return false;
    }
    return true;
}

int main() {
    std::cout << "==========================================" << std::endl;
    std::cout << "        编译期数学运算" << std::endl;
    std::cout << "==========================================\n" << std::endl;
    
    // 基础运算
    std::cout << "【基础运算】" << std::endl;
    constexpr int a = abs_val(-42);
    constexpr int b = max_val(10, 20);
    constexpr int c = min_val(10, 20);
    std::cout << "abs(-42) = " << a << std::endl;
    std::cout << "max(10, 20) = " << b << std::endl;
    std::cout << "min(10, 20) = " << c << std::endl;
    
    // 幂运算
    std::cout << "\n【幂运算】" << std::endl;
    constexpr long long p1 = power(2, 10);
    constexpr long long p2 = power(3, 5);
    std::cout << "2^10 = " << p1 << std::endl;
    std::cout << "3^5 = " << p2 << std::endl;
    
    // 2的幂相关
    std::cout << "\n【2的幂】" << std::endl;
    std::cout << "is_power_of_two(64) = " << (is_power_of_two(64) ? "true" : "false") << std::endl;
    std::cout << "is_power_of_two(100) = " << (is_power_of_two(100) ? "true" : "false") << std::endl;
    std::cout << "next_power_of_two(100) = " << next_power_of_two(100) << std::endl;
    std::cout << "next_power_of_two(64) = " << next_power_of_two(64) << std::endl;
    
    // 平方根
    std::cout << "\n【整数平方根】" << std::endl;
    constexpr int sq1 = isqrt(100);
    constexpr int sq2 = isqrt(17);
    std::cout << "isqrt(100) = " << sq1 << std::endl;
    std::cout << "isqrt(17) = " << sq2 << std::endl;
    
    // GCD/LCM
    std::cout << "\n【GCD/LCM】" << std::endl;
    constexpr int g = gcd(48, 18);
    constexpr int l = lcm(12, 18);
    std::cout << "gcd(48, 18) = " << g << std::endl;
    std::cout << "lcm(12, 18) = " << l << std::endl;
    
    // 素数判断
    std::cout << "\n【素数判断】" << std::endl;
    std::cout << "is_prime(17) = " << (is_prime(17) ? "true" : "false") << std::endl;
    std::cout << "is_prime(18) = " << (is_prime(18) ? "true" : "false") << std::endl;
    std::cout << "is_prime(97) = " << (is_prime(97) ? "true" : "false") << std::endl;
    
    // 编译期验证
    static_assert(is_power_of_two(64), "64 is power of 2");
    static_assert(next_power_of_two(100) == 128, "next power of 100 is 128");
    static_assert(isqrt(100) == 10, "sqrt(100) = 10");
    static_assert(gcd(48, 18) == 6, "gcd(48, 18) = 6");
    static_assert(is_prime(97), "97 is prime");
    std::cout << "\n所有编译期验证通过！" << std::endl;
    
    return 0;
}


In [None]:
!g++ -std=c++17 -o constexpr_math constexpr_math.cpp && ./constexpr_math


## 4. constexpr与模板结合

在模板编程中使用constexpr实现编译期配置。


In [None]:
%%writefile constexpr_template.cpp
#include <iostream>
#include <array>
#include <type_traits>

// ==================== 编译期配置计算 ====================

// 根据HeadDim计算最优BlockSize
constexpr int compute_block_size(int head_dim) {
    if (head_dim <= 32) return 256;
    if (head_dim <= 64) return 128;
    if (head_dim <= 128) return 64;
    return 32;
}

// 计算需要的Warp数量
constexpr int compute_num_warps(int block_size, int head_dim) {
    int threads_needed = block_size * head_dim / 32;
    return (threads_needed + 31) / 32;
}

// 模拟FlashAttention的Kernel配置
template <int HeadDim>
struct AttentionConfig {
    static constexpr int kHeadDim = HeadDim;
    static constexpr int kBlockSize = compute_block_size(HeadDim);
    static constexpr int kNumWarps = compute_num_warps(kBlockSize, HeadDim);
    static constexpr int kNumThreads = kNumWarps * 32;
    
    // 共享内存计算
    static constexpr int kQSmemSize = kBlockSize * kHeadDim * 2;  // fp16 = 2 bytes
    static constexpr int kKVSmemSize = kBlockSize * kHeadDim * 2 * 2;  // K和V
    static constexpr int kTotalSmem = kQSmemSize + kKVSmemSize;
    
    // 检查资源限制
    static constexpr bool kValidConfig = 
        kNumThreads <= 1024 &&  // 每block最大线程数
        kTotalSmem <= 48 * 1024;  // 共享内存限制（48KB）
    
    static void print() {
        std::cout << "HeadDim = " << kHeadDim << std::endl;
        std::cout << "  BlockSize = " << kBlockSize << std::endl;
        std::cout << "  NumWarps = " << kNumWarps << std::endl;
        std::cout << "  NumThreads = " << kNumThreads << std::endl;
        std::cout << "  QSmemSize = " << kQSmemSize << " bytes" << std::endl;
        std::cout << "  KVSmemSize = " << kKVSmemSize << " bytes" << std::endl;
        std::cout << "  TotalSmem = " << kTotalSmem << " bytes" << std::endl;
        std::cout << "  ValidConfig = " << (kValidConfig ? "true" : "false") << std::endl;
    }
};

// ==================== 编译期数组生成 ====================

// 生成编译期数组
template <int N>
constexpr std::array<int, N> generate_squares() {
    std::array<int, N> result{};
    for (int i = 0; i < N; ++i) {
        result[i] = i * i;
    }
    return result;
}

// 生成斐波那契数列
template <int N>
constexpr std::array<int, N> generate_fibonacci() {
    std::array<int, N> result{};
    result[0] = 0;
    if (N > 1) result[1] = 1;
    for (int i = 2; i < N; ++i) {
        result[i] = result[i-1] + result[i-2];
    }
    return result;
}

// ==================== 编译期字符串处理 ====================

// 编译期字符串长度
constexpr int str_length(const char* str) {
    int len = 0;
    while (str[len] != '\0') ++len;
    return len;
}

// 编译期字符串比较
constexpr bool str_equal(const char* a, const char* b) {
    int i = 0;
    while (a[i] != '\0' && b[i] != '\0') {
        if (a[i] != b[i]) return false;
        ++i;
    }
    return a[i] == b[i];
}

int main() {
    std::cout << "==========================================" << std::endl;
    std::cout << "    constexpr 与模板结合" << std::endl;
    std::cout << "==========================================\n" << std::endl;
    
    // 不同HeadDim的配置
    std::cout << "【AttentionConfig - 模拟FlashAttention】\n" << std::endl;
    
    AttentionConfig<32>::print();
    std::cout << std::endl;
    
    AttentionConfig<64>::print();
    std::cout << std::endl;
    
    AttentionConfig<128>::print();
    std::cout << std::endl;
    
    // 编译期验证配置
    static_assert(AttentionConfig<64>::kValidConfig, "Config for HeadDim=64 should be valid");
    static_assert(AttentionConfig<128>::kValidConfig, "Config for HeadDim=128 should be valid");
    
    // 编译期生成数组
    std::cout << "【编译期生成数组】" << std::endl;
    constexpr auto squares = generate_squares<10>();
    std::cout << "前10个平方数: ";
    for (int i = 0; i < 10; ++i) {
        std::cout << squares[i] << " ";
    }
    std::cout << std::endl;
    
    constexpr auto fibs = generate_fibonacci<15>();
    std::cout << "前15个斐波那契数: ";
    for (int i = 0; i < 15; ++i) {
        std::cout << fibs[i] << " ";
    }
    std::cout << std::endl;
    
    // 编译期字符串
    std::cout << "\n【编译期字符串处理】" << std::endl;
    constexpr const char* hello = "Hello, World!";
    constexpr int len = str_length(hello);
    std::cout << "字符串: \"" << hello << "\"" << std::endl;
    std::cout << "长度: " << len << std::endl;
    
    static_assert(str_length("test") == 4, "Length should be 4");
    static_assert(str_equal("hello", "hello"), "Strings should be equal");
    static_assert(!str_equal("hello", "world"), "Strings should not be equal");
    
    std::cout << "\n所有编译期验证通过！" << std::endl;
    
    return 0;
}


In [None]:
!g++ -std=c++17 -o constexpr_template constexpr_template.cpp && ./constexpr_template


## 5. 清理临时文件


In [None]:
!rm -f constexpr_basics constexpr_class constexpr_math constexpr_template
!rm -f *.cpp
print("临时文件已清理")


## 总结

通过本notebook，你应该理解了：

### constexpr 的核心用法

1. **constexpr 变量**：声明编译期常量
2. **constexpr 函数**：可在编译期求值的函数
3. **constexpr 类**：字面量类型，支持编译期构造和操作
4. **static_assert**：编译期断言，验证编译期计算结果

### 在 FlashAttention 中的应用

| 用途 | 示例 |
|------|------|
| Kernel配置 | `kBlockM`, `kBlockN`, `kHeadDim` |
| 共享内存大小 | 根据配置参数编译期计算 |
| 循环展开 | 编译期已知的循环边界 |
| 资源检查 | 编译期验证配置是否超出限制 |

### 关键优势

- **零运行时开销**：所有计算在编译期完成
- **类型安全**：编译期检查，错误早发现
- **优化友好**：编译器可以基于常量进行更多优化
- **代码清晰**：明确表达设计意图

## 练习

1. 实现一个编译期字符串哈希函数
2. 使用constexpr实现编译期排序算法
3. 设计一个类似FlashAttention的配置系统，根据GPU架构自动选择参数
