In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, d_k):
        super().__init__()
        self.d_k = d_k

    def forward(self, Q, K, V):
        """
        Input of dim: [batch, time_dim, input_dim]
        """

        output = torch.zeros_like(Q)

        # should be vectorized, keeping it like this for simplicity
        for b in range(Q.size()[0]):
            v = Q[b] @ K[b].T / math.sqrt(self.d_k)
            v = F.softmax(v, dim=1)
            v = v @ V[b]
            output[b] = v

        return output

    
class GELU(nn.Module):
    """
    Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
    """
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

In [2]:
class Block(nn.Module):
    def __init__(self, dim_size):
        super().__init__()
        self.attention = Attention(d_k=dim_size)
        self.linear = nn.Linear(dim_size, dim_size)
        self.norm = GELU()
    
    def forward(self, X):
        v = self.attention(X, X, X)
        v = self.norm(X + v)

        # should be vectorized, keeping it like this for simplicity
        for b in range(X.size(0)):
            v[b] += self.linear(v[b])

        v = self.norm(v)

        return v

In [3]:
class Transformer(nn.Module):
    def __init__(self, num_blocks, in_dim, feature_dim):
        super().__init__()
        self.blocks = []
        for _ in range(num_blocks):
            self.blocks.append(Block(feature_dim))

        self.feature_dim = feature_dim
        self.input_linear = nn.Linear(in_dim, feature_dim)
        self.output_linear = nn.Linear(feature_dim, 1)

    def forward(self, x):
        """
        Input of dim: [batch, time_dim, input_dim]
        """
        v = torch.zeros(x.size(0), x.size(1), self.feature_dim)

        # should be vectorized, keeping it like this for simplicity
        for b in range(x.size(0)):
            v[b] = self.input_linear(x[b])

        for block in self.blocks:
            v = block(v)
        
        v = torch.mean(v, dim=1)
        o = self.output_linear(v)

        return o

In [4]:
t = Transformer(num_blocks=2, in_dim=10, feature_dim=64)

X = torch.rand(2, 5, 10)
o = t.forward(X)
o.shape

  self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))


torch.Size([2, 1])