Skip to content

Commit

Permalink
Train CNN Refactoring (#201)
Browse files Browse the repository at this point in the history
* Seperate pytorch and mxnet imports

* Refactoring of supervised training

* Removed importing MXNet when training with pytorch
* Added data_loader helper method
* Removed _template suffix from config files
* Updated requirement.txt
* Updated READMEs
* Updated Dockerfile (now downloads cutechess, example NNs and calls git pull)
  • Loading branch information
QueensGambit authored Jun 30, 2023
1 parent c8ad1de commit 608e958
Show file tree
Hide file tree
Showing 13 changed files with 338 additions and 214 deletions.
4 changes: 0 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,6 @@ venv.bak/
# avoid pushing any lock files
*.~lock

# avoid pushing config files
main_config.py
train_config.py

# avoid pushing log-files generated by uci-communication
CrazyAra-log.txt
score-log.txt
Expand Down
4 changes: 2 additions & 2 deletions DeepCrazyhouse/configs/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
### Configuration files for CrazyAra

This is the configuration directory of CrazyAra for both supervised training and reinforcement learning.

If you want to test sample MCTS predictions in [MCTS_eval_demo.ipynb](https://github.com/QueensGambit/CrazyAra/blob/master/DeepCrazyhouse/src/samples/MCTS_eval_demo.ipynb),
then follow these steps:

* First rename `main_config_template.py` into `main_config.py`

* Specify the fields `main_config["model_architecture_dir"]` and `main_config["model_weights_dir"]` in the file
`main_config.py` to the appropriate paths of your system. Make sure that the path has a "/" at the end of the path.
File renamed without changes.
File renamed without changes.
23 changes: 14 additions & 9 deletions DeepCrazyhouse/src/training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,26 @@

### Prerequisites

Make sure to have a recent [MXNet](https://mxnet.incubator.apache.org/index.html) version with CUDA support installed:
```bash
pip install mxnet-cu<cuda_version>==<version_id>
```
Make sure to have a recent [Pytorch](https://pytorch.org/get-started/locally/) version with CUDA support installed.

For supervised training you need the following [additional libraries](https://github.com/QueensGambit/CrazyAra/blob/master/DeepCrazyhouse/src/training/requirements.txt):

```bash
pip install -r requirements.txt
```

* zarr (chunked, compressed, N-dimensional array library)
* numcodecs (compression codec library)
* tqdm: (progress bar library)
* MXBoard (logging MXNet data for visualization in TensorBoard)
#### Training with MXNet or Gluon (deprecated)
Make sure to have a recent [MXNet](https://mxnet.incubator.apache.org/index.html) version with CUDA support installed:
```bash
pip install mxnet-cu<cuda_version>==<version_id>
```

You need to install the following libraries when training with MXNet:
```bash
pip install -y mxboard
pip uninstall -y onnx
pip install onnx==1.3.0
```

### Training data specification
Specify the directories `"planes_train_dir"`, `"planes_val_dir"`, `"planes_test_dir"`, `"planes_mate_in_one_dir"` at
Expand All @@ -32,7 +37,7 @@ Use `train_cnn.ipynb` to conduct a training run.
* <https://jupyter.org/install.html>

Jupyter notebooks are displayed in a web-browser and can be launched with `jupyter notebook` from the command line.
After a successfull training run you can export the outputs as a html-file: `File->Download as->Html(.html)`.
After a successful training run you can export the outputs as a html-file: `File->Download as->Html(.html)`.

### Tensorboard
The [tensorboard](https://github.com/tensorflow/tensorboard) log files will be exported in `./logs` which can be viewed with tensorboard during training.
Expand Down
8 changes: 3 additions & 5 deletions DeepCrazyhouse/src/training/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
torch
numpy
python-chess==0.23.11
zarr
numcodecs
ipywidgets
matplotlib
tqdm
mxboard
tensorflow
tensorboard
onnx==1.3.0
onnx
ipytree
setproctitle
rtpt
dataclasses
mxnet
torch
torchsummary
onnxruntime
onnxsim
fvcore
Loading

0 comments on commit 608e958

Please sign in to comment.