Skip to content

Commit 2722e73

Browse files
committed
Added generator, renamed fire.h to defines.h, added Algorithm and Arith implementations.
1 parent 5653d61 commit 2722e73

File tree

15 files changed

+697
-83
lines changed

15 files changed

+697
-83
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
dist*
2-
result/
2+
result/
3+
/TAGS
4+
/result

Main.hs

Lines changed: 0 additions & 5 deletions
This file was deleted.

exe/Main.hs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module Main where
2+
3+
main = print 4

fire.cabal

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,21 @@ library
5252
Haskell2010
5353

5454
executable main
55+
hs-source-dirs:
56+
exe
5557
main-is:
5658
Main.hs
5759
build-depends:
5860
base < 5, fire
5961
default-language:
6062
Haskell2010
63+
64+
executable gen
65+
main-is:
66+
Main.hs
67+
hs-source-dirs:
68+
gen
69+
build-depends:
70+
base < 5, attoparsec, text
71+
default-language:
72+
Haskell2010

gen/Main.hs

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
{-# LANGUAGE OverloadedStrings #-}
2+
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
3+
module Main where
4+
5+
import Control.Applicative
6+
import Control.Monad
7+
import Data.Attoparsec.Text
8+
import qualified Data.Attoparsec.Text as A
9+
import Data.Char
10+
import Data.Maybe
11+
import Data.Text (Text)
12+
import qualified Data.Text as T
13+
import qualified Data.Text.IO as T
14+
import System.Environment
15+
import Text.Printf
16+
17+
main :: IO ()
18+
main = do
19+
arg <- reverse . Prelude.takeWhile (/='/') . drop 2 . reverse
20+
. fromMaybe (error "Please enter C header file")
21+
. listToMaybe <$> getArgs
22+
ls <- T.lines <$> T.readFile arg
23+
forM_ ls $ \input ->
24+
unless (T.null input || "#include" `T.isInfixOf` input) $ do
25+
result <- either error genBinding (parseInput parser input)
26+
T.writeFile (path arg) (file arg <> result)
27+
28+
path :: String -> String
29+
path s = printf "src/Data/Array/Fire/Internal/%s.hsc" s
30+
31+
file :: String -> Text
32+
file a = T.pack $ printf
33+
"module Data.Array.Fire.Internal.%s where\n\n\
34+
\import Data.Array.Fire.Internal.Defines\n\n\
35+
\#include \"%s.h\"\n\n\
36+
\import Foreign.Ptr\n\n" (capitalName a) (lowerCase a)
37+
38+
capitalName, lowerCase :: [Char] -> [Char]
39+
capitalName (x:xs) = toUpper x : xs
40+
lowerCase (x:xs) = toLower x : xs
41+
42+
type Output = Name
43+
44+
newtype Name = Name Text
45+
deriving (Show, Eq, PrintfArg)
46+
47+
data AST = AST Output Name Params
48+
deriving (Show)
49+
50+
type Params = [Param]
51+
52+
data Param = Param Type Name
53+
deriving (Show)
54+
55+
type IsPtr = Bool
56+
57+
data Type = Type IsPtr TypeValue
58+
deriving (Show)
59+
60+
newtype TypeValue = TypeName Text
61+
deriving (Show)
62+
63+
parser :: Parser AST
64+
parser = do
65+
whitespace
66+
a <- AST <$> parseOutput
67+
<*> parseName
68+
<*> parseParams
69+
whitespace
70+
pure a
71+
72+
whitespace = many (char ' ')
73+
74+
parseOutput :: Parser Output
75+
parseOutput = do
76+
result <- string "AFAPI af_err"
77+
whitespace
78+
pure (Name "AFError")
79+
80+
parseName :: Parser Name
81+
parseName = do
82+
result <- A.takeWhile (/='(')
83+
pure $ Name $ T.strip result
84+
85+
parseParams :: Parser Params
86+
parseParams = do
87+
char '('
88+
params <- parseParam `A.sepBy1` (char ',' >> whitespace)
89+
char ')'
90+
char ';'
91+
pure params
92+
93+
parseParam :: Parser Param
94+
parseParam = do
95+
parseModifier
96+
type' <- parseType
97+
name <- parseParamName
98+
pure $ Param type' name
99+
where
100+
parseModifier = do
101+
r <- (Just <$> string "const") <|> pure Nothing
102+
whitespace
103+
pure ()
104+
105+
parseParamName =
106+
Name <$> A.takeWhile (`notElem` (",)" :: String))
107+
108+
parseType = do
109+
typeValue <- getTypeValue
110+
isPtr <- (True <$ char '*') <|> pure False
111+
pure $ Type isPtr typeValue
112+
113+
getTypeValue = do
114+
name <- A.takeWhile (`notElem` (" " :: String))
115+
whitespace
116+
pure $ TypeName (camelize name)
117+
118+
camelize :: Text -> Text
119+
camelize = T.concat . map go . T.splitOn "_"
120+
where
121+
go "af" = "AF"
122+
go xs = T.cons (toUpper c) cs
123+
where
124+
c = T.head xs
125+
cs = T.tail xs
126+
127+
parseInput :: Parser AST -> Text -> Either String AST
128+
parseInput = parseOnly
129+
130+
genBinding :: AST -> IO Text
131+
genBinding (AST (Name output) name params) =
132+
pure (header <> dumpBody <> dumpOutput)
133+
where
134+
dumpOutput = "IO " <> output
135+
header = T.pack $ printf "foreign import ccall unsafe \"%s\"\n %s :: " name name
136+
dumpBody = T.concat $ map toParam params
137+
where
138+
toParam (Param (Type True (TypeName t)) _) = "Ptr " <> t <> " -> "
139+
toParam (Param (Type False (TypeName t)) _) = t <> " -> "
140+
141+
-- test :: Text
142+
-- test =
143+
-- "AFAPI af_err af_stdev(af_array *out, const af_array in, const dim_t dim);"

gen/Main.hs~

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
{-# LANGUAGE OverloadedStrings #-}
2+
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
3+
module Main where
4+
5+
import Control.Applicative
6+
import Control.Monad
7+
import Data.Attoparsec.Text
8+
import qualified Data.Attoparsec.Text as A
9+
import Data.Char
10+
import Data.Maybe
11+
import Data.Text (Text)
12+
import qualified Data.Text as T
13+
import qualified Data.Text.IO as T
14+
import System.Environment
15+
import Text.Printf
16+
17+
main :: IO ()
18+
main = do
19+
arg <- fromMaybe (error "Please enter C header file")
20+
. listToMaybe <$> getArgs
21+
ls <- T.lines <$> T.readFile arg
22+
forM_ ls $ \input ->
23+
unless (T.null input || "#include" `T.isInfixOf` input) $ do
24+
result <- either error genBinding (parseInput parser input)
25+
T.writeFile
26+
27+
file a = T.pack $ printf
28+
"module Data.Array.Fire.Internal.%s where\
29+
\import Data.Array.Fire.Internal.Defines\
30+
\#include \"%s.h\"\
31+
\import Foreign.Ptr\n\n" (capitalName a) (lowerCase a)
32+
33+
34+
type Output = Name
35+
36+
newtype Name = Name Text
37+
deriving (Show, Eq, PrintfArg)
38+
39+
data AST = AST Output Name Params
40+
deriving (Show)
41+
42+
type Params = [Param]
43+
44+
data Param = Param Type Name
45+
deriving (Show)
46+
47+
type IsPtr = Bool
48+
49+
data Type = Type IsPtr TypeValue
50+
deriving (Show)
51+
52+
newtype TypeValue = TypeName Text
53+
deriving (Show)
54+
55+
-- input :: Text
56+
-- input =
57+
-- "AFAPI af_err af_stdev(af_array *out, const af_array in, const dim_t dim);"
58+
59+
parser :: Parser AST
60+
parser = do
61+
whitespace
62+
a <- AST <$> parseOutput
63+
<*> parseName
64+
<*> parseParams
65+
whitespace
66+
pure a
67+
68+
whitespace = many (char ' ')
69+
70+
parseOutput :: Parser Output
71+
parseOutput = do
72+
result <- string "AFAPI af_err"
73+
whitespace
74+
pure (Name "AFError")
75+
76+
parseName :: Parser Name
77+
parseName = do
78+
result <- A.takeWhile (/='(')
79+
pure $ Name $ T.strip result
80+
81+
parseParams :: Parser Params
82+
parseParams = do
83+
char '('
84+
params <- parseParam `A.sepBy1` (char ',' >> whitespace)
85+
char ')'
86+
char ';'
87+
pure params
88+
89+
parseParam :: Parser Param
90+
parseParam = do
91+
parseModifier
92+
type' <- parseType
93+
name <- parseParamName
94+
pure $ Param type' name
95+
where
96+
parseModifier = do
97+
r <- (Just <$> string "const") <|> pure Nothing
98+
whitespace
99+
pure ()
100+
101+
parseParamName =
102+
Name <$> A.takeWhile (`notElem` (",)" :: String))
103+
104+
parseType = do
105+
typeValue <- getTypeValue
106+
isPtr <- (True <$ char '*') <|> pure False
107+
pure $ Type isPtr typeValue
108+
109+
getTypeValue = do
110+
name <- A.takeWhile (`notElem` (" " :: String))
111+
whitespace
112+
pure $ TypeName (camelize name)
113+
114+
camelize :: Text -> Text
115+
camelize = T.concat . map go . T.splitOn "_"
116+
where
117+
go "af" = "AF"
118+
go xs = T.cons (toUpper c) cs
119+
where
120+
c = T.head xs
121+
cs = T.tail xs
122+
123+
parseInput :: Parser AST -> Text -> Either String AST
124+
parseInput = parseOnly
125+
126+
genBinding :: AST -> IO Text
127+
genBinding (AST (Name output) name params) =
128+
pure (header <> dumpBody <> dumpOutput)
129+
where
130+
dumpOutput = "IO " <> output
131+
header = T.pack $ printf "foreign import ccall unsafe \"%s\"\n %s :: " name name
132+
dumpBody = T.concat $ map toParam params
133+
where
134+
toParam (Param (Type True (TypeName t)) _) = "Ptr " <> t <> " -> "
135+
toParam (Param (Type False (TypeName t)) _) = t <> " -> "
136+

gen/templates/arith.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
AFAPI af_err af_add (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
2+
AFAPI af_err af_sub (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
3+
AFAPI af_err af_mul (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
4+
AFAPI af_err af_div (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
5+
AFAPI af_err af_lt (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
6+
AFAPI af_err af_gt (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
7+
AFAPI af_err af_le (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
8+
AFAPI af_err af_ge (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
9+
AFAPI af_err af_eq (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
10+
AFAPI af_err af_neq (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
11+
AFAPI af_err af_and (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
12+
AFAPI af_err af_or (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
13+
AFAPI af_err af_not (af_array *out, const af_array in);
14+
AFAPI af_err af_bitand (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
15+
AFAPI af_err af_bitor (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
16+
AFAPI af_err af_bitxor (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
17+
AFAPI af_err af_bitshiftl(af_array *out, const af_array lhs, const af_array rhs, const bool batch);
18+
AFAPI af_err af_bitshiftr(af_array *out, const af_array lhs, const af_array rhs, const bool batch);
19+
AFAPI af_err af_cast (af_array *out, const af_array in, const af_dtype type);
20+
AFAPI af_err af_minof (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
21+
AFAPI af_err af_maxof (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
22+
AFAPI af_err af_clamp(af_array *out, const af_array in, const af_array lo, const af_array hi, const bool batch);
23+
AFAPI af_err af_rem (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
24+
AFAPI af_err af_mod (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
25+
AFAPI af_err af_abs (af_array *out, const af_array in);
26+
AFAPI af_err af_arg (af_array *out, const af_array in);
27+
AFAPI af_err af_sign (af_array *out, const af_array in);
28+
AFAPI af_err af_round (af_array *out, const af_array in);
29+
AFAPI af_err af_trunc (af_array *out, const af_array in);
30+
AFAPI af_err af_floor (af_array *out, const af_array in);
31+
AFAPI af_err af_ceil (af_array *out, const af_array in);
32+
AFAPI af_err af_hypot (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
33+
AFAPI af_err af_sin (af_array *out, const af_array in);
34+
AFAPI af_err af_cos (af_array *out, const af_array in);
35+
AFAPI af_err af_tan (af_array *out, const af_array in);
36+
AFAPI af_err af_asin (af_array *out, const af_array in);
37+
AFAPI af_err af_acos (af_array *out, const af_array in);
38+
AFAPI af_err af_atan (af_array *out, const af_array in);
39+
AFAPI af_err af_atan2 (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
40+
AFAPI af_err af_cplx2 (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
41+
AFAPI af_err af_cplx (af_array *out, const af_array in);
42+
AFAPI af_err af_real (af_array *out, const af_array in);
43+
AFAPI af_err af_imag (af_array *out, const af_array in);
44+
AFAPI af_err af_conjg (af_array *out, const af_array in);
45+
AFAPI af_err af_sinh (af_array *out, const af_array in);
46+
AFAPI af_err af_cosh (af_array *out, const af_array in);
47+
AFAPI af_err af_tanh (af_array *out, const af_array in);
48+
AFAPI af_err af_asinh (af_array *out, const af_array in);
49+
AFAPI af_err af_acosh (af_array *out, const af_array in);
50+
AFAPI af_err af_atanh (af_array *out, const af_array in);
51+
AFAPI af_err af_root (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
52+
AFAPI af_err af_pow (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
53+
AFAPI af_err af_pow2 (af_array *out, const af_array in);
54+
AFAPI af_err af_exp (af_array *out, const af_array in);
55+
AFAPI af_err af_sigmoid (af_array *out, const af_array in);
56+
AFAPI af_err af_expm1 (af_array *out, const af_array in);
57+
AFAPI af_err af_erf (af_array *out, const af_array in);
58+
AFAPI af_err af_erfc (af_array *out, const af_array in);
59+
AFAPI af_err af_log (af_array *out, const af_array in);
60+
AFAPI af_err af_log1p (af_array *out, const af_array in);
61+
AFAPI af_err af_log10 (af_array *out, const af_array in);
62+
AFAPI af_err af_log2 (af_array *out, const af_array in);
63+
AFAPI af_err af_sqrt (af_array *out, const af_array in);
64+
AFAPI af_err af_cbrt (af_array *out, const af_array in);
65+
AFAPI af_err af_factorial (af_array *out, const af_array in);
66+
AFAPI af_err af_tgamma (af_array *out, const af_array in);
67+
AFAPI af_err af_lgamma (af_array *out, const af_array in);
68+
AFAPI af_err af_iszero (af_array *out, const af_array in);
69+
AFAPI af_err af_isinf (af_array *out, const af_array in);
70+
AFAPI af_err af_isnan (af_array *out, const af_array in);

0 commit comments

Comments
 (0)