# 4: Model Training

## Imports

In [1]:
from utils.model_utils import *
import torch.onnx

## Variables Declaration

In [None]:
LR = 0.1
loss_func = nn.BCELoss()
NUM_HANGING_VALUES = 10
EPOCHS = 30
BITBOARD_SHAPE = (76*2, 8, 8)
CHANGE_LEARNING_RATE = True
UPDATE_EPOCHS = [4, 10, 15]
RESIDUAL_BLOCKS = 6
RESIDUAL_FILTERS = 64
SE_RATIO = 8
MODEL_FILENAME = "./Models/PikeBot_Models/PikeBot.pth"
ONNX_FILENAME = "./Models/PikeBot_Models/PikeBot.onnx"
LOG_FILE_LOCATION = "./Training_Logs/Training.txt"
CHECKPOINT_FILENAME_LOCATION = "./Models/PikeBot_Models/PikeBot_checkpoint.pth"
TRAIN_GENERATOR_PATH = "./Generators/train_generator.pkl"
VAL_GENERATOR_PATH = "./Generators/val_generator.pkl"
TEST_GENERATOR_PATH = "./Generators/test_generator.pkl"
TEMP_STATE_PATH = "./Models\\PikeBot_Models\\temp_state_dict.pth"
OPSET_VERSION = 11

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Detected Device: {device}")

## Model Sanity Check

In [2]:
model = ChessModel_V2(bit_board_shape=BITBOARD_SHAPE, num_float_inputs=NUM_HANGING_VALUES, residual_blocks=RESIDUAL_BLOCKS, residual_filters=RESIDUAL_FILTERS, se_ratio=SE_RATIO)
num_params = count_parameters(model)
print("Number of parameters in the model in millions:", round(num_params/(1e6), 4))

Number of parameters in the model in millions: 4.8321


In [3]:
print(model)

Chess_Model(
  (conv1): Conv2d(76, 304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(304, 304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(304, 304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(304, 76, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(76, 304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(304, 304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv7): Conv2d(304, 304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv8): Conv2d(304, 76, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (float_inputs_fc): Linear(in_features=4, out_features=512, bias=True)
  (fc1): Linear(in_features=588, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=64, bias=True)
  (output_layer): Linear(in_features=64, out_features=1, bias=True)
)


In [4]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
board_shape = (1, BITBOARD_SHAPE[0], BITBOARD_SHAPE[1], BITBOARD_SHAPE[2])
floats_shape = (1, NUM_HANGING_VALUES)
input_bitboard = torch.tensor(np.random.rand(*board_shape), dtype = torch.float32)
input_floats = torch.tensor(np.random.rand(*floats_shape), dtype = torch.float32)
output = model(input_bitboard, input_floats)
output

tensor([[0.5097]], grad_fn=<SigmoidBackward0>)

## Model Training

In [6]:
train_generator = efficent_load_object(TRAIN_GENERATOR_PATH)
train_generator.__len__()

209557

In [7]:
val_generator = efficent_load_object(VAL_GENERATOR_PATH)
val_generator.__len__()

20931

In [8]:
test_generator = efficent_load_object(TEST_GENERATOR_PATH)
test_generator.__len__()

20974

In [9]:
del train_generator
del val_generator
del test_generator

In [10]:
model = model.to(device)
model = train(TRAIN_GENERATOR_PATH, VAL_GENERATOR_PATH, TEST_GENERATOR_PATH, model, optimizer, loss_func, NUM_HANGING_VALUES, EPOCHS, device,
              learning_rate=LR, log = 1, log_file = "./Training_Logs\\Training.txt", verbose = 1, val = True, early_callback=False, early_callback_epochs=None,
              checkpoint=True, epochs_per_checkpoint=1, break_after_checkpoint=False, checkpoint_filename="./Models\\PikeBot_Models\\PikeBot_checkpoint.pth",
              change_learning_rate=CHANGE_LEARNING_RATE, update_epochs=UPDATE_EPOCHS)

Checkpoint found. Resuming training from checkpoint...
______________________________________________________________
Epoch 14 Train Loss: 0.0074 | MSE: 0.1569 | MAE: 0.3133 | Accuracy: 0.7658
Epoch 14 Val Loss: 0.0076 | MSE: 0.1573 | MAE: 0.3079  | Accuracy: 0.7657
Epoch 15: Saving checkpoint...
______________________________________________________________
Epoch 15 Train Loss: 0.0075 | MSE: 0.157 | MAE: 0.3133 | Accuracy: 0.7657
Epoch 15 Val Loss: 0.0074 | MSE: 0.1572 | MAE: 0.3092  | Accuracy: 0.7657
Epoch 16: Saving checkpoint...
Epoch 16 Test Loss: 0.0075 | MSE: 0.1572 | MAE: 0.3091  | Accuracy: 0.7658


## Saving Model

In [5]:
model

Chess_Model(
  (conv1): Conv2d(76, 304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(304, 304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(304, 304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(304, 76, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(76, 304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(304, 304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv7): Conv2d(304, 304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv8): Conv2d(304, 76, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (float_inputs_fc): Linear(in_features=4, out_features=512, bias=True)
  (fc1): Linear(in_features=588, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=64, bias=True)
  (output_layer): Linear(in_features=64, out_features=1, bias=True)
)

In [None]:
torch.save(model.state_dict(), TEMP_STATE_PATH)

In [8]:
model = ChessModel_V2(bit_board_shape=BITBOARD_SHAPE, num_float_inputs=NUM_HANGING_VALUES, residual_blocks=RESIDUAL_BLOCKS, residual_filters=RESIDUAL_FILTERS, se_ratio=SE_RATIO)
model.load_state_dict(torch.load(TEMP_STATE_PATH))
save_model(model, model_filename=MODEL_FILENAME, onnx_filename=ONNX_FILENAME,
           bitboard_input_shape=board_shape, hanging_values_input_shape=floats_shape, opset_version=OPSET_VERSION, device="cpu")

Model saved successfully!


## Post-Training Sanity Check

In [None]:
model = torch.load(MODEL_FILENAME)
model = model.to(device)
model.eval()
test_model(model, loss_func=loss_func, num_hanging_values=NUM_HANGING_VALUES, device=device, test_generator_path=TEST_GENERATOR_PATH)