Skip to content

Grokking Deep Learning Examples implemented in Rust

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT
Notifications You must be signed in to change notification settings

RustStudy/grokking-deep-learning-rs

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

59 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Grokking Deep Learning Rust

Build Status

The exercises from the @iamtrask book Grokking Deep Learning implemented in rust.

This crate isn't published, because ideally you'd do this on your own, but if you insist

cargo add grokking_deep_learning_rs --git https://github.com/suyash/grokking-deep-learning-rs

This crate is structured as a library, with the core library describing some common primitives used throughout and the individual chapters implemented in the exercises. To run the exercises from a particular chapter, for example chapter 12

cargo run --example chapter12

Currently this uses rulinalg for matrix operations, which uses a Rust implementation of dgemm and provides a 3x performance over normal ijk multiplication (see included benchmark). However, it still isn't as fast as numpy because it isn't multi-threaded. Currently working on something of my own.

The datasets are extracted into a separate library crate, which currently provides functions for loading 4 datasets, and an iterator for batching and shuffling. Planning to add more. Can be added using

cargo add datasets --git https://github.com/suyash/datasets

As a result of slower matmul, chapter 8 onwards, certain examples are smaller in size compared to the python examples.

The Chapter 13 core components were extracted into the core library, so they could be used in later chapters.

So, something like

use rulinalg::matrix::Matrix;

use grokking_deep_learning_rs::activations::{Sigmoid, Tanh};
use grokking_deep_learning_rs::layers::{Layer, Linear, Sequential};
use grokking_deep_learning_rs::losses::{Loss, MSELoss};
use grokking_deep_learning_rs::optimizers::{Optimizer, SGDOptimizer};
use grokking_deep_learning_rs::tensor::Tensor;

let data = Tensor::new_const(Matrix::new(
    4,
    2,
    vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0],
));

let target = Tensor::new_const(Matrix::new(4, 1, vec![0.0, 1.0, 0.0, 1.0]));

let model = Sequential::new(vec![
    Box::new(Linear::new(2, 3)),
    Box::new(Tanh),
    Box::new(Linear::new(3, 1)),
    Box::new(Sigmoid),
]);

let criterion = MSELoss;
let optim = SGDOptimizer::new(model.parameters(), 0.5);

for _ in 0..10 {
    let pred = model.forward(&[&data]);

    // compare
    let loss = criterion.forward(&pred[0], &target);

    println!("Loss: {:?}", loss.0.borrow().data.data());

    // calculate difference
    loss.backward(Tensor::grad(Matrix::ones(1, 1)));

    // learn
    optim.step(true);
}

In Chapter 14, the RNN and LSTM examples have vanishing gradients and loss keeps going to NaN. There seems to be some kind of logic bomb in the code, where something is not doing what I think it does, still investigating. I tried reproducing the problem in chapter 13 final exercise and also implemented min-char-rnn.py in Rust, but no luck so far.

For Chapter 15, the encrypted federated learning exercise is not implemented. There does exist a crate for paillier homomorphic crypto, but the current implementation only works with integers and BigInts, not floating point numbers. Will try to see how to get it to work.

License

This project is licensed under either of

at your option.

Contribution

Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in this work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions.

About

Grokking Deep Learning Examples implemented in Rust

Resources

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages

  • Rust 100.0%