In [None]:
def debiasLerp(a, b, p, nBatches):
    return torch.lerp(a, b/torch.tensor(p).pow(nBatches), p)
    
class BatchNorm(HelpfulModule):
    def __init__(self, inputSize, mom=0.1, eps=0.01):
        super().__init__()
        self.mom, self.eps = mom, eps
        self.inputSize = inputSize
        self.multiplicitiveWeight = nn.Parameter(torch.ones([inputSize]))
        self.additiveWeight = nn.Parameter(torch.zeros([inputSize]))
        self.register_buffer('running_mean', torch.zeros(inputSize))
        self.register_buffer('running_sum_of_squares', torch.zeros(inputSize))
        self.register_buffer('running_squared_sum', torch.zeros(inputSize))
        self.nBatches = 1
    
    def resetParams(self):
        self.running_mean.zero_()
        self.running_sum_of_squares.zero_()
        self.running_squared_sum.zero_()
        self.nBatches = 1
        
    def forward(self, x):
        self.nBatches += 1
        self.running_mean = debiasLerp(self.running_mean, x.mean(axis=0), self.mom, self.nBatches)
        self.running_mean_of_squares = debiasLerp(self.running_mean, x.pow(2).mean(axis=0), self.mom, self.nBatches)
        self.running_squared_mean = debiasLerp(self.running_mean, x.mean(axis=0).pow(2), self.mom, self.nBatches)
        mu = self.running_mean
        var = torch.max(self.running_mean_of_squares - self.running_squared_mean, torch.tensor(self.eps))
        normalizedOutput = (x-mu)/var
        return normalizedOutput*self.multiplicitiveWeight+self.additiveWeight
    
class MultiHeadSelfAttentionOld(HelpfulModule):
    def __init__(self, n, d, k):
        super().__init__()
        self.n, self.d, self.k = n,d,k
        # Todo: compute initialization scaling factors
        # TODO: What about more things than just QKV? Like four or five or something
        self.Q = nn.Parameter(torch.normal(0, 1, [k, d]))
        self.K = nn.Parameter(torch.normal(0, 1, [k, d]))
        self.V = nn.Parameter(torch.normal(0, 1, [k, d]))
        self.Wch = nn.Parameter(torch.normal(0, 1, [d, k]))
        self.softmax = torch.nn.Softmax(dim=2)
    def forward(self, x):
        q = torch.einsum("kd,bnd->bnk", self.Q, x)
        k = torch.einsum("kd,bnd->bnk", self.K, x)
        v = torch.einsum("kd,bnd->bnk", self.V, x)
        dotQueryKey = torch.einsum("bij, bkj->bik", q, k)/math.sqrt(k)
        queryPrs = self.softmax(dotQueryKey)
        summedRows = torch.einsum("bij,bjk->bik", queryPrs, vh)
        return torch.einsum("dk,bnk->bnd", self.Wch, summedRows)
        
        

        
class VerboseMultiHeadSelfAttentionModuleOld(HelpfulModule):
    def __init__(self, n, d, k):
        super().__init__()
        self.n, self.d, self.k = n,d,k
        self.Q = nn.Parameter(torch.normal(0, 1, [k, d]))
        self.K = nn.Parameter(torch.normal(0, 1, [k, d]))
        self.V = nn.Parameter(torch.normal(0, 1, [k, d]))
        self.Wch = nn.Parameter(torch.normal(0, 1, [d, k]))
        self.softmax = torch.nn.Softmax(dim=2)
    def forward(self, x):
        # x is [b,n,d]
        # Q, K, and V are [k,d]
        # Q*x[i] is [k,d]*[d] = [k]
        # so Q*x should be of size [k,d]*[b,n,d] -> [b,n,k]
        q = torch.einsum("kd,bnd->bnk", self.Q, x) # You can test this works by doing Q@x[0,0] and seeing first row is the same row of outputs
        k = torch.einsum("kd,bnd->bnk", self.K, x)
        v = torch.einsum("kd,bnd->bnk", self.V, x)
        
        # q[i,j] is a vector of size k
        # k[i,j] is a vector of size k
        # for every pair of (vector from q, vector from k), we need to get an output by taking their dot product
        # normally if you have two matrices A and B of size NxM and MxK,
        # when you multiply them, you can think of the output matrix's value in the (i,j)th position as the ith row in A dot jth column in B
        # (thus it is every pair of row from first and column from second)
        # in einsum, torch.einsum("ij,jk->ik", A, B)
        # If instead A and B are of size NxM and NxM and you want to do every pair of rows, you can just do
        # torch.einsum("ij,kj->ik") # transpose second matrices indices so it takes rows instead of columns
        # we have an additional batch index at the front, so include that
        dotQueryKey = torch.einsum("bij, bkj->bik", q, k)/math.sqrt(k)
        # now the [b,i,j]th element of dotQueryKey is the dot product of q[b,i] and k[b,j].
        # since q and k were of dimension [b,n,k], this becomes [b,n,n]
        
        # If we look at dotQueryKey[b,i], that is a vector of size n.
        # the jth term is the ith query dot the jth key vector.
        # thus, each term is "how much" the ith query aligns with the jth key.
        # so we will take a softmax so we actually get probabilities (TODO: try sigmoid. Seems odd that it's only one class of things)
        # This doesn't change the dim, so it's still [b,n,n]
        queryPrs = self.softmax(dotQueryKey)
        
        # the ith output is taking sum over j of (queryPrs[b,i,j])*(vh[b,j])
        #                                         scalar            vector
        # queryPrs is [b,n,n]
        # v        is [b,n,k]
        # so         queryPrs[b,i] is of dim n
        #            v[b]          is of dim [n,k]
        # so         j ranges from 0 to n-1
        # fixing b and i and thinking of this as a small matrix, we do
        # queryPrs = [0.4, 0.6] (n=2 in this example)
        #              
        # v        = [1.2,     3.4,     5.2] (each row is of length k, k=3 in this example)
        #          = [3.4,     2.3,     1.1] (there are n=2 rows)
        # we do
        #            [0.4*1.2, 0.4*3.4, 0.4*5.2 ]
        #            [0.6*3.4, 0.6*2.3, 0.6*1.1 ]
        # and then we sum them:
        #            [0.4*1.2+0.6*3.4, 0.4*3.4+0.6*2.3, 0.4*5.2+0.6*1.1]
        # In other words, for a given i we dot the ith row of queryPrs[b] (dim n) by each column in v[b] (v[b] is [n,k], so each column is dim n)
        # Thus, the output's [b,i,j] value is the ith row of queryPrs[b] dot the jth column of v[b] 
        # for regular matrix multiplication of A and B, the [i,j]th value is ith row of A dot jth column of B, so this is just regular matrix multiplication.
        # In einsum: torch.einsum("bij,bjk->bik")
        # which means that our output[b,i] is [b,n,n] x [b,n,k] -> [b,n,k]
        summedRows = torch.einsum("bij,bjk->bik", queryPrs, vh)
        
        # now we need to project back to a [b,n,d]
        # W[h,c] is a [d,k] dimensional matrix
        # summedRows is [b,n,k]
        # we want to do [d,k]x[b,n,k] -> [b,n,d]
        res = torch.einsum("dk,bnk->bnd", self.Wch, summedRows)
        return res
        
        
        
