In [4]:
import numpy as np 
import math


def safe_softmax(vector): 
  l = np.size(vector)
  out = np.zeros(l)
  maxim = float("-inf")
  denom = 0 
  #first pass, calculate max 
  for i in range(l): 
    maxim = max(vector[i],maxim)
  #second pass, calculate denom 
  for i in range(l): 
    denom += math.exp(vector[i] - maxim)
  #third pass, calculate softmax 
  for i in range(l):
    out[i] = math.exp(vector[i]-maxim)/denom
    
  return out 

def ref_softmax(vector): 
  max = np.max(vector)
  denom = np.sum(np.exp(vector-max))
  return np.exp(vector-max)/denom


def safe_online_softmax(vector): 
  prev_maxim = float("-inf")
  prev_denom = 0
  maxim = prev_maxim
  denom = prev_denom
  l = np.size(vector)
  out = np.zeros(l)
  for i in range(l): 
    maxim = max(prev_maxim, vector[i])
    denom = prev_denom*math.exp(prev_maxim-maxim) + math.exp(vector[i]-maxim)
    prev_maxim = maxim 
    prev_denom = denom 
    
  for i in range(l): 
    out[i] = math.exp(vector[i]-maxim)/denom
  return out



In [5]:
vector = np.random.randn(16)
ref_out = ref_softmax(vector)
out = safe_softmax(vector)
online_out = safe_online_softmax(vector)


In [6]:
online_out


array([0.02561199, 0.02492177, 0.01480078, 0.02532539, 0.08511779,
       0.05392862, 0.01046387, 0.00611335, 0.06218651, 0.02227012,
       0.03050929, 0.04931391, 0.40017593, 0.15116785, 0.00879579,
       0.02929706])

In [7]:
ref_out

array([0.02561199, 0.02492177, 0.01480078, 0.02532539, 0.08511779,
       0.05392862, 0.01046387, 0.00611335, 0.06218651, 0.02227012,
       0.03050929, 0.04931391, 0.40017593, 0.15116785, 0.00879579,
       0.02929706])

In [8]:
x = np.array([float('-inf')]*10)

In [34]:
class single_element_attention:
  def __init__ (self, D, L_k):
    self.Q_row = np.random.randn(1,D)
    self.K_T = np.random.randn(D,L_k)
    self.S_row = np.zeros((1,L_k))
    self.P_row = np.zeros_like(self.S_row)
    self.V_col = np.random.randn(L_k,1)

  def zero_out(self): 
    self.S_row = np.zeros_like(self.S_row)
    self.P_row = np.zeros_like(self.S_row)
    
  def naive_attention(self): 
    out = 0
    self.zero_out() 
    
    D,L_k = self.K_T.shape 
    for lk in range(L_k):
      for d in range(D): 
        self.S_row[0,lk] += self.Q_row[0,d]*self.K_T[d,lk]
    maxim = float("-inf")
    denom = 0 
    for lk in range(L_k): 
      maxim = max(maxim, self.S_row[0,lk])
    for lk in range(L_k): 
      denom += math.exp((self.S_row[0,lk]-maxim))
    for lk in range(L_k): 
      self.P_row[0,lk] = math.exp((self.S_row[0,lk]-maxim))/denom
    for lk in range(L_k): 
      out += self.P_row[0,lk]*self.V_col[lk,0]
      
    return out 
  
  def attention_V0(self): 
    #we dont need to materialize P_row 
    out = 0
    self.zero_out() 
    
    D,L_k = self.K_T.shape 
    for lk in range(L_k):
      for d in range(D): 
        self.S_row[0,lk] += self.Q_row[0,d]*self.K_T[d,lk]
    maxim = float("-inf")
    denom = 0 
    for lk in range(L_k): 
      maxim = max(maxim, self.S_row[0,lk])
    for lk in range(L_k): 
      denom += math.exp((self.S_row[0,lk]-maxim))
    for lk in range(L_k): 
     out += (math.exp((self.S_row[0,lk]-maxim))/denom)*self.V_col[lk,0];
   
    return out 
  
  def attention_V1(self): 
    #online softmax 
    out = 0 
    self.zero_out()
    maxim = float("-inf")
    denom = 0 
    prev_maxim = maxim 
    prev_denom = denom 
    D,L_k = self.K_T.shape 
    
    for lk in range(L_k):
      for d in range(D): 
        self.S_row[0,lk] += self.Q_row[0,d]*self.K_T[d,lk]
      maxim = max(prev_maxim, self.S_row[0,lk])
      denom = prev_denom*math.exp(prev_maxim-maxim) + math.exp(self.S_row[0,lk]-maxim)
      prev_maxim = maxim 
      prev_denom = denom 
    
    for lk in range(L_k): 
     out += (math.exp((self.S_row[0,lk]-maxim))/denom)*self.V_col[lk,0]
   
    return out 
  
  
  def attention_V2(self): 
    #flash
    
    self.zero_out()
    maxim = float("-inf")
    denom = 0 
    out = 0 
    prev_maxim = maxim 
    prev_denom = denom 
    prev_out = out
    
    D,L_k = self.K_T.shape 
    
    for lk in range(L_k):
      for d in range(D): 
        self.S_row[0,lk] += self.Q_row[0,d]*self.K_T[d,lk]
      maxim = max(prev_maxim, self.S_row[0,lk])
      denom = prev_denom*math.exp(prev_maxim-maxim) + math.exp(self.S_row[0,lk]-maxim)
      out = prev_out*math.exp(prev_maxim-maxim) + (math.exp(self.S_row[0,lk]-maxim)*self.V_col[lk,0])

      prev_maxim = maxim 
      prev_denom = denom 
      prev_out = out

    return out/denom 
    

In [35]:
D = 4
L_k = 32
attn = single_element_attention(D,L_k) 

naive_out = attn.naive_attention()
v0_out = attn.attention_V0()
v1_out = attn.attention_V1()
v2_out = attn.attention_V2()

In [36]:
print(naive_out)
print(v0_out)
print(v1_out)
print(v2_out)

0.2936816099532507
0.2936816099532507
0.2936816099532507
0.29368160995325077
