In [103]:
import torch
from torch import nn
import math
def approx_equals(a, b):
    assert torch.allclose(a, b, 0.0001), str(a) + "!=" + str(b)

In [104]:

class HelpfulModule(nn.Module):
    def __init__(self):
        super().__init__()
        self._myHyperParams = {}
        
    def __setattr__(self, attr, val):
        super().__setattr__(attr, val) # make sure to call super because torch.nn.Module also overrides this
        simpleTypes = [int, str, float]
        if type(val) in simpleTypes or (type(val) is list and (len(val) == 0 or type(val[0]) in simpleTypes)):
            self._myHyperParams[attr] = val
            
    
    def extra_repr(self):
        return ", ".join([(str(param) + ": " + str(val)) for param, val in self._myHyperParams.items()])

    
def debiasLerp(a, b, p, nBatches):
    return torch.lerp(a, b/torch.tensor(p).pow(nBatches), p)
    
class BatchNorm(HelpfulModule):
    def __init__(self, inputSize, mom=0.1, eps=0.01):
        super().__init__()
        self.mom, self.eps = mom, eps
        self.inputSize = inputSize
        self.multiplicitiveWeight = nn.Parameter(torch.ones([inputSize]))
        self.additiveWeight = nn.Parameter(torch.zeros([inputSize]))
        self.register_buffer('running_mean', torch.zeros(inputSize))
        self.register_buffer('running_sum_of_squares', torch.zeros(inputSize))
        self.register_buffer('running_squared_sum', torch.zeros(inputSize))
        self.nBatches = 1
    
    def resetParams(self):
        self.running_mean.zero_()
        self.running_sum_of_squares.zero_()
        self.running_squared_sum.zero_()
        self.nBatches = 1
        
    def forward(self, x):
        self.nBatches += 1
        self.running_mean = debiasLerp(self.running_mean, x.mean(axis=0), self.mom, self.nBatches)
        self.running_mean_of_squares = debiasLerp(self.running_mean, x.pow(2).mean(axis=0), self.mom, self.nBatches)
        self.running_squared_mean = debiasLerp(self.running_mean, x.mean(axis=0).pow(2), self.mom, self.nBatches)
        mu = self.running_mean
        var = torch.max(self.running_mean_of_squares - self.running_squared_mean, torch.tensor(self.eps))
        normalizedOutput = (x-mu)/var
        return normalizedOutput*self.multiplicitiveWeight+self.additiveWeight

class LayerNorm(HelpfulModule):
    def __init__(self, inputSize, mom=0.1, eps=0.01):
        super().__init__()
        self.mom, self.eps = mom, eps
        self.inputSize = inputSize
        self.multiplicitiveWeight = nn.Parameter(torch.ones(inputSize))
        self.additiveWeight = nn.Parameter(torch.zeros(inputSize))
        self.register_buffer('running_mean', torch.tensor(0.0))
        self.register_buffer('running_sum_of_squares', torch.tensor(0.0))
        self.register_buffer('running_squared_sum', torch.tensor(0.0))
        self.nBatches = 0
    
    def resetParams(self):
        self.running_mean.zero_()
        self.running_sum_of_squares.zero_()
        self.running_squared_sum.zero_()
        self.nBatches = 0
        
    def forward(self, x):
        self.nBatches += 1
        self.running_mean = debiasLerp(self.running_mean, x.mean(), self.mom, self.nBatches)
        self.running_mean_of_squares = debiasLerp(self.running_mean, x.pow(2).mean(), self.mom, self.nBatches)
        self.running_squared_mean = debiasLerp(self.running_mean, x.mean().pow(2), self.mom, self.nBatches)
        mu = self.running_mean
        var = torch.max(self.running_mean_of_squares - self.running_squared_mean, torch.tensor(self.eps))
        normalizedOutput = (x-mu)/var
        return normalizedOutput*self.multiplicitiveWeight+self.additiveWeight


        
class TransformerBlock(HelpfulModule):
    def __init__(self, n, d, k):
        super().__init__()
        self.n, self.d, self.k = n,d,k
        self.Q = nn.Parameter(torch.normal(0, 1, [k, d]))
        self.K = nn.Parameter(torch.normal(0, 1, [k, d]))
        self.V = nn.Parameter(torch.normal(0, 1, [k, d]))
        self.Wch = nn.Parameter(torch.normal(0, 1, [d,k]))
        self.softmax = torch.nn.Softmax(dim=2)
    def forward(self, x):
        # x is [b,n,d]
        # Q, K, and V are [k,d]
        # Q*x[i] is [k,d]*[d] = [k]
        # so Q*x should be of size [k,d]*[b,n,d] -> [b,n,k]
        qh = torch.einsum("kd,bnd->bnk", self.Q, x) # You can test this works by doing Q@x[0,0] and seeing first row is the same row of outputs
        kh = torch.einsum("kd,bnd->bnk", self.K, x)
        vh = torch.einsum("kd,bnd->bnk", self.V, x)
        
        # q[i,j] is a vector of size k
        # k[i,j] is a vector of size k
        # for every pair of (vector from q, vector from k), we need to get an output by taking their dot product
        # normally if you have two matrices A and B of size NxM and MxK,
        # when you multiply them, you can think of the output matrix's value in the (i,j)th position as the ith row in A dot jth column in B
        # (thus it is every pair of row from first and column from second)
        # in einsum, torch.einsum("ij,jk->ik", A, B)
        # If instead A and B are of size NxM and NxM and you want to do every pair of rows, you can just do
        # torch.einsum("ij,kj->ik") # transpose second matrices indices so it takes rows instead of columns
        # we have an additional batch index at the front, so include that
        dotQueryKey = torch.einsum("bij, bkj->bik", q, k)/math.sqrt(k)
        # now the [b,i,j]th element of dotQueryKey is the dot product of q[b,i] and k[b,j].
        # since q and k were of dimension [b,n,k], this becomes [b,n,n]
        
        # If we look at dotQueryKey[b,i], that is a vector of size n.
        # the jth term is the ith query dot the jth key vector.
        # thus, each term is "how much" the ith query aligns with the jth key.
        # so we will take a softmax so we actually get probabilities (TODO: try sigmoid. Seems odd that it's only one class of things)
        # This doesn't change the dim, so it's still [b,n,n]
        queryPrs = self.softmax(dotQueryKey)
        
        # Now we take sum over i of (queryPrs[b,i,j])*(vh[b,j])
        #                            scalar            vector
        
        
        # Now we need to take queryP
        
        
        
        


In [109]:
def testTransformer2():
    b, n, d, k = 2, 3, 4, 5
    from minGPT.mingpt.utils import set_seed
    set_seed(27)
    queryPrs = torch.normal(0, 1, [b,n,n])
    vh = torch.normal(0, 1, [b,n,k])
    print("qprs", queryPrs)
    print("vh", vh)
    # Now we take sum over i of (queryPrs[b,i,j])*(vh[b,j])
    #                            scalar            vector
    res = torch.einsum("bij,bjw->bjw", queryPrs, vh)
    val = 0.0
    bv = 0
    jv = 0
    for i in range(n):
        val += queryPrs[bv,i,jv]*vh[bv,jv]
    print("ayy", val)
    print("res", res)
    
    
    a = torch.normal(0, 1, [2, 3])
    b = torch.normal(0, 1, [2, 3, 4])
    
testTransformer2()

qprs tensor([[[ 1.7650,  0.0664,  0.0753],
         [-1.3867,  0.0756, -0.4957],
         [-0.8165, -0.0069, -1.7975]],

        [[-0.3770, -1.6994,  1.0734],
         [-0.1926,  0.5088,  1.2001],
         [ 1.0033,  0.3197, -0.6699]]])
vh tensor([[[ 0.4601,  0.3644, -1.4775,  0.4753, -0.3383],
         [-0.5367,  1.5008, -0.7286,  0.4594,  0.4356],
         [-0.2073, -1.0252, -1.1372,  1.0307,  0.4656]],

        [[-0.8964,  0.5814, -0.9950, -0.9881, -0.1613],
         [ 0.1007,  0.9505,  0.9992, -0.8928,  1.6873],
         [ 0.4901,  0.2179,  0.0329,  0.0506, -0.0541]]])


RuntimeError: dimension mismatch for operand 1: equation 2 tensor 3

In [89]:
def testTransformer():
    b, n, d, k = 2, 3, 4, 5
    from minGPT.mingpt.utils import set_seed
    set_seed(27)
    x = torch.normal(0, 1, [b,n,d])
    print("x:", x)
    x
    Q = torch.normal(0, 1, [k,d])
    K = torch.normal(0, 1, [k,d])
    V = torch.normal(0, 1, [k,d])
    # In other words
    print("Q:", Q)
    res = torch.einsum('kd,bnd->bnk', Q, x)
    # Check that it's the same both ways
    approx_equals(Q@(x[0,0]), res[0,0])
    approx_equals(Q@(x[1,0]), res[1,0])
    approx_equals(Q@(x[0,1]), res[0,1])
    approx_equals(Q@(x[1,1]), res[1,1])
    
    q = torch.einsum("kd,bnd->bnk", Q, x)
    k = torch.einsum("kd,bnd->bnk", Q, x)
    v = torch.einsum("kd,bnd->bnk", Q, x)
    
    print("q:", q)
    print("k:", k)
    dotQueryKey = torch.einsum("bij, bkj->bik", q, k)
    print("dq", dotQueryKey.shape, dotQueryKey)
    
    # dotQueryKey[b,i,j] is q[b,i] dot k[b,j]
    for ba in range(b):
        approx_equals(q[ba,0]@k[ba,0], dotQueryKey[ba,0,0])
        approx_equals(q[ba,0]@k[ba,1], dotQueryKey[ba,0,1])
        approx_equals(q[ba,1]@k[ba,0], dotQueryKey[ba,1,0])
        approx_equals(q[ba,1]@k[ba,1], dotQueryKey[ba,1,1])
    
testTransformer()

x: tensor([[[ 1.7650,  0.0664, -0.0706, -0.1672],
         [-0.4266,  1.5005, -0.2636, -1.0210],
         [-1.7975, -0.3770,  0.6140,  0.5948]],

        [[-0.8629, -0.9511, -0.9195, -0.7592],
         [ 0.3197, -0.6699,  1.5661,  0.8074],
         [-1.6036,  0.1696, -0.0308,  0.0434]]])
Q: tensor([[ 1.5008, -0.7286, -0.5098,  0.4431],
        [-0.9389,  1.5772,  1.6559, -0.4713],
        [ 0.4656, -0.8964,  0.5814, -0.9950],
        [ 0.6763,  0.1337,  0.0659,  0.5385],
        [ 0.9992, -0.8928,  1.6873,  0.4901]])
q: tensor([[[ 2.5625, -1.5905,  0.8876,  1.1078,  1.5033],
         [-2.0517,  2.8120, -0.6810, -0.6551, -2.7111],
         [-2.4725,  1.8294, -0.7338, -0.9053, -0.1320]],

        [[-0.4698, -1.8548,  0.6716, -1.1802, -1.9367],
         [ 0.5273,  0.8561,  0.8565,  0.6647,  3.9557],
         [-2.4954,  1.7016, -0.9597, -1.0404, -1.7845]]])
k: tensor([[[ 2.5625, -1.5905,  0.8876,  1.1078,  1.5033],
         [-2.0517,  2.8120, -0.6810, -0.6551, -2.7111],
         [-2.4725, 

In [85]:
from minGPT.mingpt.utils import set_seed
set_seed(27)
a = torch.normal(0, 1, [4, 5])
b = torch.normal(0, 1, [5, 4])
print("a", a)
print("b", b)
print(a[0]@b[:,0])
torch.einsum("ij,jk->ik", a, b)
set_seed(27)
print("later stuff")
a = torch.normal(0, 1, [4, 5])
b = torch.normal(0, 1, [4, 5])
print(a[0]@b[0])
print(a[0]@b[1])
print(a[1]@b[0])
print(a[1]@b[1])
# we want to go from a [nxm],[nxm] to a [n,n]
torch.einsum("ij,kj->ik", a, b)

a tensor([[ 1.7650,  0.0664, -0.0706, -0.1672,  0.0756],
        [-0.4957, -0.8165, -0.0069, -1.7975, -0.3770],
        [ 0.6140,  0.5948, -0.1926,  0.5088,  1.2001],
        [ 1.0033,  0.3197, -0.6699,  1.5661,  0.8074]])
b tensor([[-1.4775,  0.4753, -0.3383, -0.5367],
        [-0.8237, -0.4236,  0.3272, -1.9896],
        [-0.9389,  1.5772,  1.6559, -0.4713],
        [ 0.2374, -0.1400, -1.0862,  0.5188],
        [ 0.6763,  0.1337,  0.0659,  0.5385]])
tensor(-2.5848)
later stuff
tensor(-2.5248)
tensor(-0.3093)
tensor(1.6219)
tensor(1.0495)


tensor([[-2.5248, -0.3093,  2.8159,  0.9808],
        [ 1.6219,  1.0495,  0.2236, -1.1318],
        [-1.8210,  1.7328, -0.6841,  1.3748],
        [-2.6093,  0.8155,  0.2554,  1.1852]])

In [102]:
from minGPT.mingpt.utils import set_seed
set_seed(27)
a = torch.normal(0, 1, [2, 3, 4])
print(a)

softmax(a)[0,1].sum()

tensor([[[ 1.7650,  0.0664, -0.0706, -0.1672],
         [-0.4266,  1.5005, -0.2636, -1.0210],
         [-1.7975, -0.3770,  0.6140,  0.5948]],

        [[-0.8629, -0.9511, -0.9195, -0.7592],
         [ 0.3197, -0.6699,  1.5661,  0.8074],
         [-1.6036,  0.1696, -0.0308,  0.0434]]])


tensor(1.0000)

In [None]:
bn = LayerNorm(10)
inputs = torch.normal(0, 1, [30, 10])
ayy = bn(inputs)
print(ayy.mean(), ayy.std())