<a href="https://colab.research.google.com/github/TD008/OTransfomer/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import os
import numpy as np

In [8]:
class AttHead(torch.nn.Module):
    def __init__(self, dmodel, dk, dv, decoder=True):
        super(AttHead, self).__init__()
        self.dmodel = dmodel
        self.dk = dk
        self.dv = dv
        self.decoder = decoder

        self.key = torch.nn.Linear(dmodel, dk)
        self.query = torch.nn.Linear(dmodel, dk)
        self.value = torch.nn.Linear(dmodel, dv)
        self.out = torch.nn.Linear(dv, dmodel)
        self.smax = torch.nn.Softmax(dim = -1) # Will be applied along the dk dimension

    def forward(self, x):
    # want to input a tensor of (batch_size, block_size, dmodel) and output
    # of dmodel/h ??

        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        affinities = q @ k.transpose(-2, -1)   # dot product of q and t packaged into a matrix mult.
        affinities = affinities / (self.dk ** 0.5)  # Scale by sqrt(dk)

        if self.decoder:
            mask = torch.triu(torch.ones(x.size(1), x.size(1)), diagonal=1)
            mask = mask.masked_fill(mask==1, float('-inf'))
            affinities += mask

        affinities = self.smax(affinities)   # Transform the affinities into weights summing to 1

        output = affinities @ v
        output = self.out(output)
        return output


In [16]:
class AttBlock(torch.nn.Module):
    def __init__(self, num_heads, dmodel):
        super(AttBlock, self).__init__()
        self.num_heads = num_heads
        self.dmodel = dmodel
        self.heads = [AttHead(dmodel, dmodel, dmodel) for i in range(num_heads)]

    def __forward__(self, idx):
        pass

In [13]:
# Quick sanity test
batch_size = 2
block_size = 4
dmodel, dk, dv = 6, 6, 6

data = torch.randn(batch_size, block_size, dmodel)
print("data:\n", data)

model = AttHead(dmodel, dk, dv)

forward = model(data)
print("forward:\n", forward)

data:
 tensor([[[-0.0375, -0.7401, -0.5613,  0.6808,  0.6946, -1.4405],
         [-0.3934,  0.1113,  0.0411, -0.0616,  1.6976, -0.9105],
         [-0.8492,  1.0304, -0.3294, -0.3590, -1.4121, -0.0631],
         [ 0.1757, -0.4288, -1.0191, -0.7912, -0.3052, -0.7022]],

        [[ 0.3466, -0.9306,  0.3998, -0.7576, -0.6818, -0.2977],
         [-0.2433, -0.5938,  0.5151, -0.1521,  2.3302, -0.7203],
         [ 0.6507,  0.2706, -0.2490, -0.6644, -0.6329,  0.2481],
         [ 0.9437, -0.7151,  0.3176, -0.6237,  0.3903,  0.2886]]])
forward:
 tensor([[[-1.6528e-01,  5.2006e-01, -5.4918e-01,  7.3699e-01, -2.3925e-01,
           3.2250e-01],
         [-1.4137e-01,  3.8645e-01, -5.6850e-01,  5.6132e-01, -2.4060e-01,
           2.0878e-01],
         [-7.3153e-02,  3.1676e-01, -5.0752e-01,  4.1274e-01, -6.0772e-02,
           1.2568e-01],
         [-1.0167e-01,  4.2677e-01, -5.9716e-01,  4.6476e-01, -1.0579e-01,
           1.5061e-01]],

        [[ 7.7240e-02,  5.4421e-01, -4.9427e-01,  4.0210e-01,