In [2]:
from datasets import load_dataset, load_from_disk, DatasetDict, Dataset
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
import os
import numpy as np
# import torch.nn.functional as F
from torch.utils.data import DataLoader
from preprocessing import fen_to_piece_maps
from tqdm import tqdm

torch.set_float32_matmul_precision('medium')

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE == torch.device("cpu"):
    print("Using CPU, not recommended")

In [22]:
def collate_fn(batch):
    batch_fens = [example['fen'] for example in batch]
    labels = torch.tensor(
        [example['target'] for example in batch],
        dtype=torch.float32
    )

    winning_labels = []
    inputs = []

    for example in batch:
        fen = example['fen']
        target = example['target']
        
        # Parse side to move
        stm = fen.split()[1]  # 'w' or 'b'

        # Compute winning label
        if (stm == 'w' and target > 0) or (stm == 'b' and target < 0):
            winning = 1
        else:
            winning = 0

        winning_labels.append(winning)

        # Process input
        inputs.append(torch.tensor(fen_to_piece_maps(fen), dtype=torch.float32))

    inputs = torch.stack(inputs)
    winning_labels = torch.tensor(winning_labels, dtype=torch.float32)  # Make sure it's float32 for BCEWithLogitsLoss etc.

    return inputs, labels, winning_labels

In [23]:
train_dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m/train"))
val_dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m/validation"))
test_dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m/test"))

num_training_examples = len(train_dataset)

train_dataset = train_dataset.to_iterable_dataset(num_shards=32)
val_dataset = val_dataset.to_iterable_dataset()
test_dataset = test_dataset.to_iterable_dataset()

train_dataset = train_dataset.shuffle(buffer_size=10000)
val_dataset = val_dataset.shuffle(buffer_size=10000)
test_dataset = test_dataset.shuffle(buffer_size=10000)

In [None]:
# dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m"))

# # Split the dataset into train, validation, and test sets
# train_size = int(0.98 * len(dataset))
# val_size = int(0.01 * len(dataset))
# test_size = len(dataset) - train_size - val_size

# train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
#     dataset, [train_size, val_size, test_size]
# )

# # Convert to iterable datasets
# train_dataset = train_dataset.dataset.to_iterable_dataset(num_shards=32)
# val_dataset = val_dataset.dataset.to_iterable_dataset()
# test_dataset = test_dataset.dataset.to_iterable_dataset()

# # Shuffle the datasets
# train_dataset = train_dataset.shuffle(buffer_size=10000)
# val_dataset = val_dataset.shuffle(buffer_size=10000)
# test_dataset = test_dataset.shuffle(buffer_size=10000)

In [24]:
train_loader = DataLoader(train_dataset, batch_size=256, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=256, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=256, collate_fn=collate_fn)

In [25]:
import torch
import torch.nn as nn

class FFN(nn.Module):
    def __init__(self, input_size, hidden_size=256, dropout_prob=0.2):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout_prob)
        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        return x

class ChessEvalMLP(nn.Module):
    def __init__(self, input_planes=17, hidden_size=256, dropout_prob=0.2):
        super().__init__()
        input_size = input_planes * 8 * 8  # because of piece maps (channels * 8 * 8)

        self.backbone = FFN(input_size, hidden_size=hidden_size, dropout_prob=dropout_prob)

        # Two heads
        self.eval_head = nn.Linear(hidden_size // 2, 1)  # For regression
        self.win_head = nn.Linear(hidden_size // 2, 1)   # For binary classification

    def forward(self, x):
        x = torch.flatten(x, 1)  # (batch_size, channels*8*8)
        features = self.backbone(x)
        eval_output = self.eval_head(features)     # Regression output (centipawns)
        win_output = self.win_head(features)       # Classification output (winning/not)

        return eval_output, win_output

In [26]:
input_tensor = torch.randn(32, 17, 8, 8)  # A batch of 32 chessboard positions

model = ChessEvalMLP()
output = model(input_tensor)

print(output[0].shape, output[1].shape)

torch.Size([32, 1]) torch.Size([32, 1])


In [27]:
NUM_EPOCHS = 3
total_iters = NUM_EPOCHS * ((10_000_000 // 256) + 1)

model = ChessEvalMLP().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=total_iters, eta_min=1e-4)
eval_criterion = nn.MSELoss()
win_criterion = nn.BCEWithLogitsLoss()

steps, train_losses, train_maes, val_losses, val_maes = [], [], [], [], []       # for tracking performance

In [28]:
best_val_loss = float('inf')
best_val_mae = float('inf')
num_iterations = 0
patience_counter = 0

PATIENCE = 15000
MIN_IMPROVEMENT = 5e-4
VAL_ITERS = 500
LOG_ITERS = 100

early_stop = False

steps = []
train_losses, train_maes, train_win_accs = [], [], []
val_losses, val_maes, val_win_accs = [], [], []

for epoch in range(1, NUM_EPOCHS + 1):
    if early_stop:
        break

    model.train()
    total_loss = 0.0
    total_mae = 0.0
    total_win_correct = 0
    total_samples = 0

    for inputs, labels, winning_labels in tqdm(train_loader, desc=f"Epoch {epoch}"):
        inputs, labels, winning_labels = inputs.to(DEVICE), labels.to(DEVICE), winning_labels.to(DEVICE)

        optimizer.zero_grad()
        eval_pred, win_pred = model(inputs)

        loss_eval = eval_criterion(eval_pred.squeeze(), labels)
        loss_win = win_criterion(win_pred.squeeze(), winning_labels)

        loss = loss_eval + loss_win

        loss.backward()
        optimizer.step()
        scheduler.step()

        batch_loss = loss.item()
        batch_mae = torch.mean(torch.abs(eval_pred.squeeze() - labels)).item()

        # Calculate batch binary accuracy
        win_pred_labels = (torch.sigmoid(win_pred.squeeze()) > 0.5).float()
        batch_win_acc = (win_pred_labels == winning_labels).float().mean().item()

        total_loss += batch_loss * inputs.size(0)
        total_mae += batch_mae * inputs.size(0)
        total_win_correct += (win_pred_labels == winning_labels).sum().item()
        total_samples += inputs.size(0)
        num_iterations += 1

        if num_iterations % LOG_ITERS == 0:
            avg_train_loss = total_loss / total_samples
            avg_train_mae = total_mae / total_samples
            avg_train_win_acc = total_win_correct / total_samples

            print(f"Step {num_iterations} — train_loss: {avg_train_loss:.4f}, train_mae: {avg_train_mae:.4f}, train_win_acc: {avg_train_win_acc:.4f}")

    # Validation every VAL_ITERS iterations
        if num_iterations % VAL_ITERS == 0:
            model.eval()
            val_loss_sum = 0.0
            val_mae_sum = 0.0
            val_win_correct = 0
            val_total_samples = 0

            with torch.no_grad():
                for val_inputs, val_labels, val_winning_labels in val_loader:
                    val_inputs, val_labels, val_winning_labels = val_inputs.to(DEVICE), val_labels.to(DEVICE), val_winning_labels.to(DEVICE)

                    val_pred, val_win_pred = model(val_inputs)

                    val_loss = eval_criterion(val_pred.squeeze(), val_labels)
                    val_mae = torch.mean(torch.abs(val_pred.squeeze() - val_labels))

                    val_loss_sum += val_loss.item() * val_inputs.size(0)
                    val_mae_sum += val_mae.item() * val_inputs.size(0)

                    val_win_preds = (torch.sigmoid(val_win_pred.squeeze()) > 0.5).float()
                    val_win_correct += (val_win_preds == val_winning_labels).sum().item()

                    val_total_samples += val_inputs.size(0)

            avg_val_loss = val_loss_sum / val_total_samples
            avg_val_mae = val_mae_sum / val_total_samples
            avg_val_win_acc = val_win_correct / val_total_samples

            avg_train_loss = total_loss / total_samples
            avg_train_mae = total_mae / total_samples
            avg_train_win_acc = total_win_correct / total_samples

            print(f"\n[Validation] Step {num_iterations} — train_loss: {avg_train_loss:.4f}, train_mae: {avg_train_mae:.4f}, train_win_acc: {avg_train_win_acc:.4f}, val_loss: {avg_val_loss:.4f}, val_mae: {avg_val_mae:.4f}, val_win_acc: {avg_val_win_acc:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")

            steps.append(num_iterations)
            train_losses.append(avg_train_loss)
            train_maes.append(avg_train_mae)
            train_win_accs.append(avg_train_win_acc)

            val_losses.append(avg_val_loss)
            val_maes.append(avg_val_mae)
            val_win_accs.append(avg_val_win_acc)

            # Checkpoint best
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), "best_chess_mlp.pth")
                print(f"Saved new best model after {num_iterations} iters")
            
            # Early stopping on validation MAE
            if avg_val_mae + MIN_IMPROVEMENT < best_val_mae:
                best_val_mae = avg_val_mae
                patience_counter = 0
                print(f"--------Validation MAE improved to {best_val_mae:.6f}--------")
            else:
                patience_counter += VAL_ITERS
                print(f"--------No significant MAE improvement for {patience_counter} iterations--------")

            if patience_counter >= PATIENCE:
                print(f"Early stopping triggered at {num_iterations} iterations (no significant MAE improvement)")
                early_stop = True
                break

            model.train()


Epoch 1: 102it [00:05, 20.42it/s]

Step 100 — train_loss: 0.8589, train_mae: 0.3063, train_win_acc: 0.5989


Epoch 1: 202it [00:10, 20.17it/s]

Step 200 — train_loss: 0.8243, train_mae: 0.2928, train_win_acc: 0.6230


Epoch 1: 301it [00:14, 20.57it/s]

Step 300 — train_loss: 0.8004, train_mae: 0.2834, train_win_acc: 0.6390


Epoch 1: 403it [00:19, 20.65it/s]

Step 400 — train_loss: 0.7808, train_mae: 0.2759, train_win_acc: 0.6511


Epoch 1: 497it [00:24, 20.64it/s]

Step 500 — train_loss: 0.7681, train_mae: 0.2709, train_win_acc: 0.6595


Epoch 1: 503it [01:09,  3.18s/it]


[Validation] Step 500 — train_loss: 0.7681, train_mae: 0.2709, train_win_acc: 0.6595, val_loss: 0.1269, val_mae: 0.2389, val_win_acc: 0.7005, LR: 0.001000
Saved new best model after 500 iters
--------Validation MAE improved to 0.238857--------


Epoch 1: 602it [01:14, 20.26it/s]

Step 600 — train_loss: 0.7571, train_mae: 0.2670, train_win_acc: 0.6652


Epoch 1: 703it [01:19, 19.73it/s]

Step 700 — train_loss: 0.7486, train_mae: 0.2636, train_win_acc: 0.6697


Epoch 1: 803it [01:24, 19.56it/s]

Step 800 — train_loss: 0.7410, train_mae: 0.2605, train_win_acc: 0.6734


Epoch 1: 903it [01:29, 19.70it/s]

Step 900 — train_loss: 0.7339, train_mae: 0.2578, train_win_acc: 0.6766


Epoch 1: 998it [01:34, 20.07it/s]

Step 1000 — train_loss: 0.7288, train_mae: 0.2560, train_win_acc: 0.6794


Epoch 1: 1002it [02:18,  4.21s/it]


[Validation] Step 1000 — train_loss: 0.7288, train_mae: 0.2560, train_win_acc: 0.6794, val_loss: 0.1148, val_mae: 0.2264, val_win_acc: 0.7100, LR: 0.001000
Saved new best model after 1000 iters
--------Validation MAE improved to 0.226386--------


Epoch 1: 1104it [02:23, 20.37it/s]

Step 1100 — train_loss: 0.7234, train_mae: 0.2538, train_win_acc: 0.6821


Epoch 1: 1203it [02:28, 20.54it/s]

Step 1200 — train_loss: 0.7187, train_mae: 0.2521, train_win_acc: 0.6846


Epoch 1: 1302it [02:33, 19.98it/s]

Step 1300 — train_loss: 0.7140, train_mae: 0.2502, train_win_acc: 0.6872


Epoch 1: 1404it [02:38, 19.74it/s]

Step 1400 — train_loss: 0.7099, train_mae: 0.2487, train_win_acc: 0.6893


Epoch 1: 1499it [02:43, 19.83it/s]

Step 1500 — train_loss: 0.7066, train_mae: 0.2472, train_win_acc: 0.6909


Epoch 1: 1504it [03:30,  3.45s/it]


[Validation] Step 1500 — train_loss: 0.7066, train_mae: 0.2472, train_win_acc: 0.6909, val_loss: 0.1097, val_mae: 0.2190, val_win_acc: 0.7219, LR: 0.001000
Saved new best model after 1500 iters
--------Validation MAE improved to 0.218962--------


Epoch 1: 1603it [03:35, 20.07it/s]

Step 1600 — train_loss: 0.7038, train_mae: 0.2461, train_win_acc: 0.6924


Epoch 1: 1703it [03:40, 19.97it/s]

Step 1700 — train_loss: 0.7011, train_mae: 0.2449, train_win_acc: 0.6937


Epoch 1: 1803it [03:45, 18.86it/s]

Step 1800 — train_loss: 0.6982, train_mae: 0.2438, train_win_acc: 0.6952


Epoch 1: 1902it [03:50, 19.88it/s]

Step 1900 — train_loss: 0.6958, train_mae: 0.2426, train_win_acc: 0.6962


Epoch 1: 1998it [03:55, 19.72it/s]

Step 2000 — train_loss: 0.6935, train_mae: 0.2416, train_win_acc: 0.6973


Epoch 1: 2002it [04:41,  4.24s/it]


[Validation] Step 2000 — train_loss: 0.6935, train_mae: 0.2416, train_win_acc: 0.6973, val_loss: 0.1065, val_mae: 0.2142, val_win_acc: 0.7237, LR: 0.000999
Saved new best model after 2000 iters
--------Validation MAE improved to 0.214194--------


Epoch 1: 2104it [04:46, 19.97it/s]

Step 2100 — train_loss: 0.6915, train_mae: 0.2408, train_win_acc: 0.6982


Epoch 1: 2203it [04:51, 19.69it/s]

Step 2200 — train_loss: 0.6897, train_mae: 0.2400, train_win_acc: 0.6990


Epoch 1: 2303it [04:56, 20.04it/s]

Step 2300 — train_loss: 0.6877, train_mae: 0.2391, train_win_acc: 0.7001


Epoch 1: 2403it [05:02, 19.83it/s]

Step 2400 — train_loss: 0.6859, train_mae: 0.2383, train_win_acc: 0.7008


Epoch 1: 2499it [05:06, 19.72it/s]

Step 2500 — train_loss: 0.6842, train_mae: 0.2376, train_win_acc: 0.7015


Epoch 1: 2502it [05:53,  4.90s/it]


[Validation] Step 2500 — train_loss: 0.6842, train_mae: 0.2376, train_win_acc: 0.7015, val_loss: 0.1037, val_mae: 0.2115, val_win_acc: 0.7323, LR: 0.000999
Saved new best model after 2500 iters
--------Validation MAE improved to 0.211451--------


Epoch 1: 2603it [05:58, 20.65it/s]

Step 2600 — train_loss: 0.6825, train_mae: 0.2369, train_win_acc: 0.7025


Epoch 1: 2703it [06:03, 20.04it/s]

Step 2700 — train_loss: 0.6808, train_mae: 0.2362, train_win_acc: 0.7034


Epoch 1: 2804it [06:08, 20.03it/s]

Step 2800 — train_loss: 0.6792, train_mae: 0.2355, train_win_acc: 0.7041


Epoch 1: 2902it [06:13, 19.81it/s]

Step 2900 — train_loss: 0.6776, train_mae: 0.2349, train_win_acc: 0.7050


Epoch 1: 2997it [06:18, 20.05it/s]

Step 3000 — train_loss: 0.6762, train_mae: 0.2343, train_win_acc: 0.7056


Epoch 1: 3004it [07:05,  2.82s/it]


[Validation] Step 3000 — train_loss: 0.6762, train_mae: 0.2343, train_win_acc: 0.7056, val_loss: 0.1031, val_mae: 0.2087, val_win_acc: 0.7346, LR: 0.000999
Saved new best model after 3000 iters
--------Validation MAE improved to 0.208669--------


Epoch 1: 3103it [07:10, 19.68it/s]

Step 3100 — train_loss: 0.6751, train_mae: 0.2339, train_win_acc: 0.7063


Epoch 1: 3203it [07:15, 20.06it/s]

Step 3200 — train_loss: 0.6740, train_mae: 0.2334, train_win_acc: 0.7068


Epoch 1: 3303it [07:20, 19.28it/s]

Step 3300 — train_loss: 0.6729, train_mae: 0.2330, train_win_acc: 0.7074


Epoch 1: 3402it [07:25, 20.75it/s]

Step 3400 — train_loss: 0.6716, train_mae: 0.2324, train_win_acc: 0.7080


Epoch 1: 3497it [07:30, 21.35it/s]

Step 3500 — train_loss: 0.6707, train_mae: 0.2321, train_win_acc: 0.7085


Epoch 1: 3502it [08:15,  3.57s/it]


[Validation] Step 3500 — train_loss: 0.6707, train_mae: 0.2321, train_win_acc: 0.7085, val_loss: 0.1023, val_mae: 0.2098, val_win_acc: 0.7281, LR: 0.000998
Saved new best model after 3500 iters
--------No significant MAE improvement for 500 iterations--------


Epoch 1: 3602it [08:20, 20.05it/s]

Step 3600 — train_loss: 0.6697, train_mae: 0.2317, train_win_acc: 0.7091


Epoch 1: 3704it [08:25, 19.58it/s]

Step 3700 — train_loss: 0.6688, train_mae: 0.2313, train_win_acc: 0.7096


Epoch 1: 3802it [08:30, 20.22it/s]

Step 3800 — train_loss: 0.6676, train_mae: 0.2310, train_win_acc: 0.7102


Epoch 1: 3903it [08:35, 20.00it/s]

Step 3900 — train_loss: 0.6668, train_mae: 0.2306, train_win_acc: 0.7106


Epoch 1: 3998it [08:40, 19.85it/s]

Step 4000 — train_loss: 0.6659, train_mae: 0.2302, train_win_acc: 0.7110


Epoch 1: 4003it [09:25,  3.62s/it]


[Validation] Step 4000 — train_loss: 0.6659, train_mae: 0.2302, train_win_acc: 0.7110, val_loss: 0.0997, val_mae: 0.2059, val_win_acc: 0.7353, LR: 0.000997
Saved new best model after 4000 iters
--------Validation MAE improved to 0.205881--------


Epoch 1: 4103it [09:30, 19.77it/s]

Step 4100 — train_loss: 0.6649, train_mae: 0.2298, train_win_acc: 0.7115


Epoch 1: 4202it [09:36, 20.01it/s]

Step 4200 — train_loss: 0.6641, train_mae: 0.2294, train_win_acc: 0.7118


Epoch 1: 4304it [09:41, 19.93it/s]

Step 4300 — train_loss: 0.6633, train_mae: 0.2291, train_win_acc: 0.7122


Epoch 1: 4402it [09:46, 19.55it/s]

Step 4400 — train_loss: 0.6624, train_mae: 0.2288, train_win_acc: 0.7128


Epoch 1: 4499it [09:51, 19.97it/s]

Step 4500 — train_loss: 0.6617, train_mae: 0.2284, train_win_acc: 0.7131


Epoch 1: 4502it [10:37,  4.29s/it]


[Validation] Step 4500 — train_loss: 0.6617, train_mae: 0.2284, train_win_acc: 0.7131, val_loss: 0.0992, val_mae: 0.2065, val_win_acc: 0.7395, LR: 0.000997
Saved new best model after 4500 iters
--------No significant MAE improvement for 500 iterations--------


Epoch 1: 4604it [10:42, 20.23it/s]

Step 4600 — train_loss: 0.6609, train_mae: 0.2282, train_win_acc: 0.7135


Epoch 1: 4703it [10:47, 19.95it/s]

Step 4700 — train_loss: 0.6602, train_mae: 0.2278, train_win_acc: 0.7138


Epoch 1: 4802it [10:52, 20.28it/s]

Step 4800 — train_loss: 0.6594, train_mae: 0.2276, train_win_acc: 0.7142


Epoch 1: 4904it [10:57, 20.05it/s]

Step 4900 — train_loss: 0.6587, train_mae: 0.2273, train_win_acc: 0.7146


Epoch 1: 4997it [11:01, 20.28it/s]

Step 5000 — train_loss: 0.6579, train_mae: 0.2270, train_win_acc: 0.7151


Epoch 1: 5002it [11:46,  3.54s/it]


[Validation] Step 5000 — train_loss: 0.6579, train_mae: 0.2270, train_win_acc: 0.7151, val_loss: 0.0982, val_mae: 0.2036, val_win_acc: 0.7383, LR: 0.000996
Saved new best model after 5000 iters
--------Validation MAE improved to 0.203617--------


Epoch 1: 5104it [11:52, 20.12it/s]

Step 5100 — train_loss: 0.6571, train_mae: 0.2267, train_win_acc: 0.7154


Epoch 1: 5204it [11:57, 19.87it/s]

Step 5200 — train_loss: 0.6565, train_mae: 0.2264, train_win_acc: 0.7157


Epoch 1: 5303it [12:01, 20.13it/s]

Step 5300 — train_loss: 0.6559, train_mae: 0.2261, train_win_acc: 0.7160


Epoch 1: 5402it [12:06, 19.77it/s]

Step 5400 — train_loss: 0.6553, train_mae: 0.2259, train_win_acc: 0.7163


Epoch 1: 5497it [12:11, 20.23it/s]

Step 5500 — train_loss: 0.6549, train_mae: 0.2257, train_win_acc: 0.7166


Epoch 1: 5503it [12:56,  3.18s/it]


[Validation] Step 5500 — train_loss: 0.6549, train_mae: 0.2257, train_win_acc: 0.7166, val_loss: 0.0983, val_mae: 0.2039, val_win_acc: 0.7411, LR: 0.000995
--------No significant MAE improvement for 500 iterations--------


Epoch 1: 5603it [13:01, 20.75it/s]

Step 5600 — train_loss: 0.6543, train_mae: 0.2254, train_win_acc: 0.7169


Epoch 1: 5704it [13:06, 20.19it/s]

Step 5700 — train_loss: 0.6535, train_mae: 0.2251, train_win_acc: 0.7172


Epoch 1: 5803it [13:11, 20.54it/s]

Step 5800 — train_loss: 0.6530, train_mae: 0.2248, train_win_acc: 0.7175


Epoch 1: 5903it [13:16, 19.96it/s]

Step 5900 — train_loss: 0.6526, train_mae: 0.2246, train_win_acc: 0.7177


Epoch 1: 5997it [13:21, 20.14it/s]

Step 6000 — train_loss: 0.6520, train_mae: 0.2244, train_win_acc: 0.7180


Epoch 1: 6003it [14:06,  3.22s/it]


[Validation] Step 6000 — train_loss: 0.6520, train_mae: 0.2244, train_win_acc: 0.7180, val_loss: 0.0966, val_mae: 0.2032, val_win_acc: 0.7452, LR: 0.000994
Saved new best model after 6000 iters
--------No significant MAE improvement for 1000 iterations--------


Epoch 1: 6103it [14:11, 20.33it/s]

Step 6100 — train_loss: 0.6513, train_mae: 0.2242, train_win_acc: 0.7183


Epoch 1: 6204it [14:16, 21.29it/s]

Step 6200 — train_loss: 0.6508, train_mae: 0.2239, train_win_acc: 0.7186


Epoch 1: 6303it [14:21, 19.59it/s]

Step 6300 — train_loss: 0.6503, train_mae: 0.2237, train_win_acc: 0.7188


Epoch 1: 6402it [14:26, 19.91it/s]

Step 6400 — train_loss: 0.6497, train_mae: 0.2235, train_win_acc: 0.7191


Epoch 1: 6498it [14:31, 20.40it/s]

Step 6500 — train_loss: 0.6491, train_mae: 0.2232, train_win_acc: 0.7194


Epoch 1: 6502it [15:16,  3.89s/it]


[Validation] Step 6500 — train_loss: 0.6491, train_mae: 0.2232, train_win_acc: 0.7194, val_loss: 0.0983, val_mae: 0.2069, val_win_acc: 0.7417, LR: 0.000993
--------No significant MAE improvement for 1500 iterations--------


Epoch 1: 6603it [15:21, 20.38it/s]

Step 6600 — train_loss: 0.6487, train_mae: 0.2231, train_win_acc: 0.7196


Epoch 1: 6702it [15:25, 20.18it/s]

Step 6700 — train_loss: 0.6483, train_mae: 0.2229, train_win_acc: 0.7198


Epoch 1: 6802it [15:30, 20.16it/s]

Step 6800 — train_loss: 0.6477, train_mae: 0.2227, train_win_acc: 0.7201


Epoch 1: 6903it [15:35, 20.10it/s]

Step 6900 — train_loss: 0.6474, train_mae: 0.2225, train_win_acc: 0.7203


Epoch 1: 6997it [15:40, 20.11it/s]

Step 7000 — train_loss: 0.6470, train_mae: 0.2224, train_win_acc: 0.7205


Epoch 1: 7003it [16:25,  3.22s/it]


[Validation] Step 7000 — train_loss: 0.6470, train_mae: 0.2224, train_win_acc: 0.7205, val_loss: 0.0965, val_mae: 0.2027, val_win_acc: 0.7463, LR: 0.000992
Saved new best model after 7000 iters
--------Validation MAE improved to 0.202660--------


Epoch 1: 7103it [16:30, 19.38it/s]

Step 7100 — train_loss: 0.6465, train_mae: 0.2222, train_win_acc: 0.7208


Epoch 1: 7203it [16:35, 19.74it/s]

Step 7200 — train_loss: 0.6461, train_mae: 0.2221, train_win_acc: 0.7210


Epoch 1: 7302it [16:40, 20.48it/s]

Step 7300 — train_loss: 0.6457, train_mae: 0.2219, train_win_acc: 0.7212


Epoch 1: 7403it [16:45, 20.16it/s]

Step 7400 — train_loss: 0.6453, train_mae: 0.2217, train_win_acc: 0.7214


Epoch 1: 7497it [16:50, 20.04it/s]

Step 7500 — train_loss: 0.6449, train_mae: 0.2215, train_win_acc: 0.7216


Epoch 1: 7503it [17:35,  3.22s/it]


[Validation] Step 7500 — train_loss: 0.6449, train_mae: 0.2215, train_win_acc: 0.7216, val_loss: 0.1022, val_mae: 0.2096, val_win_acc: 0.7401, LR: 0.000991
--------No significant MAE improvement for 500 iterations--------


Epoch 1: 7603it [17:40, 20.51it/s]

Step 7600 — train_loss: 0.6444, train_mae: 0.2214, train_win_acc: 0.7219


Epoch 1: 7702it [17:44, 20.84it/s]

Step 7700 — train_loss: 0.6441, train_mae: 0.2212, train_win_acc: 0.7222


Epoch 1: 7802it [17:49, 19.53it/s]

Step 7800 — train_loss: 0.6437, train_mae: 0.2211, train_win_acc: 0.7224


Epoch 1: 7904it [17:55, 20.15it/s]

Step 7900 — train_loss: 0.6433, train_mae: 0.2209, train_win_acc: 0.7225


Epoch 1: 7998it [17:59, 19.08it/s]

Step 8000 — train_loss: 0.6429, train_mae: 0.2207, train_win_acc: 0.7227


Epoch 1: 8002it [18:44,  4.19s/it]


[Validation] Step 8000 — train_loss: 0.6429, train_mae: 0.2207, train_win_acc: 0.7227, val_loss: 0.0955, val_mae: 0.1994, val_win_acc: 0.7482, LR: 0.000990
Saved new best model after 8000 iters
--------Validation MAE improved to 0.199361--------


Epoch 1: 8103it [18:49, 20.11it/s]

Step 8100 — train_loss: 0.6425, train_mae: 0.2205, train_win_acc: 0.7229


Epoch 1: 8193it [18:54,  7.22it/s]


KeyboardInterrupt: 