-
Notifications
You must be signed in to change notification settings - Fork 83
/
OneHot.hs
94 lines (81 loc) · 2.83 KB
/
OneHot.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
module Grenade.Utils.OneHot (
oneHot
, hotMap
, makeHot
, unHot
, sample
) where
import qualified Control.Monad.Random as MR
import Data.List ( group, sort )
import Data.Map ( Map )
import qualified Data.Map as M
import Data.Proxy
import Data.Vector ( Vector )
import qualified Data.Vector as V
import qualified Data.Vector.Storable as VS
import Numeric.LinearAlgebra ( maxIndex )
import Numeric.LinearAlgebra.Devel
import Numeric.LinearAlgebra.Static
import GHC.TypeLits
import Grenade.Core.Shape
-- | From an int which is hot, create a 1D Shape
-- with one index hot (1) with the rest 0.
-- Rerurns Nothing if the hot number is larger
-- than the length of the vector.
oneHot :: forall n. (KnownNat n)
=> Int -> Maybe (S ('D1 n))
oneHot hot =
let len = fromIntegral $ natVal (Proxy :: Proxy n)
in if hot < len
then
fmap S1D . create $ runSTVector $ do
vec <- newVector 0 len
writeVector vec hot 1
return vec
else Nothing
-- | Create a one hot map from any enumerable.
-- Returns a map, and the ordered list for the reverse transformation
hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Either String (Map a Int, Vector a)
hotMap n as =
let len = fromIntegral $ natVal n
uniq = [ c | (c:_) <- group $ sort as]
hotl = length uniq
in if hotl == len
then
Right (M.fromList $ zip uniq [0..], V.fromList uniq)
else
Left ("Couldn't create hotMap of size " ++ show len ++ " from vector with " ++ show hotl ++ " unique characters")
-- | From a map and value, create a 1D Shape
-- with one index hot (1) with the rest 0.
-- Rerurns Nothing if the hot number is larger
-- than the length of the vector or the map
-- doesn't contain the value.
makeHot :: forall a n. (Ord a, KnownNat n)
=> Map a Int -> a -> Maybe (S ('D1 n))
makeHot m x = do
hot <- M.lookup x m
let len = fromIntegral $ natVal (Proxy :: Proxy n)
if hot < len
then
fmap S1D . create $ runSTVector $ do
vec <- newVector 0 len
writeVector vec hot 1
return vec
else Nothing
unHot :: forall a n. KnownNat n
=> Vector a -> S ('D1 n) -> Maybe a
unHot v (S1D xs)
= (V.!?) v
$ maxIndex (extract xs)
sample :: forall a n m. (KnownNat n, MR.MonadRandom m)
=> Double -> Vector a -> S ('D1 n) -> m a
sample temperature v (S1D xs) = do
ix <- MR.fromList . zip [0..] . fmap (toRational . exp . (/ temperature) . log) . VS.toList . extract $ xs
return $ v V.! ix