In [49]:
import benchmark
from memory import memset_zero
from random import rand

In [50]:
alias type = DType.float32

struct Matrix[rows: Int, cols: Int]:
    var data: DTypePointer[type]

    # Initialize zeroeing all values
    fn __init__(inout self):
        self.data = DTypePointer[type].alloc(rows * cols)
        memset_zero(self.data, rows * cols)

    # Initialize taking a pointer, don't set any elements
    fn __init__(inout self, data: DTypePointer[type]):
        self.data = data

    # Initialize with random values
    @staticmethod
    fn rand() -> Self:
        var data = DTypePointer[type].alloc(rows * cols)
        rand(data, rows * cols)
        return Self(data)

    fn __getitem__(self, y: Int, x: Int) -> Scalar[type]:
        return self.load[1](y, x)

    fn __setitem__(self, y: Int, x: Int, val: Scalar[type]):
        self.store[1](y, x, val)

    fn load[nelts: Int](self, y: Int, x: Int) -> SIMD[type, nelts]:
        return self.data.load[width=nelts](y * self.cols + x)

    fn store[nelts: Int](self, y: Int, x: Int, val: SIMD[type, nelts]):
        return self.data.store[width=nelts](y * self.cols + x, val)
    fn __del__(owned self):
        self.data.free()
    fn __copyinit__(inout self, existing: Self):
        self.data = DTypePointer[type].alloc(rows * cols)
        memcpy[rows * cols](self.data, existing.data)
        

In [51]:
import math

In [52]:

struct Column[nelts: Int]:
    alias Col = Matrix[nelts, 1]
    var elements: Self.Col
    fn __init__(inout self):
        self.elements.__init__()
    fn __init__(inout self, matrix: Matrix[nelts, 1]):
        self.elements = matrix
    @staticmethod
    fn rand() -> Self:
        return Self.Col.rand()
    fn __getitem__(self, i: Int) -> Scalar[type]:
        return self.elements[i, 0]
    fn __setitem__(inout self, i: Int, val: Scalar[type]):
        self.elements[i, 0] = val
    fn __copyinit__(inout self, existing: Self):
        self.elements = existing.elements
    fn softmax(self) -> Self:
        var x_max: Scalar[type]  = math.limit.neginf[type]()
        for i in range(nelts):
            var x = self[i]
            if x > x_max:
                x_max = x
        var d: Scalar[type] = 0
        for i in range(nelts):
            var x = self[i]
            d += math.exp(x - x_max)
        var probs = Self()
        for i in range(nelts):
            var x = self[i]
            probs[i] = math.exp(x - x_max) / d
        return probs 
    fn softmax_online(self) -> Self:
        var m = math.limit.neginf[type]()
        var d: Scalar[type] = 0
        for i in range(nelts):
            var x = self[i]
            var m_prev = m
            if x > m:
                m = x
            d = d * math.exp(m_prev - m) + math.exp(x - m) 
        var probs = Self()
        for i in range(nelts):
            var x = self[i]
            probs[i] = math.exp(x - m) / d
        return probs
    fn __eq__(self, other: Self) -> Bool:
        for i in range(nelts):
            if self[i] != other[i]:
                return False
        return True



In [53]:
var logits = Column[10].rand()
print("logits", logits[0], logits[1], logits[2], logits[3])
var probs = logits.softmax()
print("probabilities\t", probs[0], probs[1], probs[2], probs[3])
var probs_online = logits.softmax_online()
var s: Scalar[type] = 0.0
for i in range(10):
    s += probs[i]
print("expected sum=1", "actual=", s)
print(probs == probs_online)

logits 0.98534935712814331 0.082506522536277771 0.42917641997337341 0.36413341760635376
probabilities	 0.15004785358905792 0.06083172932267189 0.086037337779998779 0.08061932772397995
expected sum=1 actual= 1.0
True


In [62]:
from benchmark import Unit
from benchmark.compiler import keep
alias nelts = 100000000
fn bench[func: fn (Column[nelts]) -> None, name: StringLiteral]():
    var logits = Column[nelts].rand()
    @always_inline
    @parameter
    fn test_fn():
        var probs = func(logits)
        keep(probs)
    
    var report = benchmark.run[test_fn]()
    report.print(Unit.ms)

In [63]:
fn naive(logits: Column[nelts]):
    _ = logits.softmax()
fn online(logits: Column[nelts]):
    _ = logits.softmax_online()
bench[naive, "naive"]()
bench[online, "online normalization"]()

---------------------
Benchmark Report (ms)
---------------------
Mean: 1.7e-14
Total: 1.7e-05
Iters: 1000000000
Warmup Mean: 1.5999999999999999e-05
Warmup Total: 3.1999999999999999e-05
Warmup Iters: 2
Fastest Mean: 1.7e-14
Slowest Mean: 1.7e-14

---------------------
Benchmark Report (ms)
---------------------
Mean: 1.7999999999999999e-14
Total: 1.8e-05
Iters: 1000000000
Warmup Mean: 1.5500000000000001e-05
Warmup Total: 3.1000000000000001e-05
Warmup Iters: 2
Fastest Mean: 1.7999999999999999e-14
Slowest Mean: 1.7999999999999999e-14

