In [1]:
import torch
from torch import nn

## Positional Encoding

In [2]:
import math
import numpy as np

class PositionalEncoder(nn.Module):
    def __init__(self, d_model=12):
        super().__init__()
        self.d_model= d_model
        if self.d_model % 6 != 0:
            raise ValueError("d_models must be divedable on 6!")

        pe = np.zeros((9, 11, 11, d_model))

        for pos_x in range(9):
            pe[pos_x,:,:,0:d_model//3:2] = np.sin(0.33 * pos_x / 10_000 ** (6*np.arange(d_model//6)/d_model))
            pe[pos_x,:,:,1:d_model//3:2] = np.cos(0.33 * pos_x / 10_000 ** (6*np.arange(d_model//6)/d_model))

        for pos_y in range(11):
            pe[:,pos_y,:,d_model//3:2*d_model//3:2] = np.sin(0.33 * pos_y / 10_000 ** (6*np.arange(d_model//6)/d_model))
            pe[:,pos_y,:,1+d_model//3:2*d_model//3:2] = np.cos(0.33 * pos_y / 10_000 ** (6*np.arange(d_model//6)/d_model))

        for pos_z in range(11):
            pe[:,:,pos_z,2*d_model//3::2] = np.sin(0.33 * pos_z / 10_000 ** (6*np.arange(d_model//6)/d_model))
            pe[:,:,pos_z,1+2*d_model//3::2] = np.cos(0.33 * pos_z / 10_000 ** (6*np.arange(d_model//6)/d_model))
            
        pe = pe.reshape(9 * 11 * 11, d_model)
        pe = torch.tensor(pe).float()
        self.register_buffer("pe", pe)
        
    def forward(self, x):
        x = x * math.sqrt(self.d_model // 3) # is it actualy needed?
        num_tokens = x.shape[1]
        aux_tokens = num_tokens - 9 * 11 * 11
        
        x = x + torch.cat([self.pe, torch.zeros((aux_tokens, self.d_model))], dim=0)
        return x

## Encoder

In [3]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dim_feedforward=64, activation=nn.ReLU):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=0.0, batch_first=True)
        # Implementation of feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.pos_encoder = PositionalEncoder(d_model=d_model)
        self.activation = activation()
        
    def forward(self, src):
        # Forward with prenormalization
        src2 = self.norm1(src)
        q = k = self.pos_encoder(src2)
        src2 = self.self_attn(q, k, value=src2)[0]
        src = src + src2
        src2 = self.norm2(src)
        src2 = self.linear2(self.activation(self.linear1(src2)))
        src = src + src2
        return src

In [4]:
src = torch.rand(1, 9 * 11 * 11 + 50, 12)
encoder = TransformerEncoderLayer(d_model=12, num_heads=4)
out = encoder(src)
out.shape

torch.Size([1, 1139, 12])

In [5]:
import copy

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [6]:
class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src):
        output = src
        
        for layer in self.layers:
            output = layer(output)
        
        if self.norm is not None:
            output = self.norm(output)

        return output

In [16]:
src = torch.rand(1, 9 * 11 * 11 + 50, 6)
encoder_layer = TransformerEncoderLayer(d_model=6, num_heads=1, dim_feedforward=1024)
encoder = TransformerEncoder(encoder_layer, 1)
out = encoder(src)
out.shape

torch.Size([1, 1139, 6])

## Decoder

In [17]:
sum(p.numel() for p in encoder.parameters())

13510