In [1]:
config_path = "../../configs/train_live/local_vast.toml"
learning_rate = 1e-3
# model_file = "s3://alpha-blokus/full_v2/training_v2/010188896.pth"
model_file = "s3://alpha-blokus/full_v2/training_v3/010260000.pth"
training_file_path = "s3://alpha-blokus/full_v2/games/"
local_game_mirror = "/Users/shivamsarodia/Dev/AlphaBlokus/data/s3_mirrors/full_v2/games"

In [2]:
from alphablokus.files import list_files
from alphablokus.configs import GameConfig, NetworkConfig
from alphablokus.train_utils import load_initial_state, get_loss
from alphablokus.data_loaders import StaticListFileProvider, BufferedGameBatchDataset, build_streaming_dataloader

In [3]:
game_config = GameConfig(config_path)
network_config = NetworkConfig(config_path)

In [4]:
import random
random.seed(42)

all_files = sorted(list_files(training_file_path, ".bin"))
train_files = all_files[-20:]
random.shuffle(train_files)

In [5]:
file_provider = StaticListFileProvider(train_files)
dataset = BufferedGameBatchDataset(
    game_config,
    file_provider,
    128,
    4,
    local_cache_dir=local_game_mirror,
    cleanup_local_files=False,
)
dataloader = build_streaming_dataloader(
    dataset,
    num_workers=2,
    prefetch_factor=60,
)

In [6]:
model, optimizer = load_initial_state(
    network_config,
    game_config,
    learning_rate=1e-3,
    device="mps",
    training_file=model_file,
    skip_loading_optimizer=True,
    optimizer_type="adam",
)

Loading training state from: s3://alpha-blokus/full_v2/training_v3/010260000.pth
Created temporary directory: /var/folders/np/v76cnj490z525wk67wqh68dc0000gn/T/tmpkl7rrqy5


In [7]:
from tqdm import tqdm
import torch

model.eval()
torch.no_grad()

total_value_loss = 0
total_policy_loss = 0

for batch in tqdm(dataloader):
    loss, value_loss, policy_loss = get_loss(
        batch,
        model,
        device="mps",
        policy_loss_weight=0.158,
    )
    total_value_loss += value_loss.item()
    total_policy_loss += policy_loss.item()

print(f"Total value loss: {total_value_loss}")
print(f"Total policy loss: {total_policy_loss}")

0it [00:00, ?it/s]

1553it [01:23, 18.68it/s]

Total value loss: 1664.9245355725288
Total policy loss: 672.4627363085747





### The old, OG model:

- Total value loss: 1672.6777361631393
- Total policy loss: 688.988334029913


### The new, retrained model

- Total value loss: 1664.9245355725288
- Total policy loss: 672.4627363085747

In [10]:
1672.6777361631393 - 1664.9245355725288

7.753200590610504