Skip to content

Alanperry1/mini-tensor

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mini-tensor

A from-scratch C++17 tensor library with reverse-mode automatic differentiation — built like a stripped-down PyTorch/TensorFlow core. No runtime dependencies beyond the standard library (optional BLAS for matmul acceleration).

32 tests · 148 assertions · XOR MLP converges to loss < 1e-4


Table of Contents

  1. Why
  2. Features
  3. Build
  4. Quick Start
  5. Architecture Deep Dive
  6. API Reference
  7. Design Decisions
  8. Testing
  9. Project Layout
  10. License

Why

Most "build your own autograd" tutorials are ~200 lines of Python. This goes further: a full C++17 implementation with real stride-based memory layout, zero-copy views, broadcast-aware backward passes, and production-quality numerical stability. It's the kind of thing you'd write if you wanted to understand exactly what happens between loss.backward() and the weight update — at the metal level.


Features

Area What's included
Tensor N-d arrays, C-order strides, zero-copy transpose/reshape/view, offset-based slicing
Broadcasting NumPy-compatible shape inference, broadcast_to, unbroadcast for backward
Autograd Reverse-mode AD, dynamic computation graph, iterative DFS topo sort (stack-safe)
Ops (forward+backward) add, sub, mul, div, neg, pow, matmul, sum, mean, relu, sigmoid, tanh, exp, log, reshape, transpose
nn Linear (Kaiming init), mse_loss, cross_entropy_loss (numerically stable softmax)
Optimizers SGD (with momentum + weight decay), Adam (bias-corrected)
Memory Shared ref-counted storage, lazy grad allocation, detached saves — no ref cycles

Build

Requirements: CMake ≥ 3.14, C++17 compiler (GCC 9+ / Clang 10+), Linux or macOS.

# Release build (recommended)
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build -j$(nproc)

# Debug build with AddressSanitizer
cmake -B build-debug -DCMAKE_BUILD_TYPE=Debug
cmake --build build-debug -j$(nproc)

Optional BLAS acceleration: if OpenBLAS or MKL is found, matmul automatically uses cblas_sgemm. Otherwise falls back to a cache-friendly i-k-j naive loop.

# Verify BLAS was detected
cmake -B build | grep -i blas

Run tests

./build/tensort_tests
[doctest] test cases: 32 | 32 passed | 0 failed
[doctest] assertions: 148 | 148 passed | 0 failed

Run the XOR demo

./build/xor_mlp

Quick Start

#include "tensort/tensor.h"
#include "tensort/functional.h"
#include "tensort/nn.h"

using namespace tensort;
using namespace tensort::nn;

int main() {
    // --- Tensor creation ---
    Tensor a = Tensor::randn({3, 4}, /*requires_grad=*/true);
    Tensor b = Tensor::ones({4, 2});

    // --- Forward pass (builds autograd graph implicitly) ---
    auto c = relu(matmul(a, b));   // shape {3, 2}
    auto loss = mean(c);            // scalar

    // --- Backward pass ---
    loss.backward();

    // --- Inspect gradient ---
    // a.grad()->shape() == {3, 4}

    // --- XOR MLP: 2 → 16 → 1 ---
    Linear fc1(2, 16), fc2(16, 1);

    auto p1 = fc1.parameters(), p2 = fc2.parameters();
    std::vector<Tensor*> params(p1.begin(), p1.end());
    params.insert(params.end(), p2.begin(), p2.end());

    Adam opt(params, 0.01f);

    Tensor X = Tensor::from_data({0,0, 0,1, 1,0, 1,1}, {4, 2});
    Tensor Y = Tensor::from_data({0, 1, 1, 0},          {4, 1});

    for (int epoch = 0; epoch < 2000; ++epoch) {
        opt.zero_grad();

        auto h    = relu(fc1(X));
        auto pred = sigmoid(fc2(h));
        auto loss = mse_loss(pred, Y);

        loss.backward();
        opt.step();
    }
    // Final loss < 1e-4
}

Architecture Deep Dive

Memory Model

Every Tensor is a lightweight handle over a ref-counted storage buffer:

Tensor
 ├── storage_   shared_ptr<vector<float>>   ← actual data (shared between views)
 ├── offset_    size_t                       ← byte start (for slicing)
 ├── shape_     vector<size_t>               ← logical dimensions
 └── strides_   vector<size_t>               ← element steps per dimension

Zero-copy operations: transpose, reshape, unsqueeze, squeeze all return new Tensor objects that point into the same storage_. No heap allocation, no data copy. Only make_contiguous() triggers a copy — and only when the layout is actually non-contiguous.

Flat index formula: element at multi-index [i₀, i₁, ..., iₙ] lives at storage[offset + Σ iₖ * stridesₖ].

Tensor a({1,2,3,4,5,6}, {2,3})   // row-major: strides = {3, 1}
Tensor t = a.t()                  // transpose: strides = {1, 3}, shape = {3,2}
// t and a share the same vector<float> — zero bytes copied

GradHolder: leaf tensors that requires_grad=true eagerly allocate a shared_ptr<GradHolder>. When you copy a Tensor (Tensor y = x), both copies share the same GradHolder, so gradients accumulate into a single slot regardless of how many aliases exist. This matches Python reference semantics without garbage collection.


Autograd Engine

The engine is a textbook reverse-mode AD implementation, built on three concepts:

1. GradFn — backward nodes

Every forward op that produces a grad-tracked output allocates a GradFn subclass on the heap:

AddBackward   : saves input shapes for unbroadcast
MulBackward   : saves both inputs (needed for cross-terms)
MatMulBackward: saves both inputs (a, b detached)
SigmoidBackward: saves output (σ(x), used in σ(1-σ) backward)
...

Each GradFn holds:

  • next_edges — links to parent GradFns (or AccumulateGrad for leaves)
  • saved_tensorsdetached copies of data needed for the backward formula, no shared_ptr to outputs (prevents reference cycles)

2. Graph construction (implicit, dynamic)

The graph is built lazily during the forward pass. Each op calls get_grad_fn(input) which either:

  • Returns input.grad_fn() if input is a non-leaf (already in the graph)
  • Creates (or returns cached) AccumulateGrad if input is a leaf requiring grad

The returned GradFn becomes an Edge in the new node's next_edges.

x ─[AccGrad]─┐
              ├─ AddBackward ─── MulBackward ─── SumBackward ──▶ loss
y ─[AccGrad]─┘

3. Backward pass — run_backward()

1. topo_sort(loss.grad_fn)   → DFS post-order, reversed = root-first order
2. grad_map[root] = ones     → seed gradient
3. For each GradFn in order:
     grads_out = fn.apply(grad_map[fn])
     For each edge → accumulate grads_out[i] into grad_map[edge.function]
4. AccumulateGrad.apply()    → writes final grad into leaf's GradHolder

Key properties:

  • Iterative DFS — uses an explicit stack, not system recursion. Safe for graphs thousands of nodes deep.
  • Gradient accumulation in grad_map — handles diamond-shaped graphs correctly (multiple paths to same node merged before that node's backward runs).
  • Single pass — each GradFn is visited exactly once.

Broadcasting

Broadcasting follows NumPy rules: shapes are right-aligned and dimensions are compatible if equal or one of them is 1.

{3, 1, 4} + {2, 4}  →  {3, 2, 4}

Forward: broadcast_to(t, target_shape) returns a strided view — dimensions that were size-1 get stride 0, so the same element is read repeatedly with zero memory overhead.

Backward: unbroadcast(grad, original_shape) sums the gradient over every dimension that was broadcast. This is the mathematically correct inverse: if a scalar was broadcast to a {3,4} tensor, its gradient is the sum of all 12 incoming grad elements.


Neural Network Module

Layers implement the Module interface:

class Module {
    virtual Tensor forward(const Tensor& input) = 0;
    virtual vector<Tensor*> parameters()        = 0;
    Tensor operator()(const Tensor& input);     // calls forward
    void   zero_grad();
};

Linear layery = x @ W.T + b

  • Weights initialized with Kaiming/He init: W ~ N(0, sqrt(2 / fan_in))
  • Bias initialized to zeros
  • parameters() returns {&weight, &bias} — raw pointers into the layer, so optimizers mutate in place

Cross-entropy loss — implemented with a fused CrossEntropyBackward that saves the softmax output. Backward is (softmax - one_hot) / batch_size — numerically stable and avoids double-computing softmax.

Adam optimizer — standard bias-corrected update:

$$m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t$$ $$v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2$$ $$\hat{m} = m_t / (1 - \beta_1^t), \quad \hat{v} = v_t / (1 - \beta_2^t)$$ $$\theta \leftarrow \theta - \alpha \cdot \hat{m} / (\sqrt{\hat{v}} + \epsilon)$$


API Reference

Tensor API

// Construction
Tensor::zeros({M, N}, requires_grad)
Tensor::ones({M, N})
Tensor::randn({M, N}, requires_grad)   // ~ N(0,1)
Tensor::full({M, N}, value)
Tensor::from_data({1,2,3,4}, {2,2})
Tensor(vector<float> data, Shape shape, requires_grad)

// Shape
t.shape()          // vector<size_t>
t.ndim()           // number of dimensions
t.numel()          // total elements
t.reshape({M, N})  // new shape, view if contiguous
t.t()              // 2D transpose, zero-copy view
t.transpose(d0,d1) // swap any two dims, zero-copy
t.unsqueeze(dim)   // insert size-1 dim
t.squeeze(dim)     // remove size-1 dim
t.make_contiguous()// copy iff non-contiguous

// Data access
t.data_ptr()       // float* to first element
t.item()           // scalar value (asserts numel==1)
t.item(flat_idx)   // flat index access (contiguous only)
t.at({i, j})       // strided multi-index access

// Autograd
t.requires_grad()
t.grad()           // shared_ptr<Tensor>, nullptr if no grad yet
t.zero_grad()      // reset gradient
t.backward()       // trigger backward from scalar loss
t.grad_fn()        // GradFn of non-leaf tensors

Ops & Functional

// Elementwise binary (broadcasting supported)
add(a, b)   a + b
sub(a, b)   a - b
mul(a, b)   a * b
div(a, b)   a / b

// Also works with float scalars:
a + 2.0f,   3.0f * a,   a / 255.0f

// Unary
neg(a)            // -a
pow(a, exponent)  // a^n elementwise
relu(a)           // max(0, x)
sigmoid(a)        // 1 / (1 + e^-x)
tanh_(a)          // tanh(x)
exp_(a)           // e^x
log_(a)           // ln(x)

// Matrix multiplication
matmul(a, b)  // {M,K} @ {K,N} → {M,N}

// Reductions
sum(a)                    // total sum → scalar
sum(a, /*axis=*/1)        // sum along axis 1
sum(a, 1, /*keepdim=*/true)
mean(a)
mean(a, axis, keepdim)

nn Module

// Layers
nn::Linear fc(in_features, out_features, /*bias=*/true);
Tensor out = fc(input);          // operator() calls forward
Tensor out = fc.forward(input);
vector<Tensor*> p = fc.parameters(); // {&weight, &bias}

// Losses
Tensor loss = nn::mse_loss(prediction, target);
Tensor loss = nn::cross_entropy_loss(logits, {0, 2, 1}); // class indices

// Optimizers
nn::SGD  opt(params, lr, momentum=0.0f, weight_decay=0.0f);
nn::Adam opt(params, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8);
opt.zero_grad();
opt.step();

Design Decisions

Why shared_ptr<vector<float>> and not raw pointers?
Views (transpose, reshape) and the grad system both need shared ownership. shared_ptr gives automatic lifetime management with no GC. The cost is one atomic ref-count per operation — negligible at float32 tensor granularity.

Why iterative DFS in topo_sort?
Recursive DFS blows the call stack on deep networks (transformers, long RNNs). The iterative version uses an explicit Frame stack and is equivalent but stack-safe.

Why GradHolder instead of a raw Tensor* in AccumulateGrad?
Tensor y = x in C++ creates a copy. If AccumulateGrad stored &x, then gradients destined for x (via y's backward path) would write to a dangling pointer after y is overwritten. GradHolder is a shared_ptr allocated at leaf construction time and copied by value semantics — all aliases share the same grad slot.

Why detached copies in saved_tensors?
If saved_tensors held the original output Tensor (which holds a shared_ptr<GradFn>), and GradFn held the output, you'd have a cycle → memory leak. Detached copies break the cycle: they carry data but no grad tracking.

Why float not double?
Same choice as PyTorch's default: float32 hits the GPU SIMD width, is 2× smaller than float64, and is sufficient for everything in the nn module. The numerical grad check uses double precision for the finite-difference division to avoid cancellation error.

Why naive matmul when no BLAS?
The i-k-j loop order keeps B[k][j] and C[i][j] in L1 cache across the inner j-loop (row-major access), which is ~3–5× faster than the naive i-j-k order. For production you'd link BLAS — this is detected automatically.


Testing

Tests use doctest (fetched at configure time, no manual install needed).

tests/test_ops.cpp      — 16 numerical gradient checks
tests/test_autograd.cpp — 10 graph correctness tests
tests/test_nn.cpp       — 6 integration tests

Numerical gradient check — for every differentiable op, the analytic gradient is verified against the central-difference approximation:

$$\frac{\partial L}{\partial x_i} \approx \frac{f(x + \epsilon e_i) - f(x - \epsilon e_i)}{2\epsilon}$$

Uses absolute + relative tolerance (atol + rtol × |numerical|) to handle large-magnitude outputs (e.g. matmul on values 7–12 where float32 cancellation gives ~0.01 error at eps=1e-3).

Autograd correctness tests include: chain rule, diamond graph, relu dead neurons, sigmoid saturation, matmul chain, multi-path gradient accumulation, long chain (y = x+x+x...×10 → dy/dx=10).

XOR convergence — the hardest integration test. A 2→16→1 MLP with ReLU + Sigmoid trained for 2000 epochs with Adam must reach loss < 0.02. Typically converges to < 1e-4.

Run with full verbosity:

./build/tensort_tests -s          # show all assertions
./build/tensort_tests -tc="grad*" # filter by name

Project Layout

tensort-implementation/
├── CMakeLists.txt             build config, FetchContent doctest, BLAS detection
├── include/tensort/
│   ├── tensor.h               Tensor class, Shape/Strides, broadcasting utilities
│   ├── autograd.h             GradFn, GradHolder, AccumulateGrad, Edge, engine API
│   ├── ops.h                  GradFn subclass declarations (one per op)
│   ├── functional.h           Forward op free functions + operator overloads
│   └── nn.h                   Module, Linear, losses, SGD, Adam
├── src/
│   ├── tensor.cpp             Constructors, factories, shape ops, broadcast impl
│   ├── autograd.cpp           AccumulateGrad::apply, topo_sort, run_backward
│   ├── ops.cpp                All backward formulas
│   ├── functional.cpp         All forward ops + autograd graph construction
│   └── nn.cpp                 Linear, losses, optimizers
├── tests/
│   ├── test_ops.cpp           grad_check utility + 16 op tests
│   ├── test_autograd.cpp      10 graph behavior tests
│   └── test_nn.cpp            6 nn/optimizer tests
└── examples/
    └── xor_mlp.cpp            Full training loop demo

License

MIT

About

A from-scratch C++17 tensor library with reverse-mode automatic differentiation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors