In [1]:
import datasets
import transformers
import torch
import torch.nn.functional as F
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import torch.nn as nn

from collections import defaultdict
from collections.abc import (
    Callable,
    Iterable
)
import numpy as np

from util_110724 import (
    get_dataloader_random_reshuffle,
    to_ensembled
)

In [2]:
config = {
    "seed": 0, 
    "device": "cuda", 
    "features_dtype": torch.float32,
    "ensemble_shape": (16,)
}

In [3]:
class LayerNorm(torch.nn.Module):
    """
    Ensemble-ready layer normalization layer

    Arguments
    ---------
    config : `dict`
        Configuration dictionary. Required key-value pairs:
        `"device"` : `str`
            The device to store parameters on.
        `"ensemble_shape"` : `tuple[int]`
            The shape of the ensemble of affine transformations
            the model represents.
        `"float_dtype"` : `torch.dtype`
            The floating point datatype to use for the parameters.
    normalized_shape : `int | tuple[int]`
        The part of the shape of the incoming tensors
        that are to be normalized together with batch dimensions.
        We view the following as batch dimensions:
        ```
        range(
            len(ensemble_shape),
            -len(normalized_shape) - normalized_offset
        )
        ```
        If an integer, we view it as a single-element tuple.
    bias : `bool`, optional
        If `elementwise_affine`, whether to include offset
        in the learned transformation. Default: `True`.
    elementwise_affine : `bool`, optional
        Whether to include learnable scale. If this and `bias`,
        then we also include learnable offset. These will be tensors
        of shape `ensemble_shape + normalized_shape` that are
        broadcast to the incoming feature tensors appropriately.
        Default: `True`.
    epsilon : `float`, optional
        Small positive value, to be included in the divisor when we
        divide by the variance, for numerical stability. Default: `1e-5`.
    normalized_offset : `int`, optional
        We get `normalized_shape` out of an incoming feature tensor
        at dimensions
        ```
        range(
            -len(normalized_shape) - normalized_offset,
            -normalized_offset
        )
        ```
        Default: `0`.

    Calling
    -------
    Instance calls require one positional argument:
    batch : `dict`
        The input data dictionary. Required key:
        `"features"` : `torch.Tensor`
            Tensor of features.
    """
    def __init__(
        self,
        config: dict,
        normalized_shape: int | tuple[int],
        bias=True,
        elementwise_affine=True,
        epsilon=1e-5,
        normalized_offset=0
    ):
        super().__init__()

        if hasattr(normalized_shape, "__int__"):
            self.normalized_shape = (normalized_shape,)
        else:
            self.normalized_shape = normalized_shape

        self.ensemble_shape = config["ensemble_shape"]
        self.epsilon = epsilon
        self.normalized_offset = normalized_offset

        if elementwise_affine:
            self.scale = torch.nn.Parameter(torch.ones(
                self.ensemble_shape + self.normalized_shape + (1,) * normalized_offset,
                device=config["device"],
                dtype=config["float_dtype"]
            ))
            if bias:
                self.bias = torch.nn.Parameter(torch.zeros_like(self.scale))
            else:
                self.bias = None

        else:
            self.bias, self.scale = None, None


    def forward(self, batch: dict) -> dict:
        features: torch.Tensor = batch["features"]

        ensemble_dim = len(self.ensemble_shape)
        features = to_ensembled(self.ensemble_shape, features)

        normalized_dim = len(self.normalized_shape)

        batch_dim = len(features.shape) - ensemble_dim - normalized_dim - self.normalized_offset
        normalized_range = tuple(range(
            ensemble_dim,
            ensemble_dim + batch_dim
        )) + tuple(range(
            -normalized_dim - self.normalized_offset,
            -self.normalized_offset
        ))

        features = features - features.mean(dim=normalized_range, keepdim=True)
        features = features / features.std(dim=normalized_range, keepdim=True)

        if self.scale is not None:
            scale = self.scale.unflatten(
                ensemble_dim,
                (1,) * batch_dim + self.normalized_shape[:1]
            )

            features = features * scale

            if self.bias is not None:
                bias = self.bias.unflatten(
                    ensemble_dim,
                    (1,) * batch_dim + self.normalized_shape[:1]
                )
                features = features + bias

        return batch | {"features": features}