**Imports & Setup**

In [1]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.amp import autocast, GradScaler
from pathlib import Path
from tqdm import tqdm
from random import shuffle

DATA_DIR = Path("prepared_data/batches/")
SAVE_DIR = Path("model/")
TRAIN_VAL_RATIO = 0.8
VALUE_LOSS_WEIGHT = 0.01
LR = 5e-5
EPOCHS = 30
BATCH_SIZE = 128

# Ensure save directory exists
SAVE_DIR.mkdir(parents=True, exist_ok=True)

**Custom PyTorch Dataset**

In [2]:
class ChessDataset(Dataset):
    """Custom PyTorch Dataset for loading preprocessed chess training samples."""

    def __init__(self, data_dir: Path, max_samples: int = 100_000):
        """
        Initializes the dataset by loading .pt batch files containing training samples.

        :param data_dir: Directory containing .pt files with serialized training data
        :param max_samples: Maximum number of samples to load into memory
        :raises RuntimeError: If loading any file fails
        """
        self.samples = []
        self.batch_paths = list(data_dir.glob("*.pt"))
        shuffle(self.batch_paths) # Randomize loading order

        loaded_samples = 0
        batches_loaded = 0

        for path in self.batch_paths:
            try:
                batch = torch.load(path, weights_only=False)
                batches_loaded += 1
            except Exception as e:
                raise RuntimeError(f"Failed to load {path.name}: {e}")

            # Accumulate individual samples from the loaded batch
            for sample in batch:
                self.samples.append(sample)
                loaded_samples += 1
                if loaded_samples >= max_samples:
                    break

            if loaded_samples >= max_samples:
                break

        print(f"Loaded {len(self.samples)} samples from {batches_loaded} batch files.")

    def __len__(self) -> int:
        """Returns the total number of loaded samples."""
        return len(self.samples)

    def __getitem__(self, idx: int):
        """
        Retrieves the sample at a specific index.

        :param idx: Index of the sample
        :return: Tuple of (board_tensor, move_index, result_value) as tensors
        """
        board_tensor, move_index, result_value = self.samples[idx]
        assert board_tensor.shape == (21, 8, 8), f"Bad tensor shape: {board_tensor.shape}"
        return (
            torch.as_tensor(board_tensor, dtype=torch.float32),
            torch.as_tensor(move_index, dtype=torch.long),
            torch.as_tensor(result_value, dtype=torch.float32)
        )

**Model Implementation**

In [3]:
class CoordConv(nn.Module):
    """Applies convolution after appending coordinate channels to the input tensor.

    Adds two additional channels representing normalized X and Y coordinates,
    allowing the network to better reason about spatial position.
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels + 2, # +2 for the coordinate channels
            out_channels,
            kernel_size,
            padding=padding,
            bias=False # Bias is unnecessary with normalization
        )

    def forward(self, x):
        batch, _, height, width = x.size()

        # Generate coordinate grids normalized to [-1, 1]
        xx_channel = torch.linspace(-1, 1, width, device=x.device).repeat(height, 1)
        yy_channel = torch.linspace(-1, 1, height, device=x.device).repeat(width, 1).t()

        # Expand to match batch size and concat as extra channels
        xx_channel = xx_channel.unsqueeze(0).expand(batch, -1, -1, -1)
        yy_channel = yy_channel.unsqueeze(0).expand(batch, -1, -1, -1)
        coords = torch.cat([xx_channel, yy_channel], dim=1)

        # Concatenate coordinates to input
        x = torch.cat([x, coords], dim=1)
        return self.conv(x)


class SEBlock(nn.Module):
    """Squeeze-and-Excitation block for channel-wise attention."""

    def __init__(self, channels, reduction=8):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        s = self.pool(x).view(b, c) # Global average pooling
        s = self.fc(s).view(b, c, 1, 1) # Channel recalibration
        return x * s


class ResidualBlock(nn.Module):
    """Basic residual block with optional SE block."""

    def __init__(self, channels, use_se=False):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.norm1 = nn.GroupNorm(8, channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.norm2 = nn.GroupNorm(8, channels)

        # Zero-initialize second norm for stable residual learning
        nn.init.zeros_(self.norm2.weight)
        nn.init.zeros_(self.norm2.bias)

        self.se = SEBlock(channels) if use_se else nn.Identity()

    def forward(self, x):
        residual = x
        x = F.relu(self.norm1(self.conv1(x)))
        x = self.norm2(self.conv2(x))
        x = self.se(x)
        return F.relu(x + residual)


class ChessNet(nn.Module):
    """Main chess policy-value network with CoordConv, residual trunk, and dual heads."""

    def __init__(self, input_channels=21, num_blocks=8, num_filters=96, policy_out_dim=1968):
        super().__init__()

        # Input: 21x8x8 tensor (board features)
        self.coord_conv = CoordConv(input_channels, num_filters)

        # Residual trunk with SE blocks every 3rd block
        blocks = []
        for i in range(num_blocks):
            use_se = (i % 3 == 2)
            blocks.append(ResidualBlock(num_filters, use_se=use_se))
        self.shared_blocks = nn.Sequential(*blocks)

        # Policy head predicts logits over all possible moves
        self.policy_head = nn.Sequential(
            nn.Conv2d(num_filters, 2, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(2 * 8 * 8, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, policy_out_dim)
        )

        # Value head predicts the game outcome [-1, 1]
        self.value_head = nn.Sequential(
            nn.Conv2d(num_filters, 1, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(8 * 8, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 1),
            nn.Tanh()  # Output in [-1, 1] for win/loss/draw
        )

    def forward(self, x):
        x = F.relu(self.coord_conv(x))
        x = self.shared_blocks(x)

        policy_logits = self.policy_head(x)
        value = self.value_head(x)

        return policy_logits, value.squeeze(-1)

**Validation**

In [4]:
def validate(model, val_loader, policy_loss_fn, value_loss_fn, device, top_k=5, use_amp=False):
    """
    Evaluates the model on the validation set.

    :param model: Trained policy-value model
    :param val_loader: DataLoader providing validation samples
    :param policy_loss_fn: Loss function for the policy head
    :param value_loss_fn: Loss function for the value head
    :param device: torch.device used for evaluation
    :param top_k: Top-K for accuracy computation, defaults to 5
    :param use_amp: Whether to use automatic mixed precision (AMP), defaults to False
    :return: Tuple of (avg_policy_loss, avg_value_loss, top1_accuracy, topk_accuracy)
    """
    model.eval()
    total_policy_loss = 0
    total_value_loss = 0
    total_correct_top1 = 0
    total_correct_topk = 0
    total_samples = 0

    val_bar = tqdm(val_loader, desc="[Validation]", unit="batch", leave=False)

    with torch.no_grad():
        for board_tensor, move_index, result_value in val_bar:
            board_tensor = board_tensor.to(device)
            move_index = move_index.to(device)
            result_value = result_value.to(device)

            # Forward pass with optional AMP
            with autocast(device_type=device.type, enabled=use_amp):
                policy_logits, value_preds = model(board_tensor)
                policy_loss = policy_loss_fn(policy_logits, move_index)
                value_loss = value_loss_fn(value_preds, result_value)

            # Top-1 accuracy
            predicted_top1 = torch.argmax(policy_logits, dim=1)
            correct_top1 = (predicted_top1 == move_index).sum().item()

            # Top-k accuracy
            topk_preds = torch.topk(policy_logits, top_k, dim=1).indices
            correct_topk = (topk_preds == move_index.unsqueeze(1)).any(dim=1).sum().item()

            # Update running totals
            batch_size = board_tensor.size(0)
            total_policy_loss += policy_loss.item() * batch_size
            total_value_loss += value_loss.item() * batch_size
            total_correct_top1 += correct_top1
            total_correct_topk += correct_topk
            total_samples += batch_size

            # Update averaged metrics for progress bar
            avg_policy = total_policy_loss / total_samples
            avg_value = total_value_loss / total_samples
            avg_total_loss = avg_policy + VALUE_LOSS_WEIGHT * avg_value
            acc_top1 = total_correct_top1 / total_samples
            acc_topk = total_correct_topk / total_samples

            val_bar.set_postfix({
                "policy_loss": f"{avg_policy:.4f}",
                "value_loss": f"{avg_value:.4f}",
                "total_loss": f"{avg_total_loss:.4f}",
                "top1_acc": f"{acc_top1:.4f}",
                f"top{top_k}_acc": f"{acc_topk:.4f}"
            })

    # Final validation stats
    tqdm.write(
        f"[Validation] Policy Loss: {avg_policy:.4f}, "
        f"Value Loss: {avg_value:.4f}, "
        f"Total Loss: {avg_total_loss:.4f}, "
        f"Top-1 Acc: {acc_top1:.4f}, Top-{top_k} Acc: {acc_topk:.4f}"
    )


    return avg_policy, avg_value, avg_total_loss, acc_top1, acc_topk

**Training Loop**

In [5]:
def train(model, train_loader, val_loader, optimizer, scheduler, policy_loss_fn, value_loss_fn, device, epochs=5, validate_after=False, save_best=False, use_amp=False):
    """
    Trains the policy-value network using supervised learning.

    :param model: The model to train
    :param train_loader: DataLoader for the training set
    :param val_loader: DataLoader for the validation set
    :param optimizer: Optimizer for training
    :param scheduler: Learning rate scheduler (ReduceLROnPlateau)
    :param policy_loss_fn: Loss function for policy head
    :param value_loss_fn: Loss function for value head
    :param device: torch.device to run training on
    :param epochs: Number of training epochs, defaults to 5
    :param validate_after: Whether to run validation after each epoch, defaults to False
    :param save_best: Whether to save the best model during training, defaults to False
    :param use_amp: Use automatic mixed precision (AMP), defaults to False
    """

    best_metric = float("-inf")  # Used to track the best model (by accuracy or loss)
    LOG_PATH = SAVE_DIR / "train_log.log"
    scaler = GradScaler(device=device.type, enabled=use_amp)

    # Initialize training log
    with open(LOG_PATH, "w") as log_file:
        log_file.write("epoch | train_policy_loss - train_value_loss - train_total_loss | val_policy_loss - val_value_loss - val_total_loss | top1_acc - top5_acc | saved\n")

    for epoch in range(epochs):
        model.train()
        total_policy_loss = 0
        total_value_loss = 0
        total_samples = 0
        saved_model = False

        epoch_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", leave=False)

        for board_tensor, move_index, result_value in epoch_bar:
            board_tensor = board_tensor.to(device)
            move_index = move_index.to(device)
            result_value = result_value.to(device)

            optimizer.zero_grad()

            # Forward pass with AMP
            with autocast(device_type=device.type, enabled=use_amp):
                policy_logits, value_preds = model(board_tensor)
                policy_loss = policy_loss_fn(policy_logits, move_index)
                value_loss = value_loss_fn(value_preds, result_value)
                loss = policy_loss + VALUE_LOSS_WEIGHT * value_loss

            # Backward and optimizer step
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # Update running loss stats
            batch_size = board_tensor.size(0)
            total_policy_loss += policy_loss.detach().item() * batch_size
            total_value_loss += value_loss.detach().item() * batch_size
            total_samples += batch_size

            avg_policy = total_policy_loss / total_samples
            avg_value = total_value_loss / total_samples
            avg_total_loss = avg_policy + VALUE_LOSS_WEIGHT * avg_value

            # Update progress bar
            epoch_bar.set_postfix({
                "policy_loss": f"{avg_policy:.4f}",
                "value_loss": f"{avg_value:.4f}",
                "total_loss": f"{avg_total_loss:.4f}"
            })

        tqdm.write(f"[Epoch {epoch+1}] Train Policy Loss: {avg_policy:.4f}, Value Loss: {avg_value:.4f}")

        val_policy_loss = val_value_loss = val_total_loss = acc_top1 = acc_topk = float("nan")

        if validate_after:
            # Evaluate on validation set
            val_policy_loss, val_value_loss, val_total_loss, acc_top1, acc_topk = validate(
                model, val_loader, policy_loss_fn, value_loss_fn, device, top_k=5, use_amp=use_amp
            )
            scheduler.step(val_total_loss)

            # Save best model by validation accuracy
            if save_best and acc_top1 > best_metric:
                best_metric = acc_top1
                model.eval()
                save_path = SAVE_DIR / f"model_epoch{epoch+1:02d}_acc{acc_top1:.4f}.pt"
                scripted = torch.jit.script(model.cpu())
                scripted.save(str(save_path))
                model.to(device)
                saved_model = True
                tqdm.write(f"Saved new best model (Top-1 Accuracy: {best_metric:.4f})")

        elif save_best:
            # Save best model by lowest training loss
            current_metric = -(avg_policy + VALUE_LOSS_WEIGHT * avg_value)
            if current_metric > best_metric:
                best_metric = current_metric
                model.eval()
                save_path = SAVE_DIR / f"model_epoch{epoch+1:02d}_trainloss{-current_metric:.4f}.pt"
                scripted = torch.jit.script(model.cpu())
                scripted.save(str(save_path))
                model.to(device)
                saved_model = True
                tqdm.write(f"Saved new best scripted model (Train Loss: {-best_metric:.4f})")

        # Append results to log
        with open(LOG_PATH, "a") as log_file:
            log_file.write(
                f"{epoch+1:02d} | {avg_policy:.4f} - {avg_value:.4f} - {avg_total_loss:.4f} | {val_policy_loss:.4f} - "
                f"{val_value_loss:.4f} - {val_total_loss:.4f} | {acc_top1:.4f} - {acc_topk:.4f} | {int(saved_model)}\n"
            )

**Run Code**

In [6]:
dataset = ChessDataset(DATA_DIR, max_samples=4_000_000)

# Compute sizes
total_samples = len(dataset)
train_size = int(total_samples * TRAIN_VAL_RATIO)
val_size = total_samples - train_size

# Split randomly
generator = torch.Generator().manual_seed(42)
train_set, val_set = random_split(dataset, [train_size, val_size], generator=generator)

# Create DataLoaders
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

Loaded 4000000 samples from 400 batch files.


In [7]:
# Create model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChessNet(num_blocks=8, num_filters=96).to(device)

# Loss functions
policy_loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
value_loss_fn = nn.MSELoss()

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

# LR scheduler
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5, # halve the LR
    patience=2 # wait 2 epochs with no improvement
)

# Extract policy and value head output sizes
policy_out_features = model.policy_head[-1].out_features # final Linear in policy head
value_out_features = model.value_head[-2].out_features # second to last Linear (before tanh)

print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"Number of policy output classes: {policy_out_features}")
print(f"Value head output size: {value_out_features}")
print(f"Using device: {device}")

Model parameters: 1,898,185
Number of policy output classes: 1968
Value head output size: 1
Using device: cuda


In [8]:
train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    policy_loss_fn=policy_loss_fn,
    value_loss_fn=value_loss_fn,
    device=device,
    epochs=EPOCHS,
    validate_after=True,
    save_best=True,
    use_amp=False
)

                                                                                                                              

[Epoch 1] Train Policy Loss: 4.6979, Value Loss: 0.8688


                                                                                                                                                                 

[Validation] Policy Loss: 3.9068, Value Loss: 0.8657, Total Loss: 3.9155, Top-1 Acc: 0.2471, Top-5 Acc: 0.5460
Saved new best model (Top-1 Accuracy: 0.2471)


                                                                                                                              

[Epoch 2] Train Policy Loss: 3.6953, Value Loss: 0.8615


                                                                                                                                                                 

[Validation] Policy Loss: 3.5646, Value Loss: 0.8583, Total Loss: 3.5732, Top-1 Acc: 0.3009, Top-5 Acc: 0.6292
Saved new best model (Top-1 Accuracy: 0.3009)


                                                                                                                              

[Epoch 3] Train Policy Loss: 3.4415, Value Loss: 0.8557


                                                                                                                                                                 

[Validation] Policy Loss: 3.3918, Value Loss: 0.8533, Total Loss: 3.4003, Top-1 Acc: 0.3327, Top-5 Acc: 0.6761
Saved new best model (Top-1 Accuracy: 0.3327)


                                                                                                                              

[Epoch 4] Train Policy Loss: 3.2917, Value Loss: 0.8508


                                                                                                                                                                 

[Validation] Policy Loss: 3.2810, Value Loss: 0.8471, Total Loss: 3.2895, Top-1 Acc: 0.3538, Top-5 Acc: 0.7051
Saved new best model (Top-1 Accuracy: 0.3538)


                                                                                                                              

[Epoch 5] Train Policy Loss: 3.1856, Value Loss: 0.8461


                                                                                                                                                                 

[Validation] Policy Loss: 3.1959, Value Loss: 0.8448, Total Loss: 3.2043, Top-1 Acc: 0.3706, Top-5 Acc: 0.7294
Saved new best model (Top-1 Accuracy: 0.3706)


                                                                                                                              

[Epoch 6] Train Policy Loss: 3.1049, Value Loss: 0.8426


                                                                                                                                                                 

[Validation] Policy Loss: 3.1418, Value Loss: 0.8429, Total Loss: 3.1502, Top-1 Acc: 0.3817, Top-5 Acc: 0.7459
Saved new best model (Top-1 Accuracy: 0.3817)


                                                                                                                              

[Epoch 7] Train Policy Loss: 3.0413, Value Loss: 0.8398


                                                                                                                                                                 

[Validation] Policy Loss: 3.0901, Value Loss: 0.8387, Total Loss: 3.0985, Top-1 Acc: 0.3900, Top-5 Acc: 0.7583
Saved new best model (Top-1 Accuracy: 0.3900)


                                                                                                                              

[Epoch 8] Train Policy Loss: 2.9899, Value Loss: 0.8376


                                                                                                                                                                 

[Validation] Policy Loss: 3.0613, Value Loss: 0.8347, Total Loss: 3.0696, Top-1 Acc: 0.3947, Top-5 Acc: 0.7674
Saved new best model (Top-1 Accuracy: 0.3947)


                                                                                                                              

[Epoch 9] Train Policy Loss: 2.9474, Value Loss: 0.8360


                                                                                                                                                                 

[Validation] Policy Loss: 3.0292, Value Loss: 0.8366, Total Loss: 3.0376, Top-1 Acc: 0.4014, Top-5 Acc: 0.7760
Saved new best model (Top-1 Accuracy: 0.4014)


                                                                                                                               

[Epoch 10] Train Policy Loss: 2.9112, Value Loss: 0.8348


                                                                                                                                                                 

[Validation] Policy Loss: 3.0070, Value Loss: 0.8387, Total Loss: 3.0154, Top-1 Acc: 0.4076, Top-5 Acc: 0.7806
Saved new best model (Top-1 Accuracy: 0.4076)


                                                                                                                               

[Epoch 11] Train Policy Loss: 2.8804, Value Loss: 0.8337


                                                                                                                                                                 

[Validation] Policy Loss: 2.9861, Value Loss: 0.8381, Total Loss: 2.9945, Top-1 Acc: 0.4116, Top-5 Acc: 0.7872
Saved new best model (Top-1 Accuracy: 0.4116)


                                                                                                                               

[Epoch 12] Train Policy Loss: 2.8537, Value Loss: 0.8327


                                                                                                                                                                 

[Validation] Policy Loss: 2.9763, Value Loss: 0.8338, Total Loss: 2.9846, Top-1 Acc: 0.4133, Top-5 Acc: 0.7902
Saved new best model (Top-1 Accuracy: 0.4133)


                                                                                                                               

[Epoch 13] Train Policy Loss: 2.8296, Value Loss: 0.8322


                                                                                                                                                                 

[Validation] Policy Loss: 2.9624, Value Loss: 0.8318, Total Loss: 2.9707, Top-1 Acc: 0.4169, Top-5 Acc: 0.7933
Saved new best model (Top-1 Accuracy: 0.4169)


                                                                                                                               

[Epoch 14] Train Policy Loss: 2.8083, Value Loss: 0.8311


                                                                                                                                                                 

[Validation] Policy Loss: 2.9506, Value Loss: 0.8317, Total Loss: 2.9589, Top-1 Acc: 0.4172, Top-5 Acc: 0.7953
Saved new best model (Top-1 Accuracy: 0.4172)


                                                                                                                               

[Epoch 15] Train Policy Loss: 2.7886, Value Loss: 0.8305


                                                                                                                                                                 

[Validation] Policy Loss: 2.9393, Value Loss: 0.8296, Total Loss: 2.9476, Top-1 Acc: 0.4216, Top-5 Acc: 0.7986
Saved new best model (Top-1 Accuracy: 0.4216)


                                                                                                                               

[Epoch 16] Train Policy Loss: 2.7710, Value Loss: 0.8300


                                                                                                                                                                 

[Validation] Policy Loss: 2.9275, Value Loss: 0.8292, Total Loss: 2.9358, Top-1 Acc: 0.4231, Top-5 Acc: 0.8011
Saved new best model (Top-1 Accuracy: 0.4231)


                                                                                                                               

[Epoch 17] Train Policy Loss: 2.7546, Value Loss: 0.8294


                                                                                                                                                                 

[Validation] Policy Loss: 2.9239, Value Loss: 0.8290, Total Loss: 2.9322, Top-1 Acc: 0.4257, Top-5 Acc: 0.8043
Saved new best model (Top-1 Accuracy: 0.4257)


                                                                                                                               

[Epoch 18] Train Policy Loss: 2.7397, Value Loss: 0.8291


                                                                                                                                                                 

[Validation] Policy Loss: 2.9165, Value Loss: 0.8287, Total Loss: 2.9248, Top-1 Acc: 0.4277, Top-5 Acc: 0.8048
Saved new best model (Top-1 Accuracy: 0.4277)


                                                                                                                               

[Epoch 19] Train Policy Loss: 2.7255, Value Loss: 0.8287


                                                                                                                                                                 

[Validation] Policy Loss: 2.9232, Value Loss: 0.8296, Total Loss: 2.9315, Top-1 Acc: 0.4267, Top-5 Acc: 0.8051


                                                                                                                               

[Epoch 20] Train Policy Loss: 2.7122, Value Loss: 0.8282


                                                                                                                                                                 

[Validation] Policy Loss: 2.9120, Value Loss: 0.8286, Total Loss: 2.9203, Top-1 Acc: 0.4274, Top-5 Acc: 0.8063


                                                                                                                               

[Epoch 21] Train Policy Loss: 2.6999, Value Loss: 0.8278


                                                                                                                                                                 

[Validation] Policy Loss: 2.9119, Value Loss: 0.8280, Total Loss: 2.9202, Top-1 Acc: 0.4276, Top-5 Acc: 0.8067


                                                                                                                               

[Epoch 22] Train Policy Loss: 2.6883, Value Loss: 0.8275


                                                                                                                                                                 

[Validation] Policy Loss: 2.9069, Value Loss: 0.8272, Total Loss: 2.9151, Top-1 Acc: 0.4312, Top-5 Acc: 0.8086
Saved new best model (Top-1 Accuracy: 0.4312)


                                                                                                                               

[Epoch 23] Train Policy Loss: 2.6770, Value Loss: 0.8274


                                                                                                                                                                 

[Validation] Policy Loss: 2.9110, Value Loss: 0.8273, Total Loss: 2.9193, Top-1 Acc: 0.4312, Top-5 Acc: 0.8090
Saved new best model (Top-1 Accuracy: 0.4312)


                                                                                                                               

[Epoch 24] Train Policy Loss: 2.6669, Value Loss: 0.8272


                                                                                                                                                                 

[Validation] Policy Loss: 2.9083, Value Loss: 0.8291, Total Loss: 2.9166, Top-1 Acc: 0.4298, Top-5 Acc: 0.8091


                                                                                                                               

[Epoch 25] Train Policy Loss: 2.6569, Value Loss: 0.8270


                                                                                                                                                                 

[Validation] Policy Loss: 2.9052, Value Loss: 0.8281, Total Loss: 2.9135, Top-1 Acc: 0.4327, Top-5 Acc: 0.8090
Saved new best model (Top-1 Accuracy: 0.4327)


                                                                                                                               

[Epoch 26] Train Policy Loss: 2.6473, Value Loss: 0.8271


                                                                                                                                                                 

[Validation] Policy Loss: 2.9005, Value Loss: 0.8299, Total Loss: 2.9088, Top-1 Acc: 0.4343, Top-5 Acc: 0.8109
Saved new best model (Top-1 Accuracy: 0.4343)


                                                                                                                               

[Epoch 27] Train Policy Loss: 2.6384, Value Loss: 0.8269


                                                                                                                                                                 

[Validation] Policy Loss: 2.9053, Value Loss: 0.8289, Total Loss: 2.9136, Top-1 Acc: 0.4320, Top-5 Acc: 0.8108


                                                                                                                               

[Epoch 28] Train Policy Loss: 2.6297, Value Loss: 0.8267


                                                                                                                                                                 

[Validation] Policy Loss: 2.9038, Value Loss: 0.8283, Total Loss: 2.9121, Top-1 Acc: 0.4335, Top-5 Acc: 0.8115


                                                                                                                               

[Epoch 29] Train Policy Loss: 2.6216, Value Loss: 0.8266


                                                                                                                                                                 

[Validation] Policy Loss: 2.9111, Value Loss: 0.8273, Total Loss: 2.9193, Top-1 Acc: 0.4318, Top-5 Acc: 0.8099


                                                                                                                               

[Epoch 30] Train Policy Loss: 2.5476, Value Loss: 0.8245


                                                                                                                                                                 

[Validation] Policy Loss: 2.8844, Value Loss: 0.8265, Total Loss: 2.8926, Top-1 Acc: 0.4378, Top-5 Acc: 0.8165
Saved new best model (Top-1 Accuracy: 0.4378)
