In [54]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einx
import random
from torch.utils.data import Dataset, DataLoader

def genList(listSize):
        return [random.random() for x in range(listSize)]

def genAnswerKey(inputList):
    answerKeyList = []
    for val in inputList:
        numLessThanVal = 0
        for otherVal in inputList:
            if otherVal < val:
                numLessThanVal += 1
        answerKeyList.append(numLessThanVal)
    return answerKeyList


class ListDataset(Dataset):
    def __init__(self, listSize, datasetSize):
        self.datasetSize = datasetSize
        self.listSize = listSize
        self.allLists = [genList(self.listSize) for idx in range(self.datasetSize)]
        self.allAnswerKeys = [genAnswerKey(inputList) for inputList in self.allLists]

    def __len__(self):
        return self.datasetSize

    def __getitem__(self, idx):
        return torch.tensor(self.allLists[idx]).float(), torch.tensor(self.allAnswerKeys[idx])

listDataset = ListDataset(3, 1000)
trainLoader = DataLoader(listDataset, batch_size=1, shuffle=True)
sampleInputTensor = next(iter(trainLoader))[0]

In [61]:

# Transform a (B, L) list to a (B, L, d_emb)
class Stem(nn.Module):
    def __init__(self, embeddingDim):
        super().__init__()
        self.linearToEmbDim = nn.Linear(1, embeddingDim)
        self.gelu = nn.GELU()
        

    # X expected to be a (B, L) list
    def forward(self, x):
        x = einx.rearrange("... -> ... 1", sampleInputTensor)
        x = self.linearToEmbDim(x)
        x = self.gelu(x)
        return x


stem = Stem(2)
stem(sampleInputTensor).shape



torch.Size([1, 3, 2])

In [127]:
keyDim = 2
valueDim = 3

sampleQuery = torch.tensor([
    [
        [0., 1],
        [1, 0],
        [1, 1]
    ]
])

sampleKey = torch.tensor([
    [
        [1., 0],
        [0, 1],
        [1, 1]
    ]
])


# Perform attention from all queries to all keys
dotProd = einx.dot("... q [d], ... k [d] -> ... q k", sampleQuery, sampleKey)
print(dotProd)

softMaxPerQuery = einx.softmax("b q [k]", dotProd / keyDim)
print(softMaxPerQuery)

tensor([[[0., 1., 1.],
         [1., 0., 1.],
         [1., 1., 2.]]])
tensor([[[0.2327, 0.3837, 0.3837],
         [0.3837, 0.2327, 0.3837],
         [0.2741, 0.2741, 0.4519]]])


In [128]:
q = torch.tensor([
    [
        [[0., 1]],
        [[1, 0]],
        [[1, 1]]
    ]
])

k = torch.tensor([
    [
        [[1., 0]],
        [[0, 1]],
        [[1, 1]]
    ]
])
a = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=8)

SolveExpansionException: Failed to solve for the number of axes in the expressions.
Equations:
    b q (h c) = 1 1 3 2
    b k (h c) = 1 1 3 2
    h = 1
Reason: Expansion '(3,)' of expression 'b q (h c)' does not match expansion '(4,)' of expression '1 1 3 2'