In [1]:
import torch

In [2]:
import torch
import math

def attention (Q,K,V):
  """
  Q : (B,H,N,D)
  K : (B,H,N,D)
  V : (B,H,N,D)
  B: batch_size 
  H : num_heads
  N: sequence length 
  D: dimension per head 
  """
  B,H,N,D = Q.shape 
  
  K_t = K.transpose(-1,-2) # (B,H,N,D) -> (B,H,D,N)
  
  scores = torch.matmul(Q, K_t) #(B,H,N,D) @ (B,H,D,N) -> (B,H,N,N)
  
  scaled_scores = scores / math.sqrt(D)
  
  #we must now turn each row into a probability distribution vector, essentially saying 
  #given the sequence ID, what is the probability distribution of other sequence Id's being related? 
  
  scores_row_max = scaled_scores.max(dim=3, keepdim=True).values
  
  adjusted_scores = scaled_scores - scores_row_max 
  
  exponentiated_scores = torch.exp(adjusted_scores)
  
  exp_row_sum = torch.sum(exponentiated_scores, dim = 3, keepdim = True)
  
  probs = exponentiated_scores/exp_row_sum # still (B,H,N,N)
  
  out = torch.matmul(probs, V) # (B,H,N,N) @ (B,H,N,D) -> (B,H,N,D)
  
  return out