Skip to content
Rust bindings for PyTorch
Branch: master
Clone or download
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
.cargo Add a simple policy gradient example. Mar 22, 2019
examples Move the gym environment to a separate module. Mar 23, 2019
gen
src Move the gym environment to a separate module. Mar 23, 2019
tests Tweak the policy gradient example. Mar 23, 2019
torch-sys Tensor conversions between the different formats. Mar 19, 2019
.gitignore Improve the build process Mar 3, 2019
.travis.yml Add some char-rnn output example + use xenial in travis. Mar 3, 2019
Cargo.toml
LICENSE Metadata update. Feb 27, 2019
README.md Move the low-level wrappers in their own module. Mar 15, 2019
clippy.toml Fix a couple clippy lints. Feb 20, 2019
dune-project Generate the rust wrapping code. Feb 17, 2019
rustfmt.toml rustfmt settings tweak. Mar 5, 2019

README.md

tch-rs

Rust bindings for PyTorch. The goal of the tch crate is to provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorch). It aims at staying as close as possible to the original C++ api. More idiomatic rust bindings could then be developed on top of this. The documentation can be found on docs.rs.

Build Status Latest version Documentation License

The code generation part for the C api on top of libtorch comes from ocaml-torch.

Getting Started

This crate requires the C++ version of PyTorch (libtorch) to be available on your system. You can either install it manually and let the build script know about it via the LIBTORCH environment variable. If not set, the build script will try downloading and extracting a pre-built binary version of libtorch.

Libtorch Manual Install

export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
  • You should now be able to run some examples, e.g. cargo run --example basics.

Examples

Writing a Simple Neural Network

The following code defines a simple model with one hidden layer.

struct Net {
    fc1: nn::Linear,
    fc2: nn::Linear,
}

impl Net {
    fn new(vs: &mut nn::VarStore) -> Net {
        let fc1 = nn::Linear::new(vs, IMAGE_DIM, HIDDEN_NODES);
        let fc2 = nn::Linear::new(vs, HIDDEN_NODES, LABELS);
        Net { fc1, fc2 }
    }
}

impl nn::Module for Net {
    fn forward(&self, xs: &Tensor) -> Tensor {
        xs.apply(&self.fc1).relu().apply(&self.fc2)
    }
}

This model can be trained on the MNIST dataset by running the following command.

cargo run --example mnist

More details on the training loop can be found in the detailed tutorial.

Using some Pre-Trained Model

The pretrained-models example illustrates how to use some pre-trained computer vision model on an image. The weights - which have been extracted from the PyTorch implementation - can be downloaded here resnet18.ot and here resnet34.ot.

The example can then be run via the following command:

cargo run --example pretrained-models -- resnet18.ot tiger.jpg

This should print the top 5 imagenet categories for the image. The code for this example is pretty simple.

    // First the image is loaded and resized to 224x224.
    let image = imagenet::load_image_and_resize(image_file)?;

    // A variable store is created to hold the model parameters.
    let vs = tch::nn::VarStore::new(tch::Device::Cpu);

    // Then the model is built on this variable store, and the weights are loaded.
    let resnet18 = tch::vision::resnet::resnet18(vs.root(), imagenet::CLASS_COUNT);
    vs.load(weight_file)?;

    // Apply the forward pass of the model to get the logits and convert them
    // to probabilities via a softmax.
    let output = resnet18
        .forward_t(&image.unsqueeze(0), /*train=*/ false)
        .softmax(-1);

    // Finally print the top 5 categories and their associated probabilities.
    for (probability, class) in imagenet::top(&output, 5).iter() {
        println!("{:50} {:5.2}%", class, 100.0 * probability)
    }

Further examples include:

You can’t perform that action at this time.