Skip to content

Commit

Permalink
Strange silent compiler exit + nim-lang/Nim#16653
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jan 9, 2021
1 parent 8af1ead commit ec43ff1
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 3 deletions.
29 changes: 29 additions & 0 deletions flambeau/cpp/std_cpp.nim
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,32 @@ func `$`*(s: CppString): string =
copyMem(result[0].addr, s.data.unsafeAddr, s.len)

{.pop.}

# std::shared_ptr<T>
# -----------------------------------------------------------------------

{.push header: "<memory>".}

type
CppSharedPtr* {.importcpp: "std::shared_ptr", bycopy.} [T] = object

func make_shared*(T: typedesc): CppSharedPtr[T] {.importcpp: "std::make_shared<'*0>()".}

{.pop.}

# std::vector<T>
# -----------------------------------------------------------------------

{.push header: "<memory>".}

type
CppVector* {.importcpp"std::vector", header: "<vector>", bycopy.} [T] = object

proc init*(V: type CppVector): V {.importcpp: "std::vector<'*0>()", header: "<vector>", constructor.}
proc init*(V: type CppVector, size: int): V {.importcpp: "std::vector<'*0>(#)", header: "<vector>", constructor.}
proc len*(v: CppVector): int {.importcpp: "#.size()", header: "<vector>".}
proc add*[T](v: var CppVector[T], elem: T){.importcpp: "#.push_back(#)", header: "<vector>".}
proc `[]`*[T](v: CppVector[T], idx: int): T{.importcpp: "#[#]", header: "<vector>".}
proc `[]`*[T](v: var CppVector[T], idx: int): var T{.importcpp: "#[#]", header: "<vector>".}

{.pop.}
101 changes: 101 additions & 0 deletions flambeau/raw_bindings/data_api.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Flambeau
# Copyright (c) 2020 Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
./tensors,
../cpp/std_cpp

# (Almost) raw bindings to PyTorch Data API
# -----------------------------------------------------------------------
#
# This provides almost raw bindings to PyTorch data API.
#
# "Nimification" (camelCase), ergonomic indexing and interoperability with Nim types is left to the "high-level" bindings.
# This should ease searching PyTorch and libtorch documentation,
# and make C++ tutorials easily applicable.

# #######################################################################
#
# Datasets
#
# #######################################################################
#
# Custom Dataset example: https://github.com/mhubii/libtorch_custom_dataset
# libtorch/include/torch/csrc/api/include/torch/data/datasets/base.h

type
Example*{.bycopy, importcpp: "torch::data::Example".}
[Data, Target] = object
data*: Data
target*: Target

# TODO: https://github.com/nim-lang/Nim/issues/16653
# generics + {.inheritable.} doesn't work
BatchDataset*
{.bycopy, pure, inheritable,
importcpp: "torch::data::datasets::BatchDataset".}
# [Self, Batch, BatchRequest] # TODO: generic inheritable https://github.com/nim-lang/Nim/issues/16653
= object
## A BatchDataset type
## Self: is the class type that implements the Dataset API
## (using the Curious Recurring Template Pattern in underlying C++)
## Batch is by default the type CppVector[T]
## with T being Example[Data, Target]
## BatchRequest is by default ArrayRef[csize_t]

Dataset*
{.bycopy, pure,
importcpp: "torch::data::datasets::Dataset".}
[Self, Batch]
= object of BatchDataset # [Self, Batch, ArrayRef[csize_t]]
## A Dataset type
## Self: is the class type that implements the Dataset API
## (using the Curious Recurring Template Pattern in underlying C++)
## Batch is by default the type CppVector[T]
## with T being Example[Data, Target]

Mnist*
{.bycopy, pure,
importcpp: "torch::data::datasets::MNIST".}
= object of Dataset[Mnist, CppVector[Example[Tensor, Tensor]]]
## The MNIST dataset
## http://yann.lecun.com/exdb/mnist

MnistMode* {.size:sizeof(cint),
importcpp:"torch::data::datasets::MNIST::Mode".} = enum
## Select the train or test mode of the Mnist data
kTrain = 0
kTest = 1

func mnist*(rootPath: cstring, mode = kTrain): Mnist {.constructor, importcpp:"MNIST(@)".}
## Loads the MNIST dataset from the `root` path
## The supplied `rootpath` should contain the *content* of the unzipped
## MNIST dataset, available from http://yann.lecun.com/exdb/mnist.
func get*(dataset: Mnist, index: int): Example[Tensor, Tensor] {.importcpp:"#.get(#)".}
# func size*(dataset: Mnist): optional(int)
func is_train*(): bool {.importcpp:"#.is_train()".}
func images*(dataset: Mnist): lent Tensor {.importcpp: "#.images()".}
## Returns all images stacked into a single tensor
func targets*(dataset: Mnist): lent Tensor {.importcpp: "#.targets()".}

# #######################################################################
#
# Dataloader
#
# #######################################################################

# #######################################################################
#
# Samplers
#
# #######################################################################

# #######################################################################
#
# Samplers
#
# #######################################################################
30 changes: 27 additions & 3 deletions flambeau/raw_bindings/neural_nets.nim
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ import ./tensors
# It is suitable for layers with no learning parameters (for example reshaping),
# or when extra flexibility is required at a small price of ergonomics.
# The high-level Module API uses Functional internally.
#
# Note:
# Function exists both in ATen TensorBody.h (namespace at:: or torch::)
# and in torch::nn::functional.
#
# We can have
# func dropout*(input: Tensor, p = 0.5, training=true): Tensor {.importcpp: "torch::nn::functional::dropout(@)".}
# func dropout_mut*(input: var Tensor, p = 0.5, training=true) {.importcpp: "torch::nn::functional::dropout(@, /*inplace=*/ true)".}
#
# OR
#
# func dropout*(input: Tensor, p = 0.5, training=true): Tensor {.importcpp: "torch::dropout(@)".}
# func dropout_mut*(input: var Tensor, p = 0.5, training=true) {.importcpp: "torch::dropout_(@)".}
#
# The functions in torch::nn::functional are thin inlined wrapper over TensorBody.h
# so we directly use them.

# Linear Layers
# -------------------------------------------------------------------------
Expand All @@ -53,15 +69,23 @@ func linear*(input, weight, bias: Tensor): Tensor {.importcpp: "torch::nn::funct
## Bias: (out_features)
## Output: (N,∗,out_features)

# Dropout functions
# Activation functions
# -------------------------------------------------------------------------

# func dropout*(input: Tensor, p = 0.5, training=true): Tensor {.importcpp: "torch::nn::functional::dropout(@)".}
# func dropout_mut*(input: var Tensor, p = 0.5, training=true) {.importcpp: "torch::nn::functional::dropout(@, /*inplace=*/ true)".}
func relu*(input: Tensor): Tensor {.importcpp: "torch::relu(@)".}
func relu_mut*(input: var Tensor) {.importcpp: "torch::relu_(@)".}

# Dropout functions
# -------------------------------------------------------------------------

func dropout*(input: Tensor, p = 0.5, training=true): Tensor {.importcpp: "torch::dropout(@)".}
func dropout_mut*(input: var Tensor, p = 0.5, training=true) {.importcpp: "torch::dropout_(@)".}

# Loss functions
# -------------------------------------------------------------------------

func log_softmax*(input: Tensor, axis: int64): Tensor {.importcpp: "torch::log_softmax(@)".}
func log_softmax*(input: Tensor, axis: int64, dtype: ScalarKind): Tensor {.importcpp: "torch::log_softmax(@)".}

# #######################################################################
#
Expand Down
13 changes: 13 additions & 0 deletions proof_of_concepts/poc07_datasets.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import
../flambeau/raw_bindings/[
data_api, tensors
]

let mnist = mnist("build/mnist")

echo "Data"
# mnist.get(0).data.print()
# echo "\n-----------------------"
# echo "Target"
# mnist.get(0).target.print()
# echo "\n-----------------------"

0 comments on commit ec43ff1

Please sign in to comment.