In [3]:
import numpy as np 
import math


def safe_softmax(vector): 
  l = np.size(vector)
  out = np.zeros(l)
  maxim = float("-inf")
  denom = 0 
  #first pass, calculate max 
  for i in range(l): 
    maxim = max(vector[i],maxim)
  #second pass, calculate denom 
  for i in range(l): 
    denom += math.exp(vector[i] - maxim)
  #third pass, calculate softmax 
  for i in range(l):
    out[i] = math.exp(vector[i]-maxim)/denom
    
  return out 

def ref_softmax(vector): 
  max = np.max(vector)
  denom = np.sum(np.exp(vector-max))
  return np.exp(vector-max)/denom


def safe_online_softmax(vector): 
  prev_maxim = float("-inf")
  prev_denom = 0
  maxim = prev_maxim
  denom = prev_denom
  l = np.size(vector)
  out = np.zeros(l)
  for i in range(l): 
    maxim = max(prev_maxim, vector[i])
    denom = prev_denom*math.exp(prev_maxim-maxim) + math.exp(vector[i]-maxim)
    prev_maxim = maxim 
    prev_denom = denom 
    
  for i in range(l): 
    out[i] = math.exp(vector[i]-maxim)/denom
  return out



In [4]:
vector = np.random.randn(16)
ref_out = ref_softmax(vector)
out = safe_softmax(vector)
online_out = safe_online_softmax(vector)


In [5]:
online_out


array([0.04818763, 0.28085782, 0.00610432, 0.00608828, 0.0050852 ,
       0.02365462, 0.17424583, 0.05015363, 0.11174846, 0.06179297,
       0.0120769 , 0.00582275, 0.08914052, 0.0652335 , 0.05237347,
       0.00743408])

In [6]:
ref_out

array([0.04818763, 0.28085782, 0.00610432, 0.00608828, 0.0050852 ,
       0.02365462, 0.17424583, 0.05015363, 0.11174846, 0.06179297,
       0.0120769 , 0.00582275, 0.08914052, 0.0652335 , 0.05237347,
       0.00743408])

In [7]:
x = np.array([float('-inf')]*10)

In [8]:
class single_element_attention:
  def __init__ (self, D, L_k):
    self.Q_row = np.random.randn(1,D)
    self.K_T = np.random.randn(D,L_k)
    self.S_row = np.zeros((1,L_k))
    self.P_row = np.zeros_like(self.S_row)
    self.V_col = np.random.randn(L_k,1)

  def zero_out(self): 
    self.S_row = np.zeros_like(self.S_row)
    self.P_row = np.zeros_like(self.S_row)
    
  def naive_attention(self): 
    out = 0
    self.zero_out() 
    
    D,L_k = self.K_T.shape 
    for lk in range(L_k):
      for d in range(D): 
        self.S_row[0,lk] += self.Q_row[0,d]*self.K_T[d,lk]
    maxim = float("-inf")
    denom = 0 
    for lk in range(L_k): 
      maxim = max(maxim, self.S_row[0,lk])
    for lk in range(L_k): 
      denom += math.exp((self.S_row[0,lk]-maxim))
    for lk in range(L_k): 
      self.P_row[0,lk] = math.exp((self.S_row[0,lk]-maxim))/denom
    for lk in range(L_k): 
      out += self.P_row[0,lk]*self.V_col[lk,0]
      
    return out 
  
  def attention_V0(self): 
    #we dont need to materialize P_row 
    out = 0
    self.zero_out() 
    
    D,L_k = self.K_T.shape 
    for lk in range(L_k):
      for d in range(D): 
        self.S_row[0,lk] += self.Q_row[0,d]*self.K_T[d,lk]
    maxim = float("-inf")
    denom = 0 
    for lk in range(L_k): 
      maxim = max(maxim, self.S_row[0,lk])
    for lk in range(L_k): 
      denom += math.exp((self.S_row[0,lk]-maxim))
    for lk in range(L_k): 
     out += (math.exp((self.S_row[0,lk]-maxim))/denom)*self.V_col[lk,0];
   
    return out 
  
  def attention_V1(self): 
    #online softmax 
    out = 0 
    self.zero_out()
    maxim = float("-inf")
    denom = 0 
    prev_maxim = maxim 
    prev_denom = denom 
    D,L_k = self.K_T.shape 
    
    for lk in range(L_k):
      for d in range(D): 
        self.S_row[0,lk] += self.Q_row[0,d]*self.K_T[d,lk]
      maxim = max(prev_maxim, self.S_row[0,lk])
      denom = prev_denom*math.exp(prev_maxim-maxim) + math.exp(self.S_row[0,lk]-maxim)
      prev_maxim = maxim 
      prev_denom = denom 
    
    for lk in range(L_k): 
     out += (math.exp((self.S_row[0,lk]-maxim))/denom)*self.V_col[lk,0]
   
    return out 
  
  
  def attention_V2(self): 
    #flash
    
    self.zero_out()
    maxim = float("-inf")
    denom = 0 
    out = 0 
    prev_maxim = maxim 
    prev_denom = denom 
    prev_out = out
    
    D,L_k = self.K_T.shape 
    
    for lk in range(L_k):
      for d in range(D): 
        self.S_row[0,lk] += self.Q_row[0,d]*self.K_T[d,lk]
      maxim = max(prev_maxim, self.S_row[0,lk])
      denom = prev_denom*math.exp(prev_maxim-maxim) + math.exp(self.S_row[0,lk]-maxim)
      out = prev_out*math.exp(prev_maxim-maxim) + (math.exp(self.S_row[0,lk]-maxim)*self.V_col[lk,0])

      prev_maxim = maxim 
      prev_denom = denom 
      prev_out = out

    return out/denom 
    

In [9]:
D = 4
L_k = 32
attn = single_element_attention(D,L_k) 

naive_out = attn.naive_attention()
v0_out = attn.attention_V0()
v1_out = attn.attention_V1()
v2_out = attn.attention_V2()

In [10]:
print(naive_out)
print(v0_out)
print(v1_out)
print(v2_out)

-0.021774483024603466
-0.021774483024603466
-0.021774483024603466
-0.021774483024603438


In [11]:
class attention: 
  def __init__ (self, B,Lq,Lkv,D):
    self.Q = np.random.randn(B,Lq,D)
    self.K = np.random.randn(B,Lkv,D)
    self.V = np.random.randn(B,Lkv,D)
    
  def naive_attention(self): 
    B,Lq,D = self.Q.shape 
    _,Lkv,_ = self.K.shape 
    
    S = np.matmul(self.Q, self.K.transpose(0,2,1))
    H = np.max(S, axis=2, keepdims=True)
    B = np.exp(S-H)
    R = np.sum(B, axis=2, keepdims=True)
    P = B/R 
    out = np.matmul(P,self.V)
    return out 
    
  def flash_attention(self,tile_Lq, tile_Lkv):
    B,Lq,D = self.Q.shape 
    _,Lkv,_ = self.K.shape 
    O = np.zeros((B,Lq,D))
    
    for b in range(B): 
      for lq_start in range(0,Lq,tile_Lq): 
        #we are inside a block now 
        Q_tile = self.Q[b,lq_start:lq_start+tile_Lq,:] #load Q tile once
        K_tile = np.zeros((tile_Lkv,D))
        S_tile = np.zeros((tile_Lq,tile_Lkv))
        P_unorm_tile = np.zeros((tile_Lq,tile_Lkv))
        V_tile = np.zeros((tile_Lkv,D))
        O_unorm_tile = np.zeros((tile_Lq,D))
        
        m = np.full((tile_Lq,1),-np.inf)
        sum = np.zeros((tile_Lq,1))
        
        for lkv_start in range(0, Lkv, tile_Lkv): #stream over tile_Lkv chunks 
          K_tile = self.K[b,lkv_start:lkv_start + tile_Lkv,:] #load k_tile
          V_tile = self.V[b,lkv_start:lkv_start + tile_Lkv,:] #load V_tile
          S_tile = np.matmul(Q_tile, K_tile.transpose(1,0))
          m_p = np.max(S_tile,axis=1,keepdims=True) 
          m_op_m_p = np.maximum(m,m_p)
          P_unorm_tile = np.exp(S_tile-m_p)
          sum_p = np.sum(P_unorm_tile,axis=1,keepdims=True)
          sum_op_sum_p = (np.exp(m - m_op_m_p)*sum) + (np.exp(m_p - m_op_m_p)*sum_p)
          O_unorm_tile_p = np.matmul(P_unorm_tile,V_tile)
          O_unorm_tile_op_O_unorm_tile_p = (np.exp(m - m_op_m_p)*O_unorm_tile) + (np.exp(m_p - m_op_m_p)*O_unorm_tile_p)
          
          #state updates 
          m = m_op_m_p
          sum = sum_op_sum_p
          O_unorm_tile = O_unorm_tile_op_O_unorm_tile_p
          
        
        O_norm_tile = O_unorm_tile/sum 
        O[b,lq_start:lq_start+tile_Lq,:] = O_norm_tile
          
    return O 
      
  def flash_attention_algorithmic(self,tile_Lq, tile_Lkv):
    B,Lq,D = self.Q.shape 
    _,Lkv,_ = self.K.shape 
    O = np.zeros((B,Lq,D))
    
    for b in range(B): 
      for lq_start in range(0,Lq,tile_Lq): 
        #we are inside a block now 
        Q_tile = self.Q[b,lq_start:lq_start+tile_Lq,:] #load Q tile once
        K_tile = np.zeros((tile_Lkv,D))
        S_tile = np.zeros((tile_Lq,tile_Lkv))
        P_unorm_tile = np.zeros((tile_Lq,tile_Lkv))
        V_tile = np.zeros((tile_Lkv,D))
        O_unorm_tile = np.zeros((tile_Lq,D))
        
        m = np.full((tile_Lq,1),-np.inf)
        sum = np.zeros((tile_Lq,1))
        
        for lkv_start in range(0, Lkv, tile_Lkv): #stream over tile_Lkv chunks 
          K_tile = self.K[b,lkv_start:lkv_start + tile_Lkv,:] #load k_tile
          V_tile = self.V[b,lkv_start:lkv_start + tile_Lkv,:] #load V_tile
          S_tile = np.matmul(Q_tile, K_tile.transpose(1,0))
          curr_m = np.max(S_tile,axis=1,keepdims=True) 
          new_m = np.maximum(m,curr_m)
          scale = np.exp(m-new_m)
          
          P_unorm_tile = np.exp(S_tile-new_m)
          curr_sum = np.sum(P_unorm_tile,axis=1,keepdims=True)
          sum = (scale*sum) + (curr_sum)
          O_unorm_tile *= scale
          O_unorm_tile += np.matmul(P_unorm_tile,V_tile)

          m = new_m
        
        O_norm_tile = O_unorm_tile/sum 
        O[b,lq_start:lq_start+tile_Lq,:] = O_norm_tile
          
    return O 
      

          
          
          

        
        
      
      
      
  

In [12]:
B = 4
Lq = 512
Lkv = 1024
D = 128 

full_attn = attention(B,Lq,Lkv,D)

In [14]:
naive_out = full_attn.naive_attention()
flash_out = full_attn.flash_attention(64,64)
flash_alg_out = full_attn.flash_attention_algorithmic(64,64)
np.max(np.abs(naive_out - flash_alg_out))

np.float64(3.186340080674199e-14)