# 4: Model Training

## Imports

In [1]:
from utils.model_utils import *
import torch.onnx

torch.manual_seed(23)

2024-10-05 13:38:32.392069: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


<torch._C.Generator at 0x7935a82211d0>

## Variables Declaration

In [2]:
LR = 2e-5
loss_func = nn.BCELoss()
NUM_HANGING_VALUES = 10
EPOCHS = 30
BITBOARD_SHAPE = (76*6, 8, 8)
CHANGE_LEARNING_RATE = False
UPDATE_EPOCHS = None
RESIDUAL_BLOCKS = 6
RESIDUAL_FILTERS = 64
SE_RATIO = 8
MODEL_FILENAME = "./Models/PikeBot_Models/PikeBot.pth"
ONNX_FILENAME = "./Models/PikeBot_Models/PikeBot.onnx"
LOG_FILE_LOCATION = "./Training_Logs/Training.txt"
CHECKPOINT_FILENAME_LOCATION = "./Models/PikeBot_Models/PikeBot_checkpoint.pth"
TRAIN_GENERATOR_PATH = "./Generators/train_generator.pkl"
VAL_GENERATOR_PATH = "./Generators/val_generator.pkl"
TEST_GENERATOR_PATH = "./Generators/test_generator.pkl"
TEMP_STATE_PATH = "./Models/PikeBot_Models/temp_state_dict.pth"
OPSET_VERSION = 11

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = "cpu"
print(f"Detected Device: {device}")

Detected Device: cuda


## Model Sanity Check

In [3]:
#model = ChessModel_V2(bit_board_shape=BITBOARD_SHAPE, num_float_inputs=NUM_HANGING_VALUES, residual_blocks=RESIDUAL_BLOCKS, residual_filters=RESIDUAL_FILTERS, se_ratio=SE_RATIO)
model = Chess_Model(bit_board_shape=BITBOARD_SHAPE, num_float_inputs=NUM_HANGING_VALUES, channel_multiple=2)
num_params = count_parameters(model)
print("Number of parameters in the model in millions:", round(num_params/(1e6), 4))

Number of parameters in the model in millions: 45.9841


In [4]:
print(model)

Chess_Model(
  (conv1): Conv2d(456, 912, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(912, 912, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(912, 912, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(912, 456, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(456, 912, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(912, 912, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv7): Conv2d(912, 912, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv8): Conv2d(912, 456, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (float_inputs_fc): Linear(in_features=10, out_features=512, bias=True)
  (fc1): Linear(in_features=968, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=64, bias=True)
  (output_layer): Linear(in_features=64, out_features=1, bias=True)
)


In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
board_shape = (1, BITBOARD_SHAPE[0], BITBOARD_SHAPE[1], BITBOARD_SHAPE[2])
floats_shape = (1, NUM_HANGING_VALUES)
input_bitboard = torch.tensor(np.random.rand(*board_shape), dtype = torch.float32)
input_floats = torch.tensor(np.random.rand(*floats_shape), dtype = torch.float32)
output = model(input_bitboard, input_floats)
output

tensor([[0.5016]], grad_fn=<SigmoidBackward0>)

## Model Training

In [6]:
train_generator = efficent_load_object(TRAIN_GENERATOR_PATH)
train_generator.__len__()

1041253

In [7]:
val_generator = efficent_load_object(VAL_GENERATOR_PATH)
val_generator.__len__()

105374

In [8]:
test_generator = efficent_load_object(TEST_GENERATOR_PATH)
test_generator.__len__()

104725

In [9]:
del train_generator
del val_generator
del test_generator

In [10]:
model = model.to(device)
model = train(TRAIN_GENERATOR_PATH, VAL_GENERATOR_PATH, TEST_GENERATOR_PATH, model, optimizer, loss_func, NUM_HANGING_VALUES, EPOCHS, device,
              learning_rate=LR, log = 1, log_file = "./Training_Logs/Training.txt", verbose = 1, val = True, early_callback=True, early_callback_epochs=5,
              checkpoint=True, epochs_per_checkpoint=1, break_after_checkpoint=False, checkpoint_filename="./Models/PikeBot_Models/PikeBot_checkpoint.pth",
              change_learning_rate=CHANGE_LEARNING_RATE, update_epochs=UPDATE_EPOCHS)

______________________________________________________________
Epoch 0 Train Loss: 0.0064 | MSE: 0.1318 | MAE: 0.2633 | Accuracy: 0.8094


Epoch 0 Val Loss: 0.0061 | MSE: 0.1243 | MAE: 0.2527  | Accuracy: 0.8221
Epoch 1: Saving checkpoint...


______________________________________________________________
Epoch 1 Train Loss: 0.006 | MSE: 0.1224 | MAE: 0.2446 | Accuracy: 0.825


Epoch 1 Val Loss: 0.0059 | MSE: 0.1209 | MAE: 0.2492  | Accuracy: 0.8282
Epoch 2: Saving checkpoint...


______________________________________________________________
Epoch 2 Train Loss: 0.0059 | MSE: 0.1198 | MAE: 0.2393 | Accuracy: 0.8292


Epoch 2 Val Loss: 0.0059 | MSE: 0.1194 | MAE: 0.2468  | Accuracy: 0.8306
Epoch 3: Saving checkpoint...


______________________________________________________________
Epoch 3 Train Loss: 0.0059 | MSE: 0.1187 | MAE: 0.237 | Accuracy: 0.831


Epoch 3 Val Loss: 0.0059 | MSE: 0.1191 | MAE: 0.2456  | Accuracy: 0.8311
Epoch 4: Saving checkpoint...


______________________________________________________________
Epoch 4 Train Loss: 0.0058 | MSE: 0.1181 | MAE: 0.2358 | Accuracy: 0.8319


Epoch 4 Val Loss: 0.0059 | MSE: 0.1189 | MAE: 0.2444  | Accuracy: 0.8312
Epoch 5: Saving checkpoint...


______________________________________________________________
Epoch 5 Train Loss: 0.0059 | MSE: 0.118 | MAE: 0.2355 | Accuracy: 0.8324


Epoch 5 Val Loss: 0.0059 | MSE: 0.1192 | MAE: 0.2367  | Accuracy: 0.8307
Epoch 6: Saving checkpoint...


______________________________________________________________
Epoch 6 Train Loss: 0.0059 | MSE: 0.1181 | MAE: 0.2358 | Accuracy: 0.8323


Epoch 6 Val Loss: 0.0059 | MSE: 0.1193 | MAE: 0.2378  | Accuracy: 0.8312
Epoch 7: Saving checkpoint...


______________________________________________________________
Epoch 7 Train Loss: 0.0059 | MSE: 0.1181 | MAE: 0.2358 | Accuracy: 0.8323


Epoch 7 Val Loss: 0.0059 | MSE: 0.1198 | MAE: 0.2352  | Accuracy: 0.8309
Epoch 8: Saving checkpoint...


______________________________________________________________
Epoch 8 Train Loss: 0.0059 | MSE: 0.1182 | MAE: 0.2358 | Accuracy: 0.8323


Epoch 8 Val Loss: 0.0059 | MSE: 0.1189 | MAE: 0.236  | Accuracy: 0.8321
Epoch 9: Saving checkpoint...


______________________________________________________________
Epoch 9 Train Loss: 0.0059 | MSE: 0.1181 | MAE: 0.2356 | Accuracy: 0.8324


Epoch 9 Val Loss: 0.0059 | MSE: 0.119 | MAE: 0.2358  | Accuracy: 0.8319
Epoch 10: Saving checkpoint...


______________________________________________________________
Epoch 10 Train Loss: 0.0059 | MSE: 0.1182 | MAE: 0.2357 | Accuracy: 0.8323


Epoch 10 Val Loss: 0.0059 | MSE: 0.1192 | MAE: 0.24  | Accuracy: 0.8314
*****************
Early Callback
*****************


Epoch 10 Test Loss: 0.0058 | MSE: 0.1185 | MAE: 0.2437  | Accuracy: 0.8316


## Saving Model

In [11]:
model

Chess_Model(
  (conv1): Conv2d(456, 912, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(912, 912, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(912, 912, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(912, 456, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(456, 912, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(912, 912, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv7): Conv2d(912, 912, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv8): Conv2d(912, 456, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (float_inputs_fc): Linear(in_features=10, out_features=512, bias=True)
  (fc1): Linear(in_features=968, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=64, bias=True)
  (output_layer): Linear(in_features=64, out_features=1, bias=True)
)

In [12]:
torch.save(model.state_dict(), TEMP_STATE_PATH)

In [13]:
model = Chess_Model(bit_board_shape=BITBOARD_SHAPE, num_float_inputs=NUM_HANGING_VALUES, channel_multiple=2)
model.load_state_dict(torch.load(TEMP_STATE_PATH))
save_model(model, model_filename=MODEL_FILENAME, onnx_filename=ONNX_FILENAME,
           bitboard_input_shape=board_shape, hanging_values_input_shape=floats_shape, opset_version=OPSET_VERSION, device="cpu")

  model.load_state_dict(torch.load(TEMP_STATE_PATH))


Model saved successfully!


## Post-Training Sanity Check

In [14]:
model = torch.load(MODEL_FILENAME)
model = model.to(device)
model.eval()
test_model(model, loss_func=loss_func, num_hanging_values=NUM_HANGING_VALUES, device=device, test_generator_path=TEST_GENERATOR_PATH)

  model = torch.load(MODEL_FILENAME)


Testing Complete, Loss: 0.0058 | MSE: 0.1185 | MAE: 0.2437 | Accuracy: 0.8316
