Skip to content

Commit

Permalink
doc fixes, rmsprop learning
Browse files Browse the repository at this point in the history
  • Loading branch information
JPMoresmau committed Aug 4, 2015
1 parent ebab439 commit 31c0512
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 11 deletions.
21 changes: 14 additions & 7 deletions exe/main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ main = do
rnn<-case args of
("gradient":mg:_) -> do
let maxGen = read mg
rd <- if ex
then do
r <- read <$> readFile fn
return $ RNNData r trainData (fromIntegral $ length $ tdInputs trainData)
else
evalRandIO $ buildTD trainData
rd <- readEx fn ex trainData
gradient rd generateLength maxGen
("rmsprop":mg:_) -> do
let maxGen = read mg
rd <- readEx fn ex trainData
rmsprop rd generateLength maxGen
("genetic":mg:_) -> do
let maxGen = read mg
b <- if ex
Expand All @@ -42,12 +41,20 @@ main = do
return $ buildExisting rnn1 trainData
else return $ buildTD trainData
genetic b generateLength maxGen
_ -> error "rnn gradient | generic"
_ -> error "rnn gradient <maxgen> | rmsprop <maxgen> | generic <maxgen>"
writeFile fn $ show rnn
where
readEx fn ex trainData =
if ex
then do
r <- read <$> readFile fn
return $ RNNData r trainData (fromIntegral $ length $ tdInputs trainData)
else
evalRandIO $ buildTD trainData
(train, gener) = (textToTrainData,generate)
genStep maxGen = maxGen `div` 10
gradient (RNNData r td _) generateLength maxGen = learnGradientDescent r td $ test generateLength maxGen
rmsprop (RNNData r td _) generateLength maxGen = learnRMSProp r td $ test generateLength maxGen
genetic r generateLength maxGen = do
fitnessList <- newIORef []
(RNNData rnn _ _) <- runGAIO 64 0.1 r $ stopf2 generateLength maxGen fitnessList
Expand Down
1 change: 1 addition & 0 deletions src/AI/Network/RNN.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ module AI.Network.RNN
, lstmFullSize
, lstmList
, learnGradientDescent
, learnRMSProp

, textToTrainData
, dataToText
Expand Down
6 changes: 3 additions & 3 deletions src/AI/Network/RNN/Data.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import AI.Network.RNN.Util


-- | Transform text into training data
-- the encoding of each Char is "one of v": each step is a vector of n values, n being the number of characters, and each character
-- the encoding of each Char is "one of k": each step is a vector of n values, n being the number of characters, and each character
-- is encoded as one value in the vector being 1
textToTrainData :: T.Text -> TrainData (DM.Map Int Char) Int
textToTrainData t =
Expand All @@ -54,7 +54,7 @@ textToTrainData t =
toArr :: Int -> Int -> Vector Double
toArr sz idx = M.fromList $ replicate idx 0 ++ [1] ++ replicate (sz-idx-1) 0

-- | Decode the data from the one of v encoding
-- | Decode the data from the one of k encoding
dataToText :: DM.Map Int Char -> [Vector Double] -> T.Text
dataToText m = T.pack . F.toList . fmap toC
where
Expand All @@ -74,7 +74,7 @@ randDataToText m = liftM T.pack . mapM (toC . M.toList)
R.fromList m2
--return $ fst $ last $ sortBy (comparing snd) m3

-- | Generate text using one of v encoding and a given network
-- | Generate text using one of k encoding and a given network
generate :: (RandomGen g,RNNEval a sz)
=> DM.Map Int Char -- ^ Character map
-> Int -- ^ Number of characters to generate
Expand Down
27 changes: 26 additions & 1 deletion src/AI/Network/RNN/LSTM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ cost' sz l is os lstm = let
-- two i o = zipWith (*) (map (\x->1 - x) o) (map (\x->1 - log x) i)
-- three i o= zipWith (+) (one i o) (two i o)

-- | Gradent decent learning
-- | Gradient descent learning
-- The third parameter is a call back function to monitor progress and stop the learning process if needed
learnGradientDescent :: (Monad m) => LSTMNetwork -> TrainData a Int -> (LSTMNetwork -> TrainData a Int -> Int -> m Bool) -> m LSTMNetwork
learnGradientDescent lstm td progressF = go (toList $ toVector lstm) 0
Expand All @@ -117,3 +117,28 @@ learnGradientDescent lstm td progressF = go (toList $ toVector lstm) 0
lis = map toList (tdInputs td)
los = map toList (tdOutputs td)
gf = grad (cost' (tdRecSize td) (auto le) (map (map auto) lis) (map (map auto) los))

-- | RMSProp learning, as far as I can make out
-- The third parameter is a call back function to monitor progress and stop the learning process if needed
learnRMSProp :: (Monad m) => LSTMNetwork -> TrainData a Int -> (LSTMNetwork -> TrainData a Int -> Int -> m Bool) -> m LSTMNetwork
learnRMSProp lstm td progressF = go ls0 (replicate myl 0) (replicate myl 0) (replicate myl 0) 0
where
go ls rgs rgs2 ugs gen = do
let rnn::LSTMNetwork = fromVector (tdRecSize td) (fromList ls)
cont <- progressF rnn td gen
if cont
then do
let
gs= gf ls -- gradients using AD
rgup = force $ zipWith (\rg g-> 0.95 * rg + 0.05 * g) rgs ls
rg2up = force $ zipWith (\rg2 g-> 0.95 * rg2 + 0.05 * (g ** 2)) rgs2 ls
ugup = force $ zipWith4 (\ud zg rg rg2 -> 0.9 * ud - 1e-4 * zg / sqrt(rg2 - rg ** 2 + 1e-4)) ugs gs rgup rg2up
ls2 = force $ zipWith (+) ls ugup
go ls2 rgup rg2up ugup (gen+1)
else return rnn
le = fromIntegral $ tdRecSize td
lis = map toList (tdInputs td)
los = map toList (tdOutputs td)
gf = grad (cost' (tdRecSize td) (auto le) (map (map auto) lis) (map (map auto) los))
ls0 = toList $ toVector lstm
myl= length ls0

0 comments on commit 31c0512

Please sign in to comment.