In [1]:
import torch
# Fixed the random seed
torch.manual_seed(0)

<torch._C.Generator at 0x7cdd60ee39b0>

In [2]:
def Scaled_Product(Q: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
    # Compute QK^T
    QKT = torch.matmul(Q, K.transpose(-2, -1))
    
    print(f"QKT: {QKT}")
    # Scale QK^T
    d_k = K.size(-1)
    QKT_scaled = QKT / (d_k ** 0.5)
    
    # Apply softmax to get attention weights
    attention_weights = torch.nn.functional.softmax(QKT_scaled, dim=-1)
    return attention_weights

In [3]:
def Approx_Scale_Product(Q: torch.Tensor, K: torch.Tensor, max_iter: int) -> torch.Tensor:
    assert Q.dim() == 2 and K.dim() == 2, "Q and K must be 2D tensors"
    assert Q.size(1) == K.size(1), "Q and K must have the same feature dimension"
    Q_vec_size = Q.size(0)
    dtype, device = Q.dtype, Q.device
    
    total_score = []
    
    for q_id in range(Q_vec_size):
        q_mat = Q[q_id].repeat(K.size(0), 1)
        product = K * q_mat
        product_for_max = product.clone()
        product_for_min = product.clone()
        greedy_score = torch.zeros(K.size(0), dtype=dtype, device=device)
        print(f"Product: {product}")
        for iter in range(max_iter):
            max_val = product_for_max.max()
            max_idx_flatten = product_for_max.argmax().item()
            max_idx_row, max_idx_col = divmod(max_idx_flatten, product_for_max.size(1))
            print(f"Iteration {iter}: max_val: {max_val:.2f}, max_idx: {max_idx_row}", end = " | ")
            
            min_val = product_for_min.min()
            min_idx_flatten = product_for_min.argmin().item()
            min_idx_row, min_idx_col = divmod(min_idx_flatten, product_for_min.size(1))
            print(f"min_val: {min_val:.2f}, min_idx: {min_idx_row}")
            # Update the greedy score
            greedy_score[max_idx_row] += max_val
            product_for_max[max_idx_row, max_idx_col] = float('-inf')  # Set the max value to -inf to avoid selecting it again
            greedy_score[min_idx_row] += min_val
            product_for_min[min_idx_row, min_idx_col] = float('inf')  # Set the min value to inf to avoid selecting it again
            print(f"Greedy Score: {greedy_score}")
        # Set the <0 values to 0
        greedy_score[greedy_score < 0] = 0
        total_score.append(greedy_score)
    total_score = torch.stack(total_score)
    print(f"Approx QKT: {total_score}")
    d_k = K.size(-1)  # Dimension of keys (D)
    total_score =  total_score/ (d_k ** 0.5)
    
    results = torch.nn.functional.softmax(total_score, dim=-1)
    
    # Step 5: Scale the greedy score
    return results

In [4]:
Q = torch.tensor([
    [0.8, -0.3, 0.4],
])
K = torch.tensor([
    [-0.6, 0.1, 0.8],
    [0.1, -0.2, -0.9],
    [0.8, 0.6, 0.7],
    [0.5, 0.7, 0.5]
])
max_iter = 6

reference = Scaled_Product(Q, K)
ours = Approx_Scale_Product(Q, K, max_iter)
reference, ours


QKT: tensor([[-0.1900, -0.2200,  0.7400,  0.3900]])
Product: tensor([[-0.4800, -0.0300,  0.3200],
        [ 0.0800,  0.0600, -0.3600],
        [ 0.6400, -0.1800,  0.2800],
        [ 0.4000, -0.2100,  0.2000]])
Iteration 0: max_val: 0.64, max_idx: 2 | min_val: -0.48, min_idx: 0
Greedy Score: tensor([-0.4800,  0.0000,  0.6400,  0.0000])
Iteration 1: max_val: 0.40, max_idx: 3 | min_val: -0.36, min_idx: 1
Greedy Score: tensor([-0.4800, -0.3600,  0.6400,  0.4000])
Iteration 2: max_val: 0.32, max_idx: 0 | min_val: -0.21, min_idx: 3
Greedy Score: tensor([-0.1600, -0.3600,  0.6400,  0.1900])
Iteration 3: max_val: 0.28, max_idx: 2 | min_val: -0.18, min_idx: 2
Greedy Score: tensor([-0.1600, -0.3600,  0.7400,  0.1900])
Iteration 4: max_val: 0.20, max_idx: 3 | min_val: -0.03, min_idx: 0
Greedy Score: tensor([-0.1900, -0.3600,  0.7400,  0.3900])
Iteration 5: max_val: 0.08, max_idx: 1 | min_val: 0.06, min_idx: 1
Greedy Score: tensor([-0.1900, -0.2200,  0.7400,  0.3900])
Approx QKT: tensor([[0.0000, 

(tensor([[0.1964, 0.1930, 0.3360, 0.2745]]),
 tensor([[0.2090, 0.2090, 0.3203, 0.2617]]))