In [None]:

import torch.cuda.nvtx as nvtx
import torch
import math

batch_size = 32
seq_len = 200
d_k = 64
d_v = 64
d_model = 512

@nvtx.range('scaled dot product attention')
def annotated_scaled_dot_product_attention(Q, K, V, mask, seq_len, batch_dims):
    # Create Q, K, V tensors (single head attention)
    Q = torch.randn(batch_size, seq_len, d_k, device='cuda')
    K = torch.randn(batch_size, seq_len, d_k, device='cuda')
    V = torch.randn(batch_size, seq_len, d_v, device='cuda')
    o_proj_weight = torch.randn(d_model, d_v, device='cuda')

    batch_size_actual = Q.shape[0]
    d_k = Q.shape[-1]
   
    # 创建因果掩码：每个位置只能看到之前的token
    # causal_mask: (seq_len, seq_len)，位置(i,j)当i>=j时为True
    mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device='cuda'))
    # 扩展到批次和头的维度：(1, ..., 1, seq_len, seq_len)
    # 需要 len(batch_dims) + 1 个1（batch_dims + num_heads维度）
    mask = mask.expand(batch_size_actual, seq_len, seq_len)

    with nvtx.range('computing attention scores'):
        # 开始运算attention
        # head_i = Attention(Q_i, K_i, V_i)，使用 scaled dot-product attention
        scores = Q @ K.transpose(-2, -1)  # (batch_size, seq_len, seq_len)
        scores = scores / math.sqrt(d_k)  # 缩放
        # 应用因果掩码：将不允许的位置设为负无穷
        scores = scores.masked_fill(~mask, float('-inf'))

    with nvtx.range('computing softmax'):
        attention_weights = torch.softmax(scores, dim=-1)  # (batch_size, seq_len, seq_len)
        attention = attention_weights @ V  # (batch_size, seq_len, d_v)

    with nvtx.range('final matmul'):
        output = attention @ o_proj_weight.T

    return output

In [None]:

import triton
import triton.language as tl
import torch

@triton.jit
def flash_fwd_kernel(
    Q_ptr, K_ptr, V_ptr,
    O_ptr, L_ptr,
    stride_qb, stride_qq, stride_qd,# stride_qb是Q在batch维度上的stride，标识Q从一个batch跳到下一个batch需要经过的元素数量
    stride_kb, stride_kk, stride_kd,
    stride_vb, stride_vk, stride_vd,
    stride_ob, stride_oq, stride_od,
    stride_lb, stride_lq,
    N_QUERIES, N_KEYS,
    scale,
    D:tl.constexpr,
    Q_TILE_SIZE:tl.constexpr,
    K_TILE_SIZE:tl.constexpr,
):
    #grid的形状是[Q_TILE_SIZE, batch_size]
    query_tile_index = tl.program_id(0) #当前program在第0维的索引，标识当前program处理的数据范围
    batch_index = tl.program_id(1)

    Q_block_ptr = tl.make_block_ptr(
        Q_ptr + batch_index*stride_qb,#处理不同batch下的Q数据，Q.shape=[4,128,64],batch0在地质Q_ptr+0, batch1在地址Q_ptr+8192,以此类推。
        shape = (N_QUERIES, D), #tensor的完整形状（全局）
        strides = (stride_qq, stride_qd),
        offsets = (query_tile_index * Q_TILE_SIZE, 0), # tile在tensor中的起始位置
        block_shape = (Q_TILE_SIZE, D), #每次加载的tile大小（局部）
        order = (1,0), #储存顺序为列优先
    )

    K_block_ptr = tl.make_block_ptr(
        K_ptr + batch_index * stride_kb,
        shape = (N_KEYS, D),
        strides = (stride_kk, stride_kd),
        offsets = (0, 0),
        block_shape = (K_TILE_SIZE, D),
        order = (1,0), #储存顺序为列优先
    )

    V_block_ptr = tl.make_block_ptr(
        V_ptr + batch_index * stride_vb,
        shape = (N_KEYS, D),
        strides = (stride_vk, stride_vd),
        offsets = (0, 0),
        block_shape = (K_TILE_SIZE, D),
        order = (1,0), #储存顺序为列优先
    )

    O_block_ptr = tl.make_block_ptr(
        O_ptr + batch_index * stride_ob,
        shape = (N_QUERIES,D),
        strides = (stride_oq, stride_od),
        offsets = (0, 0),
        block_shape = (Q_TILE_SIZE, D),
        order = (1,0), #储存顺序为列优先
    )

    L_block_ptr = tl.make_block_ptr(
        L_ptr + batch_index * stride_lb,
        shape = (N_QUERIES, ),
        strides = (stride_lb, stride_lq),
        offsets = (0, 0),
        block_shape = (Q_TILE_SIZE, ),
        order = (1,0), #储存顺序为列优先
    )

    Q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option = 'zero')
    K = tl.load(K_block_ptr, boundary_check = (0,1), padding_option = 'zero')
    V = tl.load(V_block_ptr, boundary_check = (0,1), padding_option = 'zero')
    O = tl.load(O_block_ptr, boundary_check = (0,1), padding_option = 'zero')
    L = tl.load(L_block_ptr, boundary_check = (0,1), padding_option = 'zero')
    for i in range(0, N_QUERIES, Q_TILE_SIZE):
        q_end = min(i+ Q_TILE_SIZE, N_QUERIES)
        q_tile = Q[..., i:q_end, :]
        o = tl.zeros(q_tile, dtype = q_tile.dtype)
        l = tl.zeros(*q_tile.shape[:-1],dtype = q_tile.dtype)
        m = tl.full((*q_tile.shape[:-1],), float('-inf'), dtype = q_tile.dtype)
        m.fill_(float('-inf'))
        for j in range(0, N_KEYS, K_TILE_SIZE):
            k_end = min(j+K_TILE_SIZE, N_KEYS)
            k_tile = K[...,j:k_end, :]
            v_tile = V[...,j:k_end, :]
            S = tl.dot(q_tile @ k_tile.transpose(-2,-1)) / scale
            m_new = torch.maximum(m, S.max(dim=-1)[0])
            P = tl.exp(S - m_new.unsqueeze(-1), dtype = v_tile.dtype)
            l = tl.exp(m - m_new)*l + P.sum(dim=-1)
            o = tl.exp(m.unsqueeze(-1)-m_new.unsqueeze(-1)) * o + tl.dot(P, v_tile)
            m = m_new
            K_block_ptr = K_block_ptr.advance(K_TILE_SIZE,0)
            V_block_ptr = V_block_ptr.advance(K_TILE_SIZE,0)
        o = o / l.unsqueeze(-1)
        l = m + tl.log(l)
        O[...,i:q_end, :] = o
        L[...,i:q_end] = l
        Q_block_ptr = Q_block_ptr.advance(Q_TILE_SIZE,0)
        O_block_ptr = V_block_ptr.advance(Q_TILE_SIZE,0)
        L_block_ptr = V_block_ptr.advance(Q_TILE_SIZE,0)
    tl.store(O, boundary_check = (0,))


#uv run pytest -k test_flash_forward_pass_triton
class triton_kernel_flash_attention_fwd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, Q, K, V, is_causal = False):
        batch_size, N_QUERIES, D = Q.shape
        scale  = math.sqrt(D)
        N_KEYS = K.shape[-2] # 每个线程同时处理16行数据

        Q_TILE_SIZE = Q.shape[-2]
        K_TILE_SIZE = K.shape[-2]
        O = torch.zeros(Q)
        L = torch.zeros(*Q.shape[:-1])
        ctx.save_for_backward(O, L)
        flash_fwd_kernel[Q_TILE_SIZE, batch_size](
            Q, K, V,
            O, L,

            Q.stride(0), 
            Q.stride(1), 
            Q.stride(2), 

            K.stride(0), 
            K.stride(1), 
            K.stride(2), 

            V.stride(0), 
            V.stride(1), 
            V.stride(2), 

            O.stride(0), 
            O.stride(1), 
            O.stride(2), 

            L.stride(0), 
            L.stride(1), 
        
            N_QUERIES, 
            N_KEYS,
            scale,
            D,
            Q_TILE_SIZE,
            K_TILE_SIZE,
        )

        ctx.save_for_backward (Q, K, V, O, L)
        ctx.is_causal = is_causal
        return O
    @staticmethod
    def backward(ctx):
        raise NotImplemented




NameError: name 'torch' is not defined

In [2]:
from optimizer import AdamW


In [None]:
# ============================================================================
# NEW IMPLEMENTATION: Using cProfile for performance analysis
# ============================================================================
# This version works on both CPU and GPU, doesn't require NVIDIA tools
# ============================================================================
import cProfile
import pstats
import io
from pstats import SortKey
import torch
import math

batch_size = 32
seq_len = 200
d_k = 64
d_v = 64
d_model = 512

def scaled_dot_product_attention(Q, K, V, mask=None, seq_len=200, batch_dims=1):
    """
    Scaled dot-product attention (single head)
    Using cProfile for performance analysis instead of nvtx
    """
    # Use provided Q, K, V or create defaults
    if Q is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        Q = torch.randn(batch_size, seq_len, d_k, device=device)
    if K is None:
        device = Q.device if Q is not None else ('cuda' if torch.cuda.is_available() else 'cpu')
        K = torch.randn(batch_size, seq_len, d_k, device=device)
    if V is None:
        device = Q.device if Q is not None else ('cuda' if torch.cuda.is_available() else 'cpu')
        V = torch.randn(batch_size, seq_len, d_v, device=device)
    
    # Initialize output projection weight
    device = Q.device
    o_proj_weight = torch.randn(d_model, d_v, device=device)
    
    # Get dimensions
    batch_size_actual = Q.shape[0]
    d_k = Q.shape[-1]
    # 从权重字典中提取可训练参数
    optimizer = AdamW(
        params,
        lr,
        betas,
        eps,
        weight_decay
    )
    # Create causal mask if not provided
    if mask is None:
        # 创建因果掩码：每个位置只能看到之前的token
        # causal_mask: (seq_len, seq_len)，位置(i,j)当i>=j时为True
        causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
        # 扩展到批次维度：(batch_size, seq_len, seq_len)
        mask = causal_mask.expand(batch_size_actual, seq_len, seq_len)

    # Computing attention scores (equivalent to nvtx.range('computing attention scores'))
    scores = Q @ K.transpose(-2, -1)  # (batch_size, seq_len, seq_len)
    scores = scores / math.sqrt(d_k)  # Scale
    # Apply causal mask
    scores = scores.masked_fill(~mask, float('-inf'))

    # Computing softmax (equivalent to nvtx.range('computing softmax'))
    attention_weights = torch.softmax(scores, dim=-1)  # (batch_size, seq_len, seq_len)
    attention = attention_weights @ V  # (batch_size, seq_len, d_v)

    # Final matmul (equivalent to nvtx.range('final matmul'))
    output = attention @ o_proj_weight.T  # (batch_size, seq_len, d_model)
    
    return output

# Profile wrapper function
def profile_attention_function(n_iterations=100):
    """Profile the attention function using cProfile"""
    # Create test inputs
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    Q = torch.randn(batch_size, seq_len, d_k, device=device)
    K = torch.randn(batch_size, seq_len, d_k, device=device)
    V = torch.randn(batch_size, seq_len, d_v, device=device)
    
    # Create profiler
    profiler = cProfile.Profile()
    
    # Profile the function
    profiler.enable()
    for _ in range(n_iterations):
        output = scaled_dot_product_attention(Q, K, V, None, seq_len, batch_dims=1)
    profiler.disable()
    
    # Create a string buffer to capture stats
    s = io.StringIO()
    ps = pstats.Stats(profiler, stream=s)
    ps.sort_stats(SortKey.CUMULATIVE)  # Sort by cumulative time
    
    # Print statistics
    print("="*80)
    print("cProfile Results - Top 20 functions by cumulative time")
    print("="*80)
    ps.print_stats(20)
    print(s.getvalue())
    
    # Save to file
    profiler.dump_stats('attention_profile.prof')
    print("\n✓ Profile saved to 'attention_profile.prof'")
    print("  View with: python -m pstats attention_profile.prof")
    print("  Or install snakeviz: pip install snakeviz && snakeviz attention_profile.prof")
    
    return output, profiler


In [2]:
# Run profiling
output, profiler = profile_attention_function(n_iterations=100)

# The profiler object can be used for further analysis
# For example, you can filter by specific functions:
print("\n" + "="*80)
print("Filtered results - Only torch functions:")
print("="*80)
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats('torch', 10)  # Show top 10 torch-related functions


cProfile Results - Top 20 functions by cumulative time
         901 function calls in 0.413 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      100    0.254    0.003    0.413    0.004 /var/folders/cp/vqdxm_290_j_myz50smbrjwc0000gn/T/ipykernel_7947/448703527.py:15(scaled_dot_product_attention)
      100    0.053    0.001    0.053    0.001 {built-in method torch.softmax}
      100    0.041    0.000    0.041    0.000 {method 'masked_fill' of 'torch._C.TensorBase' objects}
      100    0.036    0.000    0.036    0.000 {built-in method torch.randn}
      100    0.023    0.000    0.023    0.000 {built-in method torch.ones}
      100    0.003    0.000    0.003    0.000 {built-in method torch.tril}
      100    0.002    0.000    0.002    0.000 {method 'expand' of 'torch._C.TensorBase' objects}
      100    0.001    0.000    0.001    0.000 {method 'transpose' of 'torch._C.TensorBase' objects}
      100    0.000    0.000    0.000

<pstats.Stats at 0x10776c040>

In [3]:
# ============================================================================
# NEW IMPLEMENTATION: Using cProfile for performance analysis
# ============================================================================
# This version works on both CPU and GPU, doesn't require NVIDIA tools
# ============================================================================
import cProfile
import pstats
import io
from pstats import SortKey
import torch
import math
import torch.nn as nn

from torch.nn.functional import cross_entropy

batch_size = 32
seq_len = 200
d_k = 64
d_v = 64
d_model = 512
from optimizer import AdamW


def scaled_dot_product_attention(Q, K, V, mask=None, seq_len=200, batch_dims=1):
    """
    Scaled dot-product attention (single head)
    Using cProfile for performance analysis instead of nvtx
    """
    # Use provided Q, K, V or create defaults
    if Q is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        Q = torch.randn(batch_size, seq_len, d_k, device=device)
    if K is None:
        device = Q.device if Q is not None else ('cuda' if torch.cuda.is_available() else 'cpu')
        K = torch.randn(batch_size, seq_len, d_k, device=device)
    if V is None:
        device = Q.device if Q is not None else ('cuda' if torch.cuda.is_available() else 'cpu')
        V = torch.randn(batch_size, seq_len, d_v, device=device)
    
    # Initialize output projection weight
    vocab_size = 5054
    device = Q.device
    o_proj_weight = torch.nn.Parameter(torch.randn(d_model, d_v, device=device))
    lm_head_weight = torch.nn.Parameter(torch.randn(vocab_size, d_model, device = device))

    # Get dimensions
    batch_size_actual = Q.shape[0]
    d_k = Q.shape[-1]


    # 从权重字典中提取可训练参数
    optimizer = AdamW(
        params=[o_proj_weight, lm_head_weight]
    )
    optimizer.zero_grad()
    # Create causal mask if not provided
    if mask is None:
        # 创建因果掩码：每个位置只能看到之前的token
        # causal_mask: (seq_len, seq_len)，位置(i,j)当i>=j时为True
        causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
        # 扩展到批次维度：(batch_size, seq_len, seq_len)
        mask = causal_mask.expand(batch_size_actual, seq_len, seq_len)

    # Computing attention scores (equivalent to nvtx.range('computing attention scores'))
    scores = Q @ K.transpose(-2, -1)  # (batch_size, seq_len, seq_len)
    scores = scores / math.sqrt(d_k)  # Scale
    # Apply causal mask
    scores = scores.masked_fill(~mask, float('-inf'))

    # Computing softmax (equivalent to nvtx.range('computing softmax'))
    attention_weights = torch.softmax(scores, dim=-1)  # (batch_size, seq_len, seq_len)
    attention = attention_weights @ V  # (batch_size, seq_len, d_v)

    # Final matmul (equivalent to nvtx.range('final matmul'))
    output = attention @ o_proj_weight.T  # (batch_size, seq_len, d_model)
    
    #loss
    logits = output @ lm_head_weight.T
    targets = torch.randint(0, vocab_size, (batch_size_actual, seq_len), dtype = torch.long)
    logits_flat = logits.view(-1, vocab_size)
    targets_flat = targets.view(-1)
    loss = cross_entropy(logits_flat, targets_flat)

    loss.backward()
    optimizer.step()
    return output



# Profile wrapper function
def profile_attention_function(n_iterations=100):
    """Profile the attention function using cProfile"""
    # Create test inputs
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    Q = torch.randn(batch_size, seq_len, d_k, device=device)
    K = torch.randn(batch_size, seq_len, d_k, device=device)
    V = torch.randn(batch_size, seq_len, d_v, device=device)
    
    # Create profiler
    profiler = cProfile.Profile()
    _ = scaled_dot_product_attention(Q, K, V, None, seq_len, batch_dims=1)

    # Profile the function
    profiler.enable()
    for _ in range(n_iterations):
        output = scaled_dot_product_attention(Q, K, V, None, seq_len, batch_dims=1)
    profiler.disable()
    
    # Create a string buffer to capture stats
    s = io.StringIO()
    ps = pstats.Stats(profiler, stream=s)
    ps.sort_stats(SortKey.CUMULATIVE)  # Sort by cumulative time
    
    # Print statistics
    print("="*80)
    print("cProfile Results - Top 20 functions by cumulative time")
    print("="*80)
    ps.print_stats(20)
    print(s.getvalue())
    
    # Save to file
    profiler.dump_stats('attention_profile.prof')
    print("\n✓ Profile saved to 'attention_profile.prof'")
    print("  View with: python -m pstats attention_profile.prof")
    print("  Or install snakeviz: pip install snakeviz && snakeviz attention_profile.prof")
    
    return output, profiler


# Run profiling
output, profiler = profile_attention_function(n_iterations=1)

# The profiler object can be used for further analysis
# For example, you can filter by specific functions:
print("\n" + "="*80)
print("Filtered results - Only torch functions:")
print("="*80)
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats('torch', 10)  # Show top 10 torch-related functions


cProfile Results - Top 20 functions by cumulative time
         226 function calls (216 primitive calls) in 0.173 seconds

   Ordered by: cumulative time
   List reduced from 81 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.031    0.031    0.173    0.173 /var/folders/cp/vqdxm_290_j_myz50smbrjwc0000gn/T/ipykernel_15297/15839886.py:24(scaled_dot_product_attention)
        1    0.000    0.000    0.093    0.093 /Users/richard/opt/anaconda3/envs/deepDR/OneRel_chinese/lib/python3.8/site-packages/torch/_tensor.py:465(backward)
        1    0.000    0.000    0.093    0.093 /Users/richard/opt/anaconda3/envs/deepDR/OneRel_chinese/lib/python3.8/site-packages/torch/autograd/__init__.py:183(backward)
        1    0.000    0.000    0.092    0.092 /Users/richard/opt/anaconda3/envs/deepDR/OneRel_chinese/lib/python3.8/site-packages/torch/autograd/graph.py:764(_engine_run_backward)
        1    0.092    0.092    0.092    0.092 {meth

<pstats.Stats at 0x105d477c0>