Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt MNIST example to use the original lecun data files #97

Open
wants to merge 34 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3d39f38
Work in progress
jrp2014 Jan 25, 2020
8f6e858
WIP
jrp2014 Jan 26, 2020
109ae26
Adapt MNIST example to use original data files
jrp2014 Jan 26, 2020
5100f35
remove data files from repo
jrp2014 Jan 26, 2020
e355d7d
remove some unnecessary build-depends
jrp2014 Jan 26, 2020
0fa41f2
Reinstate import of Data.Semigroup (<>) for older ghcs
jrp2014 Jan 26, 2020
128856a
typo
jrp2014 Jan 26, 2020
c2af380
resync with original
jrp2014 Jan 27, 2020
8cea7ed
Merge branch 'master' into master
jrp2014 Jan 27, 2020
426a8a1
Convert gen-mnist example to use LeCun training data; works
jrp2014 Jan 28, 2020
081e2fa
ci typo
jrp2014 Jan 28, 2020
f104e50
ci typo
jrp2014 Jan 28, 2020
726cb2d
Cleanup in response to erikd comments
jrp2014 Jan 29, 2020
9667486
Redo mnist example using gan-mnist transformations
jrp2014 Jan 30, 2020
e5f52b1
Temporarily reduce number of mnist samples used; now runs!
jrp2014 Jan 30, 2020
a5ed003
limit samples from command line
jrp2014 Feb 1, 2020
19c3521
Work in progress
jrp2014 Jan 25, 2020
b544b93
Adapt MNIST example to use original data files
jrp2014 Jan 26, 2020
7012c5f
Reinstate import of Data.Semigroup (<>) for older ghcs
jrp2014 Jan 26, 2020
a88d710
examples: Use CPP to support multiple GHC versions
erikd Jan 26, 2020
0e1462d
Convert gen-mnist example to use LeCun training data; works
jrp2014 Jan 28, 2020
bb98d46
limit samples from command line
jrp2014 Feb 1, 2020
5b19f1e
Fix compiler warning
erikd Jan 23, 2020
165b447
examples: Use CPP to support multiple GHC versions
erikd Jan 26, 2020
d386d4a
Adapt MNIST example to use original data files
jrp2014 Jan 26, 2020
99e4cd6
Reinstate import of Data.Semigroup (<>) for older ghcs
jrp2014 Jan 26, 2020
f7b53b0
Convert gen-mnist example to use LeCun training data; works
jrp2014 Jan 28, 2020
6ab97a4
Add Iris example
jrp2014 Feb 2, 2020
517da92
Add Iris example
jrp2014 Feb 2, 2020
cc547a5
Use simpler network for MNIST example
jrp2014 Feb 2, 2020
24678b0
feedforward lint
jrp2014 Feb 2, 2020
2f724dc
Further example in MNIST comments
jrp2014 Feb 2, 2020
b0b671b
Add hlinting to integration testing
jrp2014 Feb 2, 2020
ccd2679
Suspend hlinting from Travis job
jrp2014 Feb 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ cabal.project.local
cabal.sandbox.config
dist/
dist-newstyle/
data
64 changes: 64 additions & 0 deletions .hlint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# HLint configuration file
jrp2014 marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/ndmitchell/hlint
##########################

# This file contains a template configuration file, which is typically
# placed as .hlint.yaml in the root of your project


- ignore: {name: "Redundant bracket"}
- ignore: {name: "Reduce duplication"}


# Specify additional command line arguments
#
# - arguments: [--color, --cpp-simple, -XQuasiQuotes]


# Control which extensions/flags/modules/functions can be used
#
# - extensions:
# - default: false # all extension are banned by default
# - name: [PatternGuards, ViewPatterns] # only these listed extensions can be used
# - {name: CPP, within: CrossPlatform} # CPP can only be used in a given module
#
# - flags:
# - {name: -w, within: []} # -w is allowed nowhere
#
# - modules:
# - {name: [Data.Set, Data.HashSet], as: Set} # if you import Data.Set qualified, it must be as 'Set'
# - {name: Control.Arrow, within: []} # Certain modules are banned entirely
#
# - functions:
# - {name: unsafePerformIO, within: []} # unsafePerformIO can only appear in no modules


# Add custom hints for this project
#
# Will suggest replacing "wibbleMany [myvar]" with "wibbleOne myvar"
# - error: {lhs: "wibbleMany [x]", rhs: wibbleOne x}


# Turn on hints that are off by default
#
# Ban "module X(module X) where", to require a real export list
# - warn: {name: Use explicit module export list}
#
# Replace a $ b $ c with a . b $ c
# - group: {name: dollar, enabled: true}
#
# Generalise map to fmap, ++ to <>
# - group: {name: generalise, enabled: true}


# Ignore some builtin hints
# - ignore: {name: Use let}
# - ignore: {name: Use const, within: SpecialModule} # Only within certain modules


# Define some custom infix operators
# - fixity: infixr 3 ~^#^~


# To generate a suitable file for HLint do:
# $ hlint --default > .hlint.yaml
8 changes: 7 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ before_install:
install:
- echo "$(ghc --version) [$(ghc --print-project-git-commit-id 2> /dev/null || echo '?')]"
- travis_retry cabal-3.0 update
# - cabal-3.0 install hlint

notifications:
email: false
Expand All @@ -27,4 +28,9 @@ script:
- cabal-3.0 configure --enable-tests
- cabal-3.0 build all
- cabal-3.0 test --test-show-details=direct -j1

# - cabal-3.0 run feedforward
# - cabal-3.0 run recurrent
- ./runMNIST.sh data -l 5000 -i 3
- ./runGAN-MNIST.sh data -i 3
- ./runIris.sh
#- hlint .
36 changes: 27 additions & 9 deletions examples/grenade-examples.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ source-repository head
type: git
location: https://github.com/HuwCampbell/grenade.git

library

executable feedforward
ghc-options: -Wall -threaded -O2
main-is: main/feedforward.hs
Expand All @@ -39,37 +37,57 @@ executable mnist
ghc-options: -Wall -threaded -O2
main-is: main/mnist.hs
build-depends: base
, cereal
, grenade
, attoparsec
, either
, filepath
, bytestring
, optparse-applicative >= 0.13 && < 0.16
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.21
, transformers
, semigroups
, singletons
, MonadRandom
, vector
, split
, zlib

executable iris
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new example is a great idea, but should probably be a separate PR from the other changes you are making.

Having them a separate make reviewing both PRs easier.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. In this case the changes are focused on making the examples run (which they now do, as you can see from the travis output). I think that I've got to a stage where the pieces are in place. If you want to improve the run scripts, I think that that would be good. I'd be happy to flesh out some of the documentation, if that would help.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not too sure it the examples should run in CI. Will have a look.

ghc-options: -Wall -threaded -O2
main-is: main/iris.hs
build-depends: base
, MonadRandom
, bytestring
, cassava
, filepath
, grenade
, hmatrix >= 0.18 && < 0.21
, mtl >= 2.2.1 && < 2.3
, optparse-applicative >= 0.13 && < 0.16
, random-shuffle
, semigroups
, singletons
, transformers
, vector

executable gan-mnist
ghc-options: -Wall -threaded -O2
main-is: main/gan-mnist.hs
build-depends: base
, cereal
, grenade
, attoparsec
, filepath
, bytestring
, cereal
, either
, optparse-applicative >= 0.13 && < 0.16
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.21
, transformers
, semigroups
, singletons
, split
, MonadRandom
, vector
, zlib

executable recurrent
ghc-options: -Wall -threaded -O2
Expand Down
4 changes: 1 addition & 3 deletions examples/main/feedforward.hs
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ feedForward' =
main :: IO ()
main = do
FeedForwardOpts examples rate load save <- execParser (info (feedForward' <**> helper) idm)
net0 <- case load of
Just loadFile -> netLoad loadFile
Nothing -> randomNet
net0 <- maybe randomNet netLoad load

net <- netTrain net0 rate examples
netScore net
Expand Down
116 changes: 81 additions & 35 deletions examples/main/gan-mnist.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}

Expand Down Expand Up @@ -39,23 +38,27 @@
import Control.Applicative
import Control.Monad
import Control.Monad.Random
import Control.Monad.Trans.Except

import qualified Data.Attoparsec.Text as A
import qualified Data.ByteString as B
import Codec.Compression.GZip ( decompress )
import Data.Serialize ( Get )
import qualified Data.Serialize as Serialize
import qualified Data.ByteString.Lazy as B

import Data.List ( foldl' )
import Data.List.Split ( chunksOf )
import Data.Maybe ( fromMaybe )
#if ! MIN_VERSION_base(4,13,0)
import Data.Semigroup ( (<>) )
#endif
import Data.Serialize
import qualified Data.Text as T
import qualified Data.Text.IO as T

import Data.Word ( Word32 , Word8 )
import qualified Data.Vector.Storable as V

import qualified Numeric.LinearAlgebra.Static as SA
import Numeric.LinearAlgebra.Data ( toLists )

import Options.Applicative
import System.FilePath ( (</>) )

import Grenade
import Grenade.Utils.OneHot
Expand Down Expand Up @@ -104,11 +107,13 @@ trainExample rate discriminator generator realExample noiseSource
in ( newDiscriminator, newGenerator )


ganTest :: (Discriminator, Generator) -> Int -> FilePath -> LearningParameters -> ExceptT String IO (Discriminator, Generator)
ganTest (discriminator0, generator0) iterations trainFile rate = do
trainData <- fmap fst <$> readMNIST trainFile
ganTest :: (Discriminator, Generator) -> Int -> FilePath -> LearningParameters -> IO (Discriminator, Generator)
ganTest (discriminator0, generator0) iterations dataDir rate = do
-- Note that for this example we use only the samples, and not the labels
trainData <- fmap fst <$> readMNIST (dataDir </> "train-images-idx3-ubyte.gz")
(dataDir </> "train-labels-idx1-ubyte.gz")

lift $ foldM (runIteration trainData) ( discriminator0, generator0 ) [1..iterations]
foldM (runIteration trainData) ( discriminator0, generator0 ) [1..iterations]

where

Expand All @@ -127,10 +132,9 @@ ganTest (discriminator0, generator0) iterations trainFile rate = do

runIteration :: [S ('D2 28 28)] -> (Discriminator, Generator) -> Int -> IO (Discriminator, Generator)
runIteration trainData ( !discriminator, !generator ) _ = do
trained' <- foldM ( \(!discriminatorX, !generatorX ) realExample -> do
fakeExample <- randomOfShape
return $ trainExample rate discriminatorX generatorX realExample fakeExample
) ( discriminator, generator ) trainData
trained' <- foldM ( \(!discriminatorX, !generatorX ) realExample ->
trainExample rate discriminatorX generatorX realExample <$> randomOfShape )
( discriminator, generator ) trainData


showShape' . snd . runNetwork (snd trained') =<< randomOfShape
Expand All @@ -140,7 +144,7 @@ ganTest (discriminator0, generator0) iterations trainFile rate = do
data GanOpts = GanOpts FilePath Int LearningParameters (Maybe FilePath) (Maybe FilePath)

mnist' :: Parser GanOpts
mnist' = GanOpts <$> argument str (metavar "TRAIN")
mnist' = GanOpts <$> argument str (metavar "DATADIR")
<*> option auto (long "iterations" <> short 'i' <> value 15)
<*> (LearningParameters
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
Expand All @@ -159,27 +163,69 @@ main = do
Just loadFile -> netLoad loadFile
Nothing -> (,) <$> randomDiscriminator <*> randomGenerator

res <- runExceptT $ ganTest nets0 iter mnist rate
case res of
Right nets1 -> case save of
Just saveFile -> B.writeFile saveFile $ runPut (put nets1)
Nothing -> return ()

Left err -> putStrLn err

readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))]
readMNIST mnist = ExceptT $ do
mnistdata <- T.readFile mnist
return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata)
nets1 <- ganTest nets0 iter mnist rate
case save of
Just saveFile -> B.writeFile saveFile $ Serialize.runPutLazy (Serialize.put nets1)
Nothing -> return ()


-- Adapted from https://github.com/tensorflow/haskell/blob/master/tensorflow-mnist/src/TensorFlow/Examples/MNIST/Parse.hs
-- Could also have used Data.IDX, although that uses a different Vector variant from that need for fromStorable
readMNIST :: FilePath -> FilePath -> IO [(S ( 'D2 28 28), S ( 'D1 10))]
readMNIST iFP lFP = do
labels <- readMNISTLabels lFP
samples <- readMNISTSamples iFP
return $ zip
(fmap (fromMaybe (error "bad samples") . fromStorable) samples)
(fromMaybe (error "bad labels") . oneHot . fromIntegral <$> labels)

-- | Check's the file's endianess, throwing an error if it's not as expected.
checkEndian :: Get ()
checkEndian = do
magic <- Serialize.getWord32be
when (magic `notElem` ([2049, 2051] :: [Word32]))
$ error "Expected big endian, but image file is little endian."

-- | Reads an MNIST file and returns a list of samples.
readMNISTSamples :: FilePath -> IO [V.Vector Double]
readMNISTSamples path = do
raw <- decompress <$> B.readFile path
either fail ( return . fmap (V.map normalize) ) $ Serialize.runGetLazy getMNIST raw
where
getMNIST :: Get [V.Vector Word8]
getMNIST = do
checkEndian
-- Parse header data.
cnt <- fromIntegral <$> Serialize.getWord32be
rows <- fromIntegral <$> Serialize.getWord32be
cols <- fromIntegral <$> Serialize.getWord32be
-- Read all of the data, then split into samples.
pixels <- Serialize.getLazyByteString $ fromIntegral $ cnt * rows * cols
return $ V.fromList <$> chunksOf (rows * cols) (B.unpack pixels)

normalize :: Word8 -> Double
normalize = (/ 255) . fromIntegral
-- There are other normalization functions in the literature, such as
-- normalize = (/ 0.3081) . (`subtract` 0.1307) . (/ 255) . fromIntegral
-- but we need values in the range [0..1] for the showShape' pretty printer

-- | Reads a list of MNIST labels from a file and returns them.
readMNISTLabels :: FilePath -> IO [Word8]
readMNISTLabels path = do
raw <- decompress <$> B.readFile path
either fail return $ Serialize.runGetLazy getLabels raw
where
getLabels :: Get [Word8]
getLabels = do
checkEndian
-- Parse header data.
cnt <- fromIntegral <$> Serialize.getWord32be
-- Read all of the labels.
B.unpack <$> Serialize.getLazyByteString cnt

parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10))
parseMNIST = do
Just lab <- oneHot <$> A.decimal
pixels <- many (A.char ',' >> A.double)
image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels)
return (image, lab)

netLoad :: FilePath -> IO (Discriminator, Generator)
netLoad modelPath = do
modelData <- B.readFile modelPath
either fail return $ runGet (get :: Get (Discriminator, Generator)) modelData
either fail return $
Serialize.runGetLazy (Serialize.get :: Get (Discriminator, Generator)) modelData