Skip to content

Commit

Permalink
snocket: introduce LocalSocket & FileDescriptor
Browse files Browse the repository at this point in the history
FileDescriptor is a newtype wrapper for file descriptor numbers.  For
sockets on Posix and Unix we use file descriptor numbers, for windows
named pipes we use the memory address as the file descriptor number.
  • Loading branch information
coot authored and newhoggy committed Nov 23, 2020
1 parent 21e6b59 commit 724cc1e
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 83 deletions.
125 changes: 72 additions & 53 deletions ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs
Expand Up @@ -6,7 +6,6 @@
module Ouroboros.Network.Snocket
( -- * Snocket Interface
Accept (..)
, fmapAccept
, AddressFamily (..)
, Snocket (..)
-- ** Socket based Snocktes
Expand All @@ -17,15 +16,20 @@ module Ouroboros.Network.Snocket
--
, LocalSnocket
, localSnocket
, LocalSocket (..)
, LocalAddress (..)
, LocalFD
, localAddressFromPath

, FileDescriptor
, socketFileDescriptor
, localSocketFileDescriptor
) where

import Control.Exception
import Control.Monad (when)
import Control.Monad.Class.MonadTime (DiffTime)
import Control.Tracer (Tracer)
import Data.Bifunctor (Bifunctor (..))
import Data.Hashable
#if !defined(mingw32_HOST_OS)
import Network.Socket ( Family (AF_UNIX) )
Expand All @@ -36,6 +40,7 @@ import Network.Socket ( Socket
import qualified Network.Socket as Socket
#if defined(mingw32_HOST_OS)
import Data.Bits
import Foreign.Ptr (IntPtr (..), ptrToIntPtr)
import qualified System.Win32 as Win32
import qualified System.Win32.NamedPipes as Win32
import qualified System.Win32.Async as Win32.Async
Expand Down Expand Up @@ -90,16 +95,10 @@ newtype Accept addr fd = Accept
{ runAccept :: IO (fd, addr, Accept addr fd)
}


-- | Arguments of 'Accept' are in the wrong order.
--
-- TODO: this can be fixed later.
--
fmapAccept :: (addr -> addr') -> Accept addr fd -> Accept addr' fd
fmapAccept f ac = Accept $ g <$> runAccept ac
where
g (fd, addr, next) = (fd, f addr, fmapAccept f next)

instance Bifunctor Accept where
bimap f g ac = Accept $ h <$> runAccept ac
where
h (fd, addr, next) = (g fd, f addr, bimap f g next)


-- | BSD accept loop.
Expand Down Expand Up @@ -274,16 +273,27 @@ socketSnocket ioManager = Snocket {
-- NamedPipes based Snocket
--


#if defined(mingw32_HOST_OS)
type HANDLESnocket = Snocket IO Win32.HANDLE LocalAddress
type LocalHandle = Win32.HANDLE
#else
type LocalHandle = Socket
#endif

-- | System dependent LocalSnocket type
newtype LocalSocket = LocalSocket { getLocalHandle :: LocalHandle }

-- | System dependent LocalSnocket
type LocalSnocket = Snocket IO LocalSocket LocalAddress



#if defined(mingw32_HOST_OS)
-- | Create a Windows Named Pipe Snocket.
--
namedPipeSnocket
:: IOManager
-> FilePath
-> HANDLESnocket
-> LocalSnocket
namedPipeSnocket ioManager path = Snocket {
getLocalAddr = \_ -> return localAddress
, getRemoteAddr = \_ -> return localAddress
Expand All @@ -306,7 +316,7 @@ namedPipeSnocket ioManager path = Snocket {
`catch` \(SomeAsyncException _) -> do
Win32.closeHandle hpipe
throwIO e
pure hpipe
pure (LocalSocket hpipe)

-- To connect, simply create a file whose name is the named pipe name.
, openToConnect = \(LocalAddress pipeName) -> do
Expand All @@ -324,26 +334,26 @@ namedPipeSnocket ioManager path = Snocket {
`catch` \(SomeAsyncException _) -> do
Win32.closeHandle hpipe
throwIO e
return hpipe
return (LocalSocket hpipe)
, connect = \_ _ -> pure ()

-- Bind and listen are no-op.
, bind = \_ _ -> pure ()
, listen = \_ -> pure ()

, accept = \hpipe -> Accept $ do
, accept = \sock@(LocalSocket hpipe) -> Accept $ do
Win32.Async.connectNamedPipe hpipe
return (hpipe, localAddress, acceptNext)
return (sock, localAddress, acceptNext)

, close = Win32.closeHandle
, close = Win32.closeHandle . getLocalHandle

, toBearer = \_sduTimeout -> namedPipeAsBearer
, toBearer = \_sduTimeout tr -> namedPipeAsBearer tr . getLocalHandle
}
where
localAddress :: LocalAddress
localAddress = LocalAddress path

acceptNext :: Accept LocalAddress Win32.HANDLE
acceptNext :: Accept LocalAddress LocalSocket
acceptNext = Accept $ do
hpipe <- Win32.createNamedPipe
path
Expand All @@ -354,44 +364,33 @@ namedPipeSnocket ioManager path = Snocket {
16384 -- inbound pipe size
0 -- default timeout
Nothing -- default security
`catch` \(e :: IOException) -> do
putStrLn $ "accept: " ++ show e
throwIO e
associateWithIOManager ioManager (Left hpipe)
Win32.Async.connectNamedPipe hpipe
return (hpipe, localAddress, acceptNext)
return (LocalSocket hpipe, localAddress, acceptNext)
#endif


--
-- Windows/POSIX type aliases
--

localSnocket :: IOManager -> FilePath -> LocalSnocket
-- | System dependent LocalSnocket type
#if defined(mingw32_HOST_OS)
type LocalSnocket = HANDLESnocket
type LocalFD = Win32.HANDLE

localSnocket = namedPipeSnocket
#else
type LocalSnocket = Snocket IO Socket LocalAddress
type LocalFD = Socket

localSnocket ioManager _ = Snocket {
getLocalAddr = fmap toLocalAddress . Socket.getSocketName
, getRemoteAddr = fmap toLocalAddress . Socket.getPeerName
, addrFamily = const LocalFamily
, connect = \s addr -> do
Socket.connect s (fromLocalAddress addr)
, bind = \fd addr -> Socket.bind fd (fromLocalAddress addr)
, listen = flip Socket.listen 1
, accept = fmapAccept toLocalAddress . (berkeleyAccept ioManager)
, open = openSocket
, openToConnect = \_addr -> openSocket LocalFamily
, close = Socket.close
, toBearer = Mx.socketAsMuxBearer
}
localSnocket ioManager _ =
Snocket {
getLocalAddr = fmap toLocalAddress . Socket.getSocketName . getLocalHandle
, getRemoteAddr = fmap toLocalAddress . Socket.getPeerName . getLocalHandle
, addrFamily = const LocalFamily
, connect = \(LocalSocket s) addr ->
Socket.connect s (fromLocalAddress addr)
, bind = \(LocalSocket fd) addr -> Socket.bind fd (fromLocalAddress addr)
, listen = flip Socket.listen 1 . getLocalHandle
, accept = bimap toLocalAddress LocalSocket
. berkeleyAccept ioManager
. getLocalHandle
, open = openSocket
, openToConnect = \_addr -> openSocket LocalFamily
, close = Socket.close . getLocalHandle
, toBearer = \df tr (LocalSocket sd) -> Mx.socketAsMuxBearer df tr sd
}
where
toLocalAddress :: SockAddr -> LocalAddress
toLocalAddress (SockAddrUnix path) = LocalAddress path
Expand All @@ -400,7 +399,7 @@ localSnocket ioManager _ = Snocket {
fromLocalAddress :: LocalAddress -> SockAddr
fromLocalAddress = SockAddrUnix . getFilePath

openSocket :: AddressFamily LocalAddress -> IO Socket
openSocket :: AddressFamily LocalAddress -> IO LocalSocket
openSocket LocalFamily = do
sd <- Socket.socket AF_UNIX Socket.Stream Socket.defaultProtocol
associateWithIOManager ioManager (Right sd)
Expand All @@ -413,8 +412,28 @@ localSnocket ioManager _ = Snocket {
`catch` \(SomeAsyncException _) -> do
Socket.close sd
throwIO e
return sd
return (LocalSocket sd)
#endif

localAddressFromPath :: FilePath -> LocalAddress
localAddressFromPath = LocalAddress

-- | Socket file descriptor.
--
newtype FileDescriptor = FileDescriptor { getFileDescriptor :: Int }
deriving Eq

instance Show FileDescriptor where
show fd = "<file-descriptor: " ++ show (getFileDescriptor fd) ++ ">"

socketFileDescriptor :: Socket -> IO FileDescriptor
socketFileDescriptor = fmap (FileDescriptor . fromIntegral) . Socket.socketToFd

localSocketFileDescriptor :: LocalSocket -> IO FileDescriptor
#if defined(mingw32_HOST_OS)
localSocketFileDescriptor =
\(LocalSocket fd) -> case ptrToIntPtr fd of
IntPtr i -> return (FileDescriptor i)
#else
localSocketFileDescriptor = socketFileDescriptor . getLocalHandle
#endif
Expand Up @@ -16,7 +16,7 @@ import Data.Functor.Identity (Identity (..))

import Ouroboros.Network.Snocket ( LocalAddress
, LocalSnocket
, LocalFD
, LocalSocket
)
import Ouroboros.Network.ErrorPolicy ( ErrorPolicies
, ErrorPolicyTrace
Expand Down Expand Up @@ -47,7 +47,7 @@ clientSubscriptionWorker
-> Tracer IO (WithAddr LocalAddress ErrorPolicyTrace)
-> NetworkMutableState LocalAddress
-> ClientSubscriptionParams a
-> (LocalFD -> IO a)
-> (LocalSocket -> IO a)
-> IO Void
clientSubscriptionWorker snocket
tracer
Expand Down
59 changes: 34 additions & 25 deletions ouroboros-network/src/Ouroboros/Network/Diffusion.hs
@@ -1,7 +1,8 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}

module Ouroboros.Network.Diffusion
( DiffusionTracers (..)
Expand Down Expand Up @@ -30,7 +31,11 @@ import Network.Mux (MuxTrace (..), WithMuxBearer (..))
import Network.Socket (AddrInfo, SockAddr)
import qualified Network.Socket as Socket

import Ouroboros.Network.Snocket (LocalAddress, SocketSnocket)
import Ouroboros.Network.Snocket ( LocalAddress
, LocalSnocket
, LocalSocket (..)
, SocketSnocket
)
import qualified Ouroboros.Network.Snocket as Snocket

import Ouroboros.Network.Protocol.Handshake.Version
Expand Down Expand Up @@ -62,7 +67,7 @@ import Ouroboros.Network.Tracers

data DiffusionTracers = DiffusionTracers {
dtIpSubscriptionTracer :: Tracer IO (WithIPList (SubscriptionTrace SockAddr))
-- ^ IP subscription tracer
-- ^ IP subscription tracer
, dtDnsSubscriptionTracer :: Tracer IO (WithDomainName (SubscriptionTrace SockAddr))
-- ^ DNS subscription tracer
, dtDnsResolverTracer :: Tracer IO (WithDomainName DnsTrace)
Expand All @@ -78,7 +83,6 @@ data DiffusionTracers = DiffusionTracers {
, dtErrorPolicyTracer :: Tracer IO (WithAddr SockAddr ErrorPolicyTrace)
, dtLocalErrorPolicyTracer :: Tracer IO (WithAddr LocalAddress ErrorPolicyTrace)
, dtAcceptPolicyTracer :: Tracer IO AcceptConnectionsPolicyTrace
-- ^ Trace rate limiting of accepted connections
}


Expand Down Expand Up @@ -139,7 +143,7 @@ instance Exception DiffusionFailure

runDataDiffusion
:: DiffusionTracers
-> DiffusionArguments
-> DiffusionArguments
-> DiffusionApplications
RemoteAddress LocalAddress
NodeToNodeVersionData NodeToClientVersionData
Expand All @@ -156,7 +160,6 @@ runDataDiffusion tracers
}
applications@DiffusionApplications { daErrorPolicies } =
withIOManager $ \iocp -> do

let -- snocket for remote communication.
snocket :: SocketSnocket
snocket = Snocket.socketSnocket iocp
Expand Down Expand Up @@ -293,7 +296,12 @@ runDataDiffusion tracers
-> IO ()
runLocalServer iocp networkLocalState =
bracket
(
localServerInit
localServerCleanup
localServerBody
where
localServerInit :: IO (LocalSocket, LocalSnocket)
localServerInit =
case daLocalAddress of
#if defined(mingw32_HOST_OS)
-- Windows uses named pipes so can't take advantage of existing sockets
Expand All @@ -303,23 +311,25 @@ runDataDiffusion tracers
a <- Socket.getSocketName sd
case a of
(Socket.SockAddrUnix path) ->
return (sd, Snocket.localSnocket iocp path)
_ ->
-- TODO: This should be logged.
throwIO UnsupportedLocalSocketType
return (LocalSocket sd, Snocket.localSnocket iocp path)
_unsupportedAddr ->
throwIO UnsupportedLocalSocketType
#endif
Right a -> do
let sn = Snocket.localSnocket iocp a
sd <- Snocket.open sn (Snocket.addrFamily sn $ Snocket.localAddressFromPath a)
Right addr -> do
let sn = Snocket.localSnocket iocp addr
sd <- Snocket.open sn (Snocket.addrFamily sn $ Snocket.localAddressFromPath addr)
return (sd, sn)
)
(\(sd,sn) -> Snocket.close sn sd) -- We close the socket here, even if it was provided for us.
(\(sd,sn) -> do

-- We close the socket here, even if it was provided for us.
localServerCleanup :: (LocalSocket, LocalSnocket) -> IO ()
localServerCleanup (sd, sn) = Snocket.close sn sd

localServerBody :: (LocalSocket, LocalSnocket) -> IO ()
localServerBody (sd, sn) = do
case daLocalAddress of
Left _ -> pure () -- If a socket was provided it should be ready to accept
Right a -> do
Snocket.bind sn sd $ Snocket.localAddressFromPath a
Right path -> do
Snocket.bind sn sd $ Snocket.localAddressFromPath path
Snocket.listen sn sd

void $ NodeToClient.withServer
Expand All @@ -333,7 +343,6 @@ runDataDiffusion tracers
sd
(daLocalResponderApplication applications)
localErrorPolicy
)

runServer :: SocketSnocket -> NetworkMutableState SockAddr -> Either Socket.Socket SockAddr -> IO ()
runServer sn networkState address =
Expand All @@ -348,8 +357,8 @@ runDataDiffusion tracers

case address of
Left _ -> pure () -- If a socket was provided it should be ready to accept
Right a -> do
Snocket.bind sn sd a
Right addr -> do
Snocket.bind sn sd addr
Snocket.listen sn sd

void $ NodeToNode.withServer
Expand Down

0 comments on commit 724cc1e

Please sign in to comment.