In [1]:
from alphablokus.files import latest_file, list_files
from alphablokus.configs import NetworkConfig, GameConfig
from alphablokus.train_utils import load_initial_state

initial_training_file = latest_file("s3://alpha-blokus/full_v2/training_v2/", ".pth")

In [2]:
network_config = NetworkConfig("../../configs/train_offline/simulate_10block_64chan.toml")
game_config = GameConfig("../../configs/train_offline/simulate_10block_64chan.toml")

In [3]:
model, optimizer, samples_last_trained = load_initial_state(
    network_config,
    game_config,
    learning_rate=1e-3,
    device="mps",
    training_file=initial_training_file,
)

Loading training state from: s3://alpha-blokus/full_v2/training_v2/010188896.pth
Created temporary directory: /var/folders/np/v76cnj490z525wk67wqh68dc0000gn/T/tmpqdos2v4v


In [4]:
from torch import nn
import torch

def bn_stat_summary(model):
    s = []
    for name, m in model.named_modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            s.append((name,
                      m.running_mean.float().mean().item(),
                      m.running_var.float().mean().item()))
    return s

def set_bn_train_only(model: nn.Module):
    model.eval()  # everything eval by default (disables dropout)
    for m in model.modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            m.train()  # BN updates running_mean/var
    return model

def reset_bn_running_stats(model: nn.Module):
    for m in model.modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            m.reset_running_stats()

In [5]:
initial_stat_summary = bn_stat_summary(model)
initial_stat_summary

[('convolutional_block.1', -0.22940373420715332, 1.1732454299926758),
 ('residual_blocks.0.convolutional_block.1',
  -6.906782150268555,
  107.43310546875),
 ('residual_blocks.0.convolutional_block.4',
  -1.9352855682373047,
  16.140779495239258),
 ('residual_blocks.1.convolutional_block.1',
  -8.21483325958252,
  183.066650390625),
 ('residual_blocks.1.convolutional_block.4',
  -1.3503600358963013,
  11.196014404296875),
 ('residual_blocks.2.convolutional_block.1',
  -15.7227144241333,
  392.29315185546875),
 ('residual_blocks.2.convolutional_block.4',
  -3.702145576477051,
  23.113161087036133),
 ('residual_blocks.3.convolutional_block.1',
  -17.851551055908203,
  632.838134765625),
 ('residual_blocks.3.convolutional_block.4',
  -3.489837646484375,
  33.3912467956543),
 ('residual_blocks.4.convolutional_block.1',
  -14.017216682434082,
  741.4307250976562),
 ('residual_blocks.4.convolutional_block.4',
  -2.6696581840515137,
  39.646366119384766),
 ('residual_blocks.5.convolutional_bl

In [6]:
from tqdm import tqdm

@torch.no_grad()
def calibrate_bn(model, dataloader, device, num_batches=2000, reset_stats=True):
    model.to(device)

    if reset_stats:
        reset_bn_running_stats(model)

    set_bn_train_only(model)

    # Important: use a reasonably large batch size if possible
    # so BN estimates are stable.

    for i, batch in tqdm(enumerate(dataloader)):
        if i >= num_batches:
            break

        # adapt this to your batch structure
        board, expected_value, expected_policy, valid_policy_mask = batch
        x = board.to(device)

        _ = model(x)  # forward only, updates BN running stats

    model.eval()  # lock everything for inference
    return model

In [7]:
all_files = list_files("s3://alpha-blokus/full_v2/games_against_pentobi/", ".bin")
all_files = sorted(all_files)

In [8]:
from alphablokus.data_loaders import (
    BufferedGameBatchDataset,
    StaticListFileProvider,
    build_streaming_dataloader,
)

train_files = all_files[-40:]

file_provider = StaticListFileProvider(train_files)
dataset = BufferedGameBatchDataset(
    game_config,
    file_provider,
    128,
    3,
    local_cache_dir="/Users/shivamsarodia/Dev/AlphaBlokus/data/s3_mirrors/full_v2/games",
    cleanup_local_files=False,
)
dataloader = build_streaming_dataloader(
    dataset,
    num_workers=0,
    prefetch_factor=1,
)

In [9]:
calibrate_bn(model, dataloader, device="mps")

119it [00:13,  8.73it/s]


NeuralNet(
  (convolutional_block): Sequential(
    (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (residual_blocks): ModuleList(
    (0-9): 10 x ResidualBlock(
      (convolutional_block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (value_head): ValueHead(
    (layers): Sequential(
      (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): 

In [10]:
final_stat_summary = bn_stat_summary(model)
final_stat_summary

[('convolutional_block.1', -0.2223816215991974, 1.1303319931030273),
 ('residual_blocks.0.convolutional_block.1',
  -6.831547737121582,
  108.22123718261719),
 ('residual_blocks.0.convolutional_block.4',
  -1.9040334224700928,
  15.855646133422852),
 ('residual_blocks.1.convolutional_block.1',
  -8.130658149719238,
  183.28887939453125),
 ('residual_blocks.1.convolutional_block.4',
  -1.3152978420257568,
  11.172792434692383),
 ('residual_blocks.2.convolutional_block.1',
  -15.585186004638672,
  397.4678955078125),
 ('residual_blocks.2.convolutional_block.4',
  -3.6276445388793945,
  22.314590454101562),
 ('residual_blocks.3.convolutional_block.1',
  -17.648792266845703,
  636.7513427734375),
 ('residual_blocks.3.convolutional_block.4',
  -3.4205141067504883,
  31.89751625061035),
 ('residual_blocks.4.convolutional_block.1',
  -13.848779678344727,
  743.3154296875),
 ('residual_blocks.4.convolutional_block.4',
  -2.5965676307678223,
  35.730438232421875),
 ('residual_blocks.5.convoluti

In [11]:
for final, initial in zip(final_stat_summary, initial_stat_summary):
    assert final[0] == initial[0]
    print(final[1], initial[1])
    print(final[2], initial[2])
    print()


-0.2223816215991974 -0.22940373420715332
1.1303319931030273 1.1732454299926758

-6.831547737121582 -6.906782150268555
108.22123718261719 107.43310546875

-1.9040334224700928 -1.9352855682373047
15.855646133422852 16.140779495239258

-8.130658149719238 -8.21483325958252
183.28887939453125 183.066650390625

-1.3152978420257568 -1.3503600358963013
11.172792434692383 11.196014404296875

-15.585186004638672 -15.7227144241333
397.4678955078125 392.29315185546875

-3.6276445388793945 -3.702145576477051
22.314590454101562 23.113161087036133

-17.648792266845703 -17.851551055908203
636.7513427734375 632.838134765625

-3.4205141067504883 -3.489837646484375
31.89751625061035 33.3912467956543

-13.848779678344727 -14.017216682434082
743.3154296875 741.4307250976562

-2.5965676307678223 -2.6696581840515137
35.730438232421875 39.646366119384766

-12.588321685791016 -12.688703536987305
802.24169921875 802.0361938476562

-5.48783540725708 -5.575998306274414
109.77433776855469 121.60429382324219

-8.90

In [41]:
from alphablokus.train_utils import save_model_and_state

save_model_and_state(
    model=model,
    optimizer=optimizer,
    name="updated_batch_norm",
    model_directory="s3://alpha-blokus/full_v2/models_simulated/",
    training_directory="s3://alpha-blokus/full_v2/training_simulated/",
    device="mps",
    add_timestamp=True,
)

[2026-01-19 21:20:24] Saving model to: s3://alpha-blokus/full_v2/models_simulated/updated_batch_norm_1768886425.onnx
[torch.onnx] Obtain model graph for `NeuralNet([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `NeuralNet([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 2 of general pattern rewrite rules.
[2026-01-19 21:20:28] Saving training state to: s3://alpha-blokus/full_v2/training_simulated/updated_batch_norm_1768886425.pth
