# Dataset normalization analysis

In this notebook we load the train dataset and calculate the channel means and standard deviations. This is necessary for normalizing the data before passing it into the backbone.

In [1]:
from mmengine.config import Config
from mmengine.runner import Runner
import torch
from tqdm import tqdm

runner_cfg = Config.fromfile(
    "configs/models/vit-s-p16_videomaev2-vit-g-dist-k710-pre_16x4x1_kinetics-400_base.py"
)

runner = Runner.from_cfg(runner_cfg)

train_dataloader_cfg = runner.cfg.train_dataloader[0]

# My formulas don't work in parallel :(  (I think)
train_dataloader_cfg["num_workers"] = 0
train_dataloader_cfg["persistent_workers"] = False

train_dataloader = runner.build_dataloader(train_dataloader_cfg, seed=42)

11/27 17:44:55 - mmengine - [4m[97mINFO[0m - 
------------------------------------------------------------
System environment:
    sys.platform: darwin
    Python: 3.10.13 | packaged by conda-forge | (main, Oct 26 2023, 18:09:17) [Clang 16.0.6 ]
    CUDA available: False
    numpy_random_seed: 348990617
    GCC: Apple clang version 15.0.0 (clang-1500.0.40.1)
    PyTorch: 2.1.1
    PyTorch compiling details: PyTorch built with:
  - GCC 4.2
  - C++ Version: 201703
  - clang 13.1.6
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: NO AVX
  - Build settings: BLAS_INFO=accelerate, BUILD_TYPE=Release, CXX_COMPILER=/Applications/Xcode_13.3.1.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/clang++, CXX_FLAGS= -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_PYTORCH_METAL_EXPORT -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DUSE_COREML

I use incremental mean and std calculation to avoid numerical instability.

[Batch statistics   ](https://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html)

In [2]:
class StatsRecorder:
    def __init__(self, epsilon=1e-3) -> None:
        self.nobservations = 0
        self.epsilon = epsilon
        self.threshold_counter = 0

    def update(self, x):
        if self.nobservations == 0:
            self.mean = x.mean(dim=1)
            self.std = x.std(dim=1)
            self.nobservations = x.shape[1]
        else:
            newmean = x.mean(dim=1)
            newstd = x.std(dim=1)

            if torch.all(torch.abs(newmean - self.mean) < self.epsilon) and torch.all(
                torch.abs(newstd - self.std) < self.epsilon
            ):
                self.threshold_counter += 1
                if self.threshold_counter >= 10:
                    print("std and mean are not changing anymore")
                    raise KeyboardInterrupt
            else:
                self.threshold_counter = 0

            m = self.nobservations * 1.0
            n = x.shape[1]

            tmp = self.mean

            self.mean = m / (m + n) * tmp + n / (m + n) * newmean
            self.std = (
                m / (m + n) * self.std**2
                + n / (m + n) * newstd**2
                + m * n / (m + n) ** 2 * (tmp - newmean) ** 2
            )
            self.std = torch.sqrt(self.std)

            self.nobservations += n

In [3]:
sr = StatsRecorder()

pbar = tqdm(train_dataloader)

try:
    for i, batch in enumerate(pbar):
        # [B, 3, T, H, W]
        images = batch["inputs"][0].type(torch.float32)
        x = images.view(3, -1)
        sr.update(x)
        pbar.set_postfix({"mean": sr.mean.tolist(), "std": sr.std.tolist()})

except KeyboardInterrupt:
    print(f"Aborted after {i} batches")
    pass
finally:
    mean = sr.mean.tolist()
    std = sr.std.tolist()


print(f"mean: {mean}")
print(f"std: {std}")

  0%|          | 0/10764 [00:00<?, ?it/s]



100%|██████████| 10764/10764 [1:31:25<00:00,  1.96it/s, mean=[102.17311096191406, 98.78225708007812, 92.68714141845703], std=[58.04566192626953, 57.004024505615234, 57.3704948425293]]    


mean: [102.17311096191406, 98.78225708007812, 92.68714141845703]
std: [58.04566192626953, 57.004024505615234, 57.3704948425293]


 35%|███▍      | 3736/10764 [31:40<1:21:31,  1.44it/s, mean=[102.19119262695312, 98.79428100585938, 92.73539733886719], std=[58.05506134033203, 57.04253387451172, 57.42849349975586]]   

 49%|████▉     | 5267/10764 [44:25<32:22,  2.83it/s, mean=[102.20238494873047, 98.83967590332031, 92.79312133789062], std=[57.975608825683594, 56.944862365722656, 57.31865310668945]]

 63%|██████▎   | 6764/10764 [56:58<55:57,  1.19it/s, mean=[102.23499298095703, 98.87828063964844, 92.82234191894531], std=[58.03319549560547, 56.99748611450195, 57.366111755371094]]   

 78%|███████▊  | 8419/10764 [1:10:44<21:16,  1.84it/s, mean=[102.10855102539062, 98.75061798095703, 92.70094299316406], std=[58.065406799316406, 57.022430419921875, 57.378849029541016]]   

 100%|██████████| 10764/10764 [1:31:25<00:00,  1.96it/s, mean=[102.17311096191406, 98.78225708007812, 92.68714141845703], std=[58.04566192626953, 57.004024505615234, 57.3704948425293]]    
mean: [102.17311096191406, 98.78225708007812, 92.68714141845703]
std: [58.04566192626953, 57.004024505615234, 57.3704948425293]

In [6]:
sr.nobservations / 10764

802816.0