In [3]:
import torch
import quickstart_utils as utils

class BasicTransformerLayer(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
        layernorm_eps: int = 1e-5,
        attention_dropout: float = 0.1,
        hidden_dropout: float = 0.1,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.kv_channels = hidden_size // num_attention_heads
        self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
        self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)
        self.attention = utils.DotProductAttention(
            num_attention_heads=num_attention_heads,
            kv_channels=self.kv_channels,
            attention_dropout=attention_dropout,
        )
        self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)
        self.dropout = torch.nn.Dropout(hidden_dropout)
        self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
        self.mlp = utils.BasicMLP(
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
        )

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        res = x
        x = self.ln1(x)

        # Fused QKV projection
        qkv = self.qkv_projection(x)
        qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
        q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)

        x = self.attention(q, k, v, attention_mask)
        x = self.projection(x)
        x = self.dropout(x)
        x = res + x
        res = x
        x = self.ln2(x)
        x = self.mlp(x)

        return x + res

ModuleNotFoundError: No module named 'transformer_engine'