Skip to content

Commit

Permalink
added softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielJDufour committed Nov 23, 2022
1 parent a031a8d commit e8a6e55
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 2 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ The following functions are supported:
- [round](#round)
- [root](#root)
- [sign](#sign)
- [softmax](#softmax)
- [sort](#sort)
- [square_root](#square_root)
- [subtract](#subtract)
Expand Down Expand Up @@ -300,6 +301,16 @@ sort(["1", "2", "3"], { direction: "descending" })
["3", "2", "1"]
```

### softmax
Calculate the [softmax function](https://en.wikipedia.org/wiki/Softmax_function)
```js
import softmax from "preciso/softmax.js";

// example data from https://en.wikipedia.org/wiki/Softmax_function
softmax(["1", "2", "3", "4", "1", "2", "3"], { max_decimal_digits: 8 });
["0.02364054", "0.06426166", "0.1746813", "0.474833", "0.02364054", "0.06426166", "0.1746813"]
```

### square root
```js
import square_root from "preciso/square_root.js";
Expand Down
11 changes: 11 additions & 0 deletions build_e.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
const start = performance.now();
console.log("[preciso] building Euler's Number");

const fs = require("node:fs");
const e = require("./eulers_number.js");

const str = e({ steps: 1000, max_decimal_digits: 1000 });

fs.writeFileSync("./constants/e.js", `module.exports = { E: "${str}" };\n`);

console.log("[preciso] building Euler's Number took " + Math.round((performance.now() - start) / 1000) + " seconds");
1 change: 1 addition & 0 deletions constants/e.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions exp.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ const is_zero = require("./is_zero.js");
const eulers_number = require("./eulers_number.js");
const pow = require("./pow.js");

function exp(power, { max_decimal_digits } = {}) {
const e = eulers_number(100);
function exp(power, { max_decimal_digits = 100 } = {}) {
const e = eulers_number({ max_decimal_digits: 2 * max_decimal_digits });

if (is_negative_infinity(power)) return "0";
if (is_positive_infinity(power)) return "Infinity";
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"round_last_decimal.js",
"sign.js",
"sign_nonzero.js",
"softmax.js",
"sort.js",
"square_root.js",
"subtract.js",
Expand Down
3 changes: 3 additions & 0 deletions preciso.js
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ const root = require("./root.js");
const root_integer_digits = require("./root_integer_digits.js");
const round_last_decimal = require("./round_last_decimal.js");

const softmax = require("./softmax.js");

const square_root = require("./square_root.js");
const subtract = require("./subtract.js");
const long_subtraction = require("./long_subtraction.js");
Expand Down Expand Up @@ -128,6 +130,7 @@ const module_exports = {
sign,
sign_nonzero,

softmax,
sort,

sum,
Expand Down
16 changes: 16 additions & 0 deletions softmax.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"use strict";

const divide = require("./divide.js");
const exp = require("./exp.js");
const sum = require("./sum.js");

function softmax(vector, { max_decimal_digits }) {
vector = vector.map(n => exp(n, { max_decimal_digits }));

const total = sum(vector);

return vector.map(n => divide(n, total, { max_decimal_digits, ellipsis: false }));
}

module.exports = softmax;
module.exports.default = softmax;
7 changes: 7 additions & 0 deletions test.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const {
round,
round_last_decimal,
sign,
softmax,
sort,
square_root,
sum,
Expand All @@ -54,6 +55,12 @@ const {

// const nthroot = (radicand, root) => Math.pow(radicand, 1 / root);

test("softmax", ({ eq }) => {
const actual = ["1", "2", "3", "4", "1", "2", "3"];
const expected = ["0.02364054", "0.06426166", "0.1746813", "0.474833", "0.02364054", "0.06426166", "0.1746813"];
eq(softmax(actual, { max_decimal_digits: 8 }), expected);
});

test("root", ({ eq }) => {
eq(root("-9", "1"), "-9");
eq(root("-9", "2"), "3i");
Expand Down

0 comments on commit e8a6e55

Please sign in to comment.