Skip to content

Java handwritten digit (MNIST dataset) classifier, using the GPU for accelerated training and inference (via JCuda & JCublas)

Notifications You must be signed in to change notification settings

Jazz-Coding/NeuralNetwork4J_CUDA

Repository files navigation

NeuralNetwork4J_CUDA

Handwritten digit classifier written in Java using the GPU for accelerated training and inference. Trained on the MNIST handwritten digit dataset (included in .csv format split into training and test data).

  • JCublas and JCuda libraries serve as interface with native Cublas and Cuda libraries (versions >=12.0 must be installed beforehand).

By using the GPU, significant speedups over CPU-based training are achieved.

The main code for training/inference can be found in nn/gpu/NN_GPU.java

Currently available

Layer Types:

- Fully-connected

Training Algorithms:

- Stochastic Gradient Descent (SGD)

Cost Functions:

- Mean-squared error (MSE)

Activation Functions:

- Sigmoid

Parameter saving: Saving/loading from custom local file formats (examples are provided under "saved_networks").

Example Performance

Network Performance

Test Accuracy=95.96%

Network specifications: 784x32x10 (i.e. a single hidden layer with 32 neurons)

(AKA. saved_networks/tiny.txt)

Hyper-parameters: Batch size=32, Learning rate = 0.1

Training took approximately 1 minute on my machine (with a RTX 4090 GPU). GPU utilization peaked at only 20% on such a small network, but larger networks such as 784x3000x10 (saved_networks/wide.txt) take the same amount of time to train and utilize nearly 100%.

About

Java handwritten digit (MNIST dataset) classifier, using the GPU for accelerated training and inference (via JCuda & JCublas)

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages