In [None]:
# | default_exp layers/codebook

# Imports

In [None]:
# | export


import torch
import torch.distributed as dist
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

from vision_architectures.docstrings import populate_docstring
from vision_architectures.utils.custom_base_model import CustomBaseModel, Field, computed_field

# Config

In [None]:
# | export


class CodebookConfig(CustomBaseModel):
    num_vectors: int = Field(..., description="Number of vectors in the codebook")
    dim: int = Field(..., description="Dimension of each vector in the codebook")

    revive_dead_vectors_after_n_steps: int = Field(
        100, description="Number of steps after which a vector is declared dead and is revived (0 means never revive)"
    )

    ema_decay: float | None = Field(0.99, description="EMA decay rate for updating codebook vectors")

    @computed_field(description="Whether to use EMA for updating codebook vectors")
    @property
    def use_ema(self) -> bool:
        return self.ema_decay is not None and self.ema_decay > 0.0

# Codebook

In [None]:
# | export


class Codebook(nn.Module, PyTorchModelHubMixin):
    """Codebook that can be used for vector quantization. This implementation maintains the vectors in distributed
    settings. It also supports exponential moving average (EMA) updates of the codebook vectors as well as reviving
    dead vectors."""

    @populate_docstring
    def __init__(self, config: CodebookConfig = {}, **kwargs):
        """Initialize the Codebook.

        Args:
            config: {CONFIG_KWARGS_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = CodebookConfig.model_validate(config | kwargs)

        num_vectors = self.config.num_vectors
        dim = self.config.dim

        self.vectors = nn.Embedding(num_vectors, dim)

        # Usage counter tracks the number of times a vector has been used since it was last revived
        usage_counter = torch.zeros(num_vectors, dtype=torch.long)
        self.register_buffer("usage_counter", usage_counter, persistent=False)
        self.usage_counter: torch.Tensor  # For hinting

        # stale_counter tracks the number of batches a vector has been unused since the last time it was used
        stale_counter = torch.zeros(num_vectors, dtype=torch.long)
        self.register_buffer("stale_counter", stale_counter, persistent=False)
        self.stale_counter: torch.Tensor

        # Create a generator object so that randomness is consistent across all devices
        self.generator = torch.Generator()
        self.generator_initalized = False

        if self.config.use_ema:
            self.decay = self.config.ema_decay

            # EMA cluster size tracking
            cluster_size = torch.zeros(self.config.num_vectors)
            self.register_buffer("cluster_size", cluster_size, persistent=False)
            self.cluster_size: torch.Tensor

            # EMA for embedding vectors
            ema_vectors = torch.zeros_like(self.vectors.weight)
            self.register_buffer("ema_vectors", ema_vectors, persistent=False)
            self.ema_vectors: torch.Tensor

    @torch.no_grad()
    def calculate_perplexity(self, indices: torch.Tensor) -> torch.Tensor:
        """Calculate perplexity of the codebook usage.

        Args:
            indices: Indices of the codebook vectors chosen for each input vector.

        Returns:
            Perplexity of the codebook usage.
        """
        # Get mapping of which BS vector chose which codebook vector
        encodings = self._one_hot_indices(indices)
        # Calculate average number of times each codebook vector was chosen
        avg_probs = encodings.float().mean(dim=0)
        # Calculate perplexity i.e. utililzation of codebook
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        return perplexity

    def calculate_losses(self, x: torch.Tensor, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Calculate codebook and commitment losses.

        Args:
            x: Input vectors. Should be of shape (BS, C) where BS is a combination of batch and spatial/temporal
                dimensions.
            z: Quantized vectors. Should be of shape (BS, C).

        Returns:
            codebook_loss: Codebook loss.
            commitment_loss: Commitment loss.
        """
        commitment_loss = torch.mean((z - x.detach()) ** 2)
        if self.config.use_ema:
            codebook_loss = torch.zeros_like(commitment_loss)
        else:
            codebook_loss = torch.mean((z.detach() - x) ** 2)
        return codebook_loss, commitment_loss

    def quantize(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Quantize the input vectors using the codebook and return along with losses and perplexity.

        Args:
            x: Input vectors to be quantized. Should be of shape (BS, C) where BS is a combination of batch and
                spatial/temporal dimensions.

        Returns:
            z: Quantized vectors of shape (BS, C).
            codebook_loss: Codebook loss.
            commitment_loss: Commitment loss.
            perplexity: Perplexity of the codebook usage.
        """
        # Compute distances
        distances = torch.cdist(x, self.vectors.weight)
        # (BS, num_vectors)

        # Find nearest vectors
        indices = torch.argmin(distances, dim=1)
        # (BS,)

        # Quantize
        z: torch.Tensor = self.vectors(indices)
        # (BS, C)

        # Perform EMA
        if self.training and self.config.use_ema:
            self._perform_ema(x, indices)

        # Loss calculations
        codebook_loss, commitment_loss = self.calculate_losses(x, z)

        # Allow gradients to propagate using straight-through estimator
        z = x + (z - x).detach()
        # (BS, C)

        # Calculate perplexity
        perplexity = self.calculate_perplexity(indices)

        # Update counters
        if self.training:
            self._update_counters(indices)

        return z, codebook_loss, commitment_loss, perplexity

    def revive_dead_vectors(self):
        """Revive dead vectors in the codebook by replacing them with noised commonly used vectors."""

        assert self.training, "revive_dead_vectors should only be called during training"
        revive_vector_mask = self.stale_counter >= self.config.revive_dead_vectors_after_n_steps
        if not revive_vector_mask.any():
            return

        revive_vectors_shape = self.vectors.weight[revive_vector_mask].shape
        num_revive_vectors = revive_vectors_shape[0]

        # Sample commonly used vectors from the codebook
        sampling_probabilities = self.usage_counter.clone().to(torch.float32)
        sampling_probabilities.clamp_(min=1e-9)  # Don't allow all zero probabilities
        selected_vectors_mask = torch.multinomial(
            sampling_probabilities, num_revive_vectors, replacement=True, generator=self.generator
        )
        selected_vectors = self.vectors(selected_vectors_mask).detach()

        # Add noise to the selected vectors
        noise = torch.empty_like(selected_vectors).normal_(
            generator=self.generator
        )  # This is because current randn_like does not support generator input
        with torch.no_grad():
            std = self._estimate_codebook_distance() * 0.1  # https://openreview.net/pdf?id=HkGGfhC5Y7
        noised_selected_vectors = selected_vectors + noise * std

        # Replace dead vectors with noised selected vectors
        self.vectors.weight.data[revive_vector_mask] = noised_selected_vectors

        self.usage_counter[revive_vector_mask] = 0
        self.stale_counter[revive_vector_mask] = 0

        if self.config.use_ema:
            # Also update the EMA buffers for these vectors
            self.ema_vectors.data[revive_vector_mask] = noised_selected_vectors
            self.cluster_size.data[revive_vector_mask] = 0

    def forward(
        self, x: torch.Tensor, channels_first: bool = None
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Quantize the input tensor using the codebook. Update the codebook vectors if using EMA. Revive dead vectors
        if applicable..

        Args:
            x: Input tensor to be quantized. Should be of shape (B, ..., C) if channels_first is False, (B, C, ...) if
                channels_first is True or None with ndim != 3, else (B, T, C) if channels_first is None and ndim == 3.
            channels_first: Whether the input tensor has channels as the first dimension after batch dimension.

        Returns:
            z: Quantized tensor of the same shape as input.
            codebook_loss: Codebook loss.
            commitment_loss: Commitment loss.
            perplexity: Perplexity of the codebook usage.
        """
        #  If channels_first is None: x: (B, T, C) if ndim == 3, else (B, C, ...)
        #  If channels_first is True: x: (B, C, ...)
        #  If channels_first is False: x: (B, ..., C)

        if not self.generator_initalized:
            self._initialize_generator()

        shape = x.shape
        ndim = x.ndim
        B = shape[0]

        if channels_first is None:
            if ndim == 3:
                channels_first = False
            else:
                channels_first = True

        if channels_first:
            forward_pattern = "b c ... -> (b ...) c"
            backward_pattern = "(b s) c -> b c s"  # s stands for flattened spatial dimensions
        else:
            forward_pattern = "b ... c -> (b ...) c"
            backward_pattern = "(b s) c -> b s c"

        # Flatten input
        x = rearrange(x, forward_pattern).contiguous()
        # (BS, C)

        z, codebook_loss, commitment_loss, perplexity = self.quantize(x)

        # Return back to original shape
        z = rearrange(z, backward_pattern, b=B).contiguous().reshape(shape)
        # (x.shape)

        if self.training and self.config.revive_dead_vectors_after_n_steps > 0:
            self.revive_dead_vectors()

        return z, codebook_loss, commitment_loss, perplexity

    def _initialize_generator(self):
        """Initialize the random number generator to have the same seed across all devices."""
        assert not self.generator_initalized, "Generator has already been initialized"
        seed = torch.randint(0, 2**32, (1,))
        if dist.is_initialized():
            dist.all_reduce(seed, op=dist.ReduceOp.MIN)
        self.generator.manual_seed(seed.item())
        self.generator_initalized = True

    def _perform_ema(self, x: torch.Tensor, indices: torch.Tensor):
        """Perform EMA update of the codebook vectors.

        Args:
            x: Input vectors. Should be of shape (BS, C) where BS is a combination of batch and spatial/temporal
                dimensions.
            indices: Indices of the codebook vectors chosen for each input vector. Should be of shape (BS,).
        """
        # Create one-hot encodings for the selected indices
        encodings = self._one_hot_indices(indices)

        # Calculate new cluster sizes with EMA
        batch_cluster_size = encodings.sum(0)  # Sum over batch dimension

        # Synchronize across devices if using distributed training
        if dist.is_initialized():
            dist.all_reduce(batch_cluster_size, op=dist.ReduceOp.SUM)

        # Update cluster size using EMA
        self.cluster_size.data = self.cluster_size * self.decay + (1 - self.decay) * batch_cluster_size

        # Calculate sum of embeddings assigned to each cluster
        batch_ema_vectors = torch.matmul(encodings.t(), x)

        # Synchronize across devices if using distributed training
        if dist.is_initialized():
            dist.all_reduce(batch_ema_vectors, op=dist.ReduceOp.SUM)

        # Update EMA for vectors
        self.ema_vectors.data = self.ema_vectors * self.decay + (1 - self.decay) * batch_ema_vectors

        # Normalize EMA vectors by cluster size
        n = self.cluster_size.sum()
        cluster_size = (self.cluster_size + 1e-5) / (n + self.config.num_vectors * 1e-5) * n

        # Normalize codebook vectors using Laplace smoothing
        normalized_vectors = self.ema_vectors / cluster_size.unsqueeze(1)
        self.vectors.weight.data = normalized_vectors

    def _one_hot_indices(self, indices: torch.Tensor) -> torch.Tensor:
        """Convert indices to one-hot encodings."""
        encodings = torch.zeros(indices.shape[0], self.config.num_vectors, device=indices.device)
        encodings.scatter_(1, indices.unsqueeze(1), 1)
        return encodings

    def _update_counters(self, indices):
        """Update usage and stale counters based on the indices used in the current batch."""
        # Create a tensor which counts the number of times a vector has been used
        used_vector_indices, counts = torch.unique(indices, return_counts=True)
        usage_counter_increment = torch.zeros_like(self.usage_counter)
        usage_counter_increment[used_vector_indices] = counts

        # Synchronise the usage counts across all devices
        if dist.is_initialized():
            dist.all_reduce(usage_counter_increment, op=dist.ReduceOp.SUM)

        # Don't allow counters to exceed maximum possible values
        approximate_max_value = int(torch.iinfo(torch.long).max * 0.5)
        self.usage_counter.clamp_(max=approximate_max_value)
        self.stale_counter.clamp_(max=approximate_max_value)

        # Update usage counter
        self.usage_counter += usage_counter_increment

        # Identify vectors that were not used across all devices
        stale_counter_increment = torch.zeros_like(self.stale_counter)
        stale_counter_increment[usage_counter_increment == 0] = 1

        # Incrememnt counts of stale vectors and reset counts of used vectors
        self.stale_counter += stale_counter_increment
        self.stale_counter[usage_counter_increment > 0] = 0

    def _estimate_codebook_distance(self, max_sample=500) -> torch.Tensor:
        """Estimate mean distance between codebook vectors"""
        with torch.no_grad():
            vectors_weight = self.vectors.weight
            if self.vectors.weight.shape[0] > max_sample:
                # Sample a subset for efficiency
                idx = torch.randperm(self.vectors.weight.shape[0], generator=self.generator)[:max_sample]
                vectors_weight = self.vectors.weight[idx]

            distances = torch.cdist(vectors_weight, vectors_weight)
            mask = ~torch.eye(distances.shape[0], dtype=torch.bool, device=distances.device)  # Exclude self-distances
            codebook_distance = distances[mask].mean()

        return codebook_distance

In [None]:
test = Codebook(num_vectors=32, dim=8, revive_dead_vectors_after_n_steps=3, use_ema=False)
display(test)

sample_input = torch.randn(2, 2**10, 8, requires_grad=True)
output = test(sample_input)
display([output[0].shape, *output[1:]])

sample_input = torch.randn(8, 8, 2, 2, 2, requires_grad=True)
output = test(sample_input)
display([test.usage_counter, test.stale_counter])
output = test(sample_input)
display([test.usage_counter, test.stale_counter])
output = test(sample_input)
display([test.usage_counter, test.stale_counter])
display([output[0].shape, *output[1:]])


[1;35mCodebook[0m[1m([0m
  [1m([0mvectors[1m)[0m: [1;35mEmbedding[0m[1m([0m[1;36m32[0m, [1;36m8[0m[1m)[0m
[1m)[0m

[1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1024[0m, [1;36m8[0m[1m][0m[1m)[0m, [1;35mtensor[0m[1m([0m[1;36m0[0m.[1m)[0m, [1;35mtensor[0m[1m([0m[1;36m0.6075[0m, [33mgrad_fn[0m=[1m<[0m[1;95mMeanBackward0[0m[1m>[0m[1m)[0m, [1;35mtensor[0m[1m([0m[1;36m25.3339[0m[1m)[0m[1m][0m


[1m[[0m
    [1;35mtensor[0m[1m([0m[1m[[0m [1;36m16[0m,  [1;36m31[0m, [1;36m113[0m, [1;36m106[0m, [1;36m189[0m,  [1;36m62[0m, [1;36m104[0m,  [1;36m83[0m,  [1;36m91[0m,  [1;36m70[0m,  [1;36m22[0m,   [1;36m7[0m,  [1;36m54[0m,  [1;36m39[0m,
         [1;36m60[0m,  [1;36m36[0m,  [1;36m73[0m,  [1;36m52[0m,   [1;36m5[0m,  [1;36m26[0m,  [1;36m24[0m, [1;36m112[0m,  [1;36m42[0m,  [1;36m12[0m,  [1;36m59[0m, [1;36m109[0m, [1;36m127[0m, [1;36m104[0m,
        [1;36m146[0m,  [1;36m39[0m,  [1;36m77[0m,  [1;36m22[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1m[[0m[1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m1[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m1[0m, [1;36m0[0m, [1;36m1[0m, [1;36m0[0m, [1;36m0[0m, [1;36m1[0m,
        [1;36m0[0m, [1;36m


[1m[[0m
    [1;35mtensor[0m[1m([0m[1m[[0m [1;36m18[0m,  [1;36m33[0m, [1;36m114[0m, [1;36m112[0m, [1;36m192[0m,  [1;36m67[0m, [1;36m105[0m,  [1;36m86[0m,  [1;36m93[0m,  [1;36m70[0m,  [1;36m26[0m,   [1;36m8[0m,  [1;36m57[0m,  [1;36m40[0m,
         [1;36m62[0m,  [1;36m37[0m,  [1;36m75[0m,  [1;36m53[0m,   [1;36m5[0m,  [1;36m29[0m,  [1;36m24[0m, [1;36m114[0m,  [1;36m45[0m,  [1;36m12[0m,  [1;36m60[0m, [1;36m110[0m, [1;36m128[0m, [1;36m109[0m,
        [1;36m149[0m,  [1;36m40[0m,  [1;36m80[0m,  [1;36m23[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1m[[0m[1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m2[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m2[0m, [1;36m0[0m, [1;36m2[0m, [1;36m0[0m, [1;36m0[0m, [1;36m2[0m,
        [1;36m0[0m, [1;36m


[1m[[0m
    [1;35mtensor[0m[1m([0m[1m[[0m [1;36m20[0m,  [1;36m35[0m, [1;36m115[0m, [1;36m118[0m, [1;36m195[0m,  [1;36m72[0m, [1;36m106[0m,  [1;36m89[0m,  [1;36m95[0m,   [1;36m0[0m,  [1;36m30[0m,   [1;36m9[0m,  [1;36m60[0m,  [1;36m41[0m,
         [1;36m64[0m,  [1;36m38[0m,  [1;36m77[0m,  [1;36m54[0m,   [1;36m0[0m,  [1;36m32[0m,   [1;36m0[0m, [1;36m116[0m,  [1;36m48[0m,   [1;36m0[0m,  [1;36m61[0m, [1;36m111[0m, [1;36m129[0m, [1;36m114[0m,
        [1;36m152[0m,  [1;36m41[0m,  [1;36m83[0m,  [1;36m24[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1m[[0m[1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m,
        [1;36m0[0m, [1;36m

[1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m8[0m, [1;36m8[0m, [1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m][0m[1m)[0m, [1;35mtensor[0m[1m([0m[1;36m0[0m.[1m)[0m, [1;35mtensor[0m[1m([0m[1;36m0.4716[0m, [33mgrad_fn[0m=[1m<[0m[1;95mMeanBackward0[0m[1m>[0m[1m)[0m, [1;35mtensor[0m[1m([0m[1;36m23.6258[0m[1m)[0m[1m][0m

In [None]:
test = Codebook(num_vectors=32, dim=8, revive_dead_vectors_after_n_steps=3, use_ema=True)
display(test)

sample_input = torch.randn(2, 2**10, 8, requires_grad=True)
output = test(sample_input)
display([output[0].shape, *output[1:]])

sample_input = torch.randn(8, 8, 2, 2, 2, requires_grad=True)
output = test(sample_input)
display([test.usage_counter, test.stale_counter])
output = test(sample_input)
display([test.usage_counter, test.stale_counter])
output = test(sample_input)
display([test.usage_counter, test.stale_counter])
display([output[0].shape, *output[1:]])


[1;35mCodebook[0m[1m([0m
  [1m([0mvectors[1m)[0m: [1;35mEmbedding[0m[1m([0m[1;36m32[0m, [1;36m8[0m[1m)[0m
[1m)[0m

[1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1024[0m, [1;36m8[0m[1m][0m[1m)[0m, [1;35mtensor[0m[1m([0m[1;36m0[0m.[1m)[0m, [1;35mtensor[0m[1m([0m[1;36m0.6198[0m, [33mgrad_fn[0m=[1m<[0m[1;95mMeanBackward0[0m[1m>[0m[1m)[0m, [1;35mtensor[0m[1m([0m[1;36m25.6295[0m[1m)[0m[1m][0m


[1m[[0m
    [1;35mtensor[0m[1m([0m[1m[[0m[1;36m117[0m,  [1;36m74[0m,  [1;36m61[0m,  [1;36m66[0m,  [1;36m75[0m,  [1;36m88[0m,  [1;36m65[0m,  [1;36m54[0m,  [1;36m57[0m,   [1;36m9[0m,   [1;36m7[0m, [1;36m166[0m,  [1;36m19[0m,  [1;36m19[0m,
         [1;36m30[0m,  [1;36m97[0m,  [1;36m91[0m, [1;36m152[0m,  [1;36m60[0m, [1;36m120[0m,  [1;36m14[0m,  [1;36m26[0m, [1;36m165[0m,  [1;36m51[0m,  [1;36m28[0m,  [1;36m12[0m,  [1;36m80[0m,  [1;36m48[0m,
         [1;36m78[0m,  [1;36m99[0m,  [1;36m32[0m,  [1;36m52[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1m[[0m[1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m1[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m1[0m, [1;36m0[0m, [1;36m0[0m, [1;36m1[0m, [1;36m0[0m, [1;36m0[0m,
        [1;36m0[0m, [1;36m


[1m[[0m
    [1;35mtensor[0m[1m([0m[1m[[0m[1;36m118[0m,  [1;36m77[0m,  [1;36m62[0m,  [1;36m68[0m,  [1;36m79[0m,  [1;36m92[0m,  [1;36m68[0m,  [1;36m57[0m,  [1;36m61[0m,   [1;36m9[0m,   [1;36m9[0m, [1;36m169[0m,  [1;36m21[0m,  [1;36m20[0m,
         [1;36m32[0m,  [1;36m99[0m,  [1;36m93[0m, [1;36m155[0m,  [1;36m60[0m, [1;36m124[0m,  [1;36m15[0m,  [1;36m26[0m, [1;36m173[0m,  [1;36m52[0m,  [1;36m29[0m,  [1;36m12[0m,  [1;36m82[0m,  [1;36m50[0m,
         [1;36m79[0m,  [1;36m99[0m,  [1;36m32[0m,  [1;36m54[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1m[[0m[1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m2[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m2[0m, [1;36m0[0m, [1;36m0[0m, [1;36m2[0m, [1;36m0[0m, [1;36m0[0m,
        [1;36m0[0m, [1;36m


[1m[[0m
    [1;35mtensor[0m[1m([0m[1m[[0m[1;36m119[0m,  [1;36m80[0m,  [1;36m63[0m,  [1;36m70[0m,  [1;36m83[0m,  [1;36m96[0m,  [1;36m71[0m,  [1;36m60[0m,  [1;36m65[0m,   [1;36m0[0m,  [1;36m11[0m, [1;36m172[0m,  [1;36m23[0m,  [1;36m21[0m,
         [1;36m34[0m, [1;36m101[0m,  [1;36m95[0m, [1;36m158[0m,   [1;36m0[0m, [1;36m128[0m,  [1;36m16[0m,   [1;36m0[0m, [1;36m181[0m,  [1;36m53[0m,  [1;36m30[0m,   [1;36m0[0m,  [1;36m84[0m,  [1;36m52[0m,
         [1;36m80[0m,   [1;36m0[0m,   [1;36m0[0m,  [1;36m56[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1m[[0m[1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m, [1;36m0[0m,
        [1;36m0[0m, [1;36m

[1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m8[0m, [1;36m8[0m, [1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m][0m[1m)[0m, [1;35mtensor[0m[1m([0m[1;36m0[0m.[1m)[0m, [1;35mtensor[0m[1m([0m[1;36m0.5265[0m, [33mgrad_fn[0m=[1m<[0m[1;95mMeanBackward0[0m[1m>[0m[1m)[0m, [1;35mtensor[0m[1m([0m[1;36m22.1967[0m[1m)[0m[1m][0m

# nbdev

In [None]:
!nbdev_export