## Plan for continuous training

Here's the plan. The script boots up and checks a continuous play directory which is populated with a number of training round sub-directores. Each sub-directory is labelled in sequence and contains a self_play sub-sub-directory. The script takes stock of the contents of the latest training round self_play directory.

If the number of files (tournaments) saved in the self_play directory of the latest training round is less than some minimum M, say m < M then the self-play script is kicked off to play M - m more tournaments, completing the M requred self-play tournaments and saving to the  subdirectory.

If the number of files saved in the self_play directory of the latest training round m >= M then the script checks if there is already a model saved in the training round directory.

If there is not a model saved in the training round directory, the script will kick off a training routine to create a new model (or load a saved model from a possibly existing previous round), compile a training dataset from the most recent k training rounds (augmented with the catalogue of known checkmate positions) and train the model to stopping. The script will then save the model in the training round sub-directory.

Then the script will create a new training round directory, labelled with the next integer in the training round sequence, containing a self_play directory. At this point we can essentially continue the loop from the top.

In [1]:
import os, torch
from torch.utils.data import DataLoader
from chess_selfplay import harvest_checkmates
from chess_model import TransformerModel, ChessDataset, TanhLoss, train
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

min_tournaments_each_round = 50 # A new model roughly once a day.
root_dir = os.path.join('data','output')
dir_sort_key = lambda d: int(d.split('_')[-1])

## LOOP STARTS - just kill the machine any old time when you have stuff to do and you can fire it up again whenever you're ready.
while True:

    training_round_dirs = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir,d))], key=dir_sort_key)
    current_training_round_dir = training_round_dirs[-1]
    previous_training_round_dir = training_round_dirs[-2]

    current_self_play_path = os.path.join(root_dir, current_training_round_dir, 'self_play')
    if os.path.exists(current_self_play_path) and os.path.isdir(current_self_play_path):
        number_games_played_already = len(os.listdir(current_self_play_path))
    else:
        os.mkdir(current_self_play_path)
        number_games_played_already = 0

    tournaments_left_to_play = min_tournaments_each_round - number_games_played_already
    print(f'Current: {current_training_round_dir}, already played: {number_games_played_already}, left to play: {tournaments_left_to_play}')

    if tournaments_left_to_play > 0:
        
        # self-play script to play tournaments_left_to_play more tournaments, saving to current_self_play_dir.
        model_source_path = os.path.join(root_dir, previous_training_round_dir, 'model.pt')
        assert os.path.exists(model_source_path) and os.path.isfile(model_source_path), 'ERROR: MODEL NOT FOUND.'

        # Same base model with different look-ahead strength configuration 
        model_kwargs = {'nlayers':6, 'nheads':3, 'embed_dim':18, 'dk':5, 'device':device, 'load_path':model_source_path}
        agent0_spec = {'type': 'transformer', 'kwargs': model_kwargs, 'num_simgames': 200, 'max_simmoves': 5, 'C': 1, 'p': 0.4, 'k': float('inf')}
        agent1_spec = {'type': 'transformer', 'kwargs': model_kwargs, 'num_simgames':  1, 'max_simmoves': 1, 'C': 1, 'p': 0.4, 'k': float('inf')}
        self_play_args = {
            'num_workers':2, 'num_tournaments': tournaments_left_to_play, 'agents_spec': [agent0_spec, agent1_spec], 
            'num_games':1, 'starting_state':None, 'max_moves':200, 'save':True, 'result_dest':current_self_play_path
        }

        # Let's play
        print(f'Playing {tournaments_left_to_play} tournaments...')
        %run -i "chess_selfplay.py"
        print(f'Self-play complete.')

    # Extract the checkmates from the current_self_play_dir tournament games and save them in the current_training_round_dir.
    current_round_checkmates_path = os.path.join(root_dir, current_training_round_dir, 'checkmates.pkl')
    if not (os.path.exists(current_round_checkmates_path) and os.path.isfile(current_round_checkmates_path)):
        _ = harvest_checkmates(os.path.join(root_dir, current_training_round_dir))

    latest_model_path = os.path.join(root_dir, current_training_round_dir, 'model.pt')
    if not (os.path.exists(latest_model_path) and os.path.isfile(latest_model_path)):

        # No model saved here yet. Create and train a new model based on the previous k rounds of self-play data.
        # We currently have 591 tournaments saved in baseline and 191 in round1. We could use k = 10 and go from there?
        model_kwargs = {'nlayers':6, 'nheads':3, 'embed_dim':18, 'dk':5, 'device':device, 'load_path':None} # Brand new model
        model = TransformerModel(**model_kwargs)
        optimizer = torch.optim.Adam(model.parameters(), lr=0, weight_decay=0)
        loss_fn = TanhLoss()
        dataset = ChessDataset(root_dir=root_dir, look_back=10, device=device)
        train_set, test_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), len(dataset) - int(len(dataset)*0.8)])
        train_loader = DataLoader(train_set, batch_size=1000, shuffle=True, num_workers=0)
        test_loader = DataLoader(test_set, batch_size=1000, shuffle=True, num_workers=0)
        print(f'Training on {len(train_set):,.0f} examples in {len(train_loader):,.0f} batches.')

        # Train on the data
        model_dest = os.path.join(root_dir, current_training_round_dir, 'model.pt')
        model = train(model, loss_fn, optimizer, train_loader, test_loader, warmup_passes=4, max_lr=1e-4, save_dir=model_dest, stopping=5)
        torch.cuda.empty_cache()
        
    # Create next training round directory containing self_play sub-directory, and start loop from the top.
    next_index = int(os.path.split(current_training_round_dir)[-1].split('_')[-1]) + 1
    next_training_round_dir = f'round_{next_index}'
    print(f'Creating next training round directory {next_training_round_dir}')
    next_training_round_path = os.path.join(root_dir, next_training_round_dir)
    os.mkdir(next_training_round_path)
    next_training_round_self_play_path = os.path.join(next_training_round_path, 'self_play')
    os.mkdir(next_training_round_self_play_path)

cuda
Current: round_9, already played: 80, left to play: -30
Training on 238383 examples in 239 batches.


100%|██████████| 239/239 [00:51<00:00,  4.67it/s]
100%|██████████| 60/60 [00:08<00:00,  7.10it/s]


Pass: 1, train loss: 3.44938, test loss: 3.44535, stopping count: 0


100%|██████████| 239/239 [01:02<00:00,  3.85it/s]
100%|██████████| 60/60 [00:08<00:00,  7.08it/s]


Pass: 2, train loss: 3.44867, test loss: 3.44192, stopping count: 0


100%|██████████| 239/239 [01:00<00:00,  3.95it/s]
100%|██████████| 60/60 [00:07<00:00,  7.57it/s]


Pass: 3, train loss: 3.44434, test loss: 3.43911, stopping count: 0


100%|██████████| 239/239 [00:56<00:00,  4.24it/s]
100%|██████████| 60/60 [00:07<00:00,  7.85it/s]


Pass: 4, train loss: 3.44206, test loss: 3.43694, stopping count: 0


100%|██████████| 239/239 [00:55<00:00,  4.33it/s]
100%|██████████| 60/60 [00:07<00:00,  8.03it/s]


Pass: 5, train loss: 3.44039, test loss: 3.43618, stopping count: 0


100%|██████████| 239/239 [00:53<00:00,  4.48it/s]
100%|██████████| 60/60 [00:07<00:00,  7.99it/s]


Pass: 6, train loss: 3.43943, test loss: 3.43546, stopping count: 0


100%|██████████| 239/239 [00:52<00:00,  4.59it/s]
100%|██████████| 60/60 [00:07<00:00,  8.03it/s]


Pass: 7, train loss: 3.43880, test loss: 3.43411, stopping count: 0


100%|██████████| 239/239 [00:52<00:00,  4.53it/s]
100%|██████████| 60/60 [00:07<00:00,  8.05it/s]


Pass: 8, train loss: 3.43860, test loss: 3.43433, stopping count: 1


100%|██████████| 239/239 [00:54<00:00,  4.42it/s]
100%|██████████| 60/60 [00:07<00:00,  8.08it/s]


Pass: 9, train loss: 3.43819, test loss: 3.43514, stopping count: 2


100%|██████████| 239/239 [00:52<00:00,  4.59it/s]
100%|██████████| 60/60 [00:07<00:00,  7.71it/s]


Pass: 10, train loss: 3.43810, test loss: 3.43422, stopping count: 3


100%|██████████| 239/239 [00:51<00:00,  4.60it/s]
100%|██████████| 60/60 [00:07<00:00,  8.00it/s]


Pass: 11, train loss: 3.43790, test loss: 3.43330, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.66it/s]
100%|██████████| 60/60 [00:07<00:00,  7.99it/s]


Pass: 12, train loss: 3.43764, test loss: 3.43419, stopping count: 1


100%|██████████| 239/239 [00:51<00:00,  4.67it/s]
100%|██████████| 60/60 [00:07<00:00,  8.04it/s]


Pass: 13, train loss: 3.43740, test loss: 3.43293, stopping count: 0


100%|██████████| 239/239 [00:52<00:00,  4.58it/s]
100%|██████████| 60/60 [00:07<00:00,  8.15it/s]


Pass: 14, train loss: 3.43752, test loss: 3.43284, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.63it/s]
100%|██████████| 60/60 [00:07<00:00,  8.01it/s]


Pass: 15, train loss: 3.43747, test loss: 3.43260, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.63it/s]
100%|██████████| 60/60 [00:07<00:00,  8.16it/s]


Pass: 16, train loss: 3.43706, test loss: 3.43255, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.63it/s]
100%|██████████| 60/60 [00:07<00:00,  8.29it/s]


Pass: 17, train loss: 3.43686, test loss: 3.43244, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.66it/s]
100%|██████████| 60/60 [00:07<00:00,  7.90it/s]


Pass: 18, train loss: 3.43687, test loss: 3.43244, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.60it/s]
100%|██████████| 60/60 [00:07<00:00,  7.97it/s]


Pass: 19, train loss: 3.43681, test loss: 3.43245, stopping count: 1


100%|██████████| 239/239 [00:51<00:00,  4.66it/s]
100%|██████████| 60/60 [00:07<00:00,  8.06it/s]


Pass: 20, train loss: 3.43665, test loss: 3.43233, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.68it/s]
100%|██████████| 60/60 [00:07<00:00,  8.20it/s]


Pass: 21, train loss: 3.43667, test loss: 3.43223, stopping count: 0


100%|██████████| 239/239 [00:50<00:00,  4.72it/s]
100%|██████████| 60/60 [00:07<00:00,  8.23it/s]


Pass: 22, train loss: 3.43652, test loss: 3.43247, stopping count: 1


100%|██████████| 239/239 [00:51<00:00,  4.62it/s]
100%|██████████| 60/60 [00:07<00:00,  8.09it/s]


Pass: 23, train loss: 3.43630, test loss: 3.43260, stopping count: 2


100%|██████████| 239/239 [00:51<00:00,  4.65it/s]
100%|██████████| 60/60 [00:07<00:00,  8.07it/s]


Pass: 24, train loss: 3.43664, test loss: 3.43314, stopping count: 3


100%|██████████| 239/239 [00:51<00:00,  4.66it/s]
100%|██████████| 60/60 [00:07<00:00,  7.85it/s]


Pass: 25, train loss: 3.43626, test loss: 3.43190, stopping count: 0


100%|██████████| 239/239 [00:52<00:00,  4.60it/s]
100%|██████████| 60/60 [00:07<00:00,  7.54it/s]


Pass: 26, train loss: 3.43615, test loss: 3.43201, stopping count: 1


100%|██████████| 239/239 [00:51<00:00,  4.67it/s]
100%|██████████| 60/60 [00:07<00:00,  8.04it/s]


Pass: 27, train loss: 3.43637, test loss: 3.43173, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.63it/s]
100%|██████████| 60/60 [00:07<00:00,  8.12it/s]


Pass: 28, train loss: 3.43604, test loss: 3.43299, stopping count: 1


100%|██████████| 239/239 [00:51<00:00,  4.67it/s]
100%|██████████| 60/60 [00:07<00:00,  8.22it/s]


Pass: 29, train loss: 3.43628, test loss: 3.43190, stopping count: 2


100%|██████████| 239/239 [00:51<00:00,  4.66it/s]
100%|██████████| 60/60 [00:07<00:00,  8.13it/s]


Pass: 30, train loss: 3.43602, test loss: 3.43161, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.62it/s]
100%|██████████| 60/60 [00:07<00:00,  8.02it/s]


Pass: 31, train loss: 3.43617, test loss: 3.43151, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.63it/s]
100%|██████████| 60/60 [00:07<00:00,  8.09it/s]


Pass: 32, train loss: 3.43590, test loss: 3.43162, stopping count: 1


100%|██████████| 239/239 [00:51<00:00,  4.67it/s]
100%|██████████| 60/60 [00:07<00:00,  8.09it/s]


Pass: 33, train loss: 3.43595, test loss: 3.43146, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.64it/s]
100%|██████████| 60/60 [00:07<00:00,  8.21it/s]


Pass: 34, train loss: 3.43585, test loss: 3.43235, stopping count: 1


100%|██████████| 239/239 [00:52<00:00,  4.58it/s]
100%|██████████| 60/60 [00:07<00:00,  8.09it/s]


Pass: 35, train loss: 3.43584, test loss: 3.43152, stopping count: 2


100%|██████████| 239/239 [00:51<00:00,  4.63it/s]
100%|██████████| 60/60 [00:07<00:00,  8.33it/s]


Pass: 36, train loss: 3.43600, test loss: 3.43443, stopping count: 3


100%|██████████| 239/239 [00:51<00:00,  4.60it/s]
100%|██████████| 60/60 [00:07<00:00,  8.27it/s]


Pass: 37, train loss: 3.43573, test loss: 3.43207, stopping count: 4


100%|██████████| 239/239 [00:51<00:00,  4.60it/s]
100%|██████████| 60/60 [00:07<00:00,  8.23it/s]


Pass: 38, train loss: 3.43573, test loss: 3.43145, stopping count: 0


100%|██████████| 239/239 [00:52<00:00,  4.56it/s]
100%|██████████| 60/60 [00:07<00:00,  8.05it/s]


Pass: 39, train loss: 3.43575, test loss: 3.43135, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.68it/s]
100%|██████████| 60/60 [00:07<00:00,  8.17it/s]


Pass: 40, train loss: 3.43591, test loss: 3.43156, stopping count: 1


100%|██████████| 239/239 [00:51<00:00,  4.67it/s]
100%|██████████| 60/60 [00:07<00:00,  8.16it/s]


Pass: 41, train loss: 3.43564, test loss: 3.43218, stopping count: 2


100%|██████████| 239/239 [00:51<00:00,  4.60it/s]
100%|██████████| 60/60 [00:07<00:00,  7.87it/s]


Pass: 42, train loss: 3.43568, test loss: 3.43163, stopping count: 3


100%|██████████| 239/239 [00:51<00:00,  4.67it/s]
100%|██████████| 60/60 [00:07<00:00,  8.15it/s]


Pass: 43, train loss: 3.43585, test loss: 3.43278, stopping count: 4


100%|██████████| 239/239 [00:51<00:00,  4.62it/s]
100%|██████████| 60/60 [00:07<00:00,  7.97it/s]


Pass: 44, train loss: 3.43579, test loss: 3.43117, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.62it/s]
100%|██████████| 60/60 [00:07<00:00,  8.00it/s]


Pass: 45, train loss: 3.43564, test loss: 3.43148, stopping count: 1


100%|██████████| 239/239 [00:51<00:00,  4.66it/s]
100%|██████████| 60/60 [00:07<00:00,  8.06it/s]


Pass: 46, train loss: 3.43561, test loss: 3.43128, stopping count: 2


100%|██████████| 239/239 [00:52<00:00,  4.54it/s]
100%|██████████| 60/60 [00:07<00:00,  7.51it/s]


Pass: 47, train loss: 3.43558, test loss: 3.43110, stopping count: 0


100%|██████████| 239/239 [00:51<00:00,  4.62it/s]
100%|██████████| 60/60 [00:09<00:00,  6.23it/s]


Pass: 48, train loss: 3.43572, test loss: 3.43124, stopping count: 1


100%|██████████| 239/239 [00:50<00:00,  4.77it/s]
100%|██████████| 60/60 [00:08<00:00,  6.90it/s]


Pass: 49, train loss: 3.43546, test loss: 3.43290, stopping count: 2


100%|██████████| 239/239 [00:50<00:00,  4.70it/s]
100%|██████████| 60/60 [00:07<00:00,  7.69it/s]


Pass: 50, train loss: 3.43567, test loss: 3.43158, stopping count: 3


100%|██████████| 239/239 [00:53<00:00,  4.44it/s]
100%|██████████| 60/60 [00:07<00:00,  7.79it/s]


Pass: 51, train loss: 3.43535, test loss: 3.43137, stopping count: 4


100%|██████████| 239/239 [00:49<00:00,  4.84it/s]
100%|██████████| 60/60 [00:07<00:00,  8.05it/s]


Pass: 52, train loss: 3.43559, test loss: 3.43136, stopping count: 5
Creating next training round directory round_10
Current: round_10, already played: 0, left to play: 50
Playing 50 tournaments...
