In [1]:
import benchmark
from memory import memset_zero
from random import rand

In [2]:
alias type = DType.float32

struct Matrix[rows: Int, cols: Int]:
    var data: DTypePointer[type]

    # Initialize zeroeing all values
    fn __init__(inout self):
        self.data = DTypePointer[type].alloc(rows * cols)
        memset_zero(self.data, rows * cols)

    # Initialize taking a pointer, don't set any elements
    fn __init__(inout self, data: DTypePointer[type]):
        self.data = data

    # Initialize with random values
    @staticmethod
    fn rand() -> Self:
        var data = DTypePointer[type].alloc(rows * cols)
        rand(data, rows * cols)
        return Self(data)

    fn __getitem__(self, y: Int, x: Int) -> Scalar[type]:
        return self.load[1](y, x)

    fn __setitem__(self, y: Int, x: Int, val: Scalar[type]):
        self.store[1](y, x, val)

    fn load[nelts: Int](self, y: Int, x: Int) -> SIMD[type, nelts]:
        return self.data.load[width=nelts](y * self.cols + x)

    fn store[nelts: Int](self, y: Int, x: Int, val: SIMD[type, nelts]):
        return self.data.store[width=nelts](y * self.cols + x, val)
    fn __del__(owned self):
        self.data.free()
    fn __copyinit__(inout self, existing: Self):
        self.data = DTypePointer[type].alloc(rows * cols)
        memcpy[rows * cols](self.data, existing.data)
        

In [3]:
import math

In [4]:

struct Column[nelts: Int]:
    alias Col = Matrix[nelts, 1]
    var elements: Self.Col
    fn __init__(inout self):
        self.elements.__init__()
    fn __init__(inout self, matrix: Matrix[nelts, 1]):
        self.elements = matrix
    @staticmethod
    fn rand() -> Self:
        return Self.Col.rand()
    fn __getitem__(self, i: Int) -> Scalar[type]:
        return self.elements[i, 0]
    fn __setitem__(inout self, i: Int, val: Scalar[type]):
        self.elements[i, 0] = val
    fn __copyinit__(inout self, existing: Self):
        self.elements = existing.elements
    fn softmax_unroll(self) -> Self:
        var x_max: Scalar[type]  = math.limit.neginf[type]()
        @unroll(20)
        for i in range(nelts):
            var x = self[i]
            if x > x_max:
                x_max = x
        var d: Scalar[type] = 0
        @unroll(20)
        for i in range(nelts):
            var x = self[i]
            d += math.exp(x - x_max)
        var probs = Self()
        @unroll(20)
        for i in range(nelts):
            var x = self[i]
            probs[i] = math.exp(x - x_max) / d
        return probs 
    fn softmax(self) -> Self:
        var x_max: Scalar[type]  = math.limit.neginf[type]()
        for i in range(nelts):
            var x = self[i]
            if x > x_max:
                x_max = x
        var d: Scalar[type] = 0
        for i in range(nelts):
            var x = self[i]
            d += math.exp(x - x_max)
        var probs = Self()
        for i in range(nelts):
            var x = self[i]
            probs[i] = math.exp(x - x_max) / d
        return probs 
    fn softmax_online(self) -> Self:
        var m = math.limit.neginf[type]()
        var d: Scalar[type] = 0
        for i in range(nelts):
            var x = self[i]
            var m_prev = m
            if x > m:
                m = x
            d = d * math.exp(m_prev - m) + math.exp(x - m) 
        var probs = Self()
        for i in range(nelts):
            var x = self[i]
            probs[i] = math.exp(x - m) / d
        return probs
    fn __eq__(self, other: Self) -> Bool:
        for i in range(nelts):
            if self[i] != other[i]:
                return False
        return True



In [5]:
var logits = Column[10].rand()
print("logits", logits[0], logits[1], logits[2], logits[3])
var probs = logits.softmax()
print("probabilities\t", probs[0], probs[1], probs[2], probs[3])
var probs_online = logits.softmax_online()
var probs_unrolled = logits.softmax_unroll()
var s: Scalar[type] = 0.0
for i in range(10):
    s += probs[i]
print("expected sum=1", "actual=", s)
print(probs == probs_online)
print(probs == probs_unrolled)

logits 0.1315377950668335 0.458650141954422 0.21895918250083923 0.67886471748352051
probabilities	 0.076156176626682281 0.10562536865472794 0.083113528788089752 0.1316455602645874
expected sum=1 actual= 1.0
False
True


In [6]:
from benchmark import Unit
from benchmark.compiler import keep
alias nelts = 10000000
fn bench[func: fn (Column[nelts]) -> None, name: StringLiteral]():
    var logits = Column[nelts].rand()
    @always_inline
    @parameter
    fn test_fn():
        var probs = func(logits)
        keep(probs)
    
    var report = benchmark.run[test_fn]()
    report.print(Unit.ms)

In [7]:
fn naive(logits: Column[nelts]):
    var probs = logits.softmax()
    keep(probs)
fn naive_unrolled(logits: Column[nelts]):
    var probs = logits.softmax_unroll()
    keep(probs)
fn online(logits: Column[nelts]):
    var probs = logits.softmax_online()
    keep(probs)
bench[naive, "naive"]()
bench[naive_unrolled, "naive unrolled"]()
bench[online, "online normalization"]()

---------------------
Benchmark Report (ms)
---------------------
Mean: 100.56180041666666
Total: 2413.4832099999999
Iters: 24
Warmup Mean: 99.403483499999993
Warmup Total: 198.80696699999999
Warmup Iters: 2
Fastest Mean: 100.56180041666667
Slowest Mean: 100.56180041666667

---------------------
Benchmark Report (ms)
---------------------
Mean: 92.765476359999994
Total: 2319.1369089999998
Iters: 25
Warmup Mean: 93.467702000000003
Warmup Total: 186.93540400000001
Warmup Iters: 2
Fastest Mean: 92.765476359999994
Slowest Mean: 92.765476359999994

---------------------
Benchmark Report (ms)
---------------------
Mean: 123.03981573684212
Total: 2337.7564990000001
Iters: 19
Warmup Mean: 123.4338005
Warmup Total: 246.86760100000001
Warmup Iters: 2
Fastest Mean: 1.7976931348623157e+308
Slowest Mean: 123.03981573684212

