diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..47fb517 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +/objs +/dist +.*.swp +*~ +*.o +test diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..6f45040 --- /dev/null +++ b/LICENSE @@ -0,0 +1,30 @@ +Copyright (c) 2006-2009 Galois Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + + * Neither the name of Galois, Inc. nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER +OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Setup.hs b/Setup.hs new file mode 100644 index 0000000..fae177c --- /dev/null +++ b/Setup.hs @@ -0,0 +1,7 @@ +module Main (main) where + +import Distribution.Simple + +main :: IO () +main = defaultMain + diff --git a/TODO b/TODO new file mode 100644 index 0000000..e26e122 --- /dev/null +++ b/TODO @@ -0,0 +1,25 @@ + +-- src/Layer/Arp.hs ------------------------------------------------------------ + * There is currently no way to remove an address that has been added to the arp + layer + + +-- src/Layer/Ethernet.hs ------------------------------------------------------- + complete + + +-- src/Layer/IP4 --------------------------------------------------------------- + * There is no way to remove a route + + +-- src/Layer/Icmp4.hs ---------------------------------------------------------- + * The only ICMP message that is currently handled is EchoRequest + * DestinationUnreachable needs to be communicated up to relevant layers + + +-- src/Layer/Timer.hs ---------------------------------------------------------- + complete + + +-- src/Layer/Udp.hs ------------------------------------------------------------ + complete diff --git a/cbits/.gitignore b/cbits/.gitignore new file mode 100644 index 0000000..1103553 --- /dev/null +++ b/cbits/.gitignore @@ -0,0 +1,2 @@ +send +receive diff --git a/cbits/Makefile b/cbits/Makefile new file mode 100644 index 0000000..3421b21 --- /dev/null +++ b/cbits/Makefile @@ -0,0 +1,22 @@ + +OBJS = tapdevice.o + +ifndef V + QUIET_CC = @echo ' ' CC $@; +endif + +all : $(OBJS) send receive + +send : send.c + $(QUIET_CC)$(CC) -Wall -o $@ $< + +receive : receive.c + $(QUIET_CC)$(CC) -Wall -o $@ $< + +tapdevice.o : tapdevice.c + $(QUIET_CC)$(CC) -Wall -c -o $@ $< + +clean: + $(RM) tapdevice.o + $(RM) send + $(RM) receive diff --git a/cbits/receive.c b/cbits/receive.c new file mode 100644 index 0000000..c0b7316 --- /dev/null +++ b/cbits/receive.c @@ -0,0 +1,46 @@ +#include +#include +#include +#include +#include +#include + +int main(int argc, char **argv) +{ + int fd, res, len; + struct sockaddr_in server, client; + char buf[1000]; + + socklen_t clilen; + + memset(&server, 0x0, sizeof(server)); + server.sin_family = AF_INET; + server.sin_port = htons(40000); + server.sin_addr.s_addr = htonl(INADDR_ANY); + + fd = socket(AF_INET, SOCK_DGRAM, 0); + if(fd < 0) { + fprintf(stderr, "socket failed\n"); + return 1; + } + + res = bind(fd, (struct sockaddr *)&server, sizeof(server)); + if(res < 0) { + fprintf(stderr, "bind failed\n"); + close(fd); + return 1; + } + + while(1) { + memset(buf, 0x0, sizeof(buf)); + len = recvfrom(fd, buf, sizeof(buf), 0, + (struct sockaddr *)&client, &clilen); + printf("Message from: %s\n\t%s\n", inet_ntoa(client.sin_addr), + buf); + sendto(fd, buf, len, 0, (struct sockaddr *)&client, clilen); + } + + close(fd); + + return 0; +} diff --git a/cbits/send.c b/cbits/send.c new file mode 100644 index 0000000..e4f5534 --- /dev/null +++ b/cbits/send.c @@ -0,0 +1,42 @@ +#include +#include +#include +#include +#include + +int main(int argc, char **argv) +{ + int fd, res, len; + struct sockaddr_in server; + char buf[1000] = "Hello, world."; + + if(argc < 2) { + fprintf(stderr, "usage: %s \n", argv[0]); + return 1; + } + + memset(&server, 0x0, sizeof(server)); + server.sin_family = AF_INET; + server.sin_port = htons(40000); + + res = inet_pton(AF_INET, argv[1], &server.sin_addr); + if(res < 0) { + fprintf(stderr, "Unable to parse ip\n"); + return 1; + } + + fd = socket(AF_INET, SOCK_DGRAM, 0); + if(fd < 0) { + fprintf(stderr, "socket failed\n"); + return 1; + } + + sendto(fd, buf, sizeof(buf), 0, + (struct sockaddr *)&server, sizeof(server)); + + len = recvfrom(fd, buf, sizeof(buf), 0, NULL, NULL); + + printf("received: %s\n", buf); + + return 0; +} diff --git a/cbits/tapdevice.c b/cbits/tapdevice.c new file mode 100644 index 0000000..242a7fa --- /dev/null +++ b/cbits/tapdevice.c @@ -0,0 +1,36 @@ + +#include +#include +#include + +#include +#include +#include +#include +#include + +int init_tap_device(char *name) { + int fd, ret; + struct ifreq ifr; + + if(name == NULL) { + return -1; + } + + fd = open("/dev/net/tun", O_RDWR); + if(fd < 0) { + return -2; + } + + memset(&ifr, 0x0, sizeof(struct ifreq)); + ifr.ifr_flags = IFF_TAP | IFF_NO_PI; + strncpy(ifr.ifr_name, name, IFNAMSIZ); + + ret = ioctl(fd, TUNSETIFF, (void*) &ifr); + if(ret != 0) { + close(fd); + return -3; + } + + return fd; +} diff --git a/cbits/tapdevice.h b/cbits/tapdevice.h new file mode 100644 index 0000000..af2072a --- /dev/null +++ b/cbits/tapdevice.h @@ -0,0 +1,6 @@ +#ifndef __TAP_DEVICE_H +#define __TAP_DEVICE_H + +int init_tap_device(char *name); + +#endif diff --git a/example/WebServer.hs b/example/WebServer.hs new file mode 100644 index 0000000..03da917 --- /dev/null +++ b/example/WebServer.hs @@ -0,0 +1,153 @@ +module WebServer where + +import Control.Concurrent (forkIO,threadDelay) +import Data.Time.Calendar (Day(..)) +import Data.Time.Clock (UTCTime(..),addUTCTime) +import Data.Time.Clock.POSIX (POSIXTime,getPOSIXTime,posixSecondsToUTCTime) +import Data.Time.Format (formatTime) +import Hans.Layer.Tcp.Socket + (Socket,readLine,sendSocket,acceptSocket,listenPort,closeSocket + ,SocketError(..)) +import Hans.Message.Tcp (TcpPort) +import Hans.Setup (NetworkStack(nsTcp)) +import System.Exit (exitFailure) +import System.Locale (defaultTimeLocale) +import qualified Control.Exception as X +import qualified Data.ByteString as S + +webserver :: NetworkStack -> TcpPort -> IO () +webserver ns port = body `X.catch` \se -> print (se :: X.SomeException) + where + body = do + start <- getPOSIXTime + sock <- initServer ns port + serverLoop start sock + +accept :: Socket -> IO Socket +accept sock = loop + where + loop = X.catch (acceptSocket sock) $ \se -> do + case se of + AcceptError err -> putStrLn ("Accept error: " ++ err) + _ -> putStrLn ("Socket error: " ++ show se) + loop + +serverLoop :: POSIXTime -> Socket -> IO () +serverLoop start sock = loop + where + loop = do + client <- accept sock + _ <- forkIO (handleClient start client) + loop + +initServer :: NetworkStack -> TcpPort -> IO Socket +initServer ns port = listenPort (nsTcp ns) port `X.catch` h + where + h ListenError{} = do + putStrLn ("Unable to listen on port: " ++ show port) + exitFailure + h se = do + print se + exitFailure + +handleClient :: POSIXTime -> Socket -> IO () +handleClient start client = body `X.catch` \se -> print (se :: X.SomeException) + where + body = do + mb <- processRequest client + case mb of + Nothing -> closeSocket client + Just (url,req) -> do + sendSocket client =<< makeResponse start url req + threadDelay 1000000 + closeSocket client + +processRequest :: Socket -> IO (Maybe (String,[S.ByteString])) +processRequest sock = do + ls <- readRequest sock + case ls of + [] -> return Nothing + l:_ -> return (Just (parseUrl l, ls)) + +crlf :: S.ByteString +crlf = S.pack [0x0d, 0x0a] + +readRequest :: Socket -> IO [S.ByteString] +readRequest sock = loop + where + loop = do + line <- readLine sock + if line == crlf + then return [] + else do + rest <- loop + return (line:rest) + +parseUrl :: S.ByteString -> String +parseUrl = head . drop 1 . words . toString + +fromString :: String -> S.ByteString +fromString = S.pack . map (toEnum . fromEnum) + +toString :: S.ByteString -> String +toString = map (toEnum . fromEnum) . S.unpack + +status200 :: S.ByteString +status200 = fromString "HTTP/1.1 200 OK\r\n" + +contentLength :: Int -> S.ByteString +contentLength len = fromString ("Content-Length: " ++ show len) + +contentType :: String -> S.ByteString +contentType ty = fromString ("Content-Type: " ++ ty) + +response404 :: S.ByteString +response404 = fromString $ concat + [ "HTTP/1.1 404 Not Found\r\n" + , "Content-Length: 0\r\n" + , "\r\n" + ] + +connectionClose :: S.ByteString +connectionClose = fromString "Connection: close" + +makeResponse :: POSIXTime -> String -> [S.ByteString] -> IO S.ByteString +makeResponse start url req + | url == "/favicon.ico" = return response404 + | otherwise = do + uptime <- timePassed start + let date = posixSecondsToUTCTime start + body = fromString (concat + [ "HaLVM" + , "

Welcome to the HaLVM!


\r\n\r\n" + , "Started on: " + , formatDate date + , ", and up for " + , uptime + , "\r\n

HTTP Request:

\r\n
"
+               ]) `S.append` S.concat req
+                  `S.append` fromString "
" + return $! S.concat + [ status200 + , contentLength (S.length body), crlf + , contentType "text/html", crlf + , connectionClose, crlf + , crlf + , body + ] + +formatDate :: UTCTime -> String +formatDate = formatTime defaultTimeLocale "%c" + +zeroUTCTime :: UTCTime +zeroUTCTime = UTCTime (ModifiedJulianDay 0) 0 + +timePassed :: POSIXTime -> IO String +timePassed start = do + now <- getPOSIXTime + let date@(UTCTime day _) = addUTCTime (now - start) zeroUTCTime + return $ concat + [ show (toModifiedJulianDay day) + , " days, " + , formatTime defaultTimeLocale "%k hours, %M minutes, %S seconds." date + ] diff --git a/example/dhcp.hs b/example/dhcp.hs new file mode 100644 index 0000000..a433215 --- /dev/null +++ b/example/dhcp.hs @@ -0,0 +1,178 @@ +{-# LANGUAGE CPP #-} + +module Main where + +import Hans.Setup + +import Hans.Address +import Hans.Address.IP4 +import Hans.Address.Mac +import Hans.Layer.Ethernet +import Hans.Layer.Icmp4 +import Hans.Layer.Udp +import Hans.Message.Udp +import Hans.Message.Dhcp4 +import Hans.Message.Ip4 +import Hans.Message.Dhcp4Options +import Hans.Message.Dhcp4Codec +import Hans.Message.EthernetFrame +import Hans.Utils (Packet) + +import Control.Concurrent (threadDelay) +import Control.Monad (forever) +import Data.Serialize.Get +import Data.Serialize.Put + +import qualified Data.ByteString as S + +#ifdef xen_HOST_OS +import Hans.Device.Xen +import Hypervisor.Debug +import Hypervisor.Kernel +import XenDevice.NIC +#else +import Hans.Device.Tap +#endif + +output :: String -> IO () +outputBS :: S.ByteString -> IO () + +#ifdef xen_HOST_OS +output str = writeDebugConsole (showString str "\n") +outputBS = output . map (toEnum . fromEnum) . S.unpack +#else +output = putStrLn +outputBS = S.putStrLn +#endif + +ip :: IP4 +ip = IP4 192 168 80 9 + +broad :: IP4 +broad = IP4 255 255 255 255 + +dst :: IP4 +dst = IP4 192 168 80 10 + +mac :: Mac +mac = Mac 0x52 0x54 0x00 0x12 0x34 0x56 + +bootpc = UdpPort 68 +bootps = UdpPort 67 + +udpProtocol = IP4Protocol 0x11 + +ethernet = EtherType 0x800 + +zeroAddr = IP4 0 0 0 0 + +broadcastMac = Mac 0xff 0xff 0xff 0xff 0xff 0xff + +initEthernetDevice :: NetworkStack -> IO () +#ifdef xen_HOST_OS +initEthernetDevice ns = do + Just nic <- openXenDevice "" + addEthernetDevice (nsEthernet ns) mac (xenSend nic) (xenReceiveLoop nic) +#else +initEthernetDevice ns = do + Just dev <- openTapDevice "tap0" + addEthernetDevice (nsEthernet ns) mac (tapSend dev) (tapReceiveLoop dev) +#endif + + +main :: IO () +#ifdef xen_HOST_OS +main = halvm_kernel [dNICs] $ \ _ -> do + writeDebugConsole "hans2 test started\n" +#else +main = do +#endif + ns <- setup + initEthernetDevice ns + apply [ toOption (LocalEthernet (IP4Mask broad 0) mac) -- ip/24 network + , toOption (Route (zeroAddr `withMask` 0) ip) -- default route + ] ns + startEthernetDevice (nsEthernet ns) mac + addUdpHandler (nsUdp ns) bootps (simpleDhcpServerHandler ns) + forever $ threadDelay (1000 * 1000) + +simpleDhcpServerHandler ns remoteaddr remoteport packet + | remoteport /= bootpc = return () + | otherwise = case runGet getDhcp4Message packet of + Left err -> output err + Right msg -> case parseDhcpMessage msg of + Just (Left (RequestMessage req)) -> do + output (show req) + let ack = requestToAck simpleDhcpSettings req + sendResponse ns (ackToMessage ack) + Just (Left (DiscoverMessage disc)) -> do + output (show disc) + let offer = discoverToOffer simpleDhcpSettings disc + sendResponse ns (offerToMessage offer) + msg1 -> output (show msg) >> output (show msg1) + +simpleDhcpClient ns = do + let discover = mkDiscover (Xid 0x12345) mac + sendRequest ns (discoverToMessage discover) + +simpleDhcpClientHandler ns remoteaddr remoteport packet + | remoteport /= bootps = return () + | otherwise = case runGet getDhcp4Message packet of + Left err -> output err + Right msg -> case parseDhcpMessage msg of + Just (Right (OfferMessage offer)) -> do + output (show offer) + let req = offerToRequest offer + sendRequest ns (requestToMessage req) + Just (Right (AckMessage ack)) -> do + output (show ack) + output "Install assigned IP address" + msg1 -> do + output (show msg) + output (show msg1) + +sendRequest ns resp = do + ipBytes <- mkIpBytes zeroAddr broad bootpc bootps + (runPut (putDhcp4Message resp)) + + sendEthernet (nsEthernet ns) EthernetFrame + { etherDest = dhcp4ClientHardwareAddr resp + , etherSource = mac + , etherType = ethernet + , etherData = ipBytes + } + +sendResponse ns resp = do + ipBytes <- mkIpBytes ip broad bootps bootpc $ runPut (putDhcp4Message resp) + + sendEthernet (nsEthernet ns) EthernetFrame + { etherDest = broadcastMac + , etherSource = mac + , etherType = ethernet + , etherData = ipBytes + } + +mkIpBytes srcAddr dstAddr srcPort dstPort payload = do + udpBytes <- let udpHeader = UdpHeader srcPort dstPort 0 + udp = UdpPacket udpHeader payload + mk = mkIP4PseudoHeader srcAddr dstAddr udpProtocol + in renderUdpPacket udp mk + + ipBytes <- let iphdr = emptyIP4Header udpProtocol srcAddr dstAddr + ip = IP4Packet iphdr udpBytes + in renderIP4Packet ip + + return ipBytes + +simpleDhcpSettings :: ServerSettings +simpleDhcpSettings = Settings + { staticClientAddr = IP4 192 168 80 10 + , staticServerAddr = ip + , staticLeaseTime = 3600 + , staticSubnet = SubnetMask 24 + , staticBroadcast = IP4 192 168 80 255 + , staticDomainName = "galois.com" + , staticRouters = [IP4 192 168 80 9] + , staticDNS = [IP4 192 168 80 11] + , staticTimeOffset = 12345 + } diff --git a/example/setup b/example/setup new file mode 100755 index 0000000..88d5bd9 --- /dev/null +++ b/example/setup @@ -0,0 +1,15 @@ +#!/bin/sh + +case $1 in + start) + dev=$(sudo tunctl -b -u $(whoami)) + sudo ip link set dev $dev up + echo $dev + ;; + + stop) + dev=$2 + sudo ip link set dev $dev down + sudo tunctl -d $dev -b + ;; +esac diff --git a/example/test.hs b/example/test.hs new file mode 100644 index 0000000..f4b0eae --- /dev/null +++ b/example/test.hs @@ -0,0 +1,107 @@ +{-# LANGUAGE CPP #-} + +module Main where + +import WebServer + +import Hans.Address +import Hans.Address.IP4 +import Hans.Address.Mac +import Hans.DhcpClient (dhcpDiscover) +import Hans.Layer.Ethernet +import Hans.Message.Tcp (TcpPort(..)) +import Hans.Setup + +import System.Exit (exitFailure) +import qualified Data.ByteString as S + +#ifdef xen_HOST_OS +import Communication.IVC (InChannelEx,OutChannelEx,Bin) +import Hans.Device.Ivc +import Hans.Device.Xen +import Hypervisor.Debug +import Hypervisor.Kernel +import RendezvousLib.PeerToPeer (p2pConnection) +import XenDevice.NIC +#else +import Hans.Device.Tap +import System.Environment (getArgs) +#endif + +output :: String -> IO () +outputBS :: S.ByteString -> IO () + +#ifdef xen_HOST_OS +output str = writeDebugConsole (showString str "\n") +outputBS = output . map (toEnum . fromEnum) . S.unpack + +_buildInput :: IO (OutChannelEx Bin Bytes) +buildInput :: IO (InChannelEx Bin Bytes) +(_buildInput,buildInput) = p2pConnection "ethernet_dev_input" + +_buildOutput :: IO (InChannelEx Bin Bytes) +buildOutput :: IO (OutChannelEx Bin Bytes) +(_buildOutput,buildOutput) = p2pConnection "ethernet_dev_output" + +#else +output = putStrLn +outputBS = S.putStrLn +#endif + + + +initEthernetDevice :: NetworkStack -> IO Mac +#ifdef xen_HOST_OS +initEthernetDevice ns = do + Just nic <- openXenDevice "" + let mac = read (getNICName nic) + print mac + addEthernetDevice (nsEthernet ns) mac (xenSend nic) (xenReceiveLoop nic) + --let mac = Mac 0x52 0x54 0x00 0x12 0x34 0x56 + --putStrLn "Waiting for input channel..." + --input <- buildInput + --putStrLn "Waiting for output channel..." + --output <- buildOutput + --addEthernetDevice (nsEthernet ns) mac (ivcSend output) (ivcReceiveLoop input) + return mac +#else +initEthernetDevice ns = do + let mac = Mac 0x52 0x54 0x00 0x12 0x34 0x56 + Just dev <- openTapDevice "tap0" + addEthernetDevice (nsEthernet ns) mac (tapSend dev) (tapReceiveLoop dev) + return mac +#endif + +main :: IO () +#ifdef xen_HOST_OS +main = halvm_kernel [dNICs] $ \ args -> do +#else +main = do + args <- getArgs +#endif + ns <- setup + mac <- initEthernetDevice ns + startEthernetDevice (nsEthernet ns) mac + setAddress args mac ns + webserver ns (TcpPort 8000) + +setAddress :: [String] -> Mac -> NetworkStack -> IO () +setAddress args mac ns = + case args of + ["dhcp"] -> dhcpDiscover ns mac print + [ip,gw] -> apply (addrOptions ip gw mac) ns + _ -> do + putStrLn "Usage: dhcp" + putStrLn " " + exitFailure + +addrOptions :: String -> String -> Mac -> [SomeOption] +addrOptions ip gw mac = + [ SomeOption (LocalEthernet (ip4 `withMask` 24) mac) + , SomeOption (Route (IP4 0 0 0 0 `withMask` 0) gw4) + ] + where + ip4 :: IP4 + ip4 = read ip + gw4 :: IP4 + gw4 = read gw diff --git a/example/test_config b/example/test_config new file mode 100644 index 0000000..d9ad231 --- /dev/null +++ b/example/test_config @@ -0,0 +1,7 @@ +name = "test" +kernel = "test" +memory = 64 +vif = [''] +context = "system_u:system_r:domU_t" +on_crash = "destroy" +on_shutdown = "destroy" diff --git a/hans.cabal b/hans.cabal new file mode 100644 index 0000000..aff9110 --- /dev/null +++ b/hans.cabal @@ -0,0 +1,130 @@ +name: hans +version: 2.1.0.0 +cabal-version: >= 1.8 +license: OtherLicense +license-file: LICENSE +author: http://galois.com +maintainer: http://galois.com +category: Networking +synopsis: IPv4 Network Stack +build-type: Simple + +source-repository head + type: git + location: git://src.galois.com/srv/git/HaNS.git + +flag bounded-channels + default: False + description: Use bounded channels for message passing + +flag example + default: False + description: Build the example program + +library + if(os(xen)) + build-depends: XenDevice, communication + exposed-modules: Hans.Device.Xen, + Hans.Device.Ivc + else + build-depends: unix + include-dirs: cbits + c-sources: cbits/tapdevice.c + exposed-modules: Hans.Device.Tap + + if flag(bounded-channels) + build-depends: BoundedChan + cpp-options: -DBOUNDED_CHANNELS + + ghc-options: -Wall -O2 + hs-source-dirs: src + build-depends: base == 4.*, + cereal == 0.3.*, + bytestring == 0.9.1.*, + containers >= 0.4.0.0 && < 0.5.0.0, + monadLib == 3.6.*, + time >= 1.2.0.0 && < 1.3.0.0, + fingertree == 0.0.1.*, + random == 1.0.0.* + exposed-modules: Data.PrefixTree, + Hans.Address, + Hans.Utils.Checksum, + Hans.Layer, + Hans.Ports, + Hans.Setup, + Hans.Utils, + Hans.Channel, + Hans.DhcpClient, + Hans.Simple, + Hans.Layer.Tcp, + Hans.Layer.Arp.Table, + Hans.Layer.Tcp.Handlers, + Hans.Layer.Tcp.Monad, + Hans.Layer.Tcp.Socket, + Hans.Layer.IP4, + Hans.Layer.Udp, + Hans.Layer.Arp, + Hans.Layer.IP4.Routing, + Hans.Layer.IP4.Fragmentation, + Hans.Layer.Timer, + Hans.Layer.Ethernet, + Hans.Layer.Icmp4, + Hans.Message.Tcp, + Hans.Message.Types, + Hans.Message.Udp, + Hans.Message.Ip4, + Hans.Message.EthernetFrame, + Hans.Message.Arp, + Hans.Message.Dhcp4Codec, + Hans.Message.Dhcp4, + Hans.Message.Dhcp4Options, + Hans.Message.Icmp4, + Hans.Address.Mac, + Hans.Address.IP4, + Network.TCP.Aux.SockMonad, + Network.TCP.Aux.Param, + Network.TCP.Aux.Misc, + Network.TCP.Aux.Output, + Network.TCP.LTS.User, + Network.TCP.LTS.Out, + Network.TCP.LTS.InPassive, + Network.TCP.LTS.InActive, + Network.TCP.LTS.InMisc, + Network.TCP.LTS.In, + Network.TCP.LTS.InData, + Network.TCP.Type.Socket, + Network.TCP.Type.Syscall, + Network.TCP.Type.Datagram, + Network.TCP.Type.Base, + Network.TCP.Type.Timer + +executable test + if flag(example) + buildable: True + else + buildable: False + + if(flag(bounded-channels)) + build-depends: BoundedChan + cpp-options: -DBOUNDED_CHANNELS + + main-is: test.hs + other-modules: WebServer + ghc-options: -Wall + hs-source-dirs: example + build-depends: base == 4.*, + cereal == 0.3.0.*, + bytestring == 0.9.1.*, + containers >= 0.4.0.0 && < 0.5.0.0, + monadLib == 3.6.*, + time >= 1.2.0.0 && < 1.3.0.0, + old-locale == 1.0.0.*, + hans + + if os(xen) + build-depends: XenDevice == 1.0.0, + RendezvousLib == 1.0.0, + HALVMCore == 1.0.0, + communication == 1.0.0 + else + ghc-options: -threaded diff --git a/src/Data/PrefixTree.hs b/src/Data/PrefixTree.hs new file mode 100644 index 0000000..75c04c7 --- /dev/null +++ b/src/Data/PrefixTree.hs @@ -0,0 +1,244 @@ +{-# LANGUAGE CPP #-} + +module Data.PrefixTree ( + PrefixTree + + -- * Construction + , empty + , singleton + , insert + , delete + , toList + , fromList + + -- * Querying + , lookup + , member + , matches + , match + , elems + , keys + , key + ) where + +import Prelude hiding (lookup) +import Data.Maybe (isJust,listToMaybe) + +#ifdef TESTS +import Data.List (nub) +import Test.QuickCheck +#endif + +data PrefixTree a + = Empty + | Prefix Key (Maybe a) (PrefixTree a) + | Branch (PrefixTree a) (PrefixTree a) + deriving Show + +type Key = [Bool] + +-- Prefix Manipulation --------------------------------------------------------- + +matchPrefix :: Key -> Key -> (Key,Key,Key) +matchPrefix = loop id + where + loop k (a:as) (b:bs) | a == b = loop (k . (a:)) as bs + loop k as bs = (k [], as, bs) + + +-- Construction ---------------------------------------------------------------- + +empty :: PrefixTree a +empty = Empty + +singleton :: Key -> a -> PrefixTree a +singleton ks a = Prefix ks (Just a) empty + +fromList :: [(Key,a)] -> PrefixTree a +fromList = foldr (uncurry insert) empty + +toList :: PrefixTree a -> [([Bool], a)] +toList t = + case t of + Empty -> [] + + Prefix ls mb t' -> + case mb of + Nothing -> map (prefix ls) (toList t') + Just a -> (ls,a) : map (prefix ls) (toList t') + + Branch l r -> toList l ++ toList r + + where + prefix ls (ks,a) = (ls ++ ks, a) + +elems :: PrefixTree a -> [a] +elems t = + case t of + Empty -> [] + Prefix _ (Just a) t' -> a : elems t' + Prefix _ _ t' -> elems t' + Branch l r -> elems l ++ elems r + +insert :: Key -> a -> PrefixTree a -> PrefixTree a +insert ks a t = + case t of + + Empty -> singleton ks a + + Prefix ls mb t' -> + case matchPrefix ks ls of + -- empty node + ([],[],[]) -> Prefix [] (Just a) t' + + -- empty key + ([],[],_) -> Prefix [] (Just a) t + + -- empty node, full key + ([],_,[]) -> Prefix [] mb (insert ks a t') + + -- no common prefix, branch. + ([], k:_, _) + | k -> Branch (singleton ks a) t + | otherwise -> Branch t (singleton ks a) + + -- complete match, replace the value + (_, [], []) -> Prefix ks (Just a) t' + + -- complete prefix match, but partial key match + (_ ,ks',[]) -> Prefix ls mb (insert ks' a t') + + -- complete key match, partial prefix match + (_,[],ls') -> Prefix ks (Just a) (Prefix ls' mb t') + + -- partial common prefix, but not the full key + (ps,ks'@(k:_),ls') -> Prefix ps Nothing br + where + t1 = singleton ks' a + t2 = Prefix ls' mb t' + br | k = Branch t1 t2 + | otherwise = Branch t2 t1 + + Branch l r -> + case ks of + [] -> Prefix [] (Just a) t + b:_ | b -> Branch (insert ks a l) r + | otherwise -> Branch l (insert ks a r) + + +delete :: Key -> PrefixTree a -> PrefixTree a +delete ks t = + case t of + + Empty -> Empty + + Prefix ls mb t' -> + case matchPrefix ks ls of + (_,[],[]) -> compact (Prefix ls Nothing t') + ([],ks',[]) -> compact (Prefix ls mb (delete ks' t')) + _ -> t + + Branch l r -> + case ks of + [] -> t + b:bs | b -> compact (Branch (delete bs l) r) + | otherwise -> compact (Branch l (delete bs r)) + +compact :: PrefixTree a -> PrefixTree a +compact t = + case t of + + Prefix ls Nothing (Prefix ks mb t') -> Prefix (ls ++ ks) mb t' + + Branch l Empty -> l + Branch Empty r -> r + + _ -> t + + +-- Querying -------------------------------------------------------------------- + +member :: Key -> PrefixTree a -> Bool +member ks t = + case t of + + Empty -> False + + Prefix ls mb t' -> + case matchPrefix ks ls of + (_,[], []) -> isJust mb + (_,ks',[]) -> member ks' t' + _ -> False + + Branch l r -> + case ks of + [] -> False + b:_ | b -> member ks l + | otherwise -> member ks r + +matches :: Key -> PrefixTree a -> [a] +matches = loop [] + where + loop ms ks t = + case t of + + Empty -> ms + + Prefix ls mb t' -> + case matchPrefix ks ls of + (_,[], []) -> maybe ms (:ms) mb + (_,ks',[]) -> loop (maybe ms (:ms) mb) ks' t' + _ -> ms + + Branch l r -> + case ks of + [] -> ms + b:_ | b -> loop ms ks l + | otherwise -> loop ms ks r + +match :: Key -> PrefixTree a -> Maybe a +match k t = listToMaybe (matches k t) + +lookup :: Key -> PrefixTree a -> Maybe a +lookup = match + +keys :: Key -> PrefixTree a -> [Key] +keys = keys' [] [] + where + keys' as p ks t = + case t of + + Empty -> as + + Prefix ls _ t' -> + case matchPrefix ks ls of + (ps,ks',[]) -> keys' (p':as) p' ks' t' + where p' = p ++ ps + _ -> as + + Branch l r -> keys' ls p ks r + where ls = keys' as p ks l + +key :: Key -> PrefixTree a -> Maybe Key +key ks t = listToMaybe (keys ks t) + + +-- Tests ----------------------------------------------------------------------- + +#ifdef TESTS +forAllUniqueLists :: (Testable prop, Arbitrary a, Show a, Eq a) + => ([a] -> prop) -> Property +forAllUniqueLists = forAll (nub `fmap` arbitrary) + +prop_toList_fromList = forAllUniqueLists p + where + p :: [([Bool],())] -> Bool + p bs = length bs == length bs' && all (`elem` bs) bs' + where bs' = toList (fromList bs) + +prop_matchesOrder bs = and (map (f . fst) bs) + where + t1 = fromList bs + t2 = fromList (reverse bs) + f k = matches k t1 == matches k t2 +#endif diff --git a/src/Hans/Address.hs b/src/Hans/Address.hs new file mode 100644 index 0000000..1c4b2fe --- /dev/null +++ b/src/Hans/Address.hs @@ -0,0 +1,24 @@ +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE FlexibleContexts #-} + +module Hans.Address where + +import Data.Serialize (Serialize) +import Data.Word (Word8) + +class (Ord a, Serialize a) => Address a where + addrSize :: a -> Word8 + toBits :: a -> [Bool] + + +class Address addr => Mask mask addr | addr -> mask, mask -> addr where + masksAddress :: mask -> addr -> Bool + withMask :: addr -> Int -> mask + getMaskComponents :: mask -> (addr,Int) + getMaskRange :: mask -> (addr,addr) + broadcastAddress :: mask -> addr + + +isBroadcast :: (Eq addr, Mask mask addr) => mask -> addr -> Bool +isBroadcast m a = broadcastAddress m == a diff --git a/src/Hans/Address/IP4.hs b/src/Hans/Address/IP4.hs new file mode 100644 index 0000000..bfa7993 --- /dev/null +++ b/src/Hans/Address/IP4.hs @@ -0,0 +1,104 @@ +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE MultiParamTypeClasses #-} + +module Hans.Address.IP4 where + +import Hans.Address +import Hans.Utils (Endo) + +import Control.Monad (guard,liftM2) +import Data.Serialize (Serialize(..)) +import Data.Serialize.Get (getWord32be) +import Data.Serialize.Put (putWord32be) +import Data.Bits (Bits((.&.),(.|.),shiftL,shiftR)) +import Data.Data (Data) +import Data.List (intersperse) +import Data.Typeable (Typeable) +import Data.Word (Word8,Word32) +import Numeric (readDec) + + +data IP4 = IP4 + {-# UNPACK #-} !Word8 + {-# UNPACK #-} !Word8 + {-# UNPACK #-} !Word8 + {-# UNPACK #-} !Word8 + deriving (Ord,Eq,Typeable,Data) + +instance Address IP4 where + addrSize _ = 4 + + toBits (IP4 a b c d) = f 0x80 a (f 0x80 b (f 0x80 c (f 0x80 d []))) + where + f 0 _ xs = xs + f m i xs = (i .&. m == 0) : f (m `shiftR` 1) i xs + +instance Serialize IP4 where + get = do + n <- getWord32be + return $! convertFromWord32 n + put ip = putWord32be (convertToWord32 ip) + +instance Show IP4 where + showsPrec _ (IP4 a b c d) = foldl (.) id + $ intersperse (showChar '.') + [shows a, shows b, shows c, shows d] + +instance Read IP4 where + readsPrec _ rest0 = do + (a, '.':rest1) <- readDec rest0 + (b, '.':rest2) <- readDec rest1 + (c, '.':rest3) <- readDec rest2 + (d, rest4) <- readDec rest3 + return (IP4 a b c d, rest4) + +convertToWord32 :: IP4 -> Word32 +convertToWord32 (IP4 a b c d) + = fromIntegral a `shiftL` 24 + + fromIntegral b `shiftL` 16 + + fromIntegral c `shiftL` 8 + + fromIntegral d + +convertFromWord32 :: Word32 -> IP4 +convertFromWord32 n = IP4 a b c d + where + a = fromIntegral (n `shiftR` 24) + b = fromIntegral (n `shiftR` 16) + c = fromIntegral (n `shiftR` 8) + d = fromIntegral n + + +data IP4Mask = IP4Mask + {-# UNPACK #-} !IP4 + {-# UNPACK #-} !Word8 + deriving (Eq,Ord,Typeable,Data,Show) + +instance Serialize IP4Mask where + put (IP4Mask i m) = put i >> put m + get = liftM2 IP4Mask get get + +instance Read IP4Mask where + readsPrec x rest0 = do + (addr,'/':rest1) <- readsPrec x rest0 + (bits, rest2) <- readsPrec x rest1 + guard (bits >= 0 && bits <= 32) + return (IP4Mask addr bits, rest2) + +instance Mask IP4Mask IP4 where + masksAddress mask@(IP4Mask _ bits) a2 = + clearHostBits mask == clearHostBits (IP4Mask a2 bits) + getMaskRange x = (clearHostBits x, setHostBits x) + withMask addr bits = IP4Mask addr (fromIntegral bits) + getMaskComponents (IP4Mask addr bits) = (addr,fromIntegral bits) + broadcastAddress = setHostBits + +modifyAsWord32 :: Endo Word32 -> Endo IP4 +modifyAsWord32 f = convertFromWord32 . f . convertToWord32 + +clearHostBits :: IP4Mask -> IP4 +clearHostBits (IP4Mask addr bits) = modifyAsWord32 (.&. mask) addr + where mask = -2 ^ (32 - bits) + +setHostBits :: IP4Mask -> IP4 +setHostBits (IP4Mask addr bits) = modifyAsWord32 (.|. mask) addr + where mask = 2 ^ (32 - bits) - 1 diff --git a/src/Hans/Address/Mac.hs b/src/Hans/Address/Mac.hs new file mode 100644 index 0000000..9973286 --- /dev/null +++ b/src/Hans/Address/Mac.hs @@ -0,0 +1,77 @@ + +module Hans.Address.Mac ( + Mac(..) + , showsMac + , macMask + ) where + +import Hans.Address +import Hans.Utils (showPaddedHex) + +import Data.Serialize (Serialize(..)) +import Data.Serialize.Get (getWord16be,getWord32be) +import Data.Serialize.Put (putByteString) +import Data.Bits (Bits(shiftR,testBit,complement)) +import Data.List (intersperse) +import Data.Word (Word8) +import Numeric (readHex) +import qualified Data.ByteString as S + + +-- | Mac addresses. +data Mac = Mac + {-# UNPACK #-} !Word8 + {-# UNPACK #-} !Word8 + {-# UNPACK #-} !Word8 + {-# UNPACK #-} !Word8 + {-# UNPACK #-} !Word8 + {-# UNPACK #-} !Word8 + deriving ( Eq, Ord ) + + +-- | Show a Mac address. +showsMac :: Mac -> ShowS +showsMac (Mac a b c d e f) = foldl1 (.) + $ intersperse (showChar ':') + $ map showPaddedHex [a,b,c,d,e,f] + +-- | Generates a mask tailored to the given MAC address. +macMask :: Mac -> Mac +macMask (Mac a b c d e f) = + Mac (complement a) + (complement b) + (complement c) + (complement d) + (complement e) + (complement f) + +instance Show Mac where + showsPrec _ = showsMac + +instance Read Mac where + readsPrec _ = loop 6 [] + where + loop :: Int -> [Word8] -> String -> [(Mac,String)] + loop 0 [f,e,d,c,b,a] str = [(Mac a b c d e f,str)] + loop 0 _ _ = [] + loop n acc str = case readHex str of + [(a,':':rest)] -> loop (n-1) (a:acc) rest + [(a, rest)] -> loop 0 (a:acc) rest + _ -> [] + + +instance Address Mac where + addrSize _ = 6 + + toBits (Mac a b c d e f) = concatMap k [a,b,c,d,e,f] + where k i = map (testBit i) [0 .. 7] + +instance Serialize Mac where + get = do + n <- getWord32be + m <- getWord16be + let f x d = fromIntegral (x `shiftR` d) + return $! Mac (f n 24) (f n 16) (f n 8) (fromIntegral n) + (f m 8) (fromIntegral m) + + put (Mac a b c d e f) = putByteString (S.pack [a,b,c,d,e,f]) diff --git a/src/Hans/Channel.hs b/src/Hans/Channel.hs new file mode 100644 index 0000000..b24282d --- /dev/null +++ b/src/Hans/Channel.hs @@ -0,0 +1,34 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE CPP #-} + +module Hans.Channel where + +#ifdef BOUNDED_CHANNELS +import Control.Concurrent.BoundedChan as C + +type Channel a = BoundedChan a + +newChannel :: IO (Channel a) +newChannel = newBoundedChan 20 + +receive :: Channel a -> IO a +receive c = readChan c + +send :: Channel a -> a -> IO () +send c a = writeChan c $! a + +#else +import Control.Concurrent.Chan as C + +type Channel a = Chan a + +newChannel :: IO (Channel a) +newChannel = newChan + +receive :: Channel a -> IO a +receive = readChan + +send :: Channel a -> a -> IO () +send c a = writeChan c $! a + +#endif diff --git a/src/Hans/Device/Ivc.hs b/src/Hans/Device/Ivc.hs new file mode 100644 index 0000000..27159e2 --- /dev/null +++ b/src/Hans/Device/Ivc.hs @@ -0,0 +1,30 @@ +{-# LANGUAGE MultiParamTypeClasses #-} + +module Hans.Device.Ivc ( + ivcSend + , ivcReceiveLoop + , Bytes(getBytes) + ) where + +import Hans.Layer.Ethernet (EthernetHandle,queueEthernet) + +import Communication.IVC as IVC (put,OutChannelEx,get,InChannelEx,Bin) +import Control.Monad (forever,when) +import Data.Serialize (Serialize(get,put),getByteString,putByteString,remaining) +import qualified Data.ByteString as S + +newtype Bytes = Bytes + { getBytes :: S.ByteString + } deriving Show + +instance Serialize Bytes where + get = Bytes `fmap` (getByteString =<< remaining) + put = putByteString . getBytes + +ivcSend :: OutChannelEx Bin Bytes -> S.ByteString -> IO () +ivcSend chan = IVC.put chan . Bytes + +ivcReceiveLoop :: InChannelEx Bin Bytes -> EthernetHandle -> IO () +ivcReceiveLoop chan eth = forever $ do + Bytes bs <- IVC.get chan + when (S.length bs > 14) (queueEthernet eth bs) diff --git a/src/Hans/Device/Tap.hs b/src/Hans/Device/Tap.hs new file mode 100644 index 0000000..b5bd310 --- /dev/null +++ b/src/Hans/Device/Tap.hs @@ -0,0 +1,62 @@ +{-# LANGUAGE ForeignFunctionInterface #-} + +module Hans.Device.Tap where + +import Hans.Layer.Ethernet +import Hans.Utils + +import Control.Concurrent (threadWaitRead) +import Control.Monad (forever) +import Data.Word (Word8) +import Foreign.C.String (CString,withCString) +import Foreign.C.Types (CLong,CSize) +import Foreign.ForeignPtr (withForeignPtr) +import Foreign.Ptr (Ptr) +import System.Posix.Types (Fd) +import qualified Data.ByteString as S +import qualified Data.ByteString.Internal as S + + + +-- | Open a device by name. +openTapDevice :: DeviceName -> IO (Maybe Fd) +openTapDevice "" = return Nothing +openTapDevice name = withCString name $ \str -> do + ret <- c_init_tap_device str + if ret < 0 then return Nothing else return (Just ret) + + +-- | Send an ethernet frame via a tap device. +tapSend :: Fd -> Packet -> IO () +tapSend fd packet = do + let (fptr, 0, len) = S.toForeignPtr packet + _res <- withForeignPtr fptr $ \ptr -> c_write fd ptr (fromIntegral len) + -- XXX: make sure to continue sending if res < len + return () + + +-- | Fork a reciever loop, and return an IO action to kill the running thread. +tapReceiveLoop :: Fd -> EthernetHandle -> IO () +tapReceiveLoop fd eh = forever (k =<< tapReceive fd) + where k pkt = queueEthernet eh pkt + + +-- | Recieve an ethernet frame from a tap device. +tapReceive :: Fd -> IO Packet +tapReceive fd = do + threadWaitRead fd + let packet ptr = fromIntegral `fmap` c_read fd ptr 1514 + bs <- S.createAndTrim 1514 packet + if S.length bs <= 14 + then tapReceive fd + else return bs + + +foreign import ccall unsafe "init_tap_device" + c_init_tap_device :: CString -> IO Fd + +foreign import ccall unsafe "write" + c_write :: Fd -> Ptr Word8 -> CSize -> IO CLong + +foreign import ccall unsafe "read" + c_read :: Fd -> Ptr Word8 -> CSize -> IO CLong diff --git a/src/Hans/Device/Xen.hs b/src/Hans/Device/Xen.hs new file mode 100644 index 0000000..caac398 --- /dev/null +++ b/src/Hans/Device/Xen.hs @@ -0,0 +1,43 @@ +module Hans.Device.Xen where + +import Hans.Layer.Ethernet +import Hans.Utils + +import Data.Maybe (listToMaybe) +import XenDevice.NIC as Xen +import qualified Data.ByteString as S +import qualified Data.ByteString.Lazy as L + + +-- Utilities ------------------------------------------------------------------- + +infixl 1 >>=? +(>>=?) :: Monad m => m (Maybe a) -> (a -> m (Maybe b)) -> m (Maybe b) +m >>=? f = do + mb <- m + case mb of + Nothing -> return Nothing + Just a -> f a + +returnJust :: Monad m => a -> m (Maybe a) +returnJust = return . Just + + +-- Xen NIC --------------------------------------------------------------------- + + +openXenDevice :: String -> IO (Maybe NIC) +openXenDevice _ = + listToMaybe `fmap` Xen.potentialNICs >>=? \ dev -> + initializeNIC dev Nothing + + +xenSend :: NIC -> S.ByteString -> IO () +xenSend nic bs = void (Xen.transmitPacket nic (L.fromChunks [bs])) + + +xenReceiveLoop :: NIC -> EthernetHandle -> IO () +xenReceiveLoop nic eh = Xen.setReceiveHandler nic k + where + k bs | L.length bs <= 14 = return () + | otherwise = queueEthernet eh (S.concat (L.toChunks bs)) diff --git a/src/Hans/DhcpClient.hs b/src/Hans/DhcpClient.hs new file mode 100644 index 0000000..b4a3d7f --- /dev/null +++ b/src/Hans/DhcpClient.hs @@ -0,0 +1,217 @@ +module Hans.DhcpClient ( + dhcpDiscover + ) where + +import Hans.Address +import Hans.Address.IP4 +import Hans.Address.Mac +import Hans.Layer.Ethernet (sendEthernet,addEthernetHandler) +import Hans.Layer.IP4 (connectEthernet) +import Hans.Layer.Timer (delay) +import Hans.Layer.Udp (addUdpHandler,removeUdpHandler,queueUdp) +import Hans.Message.Dhcp4 +import Hans.Message.Dhcp4Codec +import Hans.Message.Dhcp4Options +import Hans.Message.EthernetFrame +import Hans.Message.Ip4 +import Hans.Message.Udp +import Hans.Setup + +import Control.Monad (guard) +import Data.Serialize (runGet,runPut) +import Data.Maybe (fromMaybe) +import System.Random (randomIO) +import qualified Data.ByteString as S + + +-- Protocol Constants ---------------------------------------------------------- + +-- | BOOTP server port. +bootps :: UdpPort +bootps = UdpPort 67 + +-- | BOOTP client port. +bootpc :: UdpPort +bootpc = UdpPort 68 + +currentNetwork :: IP4 +currentNetwork = IP4 0 0 0 0 + +broadcastIP4 :: IP4 +broadcastIP4 = IP4 255 255 255 255 + +broadcastMac :: Mac +broadcastMac = Mac 0xff 0xff 0xff 0xff 0xff 0xff + +udpProtocol :: IP4Protocol +udpProtocol = IP4Protocol 0x11 + +ethernetIp4 :: EtherType +ethernetIp4 = EtherType 0x0800 + +defaultRoute :: IP4Mask +defaultRoute = IP4 0 0 0 0 `withMask` 0 + + +-- DHCP ------------------------------------------------------------------------ + +type AckHandler = IP4 -> IO () + +-- | Discover a dhcp server, and request an address. +dhcpDiscover :: NetworkStack -> Mac -> AckHandler -> IO () +dhcpDiscover ns mac h = do + w32 <- randomIO + let xid = Xid (fromIntegral (w32 :: Int)) + + addEthernetHandler (nsEthernet ns) ethernetIp4 (dhcpIP4Handler ns) + addUdpHandler (nsUdp ns) bootpc (handleOffer ns (Just h)) + + let disc = discoverToMessage (mkDiscover xid mac) + sendMessage ns disc currentNetwork broadcastIP4 broadcastMac + +-- | Restore the connection between the Ethernet and IP4 layers. +restoreIp4 :: NetworkStack -> IO () +restoreIp4 ns = connectEthernet (nsIp4 ns) (nsEthernet ns) + +-- | Handle IP4 messages from the Ethernet layer, passing all relevant DHCP +-- messages to the UDP layer. +dhcpIP4Handler :: NetworkStack -> S.ByteString -> IO () +dhcpIP4Handler ns bytes = + case runGet parseIP4Packet bytes of + Left err -> putStrLn err >> return () + Right (hdr,ihl,len) + | ip4Protocol hdr == udpProtocol -> queue + | otherwise -> return () + where + queue = queueUdp (nsUdp ns) (ip4SourceAddr hdr) (ip4DestAddr hdr) + $ S.take (len - ihl) + $ S.drop ihl bytes + +-- | Handle a DHCP Offer message. +-- +-- * Remove the current UDP handler +-- * Install an DHCP Ack handler +-- * Send a DHCP Request +handleOffer :: NetworkStack -> Maybe AckHandler -> IP4 -> UdpPort + -> S.ByteString -> IO () +handleOffer ns mbh _src _srcPort bytes = + case runGet (getDhcp4Message) bytes of + Right msg -> case parseDhcpMessage msg of + + Just (Right (OfferMessage offer)) -> do + removeUdpHandler (nsUdp ns) bootpc + let req = requestToMessage (offerToRequest offer) + addUdpHandler (nsUdp ns) bootpc (handleAck ns offer mbh) + sendMessage ns req currentNetwork broadcastIP4 broadcastMac + + msg1 -> do + putStrLn (show msg) + putStrLn (show msg1) + + Left err -> putStrLn err + +-- | Handle a DHCP Ack message. +-- +-- * Remove the custom IP4 handler +-- * Restore the connection between the Ethernet and IP4 layers +-- * Remove the bootpc UDP listener +-- * Configure the network stack with options from the Ack +-- * Install a timer that renews the address after 50% of the lease time +-- has passed +handleAck :: NetworkStack -> Offer -> Maybe AckHandler -> IP4 -> UdpPort + -> S.ByteString -> IO () +handleAck ns offer mbh _src _srcPort bytes = + case runGet (getDhcp4Message) bytes of + Right msg -> case parseDhcpMessage msg of + + Just (Right (AckMessage ack)) -> do + removeUdpHandler (nsUdp ns) bootpc + restoreIp4 ns + apply (ackNsOptions ack) ns + let ms = fromIntegral (ackLeaseTime ack) * 500 + delay (nsTimers ns) ms (dhcpRenew ns offer) + putStrLn ("Bound to: " ++ show (ackYourAddr ack)) + + case mbh of + Nothing -> return () + Just h -> h (ackYourAddr ack) + + msg1 -> do + putStrLn (show msg) + putStrLn (show msg1) + + Left err -> putStrLn err + +-- | Perform a DHCP Renew. +-- +-- * Re-install the DHCP IP4 handler +-- * Add a UDP handler for an Ack message +-- * Re-send a renquest message, generated from the offer given. +dhcpRenew :: NetworkStack -> Offer -> IO () +dhcpRenew ns offer = do + addEthernetHandler (nsEthernet ns) ethernetIp4 (dhcpIP4Handler ns) + + let req = requestToMessage (offerToRequest offer) + addUdpHandler (nsUdp ns) bootpc (handleAck ns offer Nothing) + sendMessage ns req currentNetwork broadcastIP4 broadcastMac + + +-- NetworkStack Config --------------------------------------------------------- + +lookupGateway :: [Dhcp4Option] -> Maybe IP4 +lookupGateway = foldr p Nothing + where + p (OptRouters rs) _ = guard (not (null rs)) >> Just (head rs) + p _ a = a + +lookupSubnet :: [Dhcp4Option] -> Maybe Int +lookupSubnet = foldr p Nothing + where + p (OptSubnetMask (SubnetMask i)) _ = Just i + p _ a = a + +-- | Produce options for the network stack from a DHCP Ack. +ackNsOptions :: Ack -> [SomeOption] +ackNsOptions ack = + [ toOption (LocalEthernet (addr `withMask` mask) mac) + , toOption (Route defaultRoute gateway) + ] + where + mac = ackClientHardwareAddr ack + addr = ackYourAddr ack + opts = ackOptions ack + mask = fromMaybe 24 (lookupSubnet opts) + gateway = fromMaybe (ackRelayAddr ack) (lookupGateway opts) + + +-- Packet Helpers -------------------------------------------------------------- + +sendMessage :: NetworkStack -> Dhcp4Message -> IP4 -> IP4 -> Mac -> IO () +sendMessage ns resp src dst hwdst = do + ipBytes <- mkIpBytes src dst bootpc bootps + (runPut (putDhcp4Message resp)) + let mac = dhcp4ClientHardwareAddr resp + let frame = EthernetFrame + { etherDest = hwdst + , etherSource = mac + , etherType = ethernetIp4 + , etherData = ipBytes + } + putStrLn (show mac ++ " -> " ++ show hwdst) + + sendEthernet (nsEthernet ns) frame + +mkIpBytes :: IP4 -> IP4 -> UdpPort -> UdpPort -> S.ByteString -> IO S.ByteString +mkIpBytes srcAddr dstAddr srcPort dstPort payload = do + udpBytes <- do + let udpHdr = UdpHeader srcPort dstPort 0 + udp = UdpPacket udpHdr payload + mk = mkIP4PseudoHeader srcAddr dstAddr udpProtocol + renderUdpPacket udp mk + + ipBytes <- do + let ipHdr = emptyIP4Header udpProtocol srcAddr dstAddr + ip = IP4Packet ipHdr udpBytes + renderIP4Packet ip + + return ipBytes diff --git a/src/Hans/Layer.hs b/src/Hans/Layer.hs new file mode 100644 index 0000000..57cd4b7 --- /dev/null +++ b/src/Hans/Layer.hs @@ -0,0 +1,140 @@ +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE Rank2Types #-} + +module Hans.Layer where + +import Hans.Utils (just) + +import Control.Applicative (Applicative(..),Alternative(..)) +import Control.Monad (ap,MonadPlus(mzero,mplus)) +import Data.Monoid (Monoid(..)) +import Data.Time.Clock.POSIX +import MonadLib (StateM(get,set)) +import qualified Control.Exception as X +import qualified Data.Map as Map + +data LayerState i = LayerState + { lsNow :: POSIXTime + , lsState :: i + } + +data Action = Nop | Action (IO ()) + +instance Monoid Action where + mempty = Nop + + mappend (Action a) (Action b) = Action (a >> b) + mappend Nop b = b + mappend a _ = a + +runAction :: Action -> IO () +runAction Nop = return () +runAction (Action m) = m `X.catch` \ se -> print (se :: X.SomeException) + +data Result i a + = Error Action + | Result (LayerState i) a Action + +-- | Failure continuation +type Failure i r = Action -> Result i r + +-- | Success continuation +type Success a i r = a -> LayerState i -> Action -> Result i r + +newtype Layer i a = Layer + { getLayer :: forall r. LayerState i -> Action + -> Failure i r -> Success a i r + -> Result i r } + +runLayer :: LayerState i -> Layer i a -> Result i a +runLayer i0 m = getLayer m i0 mempty Error success + where success a i o = Result i a o + +loopLayer :: i -> IO msg -> (msg -> Layer i ()) -> IO () +loopLayer i0 msg k = loop (LayerState 0 i0) + where + loop i = do + a <- msg + now <- getPOSIXTime + let res = runLayer (i {lsNow = now }) (k a) + X.evaluate res `X.catch` \ se -> print (se :: X.SomeException) >> return res + case res of + Error m -> runAction m >> loop i + Result i' () m -> runAction m >> loop i' + +instance Functor (Layer i) where + fmap g m = Layer (\i0 o0 f k -> getLayer m i0 o0 f (\a i o -> k (g a) i o)) + +instance Applicative (Layer i) where + pure = return + (<*>) = ap + +instance Alternative (Layer i) where + empty = Layer (\_ o0 f _ -> f o0) + a <|> b = Layer (\i o f k -> getLayer a i o (\_ -> getLayer b i o f k) k) + +instance Monad (Layer i) where + return x = Layer (\i o _ k -> k x i o) + m >>= g = Layer $ \i0 o0 f k -> getLayer m i0 o0 f $ \a i o -> + getLayer (g a) i o f k + +instance MonadPlus (Layer i) where + mzero = empty + mplus = (<|>) + +instance StateM (Layer i) i where + get = Layer (\i0 o0 _ k -> k (lsState i0) i0 o0) + set i = Layer (\i0 o0 _ k -> k () (i0 { lsState = i }) o0) + + +-- Utilities ------------------------------------------------------------------- + +dropPacket :: Layer i a +dropPacket = empty + +time :: Layer i POSIXTime +time = Layer $ \i0 o0 _ k -> k (lsNow i0) i0 o0 + +output :: IO () -> Layer i () +output m = Layer $ \i0 o0 _ k -> k () i0 (o0 `mappend` Action m) + +liftRight :: Either String b -> Layer i b +liftRight (Right b) = return b +liftRight (Left err) = do + output (putStrLn err) + dropPacket + +-- Handler Generalization ------------------------------------------------------ + +type Handlers k a = Map.Map k a + +emptyHandlers :: Handlers k a +emptyHandlers = Map.empty + + +class ProvidesHandlers i k a | i -> k a where + getHandlers :: i -> Handlers k a + setHandlers :: Handlers k a -> i -> i + + +getHandler :: (Ord k, ProvidesHandlers i k a) => k -> Layer i a +getHandler k = do + state <- get + just (Map.lookup k (getHandlers state)) + + +addHandler :: (Ord k, ProvidesHandlers i k a) => k -> a -> Layer i () +addHandler k a = do + state <- get + let hs' = Map.insert k a (getHandlers state) + hs' `seq` set (setHandlers hs' state) + + +removeHandler :: (Ord k, ProvidesHandlers i k a) => k -> Layer i () +removeHandler k = do + state <- get + let hs' = Map.delete k (getHandlers state) + hs' `seq` set (setHandlers hs' state) diff --git a/src/Hans/Layer/Arp.hs b/src/Hans/Layer/Arp.hs new file mode 100644 index 0000000..e8b867c --- /dev/null +++ b/src/Hans/Layer/Arp.hs @@ -0,0 +1,243 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE FlexibleContexts #-} + +module Hans.Layer.Arp ( + ArpHandle + , runArpLayer + + -- External Interface + , arpWhoHas + , arpIP4Packet + , addLocalAddress + ) where + +import Hans.Address.IP4 +import Hans.Address.Mac +import Hans.Channel +import Hans.Layer +import Hans.Layer.Arp.Table +import Hans.Layer.Ethernet +import Hans.Layer.Timer +import Hans.Message.Arp +import Hans.Message.EthernetFrame +import Hans.Utils + +import Control.Concurrent (forkIO,takeMVar,putMVar,newEmptyMVar) +import Control.Monad (forM_,mplus,guard,unless,when) +import Data.Serialize (decode,encode) +import MonadLib (BaseM(inBase),set,get) +import qualified Data.Map as Map + +-- | A handle to a running arp layer. +type ArpHandle = Channel (Arp ()) + + +-- | Start an arp layer. +runArpLayer :: ArpHandle -> EthernetHandle -> TimerHandle -> IO () +runArpLayer h eth th = do + addEthernetHandler eth (EtherType 0x0806) (send h . handleIncoming) + let i = emptyArpState h eth th + void (forkIO (loopLayer i (receive h) id)) + + +-- | Lookup the hardware address associated with an IP address. +arpWhoHas :: BaseM m IO => ArpHandle -> IP4 -> m (Maybe Mac) +arpWhoHas h !ip = inBase $ do + var <- newEmptyMVar + send h (whoHas ip (putMVar var)) + takeMVar var + + +-- | Send an IP packet via the arp layer, to resolve the underlying hardware +-- addresses. +arpIP4Packet :: ArpHandle -> IP4 -> IP4 -> Packet -> IO () +arpIP4Packet h !src !dst !pkt = send h (handleOutgoing src dst pkt) + + +addLocalAddress :: ArpHandle -> IP4 -> Mac -> IO () +addLocalAddress h !ip !mac = send h (handleAddAddress ip mac) + + +-- Message Handling ------------------------------------------------------------ + +type Arp = Layer ArpState + +data ArpState = ArpState + { arpTable :: ArpTable + , arpAddrs :: Map.Map IP4 Mac -- this layer's addresses + , arpWaiting :: Map.Map IP4 [Maybe Mac -> IO ()] + , arpEthernet :: EthernetHandle + , arpTimers :: TimerHandle + , arpSelf :: ArpHandle + } + +emptyArpState :: ArpHandle -> EthernetHandle -> TimerHandle -> ArpState +emptyArpState h eth ts = ArpState + { arpTable = Map.empty + , arpAddrs = Map.empty + , arpWaiting = Map.empty + , arpEthernet = eth + , arpTimers = ts + , arpSelf = h + } + +ethernetHandle :: Arp EthernetHandle +ethernetHandle = arpEthernet `fmap` get + +timerHandle :: Arp TimerHandle +timerHandle = arpTimers `fmap` get + +addEntry :: IP4 -> Mac -> Arp () +addEntry spa sha = do + state <- get + now <- time + let table' = addArpEntry now spa sha (arpTable state) + table' `seq` set state { arpTable = table' } + runWaiting spa (Just sha) + +addWaiter :: IP4 -> (Maybe Mac -> IO ()) -> Arp () +addWaiter addr cont = do + state <- get + set state { arpWaiting = Map.alter f addr (arpWaiting state) } + where + f Nothing = Just [cont] + f (Just ks) = Just (cont:ks) + +runWaiting :: IP4 -> Maybe Mac -> Arp () +runWaiting spa sha = do + state <- get + let (mb,waiting') = Map.updateLookupWithKey f spa (arpWaiting state) + where f _ _ = Nothing + -- run the callbacks associated with this protocol address + let run cb = output (cb sha) + mapM_ run (maybe [] reverse mb) + waiting' `seq` set state { arpWaiting = waiting' } + +updateExistingEntry :: IP4 -> Mac -> Arp Bool +updateExistingEntry spa sha = do + state <- get + let update = do + guard (spa `Map.member` arpTable state) + addEntry spa sha + return True + update `mplus` return False + +localHwAddress :: IP4 -> Arp Mac +localHwAddress pa = do + state <- get + just (Map.lookup pa (arpAddrs state)) + +sendArpPacket :: ArpPacket Mac IP4 -> Arp () +sendArpPacket msg = do + eth <- ethernetHandle + let frame = EthernetFrame + { etherSource = arpSHA msg + , etherDest = arpTHA msg + , etherType = 0x0806 + , etherData = encode msg + } + output (sendEthernet eth frame) + +advanceArpTable :: Arp () +advanceArpTable = do + now <- time + state <- get + let (table', timedOut) = stepArpTable now (arpTable state) + set state { arpTable = table' } + forM_ timedOut $ \ x -> runWaiting x Nothing + +-- | Handle a who-has request +whoHas :: IP4 -> (Maybe Mac -> IO ()) -> Arp () +whoHas ip k = (k' =<< localHwAddress ip) `mplus` query + where + k' addr = output (k (Just addr)) + + query = do + advanceArpTable + state <- get + case lookupArpEntry ip (arpTable state) of + KnownAddress mac -> k' mac + Pending -> addWaiter ip k + Unknown -> do + let addrs = Map.toList (arpAddrs state) + msg (spa,sha) = ArpPacket + { arpHwType = 0x1 + , arpPType = 0x0800 + , arpSHA = sha + , arpSPA = spa + , arpTHA = Mac 0xff 0xff 0xff 0xff 0xff 0xff + , arpTPA = ip + , arpOper = ArpRequest + } + now <- time + let table' = addPending now ip (arpTable state) + set state { arpTable = table' } + addWaiter ip k + mapM_ (sendArpPacket . msg) addrs + th <- timerHandle + output (delay th 10000 (send (arpSelf state) advanceArpTable)) + +-- Message Handling ------------------------------------------------------------ + +-- | Process an incoming arp packet +handleIncoming :: Packet -> Arp () +handleIncoming bs = do + msg <- liftRight (decode bs) + -- ?Do I have the hardware type in ar$hrd + -- Yes: (This check is enforced by the type system) + -- [optionally check the hardware length ar$hln] + -- ?Do I speak the protocol in ar$pro? + -- Yes: (This check is also enforced by the type system) + -- [optionally check the protocol length ar$pln] + -- Merge_flag := false + -- If the pair is + -- already in my translation table, update the sender + -- hardware address field of the entry with the new + -- information in the packet and set Merge_flag to true. + let sha = arpSHA msg + let spa = arpSPA msg + merge <- updateExistingEntry spa sha + -- ?Am I the target protocol address? + let tpa = arpTPA msg + lha <- localHwAddress tpa + -- Yes: + -- If Merge_flag is false, add the triplet to + -- the translation table. + unless merge (addEntry spa sha) + -- ?Is the opcode ares_op$REQUEST? (NOW look at the opcode!!) + -- Yes: + when (arpOper msg == ArpRequest) $ do + -- Swap hardware and protocol fields, putting the local + -- hardware and protocol addresses in the sender fields. + let msg' = msg { arpSHA = lha , arpSPA = tpa + , arpTHA = sha , arpTPA = spa + -- Set the ar$op field to ares_op$REPLY + , arpOper = ArpReply } + -- Send the packet to the (new) target hardware address on + -- the same hardware on which the request was received. + sendArpPacket msg' + + +-- | Handle a request to associate an ip with a mac address for a local device +handleAddAddress :: IP4 -> Mac -> Arp () +handleAddAddress ip mac = do + state <- get + let addrs' = Map.insert ip mac (arpAddrs state) + addrs' `seq` set state { arpAddrs = addrs' } + + +-- | Output a packet to the ethernet layer. +handleOutgoing :: IP4 -> IP4 -> Packet -> Arp () +handleOutgoing src dst bs = do + eth <- ethernetHandle + lha <- localHwAddress src + let frame dha = EthernetFrame + { etherDest = dha + , etherSource = lha + , etherType = 0x0800 + , etherData = bs + } + whoHas dst $ \ res -> case res of + Nothing -> return () + Just dha -> sendEthernet eth (frame dha) diff --git a/src/Hans/Layer/Arp/Table.hs b/src/Hans/Layer/Arp/Table.hs new file mode 100644 index 0000000..bce4923 --- /dev/null +++ b/src/Hans/Layer/Arp/Table.hs @@ -0,0 +1,59 @@ +module Hans.Layer.Arp.Table where + +import Hans.Address.Mac +import Hans.Address.IP4 + +import Control.Arrow (second) +import Data.Time.Clock.POSIX (POSIXTime) +import qualified Data.Map as Map + + +-- Arp Table ------------------------------------------------------------------- + +arpEntryTimeout :: POSIXTime +arpEntryTimeout = 60 + +data ArpEntry + = ArpEntry { arpMac :: Mac + , arpTimeout :: POSIXTime + } + | ArpPending { arpTimeout :: POSIXTime + } + deriving Show + +data ArpResult + = KnownAddress Mac + | Pending + | Unknown + +type ArpTable = Map.Map IP4 ArpEntry + +stepArpTable :: POSIXTime -> ArpTable -> (ArpTable, [IP4]) +stepArpTable now tab = second (Map.keys) (Map.partition p tab) + where + p ent = arpTimeout ent >= now + +addArpEntry :: POSIXTime -> IP4 -> Mac -> ArpTable -> ArpTable +addArpEntry now ip mac = Map.insert ip ent where + ent = ArpEntry + { arpMac = mac + , arpTimeout = now + arpEntryTimeout + } + +-- | Assumption: there is not already a pending ARP query recorded in the +-- ARP table for the given IP address. +addPending :: POSIXTime -> IP4 -> ArpTable -> ArpTable +addPending now ip = Map.insert ip ent where + ent = ArpPending + { arpTimeout = now + arpEntryTimeout -- FIXME: should queries stay longer? + } + +-- | If the ARP table has a fully realized entry for the given IP address, +-- then return it. Otherwise return Pending if we're waiting for this info, +-- or Unknown if nothing is currently known about it. +lookupArpEntry :: IP4 -> ArpTable -> ArpResult +lookupArpEntry ip arp = + case Map.lookup ip arp of + Just (ArpEntry mac _) -> KnownAddress mac + Just (ArpPending _) -> Pending + _ -> Unknown diff --git a/src/Hans/Layer/Ethernet.hs b/src/Hans/Layer/Ethernet.hs new file mode 100644 index 0000000..f1bcae3 --- /dev/null +++ b/src/Hans/Layer/Ethernet.hs @@ -0,0 +1,182 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FlexibleContexts #-} + +module Hans.Layer.Ethernet ( + EthernetHandle + , runEthernetLayer + + -- * External Interface + , Tx + , Rx + , sendEthernet + , queueEthernet + , addEthernetDevice + , removeEthernetDevice + , addEthernetHandler + , removeEthernetHandler + , startEthernetDevice + , stopEthernetDevice + ) where + +import Hans.Address.Mac +import Hans.Channel +import Hans.Layer +import Hans.Message.EthernetFrame +import Hans.Utils (Packet,void,just) + +import Control.Concurrent (forkIO,ThreadId,killThread) +import Control.Monad (mplus) +import Data.Serialize (decode,encode) +import MonadLib (get,set) +import qualified Data.Map as Map + + +-- Messages -------------------------------------------------------------------- + +type Handler = Packet -> IO () + +type Tx = Packet -> IO () +type Rx = EthernetHandle -> IO () + +type EthernetHandle = Channel (Eth ()) + +-- | Run the ethernet layer. +runEthernetLayer :: EthernetHandle -> IO () +runEthernetLayer h = + void (forkIO (loopLayer (emptyEthernetState h) (receive h) id)) + +sendEthernet :: EthernetHandle -> EthernetFrame -> IO () +sendEthernet h !frame = send h (handleOutgoing frame) + +queueEthernet :: EthernetHandle -> Packet -> IO () +queueEthernet h !pkt = send h (handleIncoming pkt) + +startEthernetDevice :: EthernetHandle -> Mac -> IO () +startEthernetDevice h !m = send h (startDevice m) + +stopEthernetDevice :: EthernetHandle -> Mac -> IO () +stopEthernetDevice h !m = send h (stopDevice m) + +addEthernetDevice :: EthernetHandle -> Mac -> Tx -> Rx -> IO () +addEthernetDevice h !mac tx rx = send h (addDevice mac tx rx) + +removeEthernetDevice :: EthernetHandle -> Mac -> IO () +removeEthernetDevice h !mac = send h (delDevice mac) + +addEthernetHandler :: EthernetHandle -> EtherType -> Handler -> IO () +addEthernetHandler h !et k = send h (addHandler et k) + +removeEthernetHandler :: EthernetHandle -> EtherType -> IO () +removeEthernetHandler h !et = send h (removeHandler et) + + +-- Ethernet Message Monad ------------------------------------------------------ + +data EthernetDevice = EthernetDevice + { devTx :: Tx + , devRx :: IO () + , devUp :: Maybe ThreadId + } + +emptyDevice :: Tx -> IO () -> EthernetDevice +emptyDevice tx rx = EthernetDevice + { devTx = tx + , devRx = rx + , devUp = Nothing + } + + +type Eth = Layer EthernetState + +data EthernetState = EthernetState + { ethHandlers :: Handlers EtherType Handler + , ethDevices :: Map.Map Mac EthernetDevice + , ethHandle :: EthernetHandle + } + +instance ProvidesHandlers EthernetState EtherType Handler where + getHandlers = ethHandlers + setHandlers hs i = i { ethHandlers = hs } + +emptyEthernetState :: EthernetHandle -> EthernetState +emptyEthernetState h = EthernetState + { ethHandlers = emptyHandlers + , ethDevices = Map.empty + , ethHandle = h + } + +self :: Eth EthernetHandle +self = ethHandle `fmap` get + +-- Message Handling ------------------------------------------------------------ + +-- | Handle an incoming packet, from a device. +handleIncoming :: Packet -> Eth () +handleIncoming pkt = do + frame <- liftRight (decode pkt) + h <- getHandler (etherType frame) + output (h (etherData frame)) + + +-- | Get the device associated with a mac address. +getDevice :: Mac -> Eth EthernetDevice +getDevice mac = do + state <- get + just (Map.lookup mac (ethDevices state)) + + +-- | Set the device associated with a mac address. +setDevice :: Mac -> EthernetDevice -> Eth () +setDevice mac dev = do + state <- get + let ds' = Map.insert mac dev (ethDevices state) + ds' `seq` set state { ethDevices = ds' } + + +-- | Send an outgoing ethernet frame via the device that it's associated with. +handleOutgoing :: EthernetFrame -> Eth () +handleOutgoing frame = do + dev <- getDevice (etherSource frame) + output (devTx dev (encode frame)) + + +-- | Add an ethernet device to the state. +addDevice :: Mac -> Tx -> Rx -> Eth () +addDevice mac tx rx = do + stopDevice mac `mplus` return () + h <- self + setDevice mac (emptyDevice tx (rx h)) + + +-- | Remove a device +delDevice :: Mac -> Eth () +delDevice mac = do + stopDevice mac + state <- get + let ds' = Map.delete mac (ethDevices state) + ds' `seq` set state { ethDevices = ds' } + + +-- | Stop an ethernet device. +stopDevice :: Mac -> Eth () +stopDevice mac = do + dev <- getDevice mac + case devUp dev of + Nothing -> return () + Just tid -> do + output (killThread tid) + setDevice mac dev { devUp = Nothing } + + +-- | Start an ethernet device. +startDevice :: Mac -> Eth () +startDevice mac = do + dev <- getDevice mac + case devUp dev of + Just _ -> return () + -- XXX: add functionality to pipe the threadid back into the layer state. + Nothing -> output (void (forkIO (devRx dev))) + --setDevice mac dev { devUp = Just tid } diff --git a/src/Hans/Layer/IP4.hs b/src/Hans/Layer/IP4.hs new file mode 100644 index 0000000..31761a9 --- /dev/null +++ b/src/Hans/Layer/IP4.hs @@ -0,0 +1,220 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE FlexibleInstances #-} + +module Hans.Layer.IP4 ( + IP4Handle + , runIP4Layer + , Rule(..) + + -- * External Interface + , connectEthernet + , withIP4Source + , sendIP4Packet + , addIP4RoutingRule + , addIP4Handler + , removeIP4Handler + ) where + +import Hans.Address +import Hans.Address.IP4 +import Hans.Channel +import Hans.Layer +import Hans.Layer.Arp +import Hans.Layer.Ethernet +import Hans.Layer.IP4.Fragmentation +import Hans.Layer.IP4.Routing +import Hans.Message.EthernetFrame +import Hans.Message.Ip4 +import Hans.Utils +import Hans.Utils.Checksum + +import Control.Concurrent (forkIO) +import Control.Monad (guard,mplus,(<=<)) +import Data.Serialize.Get (runGet) +import MonadLib (get,set) +import qualified Data.ByteString as S + + +type Handler = IP4 -> IP4 -> Packet -> IO () + +type IP4Handle = Channel (IP ()) + +runIP4Layer :: IP4Handle -> ArpHandle -> EthernetHandle -> IO () +runIP4Layer h arp eth = do + void (forkIO (loopLayer (emptyIP4State arp) (receive h) id)) + connectEthernet h eth + +connectEthernet :: IP4Handle -> EthernetHandle -> IO () +connectEthernet h eth = + addEthernetHandler eth (EtherType 0x0800) (send h . handleIncoming) + +withIP4Source :: IP4Handle -> IP4 -> (IP4 -> IO ()) -> IO () +withIP4Source h !dst k = send h (handleSource dst k) + +addIP4RoutingRule :: IP4Handle -> Rule IP4Mask IP4 -> IO () +addIP4RoutingRule h !rule = send h (handleAddRule rule) + +sendIP4Packet :: IP4Handle -> IP4Protocol -> IP4 -> Packet -> IO () +sendIP4Packet h !prot !dst !pkt = send h (handleOutgoing prot dst pkt) + +addIP4Handler :: IP4Handle -> IP4Protocol -> Handler -> IO () +addIP4Handler h !prot k = send h (addHandler prot k) + +removeIP4Handler :: IP4Handle -> IP4Protocol -> IO () +removeIP4Handler h !prot = send h (removeHandler prot) + +-- IP4 State ------------------------------------------------------------------- + +type IP = Layer IP4State + +data IP4State = IP4State + { ip4Fragments :: FragmentationTable IP4 + , ip4Routes :: RoutingTable IP4 + , ip4Handlers :: Handlers IP4Protocol Handler + , ip4NextIdent :: Ident + , ip4ArpHandle :: ArpHandle + } + +instance ProvidesHandlers IP4State IP4Protocol Handler where + getHandlers = ip4Handlers + setHandlers hs i = i { ip4Handlers = hs } + + +emptyIP4State :: ArpHandle -> IP4State +emptyIP4State arp = IP4State + { ip4Fragments = emptyFragmentationTable + , ip4Routes = emptyRoutingTable + , ip4Handlers = emptyHandlers + , ip4NextIdent = 0 + , ip4ArpHandle = arp + } + + +-- IP4 Utilities --------------------------------------------------------------- + +arpHandle :: IP ArpHandle +arpHandle = ip4ArpHandle `fmap` get + +sendBytes :: IP4Protocol -> IP4 -> Packet -> IP () +sendBytes prot dst bs = do + rule@(src,_,mtu) <- findRoute dst + let hdr = emptyIP4Header prot src dst + hdr' <- if fromIntegral (S.length bs) + 20 < mtu + then return hdr + else do + i <- nextIdent + return (setIdent i hdr) + sendPacket' (IP4Packet hdr' bs) rule + +sendPacket :: IP4Packet -> IP () +sendPacket pkt = do + rule@(src,_,_) <- findRoute (ip4DestAddr (ip4Header pkt)) + guard (src /= ip4SourceAddr (ip4Header pkt)) + sendPacket' pkt rule + +-- | Send a packet using a given routing rule +sendPacket' :: IP4Packet -> (IP4,IP4,Mtu) -> IP () +sendPacket' pkt (src,dst,mtu) = do + arp <- arpHandle + output $ do + let frags = splitPacket mtu pkt + mapM_ (arpIP4Packet arp src dst <=< renderIP4Packet) frags + + +-- | Find a route to an address +findRoute :: IP4 -> IP (IP4,IP4,Mtu) +findRoute addr = do + state <- get + just (route addr (ip4Routes state)) + + +-- | Route a packet that is forwardable +forward :: IP4Packet -> IP () +forward pkt = sendPacket pkt + + +-- | Require that an address is local. +localAddress :: IP4 -> IP () +localAddress ip = do + state <- get + guard (ip `elem` localAddrs (ip4Routes state)) + + +findSourceMask :: IP4 -> IP IP4Mask +findSourceMask ip = do + state <- get + just (sourceMask ip (ip4Routes state)) + + +broadcastDestination :: IP4 -> IP () +broadcastDestination ip = do + mask <- findSourceMask ip + guard (isBroadcast mask ip) + +-- | Route a message to a local handler +routeLocal :: IP4Packet -> IP () +routeLocal pkt@(IP4Packet hdr _) = do + let dest = ip4DestAddr hdr + localAddress dest `mplus` broadcastDestination dest + h <- getHandler (ip4Protocol hdr) + mb <- handleFragments pkt + case mb of + Nothing -> return () + Just bs -> output (h (ip4SourceAddr hdr) (ip4DestAddr hdr) bs) + + +handleFragments :: IP4Packet -> IP (Maybe Packet) +handleFragments pkt = do + state <- get + now <- time + let (table',mb) = processIP4Packet now (ip4Fragments state) pkt + table' `seq` set state { ip4Fragments = table' } + return mb + +nextIdent :: IP Ident +nextIdent = do + state <- get + let i = ip4NextIdent state + set state { ip4NextIdent = i + 1 } + return i + +-- Message Handling ------------------------------------------------------------ + +-- | Incoming packet from the network +handleIncoming :: Packet -> IP () +handleIncoming bs = do + (hdr,hlen,plen) <- liftRight (runGet parseIP4Packet bs) + let (header,rest) = S.splitAt hlen bs + let payload = S.take plen rest + let checksum = computeChecksum 0 header + let pkt = IP4Packet hdr payload + guard $ and + [ S.length bs >= 20 + , hlen >= 20 + , checksum == 0 + , ip4Version hdr == 4 + ] + + -- forward? + routeLocal pkt `mplus` forward pkt + + +-- | Outgoing packet +handleOutgoing :: IP4Protocol -> IP4 -> Packet -> IP () +handleOutgoing prot dst bs = do + sendBytes prot dst bs + + +handleAddRule :: Rule IP4Mask IP4 -> IP () +handleAddRule rule = do + state <- get + let routes' = addRule rule (ip4Routes state) + routes' `seq` set state { ip4Routes = routes' } + + +handleSource :: IP4 -> (IP4 -> IO ()) -> IP () +handleSource dst k = do + (s,_,_) <- findRoute dst + output (k s) diff --git a/src/Hans/Layer/IP4/Fragmentation.hs b/src/Hans/Layer/IP4/Fragmentation.hs new file mode 100644 index 0000000..301542b --- /dev/null +++ b/src/Hans/Layer/IP4/Fragmentation.hs @@ -0,0 +1,111 @@ + +module Hans.Layer.IP4.Fragmentation where + +import Hans.Address +import Hans.Address.IP4 +import Hans.Message.Ip4 +import Hans.Utils + +import Data.Ord (comparing) +import Data.Time.Clock.POSIX (POSIXTime) +import qualified Data.ByteString as S +import qualified Data.Map as Map + + +type FragmentationTable addr = Map.Map (Ident,addr,addr) Fragments + +emptyFragmentationTable :: FragmentationTable IP4 +emptyFragmentationTable = Map.empty + + +data Fragments = Fragments + { startTime :: !POSIXTime + , totalSize :: !Int + , fragments :: [Fragment] + } deriving Show + +data Fragment = Fragment + { fragmentOffset :: !Int + , fragmentLength :: !Int + , fragmentPayload :: !Packet + } deriving (Eq,Show) + +instance Ord Fragment where + compare = comparing fragmentOffset + + +-- | The end of a fragment. +fragmentEnd :: Fragment -> Int +fragmentEnd f = fragmentOffset f + fragmentLength f + +-- | Check the ordering of two fragments. +comesBefore :: Fragment -> Fragment -> Bool +comesBefore f g = fragmentEnd f == fragmentOffset g + +-- | Check the ordering of two fragments. +comesAfter :: Fragment -> Fragment -> Bool +comesAfter = flip comesBefore + +-- | Merge two fragments. +-- +-- Note: This doesn't do a validity check to make sure that they're actually +-- adjacent. +combineFragments :: Fragment -> Fragment -> Fragment +combineFragments f g = Fragment (fragmentOffset f) len pay + where + len = fragmentLength f + fragmentLength g + pay = fragmentPayload f `S.append` fragmentPayload g + + +-- | Given a group of fragments, a new fragment, and a possible total size, +-- create a new group of fragments that incorporates the new fragment. +expandGroup :: Fragments -> Fragment -> Int -> Fragments +expandGroup fs newfrag x = case totalSize fs of + -1 | x >= 0 -> expandGroup fs{ totalSize = x } newfrag x + _ -> fs { fragments = addFragment newfrag (fragments fs) } + + +-- | Add a fragment to a list of fragments, in a position that is relative to +-- its offset and length. +addFragment :: Fragment -> [Fragment] -> [Fragment] +addFragment f fs = case fs of + [] -> [f] + g:rest | f `comesBefore` g -> addFragment (combineFragments f g) rest + | f `comesAfter` g -> addFragment (combineFragments g f) rest + | f < g -> f:fs + | otherwise -> g:(addFragment f rest) + + +-- | Process a packet fragment through the system, potentially returning a +-- fully-processed packet if this fragment completes an existing packet or +-- is itself a fully-complete packet. +processFragment :: Address addr + => POSIXTime -> FragmentationTable addr -> Bool -> Int + -> addr -> addr -> Ident -> Packet + -> (FragmentationTable addr, Maybe Packet) +processFragment _ table False 0 _ _ _ bs = (table, Just bs) +processFragment now table areMore off src dest ident bs = case group of + Fragments _ x [Fragment 0 y bs'] + | x == y -> (Map.delete entry table, Just bs') + _ -> (Map.insert entry group table, Nothing) + where + entry = (ident,src,dest) + group = case Map.lookup (ident,src,dest) table of + Nothing -> Fragments now newTotalLen [cur] + Just g -> expandGroup g cur newTotalLen + curlen = fromIntegral (S.length bs) + cur = Fragment off curlen bs + newTotalLen | areMore = -1 + | otherwise = off + curlen + + +processIP4Packet :: POSIXTime -> FragmentationTable IP4 -> IP4Packet + -> (FragmentationTable IP4, Maybe Packet) +processIP4Packet now table (IP4Packet hdr bs) = + processFragment now table areMore off src dest ident bs + where + off = fromIntegral (ip4FragmentOffset hdr) + ident = fromIntegral (ip4Ident hdr) + areMore = ip4MoreFragments hdr + src = ip4SourceAddr hdr + dest = ip4DestAddr hdr diff --git a/src/Hans/Layer/IP4/Routing.hs b/src/Hans/Layer/IP4/Routing.hs new file mode 100644 index 0000000..87dc9f2 --- /dev/null +++ b/src/Hans/Layer/IP4/Routing.hs @@ -0,0 +1,87 @@ +module Hans.Layer.IP4.Routing ( + -- * Routing Rules + Rule(..) + , Mtu + + -- * Routing Table + , RoutingTable + , emptyRoutingTable + , addRule + , route + , localAddrs + , sourceMask + ) where + +import Data.PrefixTree as PT +import Hans.Address.IP4 +import Hans.Address + +import Data.Maybe (mapMaybe) + + +-- Routing Rules --------------------------------------------------------------- + +type Mtu = Int + +data Rule mask addr + = Direct mask addr Mtu + | Indirect mask addr + deriving Show + + +-- Routing Table --------------------------------------------------------------- + +type RoutingTable addr = PrefixTree (Dest addr) + +data Dest addr + = NextHop addr + | Via addr Mtu + deriving Show + + +emptyRoutingTable :: Address addr => RoutingTable addr +emptyRoutingTable = PT.empty + + +{-# SPECIALIZE addRule :: Rule IP4Mask IP4 -> RoutingTable IP4 + -> RoutingTable IP4 #-} +-- | Add a rule to the routing table. +addRule :: Mask mask addr + => Rule mask addr -> RoutingTable addr -> RoutingTable addr +addRule rule table = case rule of + Direct mask addr mtu -> k mask (Via addr mtu) + Indirect mask addr -> k mask (NextHop addr) + where + k m d = insert ks d table + where + (addr,bits) = getMaskComponents m + ks = take bits (toBits addr) + + +{-# SPECIALIZE route :: IP4 -> RoutingTable IP4 -> Maybe (IP4,IP4,Mtu) #-} +-- | Discover the source and destination when trying to route an address. +route :: Address addr => addr -> RoutingTable addr -> Maybe (addr,addr,Mtu) +route addr t = do + r <- match (toBits addr) t + case r of + Via s mtu -> return (s,addr,mtu) + NextHop hop -> do + Via s mtu <- match (toBits hop) t + return (s,hop,mtu) + + +{-# SPECIALIZE sourceMask :: IP4 -> RoutingTable IP4 -> Maybe IP4Mask #-} +-- | Find the mask that would be used to route an address. +sourceMask :: Mask mask addr => addr -> RoutingTable addr -> Maybe mask +sourceMask addr table = do + src <- key (toBits addr) table + let bits = fromIntegral (addrSize addr * 8) - length src + return (addr `withMask` bits) + + +-- | Dump all local addresses. +localAddrs :: Address addr => RoutingTable addr -> [addr] +localAddrs table = mapMaybe p (PT.elems table) + where + p (Via s _) = Just s + p _ = Nothing diff --git a/src/Hans/Layer/Icmp4.hs b/src/Hans/Layer/Icmp4.hs new file mode 100644 index 0000000..14c05f4 --- /dev/null +++ b/src/Hans/Layer/Icmp4.hs @@ -0,0 +1,88 @@ +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE FlexibleInstances #-} + +module Hans.Layer.Icmp4 ( + Icmp4Handle + , runIcmp4Layer + , addIcmp4Handler + ) where + +import Hans.Address.IP4 +import Hans.Channel +import Hans.Layer +import Hans.Layer.IP4 +import Hans.Message.Icmp4 +import Hans.Message.Ip4 +import Hans.Utils + +import Control.Concurrent (forkIO) +import Data.Serialize (decode,encode) +import MonadLib (get,set) + +type Handler = Icmp4Packet -> IO () + +type Icmp4Handle = Channel (Icmp4 ()) + +icmpProtocol :: IP4Protocol +icmpProtocol = IP4Protocol 0x1 + +runIcmp4Layer :: Icmp4Handle -> IP4Handle -> IO () +runIcmp4Layer h ip4 = do + let handles = Icmp4Handles ip4 [] + addIP4Handler ip4 icmpProtocol + $ \src dst bs -> send h (handleIncoming src dst bs) + void (forkIO (loopLayer handles (receive h) id)) + +data Icmp4Handles = Icmp4Handles + { icmpIp4 :: IP4Handle + , icmpHandlers :: [Handler] + } + +type Icmp4 = Layer Icmp4Handles + +ip4Handle :: Icmp4 IP4Handle +ip4Handle = icmpIp4 `fmap` get + +sendPacket :: IP4 -> Icmp4Packet -> Icmp4 () +sendPacket dst pkt = do + ip4 <- ip4Handle + output $ sendIP4Packet ip4 icmpProtocol dst (encode pkt) + +-- | Add a handler for Icmp4 messages that match the provided predicate. +addIcmp4Handler :: Icmp4Handle -> Handler -> IO () +addIcmp4Handler h k = send h (handleAdd k) + +-- Message Handling ------------------------------------------------------------ + +-- | Handle incoming ICMP packets +handleIncoming :: IP4 -> IP4 -> Packet -> Icmp4 () +handleIncoming src _dst bs = do + pkt <- liftRight (decode bs) + matchHandlers pkt + case pkt of + -- XXX: Only echo-request is handled at the moment + Echo ident seqNum dat -> handleEchoRequest src ident seqNum dat + _ty -> do + --output (putStrLn ("Unhandled ICMP message type: " ++ show ty)) + dropPacket + + +-- | Add an icmp packet handler. +handleAdd :: Handler -> Icmp4 () +handleAdd k = do + s <- get + set s { icmpHandlers = k : icmpHandlers s } + + +-- | Respond to an echo request +handleEchoRequest :: IP4 -> Identifier -> SequenceNumber -> Packet -> Icmp4 () +handleEchoRequest src ident seqNum dat = do + sendPacket src (EchoReply ident seqNum dat) + + +-- | Output the IO actions for each handler that's registered. +matchHandlers :: Icmp4Packet -> Icmp4 () +matchHandlers pkt = do + s <- get + output (mapM_ ($ pkt) (icmpHandlers s)) diff --git a/src/Hans/Layer/Tcp.hs b/src/Hans/Layer/Tcp.hs new file mode 100644 index 0000000..43d1609 --- /dev/null +++ b/src/Hans/Layer/Tcp.hs @@ -0,0 +1,49 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DeriveDataTypeable #-} + +module Hans.Layer.Tcp ( + TcpHandle + , runTcpLayer + , queueTcp + + , module Exports + ) where + +import Hans.Address.IP4 +import Hans.Channel +import Hans.Layer +import Hans.Layer.IP4 +import Hans.Layer.Tcp.Monad (Tcp,TcpHandle,TcpState(..),emptyTcpState) +import Hans.Layer.Tcp.Socket as Exports +import Hans.Layer.Timer (TimerHandle,udelay) +import Hans.Message.Tcp (tcpProtocol) +import Hans.Layer.Tcp.Handlers (handleIncomingTcp,handleOutgoing) +import Hans.Utils (void) + +import Network.TCP.Type.Base (posixtime_to_time) +import Network.TCP.Type.Socket (update_host_time) + +import Control.Concurrent (forkIO) +import MonadLib (get,set) +import qualified Data.ByteString as S + + +runTcpLayer :: TcpHandle -> IP4Handle -> TimerHandle -> IO () +runTcpLayer tcp ip4 t = do + let s0 = emptyTcpState tcp ip4 t + void (forkIO (loopLayer s0 (receive tcp) updateTimeAndRun)) + addIP4Handler ip4 tcpProtocol (queueTcp tcp) + +-- | Queue a tcp packet. +queueTcp :: TcpHandle -> IP4 -> IP4 -> S.ByteString -> IO () +queueTcp tcp !src !dst !bs = send tcp (handleIncomingTcp src dst bs) + +-- | Pull the time out of the Layer monad, and convert it to a value that can be +-- used with the TCP layer. +updateTimeAndRun :: Tcp () -> Tcp () +updateTimeAndRun body = do + now <- time + s <- get + set $! s { tcpHost = update_host_time (posixtime_to_time now) (tcpHost s) } + body + handleOutgoing diff --git a/src/Hans/Layer/Tcp/Handlers.hs b/src/Hans/Layer/Tcp/Handlers.hs new file mode 100644 index 0000000..0f6d41e --- /dev/null +++ b/src/Hans/Layer/Tcp/Handlers.hs @@ -0,0 +1,74 @@ +module Hans.Layer.Tcp.Handlers ( + handleIncomingTcp + , handleOutgoing + ) where + +import Hans.Address.IP4 (IP4,convertFromWord32) +import Hans.Channel (send) +import Hans.Layer (output,liftRight) +import Hans.Layer.IP4 (sendIP4Packet,withIP4Source) +import Hans.Layer.Tcp.Monad + (Tcp,TcpState(..),ip4Handle,ip4Handle,ip4Handle,ip4Handle) +import Hans.Layer.Timer (udelay) +import Hans.Message.Tcp + (tcpProtocol,renderWithTcpChecksumIP4,TcpPacket(..),getTcpPacket + ,recreateTcpChecksumIP4,TcpHeader(..)) + +import Network.TCP.LTS.In (tcp_deliver_in_packet) +import Network.TCP.Type.Base (get_ip,bufferchain_collapse,IPAddr(..)) +import Network.TCP.Type.Datagram + (ICMPDatagram(..),UDPDatagram(..),TCPSegment(..),IPMessage(..) + ,mkTCPSegment) +import Network.TCP.Type.Socket (Host(..)) + +import Control.Monad (unless,guard) +import Data.Serialize (runGet) +import MonadLib (get,set) +import qualified Data.ByteString as S + + +-- | Handle a TCP message from the IP4 layer. +handleIncomingTcp :: IP4 -> IP4 -> S.ByteString -> Tcp () +handleIncomingTcp src dst bytes = do + let cs = recreateTcpChecksumIP4 src dst bytes + pkt@(TcpPacket hdr _body) <- liftRight (runGet getTcpPacket bytes) + guard (tcpChecksum hdr == cs) + tcp_deliver_in_packet (mkTCPSegment src dst pkt) + +-- | Force packets out of the pure layer. +handleOutgoing :: Tcp () +handleOutgoing = do + s <- get + let h = tcpHost s + set (s { tcpHost = h { output_queue = [], ready_list = [] } }) + let msgs = output_queue h + unless (null msgs) (mapM_ deliverIPMessage msgs) + let ready = ready_list h + unless (null ready) (mapM_ output ready) + +deliverIPMessage :: IPMessage -> Tcp () +deliverIPMessage msg = + case msg of + TCPMessage seg -> deliverTCPSegment seg + ICMPMessage icmp -> deliverICMPDatagram icmp + UDPMessage udp -> deliverUDPDatagram udp + +deliverTCPSegment :: TCPSegment -> Tcp () +deliverTCPSegment seg = do + let hdr = tcp_header seg + IPAddr dst = get_ip (tcp_dst seg) + dstAddr = convertFromWord32 dst + ip4 <- ip4Handle + output $ withIP4Source ip4 dstAddr $ \ srcAddr -> do + body <- bufferchain_collapse (tcp_data seg) + let pkt = renderWithTcpChecksumIP4 srcAddr dstAddr (TcpPacket hdr body) + sendIP4Packet ip4 tcpProtocol dstAddr pkt + +deliverICMPDatagram :: ICMPDatagram -> Tcp () +deliverICMPDatagram _icmp = do + output (putStrLn "Ignoring TCP icmp packet") + +deliverUDPDatagram :: UDPDatagram -> Tcp () +deliverUDPDatagram _udp = do + output (putStrLn "Ignoring TCP udp packet") + diff --git a/src/Hans/Layer/Tcp/Monad.hs b/src/Hans/Layer/Tcp/Monad.hs new file mode 100644 index 0000000..e56acd1 --- /dev/null +++ b/src/Hans/Layer/Tcp/Monad.hs @@ -0,0 +1,92 @@ +module Hans.Layer.Tcp.Monad where + +import Hans.Channel +import Hans.Layer +import Hans.Layer.IP4 +import Hans.Layer.Timer + +import MonadLib +import Network.TCP.Type.Datagram (IPMessage) +import Network.TCP.Type.Socket (Host(..),empty_host,TCPSocket) +import Network.TCP.Type.Syscall (SocketID) +import qualified Data.Map as Map + + +-- TCP Monad ------------------------------------------------------------------- + +type TcpHandle = Channel (Tcp ()) + +type Tcp = Layer (TcpState (IO ())) + +data TcpState t = TcpState + { tcpSelf :: TcpHandle + , tcpIP4 :: IP4Handle + , tcpTimers :: TimerHandle + , tcpHost :: Host t + } + +emptyTcpState :: TcpHandle -> IP4Handle -> TimerHandle -> TcpState t +emptyTcpState tcp ip4 timer = TcpState + { tcpSelf = tcp + , tcpIP4 = ip4 + , tcpTimers = timer + , tcpHost = empty_host + } + +-- | The handle to this layer. +self :: Tcp TcpHandle +self = tcpSelf `fmap` get + +-- | Get the handle to the IP4 layer. +ip4Handle :: Tcp IP4Handle +ip4Handle = tcpIP4 `fmap` get + +-- | Get the handle to the Timer layer. +timerHandle :: Tcp TimerHandle +timerHandle = tcpTimers `fmap` get + + +-- Compatibility Layer --------------------------------------------------------- + +type HMonad t = Layer (TcpState t) + +get_host :: HMonad t (Host t) +get_host = tcpHost `fmap` get + +put_host :: Host t -> HMonad t () +put_host h = do + s <- get + set $! s { tcpHost = h } + +modify_host :: (Host t -> Host t) -> HMonad t () +modify_host f = do + h <- get_host + put_host $! f h + +emit_segs :: [IPMessage] -> HMonad t () +emit_segs segs = modify_host (\h -> h { output_queue = output_queue h ++ segs }) + +emit_ready :: [t] -> HMonad t () +emit_ready ts = modify_host (\h -> h { ready_list = ready_list h ++ ts }) + +has_sock :: SocketID -> HMonad t Bool +has_sock sid = (Map.member sid . sock_map) `fmap` get_host + +lookup_sock :: SocketID -> HMonad t (TCPSocket t) +lookup_sock sid = do + h <- get_host + case Map.lookup sid (sock_map h) of + Nothing -> fail "lookup_sock: sid not found" + Just res -> return res + +delete_sock :: SocketID -> HMonad t () +delete_sock sid = + modify_host (\h -> h { sock_map = Map.delete sid (sock_map h) } ) + +update_sock :: SocketID -> (TCPSocket t -> TCPSocket t) -> HMonad t () +update_sock sid f = + modify_host (\h -> h { sock_map = Map.adjust f sid (sock_map h) }) + +insert_sock :: SocketID -> TCPSocket t -> HMonad t () +insert_sock sid sock = do + modify_host (\h -> h { sock_map = Map.insert sid sock (sock_map h) }) diff --git a/src/Hans/Layer/Tcp/Socket.hs b/src/Hans/Layer/Tcp/Socket.hs new file mode 100644 index 0000000..06fd12a --- /dev/null +++ b/src/Hans/Layer/Tcp/Socket.hs @@ -0,0 +1,189 @@ +{-# LANGUAGE DeriveDataTypeable #-} + +module Hans.Layer.Tcp.Socket ( + -- * Socket Layer + Socket() + , SocketError(..) + , listenPort + , acceptSocket + , connect + , sendSocket + , closeSocket + , readBytes + , readLine + ) where + +import Hans.Address.IP4 +import Hans.Channel +import Hans.Layer +import Hans.Layer.Tcp.Monad +import Hans.Message.Tcp (TcpPort(..)) + +import Network.TCP.LTS.User (tcp_process_user_request) +import Network.TCP.Type.Base + (IPAddr(..),SocketID,TCPAddr(..)) +import Network.TCP.Type.Syscall (SockReq(..),SockRsp(..)) + +import Control.Exception (throwIO,Exception) +import Control.Concurrent (MVar,newMVar,newEmptyMVar,takeMVar,putMVar) +import Data.Typeable (Typeable) +import qualified Data.ByteString as S +import qualified Data.ByteString.Lazy as L + +-- Socket Layer ---------------------------------------------------------------- + +data Socket = Socket + { socketTcpHandle :: TcpHandle + , socketId :: !SocketID + , socketBuffer :: MVar L.ByteString + } + +data SocketResult a + = SocketResult a + | SocketError SocketError + +data SocketError + = ListenError String + | AcceptError String + | ConnectError String + | SendError String + | RecvError String + | CloseError String + deriving (Typeable,Show) + +instance Exception SocketError + +-- | Block on a socket operation, waiting for the TCP layer to finish an action. +blockResult :: TcpHandle -> (MVar (SocketResult a) -> Tcp ()) -> IO a +blockResult tcp k = do + var <- newEmptyMVar + send tcp (k var) + sr <- takeMVar var + case sr of + SocketResult a -> return a + SocketError se -> throwIO se + +-- | Call @output@ if the @Tcp@ action returns a @Just@. +maybeOutput :: Tcp (Maybe (IO ())) -> Tcp () +maybeOutput body = do + mb <- body + case mb of + Just m -> output m + Nothing -> return () + +-- | Listen on a port. +listenPort :: TcpHandle -> TcpPort -> IO Socket +listenPort tcp (TcpPort port) = blockResult tcp $ \ res -> do + let mkError = SocketError . ListenError + k rsp = case rsp of + SockNew sid -> do + buf <- newMVar L.empty + putMVar res (SocketResult (Socket tcp sid buf)) + SockError err -> putMVar res (mkError err) + _ -> putMVar res (mkError "Unexpected response") + maybeOutput (tcp_process_user_request (SockListen port,k)) + +-- | Accept a client connection on a @Socket@. +acceptSocket :: Socket -> IO Socket +acceptSocket sock = blockResult (socketTcpHandle sock) $ \ res -> do + let mkError = SocketError . AcceptError + k rsp = case rsp of + SockNew sid -> do + buf <- newMVar L.empty + putMVar res (SocketResult (Socket (socketTcpHandle sock) sid buf)) + SockError err -> putMVar res (mkError err) + _ -> putMVar res (mkError "Unexpected response") + maybeOutput (tcp_process_user_request (SockAccept (socketId sock),k)) + +-- | Connect to a remote server. +connect :: TcpHandle -> IP4 -> IP4 -> TcpPort -> IO Socket +connect tcp src dst (TcpPort port) = blockResult tcp $ \ res -> do + let us = IPAddr (convertToWord32 src) + them = TCPAddr (IPAddr (convertToWord32 dst), port) + mkError = SocketError . ConnectError + k rsp = case rsp of + SockNew sid -> do + buf <- newMVar L.empty + putMVar res (SocketResult (Socket tcp sid buf)) + SockError err -> putMVar res (mkError err) + _ -> putMVar res (mkError "Unexpected response") + maybeOutput (tcp_process_user_request (SockConnect us them,k)) + +-- | Send on a @Socket@. +sendSocket :: Socket -> S.ByteString -> IO () +sendSocket sock bytes = blockResult (socketTcpHandle sock) $ \ res -> do + let mkError = SocketError . SendError + k rsp = putMVar res $! case rsp of + SockOK -> SocketResult () + SockError err -> mkError err + _ -> mkError "Unexpected response" + maybeOutput (tcp_process_user_request (SockSend (socketId sock) bytes,k)) + +-- | Receive from a @Socket@. +recvSocket :: Socket -> IO S.ByteString +recvSocket sock = blockResult (socketTcpHandle sock) $ \ res -> do + let mkError = SocketError . RecvError + k rsp = putMVar res $! case rsp of + SockData bs -> SocketResult bs + SockError err -> mkError err + _ -> mkError "Unexpected response" + maybeOutput (tcp_process_user_request (SockRecv (socketId sock),k)) + +-- | Close a socket. +closeSocket :: Socket -> IO () +closeSocket sock = + blockResult (socketTcpHandle sock) $ \ res -> do + let mkError = SocketError . CloseError + k rsp = putMVar res $! case rsp of + SockOK -> SocketResult () + SockError err -> mkError err + _ -> mkError "Unexpected response" + maybeOutput (tcp_process_user_request (SockClose (socketId sock),k)) + + +-- Derived Interaction --------------------------------------------------------- + +-- | Read n bytes from a @Socket@. +readBytes :: Socket -> Int -> IO S.ByteString +readBytes sock goal = do + buf <- takeMVar (socketBuffer sock) + loop buf (fromIntegral (L.length buf)) + where + loop buf len + | goal <= len = finish buf + | otherwise = do + bytes <- recvSocket sock + if S.null bytes + then finish buf + else loop (buf `L.append` L.fromChunks [bytes]) (len + S.length bytes) + + finish buf = do + let (as,bs) = L.splitAt (fromIntegral goal) buf + putMVar (socketBuffer sock) bs + return (S.concat (L.toChunks as)) + +-- | Read until a CRLF, LF or CR are read. +readLine :: Socket -> IO S.ByteString +readLine sock = do + buf <- takeMVar (socketBuffer sock) + loop False 0 buf + where + loop cr ix buf + | L.null buf = fillBuffer cr ix buf + | otherwise = + case L.index buf ix of + 0x0d -> loop True (ix+1) buf + 0x0a -> finish (ix+1) buf + _ | cr -> finish ix buf + | otherwise -> loop False (ix+1) buf + + fillBuffer cr ix buf = do + bytes <- recvSocket sock + if S.null bytes + then finish ix buf + else loop cr ix (buf `L.append` L.fromChunks [bytes]) + + finish ix buf = do + let (as,bs) = L.splitAt ix buf + putMVar (socketBuffer sock) bs + return (S.concat (L.toChunks as)) diff --git a/src/Hans/Layer/Timer.hs b/src/Hans/Layer/Timer.hs new file mode 100644 index 0000000..e4a105c --- /dev/null +++ b/src/Hans/Layer/Timer.hs @@ -0,0 +1,131 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# OPTIONS -fno-warn-orphans #-} + +module Hans.Layer.Timer ( + TimerHandle + , runTimerLayer + + , Milliseconds + , delay + , Microseconds + , udelay + ) where + +import Hans.Channel + +import Control.Concurrent (forkIO,threadDelay,MVar,newMVar,takeMVar,putMVar) +import Data.FingerTree as FT +import Data.Monoid (Monoid(..)) +import Data.Time.Clock.POSIX (POSIXTime,getPOSIXTime) +import MonadLib + + +-- The Timers Structure -------------------------------------------------------- + +data Timer a = Timer + { timerAt :: POSIXTime + , timerValue :: a + } + +instance Monoid POSIXTime where + mempty = 0 + mappend a b | a > b = a + | otherwise = b + +instance Measured POSIXTime (Timer a) where + measure t = timerAt t + +type Timers a = FingerTree POSIXTime (Timer a) + + +-- | Add an action to happen sometime in the future. +at :: POSIXTime -> a -> Timers a -> Timers a +at fut a ts = as' >< bs + where + (as,bs) = runnable fut ts + as' = as |> Timer fut a + + +-- | Partition the actions into runnable and deferred. +runnable :: POSIXTime -> Timers a -> (Timers a, Timers a) +runnable now ts = FT.split (> now) ts + + +-- | Step through the timers in a group. +stepTimers :: Timers a -> Maybe (a, Timers a) +stepTimers ts = case viewl ts of + EmptyL -> Nothing + r :< rs -> Just (timerValue r,rs) + + +-- | Run all timers, when the action is an IO action. +runTimers :: Timers (IO a) -> IO (Timers (IO a)) +runTimers ts = do + now <- getPOSIXTime + let loop (Just (a,as)) = a >> loop (stepTimers as) + loop Nothing = return () + (rs,rest) = runnable now ts + loop (stepTimers rs) + return rest + + +-- The External Message Interface ---------------------------------------------- + +type ActionTimers = Timers (IO ()) + +-- | Mesages handled by the timer thread. +data TimerMessage = AddTimer POSIXTime (IO ()) + +-- | A channel to the timer thread. +type TimerHandle = Channel TimerMessage + + +-- | Start the timer thread. +runTimerLayer :: TimerHandle -> IO () +runTimerLayer h = do + timers <- newMVar FT.empty + _ <- forkIO (actionHandler timers) + _ <- forkIO (messageHandler h timers) + return () + + +type Milliseconds = Int + +-- | Add a message to happen after some number of milliseconds. +delay :: TimerHandle -> Milliseconds -> IO () -> IO () +delay h !off = udelay h (off * 1000) + +type Microseconds = Int + +-- | Add a message to happen after some number of microseconds. +udelay :: TimerHandle -> Microseconds -> IO () -> IO () +udelay h !micros k = do + now <- getPOSIXTime + -- the granularity for a NominalTimeDiff is 10^-12 + let off = fromIntegral micros / 1000000 + send h (AddTimer (now + off) k) + + +-- Internal Loops -------------------------------------------------------------- + +-- | The delay granularity of the timer layer. +timerStep :: Microseconds +timerStep = 100 + + +-- | Loop, running available timer actions. +actionHandler :: MVar ActionTimers -> IO () +actionHandler timers = forever $ do + threadDelay timerStep + putMVar timers =<< runTimers =<< takeMVar timers + + +-- | Loop, processing timer add requests. +messageHandler :: TimerHandle -> MVar ActionTimers -> IO () +messageHandler h timers = forever $ do + AddTimer t k <- receive h + ts <- takeMVar timers + putMVar timers $! at t k ts diff --git a/src/Hans/Layer/Udp.hs b/src/Hans/Layer/Udp.hs new file mode 100644 index 0000000..086beb4 --- /dev/null +++ b/src/Hans/Layer/Udp.hs @@ -0,0 +1,127 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE FlexibleInstances #-} + +module Hans.Layer.Udp ( + UdpHandle + , runUdpLayer + + , queueUdp + , sendUdp + , addUdpHandler + , removeUdpHandler + ) where + +import Hans.Address.IP4 +import Hans.Channel +import Hans.Layer +import Hans.Layer.IP4 +import Hans.Layer.Icmp4 +import Hans.Message.Ip4 +import Hans.Message.Udp +import Hans.Ports +import Hans.Utils + +import Control.Concurrent (forkIO) +import Data.Serialize.Get (runGet) +import MonadLib (get,set) + + +type Handler = IP4 -> UdpPort -> Packet -> IO () + +type UdpHandle = Channel (Udp ()) + +udpProtocol :: IP4Protocol +udpProtocol = IP4Protocol 0x11 + +runUdpLayer :: UdpHandle -> IP4Handle -> Icmp4Handle -> IO () +runUdpLayer h ip4 icmp4 = do + addIP4Handler ip4 udpProtocol (queueUdp h) + void (forkIO (loopLayer (emptyUdp4State ip4 icmp4) (receive h) id)) + +sendUdp :: UdpHandle -> IP4 -> Maybe UdpPort -> UdpPort -> Packet -> IO () +sendUdp h !dst mb !dp !bs = send h (handleOutgoing dst mb dp bs) + +queueUdp :: UdpHandle -> IP4 -> IP4 -> Packet -> IO () +queueUdp h !src !dst !bs = send h (handleIncoming src dst bs) + +addUdpHandler :: UdpHandle -> UdpPort -> Handler -> IO () +addUdpHandler h !sp k = send h (handleAddHandler sp k) + +removeUdpHandler :: UdpHandle -> UdpPort -> IO () +removeUdpHandler h !sp = send h (handleRemoveHandler sp) + + +-- Udp State ------------------------------------------------------------------- + +type Udp = Layer UdpState + +data UdpState = UdpState + { udpPorts :: PortManager UdpPort + , udpHandlers :: Handlers UdpPort Handler + , udpIp4Handle :: IP4Handle + , udpIcmp4Handle :: Icmp4Handle + } + +emptyUdp4State :: IP4Handle -> Icmp4Handle -> UdpState +emptyUdp4State ip4 icmp4 = UdpState + { udpPorts = emptyPortManager [maxBound, maxBound - 1 .. 1 ] + , udpHandlers = emptyHandlers + , udpIp4Handle = ip4 + , udpIcmp4Handle = icmp4 + } + +instance ProvidesHandlers UdpState UdpPort Handler where + getHandlers = udpHandlers + setHandlers hs s = s { udpHandlers = hs } + + +-- Utilities ------------------------------------------------------------------- + +ip4Handle :: Udp IP4Handle +ip4Handle = udpIp4Handle `fmap` get + +--icmp4Handle :: Udp Icmp4Handle +--icmp4Handle = udpIcmp4Handle `fmap` get + +maybePort :: Maybe UdpPort -> Udp UdpPort +maybePort (Just p) = return p +maybePort Nothing = do + state <- get + (p,pm') <- nextPort (udpPorts state) + pm' `seq` set state { udpPorts = pm' } + return p + +-- Message Handling ------------------------------------------------------------ + +handleAddHandler :: UdpPort -> Handler -> Udp () +handleAddHandler sp k = do + state <- get + pm' <- reserve sp (udpPorts state) + pm' `seq` set state { udpPorts = pm' } + addHandler sp k + +handleRemoveHandler :: UdpPort -> Udp () +handleRemoveHandler sp = do + state <- get + pm' <- unreserve sp (udpPorts state) + pm' `seq` set state { udpPorts = pm' } + removeHandler sp + + +handleIncoming :: IP4 -> IP4 -> Packet -> Udp () +handleIncoming src _dst bs = do + UdpPacket hdr pkt <- liftRight (runGet parseUdpPacket bs) + h <- getHandler (udpDestPort hdr) + output (h src (udpSourcePort hdr) pkt) + + +handleOutgoing :: IP4 -> Maybe UdpPort -> UdpPort -> Packet -> Udp () +handleOutgoing dst mb dp bs = do + sp <- maybePort mb + ip4 <- ip4Handle + let udp = UdpPacket (UdpHeader sp dp 0) bs + output $ withIP4Source ip4 dst $ \ src -> do + pkt <- renderUdpPacket udp (mkIP4PseudoHeader src dst udpProtocol) + sendIP4Packet ip4 udpProtocol dst pkt diff --git a/src/Hans/LayerPrime.hs b/src/Hans/LayerPrime.hs new file mode 100644 index 0000000..51e9f23 --- /dev/null +++ b/src/Hans/LayerPrime.hs @@ -0,0 +1,191 @@ +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FlexibleContexts #-} + +module Hans.LayerPrime where + +import Hans.Device (Add(..)) +import Hans.Message (Message(sendTo)) + +import Control.Applicative (Applicative(..),Alternative(..)) +import Control.Monad (MonadPlus(..)) +import Data.Monoid (Monoid(..)) +import Data.Time.Clock.POSIX +import MonadLib +import qualified Data.Map as Map + + +-- Continuation Monad ---------------------------------------------------------- + +newtype Cont m r a = Cont { getCont :: (a -> m r) -> m r } + +{-# INLINE runCont #-} +runCont :: (a -> m r) -> Cont m r a -> m r +runCont k m = getCont m k + +instance Functor (Cont m r) where + fmap f m = Cont (\k -> runCont (k . f) m) + +instance Applicative (Cont m r) where + pure a = Cont (\k -> k a) + f <*> x = Cont (\k -> getCont f (\g -> getCont x (\y -> k (g y)))) + +instance Monad (Cont m r) where + {-# INLINE return #-} + return a = pure a + m >>= f = Cont (\k -> runCont (runCont k . f) m) + +instance StateM m i => StateM (Cont m r) i where + get = liftC get + set i = liftC (set i) + +instance MonadPlus m => MonadPlus (Cont m r) where + mzero = liftC mzero + mplus a b = Cont (\k -> getCont a k `mplus` getCont b k) + +liftC :: Monad m => m a -> Cont m r a +liftC m = Cont (=<< m) + + +-- Network Stack Layer Monad --------------------------------------------------- + +data LayerState i = LayerState + { lsNow :: POSIXTime + , lsState :: i + } + +data Action = Nop | Action (IO ()) + +instance Monoid Action where + {-# INLINE mempty #-} + mempty = Nop + + mappend (Action a) (Action b) = Action (a >> b) + mappend Nop b = b + mappend a _ = a + +{-# INLINE runAction #-} +runAction :: Action -> IO () +runAction Nop = return () +runAction (Action m) = m + +type Layer i = Cont (LayerM i) () + +data Result i a + = Error Action + | Result (LayerState i) a Action + +newtype LayerM i a = LayerM { getLayerM :: LayerState i -> Result i a } + +{-# SPECIALIZE runLayerM :: LayerState i -> LayerM i () -> Result i () #-} +runLayerM :: LayerState i -> LayerM i a -> Result i a +runLayerM i m = getLayerM m i + +{-# INLINE loopLayer #-} +loopLayer :: i -> IO msg -> (msg -> Layer i ()) -> IO () +loopLayer i0 msg k = loop (LayerState 0 i0) + where + loop i = do + a <- msg + now <- getPOSIXTime + case runLayerM (i {lsNow = now }) (runCont return (k a)) of + Error a -> runAction a >> loop i + Result i' () a -> runAction a >> loop i' + +instance Functor (LayerM i) where + fmap f m = LayerM $ \i0 -> + case runLayerM i0 m of + Error k -> Error k + Result i a k -> Result i (f a) k + +instance Applicative (LayerM i) where + pure x = LayerM $ \i0 -> Result i0 x mempty + + f <*> x = LayerM $ \i0 -> case runLayerM i0 f of + Error k -> Error k + Result i g k -> case runLayerM i x of + Error l -> Error (k `mappend` l) + Result i y l -> Result i (g y) (k `mappend` l) + +instance Alternative (LayerM i) where + {-# INLINE empty #-} + empty = LayerM $ \ _ -> Error mempty + + {-# INLINE (<|>) #-} + a <|> b = LayerM $ \i0 -> case runLayerM i0 a of + Error _ -> runLayerM i0 b + Result i a k -> Result i a k + +instance Monad (LayerM i) where + {-# INLINE return #-} + return x = pure x + + m >>= f = LayerM $ \i0 -> case runLayerM i0 m of + Error k -> Error k + Result i a k -> case runLayerM i (f a) of + Error l -> Error (k `mappend` l) + Result i b l -> Result i b (k `mappend` l) + +instance MonadPlus (LayerM i) where + {-# INLINE mzero #-} + mzero = empty + + {-# INLINE mplus #-} + mplus = (<|>) + +instance StateM (LayerM i) i where + get = LayerM $ \i0 -> Result i0 (lsState i0) mempty + set i = LayerM $ \i0 -> Result (i0 { lsState = i }) () mempty + + +-- Utilities ------------------------------------------------------------------- + +{-# INLINE dropPacket #-} +dropPacket :: Layer i a +dropPacket = liftC empty + +time :: Layer i POSIXTime +time = liftC $ LayerM $ \i0 -> Result i0 (lsNow i0) mempty + +output :: IO () -> Layer i () +output m = liftC $ LayerM $ \i0 -> Result i0 () (Action m) + +-- Handler Generalization ------------------------------------------------------ + +type Handlers k a = Map.Map k a + +emptyHandlers :: Handlers k a +emptyHandlers = Map.empty + + +class ProvidesHandlers i k a | i -> k a where + getHandlers :: i -> Handlers k a + setHandlers :: Handlers k a -> i -> i + + +setHandler :: Message h (Add (k,a)) => h -> k -> a -> IO () +setHandler h k a = sendTo h (Add (k,a)) + + +getHandler :: (Ord k, ProvidesHandlers i k a) => k -> Layer i a +getHandler k = liftC $ do + state <- get + case Map.lookup k (getHandlers state) of + Nothing -> mzero + Just h -> return h + + +addHandler :: (Ord k, ProvidesHandlers i k a) => k -> a -> Layer i () +addHandler k a = liftC $ do + state <- get + let hs' = Map.insert k a (getHandlers state) + hs' `seq` set (setHandlers hs' state) + + +removeHandler :: (Ord k, ProvidesHandlers i k a) => k -> Layer i () +removeHandler k = liftC $ do + state <- get + let hs' = Map.delete k (getHandlers state) + hs' `seq` set (setHandlers hs' state) diff --git a/src/Hans/Message/Arp.hs b/src/Hans/Message/Arp.hs new file mode 100644 index 0000000..f8faf6a --- /dev/null +++ b/src/Hans/Message/Arp.hs @@ -0,0 +1,74 @@ +module Hans.Message.Arp where + +import Hans.Address + +import Data.Serialize (Serialize(..)) +import Data.Serialize.Get (getWord8,getWord16be) +import Data.Serialize.Put (putWord16be,putWord8) +import Data.Word (Word16) + + +-- Arp Packets ----------------------------------------------------------------- + +data ArpPacket hw p = ArpPacket + { arpHwType :: !Word16 + , arpPType :: !Word16 + , arpOper :: ArpOper + , arpSHA :: hw + , arpSPA :: p + , arpTHA :: hw + , arpTPA :: p + } + +-- | Decode an Arp message. +instance (Address hw, Address p) => Serialize (ArpPacket hw p) where + get = do + hty <- getWord16be + pty <- getWord16be + _ <- getWord8 + _ <- getWord8 + oper <- get + sha <- get + spa <- get + tha <- get + tpa <- get + return $! ArpPacket + { arpHwType = hty + , arpPType = pty + , arpOper = oper + , arpSHA = sha + , arpSPA = spa + , arpTHA = tha + , arpTPA = tpa + } + + put msg = do + putWord16be (arpHwType msg) + putWord16be (arpPType msg) + putWord8 (addrSize (arpSHA msg)) + putWord8 (addrSize (arpSPA msg)) + put (arpOper msg) + put (arpSHA msg) + put (arpSPA msg) + put (arpTHA msg) + put (arpTPA msg) + + +-- Arp Opcodes ----------------------------------------------------------------- + +data ArpOper + = ArpRequest -- ^ 0x1 + | ArpReply -- ^ 0x2 + deriving (Eq) + + +instance Serialize ArpOper where + get = do + b <- getWord16be + case b of + 0x1 -> return ArpRequest + 0x2 -> return ArpReply + _ -> fail "invalid Arp opcode" + + put ArpRequest = putWord16be 0x1 + put ArpReply = putWord16be 0x2 diff --git a/src/Hans/Message/Dhcp4.hs b/src/Hans/Message/Dhcp4.hs new file mode 100644 index 0000000..4ca95fc --- /dev/null +++ b/src/Hans/Message/Dhcp4.hs @@ -0,0 +1,568 @@ +{- | The 'Hans.Message.Dhcp4' module defines the various messages and + transitions used in the DHCPv4 protocol. This module provides both + a high-level view of the message types as well as a low-level + intermediate form which is closely tied to the binary format. + + References: + RFC 2131 - Dynamic Host Configuration Protocol + http://www.faqs.org/rfcs/rfc2131.html +-} +module Hans.Message.Dhcp4 + ( + -- ** High-level client types + RequestMessage(..) + , Request(..) + , Discover(..) + + -- ** High-level server types + , ServerSettings(..) + , ReplyMessage(..) + , Ack(..) + , Offer(..) + + -- ** Low-level message types + , Dhcp4Message(..) + , Xid(..) + + -- ** Server message transition logic + , requestToAck + , discoverToOffer + + -- ** Client message transition logic + , mkDiscover + , offerToRequest + + -- ** Convert high-level message types to low-level format + , requestToMessage + , ackToMessage + , offerToMessage + , discoverToMessage + + -- ** Convert low-level message type to high-level format + , parseDhcpMessage + + -- ** Convert low-level message type to binary format + , getDhcp4Message + , putDhcp4Message + + ) where + +import Hans.Address.IP4 (IP4(..)) +import Hans.Address.Mac (Mac) +import Hans.Message.Dhcp4Codec +import Hans.Message.Dhcp4Options + +import Control.Applicative ((<*), (<$>)) +import Control.Monad (unless) +import Data.Bits (testBit,bit) +import Data.Maybe (mapMaybe) +import Data.Serialize.Get (Get, getByteString, isolate, remaining, label, skip) +import Data.Serialize.Put (Put, putByteString) +import Data.Word (Word8,Word16,Word32) +import Numeric (showHex) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Char8 as BS8 + +-- DHCP Static Server Settings --------------------------------------------- + +-- |'ServerSettings' define all of the information that would be needed to +-- act as a DHCP server for one client. The server is defined to be able to +-- issue a single "lease" whose parameters are defined below. +data ServerSettings = Settings + { staticServerAddr :: IP4 -- ^ The IPv4 address of the DHCP server + , staticTimeOffset :: Word32 -- ^ Lease: timezone offset in seconds from UTC + , staticClientAddr :: IP4 -- ^ Lease: client IPv4 address on network + , staticLeaseTime :: Word32 -- ^ Lease: duration in seconds + , staticSubnet :: SubnetMask -- ^ Lease: subnet mask on network + , staticBroadcast :: IP4 -- ^ Lease: broadcast address on network + , staticRouters :: [IP4] -- ^ Lease: gateway routers on network + , staticDomainName :: String -- ^ Lease: client's assigned domain name + , staticDNS :: [IP4] -- ^ Lease: network DNS servers + } + deriving (Show) + +-- Structured DHCP Messages ------------------------------------------------ + +-- |'RequestMessage' is a sum of the client request messages. +data RequestMessage = RequestMessage Request + | DiscoverMessage Discover + deriving (Show) + +-- |'ReplyMessage' is a sum of the server response messages. +data ReplyMessage = AckMessage Ack + | OfferMessage Offer + deriving (Show) + +-- |'Request' is used by the client to accept an offered lease. +data Request = Request + { requestXid :: Xid -- ^ Transaction ID of offer + , requestBroadcast :: Bool -- ^ Set 'True' to instruct server to send to broadcast hardware address + , requestClientHardwareAddress :: Mac -- ^ Hardware address of the client + , requestParameters :: [Dhcp4OptionTag] -- ^ Used to specify the information that client needs + , requestAddress :: Maybe IP4 -- ^ Used to specify the address which was accepted + } + deriving (Show) + +-- |'Discover' is used by the client to discover what servers are available. +-- This message is sent to the IPv4 broadcast. +data Discover = Discover + { discoverXid :: Xid -- ^ Transaction ID of this and subsequent messages + , discoverBroadcast :: Bool -- ^ Set 'True' to instruct the server to send to broadcast hardware address + , discoverClientHardwareAddr :: Mac -- ^ Hardware address of the client + , discoverParameters :: [Dhcp4OptionTag]-- ^ Used to specify the information that client needs in the offers + } + deriving (Show) + +-- |'Ack' is sent by the DHCPv4 server to acknowledge a sucessful 'Request' +-- message. Upon receiving this message the client has completed the +-- exchange and has successfully obtained a lease. +data Ack = Ack + { ackHops :: Word8 -- ^ The maximum number of relays this message can use. + , ackXid :: Xid -- ^ Transaction ID for this exchange + , ackYourAddr :: IP4 -- ^ Lease: assigned client address + , ackServerAddr :: IP4 -- ^ DHCP server's IPv4 address + , ackRelayAddr :: IP4 -- ^ DHCP relay server's address + , ackClientHardwareAddr :: Mac -- ^ Client's hardware address + , ackLeaseTime :: Word32 -- ^ Lease: duration of lease in seconds + , ackOptions :: [Dhcp4Option] -- ^ Subset of information requested in previous 'Request' + } + deriving (Show) + +-- |'Offer' is sent by the DHCPv4 server in response to a 'Discover'. +-- This offer is only valid for a short period of time as the client +-- might receive many offers. The client must next request a lease +-- from a specific server using the information in that server's offer. +data Offer = Offer + { offerHops :: Word8 -- ^ The maximum number of relays this message can use. + , offerXid :: Xid -- ^ Transaction ID of this exchange + , offerYourAddr :: IP4 -- ^ The IPv4 address that this server is willing to lease + , offerServerAddr :: IP4 -- ^ The IPv4 address of the DHCPv4 server + , offerRelayAddr :: IP4 -- ^ The IPv4 address of the DHCPv4 relay server + , offerClientHardwareAddr :: Mac -- ^ The hardware address of the client + , offerOptions :: [Dhcp4Option] -- ^ The options that this server would include in a lease + } + deriving (Show) + +-- |'requestToAck' creates 'Ack' messages suitable for responding to 'Request' +-- messages given a static 'ServerSettings' configuration. +requestToAck :: ServerSettings -- ^ DHCPv4 server settings + -> Request -- ^ Client's request message + -> Ack +requestToAck settings request = Ack + { ackHops = 1 + , ackXid = requestXid request + , ackYourAddr = staticClientAddr settings + , ackServerAddr = staticServerAddr settings + , ackRelayAddr = staticServerAddr settings + , ackClientHardwareAddr = requestClientHardwareAddress request + , ackLeaseTime = staticLeaseTime settings + , ackOptions = mapMaybe lookupOption (requestParameters request) + } + where + lookupOption tag = case tag of + OptTagSubnetMask + -> Just (OptSubnetMask (staticSubnet settings)) + OptTagBroadcastAddress + -> Just (OptBroadcastAddress (staticBroadcast settings)) + OptTagTimeOffset + -> Just (OptTimeOffset (staticTimeOffset settings)) + OptTagRouters + -> Just (OptRouters (staticRouters settings)) + OptTagDomainName + -> Just (OptDomainName (NVTAsciiString (staticDomainName settings))) + OptTagNameServers + -> Just (OptNameServers (staticDNS settings)) + _ -> Nothing + +-- |'discoverToOffer' creates a suitable 'Offer' in response to a client's +-- 'Discover' message using the configuration settings specified in the +-- given 'ServerSettings'. +discoverToOffer :: ServerSettings -- ^ DHCPv4 server settings + -> Discover -- ^ Client's discover message + -> Offer +discoverToOffer settings discover = Offer + { offerHops = 1 + , offerXid = discoverXid discover + , offerYourAddr = staticClientAddr settings + , offerServerAddr = staticServerAddr settings + , offerRelayAddr = staticServerAddr settings + , offerClientHardwareAddr = discoverClientHardwareAddr discover + , offerOptions = mapMaybe lookupOption (discoverParameters discover) + } + where + lookupOption tag = case tag of + OptTagSubnetMask + -> Just (OptSubnetMask (staticSubnet settings)) + OptTagBroadcastAddress + -> Just (OptBroadcastAddress (staticBroadcast settings)) + OptTagTimeOffset + -> Just (OptTimeOffset (staticTimeOffset settings)) + OptTagRouters + -> Just (OptRouters (staticRouters settings)) + OptTagDomainName + -> Just (OptDomainName (NVTAsciiString (staticDomainName settings))) + OptTagNameServers + -> Just (OptNameServers (staticDNS settings)) + _ -> Nothing + +-- Unstructured DHCP Message ----------------------------------------------- + +-- |'Dhcp4Message' is a low-level message container that is very close to +-- the binary representation of DHCPv4 message. It is suitable for containing +-- any DHCPv4 message. Values of this type should only be created using the +-- publicly exported functions. +data Dhcp4Message = Dhcp4Message + { dhcp4Op :: Dhcp4Op -- ^ Message op code / message type. 1 = BOOTREQUEST, 2 = BOOTREPLY + , dhcp4Hops :: Word8 -- ^ Client sets to zero, optionally used by relay agents when booting via a relay agent. + , dhcp4Xid :: Xid -- ^ Transaction ID, a random number chosen by the client, used by the client and server to associate messages and responses between a client and a server. + , dhcp4Secs :: Word16 -- ^ Filled in by client, seconds elapsed since client began address acquisition or renewal process. + , dhcp4Broadcast :: Bool -- ^ Client requests messages be sent to hardware broadcast address + , dhcp4ClientAddr :: IP4 -- ^ Client IP address; only filled in if client is in BOUND, RENEW or REBINDING state and can respond to ARP requests. + , dhcp4YourAddr :: IP4 -- ^ 'your' (client) address + , dhcp4ServerAddr :: IP4 -- ^ IP address of next server to use in bootstrap; returned in DHCPOFFER, DHCPACK by server + , dhcp4RelayAddr :: IP4 -- ^ Relay agent IP address, used in booting via a relay agent + , dhcp4ClientHardwareAddr :: Mac -- ^ Client hardware address + , dhcp4ServerHostname :: String -- ^ Optional server host name, null terminated string + , dhcp4BootFilename :: String -- ^ Boot file name, full terminated string; "generic" name of null in DHCPDISCOVER, fully qualified directory-path name in DHCPOFFER + , dhcp4Options :: [Dhcp4Option] -- ^ Optional parameters field. + } deriving (Eq,Show) + +-- |'getDhcp4Message' is the binary decoder for parsing 'Dhcp4Message' values. +getDhcp4Message :: Get Dhcp4Message +getDhcp4Message = do + op <- getAtom + hwtype <- getAtom + len <- getAtom + unless (len == hardwareTypeAddressLength hwtype) + (fail "Hardware address length does not match hardware type.") + hops <- label "hops" getAtom + xid <- label "xid" getAtom + secs <- label "secs" getAtom + flags <- label "flags" getAtom + ciaddr <- label "ciaddr" getAtom + yiaddr <- label "yiaddr" getAtom + siaddr <- label "siaddr" getAtom + giaddr <- label "giaddr" getAtom + chaddr <- label "chaddr" $ isolate 16 $ getAtom <* (skip =<< remaining) + snameBytes <- label "sname field" (getByteString 64) + fileBytes <- label "file field" (getByteString 128) + (sname, file, opts) <- getDhcp4Options snameBytes fileBytes + return $! Dhcp4Message + { dhcp4Op = op + , dhcp4Hops = hops + , dhcp4Xid = xid + , dhcp4Secs = secs + , dhcp4Broadcast = broadcastFlag flags + , dhcp4ClientAddr = ciaddr + , dhcp4YourAddr = yiaddr + , dhcp4ServerAddr = siaddr + , dhcp4RelayAddr = giaddr + , dhcp4ClientHardwareAddr = chaddr + , dhcp4ServerHostname = sname + , dhcp4BootFilename = file + , dhcp4Options = opts + } + +-- |'getDhcp4Message' is the binary encoder for rendering 'Dhcp4Message' values. +putDhcp4Message :: Dhcp4Message -> Put +putDhcp4Message dhcp = do + putAtom (dhcp4Op dhcp) + let hwType = Ethernet + putAtom hwType + putAtom (hardwareTypeAddressLength hwType) + putAtom (dhcp4Hops dhcp) + putAtom (dhcp4Xid dhcp) + putAtom (dhcp4Secs dhcp) + putAtom Flags { broadcastFlag = dhcp4Broadcast dhcp } + putAtom (dhcp4ClientAddr dhcp) + putAtom (dhcp4YourAddr dhcp) + putAtom (dhcp4ServerAddr dhcp) + putAtom (dhcp4RelayAddr dhcp) + putAtom (dhcp4ClientHardwareAddr dhcp) + putByteString $ BS.replicate (16 {- chaddr field length -} + - fromIntegral (hardwareTypeAddressLength hwType)) 0 + putPaddedByteString 64 (BS8.pack (dhcp4ServerHostname dhcp)) + putPaddedByteString 128 (BS8.pack (dhcp4BootFilename dhcp)) + putDhcp4Options (dhcp4Options dhcp) + +-- Transaction ID -------------------------------------------------------------- + +-- |'Xid' is a Transaction ID, a random number chosen by the client, +-- used by the client and server to associate messages and responses between a +-- client and a server. +newtype Xid = Xid Word32 + deriving (Eq, Show) + +instance CodecAtom Xid where + getAtom = Xid <$> getAtom + putAtom (Xid xid) = putAtom xid + atomSize _ = atomSize (0 :: Word32) + +-- Opcodes --------------------------------------------------------------------- + +data Dhcp4Op + = BootRequest + | BootReply + deriving (Eq,Show) + +instance CodecAtom Dhcp4Op where + getAtom = do + b <- getAtom + case b :: Word8 of + 1 -> return BootRequest + 2 -> return BootReply + _ -> fail ("Unknown DHCP op 0x" ++ showHex b "") + + putAtom BootRequest = putAtom (0x1 :: Word8) + putAtom BootReply = putAtom (0x2 :: Word8) + + atomSize _ = atomSize (0 :: Word8) + +-- | HardwareType is an enumeration of the supported hardware types as assigned +-- in the ARP RFC http://www.iana.org/assignments/arp-parameters/ +data HardwareType + = Ethernet + deriving (Eq, Show) + +instance CodecAtom HardwareType where + getAtom = getAtom >>= \ b -> case b :: Word8 of + 1 -> return Ethernet + _ -> fail ("Unsupported hardware type 0x" ++ showHex b "") + + putAtom Ethernet = putAtom (1 :: Word8) + + atomSize _ = atomSize (1 :: Word8) + +hardwareTypeAddressLength :: HardwareType -> Word8 +hardwareTypeAddressLength Ethernet = 6 + +-- RFC 2131, Section 2, Page 11: +-- 1 1 1 1 1 1 +-- 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- |B| MBZ | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- B: BROADCAST flag +-- MBZ: MUST BE ZERO (reserved for future use) +-- Figure 2: Format of the ’flags’ field + +data Flags = Flags { broadcastFlag :: Bool } + deriving (Show, Eq) + +instance CodecAtom Flags where + getAtom = do + b <- getAtom :: Get Word16 + return Flags { broadcastFlag = testBit b 15 } + + putAtom flags = putAtom $ if broadcastFlag flags then bit 15 :: Word16 + else 0 + + atomSize _ = atomSize (0 :: Word16) + +putPaddedByteString :: Int -> BS.ByteString -> Put +putPaddedByteString n bs = do + putByteString $ BS.take n bs + putByteString $ BS.replicate (n - BS.length bs) 0 + + +selectKnownTags :: [OptionTagOrError] -> [Dhcp4OptionTag] +selectKnownTags = mapMaybe aux + where + aux (KnownTag t) = Just t + aux _ = Nothing + +-- +-- Unstructured to Structured logic +-- +-- |'parseDhcpMessage' attempts to find a valid high-level message +-- contained in a low-level message. The 'Dhcp4Message' is a large +-- type and can encode invalid combinations of options. +parseDhcpMessage :: Dhcp4Message -> Maybe (Either RequestMessage ReplyMessage) +parseDhcpMessage msg = do + messageType <- lookupMessageType (dhcp4Options msg) + case dhcp4Op msg of + + BootRequest -> Left <$> case messageType of + + Dhcp4Request -> RequestMessage <$> do + params <- lookupParams (dhcp4Options msg) + let params' = selectKnownTags params + let addr = lookupRequestAddr (dhcp4Options msg) + return Request + { requestXid = dhcp4Xid msg + , requestBroadcast = dhcp4Broadcast msg + , requestClientHardwareAddress = dhcp4ClientHardwareAddr msg + , requestParameters = params' + , requestAddress = addr + } + + Dhcp4Discover -> DiscoverMessage <$> do + params <- lookupParams (dhcp4Options msg) + let params' = selectKnownTags params + return Discover + { discoverXid = dhcp4Xid msg + , discoverBroadcast = dhcp4Broadcast msg + , discoverClientHardwareAddr = dhcp4ClientHardwareAddr msg + , discoverParameters = params' + } + + _ -> Nothing + + BootReply -> Right <$> case messageType of + + Dhcp4Ack -> AckMessage <$> do + leaseTime <- lookupLeaseTime (dhcp4Options msg) + return Ack + { ackHops = dhcp4Hops msg + , ackXid = dhcp4Xid msg + , ackYourAddr = dhcp4YourAddr msg + , ackServerAddr = dhcp4ServerAddr msg + , ackRelayAddr = dhcp4RelayAddr msg + , ackClientHardwareAddr = dhcp4ClientHardwareAddr msg + , ackLeaseTime = leaseTime + , ackOptions = dhcp4Options msg + } + + Dhcp4Offer -> OfferMessage <$> do + return Offer + { offerHops = dhcp4Hops msg + , offerXid = dhcp4Xid msg + , offerYourAddr = dhcp4YourAddr msg + , offerServerAddr = dhcp4ServerAddr msg + , offerRelayAddr = dhcp4RelayAddr msg + , offerClientHardwareAddr = dhcp4ClientHardwareAddr msg + , offerOptions = dhcp4Options msg + } + + _ -> Nothing + +-- +-- Structured to unstrucured logic +-- +-- |'discoverToMessage' embeds 'Discover' messages in the low-level +-- 'Dhcp4Message' type, typically for the purpose of serialization. +discoverToMessage :: Discover -> Dhcp4Message +discoverToMessage discover = Dhcp4Message + { dhcp4Op = BootRequest + , dhcp4Hops = 0 + , dhcp4Xid = discoverXid discover + , dhcp4Secs = 0 + , dhcp4Broadcast = False + , dhcp4ClientAddr = IP4 0 0 0 0 + , dhcp4YourAddr = IP4 0 0 0 0 + , dhcp4ServerAddr = IP4 0 0 0 0 + , dhcp4RelayAddr = IP4 0 0 0 0 + , dhcp4ClientHardwareAddr = discoverClientHardwareAddr discover + , dhcp4ServerHostname = "" + , dhcp4BootFilename = "" + , dhcp4Options = [ OptMessageType Dhcp4Discover + , OptParameterRequestList + $ map KnownTag + $ discoverParameters discover + ] + } + +-- |'ackToMessage' embeds 'Ack' messages in the low-level +-- 'Dhcp4Message' type, typically for the purpose of serialization. +ackToMessage :: Ack -> Dhcp4Message +ackToMessage ack = Dhcp4Message + { dhcp4Op = BootReply + , dhcp4Hops = ackHops ack + , dhcp4Xid = ackXid ack + , dhcp4Secs = 0 + , dhcp4Broadcast = False + , dhcp4ClientAddr = IP4 0 0 0 0 + , dhcp4YourAddr = ackYourAddr ack + , dhcp4ServerAddr = ackServerAddr ack + , dhcp4RelayAddr = ackRelayAddr ack + , dhcp4ClientHardwareAddr = ackClientHardwareAddr ack + , dhcp4ServerHostname = "" + , dhcp4BootFilename = "" + , dhcp4Options = OptMessageType Dhcp4Ack + : OptServerIdentifier (ackServerAddr ack) + : OptIPAddressLeaseTime (ackLeaseTime ack) + : ackOptions ack + } + +-- |'offerToMessage' embeds 'Offer' messages in the low-level +-- 'Dhcp4Message' type, typically for the purpose of serialization. +offerToMessage :: Offer -> Dhcp4Message +offerToMessage offer = Dhcp4Message + { dhcp4Op = BootReply + , dhcp4Hops = offerHops offer + , dhcp4Xid = offerXid offer + , dhcp4Secs = 0 + , dhcp4Broadcast = False + , dhcp4ClientAddr = IP4 0 0 0 0 + , dhcp4YourAddr = offerYourAddr offer + , dhcp4ServerAddr = offerServerAddr offer + , dhcp4RelayAddr = offerRelayAddr offer + , dhcp4ClientHardwareAddr = offerClientHardwareAddr offer + , dhcp4ServerHostname = "" + , dhcp4BootFilename = "" + , dhcp4Options = OptMessageType Dhcp4Offer + : OptServerIdentifier (offerServerAddr offer) + : offerOptions offer + } + +-- |'requestToMessage' embeds 'Request' messages in the low-level +-- 'Dhcp4Message' type, typically for the purpose of serialization. +requestToMessage :: Request -> Dhcp4Message +requestToMessage request = Dhcp4Message + { dhcp4Op = BootRequest + , dhcp4Hops = 0 + , dhcp4Xid = requestXid request + , dhcp4Secs = 0 + , dhcp4Broadcast = requestBroadcast request + , dhcp4ClientAddr = IP4 0 0 0 0 + , dhcp4YourAddr = IP4 0 0 0 0 + , dhcp4ServerAddr = IP4 0 0 0 0 + , dhcp4RelayAddr = IP4 0 0 0 0 + , dhcp4ClientHardwareAddr = requestClientHardwareAddress request + , dhcp4ServerHostname = "" + , dhcp4BootFilename = "" + , dhcp4Options = [ OptMessageType Dhcp4Request + , OptParameterRequestList + $ map KnownTag + $ requestParameters request + ] ++ maybe [] (\x -> [OptRequestIPAddress x]) + (requestAddress request) + } + +-- |'mkDiscover' creates a new 'Discover' message with a set +-- of options suitable for configuring a basic network stack. +mkDiscover :: Xid -- ^ New randomly generated transaction ID + -> Mac -- ^ The client's hardware address + -> Discover +mkDiscover xid mac = Discover + { discoverXid = xid + , discoverBroadcast = False + , discoverClientHardwareAddr = mac + , discoverParameters = [ OptTagSubnetMask + , OptTagBroadcastAddress + , OptTagTimeOffset + , OptTagRouters + , OptTagDomainName + , OptTagNameServers + , OptTagHostName + ] + } + +-- |'offerToRequest' creates a 'Request' message suitable for accepting +-- an 'Offer' from the DHCPv4 server. +offerToRequest :: Offer -- ^ The offer as received from the server + -> Request +offerToRequest offer = Request + { requestXid = offerXid offer + , requestBroadcast = False + , requestClientHardwareAddress = offerClientHardwareAddr offer + , requestParameters = [ OptTagSubnetMask + , OptTagBroadcastAddress + , OptTagTimeOffset + , OptTagRouters + , OptTagDomainName + , OptTagNameServers + , OptTagHostName + ] + , requestAddress = Just (offerYourAddr offer) + } diff --git a/src/Hans/Message/Dhcp4Codec.hs b/src/Hans/Message/Dhcp4Codec.hs new file mode 100644 index 0000000..39edefc --- /dev/null +++ b/src/Hans/Message/Dhcp4Codec.hs @@ -0,0 +1,90 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +module Hans.Message.Dhcp4Codec where + +import Control.Applicative +import Data.List (find) +import qualified Data.Serialize +import Data.Serialize.Get +import Data.Serialize.Put +import Data.Word (Word8, Word16, Word32) + +import Hans.Address.IP4 (IP4,IP4Mask) +import Hans.Address.Mac (Mac) +import Hans.Address (Mask(..)) + +class CodecAtom a where + getAtom :: Get a + putAtom :: a -> Put + atomSize :: a -> Int + +instance (CodecAtom a, CodecAtom b) => CodecAtom (a,b) where + getAtom = (,) <$> getAtom <*> getAtom + putAtom (a,b) = putAtom a *> putAtom b + atomSize (a,b)= atomSize a + atomSize b + +instance CodecAtom Word8 where + getAtom = getWord8 + putAtom n = putWord8 n + atomSize _ = 1 + +instance CodecAtom Word16 where + getAtom = getWord16be + putAtom n = putWord16be n + atomSize _ = 2 + +instance CodecAtom Word32 where + getAtom = getWord32be + putAtom n = putWord32be n + atomSize _ = 4 + +instance CodecAtom Bool where + getAtom = do b <- getWord8 + case b of + 0 -> return False + 1 -> return True + _ -> fail "Expected 0/1 in boolean option" + putAtom False = putWord8 0 + putAtom True = putWord8 1 + atomSize _ = 1 + +instance CodecAtom IP4 where + getAtom = Data.Serialize.get + putAtom = Data.Serialize.put + atomSize _ = 4 + +instance CodecAtom IP4Mask where + getAtom = withMask <$> getAtom <*> (unmask <$> getAtom) + putAtom ip4mask = putAtom addr *> putAtom (SubnetMask mask) + where (addr, mask) = getMaskComponents ip4mask + atomSize _ = atomSize (undefined :: IP4) + + atomSize (undefined :: SubnetMask) + +instance CodecAtom Mac where + getAtom = Data.Serialize.get + putAtom = Data.Serialize.put + atomSize _ = 6 + +----------------------------------------------------------------------- +-- Subnet parser/unparser operations ---------------------------------- +----------------------------------------------------------------------- + +newtype SubnetMask = SubnetMask { unmask :: Int} + deriving (Show, Eq) + +word32ToSubnetMask :: Word32 -> Maybe SubnetMask +word32ToSubnetMask mask = + SubnetMask <$> find (\ i -> computeMask i == mask) [0..32] + +subnetMaskToWord32 :: SubnetMask -> Word32 +subnetMaskToWord32 (SubnetMask n) = computeMask n + +computeMask :: Int -> Word32 +computeMask n = 0-2^(32-n) + +instance CodecAtom SubnetMask where + getAtom = do x <- getAtom + case word32ToSubnetMask x of + Just mask -> return mask + Nothing -> fail "Invalid subnet mask" + putAtom = putAtom . subnetMaskToWord32 + atomSize _ = atomSize (undefined :: Word32) diff --git a/src/Hans/Message/Dhcp4Options.hs b/src/Hans/Message/Dhcp4Options.hs new file mode 100644 index 0000000..ee9e5f0 --- /dev/null +++ b/src/Hans/Message/Dhcp4Options.hs @@ -0,0 +1,869 @@ +module Hans.Message.Dhcp4Options where + +import Control.Monad (unless) +import Control.Applicative +import Data.Maybe (fromMaybe) +import Data.Foldable (traverse_) +import Data.Traversable (sequenceA) +import Data.Word (Word8, Word16, Word32) +import Data.Serialize.Get +import Data.Serialize.Put +import Data.ByteString (ByteString) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Char8 as BS8 +import Numeric (showHex) + +import Hans.Address.IP4 (IP4,IP4Mask) +import Hans.Message.Dhcp4Codec + +----------------------------------------------------------------------- +-- Magic constants ---------------------------------------------------- +----------------------------------------------------------------------- + +data MagicCookie = MagicCookie + +dhcp4MagicCookie :: Word32 +dhcp4MagicCookie = 0x63825363 + +instance CodecAtom MagicCookie where + getAtom = do cookie <- getAtom + unless (cookie == dhcp4MagicCookie) + (fail "Incorrect magic cookie.") + return MagicCookie + putAtom MagicCookie = putAtom dhcp4MagicCookie + atomSize MagicCookie = atomSize dhcp4MagicCookie + + +----------------------------------------------------------------------- +-- DHCP option type and operations ------------------------------------ +----------------------------------------------------------------------- + +data Dhcp4Option + = OptSubnetMask SubnetMask + | OptTimeOffset Word32 + | OptRouters [IP4] + | OptTimeServers [IP4] + | OptIEN116NameServers [IP4] + | OptNameServers [IP4] + | OptLogServers [IP4] + | OptCookieServers [IP4] + | OptLPRServers [IP4] + | OptImpressServers [IP4] + | OptResourceLocationServers [IP4] + | OptHostName NVTAsciiString + | OptBootFileSize Word16 + | OptMeritDumpFile NVTAsciiString + | OptDomainName NVTAsciiString + | OptSwapServer IP4 + | OptRootPath NVTAsciiString + | OptExtensionsPath NVTAsciiString + | OptEnableIPForwarding Bool + | OptEnableNonLocalSourceRouting Bool + | OptPolicyFilters [IP4Mask] + | OptMaximumDatagramReassemblySize Word16 + | OptDefaultTTL Word8 + | OptPathMTUAgingTimeout Word32 + | OptPathMTUPlateauTable [Word16] + | OptInterfaceMTU Word16 + | OptAllSubnetsAreLocal Bool + | OptBroadcastAddress IP4 + | OptPerformMaskDiscovery Bool + | OptShouldSupplyMasks Bool + | OptShouldPerformRouterDiscovery Bool + | OptRouterSolicitationAddress IP4 + | OptStaticRoutes [(IP4,IP4)] + | OptShouldNegotiateArpTrailers Bool + | OptArpCacheTimeout Word32 + | OptUseRFC1042EthernetEncapsulation Bool + | OptTcpDefaultTTL Word8 + | OptTcpKeepaliveInterval Word32 + | OptTcpKeepaliveUseGarbage Bool + | OptNisDomainName NVTAsciiString + | OptNisServers [IP4] + | OptNtpServers [IP4] + | OptVendorSpecific ByteString + | OptNetBiosNameServers [IP4] + | OptNetBiosDistributionServers [IP4] + | OptNetBiosNodeType NetBiosNodeType + | OptNetBiosScope NVTAsciiString + | OptXWindowsFontServer [IP4] + | OptXWindowsDisplayManagers [IP4] + | OptNisPlusDomain NVTAsciiString + | OptNisPlusServers [IP4] + | OptSmtpServers [IP4] + | OptPopServers [IP4] + | OptNntpServers [IP4] + | OptWwwServers [IP4] + | OptFingerServers [IP4] + | OptIrcServers [IP4] + | OptStreetTalkServers [IP4] + | OptStreetTalkDirectoryAssistanceServers [IP4] + | OptFQDN NVTAsciiString -- RFC 4702 + | OptRequestIPAddress IP4 + | OptIPAddressLeaseTime Word32 + | OptOverload OverloadOption + | OptTftpServer NVTAsciiString + | OptBootfileName NVTAsciiString + | OptMessageType Dhcp4MessageType + | OptServerIdentifier IP4 + | OptParameterRequestList [OptionTagOrError] + | OptErrorMessage NVTAsciiString + | OptMaxDHCPMessageSize Word16 + | OptRenewalTime Word32 + | OptRebindingTime Word32 + | OptVendorClass NVTAsciiString + | OptClientIdentifier ByteString + | OptNetWareDomainName NVTAsciiString -- RFC 2242 + | OptNetWareInfo ByteString -- RFC 2242 + | OptAutoconfiguration Bool -- RFC 2563 + deriving (Show,Eq) + +getDhcp4Option :: Get (Either ControlTag Dhcp4Option) +getDhcp4Option = do + mb_tag <- getOptionTag + case mb_tag of + UnknownTag t -> do xs <- getBytes =<< remaining + fail ("getDhcp4Option failed tag (" ++ show t ++ ") " ++ show xs) + KnownTag tag -> do + let r con = Right . con <$> getOption + case tag of + OptTagPad -> Left <$> pure ControlPad + OptTagEnd -> Left <$> pure ControlEnd + OptTagSubnetMask -> r OptSubnetMask + OptTagTimeOffset -> r OptTimeOffset + OptTagRouters -> r OptRouters + OptTagTimeServers -> r OptTimeServers + OptTagIEN116NameServers -> r OptIEN116NameServers + OptTagNameServers -> r OptNameServers + OptTagLogServers -> r OptLogServers + OptTagCookieServers -> r OptCookieServers + OptTagLPRServers -> r OptLPRServers + OptTagImpressServers -> r OptImpressServers + OptTagResourceLocationServers -> r OptResourceLocationServers + OptTagHostName -> r OptHostName + OptTagBootFileSize -> r OptBootFileSize + OptTagMeritDumpFile -> r OptMeritDumpFile + OptTagDomainName -> r OptDomainName + OptTagSwapServer -> r OptSwapServer + OptTagRootPath -> r OptRootPath + OptTagExtensionsPath -> r OptExtensionsPath + OptTagEnableIPForwarding -> r OptEnableIPForwarding + OptTagEnableNonLocalSourceRouting -> r OptEnableNonLocalSourceRouting + OptTagPolicyFilters -> r OptPolicyFilters + OptTagMaximumDatagramReassemblySize -> r OptMaximumDatagramReassemblySize + OptTagDefaultTTL -> r OptDefaultTTL + OptTagPathMTUAgingTimeout -> r OptPathMTUAgingTimeout + OptTagPathMTUPlateauTable -> r OptPathMTUPlateauTable + OptTagInterfaceMTU -> r OptInterfaceMTU + OptTagAllSubnetsAreLocal -> r OptAllSubnetsAreLocal + OptTagBroadcastAddress -> r OptBroadcastAddress + OptTagPerformMaskDiscovery -> r OptPerformMaskDiscovery + OptTagShouldSupplyMasks -> r OptShouldSupplyMasks + OptTagShouldPerformRouterDiscovery -> r OptShouldPerformRouterDiscovery + OptTagRouterSolicitationAddress -> r OptRouterSolicitationAddress + OptTagStaticRoutes -> r OptStaticRoutes + OptTagShouldNegotiateArpTrailers -> r OptShouldNegotiateArpTrailers + OptTagArpCacheTimeout -> r OptArpCacheTimeout + OptTagUseRFC1042EthernetEncapsulation -> r OptUseRFC1042EthernetEncapsulation + OptTagTcpDefaultTTL -> r OptTcpDefaultTTL + OptTagTcpKeepaliveInterval -> r OptTcpKeepaliveInterval + OptTagTcpKeepaliveUseGarbage -> r OptTcpKeepaliveUseGarbage + OptTagNisDomainName -> r OptNisDomainName + OptTagNisServers -> r OptNisServers + OptTagNtpServers -> r OptNtpServers + OptTagVendorSpecific -> r OptVendorSpecific + OptTagNetBiosNameServers -> r OptNetBiosNameServers + OptTagNetBiosDistributionServers -> r OptNetBiosDistributionServers + OptTagNetBiosNodeType -> r OptNetBiosNodeType + OptTagNetBiosScope -> r OptNetBiosScope + OptTagXWindowsFontServer -> r OptXWindowsFontServer + OptTagXWindowsDisplayManagers -> r OptXWindowsDisplayManagers + OptTagNisPlusDomain -> r OptNisPlusDomain + OptTagNisPlusServers -> r OptNisPlusServers + OptTagSmtpServers -> r OptSmtpServers + OptTagPopServers -> r OptPopServers + OptTagNntpServers -> r OptNntpServers + OptTagWwwServers -> r OptWwwServers + OptTagFingerServers -> r OptFingerServers + OptTagIrcServers -> r OptIrcServers + OptTagStreetTalkServers -> r OptStreetTalkServers + OptTagStreetTalkDirectoryAssistanceServers -> r OptStreetTalkDirectoryAssistanceServers + OptTagFQDN -> r OptFQDN + OptTagRequestIPAddress -> r OptRequestIPAddress + OptTagIPAddressLeaseTime -> r OptIPAddressLeaseTime + OptTagOverload -> r OptOverload + OptTagTftpServer -> r OptTftpServer + OptTagBootfileName -> r OptBootfileName + OptTagMessageType -> r OptMessageType + OptTagServerIdentifier -> r OptServerIdentifier + OptTagParameterRequestList -> r OptParameterRequestList + OptTagErrorMessage -> r OptErrorMessage + OptTagMaxDHCPMessageSize -> r OptMaxDHCPMessageSize + OptTagRenewalTime -> r OptRenewalTime + OptTagRebindingTime -> r OptRebindingTime + OptTagVendorClass -> r OptVendorClass + OptTagClientIdentifier -> r OptClientIdentifier + OptTagNetWareDomainName -> r OptNetWareDomainName + OptTagNetWareInfo -> r OptNetWareInfo + OptTagAutoconfiguration -> r OptAutoconfiguration + +putDhcp4Option :: Dhcp4Option -> Put +putDhcp4Option opt = + let p tag val = putAtom (KnownTag tag) *> putOption val in + case opt of + OptSubnetMask mask -> p OptTagSubnetMask mask + OptTimeOffset offset -> p OptTagTimeOffset offset + OptRouters routers -> p OptTagRouters routers + OptTimeServers servers -> p OptTagTimeServers servers + OptIEN116NameServers servers -> p OptTagIEN116NameServers servers + OptNameServers servers -> p OptTagNameServers servers + OptLogServers servers -> p OptTagLogServers servers + OptCookieServers servers -> p OptTagCookieServers servers + OptLPRServers servers -> p OptTagLPRServers servers + OptImpressServers servers -> p OptTagImpressServers servers + OptResourceLocationServers servers -> p OptTagResourceLocationServers servers + OptHostName hostname -> p OptTagHostName hostname + OptBootFileSize sz -> p OptTagBootFileSize sz + OptMeritDumpFile file -> p OptTagMeritDumpFile file + OptDomainName domainname -> p OptTagDomainName domainname + OptSwapServer server -> p OptTagSwapServer server + OptRootPath path -> p OptTagRootPath path + OptExtensionsPath path -> p OptTagExtensionsPath path + OptEnableIPForwarding enabled -> p OptTagEnableIPForwarding enabled + OptEnableNonLocalSourceRouting enab -> p OptTagEnableNonLocalSourceRouting enab + OptPolicyFilters filters -> p OptTagPolicyFilters filters + OptMaximumDatagramReassemblySize n -> p OptTagMaximumDatagramReassemblySize n + OptDefaultTTL ttl -> p OptTagDefaultTTL ttl + OptPathMTUAgingTimeout timeout -> p OptTagPathMTUAgingTimeout timeout + OptPathMTUPlateauTable mtus -> p OptTagPathMTUPlateauTable mtus + OptInterfaceMTU mtu -> p OptTagInterfaceMTU mtu + OptAllSubnetsAreLocal arelocal -> p OptTagAllSubnetsAreLocal arelocal + OptBroadcastAddress addr -> p OptTagBroadcastAddress addr + OptPerformMaskDiscovery perform -> p OptTagPerformMaskDiscovery perform + OptShouldSupplyMasks should -> p OptTagShouldSupplyMasks should + OptShouldPerformRouterDiscovery b -> p OptTagShouldPerformRouterDiscovery b + OptRouterSolicitationAddress addr -> p OptTagRouterSolicitationAddress addr + OptStaticRoutes routes -> p OptTagStaticRoutes routes + OptShouldNegotiateArpTrailers b -> p OptTagShouldNegotiateArpTrailers b + OptArpCacheTimeout timeout -> p OptTagArpCacheTimeout timeout + OptUseRFC1042EthernetEncapsulation b-> p OptTagUseRFC1042EthernetEncapsulation b + OptTcpDefaultTTL ttl -> p OptTagTcpDefaultTTL ttl + OptTcpKeepaliveInterval interval -> p OptTagTcpKeepaliveInterval interval + OptTcpKeepaliveUseGarbage use -> p OptTagTcpKeepaliveUseGarbage use + OptNisDomainName domainname -> p OptTagNisDomainName domainname + OptNisServers servers -> p OptTagNisServers servers + OptNtpServers servers -> p OptTagNtpServers servers + OptVendorSpecific bs -> p OptTagVendorSpecific bs + OptNetBiosNameServers servers -> p OptTagNetBiosNameServers servers + OptNetBiosDistributionServers srvs -> p OptTagNetBiosDistributionServers srvs + OptNetBiosNodeType node -> p OptTagNetBiosNodeType node + OptNetBiosScope scope -> p OptTagNetBiosScope scope + OptXWindowsFontServer servers -> p OptTagXWindowsFontServer servers + OptXWindowsDisplayManagers servers -> p OptTagXWindowsDisplayManagers servers + OptNisPlusDomain domain -> p OptTagNisPlusDomain domain + OptNisPlusServers servers -> p OptTagNisPlusServers servers + OptSmtpServers servers -> p OptTagSmtpServers servers + OptPopServers servers -> p OptTagPopServers servers + OptNntpServers servers -> p OptTagNntpServers servers + OptWwwServers servers -> p OptTagWwwServers servers + OptFingerServers servers -> p OptTagFingerServers servers + OptIrcServers servers -> p OptTagIrcServers servers + OptStreetTalkServers servers -> p OptTagStreetTalkServers servers + OptStreetTalkDirectoryAssistanceServers servers -> p OptTagStreetTalkDirectoryAssistanceServers servers + OptFQDN fqdn -> p OptTagFQDN fqdn + OptRequestIPAddress addr -> p OptTagRequestIPAddress addr + OptIPAddressLeaseTime time -> p OptTagIPAddressLeaseTime time + OptOverload overload -> p OptTagOverload overload + OptTftpServer server -> p OptTagTftpServer server + OptBootfileName filename -> p OptTagBootfileName filename + OptMessageType t -> p OptTagMessageType t + OptServerIdentifier server -> p OptTagServerIdentifier server + OptParameterRequestList ps -> p OptTagParameterRequestList ps + OptErrorMessage msg -> p OptTagErrorMessage msg + OptMaxDHCPMessageSize maxsz -> p OptTagMaxDHCPMessageSize maxsz + OptRenewalTime time -> p OptTagRenewalTime time + OptRebindingTime time -> p OptTagRebindingTime time + OptVendorClass str -> p OptTagVendorClass str + OptClientIdentifier client -> p OptTagClientIdentifier client + OptNetWareDomainName name -> p OptTagNetWareDomainName name + OptNetWareInfo info -> p OptTagNetWareInfo info + OptAutoconfiguration autoconf -> p OptTagAutoconfiguration autoconf + +----------------------------------------------------------------------- +-- Message Type type and operations ----------------------------------- +----------------------------------------------------------------------- + +data Dhcp4MessageType + = Dhcp4Discover + | Dhcp4Offer + | Dhcp4Request + | Dhcp4Decline + | Dhcp4Ack + | Dhcp4Nak + | Dhcp4Release + | Dhcp4Inform + deriving (Eq,Show) + +instance Option Dhcp4MessageType where + getOption = defaultFixedGetOption + putOption = defaultFixedPutOption + +instance CodecAtom Dhcp4MessageType where + getAtom = do + b <- getAtom + case b :: Word8 of + 1 -> return Dhcp4Discover + 2 -> return Dhcp4Offer + 3 -> return Dhcp4Request + 4 -> return Dhcp4Decline + 5 -> return Dhcp4Ack + 6 -> return Dhcp4Nak + 7 -> return Dhcp4Release + 8 -> return Dhcp4Inform + _ -> fail ("Unknown DHCP Message Type 0x" ++ showHex b "") + + putAtom t = putAtom $ case t of + Dhcp4Discover -> 1 :: Word8 + Dhcp4Offer -> 2 + Dhcp4Request -> 3 + Dhcp4Decline -> 4 + Dhcp4Ack -> 5 + Dhcp4Nak -> 6 + Dhcp4Release -> 7 + Dhcp4Inform -> 8 + + atomSize _ = 1 + +----------------------------------------------------------------------- +-- Control tag type and operations ------------------------------------ +----------------------------------------------------------------------- + +data ControlTag + = ControlPad + | ControlEnd + deriving (Eq, Show) + +putControlOption :: ControlTag -> Put +putControlOption opt = case opt of + ControlPad -> putAtom (KnownTag OptTagPad) + ControlEnd -> putAtom (KnownTag OptTagEnd) + +----------------------------------------------------------------------- +-- Option tag type and operations ------------------------------------- +----------------------------------------------------------------------- + +data Dhcp4OptionTag + = OptTagPad + | OptTagEnd + | OptTagSubnetMask + | OptTagTimeOffset + | OptTagRouters + | OptTagTimeServers + | OptTagIEN116NameServers + | OptTagNameServers + | OptTagLogServers + | OptTagCookieServers + | OptTagLPRServers + | OptTagImpressServers + | OptTagResourceLocationServers + | OptTagHostName + | OptTagBootFileSize + | OptTagMeritDumpFile + | OptTagDomainName + | OptTagSwapServer + | OptTagRootPath + | OptTagExtensionsPath + | OptTagEnableIPForwarding + | OptTagEnableNonLocalSourceRouting + | OptTagPolicyFilters + | OptTagMaximumDatagramReassemblySize + | OptTagDefaultTTL + | OptTagPathMTUAgingTimeout + | OptTagPathMTUPlateauTable + | OptTagInterfaceMTU + | OptTagAllSubnetsAreLocal + | OptTagBroadcastAddress + | OptTagPerformMaskDiscovery + | OptTagShouldSupplyMasks + | OptTagShouldPerformRouterDiscovery + | OptTagRouterSolicitationAddress + | OptTagStaticRoutes + | OptTagShouldNegotiateArpTrailers + | OptTagArpCacheTimeout + | OptTagUseRFC1042EthernetEncapsulation + | OptTagTcpDefaultTTL + | OptTagTcpKeepaliveInterval + | OptTagTcpKeepaliveUseGarbage + | OptTagNisDomainName + | OptTagNisServers + | OptTagNtpServers + | OptTagVendorSpecific + | OptTagNetBiosNameServers + | OptTagNetBiosDistributionServers + | OptTagNetBiosNodeType + | OptTagNetBiosScope + | OptTagXWindowsFontServer + | OptTagXWindowsDisplayManagers + | OptTagNisPlusDomain + | OptTagNisPlusServers + | OptTagSmtpServers + | OptTagPopServers + | OptTagNntpServers + | OptTagWwwServers + | OptTagFingerServers + | OptTagIrcServers + | OptTagStreetTalkServers + | OptTagStreetTalkDirectoryAssistanceServers + | OptTagFQDN + | OptTagRequestIPAddress + | OptTagIPAddressLeaseTime + | OptTagOverload + | OptTagTftpServer + | OptTagBootfileName + | OptTagMessageType + | OptTagServerIdentifier + | OptTagParameterRequestList + | OptTagErrorMessage + | OptTagMaxDHCPMessageSize + | OptTagRenewalTime + | OptTagRebindingTime + | OptTagVendorClass + | OptTagClientIdentifier + | OptTagNetWareDomainName + | OptTagNetWareInfo + | OptTagAutoconfiguration + deriving (Show,Eq) + +data OptionTagOrError = UnknownTag Word8 | KnownTag Dhcp4OptionTag + deriving (Show,Eq) + +getOptionTag :: Get OptionTagOrError +getOptionTag = f =<< getWord8 + where + r = return . KnownTag + f 0 = r OptTagPad + f 1 = r OptTagSubnetMask + f 2 = r OptTagTimeOffset + f 3 = r OptTagRouters + f 4 = r OptTagTimeServers + f 5 = r OptTagIEN116NameServers + f 6 = r OptTagNameServers + f 7 = r OptTagLogServers + f 8 = r OptTagCookieServers + f 9 = r OptTagLPRServers + f 10 = r OptTagImpressServers + f 11 = r OptTagResourceLocationServers + f 12 = r OptTagHostName + f 13 = r OptTagBootFileSize + f 14 = r OptTagMeritDumpFile + f 15 = r OptTagDomainName + f 16 = r OptTagSwapServer + f 17 = r OptTagRootPath + f 18 = r OptTagExtensionsPath + f 19 = r OptTagEnableIPForwarding + f 20 = r OptTagEnableNonLocalSourceRouting + f 21 = r OptTagPolicyFilters + f 22 = r OptTagMaximumDatagramReassemblySize + f 23 = r OptTagDefaultTTL + f 24 = r OptTagPathMTUAgingTimeout + f 25 = r OptTagPathMTUPlateauTable + f 26 = r OptTagInterfaceMTU + f 27 = r OptTagAllSubnetsAreLocal + f 28 = r OptTagBroadcastAddress + f 29 = r OptTagPerformMaskDiscovery + f 30 = r OptTagShouldSupplyMasks + f 31 = r OptTagShouldPerformRouterDiscovery + f 32 = r OptTagRouterSolicitationAddress + f 33 = r OptTagStaticRoutes + f 34 = r OptTagShouldNegotiateArpTrailers + f 35 = r OptTagArpCacheTimeout + f 36 = r OptTagUseRFC1042EthernetEncapsulation + f 37 = r OptTagTcpDefaultTTL + f 38 = r OptTagTcpKeepaliveInterval + f 39 = r OptTagTcpKeepaliveUseGarbage + f 40 = r OptTagNisDomainName + f 41 = r OptTagNisServers + f 42 = r OptTagNtpServers + f 43 = r OptTagVendorSpecific + f 44 = r OptTagNetBiosNameServers + f 45 = r OptTagNetBiosDistributionServers + f 46 = r OptTagNetBiosNodeType + f 47 = r OptTagNetBiosScope + f 48 = r OptTagXWindowsFontServer + f 49 = r OptTagXWindowsDisplayManagers + f 50 = r OptTagRequestIPAddress + f 51 = r OptTagIPAddressLeaseTime + f 52 = r OptTagOverload + f 53 = r OptTagMessageType + f 54 = r OptTagServerIdentifier + f 55 = r OptTagParameterRequestList + f 56 = r OptTagErrorMessage + f 57 = r OptTagMaxDHCPMessageSize + f 58 = r OptTagRenewalTime + f 59 = r OptTagRebindingTime + f 60 = r OptTagVendorClass + f 61 = r OptTagClientIdentifier + f 62 = r OptTagNetWareDomainName + f 63 = r OptTagNetWareInfo + f 64 = r OptTagNisPlusDomain + f 65 = r OptTagNisPlusServers + f 66 = r OptTagTftpServer + f 67 = r OptTagBootfileName + f 69 = r OptTagSmtpServers + f 70 = r OptTagPopServers + f 71 = r OptTagNntpServers + f 72 = r OptTagWwwServers + f 73 = r OptTagFingerServers + f 74 = r OptTagIrcServers + f 75 = r OptTagStreetTalkServers + f 76 = r OptTagStreetTalkDirectoryAssistanceServers + f 81 = r OptTagFQDN + f 116 = r OptTagAutoconfiguration + f 255 = r OptTagEnd + f t = return (UnknownTag t) + +putOptionTag :: OptionTagOrError -> Put +putOptionTag (UnknownTag t) = putAtom t +putOptionTag (KnownTag t) = putAtom (f t) + where + f :: Dhcp4OptionTag -> Word8 + f OptTagPad = 0 + f OptTagEnd = 255 + f OptTagSubnetMask = 1 + f OptTagTimeOffset = 2 + f OptTagRouters = 3 + f OptTagTimeServers = 4 + f OptTagIEN116NameServers = 5 + f OptTagNameServers = 6 + f OptTagLogServers = 7 + f OptTagCookieServers = 8 + f OptTagLPRServers = 9 + f OptTagImpressServers = 10 + f OptTagResourceLocationServers = 11 + f OptTagHostName = 12 + f OptTagBootFileSize = 13 + f OptTagMeritDumpFile = 14 + f OptTagDomainName = 15 + f OptTagSwapServer = 16 + f OptTagRootPath = 17 + f OptTagExtensionsPath = 18 + f OptTagEnableIPForwarding = 19 + f OptTagEnableNonLocalSourceRouting = 20 + f OptTagPolicyFilters = 21 + f OptTagMaximumDatagramReassemblySize = 22 + f OptTagDefaultTTL = 23 + f OptTagPathMTUAgingTimeout = 24 + f OptTagPathMTUPlateauTable = 25 + f OptTagInterfaceMTU = 26 + f OptTagAllSubnetsAreLocal = 27 + f OptTagBroadcastAddress = 28 + f OptTagPerformMaskDiscovery = 29 + f OptTagShouldSupplyMasks = 30 + f OptTagShouldPerformRouterDiscovery = 31 + f OptTagRouterSolicitationAddress = 32 + f OptTagStaticRoutes = 33 + f OptTagShouldNegotiateArpTrailers = 34 + f OptTagArpCacheTimeout = 35 + f OptTagUseRFC1042EthernetEncapsulation = 36 + f OptTagTcpDefaultTTL = 37 + f OptTagTcpKeepaliveInterval = 38 + f OptTagTcpKeepaliveUseGarbage = 39 + f OptTagNisDomainName = 40 + f OptTagNisServers = 41 + f OptTagNtpServers = 42 + f OptTagVendorSpecific = 43 + f OptTagNetBiosNameServers = 44 + f OptTagNetBiosDistributionServers = 45 + f OptTagNetBiosNodeType = 46 + f OptTagNetBiosScope = 47 + f OptTagXWindowsFontServer = 48 + f OptTagXWindowsDisplayManagers = 49 + f OptTagRequestIPAddress = 50 + f OptTagIPAddressLeaseTime = 51 + f OptTagOverload = 52 + f OptTagMessageType = 53 + f OptTagServerIdentifier = 54 + f OptTagParameterRequestList = 55 + f OptTagErrorMessage = 56 + f OptTagMaxDHCPMessageSize = 57 + f OptTagRenewalTime = 58 + f OptTagRebindingTime = 59 + f OptTagVendorClass = 60 + f OptTagClientIdentifier = 61 + f OptTagNetWareDomainName = 62 + f OptTagNetWareInfo = 63 + f OptTagNisPlusDomain = 64 + f OptTagNisPlusServers = 65 + f OptTagTftpServer = 66 + f OptTagBootfileName = 67 + f OptTagSmtpServers = 69 + f OptTagPopServers = 70 + f OptTagNntpServers = 71 + f OptTagWwwServers = 72 + f OptTagFingerServers = 73 + f OptTagIrcServers = 74 + f OptTagStreetTalkServers = 75 + f OptTagStreetTalkDirectoryAssistanceServers = 76 + f OptTagFQDN = 81 + f OptTagAutoconfiguration = 116 + +----------------------------------------------------------------------- +-- NetBIOS node type and operations ----------------------------------- +----------------------------------------------------------------------- + +data NetBiosNodeType + = BNode + | PNode + | MNode + | HNode + deriving (Show,Eq) + +instance Option NetBiosNodeType where + getOption = defaultFixedGetOption + putOption = defaultFixedPutOption + +instance CodecAtom NetBiosNodeType where + getAtom = do + b <- getAtom + case b :: Word8 of + 0x1 -> return BNode + 0x2 -> return PNode + 0x4 -> return MNode + 0x8 -> return HNode + _ -> fail "Unknown NetBIOS node type" + + putAtom t = putAtom $ case t of + BNode -> 0x1 :: Word8 + PNode -> 0x2 + MNode -> 0x4 + HNode -> 0x8 + + atomSize _ = 1 + +----------------------------------------------------------------------- +-- Overload option type and operations -------------------------------- +----------------------------------------------------------------------- + +data OverloadOption + = UsedFileField + | UsedSNameField + | UsedBothFields + deriving (Show, Eq) + +instance Option OverloadOption where + getOption = defaultFixedGetOption + putOption = defaultFixedPutOption + +instance CodecAtom OverloadOption where + getAtom = do b <- getAtom + case b :: Word8 of + 1 -> return UsedFileField + 2 -> return UsedSNameField + 3 -> return UsedBothFields + _ -> fail ("Bad overload value 0x" ++ showHex b "") + putAtom t = putAtom $ case t of + UsedFileField -> 1 :: Word8 + UsedSNameField -> 2 + UsedBothFields -> 3 + + atomSize _ = atomSize (undefined :: Word8) + +----------------------------------------------------------------------- +-- Options list operations -------------------------------------------- +----------------------------------------------------------------------- + +getDhcp4Options :: ByteString -> ByteString + -> Get (String, String, [Dhcp4Option]) +getDhcp4Options sname file = do + MagicCookie <- getAtom + options0 <- remainingAsOptions + case lookupOverload options0 of + Nothing -> return (nullTerminated sname, nullTerminated file, options0) + + Just UsedFileField -> do + options1 <- localParse file remainingAsOptions + let options = options0 ++ options1 + NVTAsciiString fileString + = fromMaybe (NVTAsciiString "") (lookupFile options) + return (nullTerminated sname, fileString, options) + + Just UsedSNameField -> do + options1 <- localParse sname remainingAsOptions + let options = options0 ++ options1 + NVTAsciiString snameString + = fromMaybe (NVTAsciiString "") (lookupSname options) + return (snameString, nullTerminated file, options) + + Just UsedBothFields -> do + + -- The file field MUST be interpreted for options before the sname field. + -- RFC 2131, Section 4.1, Page 24 + options1 <- localParse file remainingAsOptions + options2 <- localParse sname remainingAsOptions + let options = options0 ++ options1 ++ options2 + NVTAsciiString snameString + = fromMaybe (NVTAsciiString "") (lookupSname options) + NVTAsciiString fileString + = fromMaybe (NVTAsciiString "") (lookupFile options) + return (snameString, fileString, options) + + where + remainingAsOptions = scrubControls =<< repeatedly getDhcp4Option + + localParse bs m = case runGet m bs of + Right x -> return x + Left err -> fail err + + +putDhcp4Options :: [Dhcp4Option] -> Put +putDhcp4Options opts = putAtom MagicCookie + *> traverse_ putDhcp4Option opts + *> putControlOption ControlEnd + +scrubControls :: (Applicative m, Monad m) + => [Either ControlTag Dhcp4Option] -> m [Dhcp4Option] +scrubControls [] = fail "No END option found" +scrubControls (Left ControlPad : xs) = scrubControls xs +scrubControls (Left ControlEnd : xs) = [] <$ traverse_ eatPad xs +scrubControls (Right o : xs) = (o :) <$> scrubControls xs + +-- | 'eatPad' fails on any non 'ControlPad' option with an error message. +eatPad :: Monad m => Either ControlTag Dhcp4Option -> m () +eatPad (Left ControlPad) = return () +eatPad _ = fail "Unexpected option after END option" + +replicateA :: Applicative f => Int -> f a -> f [a] +replicateA n f = sequenceA (replicate n f) + +repeatedly :: Get a -> Get [a] +repeatedly m = do + done <- isEmpty + if done then return [] + else (:) <$> m <*> repeatedly m + +nullTerminated :: ByteString -> String +nullTerminated = takeWhile (/= '\NUL') . BS8.unpack + +lookupOverload :: [Dhcp4Option] -> Maybe OverloadOption +lookupOverload = foldr f Nothing + where f (OptOverload o) _ = Just o + f _ a = a + +lookupFile :: [Dhcp4Option] -> Maybe NVTAsciiString +lookupFile = foldr f Nothing + where f (OptBootfileName fn) _ = Just fn + f _ a = a + +lookupSname :: [Dhcp4Option] -> Maybe NVTAsciiString +lookupSname = foldr f Nothing + where f (OptTftpServer n) _ = Just n + f _ a = a + +lookupParams :: [Dhcp4Option] -> Maybe [OptionTagOrError] +lookupParams = foldr f Nothing + where f (OptParameterRequestList n) _ = Just n + f _ a = a + +lookupMessageType :: [Dhcp4Option] -> Maybe Dhcp4MessageType +lookupMessageType = foldr f Nothing + where f (OptMessageType n) _ = Just n + f _ a = a + +lookupRequestAddr :: [Dhcp4Option] -> Maybe IP4 +lookupRequestAddr = foldr f Nothing + where f (OptRequestIPAddress n) _ = Just n + f _ a = a + +lookupLeaseTime :: [Dhcp4Option] -> Maybe Word32 +lookupLeaseTime = foldr f Nothing + where f (OptIPAddressLeaseTime t) _ = Just t + f _ a = a + +----------------------------------------------------------------------- +-- Protected parser and unparser monad -------------------------------- +----------------------------------------------------------------------- + +class Option a where + getOption :: Get a + putOption :: a -> Put + +instance CodecAtom a => Option [a] where + getOption = do + let (n, m) = getRecord + len <- getLen + let (count, remainder) = divMod len n + unless (remainder == 0) (fail ("Length was not a multiple of " ++ show n)) + unless (count > 0) (fail "Minimum length not met") + replicateA count $ label "List of fixed-length values" $ isolate n m + putOption xs = do putLen (atomSize (head xs) * length xs) + traverse_ putAtom xs + +instance (CodecAtom a, CodecAtom b) => Option (a,b) where + getOption = defaultFixedGetOption + putOption = defaultFixedPutOption +instance Option Bool where + getOption = defaultFixedGetOption + putOption = defaultFixedPutOption +instance Option Word8 where + getOption = defaultFixedGetOption + putOption = defaultFixedPutOption +instance Option Word16 where + getOption = defaultFixedGetOption + putOption = defaultFixedPutOption +instance Option Word32 where + getOption = defaultFixedGetOption + putOption = defaultFixedPutOption +instance Option IP4 where + getOption = defaultFixedGetOption + putOption = defaultFixedPutOption +instance Option SubnetMask where + getOption = defaultFixedGetOption + putOption = defaultFixedPutOption +instance Option ByteString where + getOption = do len <- getLen + getByteString len + putOption bs = do putLen (BS.length bs) + putByteString bs + +defaultFixedGetOption :: CodecAtom a => Get a +defaultFixedGetOption = fixedLen n m + where (n,m) = getRecord + +defaultFixedPutOption :: CodecAtom a => a -> Put +defaultFixedPutOption x = do + putLen (atomSize x) + putAtom x + +fixedLen :: Int -> Get a -> Get a +fixedLen expectedLen m = do + len <- getLen + unless (len == expectedLen) (fail "Bad length on \"fixed-length\" option.") + label "Fixed length field" (isolate expectedLen m) + +getRecord :: CodecAtom a => (Int, Get a) +getRecord = (atomSize undef, m) + where + (undef, m) = (undefined, getAtom) :: CodecAtom a => (a, Get a) + + +instance CodecAtom OptionTagOrError where + getAtom = getOptionTag + putAtom x = putOptionTag x + atomSize _ = 1 + +newtype NVTAsciiString = NVTAsciiString String + deriving (Eq, Show) + +instance Option NVTAsciiString where + getOption = do len <- getLen + bs <- getByteString len + return (NVTAsciiString (nullTerminated bs)) + putOption (NVTAsciiString str) = do + putLen (length str) + putByteString (BS8.pack str) + +getLen :: Get Int +getLen = fromIntegral <$> getWord8 + +putLen :: Int -> Put +putLen n = putWord8 $ fromIntegral n diff --git a/src/Hans/Message/EthernetFrame.hs b/src/Hans/Message/EthernetFrame.hs new file mode 100644 index 0000000..93f2977 --- /dev/null +++ b/src/Hans/Message/EthernetFrame.hs @@ -0,0 +1,44 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} + +module Hans.Message.EthernetFrame ( + EtherType(..) + , EthernetFrame(..) + ) where + +import Hans.Address.Mac (Mac) + +import Data.Serialize (Serialize(..)) +import Data.Serialize.Get (Get,getBytes,remaining) +import Data.Serialize.Put (Put,putByteString) +import Data.ByteString (ByteString) +import Data.Word (Word16) +import Numeric (showHex) + +newtype EtherType = EtherType { getEtherType :: Word16 } + deriving (Eq,Num,Ord,Serialize) + +instance Show EtherType where + showsPrec _ (EtherType et) = showString "EtherType 0x" . showHex et + +data EthernetFrame = EthernetFrame + { etherDest :: !Mac + , etherSource :: !Mac + , etherType :: !EtherType + , etherData :: ByteString + } deriving (Eq,Show) + +instance Serialize EthernetFrame where + get = parseEthernetFrame + put = renderEthernetFrame + +parseEthernetFrame :: Get EthernetFrame +parseEthernetFrame = do + dst <- get + src <- get + ty <- get + body <- getBytes =<< remaining + return $! EthernetFrame dst src ty body + +renderEthernetFrame :: EthernetFrame -> Put +renderEthernetFrame (EthernetFrame s d t da) = + put s >> put d >> put t >> putByteString da diff --git a/src/Hans/Message/Icmp4.hs b/src/Hans/Message/Icmp4.hs new file mode 100644 index 0000000..020a7a2 --- /dev/null +++ b/src/Hans/Message/Icmp4.hs @@ -0,0 +1,398 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} + +module Hans.Message.Icmp4 where + +import Hans.Address.IP4 (IP4) +import Hans.Message.Types (Lifetime) +import Hans.Utils (Packet) +import Hans.Utils.Checksum (pokeChecksum, computeChecksum) + +import Control.Monad (liftM2, unless, when, replicateM) +import Data.Serialize (Serialize(..)) +import Data.Serialize.Get (getWord8, getByteString, remaining, skip, Get, label, + lookAhead, getBytes, isEmpty) +import Data.Serialize.Put (putWord8,putByteString, Put, runPut) +import Data.Int (Int32) +import Data.Word (Word8,Word16,Word32) +import System.IO.Unsafe (unsafePerformIO) + + +-- General ICMP Packets -------------------------------------------------------- + +data Icmp4Packet + -- RFC 792 - Internet Control Message Protocol + = EchoReply Identifier SequenceNumber Packet + | DestinationUnreachable DestinationUnreachableCode Packet + | SourceQuench Packet + | Redirect RedirectCode IP4 Packet + | Echo Identifier SequenceNumber Packet + | RouterAdvertisement Lifetime [RouterAddress] + | RouterSolicitation + | TimeExceeded TimeExceededCode Packet + | ParameterProblem Word8 Packet + | Timestamp Identifier SequenceNumber Word32 Word32 Word32 + | TimestampReply Identifier SequenceNumber Word32 Word32 Word32 + | Information Identifier SequenceNumber + | InformationReply Identifier SequenceNumber + + -- rfc 1393 - Traceroute Using an IP Option + | TraceRoute TraceRouteCode Identifier Word16 Word16 Word32 Word32 + + -- rfc 950 - Internet Standard Subnetting Procedure + | AddressMask Identifier SequenceNumber + | AddressMaskReply Identifier SequenceNumber Word32 + deriving Show + +noCode :: String -> Get () +noCode str = do + code <- getWord8 + unless (code == 0) + (fail (str ++ " expects code 0")) + +instance Serialize Icmp4Packet where + get = label "ICMP" $ do + rest <- lookAhead (getBytes =<< remaining) + unless (computeChecksum 0 rest == 0) + (fail "Bad checksum") + ty <- get + + let firstGet :: Serialize a => String -> (a -> Get b) -> Get b + firstGet labelString f = label labelString $ do + code <- get + skip 2 -- checksum + f code + + case (ty :: Word8) of + 0 -> firstGet "Echo Reply" $ \ NoCode -> do + ident <- get + seqNum <- get + dat <- getByteString =<< remaining + return $! EchoReply ident seqNum dat + + 3 -> firstGet "DestinationUnreachable" $ \ code -> do + skip 4 -- unused + dat <- getByteString =<< remaining + return $! DestinationUnreachable code dat + + 4 -> firstGet "Source Quence" $ \ NoCode -> do + skip 4 -- unused + dat <- getByteString =<< remaining + return $! SourceQuench dat + + 5 -> firstGet "Redirect" $ \ code -> do + gateway <- get + dat <- getByteString =<< remaining + return $! Redirect code gateway dat + + 8 -> firstGet "Echo" $ \ NoCode -> do + ident <- get + seqNum <- get + dat <- getByteString =<< remaining + return $! Echo ident seqNum dat + + 9 -> firstGet "Router Advertisement" $ \ NoCode -> do + n <- getWord8 + sz <- getWord8 + unless (sz == 2) + (fail ("Expected size 2, got: " ++ show sz)) + lifetime <- get + addrs <- replicateM (fromIntegral n) get + return $! RouterAdvertisement lifetime addrs + + 10 -> firstGet "Router Solicitation" $ \ NoCode -> do + skip 4 -- reserved + return RouterSolicitation + + 11 -> firstGet "Time Exceeded" $ \ code -> do + skip 4 -- unused + dat <- getByteString =<< remaining + return $! TimeExceeded code dat + + 12 -> firstGet "Parameter Problem" $ \ NoCode -> do + ptr <- getWord8 + skip 3 -- unused + dat <- getByteString =<< remaining + return $! ParameterProblem ptr dat + + 13 -> firstGet "Timestamp" $ \ NoCode -> do + ident <- get + seqNum <- get + origTime <- get + recvTime <- get + tranTime <- get + return $! Timestamp ident seqNum origTime recvTime tranTime + + 14 -> firstGet "Timestamp Reply" $ \ NoCode -> do + ident <- get + seqNum <- get + origTime <- get + recvTime <- get + tranTime <- get + return $! TimestampReply ident seqNum origTime recvTime tranTime + + 15 -> firstGet "Information" $ \ NoCode -> do + ident <- get + seqNum <- get + return $! Information ident seqNum + + 16 -> firstGet "Information Reply" $ \ NoCode -> do + ident <- get + seqNum <- get + return $! InformationReply ident seqNum + + 17 -> firstGet "Address Mask" $ \ NoCode -> do + ident <- get + seqNum <- get + skip 4 -- address mask + return $! AddressMask ident seqNum + + 18 -> firstGet "Address Mask Reply" $ \ NoCode -> do + ident <- get + seqNum <- get + mask <- get + return $! AddressMaskReply ident seqNum mask + + 30 -> firstGet "Trace Route" $ \ code -> do + ident <- get + skip 2 -- unused + outHop <- get + retHop <- get + speed <- get + mtu <- get + return $! TraceRoute code ident outHop retHop speed mtu + + _ -> fail ("Unknown type: " ++ show ty) + + + put = putByteString . unsafePerformIO . setChecksum . runPut . put' + -- Argument for safety: The bytestring being + -- destructively modified here is only accessible + -- through the composition and will never escape + where + setChecksum pkt = pokeChecksum (computeChecksum 0 pkt) pkt 2 + + firstPut :: Serialize a => Word8 -> a -> Put + firstPut ty code + = do put ty + put code + put (0 :: Word16) + + put' (EchoReply ident seqNum dat) + = do firstPut 0 NoCode + put ident + put seqNum + putByteString dat + + put' (DestinationUnreachable code dat) + = do firstPut 3 code + put (0 :: Word32) -- unused + putByteString dat + + put' (SourceQuench dat) + = do firstPut 4 NoCode + put (0 :: Word32) -- unused + putByteString dat + + put' (Redirect code gateway dat) + = do firstPut 5 code + put gateway + putByteString dat + + put' (Echo ident seqNum dat) + = do firstPut 8 NoCode + put ident + put seqNum + putByteString dat + + put' (RouterAdvertisement lifetime addrs) + = do let len = length addrs + addrSize :: Word8 + addrSize = 2 + + when (len > 255) + (fail "Too many routers in Router Advertisement") + + firstPut 9 NoCode + put (fromIntegral len :: Word8) + put addrSize + put lifetime + mapM_ put addrs + + put' RouterSolicitation + = do firstPut 10 NoCode + put (0 :: Word32) -- RESERVED + + put' (TimeExceeded code dat) + = do firstPut 11 code + put (0 :: Word32) -- unused + putByteString dat + + put' (ParameterProblem ptr dat) + = do firstPut 12 NoCode + put ptr + put (0 :: Word8) -- unused + put (0 :: Word16) -- unused + putByteString dat + + put' (Timestamp ident seqNum origTime recvTime tranTime) + = do firstPut 13 NoCode + put ident + put seqNum + put origTime + put recvTime + put tranTime + + put' (TimestampReply ident seqNum origTime recvTime tranTime) + = do firstPut 14 NoCode + put ident + put seqNum + put origTime + put recvTime + put tranTime + + put' (Information ident seqNum) + = do firstPut 15 NoCode + put ident + put seqNum + + put' (InformationReply ident seqNum) + = do firstPut 16 NoCode + put ident + put seqNum + + put' (AddressMask ident seqNum) + = do firstPut 17 NoCode + put ident + put seqNum + put (0 :: Word32) -- address mask + + put' (AddressMaskReply ident seqNum mask) + = do firstPut 17 NoCode + put ident + put seqNum + put mask + + put' (TraceRoute code ident outHop retHop speed mtu) + = do firstPut 30 code + put ident + put (0 :: Word16) -- unused + put outHop + put retHop + put speed + put mtu + +data NoCode = NoCode + +instance Serialize NoCode where + get = do b <- getWord8 + unless (b == 0) + (fail ("Expected code 0, got code: " ++ show b)) + return NoCode + put NoCode = putWord8 0 + +data DestinationUnreachableCode + = NetUnreachable + | HostUnreachable + | ProtocolUnreachable + | PortUnreachable + | FragmentationUnreachable + | SourceRouteFailed + deriving Show + +instance Serialize DestinationUnreachableCode where + get = do b <- getWord8 + case b of + 0 -> return NetUnreachable + 1 -> return HostUnreachable + 2 -> return ProtocolUnreachable + 3 -> return PortUnreachable + 4 -> return FragmentationUnreachable + 5 -> return SourceRouteFailed + _ -> fail "Invalid code for Destination Unreachable" + + put NetUnreachable = putWord8 0 + put HostUnreachable = putWord8 1 + put ProtocolUnreachable = putWord8 2 + put PortUnreachable = putWord8 3 + put FragmentationUnreachable = putWord8 4 + put SourceRouteFailed = putWord8 5 + +data TimeExceededCode + = TimeToLiveExceededInTransit + | FragmentReassemblyTimeExceeded + deriving Show + +instance Serialize TimeExceededCode where + get = do b <- getWord8 + case b of + 0 -> return TimeToLiveExceededInTransit + 1 -> return FragmentReassemblyTimeExceeded + _ -> fail "Invalid code for Time Exceeded" + + put TimeToLiveExceededInTransit = putWord8 0 + put FragmentReassemblyTimeExceeded = putWord8 1 + +data RedirectCode + = RedirectForNetwork + | RedirectForHost + | RedirectForTypeOfServiceAndNetwork + | RedirectForTypeOfServiceAndHost + deriving Show + +instance Serialize RedirectCode where + get = do b <- getWord8 + case b of + 0 -> return RedirectForNetwork + 1 -> return RedirectForHost + 2 -> return RedirectForTypeOfServiceAndNetwork + 3 -> return RedirectForTypeOfServiceAndHost + _ -> fail "Invalid code for Time Exceeded" + + put RedirectForNetwork = putWord8 0 + put RedirectForHost = putWord8 1 + put RedirectForTypeOfServiceAndNetwork = putWord8 2 + put RedirectForTypeOfServiceAndHost = putWord8 3 + +data TraceRouteCode + = TraceRouteForwarded + | TraceRouteDiscarded + deriving Show + +instance Serialize TraceRouteCode where + get = do b <- getWord8 + case b of + 0 -> return TraceRouteForwarded + 1 -> return TraceRouteDiscarded + _ -> fail "Invalid code for Trace Route" + + put TraceRouteForwarded = putWord8 0 + put TraceRouteDiscarded = putWord8 1 + +-- Router Discovery Data ------------------------------------------------------- + +newtype PreferenceLevel = PreferenceLevel Int32 + deriving (Show,Eq,Ord,Num,Serialize) + +data RouterAddress = RouterAddress + { raAddr :: IP4 + , raPreferenceLevel :: PreferenceLevel + } deriving Show + +instance Serialize RouterAddress where + get = liftM2 RouterAddress get get + + put ra = do + put (raAddr ra) + put (raPreferenceLevel ra) + +newtype Identifier = Identifier Word16 + deriving (Show, Eq, Ord, Num, Serialize) + +newtype SequenceNumber = SequenceNumber Word16 + deriving (Show, Eq, Ord, Num, Serialize) + +getUntilDone :: Serialize a => Get [a] +getUntilDone = do + empty <- isEmpty + if empty then return [] + else liftM2 (:) get getUntilDone diff --git a/src/Hans/Message/Ip4.hs b/src/Hans/Message/Ip4.hs new file mode 100644 index 0000000..3b94f4d --- /dev/null +++ b/src/Hans/Message/Ip4.hs @@ -0,0 +1,280 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} + +module Hans.Message.Ip4 where + +import Hans.Address.IP4 (IP4) +import Hans.Utils +import Hans.Utils.Checksum + +import Control.Monad (unless) +import Data.Serialize (Serialize(..)) +import Data.Serialize.Get (Get,getWord8,getWord16be,getByteString,isolate,label) +import Data.Serialize.Put (runPut,runPutM,putWord8,putWord16be,putByteString) +import Data.Bits (Bits((.&.),(.|.),testBit,setBit,shiftR,shiftL,bit)) +import Data.Word (Word8,Word16) +import qualified Data.ByteString as S + + +-- IP4 Pseudo Header ----------------------------------------------------------- + +-- 0 7 8 15 16 23 24 31 +-- +--------+--------+--------+--------+ +-- | source address | +-- +--------+--------+--------+--------+ +-- | destination address | +-- +--------+--------+--------+--------+ +-- | zero |protocol| length | +-- +--------+--------+--------+--------+ +mkIP4PseudoHeader :: IP4 -> IP4 -> IP4Protocol -> MkPseudoHeader +mkIP4PseudoHeader src dst prot len = runPut $ do + put src + put dst + putWord8 0 >> put prot >> putWord16be (fromIntegral len) + + +-- IP4 Packets ----------------------------------------------------------------- + +newtype Ident = Ident { getIdent :: Word16 } + deriving (Eq,Ord,Num,Show,Serialize,Integral,Real,Enum) + +newtype IP4Protocol = IP4Protocol { getIP4Protocol :: Word8 } + deriving (Eq,Ord,Num,Show,Serialize) + +data IP4Packet = IP4Packet + { ip4Header :: !IP4Header + , ip4Payload :: S.ByteString + } deriving Show + +data IP4Header = IP4Header + { ip4Version :: !Word8 + , ip4TypeOfService :: !Word8 + , ip4Ident :: !Ident + , ip4MayFragment :: Bool + , ip4MoreFragments :: Bool + , ip4FragmentOffset :: !Word16 + , ip4TimeToLive :: !Word8 + , ip4Protocol :: !IP4Protocol + , ip4Checksum :: !Word16 + , ip4SourceAddr :: !IP4 + , ip4DestAddr :: !IP4 + , ip4Options :: [IP4Option] + } deriving Show + +emptyIP4Header :: IP4Protocol -> IP4 -> IP4 -> IP4Header +emptyIP4Header prot src dst = IP4Header + { ip4Version = 4 + , ip4TypeOfService = 0 + , ip4Ident = 0 + , ip4MayFragment = False + , ip4MoreFragments = False + , ip4FragmentOffset = 0 + , ip4TimeToLive = 127 + , ip4Protocol = prot + , ip4Checksum = 0 + , ip4SourceAddr = src + , ip4DestAddr = dst + , ip4Options = [] + } + + +noMoreFragments :: IP4Header -> IP4Header +noMoreFragments hdr = hdr { ip4MoreFragments = False } + +moreFragments :: IP4Header -> IP4Header +moreFragments hdr = hdr { ip4MoreFragments = True } + +addOffset :: Word16 -> IP4Header -> IP4Header +addOffset off hdr = hdr { ip4FragmentOffset = ip4FragmentOffset hdr + off } + +setIdent :: Ident -> IP4Header -> IP4Header +setIdent i hdr = hdr { ip4Ident = i } + + +-- | Calculate the size of an IP4 packet +ip4PacketSize :: IP4Packet -> Int +ip4PacketSize (IP4Packet hdr bs) = + ip4HeaderSize hdr + fromIntegral (S.length bs) + +-- | Calculate the size of an IP4 header +ip4HeaderSize :: IP4Header -> Int +ip4HeaderSize hdr = 20 + sum (map ip4OptionSize (ip4Options hdr)) + + +-- | Fragment a single IP packet into one or more, given an MTU to fit into. +splitPacket :: Int -> IP4Packet -> [IP4Packet] +splitPacket mtu pkt + | ip4PacketSize pkt > mtu = fragmentPacket mtu' pkt + | otherwise = [pkt] + where + mtu' = fromIntegral (mtu - ip4HeaderSize (ip4Header pkt)) + + +-- | Given a fragment size and a packet, fragment the packet into multiple +-- smaller ones. +fragmentPacket :: Int -> IP4Packet -> [IP4Packet] +fragmentPacket mtu pkt@(IP4Packet hdr bs) + | payloadLen <= mtu = [pkt { ip4Header = noMoreFragments hdr }] + | otherwise = frag : fragmentPacket mtu pkt' + where + payloadLen = S.length bs + (as,rest) = S.splitAt mtu bs + alen = fromIntegral (S.length as) + pkt' = pkt { ip4Header = hdr', ip4Payload = rest } + hdr' = addOffset alen hdr + frag = pkt { ip4Header = moreFragments hdr, ip4Payload = as } + + +-- 0 1 2 3 +-- 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- |Version| IHL |Type of Service| Total Length | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- | Identification |Flags| Fragment Offset | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- | Time to Live | Protocol | Header Checksum | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- | Source Address | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- | Destination Address | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +parseIP4Packet :: Get (IP4Header, Int, Int) +parseIP4Packet = do + b0 <- getWord8 + let ver = b0 `shiftR` 4 + let ihl = fromIntegral ((b0 .&. 0xf) * 4) + label "IP4 Header" $ isolate (ihl - 1) $ do + tos <- getWord8 + len <- getWord16be + ident <- get + b1 <- getWord16be + let flags = b1 `shiftR` 13 + let off = b1 .&. 0x1fff + ttl <- getWord8 + prot <- get + cs <- getWord16be + source <- get + dest <- get + let optlen = ihl - 20 + opts <- label "IP4 Options" + $ isolate optlen + $ getOptions + $ fromIntegral optlen + let hdr = IP4Header + { ip4Version = ver + , ip4TypeOfService = tos + , ip4Ident = ident + , ip4MayFragment = flags `testBit` 1 + , ip4MoreFragments = flags `testBit` 0 + , ip4FragmentOffset = off * 8 + , ip4TimeToLive = ttl + , ip4Protocol = prot + , ip4Checksum = cs + , ip4SourceAddr = source + , ip4DestAddr = dest + , ip4Options = opts + } + return (hdr, fromIntegral ihl, fromIntegral len) + + +-- | The final step to render an IP header and its payload out as a bytestring. +renderIP4Packet :: IP4Packet -> IO Packet +renderIP4Packet (IP4Packet hdr pkt) = do + let (len,bs) = runPutM $ do + let (optbs,optlen) = renderOptions (ip4Options hdr) + let ihl = 20 + optlen + putWord8 (ip4Version hdr `shiftL` 4 .|. (ihl `div` 4)) + putWord8 (ip4TypeOfService hdr) + putWord16be (fromIntegral (S.length pkt) + fromIntegral ihl) + + put (ip4Ident hdr) + let frag | ip4MayFragment hdr = (`setBit` 1) + | otherwise = id + let morefrags | ip4MoreFragments hdr = (`setBit` 0) + | otherwise = id + let flags = frag (morefrags 0) + let off = ip4FragmentOffset hdr `div` 8 + putWord16be (flags `shiftL` 13 .|. off .&. 0x1fff) + + putWord8 (ip4TimeToLive hdr) + put (ip4Protocol hdr) + putWord16be 0 -- (ip4Checksum hdr) + + put (ip4SourceAddr hdr) + + put (ip4DestAddr hdr) + + putByteString optbs + + putByteString pkt + + return ihl + let cs = computeChecksum 0 (S.take (fromIntegral len) bs) + pokeChecksum cs bs 10 + + +-- IP4 Options ----------------------------------------------------------------- + +renderOptions :: [IP4Option] -> (S.ByteString,Word8) +renderOptions opts = case optlen `mod` 4 of + 0 -> (optbs,fromIntegral optlen) + -- pad with no-ops + n -> (optbs `S.append` S.replicate n 0x1, fromIntegral (optlen + n)) + where + optbs = runPut (mapM_ put opts) + optlen = S.length optbs + + +getOptions :: Int -> Get [IP4Option] +getOptions len + | len <= 0 = return [] + | otherwise = do + o <- get + rest <- getOptions (len - ip4OptionSize o) + return $! (o : rest) + + +data IP4Option = IP4Option + { ip4OptionCopied :: !Bool + , ip4OptionClass :: !Word8 + , ip4OptionNum :: !Word8 + , ip4OptionData :: S.ByteString + } deriving Show + + +ip4OptionSize :: IP4Option -> Int +ip4OptionSize opt = case ip4OptionNum opt of + 0 -> 1 + 1 -> 1 + _ -> 2 + fromIntegral (S.length (ip4OptionData opt)) + + +instance Serialize IP4Option where + get = do + b <- getWord8 + let optCopied = testBit b 7 + let optClass = (b `shiftR` 5) .&. 0x3 + let optNum = b .&. 0x1f + bs <- case optNum of + 0 -> return S.empty + 1 -> return S.empty + _ -> do + len <- getWord8 + unless (len >= 2) (fail "Option length parameter is to small") + getByteString (fromIntegral (len - 2)) + return $! IP4Option + { ip4OptionCopied = optCopied + , ip4OptionClass = optClass + , ip4OptionNum = optNum + , ip4OptionData = bs + } + put opt = do + let copied | ip4OptionCopied opt = bit 7 + | otherwise = 0 + putWord8 (copied .|. ((ip4OptionClass opt .&. 0x3) `shiftL` 5) + .|. ip4OptionNum opt .&. 0x1f) + case ip4OptionNum opt of + 0 -> return () + 1 -> return () + _ -> do + putWord8 (fromIntegral (S.length (ip4OptionData opt))) + putByteString (ip4OptionData opt) diff --git a/src/Hans/Message/Tcp.hs b/src/Hans/Message/Tcp.hs new file mode 100644 index 0000000..4c8ea65 --- /dev/null +++ b/src/Hans/Message/Tcp.hs @@ -0,0 +1,446 @@ +{-# LANGUAGE FlexibleInstances #-} + +module Hans.Message.Tcp where + +import Hans.Address.IP4 (IP4) +import Hans.Message.Ip4 (mkIP4PseudoHeader,IP4Protocol(..)) +import Hans.Utils.Checksum (computePartialChecksum,computeChecksum,pokeChecksum) + +import Control.Monad (when,unless,ap) +import Data.Bits ((.&.),setBit,testBit,shiftL,shiftR) +import Data.List (foldl',find) +import Data.Serialize + (Get,Put,Putter,getWord16be,putWord16be,getWord32be,putWord32be,getWord8 + ,putWord8,putByteString,getBytes,remaining,label,isolate,skip,runPut) +import Data.Word (Word8,Word16,Word32) +import System.IO.Unsafe (unsafePerformIO) +import qualified Data.ByteString as S + + +-- Tcp Support Types ----------------------------------------------------------- + +tcpProtocol :: IP4Protocol +tcpProtocol = IP4Protocol 0x6 + +newtype TcpPort = TcpPort + { getPort :: Word16 + } deriving Show + +putTcpPort :: Putter TcpPort +putTcpPort (TcpPort w16) = putWord16be w16 + +getTcpPort :: Get TcpPort +getTcpPort = TcpPort `fmap` getWord16be + + +newtype TcpSeqNum = TcpSeqNum + { getSeqNum :: Word32 + } deriving (Eq,Ord,Show) + +putTcpSeqNum :: Putter TcpSeqNum +putTcpSeqNum (TcpSeqNum w32) = putWord32be w32 + +getTcpSeqNum :: Get TcpSeqNum +getTcpSeqNum = TcpSeqNum `fmap` getWord32be + + +newtype TcpAckNum = TcpAckNum + { getAckNum :: Word32 + } deriving (Eq,Ord,Show) + +putTcpAckNum :: Putter TcpAckNum +putTcpAckNum (TcpAckNum w32) = putWord32be w32 + +getTcpAckNum :: Get TcpAckNum +getTcpAckNum = TcpAckNum `fmap` getWord32be + + +-- Tcp Header ------------------------------------------------------------------ + +-- 0 1 2 3 +-- 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- | Source Port | Destination Port | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- | Sequence Number | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- | Acknowledgment Number | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- | Data | |C|E|U|A|P|R|S|F| | +-- | Offset| Res. |W|C|R|C|S|S|Y|I| Window | +-- | | |R|E|G|K|H|T|N|N| | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- | Checksum | Urgent Pointer | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- | Options | Padding | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-- | data | +-- +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +data TcpHeader = TcpHeader + { tcpSourcePort :: !TcpPort + , tcpDestPort :: !TcpPort + , tcpSeqNum :: !TcpSeqNum + , tcpAckNum :: !TcpAckNum + , tcpCwr :: !Bool + , tcpEce :: !Bool + , tcpUrg :: !Bool + , tcpAck :: !Bool + , tcpPsh :: !Bool + , tcpRst :: !Bool + , tcpSyn :: !Bool + , tcpFin :: !Bool + , tcpWindow :: !Word16 + , tcpChecksum :: !Word16 + , tcpUrgentPointer :: !Word16 + , tcpOptions :: [TcpOption] + } deriving Show + +instance HasTcpOptions TcpHeader where + findTcpOption tag hdr = findTcpOption tag (tcpOptions hdr) + setTcpOption opt hdr = hdr { tcpOptions = setTcpOption opt (tcpOptions hdr) } + +emptyTcpHeader :: TcpHeader +emptyTcpHeader = TcpHeader + { tcpSourcePort = TcpPort 0 + , tcpDestPort = TcpPort 0 + , tcpSeqNum = TcpSeqNum 0 + , tcpAckNum = TcpAckNum 0 + , tcpCwr = False + , tcpEce = False + , tcpUrg = False + , tcpAck = False + , tcpPsh = False + , tcpRst = False + , tcpSyn = False + , tcpFin = False + , tcpWindow = 0 + , tcpChecksum = 0 + , tcpUrgentPointer = 0 + , tcpOptions = [] + } + +-- | The length of the fixed part of the TcpHeader, in 4-byte octets. +tcpFixedHeaderLength :: Int +tcpFixedHeaderLength = 5 + +-- | Calculate the length of a TcpHeader, in 4-byte octets. +tcpHeaderLength :: TcpHeader -> Int +tcpHeaderLength hdr = + tcpFixedHeaderLength + tcpOptionsLength (tcpOptions hdr) + +-- | Render a TcpHeader. The checksum value is never rendered, as it is +-- expected to be calculated and poked in afterwords. +putTcpHeader :: Putter TcpHeader +putTcpHeader hdr = do + putTcpPort (tcpSourcePort hdr) + putTcpPort (tcpDestPort hdr) + putTcpSeqNum (tcpSeqNum hdr) + putTcpAckNum (tcpAckNum hdr) + putWord8 (fromIntegral (tcpHeaderLength hdr) `shiftL` 4) + putTcpControl hdr + putWord16be (tcpWindow hdr) + putWord16be 0 + putWord16be (tcpUrgentPointer hdr) + putTcpOptions (tcpOptions hdr) + +-- | Parse out a TcpHeader, and its length. The resulting length is in bytes, +-- and is derived from the data offset. +getTcpHeader :: Get (TcpHeader,Int) +getTcpHeader = do + src <- getTcpPort + dst <- getTcpPort + seqNum <- getTcpSeqNum + ackNum <- getTcpAckNum + b <- getWord8 + let len = fromIntegral ((b `shiftR` 4) .&. 0xf) + cont <- getWord8 + win <- getWord16be + cs <- getWord16be + urgent <- getWord16be + let optsLen = len - tcpFixedHeaderLength + opts <- getTcpOptions optsLen + let hdr = setTcpControl cont emptyTcpHeader + { tcpSourcePort = src + , tcpDestPort = dst + , tcpSeqNum = seqNum + , tcpAckNum = ackNum + , tcpWindow = win + , tcpChecksum = cs + , tcpUrgentPointer = urgent + , tcpOptions = opts + } + return (hdr,len * 4) + +-- | Render out the @Word8@ that contains the Control field of the TcpHeader. +putTcpControl :: Putter TcpHeader +putTcpControl c = + putWord8 $ putBit 7 tcpCwr + $ putBit 6 tcpEce + $ putBit 5 tcpUrg + $ putBit 4 tcpAck + $ putBit 3 tcpPsh + $ putBit 2 tcpRst + $ putBit 1 tcpSyn + $ putBit 0 tcpFin + 0 + where + putBit n prj w | prj c = setBit w n + | otherwise = w + +-- | Parse out the control flags from the octet that contains them. +setTcpControl :: Word8 -> TcpHeader -> TcpHeader +setTcpControl w hdr = hdr + { tcpCwr = testBit w 7 + , tcpEce = testBit w 6 + , tcpUrg = testBit w 5 + , tcpAck = testBit w 4 + , tcpPsh = testBit w 3 + , tcpRst = testBit w 2 + , tcpSyn = testBit w 1 + , tcpFin = testBit w 0 + } + + +-- Tcp Options ----------------------------------------------------------------- + +class HasTcpOptions a where + findTcpOption :: TcpOptionTag -> a -> Maybe TcpOption + setTcpOption :: TcpOption -> a -> a + +data TcpOptionTag + = OptTagEndOfOptions + | OptTagNoOption + | OptTagMaxSegmentSize + | OptTagWindowScaling + | OptTagTimestamp + | OptTagUnknown !Word8 + deriving (Eq,Show) + +getTcpOptionTag :: Get TcpOptionTag +getTcpOptionTag = do + ty <- getWord8 + return $! case ty of + 0 -> OptTagEndOfOptions + 1 -> OptTagNoOption + 2 -> OptTagMaxSegmentSize + 3 -> OptTagWindowScaling + 8 -> OptTagTimestamp + _ -> OptTagUnknown ty + +putTcpOptionTag :: Putter TcpOptionTag +putTcpOptionTag tag = + putWord8 $ case tag of + OptTagEndOfOptions -> 0 + OptTagNoOption -> 1 + OptTagMaxSegmentSize -> 2 + OptTagWindowScaling -> 3 + OptTagTimestamp -> 8 + OptTagUnknown ty -> ty + +instance HasTcpOptions [TcpOption] where + findTcpOption tag = find p + where + p opt = tag == tcpOptionTag opt + + setTcpOption opt = loop + where + tag = tcpOptionTag opt + loop [] = [opt] + loop (o:opts) + | tcpOptionTag o == tag = opt : opts + | otherwise = o : loop opts + + +data TcpOption + = OptEndOfOptions + | OptNoOption + | OptMaxSegmentSize !Word16 + | OptWindowScaling !Word8 + | OptTimestamp !Word32 !Word32 + | OptUnknown !Word8 !Word8 !S.ByteString + deriving Show + +tcpOptionTag :: TcpOption -> TcpOptionTag +tcpOptionTag opt = case opt of + OptEndOfOptions{} -> OptTagEndOfOptions + OptNoOption{} -> OptTagNoOption + OptMaxSegmentSize{} -> OptTagMaxSegmentSize + OptWindowScaling{} -> OptTagWindowScaling + OptTimestamp{} -> OptTagTimestamp + OptUnknown ty _ _ -> OptTagUnknown ty + +-- | Get the length of a TcpOptions, in 4-byte words. This rounds up to the +-- nearest 4-byte word. +tcpOptionsLength :: [TcpOption] -> Int +tcpOptionsLength opts + | left == 0 = len + | otherwise = len + 1 + where + (len,left) = foldl' step 0 opts `quotRem` 4 + step acc opt = tcpOptionLength opt + acc + +tcpOptionLength :: TcpOption -> Int +tcpOptionLength OptEndOfOptions{} = 1 +tcpOptionLength OptNoOption{} = 1 +tcpOptionLength OptMaxSegmentSize{} = 4 +tcpOptionLength OptWindowScaling{} = 3 +tcpOptionLength OptTimestamp{} = 10 +tcpOptionLength (OptUnknown _ len _) = fromIntegral len + + +-- | Render out the tcp options, and pad with zeros if they don't fall on a +-- 4-byte boundary. +putTcpOptions :: Putter [TcpOption] +putTcpOptions opts = do + let len = tcpOptionsLength opts + left = len `rem` 4 + padding + | left == 0 = 0 + | otherwise = 4 - left + mapM_ putTcpOption opts + when (padding > 0) (putByteString (S.replicate padding 0)) + +putTcpOption :: Putter TcpOption +putTcpOption opt = + case opt of + OptEndOfOptions -> putWord8 0 + OptNoOption -> putWord8 1 + OptMaxSegmentSize mss -> putMaxSegmentSize mss + OptWindowScaling w -> putWindowScaling w + OptTimestamp v r -> putTimestamp v r + OptUnknown ty len bs -> putUnknown ty len bs + +-- | Parse in known tcp options. +getTcpOptions :: Int -> Get [TcpOption] +getTcpOptions len = label ("Tcp Options (" ++ show len ++ ")") + $ isolate (len * 4) loop + where + loop = do + left <- remaining + if left <= 0 then return [] else body + + body = do + opt <- getTcpOption + case opt of + OptNoOption -> loop + + OptEndOfOptions -> do + skip =<< remaining + return [] + + _ -> do + rest <- loop + return (opt:rest) + +getTcpOption :: Get TcpOption +getTcpOption = do + tag <- getTcpOptionTag + case tag of + OptTagEndOfOptions -> return OptEndOfOptions + OptTagNoOption -> return OptNoOption + OptTagMaxSegmentSize -> getMaxSegmentSize + OptTagWindowScaling -> getWindowScaling + OptTagTimestamp -> getTimestamp + OptTagUnknown ty -> getUnknown ty + +getMaxSegmentSize :: Get TcpOption +getMaxSegmentSize = label "Max Segment Size" $ isolate 3 $ do + len <- getWord8 + unless (len == 4) (fail ("Unexpected length: " ++ show len)) + OptMaxSegmentSize `fmap` getWord16be + +putMaxSegmentSize :: Putter Word16 +putMaxSegmentSize w16 = do + putWord8 4 + putWord16be w16 + +getWindowScaling :: Get TcpOption +getWindowScaling = label "Window Scaling" $ isolate 2 $ do + len <- getWord8 + unless (len == 3) (fail ("Unexpected length: " ++ show len)) + OptWindowScaling `fmap` getWord8 + +putWindowScaling :: Putter Word8 +putWindowScaling w = do + putWord8 3 + putWord8 w + +getTimestamp :: Get TcpOption +getTimestamp = label "Timestamp" $ isolate 9 $ do + len <- getWord8 + unless (len == 10) (fail ("Unexpected length: " ++ show len)) + OptTimestamp `fmap` getWord32be `ap` getWord32be + +putTimestamp :: Word32 -> Word32 -> Put +putTimestamp v r = do + putWord8 8 + putWord8 10 + putWord32be v + putWord32be r + +getUnknown :: Word8 -> Get TcpOption +getUnknown ty = do + len <- getWord8 + body <- isolate (fromIntegral len - 2) (getBytes =<< remaining) + return (OptUnknown ty len body) + +putUnknown :: Word8 -> Word8 -> S.ByteString -> Put +putUnknown ty len body = do + putWord8 ty + putWord8 len + putByteString body + + +-- Tcp Packet ------------------------------------------------------------------ + +data TcpPacket = TcpPacket + { tcpHeader :: !TcpHeader + , tcpBody :: !S.ByteString + } deriving Show + +-- | Parse a TcpPacket. +getTcpPacket :: Get TcpPacket +getTcpPacket = do + pktLen <- remaining + (hdr,hdrLen) <- getTcpHeader + body <- getBytes (pktLen - hdrLen) + return (TcpPacket hdr body) + +-- | Render out a TcpPacket, without calculating its checksum. +putTcpPacket :: Putter TcpPacket +putTcpPacket (TcpPacket hdr body) = do + putTcpHeader hdr + putByteString body + +-- | Calculate the checksum of a TcpHeader, and its body. +renderWithTcpChecksumIP4 :: IP4 -> IP4 -> TcpPacket -> S.ByteString +renderWithTcpChecksumIP4 src dst pkt@(TcpPacket _ body) = hdrbs `S.append` body + where + (hdrbs,_) = computeTcpChecksumIP4 src dst pkt + +-- | Calculate the checksum of a tcp packet, and return its rendered header. +computeTcpChecksumIP4 :: IP4 -> IP4 -> TcpPacket -> (S.ByteString,Word16) +computeTcpChecksumIP4 src dst (TcpPacket hdr body) = + -- this is safe, as the header bytestring that gets modified is modified at + -- its creation time. + (cs `seq` unsafePerformIO (pokeChecksum cs hdrbs 16), cs) + where + hdrbs = runPut (putTcpHeader hdr { tcpChecksum = 0 }) + phcs = computePartialChecksum 0 + $ mkIP4PseudoHeader src dst tcpProtocol + $ S.length hdrbs + S.length body + hdrcs = computePartialChecksum phcs hdrbs + cs = computeChecksum hdrcs body + +-- | Re-create the checksum, minimizing duplication of the original, rendered +-- TCP packet. +recreateTcpChecksumIP4 :: IP4 -> IP4 -> S.ByteString -> Word16 +recreateTcpChecksumIP4 src dst bytes = computeChecksum hdrcs rest + where + phcs = computePartialChecksum 0 + $ mkIP4PseudoHeader src dst tcpProtocol + $ S.length bytes + (hdrbs,rest) = S.splitAt 18 bytes + hdrbs' = unsafePerformIO (pokeChecksum 0 (S.copy hdrbs) 16) + hdrcs = computePartialChecksum phcs hdrbs' + diff --git a/src/Hans/Message/Types.hs b/src/Hans/Message/Types.hs new file mode 100644 index 0000000..d63f931 --- /dev/null +++ b/src/Hans/Message/Types.hs @@ -0,0 +1,8 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +module Hans.Message.Types where + +import Data.Serialize (Serialize) +import Data.Word (Word16) + +newtype Lifetime = Lifetime Word16 + deriving (Show,Eq,Ord,Num,Serialize) diff --git a/src/Hans/Message/Udp.hs b/src/Hans/Message/Udp.hs new file mode 100644 index 0000000..855692b --- /dev/null +++ b/src/Hans/Message/Udp.hs @@ -0,0 +1,59 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} + +module Hans.Message.Udp where + +import Hans.Utils +import Hans.Utils.Checksum + +import Data.Serialize (Serialize(..)) +import Data.Serialize.Get (Get,getWord16be,getByteString,isolate,label) +import Data.Serialize.Put (runPut,putWord16be,putByteString) +import Data.Word (Word16) +import qualified Data.ByteString as S + +newtype UdpPort = UdpPort { getUdpPort :: Word16 } + deriving (Eq,Ord,Num,Show,Serialize,Enum,Bounded) + +data UdpPacket = UdpPacket + { udpHeader :: !UdpHeader + , udpPayload :: S.ByteString + } deriving Show + +data UdpHeader = UdpHeader + { udpSourcePort :: !UdpPort + , udpDestPort :: !UdpPort + , udpChecksum :: !Word16 + } deriving Show + +parseUdpPacket :: Get UdpPacket +parseUdpPacket = do + src <- get + dst <- get + b16 <- getWord16be + let len = fromIntegral b16 + label "UDPPacket" $ isolate (len - 6) $ do + cs <- getWord16be + bs <- getByteString (len - 8) + let hdr = UdpHeader + { udpSourcePort = src + , udpDestPort = dst + , udpChecksum = cs + } + return $! UdpPacket hdr bs + +-- | Given a way to make the pseudo header, render the UDP packet. +renderUdpPacket :: UdpPacket -> MkPseudoHeader -> IO Packet +renderUdpPacket (UdpPacket hdr bs) mk = do + let hdrSize = 8 + let len = S.length bs + hdrSize + let ph = mk len + let pcs = computePartialChecksum 0 ph + let bytes = runPut $ do + put (udpSourcePort hdr) + put (udpDestPort hdr) + putWord16be (fromIntegral len) + putWord16be 0 -- initial checksum + putByteString bs + -- the checksum is 6 bytes into the rendered packet + let cs = computeChecksum pcs bytes + pokeChecksum cs bytes 6 diff --git a/src/Hans/Ports.hs b/src/Hans/Ports.hs new file mode 100644 index 0000000..d8450c0 --- /dev/null +++ b/src/Hans/Ports.hs @@ -0,0 +1,56 @@ + +module Hans.Ports ( + -- * Port Management + PortManager + , emptyPortManager + , isReserved + , reserve + , unreserve + , nextPort + ) where + +import Control.Monad (MonadPlus(mzero),guard) +import Data.List (delete) +import qualified Data.Set as Set + + +-- Port Management ------------------------------------------------------------- + +data PortManager i = PortManager + { portNext :: [i] + , portActive :: Set.Set i + } + +emptyPortManager :: [i] -> PortManager i +emptyPortManager range = PortManager + { portNext = range + , portActive = Set.empty + } + +isReserved :: (Eq i, Ord i) => i -> PortManager i -> Bool +isReserved i pm = i `Set.member` portActive pm + +reserve :: (MonadPlus m, Eq i, Ord i) => i -> PortManager i -> m (PortManager i) +reserve i pm = do + guard (not (isReserved i pm)) + return $! pm + { portNext = delete i (portNext pm) + , portActive = Set.insert i (portActive pm) + } + +unreserve :: (MonadPlus m, Eq i, Ord i) + => i -> PortManager i -> m (PortManager i) +unreserve i pm = do + guard (isReserved i pm) + return $! pm + { portNext = i : portNext pm + , portActive = Set.delete i (portActive pm) + } + +nextPort :: (MonadPlus m, Eq i, Ord i) + => PortManager i -> m (i, PortManager i) +nextPort pm = case portNext pm of + [] -> mzero + i:_ -> do + pm' <- reserve i pm + return $! (i,pm') diff --git a/src/Hans/Setup.hs b/src/Hans/Setup.hs new file mode 100644 index 0000000..55dcad0 --- /dev/null +++ b/src/Hans/Setup.hs @@ -0,0 +1,96 @@ +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE FlexibleInstances #-} + +module Hans.Setup where + +import Hans.Address +import Hans.Address.IP4 +import Hans.Address.Mac +import Hans.Channel +import Hans.Layer.Arp +import Hans.Layer.Ethernet +import Hans.Layer.IP4 +import Hans.Layer.Icmp4 +import Hans.Layer.Tcp +import Hans.Layer.Timer +import Hans.Layer.Udp + + +data NetworkStack = NetworkStack + { nsArp :: ArpHandle + , nsEthernet :: EthernetHandle + , nsIp4 :: IP4Handle + , nsIcmp4 :: Icmp4Handle + , nsTimers :: TimerHandle + , nsUdp :: UdpHandle + , nsTcp :: TcpHandle + } + + +setup :: IO NetworkStack +setup = do + eth <- newChannel + arp <- newChannel + ip4 <- newChannel + icmp <- newChannel + th <- newChannel + udp <- newChannel + tcp <- newChannel + + runTimerLayer th + runEthernetLayer eth + runArpLayer arp eth th + runIP4Layer ip4 arp eth + runIcmp4Layer icmp ip4 + runUdpLayer udp ip4 icmp + runTcpLayer tcp ip4 th + + return NetworkStack + { nsArp = arp + , nsEthernet= eth + , nsIp4 = ip4 + , nsIcmp4 = icmp + , nsTimers = th + , nsUdp = udp + , nsTcp = tcp + } + + +data SomeOption = forall o. Option o => SomeOption o + +instance Show SomeOption where + showsPrec p (SomeOption o) = parens (showsPrec 11 o) + where + parens body | p > 10 = showString "(SomeOption " . body . showChar ')' + | otherwise = showString "SomeOption " . body + +toOption :: Option o => o -> SomeOption +toOption = SomeOption + +class Show o => Option o where + apply :: o -> NetworkStack -> IO () + + +instance Option SomeOption where + apply (SomeOption o) ns = apply o ns + +instance Option o => Option [o] where + apply os ns = mapM_ (`apply` ns) os + + +data OptEthernet mask = LocalEthernet mask Mac + deriving Show + +instance Option (OptEthernet IP4Mask) where + apply (LocalEthernet mask mac) ns = do + let (addr,_) = getMaskComponents mask + addLocalAddress (nsArp ns) addr mac + addIP4RoutingRule (nsIp4 ns) (Direct mask addr 1500) + + +data OptRoute mask addr = Route mask addr + deriving Show + +instance Option (OptRoute IP4Mask IP4) where + apply (Route mask addr) ns = addIP4RoutingRule (nsIp4 ns) (Indirect mask addr) diff --git a/src/Hans/Simple.hs b/src/Hans/Simple.hs new file mode 100644 index 0000000..fbb2d1f --- /dev/null +++ b/src/Hans/Simple.hs @@ -0,0 +1,59 @@ +module Hans.Simple ( + -- * UDP Messages + renderUdp + , renderIp4 + + -- * Ident + , Ident + , newIdent + , nextIdent + ) where + +import Hans.Address.IP4 (IP4) +import Hans.Message.Ip4 (IP4Packet(..),IP4Header(..),IP4Protocol(..) + ,mkIP4PseudoHeader,splitPacket,renderIP4Packet + ,emptyIP4Header) +import Hans.Message.Udp (UdpHeader(..),UdpPort,UdpPacket(..),renderUdpPacket) +import qualified Hans.Message.Ip4 as IP4 + +import Control.Concurrent (MVar,newMVar,modifyMVar) +import Data.Word (Word16) +import qualified Data.ByteString as S + +newtype Ident = Ident (MVar IP4.Ident) + +newIdent :: IO Ident +newIdent = Ident `fmap` newMVar 0 + +nextIdent :: Ident -> IO IP4.Ident +nextIdent (Ident var) = modifyMVar var (\i -> return (i+1, i)) + +type MTU = Word16 + +fromMTU :: Maybe MTU -> Int +fromMTU = maybe ip4Max (min ip4Max . fromIntegral) + where ip4Max = 0xffff + +-- | Render a UDP message to an unfragmented IP4 packet. +renderUdp :: Ident -> Maybe MTU -> IP4 -> IP4 -> UdpPort -> UdpPort + -> S.ByteString + -> IO [S.ByteString] +renderUdp i mb source dest srcPort destPort payload = do + let prot = IP4Protocol 0x11 + let mk = mkIP4PseudoHeader source dest prot + let hdr = UdpHeader + { udpSourcePort = srcPort + , udpDestPort = destPort + , udpChecksum = 0 + } + udp <- renderUdpPacket (UdpPacket hdr payload) mk + renderIp4 i mb prot source dest udp + + +-- | Render an IP4 packet. +renderIp4 :: Ident -> Maybe MTU -> IP4Protocol -> IP4 -> IP4 -> S.ByteString + -> IO [S.ByteString] +renderIp4 ident mb prot source dest payload = do + i <- nextIdent ident + let hdr = (emptyIP4Header prot source dest) { ip4Ident = i } + mapM renderIP4Packet (splitPacket (fromMTU mb) (IP4Packet hdr payload)) diff --git a/src/Hans/Utils.hs b/src/Hans/Utils.hs new file mode 100644 index 0000000..a601ca2 --- /dev/null +++ b/src/Hans/Utils.hs @@ -0,0 +1,32 @@ +module Hans.Utils where + +import Control.Monad (MonadPlus(mzero)) +import Data.ByteString (ByteString) +import Numeric (showHex) + +type DeviceName = String + +type Packet = ByteString + +type MkPseudoHeader = Int -> Packet + +type Endo a = a -> a + +-- | Discard the result of a monadic computation. +void :: Monad m => m a -> m () +void m = m >> return () + +-- | Show a single hex number, padded with a leading 0. +showPaddedHex :: (Integral a) => a -> ShowS +showPaddedHex x + | x < 0x10 = showChar '0' . base + | otherwise = base + where base = showHex x + +-- | Lift a maybe into MonadPlus +just :: MonadPlus m => Maybe a -> m a +just = maybe mzero return + +-- | Make a singleton list. +singleton :: a -> [a] +singleton x = [x] diff --git a/src/Hans/Utils/Checksum.hs b/src/Hans/Utils/Checksum.hs new file mode 100644 index 0000000..431b8b3 --- /dev/null +++ b/src/Hans/Utils/Checksum.hs @@ -0,0 +1,70 @@ +{-# LANGUAGE BangPatterns #-} +-- BANNERSTART +-- - Copyright 2006-2008, Galois, Inc. +-- - This software is distributed under a standard, three-clause BSD license. +-- - Please see the file LICENSE, distributed with this software, for specific +-- - terms and conditions. +-- Author: Adam Wick +-- BANNEREND +-- |A module providing checksum computations to other parts of Hans. The +-- checksum here is the standard Internet 16-bit checksum (the one's +-- complement of the one's complement sum of the data). +module Hans.Utils.Checksum( + computeChecksum + , computePartialChecksum + , clearChecksum + , pokeChecksum + ) + where + +import Control.Exception (assert) +import Data.Bits (Bits(shiftL,shiftR,complement,clearBit,(.&.),rotate)) +import Data.Word (Word8,Word16,Word32) +import Foreign.Storable (pokeByteOff) +import qualified Data.ByteString as S +import qualified Data.ByteString.Unsafe as S + + +-- | Clear the two bytes at the checksum offset of a rendered packet. +clearChecksum :: S.ByteString -> Int -> IO S.ByteString +clearChecksum b off = S.unsafeUseAsCStringLen b $ \(ptr,len) -> do + assert (len > off + 1) (pokeByteOff ptr off (0 :: Word16)) + return b + +-- | Poke a checksum into a bytestring. +pokeChecksum :: Word16 -> S.ByteString -> Int -> IO S.ByteString +pokeChecksum cs b off = S.unsafeUseAsCStringLen b $ \(ptr,len) -> do + assert (off < len + 1) (pokeByteOff ptr off (rotate cs 8)) + return b + +-- | Compute the final checksum, using the given initial value. +computeChecksum :: Word32 -> S.ByteString -> Word16 +computeChecksum c0 = + complement . fromIntegral . fold32 . fold32 . computePartialChecksum c0 + +-- | Compute a partial checksum, yielding a value suitable to be passed to +-- computeChecksum. +computePartialChecksum :: Word32 -> S.ByteString -> Word32 +computePartialChecksum base b = result + where + !n' = S.length b + + !result + | odd n' = step most hi 0 + | otherwise = most + where hi = S.unsafeIndex b (n'-1) + + !most = loop (fromIntegral base) 0 + + loop !acc off + | off < n = loop (step acc hi lo) (off + 2) + | otherwise = acc + where hi = S.unsafeIndex b off + lo = S.unsafeIndex b (off+1) + n = clearBit n' 0 + +step :: Word32 -> Word8 -> Word8 -> Word32 +step acc hi lo = acc + fromIntegral hi `shiftL` 8 + fromIntegral lo + +fold32 :: Word32 -> Word32 +fold32 x = (x .&. 0xFFFF) + (x `shiftR` 16) diff --git a/src/Network/TCP/Aux/Misc.hs b/src/Network/TCP/Aux/Misc.hs new file mode 100644 index 0000000..3b24070 --- /dev/null +++ b/src/Network/TCP/Aux/Misc.hs @@ -0,0 +1,372 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.Aux.Misc where + +import Network.TCP.Type.Base +import Network.TCP.Type.Timer +import Network.TCP.Type.Datagram +import Network.TCP.Type.Socket +import Network.TCP.Aux.Param +import Foreign +import Data.Map as Map +import Data.List as List +import Data.Maybe +import Data.List as List +import System.IO.Unsafe +import Control.Exception + +debug :: (Monad m) => String -> m a +debug s = seq (unsafePerformIO $ putStrLn s) return undefined + + +bound_ports :: Map SocketID (TCPSocket threadt) -> [Port] +bound_ports sockmap = List.map get_local_port (keys sockmap) + +-- not considering SO_REUSEADDR +-- bound_port_allowed :: Map SocketID (TCPSocket threadt) -> Port -> Bool +-- bound_port_allowed m p = not $ List.elem p (bound_ports m) + +-- lookup_socketid_by_seg :: Map SocketID (TCPSocket threadt) -> TCPSegment -> Maybe SocketID +-- lookup_socketid_by_seg m s = +-- let fakeid = (tcp_dst s, tcp_src s) in +-- if (member fakeid m) then +-- Just fakeid +-- else +-- Nothing +-- +create_timer (curr_time :: Time) (offset :: Time) = curr_time + offset +slow_timer = create_timer + +create_timewindow (curr_time :: Time) (offset :: Time) a = Just (Timed a (create_timer curr_time offset)) + +-- queues + +-- enqueue_message msg q = addToQueue q msg +-- enqueue_messages msgs q = foldl addToQueue q msgs +-- +accept_incoming_q0 :: SocketListen -> Bool +accept_incoming_q0 lis = + (length $ lis_q lis) < (backlog_fudge (lis_qlimit lis)) +accept_incoming_q lis = + (length $ lis_q lis) < 3 * (backlog_fudge (lis_qlimit lis `div` 2)) +drop_from_q0 lis = + (length $ lis_q0 lis) >= tcp_q0maxlimit + +do_tcp_options :: Time -> Bool -> (TimeWindow Timestamp) -> Timestamp -> Maybe (Timestamp,Timestamp) +do_tcp_options curr_time cb_tf_doing_tstmp cb_ts_recent cb_ts_val = + if cb_tf_doing_tstmp then + let ts_ecr' = case timewindow_val curr_time cb_ts_recent of + Just x -> x + Nothing -> Timestamp 0 + in Just(cb_ts_val, ts_ecr') + else + Nothing + +calculate_tcp_options_len cb_tf_doing_tstmp = + if cb_tf_doing_tstmp then 12 else 0 + +rounddown bs v = if v < bs then v else (v `div` bs) * bs +roundup bs v = ((v+(bs-1)) `div` bs) * bs + +calculate_buf_sizes (cb_t_maxseg :: Int) + (seg_mss :: Maybe Int) + (bw_delay_product_for_rt :: Maybe Int) + (is_local_conn :: Bool) + (rcvbufsize :: Int) + (sndbufsize :: Int) + (cb_tf_doing_tstmp :: Bool) + = let t_maxseg' = + let maxseg = (min cb_t_maxseg (max 64 $ (case seg_mss of Nothing -> mssdflt; Just x-> x))) in + -- BSD + maxseg - (calculate_tcp_options_len cb_tf_doing_tstmp) + in + let t_maxseg'' = rounddown mclbytes (t_maxseg') in + let rcvbufsize' = case bw_delay_product_for_rt of Nothing->rcvbufsize; Just x->x in + let (rcvbufsize'', t_maxseg''') = ( if rcvbufsize' < t_maxseg'' + then (rcvbufsize', rcvbufsize') + else (min (sb_max) (roundup (t_maxseg'') rcvbufsize'), + t_maxseg'')) in + let sndbufsize' = case bw_delay_product_for_rt of Nothing->sndbufsize; Just x->x in + let sndbufsize'' = (if sndbufsize' < t_maxseg''' + then sndbufsize' + else min (sb_max) (roundup (t_maxseg'') sndbufsize')) in + let snd_cwnd = t_maxseg''' * ((if is_local_conn then ss_fltsz_local else ss_fltsz)) in + (rcvbufsize'', sndbufsize'', t_maxseg''', snd_cwnd) + + +calculate_bsd_rcv_wnd :: TCPSocket t -> Int +calculate_bsd_rcv_wnd (tcp_sock :: TCPSocket t)= + let cb = cb_rcv tcp_sock in + assert ((rcv_adv cb) >= (rcv_nxt cb)) $ -- assertion for debugging + max (seq_diff (rcv_adv cb) (rcv_nxt cb)) + (freebsd_so_rcvbuf - (bufc_length $ rcvq cb)) + +send_queue_space sndq_max sndq_size = (sndq_max - sndq_size) + + + +update_idle (curr_time :: Time) tcp_sock = + let tt_keep' = if not (st tcp_sock == SYN_RECEIVED && tf_needfin (cb tcp_sock)) then + Just (slow_timer curr_time tcptv_keep_idle) + else + tt_keep $ cb_time tcp_sock + tt_fin_wait_2' = if st tcp_sock == FIN_WAIT_2 then + Just (slow_timer curr_time tcptv_maxidle ) + else + tt_fin_wait_2 $ cb_time tcp_sock + in + (tt_keep', tt_fin_wait_2') + +-- tcp timing and rtt + + +tcp_backoffs = tcp_bsd_backoffs +tcp_syn_backoffs = tcp_syn_backoffs + +mode_of :: Maybe (Timed (RexmtMode,Int)) -> Maybe RexmtMode +mode_of (Just (Timed (x,_) _)) = Just x +mode_of Nothing = Nothing + +shift_of :: Maybe (Timed (RexmtMode,Int)) -> Int +shift_of (Just (Timed (_,shift) _ )) = shift + +-- todo: check types! + +-- compute the retransmit timeout to use +computed_rto :: [Int] -> Int -> Rttinf -> Time +computed_rto (backoffs :: [Int]) (shift :: Int) (ri::Rttinf) = + (to_Int64 $ backoffs !! shift ) * (max (t_rttmin ri) ((t_srtt ri) + 4*(t_rttvar ri))) + +-- compute the last-used rxtcur +computed_rxtcur (ri :: Rttinf) = + max (t_rttmin ri) + (min (tcptv_rexmtmax) + ((computed_rto ( if t_wassyn ri then tcp_syn_backoffs else tcp_backoffs ) + (t_lastshift ri) ri ))) + +start_tt_rexmt_gen (mode :: RexmtMode) (backoffs :: [Int]) (shift :: Int) + (wantmin :: Bool) (ri :: Rttinf) (curr_time :: Time) = + let rxtcur = max (if wantmin + then max (t_rttmin ri) (t_lastrtt ri + (2*1000*1000 `div` 100)) -- 2s/100 + else t_rttmin ri ) + ( min (tcptv_rexmtmax ) + ( computed_rto backoffs shift ri) ) + in + Just ( Timed (mode,shift) (create_timer curr_time rxtcur ) ) + +start_tt_rexmt = start_tt_rexmt_gen Rexmt tcp_backoffs +start_tt_rexmtsyn = start_tt_rexmt_gen RexmtSyn tcp_syn_backoffs + +start_tt_persist (shift :: Int) (ri::Rttinf) (curr_time :: Time) = + let cur = max (tcptv_persmin) + (min (tcptv_persmax) + (computed_rto tcp_backoffs shift ri) ) + in + Just ( Timed (Persist, shift) (create_timer curr_time cur)) + +update_rtt :: Time -> Rttinf -> Rttinf +update_rtt rtt ri = + let (t_srtt'', t_rttvar'') + = if tf_srtt_valid ri then + let delta = (rtt - 1000*10) - (t_srtt ri) -- 1000*10 = 1/HZ + vardelta = (abs delta) - (t_rttvar ri) + t_srtt' = max (1000*1000 `div` (32*100)) (t_srtt ri + (delta `div` 8)) + t_rttvar'=max (1000*1000 `div` (16*100)) (t_rttvar ri + (vardelta `div` 4)) + in (t_srtt', t_rttvar') + else + let t_srtt' = rtt + t_rttvar' = rtt `div` 2 + in (t_srtt',t_rttvar') + in + ri { t_rttupdated = t_rttupdated ri + 1 + , tf_srtt_valid = True + , t_srtt = t_srtt'' + , t_rttvar = t_rttvar'' + , t_lastrtt = rtt + , t_lastshift = 0 + , t_wassyn = False + } + +expand_cwnd ssthresh maxseg maxwin cwnd + = min maxwin (cwnd + (if cwnd > ssthresh then (maxseg * maxseg) `div` cwnd else maxseg)) + +-- Path MTU Discovery + +mtu_tab = [65535, 32000, 17914, 8166, 4352, 2002, 1492, 1006, 508, 296, 88] + +next_smaller :: [Int] -> Int -> Int +next_smaller (x:xs) value = if value >= x then x else next_smaller xs value + + +initial_cb_time = TCBTiming + { tt_keep = Nothing + , tt_conn_est = Nothing + , tt_fin_wait_2 = Nothing + , tt_2msl = Nothing + , t_idletime = 0 + , ts_recent = Nothing + , t_badrxtwin = Nothing + } + +initial_cb_snd = TCBSending + { sndq = bufferchain_empty + , snd_una = SeqLocal 0 + , snd_wnd = 0 + , snd_wl1 = SeqForeign 0 + , snd_wl2 = SeqLocal 0 + , snd_cwnd = tcp_maxwin `shiftL` tcp_maxwinscale + , snd_nxt = SeqLocal 0 + , snd_max = SeqLocal 0 + , t_dupacks = 0 + , t_rttinf = Rttinf { t_rttupdated = 0 + , tf_srtt_valid = False + , t_srtt = tcptv_rtobase + , t_rttvar = tcptv_rttvarbase + , t_rttmin = tcptv_min + , t_lastrtt = 0 + , t_lastshift = 0 + , t_wassyn = False + } + , t_rttseg = Nothing + , tt_rexmt = Nothing + } + +{-# INLINE hasfin #-} +{-# INLINE tcp_reass #-} +{-# INLINE tcp_reass_prune #-} + +hasfin seg = case trs_FIN seg of True -> 1; False -> 0 + +-- returns (1) the string +-- (2) the SEQ for the next byte +-- (3) whether FIN has been reached +-- (4) remaining... +-- this is a very SLOW algorithm and should be replaced .... + +tcp_reass :: SeqForeign -> [TCPReassSegment] -> (BufferChain, SeqForeign, Bool, [TCPReassSegment]) +tcp_reass seq rsegq = + let searchpkt rseg = + let seq1 = (trs_seq rseg) + seq2 = seq1 `seq_plus` (bufc_length $ trs_data rseg) `seq_plus` (hasfin rseg) + in (seq >= seq1 && seq < seq2) + in + case List.find searchpkt rsegq of + Nothing -> + (bufferchain_empty, seq, False, rsegq) + Just rseg -> + let data_to_trim = seq `seq_diff` (trs_seq rseg) in + let result_buf = bufferchain_drop data_to_trim (trs_data rseg) in + let next_seq = (trs_seq rseg) `seq_plus` (bufc_length $ trs_data rseg) `seq_plus` (hasfin rseg) in + let new_rsegq = tcp_reass_prune next_seq rsegq in + if trs_FIN rseg then + (result_buf + , next_seq + , True + , new_rsegq + ) + else + let (bufc2, next_seq2, hasfin2, rsegq2) = tcp_reass next_seq new_rsegq in + ( bufferchain_concat result_buf bufc2 + , next_seq2 + , hasfin2 + , rsegq2 + ) + +tcp_reass_prune :: SeqForeign -> [TCPReassSegment] -> [TCPReassSegment] +tcp_reass_prune seq rsegq = + List.filter (\seg -> + let nxtseq = (trs_seq seg) `seq_plus` (bufc_length $ trs_data seg) `seq_plus` (hasfin seg) + in nxtseq > seq + ) rsegq + +initial_cb_rcv = TCBReceiving + { last_ack_sent = SeqForeign 0 + , tf_rxwin0sent = False + , tf_shouldacknow = False + , tt_delack = False + , rcv_adv = SeqForeign 0 + , rcv_wnd = 0 + , rcv_nxt = SeqForeign 0 + , rcvq = bufferchain_empty + , t_segq = [] + } + +initial_cb_misc = TCBMisc + { -- retransmission + snd_ssthresh = tcp_maxwin `shiftL` tcp_maxwinscale + , snd_cwnd_prev = 0 + , snd_ssthresh_prev = 0 + , snd_recover = SeqLocal 0 + -- some tags + , cantsndmore = False + , cantrcvmore = False + , bsd_cantconnect = False + -- initialization parameters + , self_id = SocketID (0,TCPAddr (IPAddr 0,0)) + , parent_id = SocketID (0,TCPAddr (IPAddr 0,0)) + , local_addr = TCPAddr (IPAddr 0,0) + , remote_addr = TCPAddr (IPAddr 0,0) + , t_maxseg = mssdflt + , t_advmss = Nothing + , tf_doing_ws = False + , tf_doing_tstmp = False + , tf_req_tstmp = False + , request_r_scale = Nothing + , snd_scale = 0 + , rcv_scale = 0 + , iss = SeqLocal 0 + , irs = SeqForeign 0 + -- other things i don't use for the moment + , sndurp = Nothing + , rcvurp = Nothing + , iobc = NO_OOBDATA + , rcv_up = SeqForeign 0 + , tf_needfin = False + } + +initial_tcp_socket = TCPSocket + { st = CLOSED + , cb_time = initial_cb_time + , cb_snd = initial_cb_snd + , cb_rcv = initial_cb_rcv + , cb = initial_cb_misc + , sock_listen = SocketListen [] [] 0 + , waiting_list = [] + } + +empty_sid :: SocketID +empty_sid = SocketID (0,TCPAddr (IPAddr 0,0)) + + diff --git a/src/Network/TCP/Aux/Output.hs b/src/Network/TCP/Aux/Output.hs new file mode 100644 index 0000000..003bff2 --- /dev/null +++ b/src/Network/TCP/Aux/Output.hs @@ -0,0 +1,213 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.Aux.Output where + +import Hans.Message.Tcp + +import Network.TCP.Type.Base +import Network.TCP.Type.Timer +import Network.TCP.Type.Datagram as Datagram +import Network.TCP.Type.Socket +import Network.TCP.Type.Syscall +import Network.TCP.Aux.Param +import Network.TCP.Aux.Misc +import Hans.Layer.Tcp.Monad +import Foreign +import Control.Exception +import Control.Monad + +make_syn_segment :: Time -> TCPSocket t -> Timestamp -> TCPSegment +make_syn_segment curr_time sock (ts_val::Timestamp) = + let ws = request_r_scale $ cb sock -- should assert it's <= tcp_maxwinscale ? + mss = t_advmss $ cb $ sock + ts = do_tcp_options curr_time (tf_req_tstmp $ cb sock) (ts_recent $ cb_time $ sock ) ts_val + hdr = + set_tcp_mss mss $ + set_tcp_ws ws $ + set_tcp_ts ts $ emptyTcpHeader + { tcpSeqNum = TcpSeqNum (seq_val (iss (cb sock))) + , tcpAckNum = TcpAckNum 0 + , tcpSyn = True + , tcpWindow = fromIntegral (rcv_wnd (cb_rcv sock)) + } + in mkTCPSegment' + (local_addr (cb sock)) (remote_addr (cb sock)) hdr bufferchain_empty + +make_syn_ack_segment curr_time sock (addrfrom::TCPAddr) (addrto::TCPAddr) (ts_val::Timestamp) = + let urp_any = 0 + tcb = cb sock + win = rcv_wnd (cb_rcv sock) -- `shiftR` (rcv_scale $ cb sock) + ws = if tf_doing_ws tcb then Just (rcv_scale tcb) else Nothing + mss = t_advmss tcb + ts = do_tcp_options curr_time (tf_req_tstmp tcb) (ts_recent $ cb_time sock) ts_val + hdr = + set_tcp_mss mss $ + set_tcp_ws ws $ + set_tcp_ts ts $ emptyTcpHeader + { tcpSeqNum = TcpSeqNum (seq_val (iss tcb)) + , tcpAckNum = TcpAckNum (fseq_val (rcv_nxt (cb_rcv sock))) + , tcpAck = True + , tcpSyn = True + , tcpWindow = fromIntegral win + , tcpUrgentPointer = urp_any + } + in mkTCPSegment' addrfrom addrto hdr bufferchain_empty + +make_ack_segment curr_time sock (fin::Bool) (ts_val::Timestamp) = + let urp_garbage = 0 + tcb = cb sock + win = (rcv_wnd $ cb_rcv sock) `shiftR` (rcv_scale tcb) + ts = do_tcp_options curr_time (tf_req_tstmp tcb) (ts_recent $ cb_time sock) ts_val + sn | fin = snd_una (cb_snd sock) + | otherwise = snd_nxt (cb_snd sock) + hdr = + set_tcp_ts ts $ emptyTcpHeader + { tcpSeqNum = TcpSeqNum (seq_val sn) + , tcpAckNum = TcpAckNum (fseq_val (rcv_nxt (cb_rcv sock))) + , tcpAck = True + , tcpFin = fin + , tcpWindow = fromIntegral win + , tcpUrgentPointer = urp_garbage + } + in mkTCPSegment' (local_addr tcb) (remote_addr tcb) hdr bufferchain_empty + +bsd_make_phantom_segment curr_time sock (addrfrom::TCPAddr) (addrto::TCPAddr) (ts_val::Timestamp) (cantsendmore::Bool) = + let urp_garbage = 0 + tcb = cb sock + scb = cb_snd sock + rcb = cb_rcv sock + win = (rcv_wnd rcb) `shiftR` (rcv_scale tcb) + fin = (cantsendmore && seq_lt (snd_una scb) (seq_minus (snd_max scb) 1)) + ts = do_tcp_options curr_time (tf_req_tstmp tcb) (ts_recent $ cb_time sock) ts_val + sn | fin = snd_una scb + | otherwise = snd_max scb + hdr = + set_tcp_ts ts emptyTcpHeader + { tcpSourcePort = TcpPort 0 + , tcpDestPort = TcpPort 0 + , tcpSeqNum = TcpSeqNum (seq_val sn) + , tcpAckNum = TcpAckNum (fseq_val (rcv_nxt rcb)) + , tcpFin = fin + , tcpWindow = fromIntegral win + , tcpUrgentPointer = urp_garbage + } + in + mkTCPSegment' addrfrom addrto hdr bufferchain_empty + +make_rst_segment_from_cb sock (addrfrom::TCPAddr) (addrto::TCPAddr) = + let hdr = emptyTcpHeader + { tcpSourcePort = TcpPort 0 + , tcpDestPort = TcpPort 0 + , tcpSeqNum = TcpSeqNum (seq_val (snd_nxt (cb_snd sock))) + , tcpAckNum = TcpAckNum (fseq_val (rcv_nxt (cb_rcv sock))) + , tcpAck = False + , tcpRst = False + } + in mkTCPSegment' addrfrom addrto hdr bufferchain_empty + +make_rst_segment_from_seg (seg::TCPSegment) = + let tcp_ACK' = not (tcp_ACK seg) + seq' = if tcp_ACK seg then tcp_ack seg else 0 + ack' = if tcp_ACK' + then let s1 = tcp_seq seg + in s1 `seq_plus` bufc_length (tcp_data seg) + `seq_plus` (if tcp_SYN seg then 1 else 0) + else 0 + + + hdr = emptyTcpHeader + { tcpSeqNum = TcpSeqNum seq' + , tcpAckNum = TcpAckNum ack' + , tcpAck = tcp_ACK' + , tcpRst = True + } + in + mkTCPSegment' (tcp_src seg) (tcp_dst seg) hdr bufferchain_empty + + +dropwithreset (seg::TCPSegment) = + if tcp_RST seg then [] + else let seg' = make_rst_segment_from_seg seg + in [TCPMessage seg'] + +dropwithreset_ignore_or_fail = dropwithreset + +tcp_close_temp sock = + sock { cb = (cb sock) { cantrcvmore = True + , cantsndmore = True + , local_addr = TCPAddr (IPAddr 0,0) + , remote_addr = TCPAddr (IPAddr 0,0) + , bsd_cantconnect = True + } + , st = CLOSED + , cb_snd = (cb_snd sock) { sndq = bufferchain_empty } + } + + +tcp_close :: SocketID -> HMonad t () +tcp_close sid = + do b <- has_sock sid + when b $ do + sock <- lookup_sock sid + let pending_tasks = waiting_list sock + has_parent = (get_local_port $ parent_id $ cb sock) /= 0 + let result = map (\(_,cont) -> cont (SockError "tcpclose")) pending_tasks + emit_ready result + delete_sock sid + when (not has_parent) $ free_local_port $ get_local_port sid + +tcp_drop_and_close :: SocketID -> HMonad t () +tcp_drop_and_close sid = + do b <- has_sock sid + when b $ do + sock <- lookup_sock sid + let outsegs = if st sock `notElem` [CLOSED,LISTEN,SYN_SENT] + then [TCPMessage $ make_rst_segment_from_cb + (sock) (local_addr $ cb sock) (remote_addr $ cb sock)] + else [] + emit_segs outsegs + tcp_close sid + +alloc_local_port :: HMonad t (Maybe Port) +alloc_local_port = do + h <- get_host + case local_ports h of + [] -> return Nothing + port:rest -> do put_host $ h { local_ports = rest } + return $ Just port + +free_local_port port = + modify_host $ \h -> h { local_ports = port:(local_ports h) } + diff --git a/src/Network/TCP/Aux/Param.hs b/src/Network/TCP/Aux/Param.hs new file mode 100644 index 0000000..e063c00 --- /dev/null +++ b/src/Network/TCP/Aux/Param.hs @@ -0,0 +1,120 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.Aux.Param where + +import Network.TCP.Type.Base + +dschedmax = seconds_to_time 1 +dinput_queuemax = seconds_to_time 1 +doutput_queuemax = seconds_to_time 1 + +hz = 100 + +tickintvlmin = seconds_to_time $ 100/(105*hz) +tickintvlmax = seconds_to_time $ 105/(100*hz) + +slow_timer_intvl = seconds_to_time $ 1/2 +-- slow_timer_model_intvl = seconds_to_time $ 1/1000 + +fast_timer_intvl = seconds_to_time $ 1/5 +-- fast_timer_model_intvl = seconds_to_time $ 1/1000 + +kern_timer_intvl = tickintvlmax +-- kern_timer_model_intvl = + + +-- listen queue length +somaxconn ::Int= 128 + +-- buffers +mclbytes ::Int= 2048 +msize ::Int= 256 +sb_max ::Int= 256*1024 + +-- rfc limits + +dtsinval::Time = seconds_to_time $ 24*24*60*60 +tcp_maxwin :: Int = 65535 +tcp_maxwinscale :: Int = 14 + + +--default +freebsd_so_rcvbuf :: Int = 42080 +freebsd_so_sndbuf :: Int = 9216 + +-- tcp parameters + +mssdflt :: Int = 1400 -- 512 +ss_fltsz_local :: Int = 4 +ss_fltsz :: Int = 1 +tcp_do_newreno = True +tcp_q0minlimit :: Int = 30 +tcp_q0maxlimit :: Int = 512*30 + +backlog_fudge :: Int -> Int +backlog_fudge n = min somaxconn n + +-- time values (TCP only) + +tcptv_delack = seconds_to_time 0.1 +tcptv_rtobase = seconds_to_time 3 +tcptv_rttvarbase = seconds_to_time 0 +tcptv_min = seconds_to_time 1 +tcptv_rexmtmax = seconds_to_time 64 +tcptv_msl = seconds_to_time 1 -- this is too stringent... good for testing +tcptv_persmin = seconds_to_time 5 +tcptv_persmax = seconds_to_time 60 +tcptv_keep_init = seconds_to_time 75 +tcptv_keep_idle = seconds_to_time $ 120*60 +tcptv_keepintvl = seconds_to_time $ 75 +tcptv_keepcnt = seconds_to_time $ 8 +tcptv_maxidle = tcptv_keepintvl*8 + + +-- timing related parameters (TCP only) + +tcp_bsd_backoffs :: [Int]= [1,2,4,8,16,32,64, 64,64,64, 64,64,64] +tcp_linux_backoffs = [1,2,4,8,16,32,64, 128,256,512, 512] +tcp_winxp_backoffs = [1,2,4,8,16] + +--tcp_maxrxtshift = 12 +tcp_maxrxtshift :: Int = 3 -- this is not right... for testing only +tcp_synackmaxrxtshift :: Int = 3 + +tcp_syn_bsd_backoffs :: [Int] = [1,1,1,1,1,2,4,8,16,32,64,64,64] +tcp_syn_linux_backoffs =[1,2,4,8,16] +tcp_syn_winxp_backoffs=[1,2] + +listen_qlimit :: Int = 100 diff --git a/src/Network/TCP/Aux/SockMonad.hs b/src/Network/TCP/Aux/SockMonad.hs new file mode 100644 index 0000000..c53b095 --- /dev/null +++ b/src/Network/TCP/Aux/SockMonad.hs @@ -0,0 +1,127 @@ +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.Aux.SockMonad where + +import Network.TCP.Type.Base +import Network.TCP.Type.Socket +import Network.TCP.Aux.Misc +import Hans.Layer.Tcp.Monad +import Control.Exception + +data HState t = HState + { hs_host :: !(Host t) + , hs_sock :: !(TCPSocket t) + } + +newtype SMonad t a = SMonad (HState t -> (a, HState t)) + +instance Monad (SMonad t) where + return a = SMonad $ \s -> (a,s) + x >>= f = bindSMonad x f + {-# INLINE return #-} + {-# INLINE (>>=) #-} + +bindSMonad :: SMonad t a -> (a -> SMonad t b) -> SMonad t b +bindSMonad (SMonad x) f = + SMonad $ \s -> let (res, s') = x s + (SMonad z) = f res in z s' +{-# INLINE bindSMonad #-} + +get_host_ :: SMonad t (Host t) +get_host_ = SMonad $ \s -> (hs_host s,s) + +modify_host_ f = SMonad $ \s -> ((), s { hs_host = f (hs_host s)}) +emit_segs_ segs = modify_host_ $ \h -> h { output_queue = (output_queue h)++ segs} +emit_ready_ threads = modify_host_ $ \h -> h { ready_list = (ready_list h)++ threads} + +{-# INLINE get_host_ #-} +{-# INLINE modify_host_ #-} +{-# INLINE emit_segs_ #-} +{-# INLINE emit_ready_ #-} + +-------------------------------------------------- +-- get_sid = do +-- SMonad $ \s -> (hs_sid s,s) +-- {-# INLINE get_sid #-} + +get_sock = do + SMonad $ \s -> (hs_sock s,s) + +put_sock sock = do + SMonad $ \s -> ((), s { hs_sock = sock}) + +modify_sock f = do + SMonad $ \s -> ((), s { hs_sock = f (hs_sock s)}) +modify_cb f = + SMonad $ \s-> let sock=hs_sock s in ((), s { hs_sock=sock { cb =f (cb sock) }}) +modify_cb_snd f = + SMonad $ \s-> let sock=hs_sock s in ((), s { hs_sock=sock { cb_snd=f (cb_snd sock) }}) +modify_cb_rcv f = + SMonad $ \s-> let sock=hs_sock s in ((), s { hs_sock=sock { cb_rcv=f (cb_rcv sock) }}) +modify_cb_time f= + SMonad $ \s-> let sock=hs_sock s in ((), s { hs_sock=sock {cb_time=f (cb_time sock)}}) +{-# INLINE get_sock #-} +{-# INLINE put_sock #-} +{-# INLINE modify_sock #-} +{-# INLINE modify_cb #-} +{-# INLINE modify_cb_snd #-} +{-# INLINE modify_cb_rcv #-} +{-# INLINE modify_cb_time #-} + + +----------------------------------------------------- +-- has_sock_ :: SocketID -> SMonad t Bool +-- has_sock_ sid = do +-- h <- get_host_ +-- return $ Map.member sid (sock_map h) +-- +-- lookup_sock_ sid = do +-- h <- get_host_ +-- res <- Map.lookup sid (sock_map h) +-- return res +-- +-- {-# INLINE has_sock_ #-} +-- {-# INLINE lookup_sock_ #-} +-- +------------------ + +runSMonad :: SocketID -> (SMonad t a) -> HMonad t a +runSMonad sid (SMonad m) = do + h <- get_host + sock <- lookup_sock sid + let initstate = HState h sock + let (res, finalstate) = m initstate + put_host $ hs_host finalstate + update_sock sid $ \_ -> hs_sock finalstate + return res diff --git a/src/Network/TCP/LTS/In.hs b/src/Network/TCP/LTS/In.hs new file mode 100644 index 0000000..8e37e6b --- /dev/null +++ b/src/Network/TCP/LTS/In.hs @@ -0,0 +1,200 @@ +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.LTS.In + ( tcp_deliver_in_packet + ) +where + +import Hans.Layer.Tcp.Monad +import Hans.Message.Tcp + +import Foreign +import Foreign.C +import Data.List as List +import Control.Exception +import Control.Monad + +import Network.TCP.Type.Base +import Network.TCP.Type.Syscall +import Network.TCP.Type.Timer +import Network.TCP.Type.Socket +import Network.TCP.Type.Datagram +import Network.TCP.Aux.Param + +import Network.TCP.Aux.Misc +import Network.TCP.Aux.Output +import Network.TCP.Aux.SockMonad +import Network.TCP.Aux.Output + +import Network.TCP.LTS.InPassive +import Network.TCP.LTS.InActive +import Network.TCP.LTS.InData +import Network.TCP.LTS.User +import Network.TCP.LTS.Out + +tcp_deliver_in_packet seg = do + let sid = SocketID (get_port (tcp_dst seg), tcp_src seg) + ok <- has_sock sid + if ok + then tcp_deliver_packet_to_sock sid seg + else if tcp_SYN seg && (not $ tcp_ACK seg) && (not $ tcp_RST seg) + then tcp_deliver_syn_packet seg + else emit_segs $ dropwithreset seg + +-- Note: if there exists a socket in TIME_WAIT state, and an SYN +-- packet matches it, the SYN packet will always be delivered to this +-- socket; it will never be delivered to a listening socket. This +-- makes the implementation simpler... + +--pre-condition: sid exists +tcp_deliver_packet_to_sock :: SocketID -> TCPSegment -> HMonad t () +tcp_deliver_packet_to_sock sid seg = + do h <- get_host + sock <- lookup_sock sid + let tcb = cb sock + rcb = cb_rcv sock + scb = cb_snd sock + seqnum = SeqForeign (tcp_seq seg) + acknum = SeqLocal (tcp_ack seg) + + success <- header_prediction seg h sid sock tcb rcb scb seqnum acknum + when (not success) $ + case st sock of + CLOSED -> assert (False) return () + LISTEN -> assert (False) return () + SYN_SENT -> let goodack = (iss tcb) < acknum && acknum <= (snd_max scb) in + if tcp_RST seg then + when (tcp_ACK seg && goodack) $ tcp_close sid + else + if tcp_SYN seg && tcp_ACK seg then + if goodack then runSMonad sid $ deliver_in_2 seg + else emit_segs $ dropwithreset seg + else return () + SYN_RECEIVED -> + let invalidack = acknum <= snd_una scb || acknum > snd_max scb in + if tcp_RST seg then + tcp_close sid + else if tcp_SYN seg || not (tcp_ACK seg) then -- check with spec? + return () + else if invalidack || (seqnum < (irs tcb)) then + return () + else do + sock <- runSMonad sid $ deliver_in_3 seg + if st sock == CLOSED then + tcp_close sid + else when (st sock /= SYN_RECEIVED) $ + di3_socks_update sid + _ -> if tcp_RST seg then + when (st sock /= TIME_WAIT) $ tcp_close sid + else if tcp_SYN seg then + when (st sock==TIME_WAIT) $ emit_segs $ dropwithreset seg + else + if st sock `elem` [FIN_WAIT_1, CLOSING, LAST_ACK, FIN_WAIT_2, TIME_WAIT] + && seqnum `seq_plus` (bufc_length $ tcp_data seg) > (rcv_nxt rcb) + then return () -- data coming into closing socket? + else do sock <- runSMonad sid $ deliver_in_3 seg + --debug $ (show $ st sock) + when (st sock == CLOSED) $ tcp_close sid + +{-# INLINE header_prediction #-} +header_prediction seg h sid sock tcb rcb scb seqnum acknum = + if st sock == ESTABLISHED + && not (tcp_SYN seg) + && not (tcp_FIN seg) + && not (tcp_URG seg) + && not (tcp_RST seg) + && tcp_ACK seg + && seqnum == rcv_nxt rcb + && snd_wnd scb == fromIntegral (tcp_win seg) `shiftL` snd_scale tcb + && snd_max scb == snd_nxt scb + then if bufc_length (tcp_data seg) == 0 + && acknum > (snd_una scb) + && acknum <= (snd_max scb) + && snd_cwnd scb >= snd_wnd scb + && t_dupacks scb < 3 + then do -- pure ack for outstanding data + -------------------------------------------------------------------------------- + --debug $ "prediction 2.1!" + let emission_time = case (tcp_ts seg, t_rttseg scb) of + (Just (ts_val, ts_ecr), _ ) -> Just (ts_ecr `seq_minus` 1) + (Nothing, Just (ts0, seq0)) -> if acknum > seq0 then Just ts0 else Nothing + (Nothing, Nothing) -> Nothing + let t_rttinf' = case emission_time of + Just emtime -> assert ((ticks h) >= emtime) $ + update_rtt ( ((ticks h) `seq_diff` emtime)*10000 ) (t_rttinf scb) + Nothing -> t_rttinf scb + let tt_rexmt' = if acknum == snd_max scb then + Nothing + else case mode_of (tt_rexmt scb) of + Nothing -> start_tt_rexmt 0 True t_rttinf' (clock h) + Just Rexmt -> start_tt_rexmt 0 True t_rttinf' (clock h) + _ -> tt_rexmt scb + let acked = acknum `seq_diff` (snd_una scb) + let snd_wnd' = snd_wnd scb - acked + let sndq' = bufferchain_drop acked (sndq scb) + runSMonad sid $ do + modify_sock $ \s -> s { cb_snd = scb + { sndq = sndq' + , t_dupacks = 0 + , t_rttinf = t_rttinf' + , tt_rexmt = tt_rexmt' + , t_rttseg = if emission_time == Nothing then t_rttseg scb else Nothing + , snd_cwnd = expand_cwnd (snd_ssthresh tcb) + (t_maxseg tcb) + (tcp_maxwin `shiftL` (snd_scale tcb)) + (snd_cwnd scb) + , snd_wnd = snd_wnd' + , snd_una = acknum + --, snd_nxt = max acknum (snd_nxt scb) + } + } + tcp_wakeup + tcp_output_all + return True + -------------------------------------------------------------------------------- + else if acknum == snd_una scb + && List.null (t_segq rcb) + && bufc_length (tcp_data seg) < (freebsd_so_rcvbuf - (bufc_length $ rcvq rcb)) + then do -- pure in-sequence data packet + -------------------------------------------------------------------------------- + return False + -------------------------------------------------------------------------------- + else do + -- debug $ "predictions 2.1, 2.2 fail!" + return False + else do + -- debug $ "prediction 1 fail!" ++ (show $ snd_wnd tcb) + -- ++ " " ++ (show (tcp_win seg)) ++ " " ++ (show $ snd_scale tcb) + return False + diff --git a/src/Network/TCP/LTS/InActive.hs b/src/Network/TCP/LTS/InActive.hs new file mode 100644 index 0000000..efc9943 --- /dev/null +++ b/src/Network/TCP/LTS/InActive.hs @@ -0,0 +1,170 @@ +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.LTS.InActive where + +import Foreign +import Foreign.C +import Data.Maybe +import Network.TCP.Type.Base +import Network.TCP.Type.Syscall +import Network.TCP.Aux.Misc +import Network.TCP.Aux.Param +import Network.TCP.Aux.Output +import Network.TCP.Type.Socket +import Network.TCP.Type.Datagram +import Network.TCP.Aux.SockMonad +import Network.TCP.LTS.User + +deliver_in_2 seg = do + sock <- get_sock + h <- get_host_ + --debug $ "deliver_in_3 " ++ (show seg) + let tcb = cb sock + scb = cb_snd sock + rcb = cb_rcv sock + acknum = SeqLocal (tcp_ack seg) + seqnum = SeqForeign (tcp_seq seg) + let { + -- window scaling + (rcv_scale', snd_scale', tf_doing_ws') = + ( case (request_r_scale tcb, tcp_ws seg) of + (Just rs, Just ss) -> (rs, ss, True) + _ -> (0,0,False) + ); + + -- timestamping + + tf_rcvd_tstmp' = isJust $ tcp_ts seg; + tf_doing_tstmp' = tf_rcvd_tstmp' && (tf_req_tstmp tcb); + -- mss negotiation + ourmss = ( case (t_advmss tcb) of + Nothing -> (t_maxseg tcb) + Just v -> v + ); + + (rcvbufsize', sndbufsize', t_maxseg'', snd_cwnd') = + calculate_buf_sizes ourmss (tcp_mss seg) Nothing False + (freebsd_so_rcvbuf) (freebsd_so_sndbuf) tf_doing_tstmp'; + + rcv_window = min tcp_maxwin freebsd_so_rcvbuf; + + emission_time = + ( case tcp_ts seg of + Just (ts_val, ts_ecr) -> Just (ts_ecr `seq_minus` 1) + Nothing -> case t_rttseg scb of + Just (ts0, seq0) -> if acknum > seq0 then Just ts0 else Nothing + Nothing -> Nothing; + ); + + t_rttseg' = ( case emission_time of + Nothing -> Nothing + Just _ -> t_rttseg scb ); + + t_rttinf' = ( case emission_time of + Just emtime -> update_rtt ( ((ticks h) `seq_diff` emtime)*10*1000 ) (t_rttinf scb) + Nothing -> t_rttinf scb ); + + tt_rexmt' = if acknum == snd_max scb then Nothing else tt_rexmt scb; + + fin' = tcp_FIN seg; + rcvq' = tcp_data seg; + rcv_nxt' = seqnum `seq_plus` 1 `seq_plus` (if fin' then 1 else 0); + rcv_wnd' = rcv_window - (bufc_length $ tcp_data seg); + + cantrcvmore' = if fin' then True else cantrcvmore tcb; + + new_st = if fin' then + if cantsndmore tcb then LAST_ACK else CLOSE_WAIT + else + if cantsndmore tcb then + if snd_max scb > iss tcb `seq_plus` 1 && acknum >= snd_max scb then + FIN_WAIT_2 + else + FIN_WAIT_1 + else + ESTABLISHED; + + newsock = sock + { st = new_st + , cb_time = (cb_time sock) + { t_idletime = clock h + , tt_keep = Just (create_timer (clock h) tcptv_keep_idle) + , tt_conn_est = Nothing + , ts_recent = case tcp_ts seg of + Nothing -> ts_recent $ cb_time sock + Just (ts_val, ts_ecr) -> create_timewindow (clock h) dtsinval (ts_val) + } + , cb_snd = scb + { tt_rexmt = tt_rexmt' + , snd_una = acknum + , snd_nxt = if cantsndmore tcb then acknum else snd_nxt scb + , snd_max = if cantsndmore tcb && acknum > snd_max scb then acknum else snd_max scb + , snd_wl1 = seqnum `seq_plus` 1 + , snd_wl2 = acknum + , snd_wnd = fromIntegral (tcp_win seg) `shiftL` snd_scale' + , snd_cwnd = if acknum > (iss tcb `seq_plus` 1) + then min snd_cwnd' (tcp_maxwin `shiftL` snd_scale') + else snd_cwnd' + , t_rttseg = t_rttseg' + , t_rttinf = t_rttinf' + } + , cb_rcv = rcb + { rcvq = rcvq' + , tt_delack = False + , rcv_nxt = rcv_nxt' + , rcv_wnd = rcv_wnd' + , tf_rxwin0sent = (rcv_wnd' == 0) + , rcv_adv = rcv_nxt' `seq_plus` (( rcv_wnd' `shiftR` rcv_scale') `shiftL` rcv_scale') + , last_ack_sent = rcv_nxt' + } + , cb = tcb + { --local_addr = tcp_dst seg + rcv_scale = rcv_scale' + , snd_scale = snd_scale' + , tf_doing_ws = tf_doing_ws' + , irs = seqnum + , t_maxseg = t_maxseg'' + , tf_req_tstmp = tf_doing_tstmp' + , tf_doing_tstmp = tf_doing_tstmp' + , cantrcvmore = cantrcvmore' + } + }; + } + put_sock newsock + emit_segs_ [ TCPMessage $ make_ack_segment (clock h) newsock + (cantsndmore tcb && acknum < (iss tcb `seq_plus` 2)) (ticks h)] + tcp_wakeup + return () + + diff --git a/src/Network/TCP/LTS/InData.hs b/src/Network/TCP/LTS/InData.hs new file mode 100644 index 0000000..c88d2ec --- /dev/null +++ b/src/Network/TCP/LTS/InData.hs @@ -0,0 +1,385 @@ +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.LTS.InData where + +import Foreign +import Foreign.C +import Control.Exception +import Control.Monad +import Data.List as List + +import Network.TCP.Type.Base +import Network.TCP.Type.Timer +import Network.TCP.Type.Socket +import Network.TCP.Type.Datagram as Datagram +import Network.TCP.Type.Syscall + +import Network.TCP.Aux.Param +import Network.TCP.Aux.Misc +import Network.TCP.Aux.Output +import Hans.Layer.Tcp.Monad +import Network.TCP.Aux.SockMonad + +import Network.TCP.LTS.User +import Network.TCP.LTS.Out + +deliver_in_3 seg = + do sock <- get_sock + h <- get_host_ + --debug $ "deliver_in_3 " ++ (show seg) + let tcb = cb sock + scb = cb_snd sock + acknum = SeqLocal (tcp_ack seg) + seqnum = SeqForeign (tcp_seq seg) `seq_plus` + if tcp_SYN seg then 1 else 0 + seg_win = fromIntegral (tcp_win seg) `shiftL` (snd_scale tcb) + let wesentafin = (snd_max scb) > (snd_una scb `seq_plus` (bufc_length $ sndq scb)) + ourfinisacked = wesentafin && tcp_ACK seg && acknum >= (snd_max scb) + + -- update idle time + -- seqnum bound checking + drop_it <- di3_topstuff seg seqnum acknum h + when (not drop_it) $ do + -- acknum bound checking + -- fast retransmit + -- correct bad retransmit + -- update send queue + ack_ok <- di3_ackstuff seg seqnum acknum seg_win h ourfinisacked + when ack_ok $ do + -- update send window + -- receive data + fin_reass <- di3_datastuff seg seqnum acknum seg_win h ourfinisacked + -- update socket state + di3_ststuff fin_reass h ourfinisacked acknum + tcp_wakeup + tcp_output_all + get_sock + +{-# INLINE di3_topstuff #-} +di3_topstuff seg seqnum acknum h = + do sock <- get_sock + let tcb = cb sock + scb = cb_snd sock + rcb = cb_rcv sock + let rseq = seqnum `seq_plus` (bufc_length $ tcp_data seg) + let seg_ts = tcp_ts seg + -- PAWS check: -- todo + let paws_failed = False + let rcv_wnd' = calculate_bsd_rcv_wnd sock + let segment_off_right_hand_edge = + (seqnum >= (rcv_nxt rcb `seq_plus` rcv_wnd')) + && (rseq > (rcv_nxt rcb `seq_plus` rcv_wnd')) + && (rcv_wnd' /= 0) + let drop_it = paws_failed || segment_off_right_hand_edge + let Just seg_ts_val = seg_ts + let (tt_keep', tt_fin_wait_2') = update_idle (clock h) sock + let ts_recent'' = if not drop_it && seg_ts /= Nothing && seqnum <= (last_ack_sent rcb) + then create_timewindow (clock h) dtsinval (fst $ seg_ts_val) + else ts_recent $ cb_time sock + modify_cb_time $ \t -> t { tt_keep = tt_keep' + , tt_fin_wait_2 = tt_fin_wait_2' + , t_idletime = clock h + , ts_recent = ts_recent'' + } + return drop_it + +{-# INLINE di3_ackstuff #-} +di3_ackstuff seg seqnum acknum seg_win h ourfinisacked = + do sock <- get_sock + let scb = cb_snd sock + if acknum > snd_max scb then return False + else if acknum > snd_una scb + then di3_newackstuff sock seg acknum h ourfinisacked + else di3_oldackstuff sock seg seqnum acknum seg_win h + +{-# INLINE di3_oldackstuff #-} +di3_oldackstuff sock seg seqnum acknum seg_win h = + let tcb = cb sock + scb = cb_snd sock + rcb = cb_rcv sock in + let has_data = bufc_length (tcp_data seg) > 0 + && (rcv_nxt rcb) < (seqnum `seq_plus` (bufc_length $ tcp_data seg)) + && seqnum < ( (rcv_nxt rcb) `seq_plus` (rcv_wnd rcb)) in + let maybe_dup_ack = not has_data + && seg_win == (snd_wnd scb) + && mode_of (tt_rexmt scb) == Just Rexmt in + if not maybe_dup_ack then do + modify_cb_snd $ \c -> c { t_dupacks = 0 } + return True + else + let t_dupacks' = t_dupacks scb + 1 in + if acknum < (snd_una scb) then + do modify_cb_snd $ \c -> c { t_dupacks = 0} + return False + else if t_dupacks' < 3 then + do modify_cb_snd $ \c -> c { t_dupacks = t_dupacks'} + return True -- in case FIN is set + else if t_dupacks' > 3 || (t_dupacks' == 3 && tcp_do_newreno && acknum < (snd_recover tcb)) then + do modify_cb_snd $ \c -> c { t_dupacks = if t_dupacks' == 3 then 0 else t_dupacks' + , snd_cwnd = (snd_cwnd scb) + (t_maxseg tcb) + } + tcp_output False + return False + else -- t_dupacks' == 3 && not (tcp_do_newreno && acknum < (snd_recover tcb)) + do modify_cb_snd $ \c -> c { t_dupacks = t_dupacks' + , tt_rexmt = Nothing + , t_rttseg = Nothing + , snd_nxt = acknum + , snd_cwnd = t_maxseg tcb + } + modify_cb $ \c -> c { snd_ssthresh = (max 2 ( (min (snd_wnd scb) (snd_cwnd scb)) + `div` 2 `div` (t_maxseg tcb))) + * (t_maxseg tcb) + , snd_recover = if tcp_do_newreno then snd_max scb else snd_recover c + } + tcp_output False + modify_cb_snd $ \c -> c { snd_cwnd = (snd_ssthresh tcb) + (t_maxseg tcb) * t_dupacks' + , snd_nxt = max (snd_nxt scb) (snd_nxt c) + } + return False + +{-# INLINE di3_newackstuff #-} +di3_newackstuff sock seg acknum h ourfinisacked = + do let seg_ts = tcp_ts seg + let tcb = cb sock + scb = cb_snd sock + rcb = cb_rcv sock + if (not tcp_do_newreno) || t_dupacks scb < 3 then + modify_cb_snd $ \c->c { t_dupacks = 0 + , snd_cwnd = if t_dupacks c >= 3 + then min (snd_cwnd c) (snd_ssthresh tcb) + else snd_cwnd c } + -- below: tcp_do_newreno && t_dupacks scb >= 3 + else if acknum < (snd_recover tcb) then + do modify_cb_snd $ \c -> c { tt_rexmt = Nothing + , t_rttseg = Nothing + , snd_nxt = acknum + , snd_cwnd = t_maxseg tcb } + tcp_output False + modify_cb_snd $ \c->c{ snd_cwnd = (snd_cwnd c-(acknum `seq_diff` (snd_una c))+(t_maxseg tcb)) + , snd_nxt = snd_nxt scb } + else --acknum >= snd_recover tcb + modify_cb_snd $ \c -> c { t_dupacks = 0 + , snd_cwnd = if snd_max c `seq_diff` acknum < (snd_ssthresh tcb) + then snd_max c `seq_diff` acknum + (t_maxseg tcb) + else snd_ssthresh tcb } + + let revert_rexmt = mode_of (tt_rexmt scb) `elem` [ Just Rexmt, Just RexmtSyn ] + && shift_of (tt_rexmt scb) == 1 + && timewindow_open (clock h) (t_badrxtwin $ cb_time sock) + when revert_rexmt $ do + modify_cb_snd $ \c -> c { snd_cwnd = snd_cwnd_prev tcb + , snd_nxt = snd_max scb + } + modify_cb_time $ \c -> c { t_badrxtwin = Nothing } + modify_cb $ \c -> c { snd_ssthresh = snd_ssthresh_prev tcb } + + -- to understand: timestamping + let emission_time = case (seg_ts, t_rttseg scb) of + (Just (ts_val, ts_ecr), _ ) -> Just (ts_ecr `seq_minus` 1) + (Nothing, Just (ts0, seq0)) -> if acknum > seq0 then Just ts0 else Nothing + (Nothing, Nothing) -> Nothing + -- to understand: rtt update + let t_rttinf' = case emission_time of + Just emtime -> assert ((ticks h) >= emtime) $ + update_rtt ( ((ticks h) `seq_diff` emtime)*10*1000 ) (t_rttinf scb) + Nothing -> t_rttinf scb + let tt_rexmt' = if acknum == snd_max scb then + Nothing + else case mode_of (tt_rexmt scb) of + Nothing -> start_tt_rexmt 0 True t_rttinf' (clock h) + Just Rexmt -> start_tt_rexmt 0 True t_rttinf' (clock h) + _ -> tt_rexmt scb + let (snd_wnd', sndq') = if ourfinisacked then + (snd_wnd scb - (bufc_length $ sndq scb), bufferchain_empty) + else + (snd_wnd scb - (acknum `seq_diff` (snd_una scb)), + bufferchain_drop (acknum `seq_diff` (snd_una scb)) (sndq scb)) + + modify_cb_snd $ \c -> c { t_rttinf = t_rttinf' + , tt_rexmt = tt_rexmt' + , t_rttseg = if emission_time == Nothing then t_rttseg c else Nothing + , snd_cwnd = if not tcp_do_newreno || t_dupacks scb == 0 then + expand_cwnd (snd_ssthresh tcb) + (t_maxseg tcb) + (tcp_maxwin `shiftL` (snd_scale tcb)) + (snd_cwnd c) + else snd_cwnd c + , snd_wnd = snd_wnd' + , snd_una = acknum + , snd_nxt = max acknum (snd_nxt c) + , sndq = sndq' + } + + when (st sock == TIME_WAIT) $ + modify_cb_time $ \c -> c { tt_2msl = Just (create_timer (clock h) (2*tcptv_msl))} + + if (st sock == LAST_ACK) && ourfinisacked then do + modify_sock tcp_close_temp + return False + else return True + +{-# INLINE di3_datastuff #-} +di3_datastuff seg seqnum acknum seg_win h ourfinisacked = do + sock <- get_sock + let tcb = cb sock + scb = cb_snd sock + rcb = cb_rcv sock + let update_send_window = + tcp_ACK seg + && seqnum <= ( (rcv_nxt rcb) `seq_plus` (rcv_wnd rcb) ) + && ( snd_wl1 scb < seqnum + || ( snd_wl1 scb == seqnum + && ( snd_wl2 scb < acknum + || ( snd_wl2 scb == acknum && seg_win > snd_wnd scb ) + ) + ) + || (st sock == SYN_RECEIVED && not (tcp_FIN seg) ) + ) + let seq_trimmed = max seqnum (min (rcv_nxt rcb) (seqnum `seq_plus` (bufc_length $ tcp_data seg))) + when update_send_window $ + --debug $ "send window updated" + modify_cb_snd $ \c -> c { snd_wnd = seg_win + , snd_wl1 = seq_trimmed + , snd_wl2 = acknum + } + if st sock == TIME_WAIT || (st sock == CLOSING && ourfinisacked) + then do modify_cb $ \c -> c { rcv_up = max (rcv_up c) (rcv_nxt rcb) } + return False + else di3_datastuff_really seg seqnum acknum seg_win h + +{-# INLINE di3_datastuff_really #-} +di3_datastuff_really seg seqnum acknum seg_win h = + do let dat = tcp_data seg + sock <- get_sock + let tcb = cb sock + scb = cb_snd sock + rcb = cb_rcv sock + + let trim_amt_left = if rcv_nxt rcb > seqnum + then min (rcv_nxt rcb `seq_diff` seqnum) (bufc_length dat) + else 0 + data_trimmed_left = bufferchain_drop trim_amt_left dat + seq_trimmed = seqnum `seq_plus` trim_amt_left + let data_trimmed_left_right = bufferchain_take (rcv_wnd rcb) data_trimmed_left + fin_trimmed = if bufc_length data_trimmed_left_right == + bufc_length data_trimmed_left then tcp_FIN seg else False + let rseg = TCPReassSegment { trs_seq = seq_trimmed + , trs_FIN = fin_trimmed + , trs_data = data_trimmed_left_right + } + -- processing incoming data + if seq_trimmed == rcv_nxt rcb + && seq_trimmed `seq_plus` (bufc_length data_trimmed_left_right) + `seq_plus` (if fin_trimmed then 1 else 0) > (rcv_nxt rcb) + && rcv_wnd rcb > 0 + then do + -- case 1: reassambling possible + let have_stuff_to_ack = bufc_length data_trimmed_left_right >0 || fin_trimmed + let delay_ack = st sock `elem` [ESTABLISHED, CLOSE_WAIT, FIN_WAIT_1, FIN_WAIT_2, CLOSING, LAST_ACK] + && have_stuff_to_ack && not fin_trimmed && List.null (t_segq rcb) + && not (tf_rxwin0sent rcb) + && tt_delack rcb == False + let rsegq = rseg:(t_segq rcb) + let (data_reass, rcv_nxt', fin_reass0, t_segq') = tcp_reass (rcv_nxt rcb) rsegq + let rcvq' = bufferchain_concat (rcvq rcb) data_reass + let rcv_wnd' = rcv_wnd rcb - (bufc_length data_reass) + modify_cb_rcv $ \c -> c + { tt_delack = if delay_ack then True else tt_delack c + , tf_shouldacknow = if have_stuff_to_ack then not delay_ack else tf_shouldacknow c + , t_segq = t_segq' + , rcv_nxt = rcv_nxt' + , rcv_wnd = rcv_wnd' + , rcvq = rcvq' + } + return fin_reass0 + else if seq_trimmed > (rcv_nxt rcb) + && seq_trimmed < ((rcv_nxt rcb) `seq_plus` (rcv_wnd rcb)) + && bufc_length data_trimmed_left_right + (if fin_trimmed then 1 else 0) > 0 + && rcv_wnd rcb > 0 + then do + -- case 2: wait for future reassambling + modify_cb_rcv $ \c -> c { t_segq = rseg:(t_segq c) + , tf_shouldacknow = True + } + return False + else if tcp_ACK seg && seq_trimmed == rcv_nxt rcb + && bufc_length dat + (if tcp_FIN seg then 1 else 0) == 0 then + -- case 3: no data + return False + else do + -- case 4: other cases... maybe windows is closed + modify_cb_rcv $ \c -> c { tf_shouldacknow = True } + return False + +{-# INLINE di3_ststuff #-} +di3_ststuff fin_reass h ourfinisacked acknum = + do sock <- get_sock + let tcb = cb sock + let enter_TIME_WAIT = do + modify_sock $ \s -> s { st = TIME_WAIT } + modify_cb_time $ \c -> c { tt_2msl = Just (create_timer (clock h) (2*tcptv_msl)) + , tt_keep = Nothing + , tt_conn_est = Nothing + , tt_fin_wait_2 = Nothing + } + modify_cb_snd $ \c -> c { tt_rexmt = Nothing } + modify_cb_rcv $ \c -> c { tt_delack = False } + + when fin_reass $ + modify_cb $ \s -> s { cantrcvmore = True } + + case (st sock, fin_reass) of + (SYN_RECEIVED,False) -> when (acknum >= (iss tcb) `seq_plus` 1 ) $ + modify_sock $ \s -> s + { st = if not (cantsndmore tcb) then ESTABLISHED else + if ourfinisacked then FIN_WAIT_2 else FIN_WAIT_1 + } + (SYN_RECEIVED, True) -> modify_sock $ \s -> s { st = CLOSE_WAIT } + (ESTABLISHED, False) -> return () + (ESTABLISHED, True) -> modify_sock $ \s -> s { st = CLOSE_WAIT } + (CLOSE_WAIT, _ ) -> return () + (FIN_WAIT_1, False) -> when ourfinisacked $ do + modify_sock $ \s -> s { st = FIN_WAIT_2 } + when (cantrcvmore tcb) $ + modify_cb_time $ \c -> c { tt_fin_wait_2 = + Just (create_timer (clock h) (tcptv_maxidle)) } + (FIN_WAIT_1, True) -> if ourfinisacked then enter_TIME_WAIT + else modify_sock $ \s->s { st=CLOSING } + (FIN_WAIT_2, False) -> return () + (FIN_WAIT_2, True) -> enter_TIME_WAIT + (CLOSING, _) -> when ourfinisacked enter_TIME_WAIT + (LAST_ACK, False) -> return () + (LAST_ACK, True) -> error "di3_ststuff" + (TIME_WAIT, _ ) -> return () diff --git a/src/Network/TCP/LTS/InMisc.hs b/src/Network/TCP/LTS/InMisc.hs new file mode 100644 index 0000000..8f03452 --- /dev/null +++ b/src/Network/TCP/LTS/InMisc.hs @@ -0,0 +1,36 @@ +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.LTS.InMisc where + + diff --git a/src/Network/TCP/LTS/InPassive.hs b/src/Network/TCP/LTS/InPassive.hs new file mode 100644 index 0000000..d124dc7 --- /dev/null +++ b/src/Network/TCP/LTS/InPassive.hs @@ -0,0 +1,192 @@ +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.LTS.InPassive where +import Data.List as List +import Control.Exception +import Control.Monad + +--import Foreign.C +import Network.TCP.Type.Base +import Network.TCP.Type.Syscall +import Network.TCP.Type.Timer +import Network.TCP.Type.Socket +import Network.TCP.Type.Datagram +import Network.TCP.Aux.Misc +import Network.TCP.Aux.Param +import Network.TCP.Aux.Output +import Hans.Layer.Tcp.Monad +import Network.TCP.Aux.SockMonad + +import Network.TCP.LTS.Out +import Network.TCP.LTS.User + +tcp_deliver_syn_packet seg = do + -- precondition: sid does not exist + -- try if seg matches a listening socket... + let sidlisten = SocketID ((get_port $ tcp_dst seg), TCPAddr (IPAddr 0,0)) + h <- get_host + haslisten <- has_sock sidlisten + if not haslisten then return () else do + -- matches a socket... + sock <- lookup_sock sidlisten + if st sock /= LISTEN then return () else do + -- now we find a listening socket maching incoming SYN=1 ACK=0 RST=0 + if accept_incoming_q0 (sock_listen sock) + then deliver_in_1 sidlisten sock seg + else return () + +deliver_in_1 sid sock seg = + do let newsid = SocketID ((get_port $ tcp_dst seg), tcp_src seg) + h <- get_host + -- at this point, newsid is an unique socket id in the system... + + -- drop the first sid from q0 if needed, append newsid to q0. + let lis1 = sock_listen sock + should_drop = drop_from_q0 lis1 + drop_sid = head $ lis_q0 lis1 + oldq = lis_q0 lis1 + newq = if should_drop then tail oldq else oldq + lis2 = lis1 { lis_q0 = newq++[newsid]} + -- update listening socket (sid) + update_sock sid $ \_ -> sock { sock_listen = lis2 } + -- delete old socket if needed + when should_drop $ tcp_close drop_sid + + -- Create a new socket + let advmss = mssdflt -- todo: lookup interface mss + advmss' = Nothing -- not advertising MSS (todo: change it) + + tf_rcvd_tstmp = case tcp_ts seg of Just _ -> True; Nothing -> False + tf_doing_tstmp' = False -- not doing timestamping (todo: change it) + + (rcvbufsize', sndbufsize', t_maxseg', snd_cwnd') = + calculate_buf_sizes advmss (tcp_mss seg) Nothing False + (freebsd_so_rcvbuf) (freebsd_so_sndbuf) tf_doing_tstmp' + + tf_doing_ws' = False -- not doing window scaling (todo: change it) + rcv_scale' = 0 + snd_scale' = 0 + rcv_window = min tcp_maxwin freebsd_so_rcvbuf + + newiss = SeqLocal 1000 -- beginning iss. (todo: add more randomness) + t_rttseg' = Just (ticks h, newiss) + seqnum = SeqForeign (tcp_seq seg) + acknum = SeqLocal (tcp_ack seg) + ack' = seqnum `seq_plus` 1 + cb_time' = (cb_time sock) + { tt_keep = Just (create_timer (clock h) tcptv_keep_idle) + , ts_recent = case (tcp_ts seg) of + Nothing -> ts_recent (cb_time sock) + Just (ts_val, ts_ecr) -> create_timewindow (clock h) (dtsinval) ts_val + } + cb_snd' = (cb_snd sock) + { tt_rexmt = start_tt_rexmt 0 False (t_rttinf (cb_snd sock)) (clock h) + , snd_una = newiss + , snd_max = newiss `seq_plus` 1 + , snd_nxt = newiss `seq_plus` 1 + , snd_cwnd = snd_cwnd' + , t_rttseg = t_rttseg' + } + cb_rcv' = (cb_rcv sock) + { rcv_wnd = rcv_window + , tf_rxwin0sent = (rcv_window == 0) + , last_ack_sent = ack' + , rcv_adv = ack' `seq_plus` rcv_window + , rcv_nxt = ack' + } + cb' = (cb sock) + { iss = newiss + , irs = seqnum + , rcv_up = seqnum `seq_plus` 1 + , t_maxseg = t_maxseg' + , t_advmss = advmss' + , rcv_scale = rcv_scale' + , snd_scale = snd_scale' + , tf_doing_ws = tf_doing_ws' + , tf_req_tstmp = tf_doing_tstmp' + , tf_doing_tstmp = tf_doing_tstmp' + , local_addr = tcp_dst seg + , remote_addr = tcp_src seg + , self_id = newsid + , parent_id = sid + } + -- create new socket (newsid) + let newsock = initial_tcp_socket + { st = SYN_RECEIVED + , cb = cb' + , cb_time = cb_time' + , cb_snd = cb_snd' + , cb_rcv = cb_rcv' + } + insert_sock newsid newsock + -- emit [SYN,ACK] packet + emit_segs [TCPMessage $ make_syn_ack_segment (clock h) newsock + (tcp_dst seg) (tcp_src seg) (ticks h) ] + +-- After receiving ACK on SYN_RECEIVED, a connection is established. +-- Now we need to update the queues of the listening socket... +di3_socks_update sid = do + h <- get_host + -- precondition: sid exists + newsock <- lookup_sock sid + let tcb = cb newsock + rcb = cb_rcv newsock + sidlisten = parent_id tcb + haslisten <- has_sock sidlisten + assert (haslisten) return () + listensock <- lookup_sock sidlisten + let lis1 = sock_listen listensock + assert (sid `elem` (lis_q0 lis1)) return () + -- found the listening socket! + if accept_incoming_q lis1 then do + -- delete socket from q0 + -- move into completed queue + let lis2 = lis1 { lis_q0 = List.delete sid (lis_q0 lis1) + , lis_q = sid : (lis_q lis1) + } + let rcv_window = calculate_bsd_rcv_wnd newsock + let newcb = (cb_rcv newsock) { rcv_wnd = rcv_window + , rcv_adv = (rcv_nxt rcb) `seq_plus` (rcv_wnd rcb) + } + update_sock sidlisten $ \_ -> listensock { sock_listen = lis2 } + update_sock sid $ \_ -> newsock { cb_rcv = newcb } + runSMonad sidlisten $ tcp_wakeup + else do + -- delete socket from q0, backlog full -> delete socket + let lis2 = lis1 { lis_q0 = List.delete sid (lis_q0 lis1) } + update_sock sidlisten $ \_ -> listensock { sock_listen = lis2 } + tcp_close sid + --endif + + diff --git a/src/Network/TCP/LTS/Out.hs b/src/Network/TCP/LTS/Out.hs new file mode 100644 index 0000000..3eeb227 --- /dev/null +++ b/src/Network/TCP/LTS/Out.hs @@ -0,0 +1,225 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.LTS.Out + ( tcp_output_all + , tcp_output + , tcp_close + , tcp_drop_and_close + ) +where + +import Hans.Message.Tcp + +import Data.List as List +import Network.TCP.Aux.Output +import Network.TCP.Aux.Misc +import Network.TCP.Type.Base +import Network.TCP.Type.Syscall +import Network.TCP.Type.Socket +import Hans.Layer.Tcp.Monad +import Network.TCP.Aux.SockMonad +import Control.Monad +import Control.Exception +import Foreign +import Network.TCP.Type.Timer +import Network.TCP.Type.Datagram as Datagram +import Network.TCP.Aux.Param + +tcp_output_all :: SMonad t () +tcp_output_all = do + h <- get_host_ + sock <- get_sock + let scb = cb_snd sock + tcb = cb sock + when ((st sock `elem` [ESTABLISHED, CLOSE_WAIT, FIN_WAIT_1, + FIN_WAIT_2, CLOSING, LAST_ACK, TIME_WAIT] + && (snd_una scb /= iss tcb)) -- does this make sense? + || ( st sock `elem` [SYN_SENT, SYN_RECEIVED] && + cantsndmore tcb && (tf_shouldacknow $ cb_rcv sock))) $ + output_loop h sock + +output_loop h sock = + let (sock1, outsegs) = tcp_output_really (clock h) False (ticks h) sock in + if List.null outsegs then + put_sock sock1 + else do + --debug $ "tcp_output_all: " ++ (show outsegs) + emit_segs_ $! outsegs + output_loop h sock1 + +{-# INLINE tcp_output_all #-} +{-# INLINE output_loop #-} + + +{-# INLINE tcp_output_really #-} + +tcp_output_really (curr_time :: Time) (window_probe::Bool) (ts_val'::Timestamp) tcp_sock = + let tcb = cb tcp_sock + scb = cb_snd tcp_sock + rcb = cb_rcv tcp_sock + in + assert ((rcv_adv rcb) >= (rcv_nxt rcb)) $ + assert ((snd_nxt scb) >= (snd_una scb)) $ + let snd_cwnd' = if snd_max scb == snd_una scb && + (t_idletime $ cb_time tcp_sock) - curr_time + >= (computed_rxtcur $ t_rttinf scb) + then (t_maxseg tcb) * ss_fltsz -- has been idle for a while, slowstart + else snd_cwnd scb + win0 = min (snd_wnd scb) snd_cwnd' + win = if window_probe && win0==0 then 1 else win0 + snd_wnd_unused ::Int = win - ((snd_nxt scb) `seq_diff` (snd_una scb)) + syn_not_acked = (st tcp_sock `elem` [SYN_SENT, SYN_RECEIVED]) + fin_required = (cantsndmore tcb && st tcp_sock `notElem` [FIN_WAIT_2, TIME_WAIT]) + last_sndq_data_seq = (snd_una scb) `seq_plus` (bufc_length $ sndq scb) + last_sndq_data_and_fin_seq = last_sndq_data_seq `seq_plus` + (if fin_required then 1 else 0) `seq_plus` + (if syn_not_acked then 1 else 0) + have_data_to_send = (snd_nxt scb) < last_sndq_data_seq + have_data_or_fin_to_send = (snd_nxt scb) < last_sndq_data_and_fin_seq + window_update_delta = (min (tcp_maxwin `shiftL` (rcv_scale tcb)) + (freebsd_so_rcvbuf - (bufc_length $ rcvq rcb)) + ) - ( (rcv_adv rcb) `seq_diff` (rcv_nxt rcb)) + need_to_send_a_window_update = (window_update_delta >= 2 * (t_maxseg tcb)) || + (2*window_update_delta >= freebsd_so_rcvbuf) + do_output = ( have_data_or_fin_to_send && (if have_data_to_send then snd_wnd_unused>0 else True) ) + || need_to_send_a_window_update -- sndurp tcp_sock /= Nothing + || tf_shouldacknow rcb + cant_send = (not do_output) && + (bufc_length (sndq scb) > 0 ) && + mode_of (tt_rexmt scb) == Nothing + window_shrunk = win==0 && + snd_wnd_unused <0 && + st tcp_sock /= SYN_SENT + tcp_sock0 = if cant_send then + tcp_sock { cb_snd = scb {tt_rexmt = start_tt_persist 0 (t_rttinf scb) curr_time}} + else if window_shrunk then + tcp_sock { cb_snd = scb { + tt_rexmt = case tt_rexmt scb of + Just(Timed (Persist, shift) d ) -> Just (Timed (Persist, 0) d) + _ -> start_tt_persist 0 (t_rttinf scb) curr_time + , snd_nxt = snd_una scb + }} + else tcp_sock + in + if (not do_output) then (tcp_sock0, []) else + ------------ really do it --------------------------------------------- + let tcp_sock = tcp_sock0 + scb = cb_snd tcp_sock + + data' = bufferchain_drop (snd_nxt scb `seq_diff` (snd_una scb)) (sndq scb) + data_to_send = bufferchain_take (min (snd_wnd_unused) ( t_maxseg tcb)) data' + bFIN = fin_required && (snd_nxt scb) `seq_plus` (bufc_length data_to_send) >= last_sndq_data_seq + bACK = if bFIN && st tcp_sock == SYN_SENT then False else True + snd_nxt' = if bFIN && + ((snd_nxt scb `seq_plus` (bufc_length data_to_send) == + last_sndq_data_seq `seq_plus` 1 && snd_una scb /= iss tcb ) + || (snd_nxt scb) `seq_diff` (iss tcb) == 2) + then snd_nxt scb `seq_minus` 1 + else snd_nxt scb + bPSH = bufc_length data_to_send > 0 && + snd_nxt scb `seq_plus` (bufc_length data_to_send) == last_sndq_data_seq + rcv_wnd'' = calculate_bsd_rcv_wnd tcp_sock + rcv_wnd' = max (rcv_adv rcb `seq_diff` (rcv_nxt rcb)) + (min (tcp_maxwin `shiftL` (rcv_scale tcb)) + (if rcv_wnd'' < (freebsd_so_rcvbuf `div` 4) && rcv_wnd'' < (t_maxseg tcb) + then 0 else rcv_wnd'')) + want_tstmp = if st tcp_sock == SYN_SENT then tf_req_tstmp tcb else tf_doing_tstmp tcb + ts_ = do_tcp_options curr_time want_tstmp (ts_recent $ cb_time tcp_sock) ts_val' + in + let win_ = rcv_wnd' `shiftR` (rcv_scale tcb) + hdr = set_tcp_ts ts_ emptyTcpHeader + { tcpSeqNum = TcpSeqNum (seq_val snd_nxt') + , tcpAckNum = TcpAckNum (fseq_val (rcv_nxt rcb)) + , tcpAck = bACK + , tcpPsh = bPSH + , tcpFin = bFIN + , tcpWindow = fromIntegral win_ + } + seg = mkTCPSegment' (local_addr tcb) (remote_addr tcb) hdr data_to_send + st' = if bFIN then + case st tcp_sock of + ESTABLISHED -> FIN_WAIT_1 + CLOSE_WAIT -> LAST_ACK + xxx -> xxx + else + st tcp_sock + snd_nxt'' = snd_nxt' `seq_plus` (bufc_length data_to_send) `seq_plus` (if bFIN then 1 else 0) + snd_max' = max (snd_max scb) snd_nxt'' + tt_rexmt' = if (mode_of (tt_rexmt scb) == Nothing || + (mode_of (tt_rexmt scb) == Just Persist && not window_probe)) && + snd_nxt'' > (snd_una scb) then + start_tt_rexmt 0 False (t_rttinf scb) curr_time + else if (window_probe {-- || sndurp tcp_sock /= Nothing --} ) && win0 /= 0 && + mode_of (tt_rexmt scb) == Just Persist then + Nothing + else + tt_rexmt scb + t_rttseg' = if t_rttseg scb == Nothing && (bufc_length data_to_send > 0 || bFIN) && + snd_nxt'' > (snd_max scb) && not window_probe then + Just (ts_val', snd_nxt') + else + t_rttseg scb + tcp_sock' = tcp_sock + { st = st' + , cb_snd = scb { tt_rexmt = tt_rexmt' + , snd_cwnd = snd_cwnd' + , t_rttseg = t_rttseg' + , snd_max = snd_max' + , snd_nxt = snd_nxt'' + } + , cb_rcv = rcb { last_ack_sent = rcv_nxt rcb + , rcv_adv = rcv_nxt rcb `seq_plus` rcv_wnd' + , tt_delack = False + , rcv_wnd = rcv_wnd' + , tf_rxwin0sent = (rcv_wnd' == 0) + , tf_shouldacknow = False + } + } + outsegs' = [TCPMessage seg] + in + (tcp_sock', outsegs') + +{-# INLINE tcp_output #-} +tcp_output :: Bool -> SMonad t () +tcp_output win_probe = + do sock <- get_sock + h <- get_host_ + let (newsock, segs) = tcp_output_really (clock h) win_probe (ticks h) sock + put_sock newsock + emit_segs_ segs + --if List.null segs then return () else debug $ "tcp_output: " ++ (show segs) + + + diff --git a/src/Network/TCP/LTS/Time.hs b/src/Network/TCP/LTS/Time.hs new file mode 100644 index 0000000..6745e29 --- /dev/null +++ b/src/Network/TCP/LTS/Time.hs @@ -0,0 +1,261 @@ +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.LTS.Time + ( tcp_update_timers + ) +where + +import Foreign +import Foreign.C +import Data.Maybe +import Data.Map as Map +import Data.List as List +import Control.Monad + +import Network.TCP.Type.Base +import Network.TCP.Type.Datagram as Datagram +import Network.TCP.Type.Syscall +import Network.TCP.Type.Socket +import Network.TCP.Type.Timer +import Hans.Layer.Tcp.Monad +import Network.TCP.Aux.SockMonad +import Network.TCP.Aux.Output +import Network.TCP.Aux.Misc +import Network.TCP.Aux.Param +import Network.TCP.LTS.Out + +tcp_update_timers :: HMonad t () +tcp_update_timers = + do h <- get_host + when (clock h >= (fst $ next_timers h)) $ do + mapM update_fasttimer (keys (sock_map h)) + modify_host $ \h->h {next_timers = ( (fst $ next_timers h) + 200*1000, (snd $ next_timers h))} + + h <- get_host + when (clock h >= (snd $ next_timers h)) $ do + mapM update_slowtimer (keys (sock_map h)) + modify_host $ \h->h {next_timers = ((fst $ next_timers h), (snd $ next_timers h) + 500*1000)} + +update_fasttimer sid = + runSMonad sid $ do + sock <- get_sock + let tcb = cb_rcv sock + when (tt_delack tcb) $ + timer_tt_delack_1 sid sock + +timer_tt_delack_1 sid sock = + do modify_sock $ \sock-> sock { cb_rcv = (cb_rcv sock) { tt_delack = False } } + tcp_output False + +update_slowtimer sid = + do h <- get_host + sock <- lookup_sock sid + let scb = cb_snd sock + tcb = cb_time sock + curr_time = clock h + when (maybe_timed_expires curr_time (tt_rexmt scb)) $ + case tt_rexmt scb of + Just (Timed (RexmtSyn, shift) tmr) -> timer_tt_rexmtsyn sid sock (shift) + Just (Timed (Rexmt, shift) tmr) -> timer_tt_rexmt sid sock (shift) + Just (Timed (Persist, shift) tmr) -> timer_tt_persist sid sock (shift) + _ -> return () + + when (maybe_timer_expires curr_time $ tt_keep tcb) $ + timer_tt_keep sid sock + + when (maybe_timer_expires curr_time $ tt_conn_est tcb) $ + timer_tt_conn_est sid + + when (maybe_timer_expires curr_time $ tt_2msl tcb) $ + timer_tt_2msl sid + + when (maybe_timer_expires curr_time $ tt_fin_wait_2 tcb) $ + timer_tt_fin_wait_2 sid + +timer_tt_rexmtsyn sid sock shift = + let tcb = cb sock in + let scb = cb_snd sock in + when (st sock == SYN_SENT) $ do + if shift+1 >= tcp_maxrxtshift then tcp_drop_and_close sid else do + h <- get_host + let { + (snd_cwnd_prev', snd_ssthresh_prev', t_badrxtwin') = + if shift==0 && (tf_srtt_valid $ t_rttinf scb) then + (snd_cwnd scb, snd_ssthresh tcb, create_timewindow (clock h) (t_srtt (t_rttinf scb) `div` 2 ) ()) + else + (snd_cwnd_prev tcb, snd_ssthresh_prev tcb, t_badrxtwin $ cb_time sock); + tf_req_tstmp' = if shift==2 then False else tf_req_tstmp tcb; + req_r_scale' = if shift==2 then Nothing else request_r_scale tcb; + t_rttinf' = if shift+1 > tcp_maxrxtshift `div` 4 + then (t_rttinf scb) { tf_srtt_valid = False} else t_rttinf scb; + newsock = sock + { cb_snd = scb + { tt_rexmt = start_tt_rexmtsyn (shift+1) False (t_rttinf scb) (clock h) + , t_rttinf = t_rttinf' { t_lastshift = shift+1, t_wassyn = True } + , snd_cwnd = t_maxseg tcb + , t_dupacks = 0 + , t_rttseg = Nothing + } + , cb_time = (cb_time sock) + { t_badrxtwin = t_badrxtwin' + } + , cb = tcb + { tf_req_tstmp = tf_req_tstmp' + , request_r_scale = req_r_scale' + , snd_ssthresh = (t_maxseg tcb) * + (max 2 (min (snd_wnd scb) (snd_cwnd scb) `div` (2 * (t_maxseg tcb)))) + , snd_cwnd_prev = snd_cwnd_prev' + , snd_ssthresh_prev = snd_ssthresh_prev' + } + } + } + update_sock sid $ \_ -> newsock + emit_segs [ TCPMessage $ make_syn_segment (clock h) newsock (ticks h)] + +timer_tt_rexmt sid sock shift = + let tcb = cb sock in + let scb = cb_snd sock in + when (st sock `notElem` [CLOSED,SYN_SENT,CLOSE_WAIT,FIN_WAIT_2,TIME_WAIT]) $ + if shift+1 > (if st sock == SYN_RECEIVED then tcp_synackmaxrxtshift else tcp_maxrxtshift) + then tcp_drop_and_close sid else do + + h <- get_host + let { + (snd_cwnd_prev', snd_ssthresh_prev', t_badrxtwin') = + if shift+1==1 && tf_srtt_valid (t_rttinf scb) then + (snd_cwnd scb, snd_ssthresh tcb, + create_timewindow (clock h) ( t_srtt (t_rttinf scb) `div` 2 ) () ) + else + (snd_cwnd_prev tcb, snd_ssthresh_prev tcb, t_badrxtwin $ cb_time sock); + t_rttinf' = if shift+1 > tcp_maxrxtshift `div` 4 then + (t_rttinf scb) { tf_srtt_valid = False + , t_srtt = (t_srtt $ t_rttinf scb) `div` 4 + } + else t_rttinf scb; + sock1 = sock + { cb_snd = scb + { tt_rexmt = start_tt_rexmt (shift+1) False (t_rttinf scb) (clock h) + , t_rttinf = t_rttinf' { t_lastshift = shift + 1 + , t_wassyn = False + } + , snd_nxt = (snd_una scb) + , t_rttseg = Nothing + , snd_cwnd = t_maxseg tcb + , t_dupacks = 0 + } + , cb_time = (cb_time sock) { t_badrxtwin = t_badrxtwin' } + , cb = tcb { snd_recover = snd_max scb + , snd_ssthresh = (t_maxseg tcb) * (max 2 + (min (snd_wnd scb) (snd_cwnd scb) `div` (2 * (t_maxseg tcb)))) + , snd_cwnd_prev = snd_cwnd_prev' + , snd_ssthresh_prev = snd_ssthresh_prev' + } + }; + } + if st sock == SYN_RECEIVED then do + + emit_segs [ TCPMessage $ make_syn_ack_segment + (clock h) sock1 (local_addr tcb) (remote_addr tcb) (ticks h) ] + update_sock sid $ \_ -> sock1 { cb_snd = (cb_snd sock1) + { snd_nxt = (snd_nxt $ cb_snd $ sock1) `seq_plus` 1 }} + + else if st sock == LISTEN then do + + let seg' = bsd_make_phantom_segment (clock h) sock1 + (local_addr tcb) (remote_addr tcb) (ticks h) (cantsndmore tcb) + emit_segs [ TCPMessage $ seg'] + update_sock sid $ \_ -> sock1 { cb_snd = (cb_snd sock1) + { tt_rexmt = if tcp_FIN seg' then tt_rexmt (cb_snd sock1) else Nothing } } + + else runSMonad sid $ do + put_sock sock1 + tcp_output False + +timer_tt_persist sid sock shift = + runSMonad sid $ do + h <- get_host_ + let scb = cb_snd sock + put_sock $ sock { cb_snd = scb {tt_rexmt = start_tt_persist (shift+1) (t_rttinf scb) (clock h) }} + tcp_output True + +timer_tt_keep sid sock = + do h <- get_host + let tcb = cb sock + scb = cb_snd sock + rcb = cb_rcv sock + let win_ = (rcv_wnd rcb `shiftR` (rcv_scale tcb)) + let ts'= if tf_doing_tstmp tcb then + let ts_ecr' = case timewindow_val (clock h) (ts_recent $ cb_time sock) of + Just q -> q + Nothing -> Timestamp 0 + in + Just ( (ticks h), ts_ecr') + else + Nothing + let seg = TCPSegment + { tcp_src = local_addr tcb + , tcp_dst = remote_addr tcb + , tcp_seq = snd_una scb `seq_minus` 1 + , tcp_ack = rcv_nxt rcb + , tcp_URG = False + , tcp_ACK = True + , tcp_PSH = False + , tcp_RST = False + , tcp_SYN = False + , tcp_FIN = False + , tcp_win = win_ + , tcp_urp = 0 + , tcp_data = bufferchain_empty + -- option: window scaling + , tcp_ws = Nothing + -- option: max segment size + , tcp_mss = Nothing + -- option: RFC1323 + , tcp_ts = ts' + } + emit_segs [TCPMessage seg] + update_sock sid $ \_ -> sock { cb_time = (cb_time sock) { tt_keep = Just (create_timer (clock h) tcptv_keepintvl) } + , cb_rcv = rcb { last_ack_sent = tcp_ack seg } + } + return () + +timer_tt_conn_est sid = + tcp_drop_and_close sid + +timer_tt_2msl sid = + tcp_close sid + +timer_tt_fin_wait_2 sid = + tcp_close sid + diff --git a/src/Network/TCP/LTS/User.hs b/src/Network/TCP/LTS/User.hs new file mode 100644 index 0000000..e00c242 --- /dev/null +++ b/src/Network/TCP/LTS/User.hs @@ -0,0 +1,293 @@ +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.LTS.User + ( tcp_process_user_request + , tcp_wakeup + ) +where + +import Foreign.C +import Data.List as List +import Control.Monad +import Data.Maybe +import Network.TCP.Type.Base +import Network.TCP.Type.Datagram +import Network.TCP.Type.Syscall +import Network.TCP.Type.Socket +import Hans.Layer.Tcp.Monad +import Network.TCP.Aux.SockMonad +import Network.TCP.Aux.Misc +import Network.TCP.Aux.Param +import Network.TCP.Aux.Output + +import Network.TCP.LTS.Out + +--------------------------------------------------------------------------------------- +-- input: a list of sock request +-- output: threads that have taken the completed requests (back in running state) +-- side effect: host state changes +-- blocked threads goes into the wait queue of each socket +-- tcp_process_user_requests :: (Monad m) => [(SockReq,SockRsp->t)] -> HMonad t m [t] +-- tcp_process_user_requests reqs = +-- do r <- mapM tcp_process_user_request reqs +-- return $ concat r + +tcp_process_user_request :: (SockReq,SockRsp->t) -> HMonad t (Maybe t) +tcp_process_user_request (req, cont) = + case req of + SockListen addr -> process_listen addr cont + SockClose sid -> process_close sid cont + SockConnect local addr -> process_connect local addr cont + SockAccept sid -> process_accept sid cont + SockSend sid d -> process_send sid d cont + SockRecv sid -> process_recv sid cont + +tcp_wakeup_request req cont = + case req of + SockConnect local addr -> wakeup_connect cont + SockAccept sid -> wakeup_accept sid cont + SockSend sid d -> wakeup_send sid d cont + SockRecv sid -> wakeup_recv sid cont + +-- pre-cond: sock is set +-- post-cond: sock is set, tcp_output needed +tcp_wakeup = + do sock <- get_sock + case waiting_list sock of + [] -> return () + (req,cont):reqs -> do + res <- tcp_wakeup_request req cont + case res of + Nothing -> return () + Just th -> do + emit_ready_ [th] + modify_sock $ \s -> s {waiting_list = reqs} + +-- pre-cond: sock not set +-- post-cond: sock not set +process_listen :: Port -> (SockRsp->t) -> HMonad t (Maybe t) +process_listen port cont = + do let sock_id = SocketID (port, TCPAddr (IPAddr 0,0)) + h <- get_host + -- check if port has been used... + if port `elem` (local_ports h) then + do let listen = SocketListen [] [] listen_qlimit + let newsock = initial_tcp_socket + { cb = (cb initial_tcp_socket) { local_addr = TCPAddr (IPAddr 0,port), self_id=sock_id } + , st = LISTEN + , sock_listen = listen + } + insert_sock sock_id newsock + modify_host $ \h -> h { local_ports = List.delete port (local_ports h) } + return $ Just $ cont $ SockNew sock_id + else + return $ Just $ cont $ SockError "Port not available" + +process_close :: SocketID -> (SockRsp->t) -> HMonad t (Maybe t) +process_close sid cont = + do ok <- has_sock sid + if not ok then + return $ Just $ cont $ SockError "Socket not found" + else do + sock <- lookup_sock sid + if st sock `elem` [CLOSED,SYN_SENT,SYN_RECEIVED] then do + -- close_7 : delete sid + tcp_close sid + return $ Just $ cont $ SockOK + else if st sock /= LISTEN then runSMonad sid $ do + -- close_1 : change the flags so FIN can be sent later + modify_sock $ \sock-> sock { cb = (cb sock) { cantsndmore=True, cantrcvmore=True} + , cb_rcv = (cb_rcv sock) { rcvq=bufferchain_empty } + } + tcp_output_all + return $ Just $ cont $ SockOK + else do + -- close_8 : closing a LISTEN socket + -- todo: not implemented yet + return $ Just $ cont $ SockError "not implemented yet: close_8 : closing a LISTEN socket" + +-- pre-cond: sock not set +-- post-cond: sock not set +process_accept :: SocketID -> (SockRsp->t) -> HMonad t (Maybe t) +process_accept sid cont = + do ok <- has_sock sid + if not ok then + return $ Just $ cont $ SockError "Socket not found" + else runSMonad sid $ do + res <- try_accept cont + when (isNothing res) $ + -- put thread in waiting list + modify_sock $ \sock -> sock {waiting_list = (waiting_list sock)++[(SockAccept sid, cont)] } + return res +wakeup_accept sid cont + = try_accept cont + +-- pre-cond: sock is set +-- post-cond: sock is set, listening queue updated +try_accept :: (SockRsp->t) -> SMonad t (Maybe t) +try_accept cont = + do sock <- get_sock + if st sock /= LISTEN then + return $ Just $ cont $ SockError "Socket not in LISTEN state" + else do + -- find the listen queue + let listen = sock_listen sock + case lis_q listen of + [] -> return Nothing -- no connection to accept, can't proceed + (sid2:qs) -> do -- try to accept sid2, either success or fail + modify_sock $ \sock -> sock { sock_listen = listen { lis_q = qs } } + return $ Just $ cont $ SockNew sid2 + +process_recv :: SocketID -> (SockRsp->t) -> HMonad t (Maybe t) +process_recv sid cont = + do ok <- has_sock sid + if not ok then + return $ Just $ cont $ SockError "Socket not found" + else runSMonad sid $ do + res <- try_recv cont + when (isNothing res) $ + -- put thread in waiting list + modify_sock $ \sock -> sock {waiting_list = (waiting_list sock)++[(SockRecv sid, cont)] } + return res + +wakeup_recv sid cont = + try_recv cont + +try_recv :: (SockRsp->t) -> SMonad t (Maybe t) +try_recv cont = + do sock <- get_sock + let q = rcvq $ cb_rcv sock + if st sock `elem` [ CLOSED, SYN_SENT, SYN_RECEIVED] then + return $ Just $ cont $ SockError "Socket not in synchronized state" + else if bufc_length q == 0 then + if cantrcvmore $ cb sock + then return $ Just $ cont $ SockData buffer_empty -- EOF + else return Nothing -- no data, can't proceed + else do + -- let rcvnum = min size (length q) + -- (q1,q2) = splitAt rcvnum q + put_sock $ sock { cb_rcv = (cb_rcv sock) {rcvq = bufferchain_tail q }} + return $ Just $ cont $ SockData $ bufferchain_head q + +process_send :: SocketID -> Buffer -> (SockRsp->t) -> HMonad t (Maybe t) +process_send sid d cont = + do ok <- has_sock sid + if not ok then + return $ Just $ cont $ SockError "Socket not found" + else runSMonad sid $ do + (res,remain) <- try_send d cont + when (isNothing res) $ + -- put thread in waiting list + modify_sock $ \sock -> sock {waiting_list = (waiting_list sock)++[(SockSend sid remain, cont)] } + return res + +wakeup_send sid d cont = + do (res,remain) <- try_send d cont + when (isNothing res) $ + -- put thread in waiting list again + modify_sock $ \sock -> sock {waiting_list = (tail $ waiting_list sock)++[(SockSend sid remain, cont)] } + return res + +try_send :: Buffer -> (SockRsp->t) -> SMonad t (Maybe t, Buffer) +try_send d cont = + do sock <- get_sock + if st sock `notElem` [ ESTABLISHED, CLOSE_WAIT] then + return (Just $ cont $ SockError "Socket not in synchronized state", buffer_empty ) + else if cantsndmore $ cb sock then + return (Just $ cont $ SockError "Socket cantsndmore=true, cannot send...", buffer_empty) + else do + let max_can_send = freebsd_so_sndbuf - (bufc_length $ sndq $ cb_snd sock) + num_to_send = min max_can_send (buf_len d) + (d1,d2) = buffer_split num_to_send d + modify_cb_snd $ \c -> c { sndq = (sndq c) `bufferchain_append` d1 } + --if (bufc_length (sndq $ cb_snd $ sock) == 0) then + -- modify_cb_rcv $ \c -> c { tt_delack = True } + -- else + tcp_output_all + if buf_len d2 == 0 then return (Just $ cont $ SockOK, d2) + else return (Nothing, d2) + + +process_connect :: IPAddr -> TCPAddr -> (SockRsp->t) -> HMonad t (Maybe t) +process_connect local addr cont = do + h <- get_host + m_port <- alloc_local_port + if m_port == Nothing then return $ Just $ cont $ SockError "cannot allocate local port" else do + let (Just port) = m_port + sock_id = SocketID (port, addr) + newiss = SeqLocal 1000 -- beginning iss. (todo: add more randomness) + request_r_scale' = 0 + rcv_wnd' = freebsd_so_rcvbuf + adv_mss = Just mssdflt + tf_req_tstmp' = False -- todo: change it + t_rttseg' = Just (ticks h, newiss) + let { newsock = initial_tcp_socket + { st = SYN_SENT + , cb_time = initial_cb_time + { tt_conn_est = Just (create_timer (clock h) tcptv_keep_init) + } + , cb_snd = initial_cb_snd + { tt_rexmt = start_tt_rexmtsyn 0 False (t_rttinf initial_cb_snd) (clock h) + , snd_una = newiss + , snd_nxt = newiss `seq_plus` 1 + , snd_max = newiss `seq_plus` 1 + , t_rttseg = t_rttseg' + } + , cb_rcv = initial_cb_rcv + { rcv_wnd = rcv_wnd' + , rcv_adv = (rcv_nxt initial_cb_rcv) `seq_plus` rcv_wnd' + , tf_rxwin0sent = (rcv_wnd' == 0) + } + , cb = initial_cb_misc + { local_addr = TCPAddr (local,port) + , remote_addr = addr + , self_id=sock_id + , cantsndmore = False + , cantrcvmore = False + , iss = newiss + , request_r_scale = Just request_r_scale' + , t_advmss = adv_mss + , tf_req_tstmp = tf_req_tstmp' + } + }} + insert_sock sock_id newsock + emit_segs $ [TCPMessage $ make_syn_segment (clock h) newsock (ticks h)] + return $ Nothing + +wakeup_connect :: (SockRsp->t) -> SMonad t (Maybe t) +wakeup_connect cont = do + sock <- get_sock + if st sock == SYN_SENT + then return Nothing + else return $ Just $ cont $ SockNew $ self_id $ cb sock diff --git a/src/Network/TCP/Pure.hs b/src/Network/TCP/Pure.hs new file mode 100644 index 0000000..18db2b1 --- /dev/null +++ b/src/Network/TCP/Pure.hs @@ -0,0 +1,116 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.Pure +( module Network.TCP.Type.Syscall +, Host, IPMessage, Time +, tcp_init_host -- :: Time -> [Port] -> Host t +, tcp_user_req -- :: (SockReq, SockRsp -> t) -> Host t -> (Host t, Maybe t) +, tcp_user_rsp -- :: Host t -> (Host t, [t]) +, tcp_packet_in -- :: IPMessage -> Host t -> Host t +, tcp_packet_out -- :: Host t -> (Host t, [IPMessage]) +, tcp_timer -- :: Time -> Host t -> Host t +, tcp_timer_check -- :: Time -> Host t -> IO (Host t) +) +where + +import Network.TCP.Type.Base +import Network.TCP.Type.Datagram +import Network.TCP.Type.Socket +import Network.TCP.Type.Syscall +-- import TCP.Impl.PacketIO + +import Network.TCP.LTS.In +import Network.TCP.LTS.Out +import Network.TCP.LTS.Time +import Network.TCP.LTS.User + +import Hans.Layer.Tcp.Monad + +import Foreign.C +import Data.Map as Map + +tcp_init_host :: Time -> [Port] -> Host t +tcp_init_host curr_time ports = Host + { output_queue = [] + , sock_map = Map.empty + , clock = curr_time + , ticks = Timestamp 0 + , next_timers = (curr_time + 200*1000,curr_time + 500*1000) + , ready_list = [] + , local_ports = ports + } + +tcp_user_req :: (SockReq, SockRsp -> t) -> Host t -> (Host t, Maybe t) +tcp_user_req req h = + runHMonad_ (tcp_process_user_request req) h + +tcp_user_rsp :: Host t -> (Host t, [t]) +tcp_user_rsp h = ( h { ready_list = [] }, ready_list h ) + +tcp_timer :: Time -> Host t -> Host t +tcp_timer tm = runHMonad $ do + h <- get_host + let oldclock = clock h + newclock = tm + newtick = (ticks h) `seq_plus` (fromIntegral (newclock - oldclock) `div` 10000) + put_host $ h { clock = newclock, ticks=newtick } + tcp_update_timers + +tcp_packet_in :: IPMessage -> Host t -> Host t +tcp_packet_in (TCPMessage seg) = + runHMonad $ tcp_deliver_in_packet seg + +tcp_packet_out :: Host t -> (Host t, [IPMessage]) +tcp_packet_out h = ( h { output_queue = [] }, output_queue h ) + +tcp_timer_check :: Time -> Host t -> IO (Host t) +tcp_timer_check curr_time h = do + let deadline = min (fst $ next_timers h) (snd $ next_timers h) + error :: Integer = (fromIntegral curr_time) - (fromIntegral $ deadline) + percent = fromInteger error / 100000.0 * 100 + + if (curr_time > deadline +250*1000) then do + putStrLn $ "Warning: too slow to maintain 200ms/500ms timer ticks, distance (should be within 200000us) = "++ + (show $ curr_time - deadline) ++ "us" + return $ h { next_timers= (curr_time,curr_time) } + else return h + +{-# INLINE tcp_init_host #-} +{-# INLINE tcp_packet_in #-} +{-# INLINE tcp_packet_out #-} +{-# INLINE tcp_timer #-} +{-# INLINE tcp_user_req #-} +{-# INLINE tcp_user_rsp #-} +{-# INLINE tcp_timer_check #-} diff --git a/src/Network/TCP/Type/Base.hs b/src/Network/TCP/Type/Base.hs new file mode 100644 index 0000000..60d85e8 --- /dev/null +++ b/src/Network/TCP/Type/Base.hs @@ -0,0 +1,271 @@ +{-# OPTIONS_GHC -fglasgow-exts #-} + +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.Type.Base where + +import Data.Time.Clock.POSIX (POSIXTime,getPOSIXTime) +import Foreign +import Foreign.C +import System.IO.Unsafe +import Control.Exception +import qualified Data.ByteString as S +import qualified Data.ByteString.Lazy as L + + +to_Int x = (fromIntegral x)::Int +to_Int8 x = (fromIntegral x)::Int8 +to_Int16 x = (fromIntegral x)::Int16 +to_Int32 x = (fromIntegral x)::Int32 +to_Int64 x = (fromIntegral x)::Int64 + +to_Word x = (fromIntegral x)::Word +to_Word8 x = (fromIntegral x)::Word8 +to_Word16 x = (fromIntegral x)::Word16 +to_Word32 x = (fromIntegral x)::Word32 +to_Word64 x = (fromIntegral x)::Word64 + + +{-# INLINE to_Int #-} +{-# INLINE to_Int8 #-} +{-# INLINE to_Int16 #-} +{-# INLINE to_Int32 #-} +{-# INLINE to_Int64 #-} +{-# INLINE to_Word #-} +{-# INLINE to_Word8 #-} +{-# INLINE to_Word16 #-} +{-# INLINE to_Word32 #-} +{-# INLINE to_Word64 #-} + +-- Port numbers, IP addresses + +type Port = Word16 +newtype IPAddr = IPAddr Word32 deriving (Eq,Ord) +newtype TCPAddr = TCPAddr (IPAddr, Port) deriving (Eq,Ord) +newtype SocketID = SocketID (Port, TCPAddr) deriving (Eq,Ord,Show) + +instance Show IPAddr where + show (IPAddr w) = (show $ w .&. 255) ++ "." ++ + (show $ (w `shiftR` 8) .&. 255) ++ "." ++ + (show $ (w `shiftR` 16) .&. 255) ++ "." ++ + (show $ (w `shiftR` 24) .&. 255) +instance Show TCPAddr where + show (TCPAddr (ip,pt)) = (show ip) ++ ":" ++ (show pt) + + +get_ip :: TCPAddr -> IPAddr +get_ip (TCPAddr (i,p)) = i + +get_port :: TCPAddr -> Port +get_port (TCPAddr (i,p)) = p + +get_remote_addr :: SocketID -> TCPAddr +get_remote_addr (SocketID (p,a)) = a + +get_local_port :: SocketID -> Port +get_local_port (SocketID (p,a)) = p + +{-# INLINE get_ip #-} +{-# INLINE get_port #-} +{-# INLINE get_remote_addr #-} +{-# INLINE get_local_port #-} + +-- TCP Sequence numbers + +class (Eq a) => Seq32 a where + seq_val :: a -> Word32 + seq_lt :: a -> a -> Bool + seq_leq :: a -> a -> Bool + seq_gt :: a -> a -> Bool + seq_geq :: a -> a -> Bool + seq_plus :: (Integral n) => a -> n -> a + seq_minus :: (Integral n) => a -> n -> a + seq_diff :: (Integral n) => a -> a -> n + +instance Seq32 Word32 where + seq_val w = w + seq_lt x y = (to_Int32 (x-y)) < 0 + seq_leq x y = (to_Int32 (x-y)) <= 0 + seq_gt x y = (to_Int32 (x-y)) > 0 + seq_geq x y = (to_Int32 (x-y)) >= 0 + seq_plus s i = assert (i>=0) $ s + (to_Word32 i) + seq_minus s i = assert (i>=0) $ s - (to_Word32 i) + seq_diff s t = let res=fromIntegral $ to_Int32 (s-t) in assert (res>=0) res + {-# INLINE seq_val #-} + {-# INLINE seq_lt #-} + {-# INLINE seq_leq #-} + {-# INLINE seq_gt #-} + {-# INLINE seq_geq #-} + {-# INLINE seq_plus #-} + {-# INLINE seq_minus #-} + {-# INLINE seq_diff #-} + +newtype SeqLocal = SeqLocal Word32 deriving (Eq,Show,Seq32) +newtype SeqForeign = SeqForeign Word32 deriving (Eq,Show,Seq32) +newtype Timestamp = Timestamp Word32 deriving (Eq,Show,Seq32) + +instance Ord SeqLocal where + (<) = seq_lt + (>) = seq_gt + (<=) = seq_leq + (>=) = seq_geq + {-# INLINE (<) #-} + {-# INLINE (>) #-} + {-# INLINE (<=) #-} + {-# INLINE (>=) #-} +instance Ord SeqForeign where + (<) = seq_lt + (>) = seq_gt + (<=) = seq_leq + (>=) = seq_geq + {-# INLINE (<) #-} + {-# INLINE (>) #-} + {-# INLINE (<=) #-} + {-# INLINE (>=) #-} +instance Ord Timestamp where + (<) = seq_lt + (>) = seq_gt + (<=) = seq_leq + (>=) = seq_geq + {-# INLINE (<) #-} + {-# INLINE (>) #-} + {-# INLINE (<=) #-} + {-# INLINE (>=) #-} + +seq_flip_ltof (SeqLocal w) = SeqForeign w +seq_flip_ftol (SeqForeign w) = SeqLocal w + +fseq_val :: SeqForeign -> Word32 +fseq_val (SeqForeign w32) = w32 + +{-# INLINE seq_flip_ltof #-} +{-# INLINE seq_flip_ftol #-} + + +-- | Clock time, in microseconds. +type Time = Int64 + +seconds_to_time :: RealFrac a => a -> Time +seconds_to_time f = round (f * 1000*1000) + +{-# INLINE seconds_to_time #-} + +get_current_time :: IO Time +get_current_time = posixtime_to_time `fmap` getPOSIXTime + +posixtime_to_time :: POSIXTime -> Time +posixtime_to_time = seconds_to_time . toRational + +--------------------------------------------------------------------------- + +type Buffer = S.ByteString + +buf_len :: Buffer -> Int +buf_len = S.length + +buffer_ok :: Buffer -> Bool +buffer_ok _ = True + +buffer_empty :: Buffer +buffer_empty = S.empty + +buffer_to_string :: Buffer -> IO String +buffer_to_string = return . map (toEnum . fromEnum) . S.unpack + +string_to_buffer :: String -> IO Buffer +string_to_buffer = return . S.pack . map (toEnum . fromEnum) + +buffer_split :: Int -> Buffer -> (Buffer,Buffer) +buffer_split = S.splitAt + +buffer_take = S.take +buffer_drop = S.drop + +buffer_merge :: Buffer -> Buffer -> [Buffer] +buffer_merge bs1 bs2 + | S.length bs1 == 0 = [bs2] + | S.length bs2 == 0 = [bs1] + | otherwise = [bs1,bs2] + + +type BufferChain = L.ByteString + +bufc_length :: BufferChain -> Int +bufc_length = fromIntegral . L.length + +bufferchain_empty = L.empty +bufferchain_singleton b + | S.null b = L.empty + | otherwise = L.fromChunks [b] + +bufferchain_add bs bc = bufferchain_singleton bs `L.append` bc + +bufferchain_get :: BufferChain -> Int -> BufferChain +bufferchain_get bc ix = L.take 1 (L.drop (fromIntegral ix) bc) + +bufferchain_append bc bs = bc `L.append` bufferchain_singleton bs + +bufferchain_concat :: BufferChain -> BufferChain -> BufferChain +bufferchain_concat = L.append + +bufferchain_head :: BufferChain -> Buffer +bufferchain_head = head . L.toChunks + +bufferchain_tail :: BufferChain -> BufferChain +bufferchain_tail = L.fromChunks . tail . L.toChunks + +bufferchain_take :: Int -> BufferChain -> BufferChain +bufferchain_take = L.take . fromIntegral + +bufferchain_drop :: Int -> BufferChain -> BufferChain +bufferchain_drop = L.drop . fromIntegral + +bufferchain_split_at :: Int -> BufferChain -> (BufferChain,BufferChain) +bufferchain_split_at = L.splitAt . fromIntegral + +bufferchain_collapse :: BufferChain -> IO Buffer +bufferchain_collapse = return . S.concat . L.toChunks + +-- bufferchain_output bc@(BufferChain lst len) (ptr::Ptr CChar) = +-- copybuf ptr lst +-- where copybuf ptrDest [] = return () +-- copybuf ptrDest (x:xs) = +-- withForeignPtr (buf_ptr x) +-- (\ptrSrc -> do +-- copyArray ptrDest (ptrSrc `plusPtr` (buf_offset x)) (buf_len x) +-- copybuf (ptrDest `plusPtr` (buf_len x)) xs +-- ) + +bufferchain_ok :: BufferChain -> Bool +bufferchain_ok _ = True diff --git a/src/Network/TCP/Type/Datagram.hs b/src/Network/TCP/Type/Datagram.hs new file mode 100644 index 0000000..124cd59 --- /dev/null +++ b/src/Network/TCP/Type/Datagram.hs @@ -0,0 +1,217 @@ +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.Type.Datagram +( TCPSegment (..) +, UDPDatagram (..) +, Protocol (..) +, ICMPType (..) +, ICMPDatagram (..) +, IPMessage (..) +, tcp_ws +, set_tcp_ws +, tcp_mss +, set_tcp_mss +, tcp_ts +, set_tcp_ts +, tcp_seq +, tcp_ack +, tcp_URG +, tcp_ACK +, tcp_PSH +, tcp_RST +, tcp_SYN +, tcp_FIN +, tcp_win +, tcp_urp +, mkTCPSegment +, mkTCPSegment' +) +where + +import Hans.Address.IP4 (IP4,convertToWord32) +import Hans.Message.Tcp + (TcpHeader(..),TcpPacket(..),TcpPort(..),TcpAckNum(..),TcpSeqNum(..) + ,findTcpOption,setTcpOption,TcpOptionTag(..),TcpOption(..)) + +import Network.TCP.Type.Base + +data TCPSegment = TCPSegment + { tcp_src :: !TCPAddr + , tcp_dst :: !TCPAddr + , tcp_header :: !TcpHeader + , tcp_data :: !BufferChain + } + +mkTCPSegment :: IP4 -> IP4 -> TcpPacket -> TCPSegment +mkTCPSegment src dst (TcpPacket hdr body) = TCPSegment + { tcp_src = TCPAddr (IPAddr (convertToWord32 src),srcP) + , tcp_dst = TCPAddr (IPAddr (convertToWord32 dst),dstP) + , tcp_header = hdr + , tcp_data = bufferchain_singleton body + } + where + TcpPort srcP = tcpSourcePort hdr + TcpPort dstP = tcpDestPort hdr + +mkTCPSegment' :: TCPAddr -> TCPAddr -> TcpHeader -> BufferChain -> TCPSegment +mkTCPSegment' s@(TCPAddr (_, srcP)) d@(TCPAddr (_, dstP)) hdr body = + TCPSegment + { tcp_src = s + , tcp_dst = d + , tcp_header = hdr + { tcpSourcePort = TcpPort srcP + , tcpDestPort = TcpPort dstP + } + , tcp_data = body + } + where + +tcp_seq = getSeqNum . tcpSeqNum . tcp_header +tcp_ack = getAckNum . tcpAckNum . tcp_header +tcp_URG = tcpUrg . tcp_header +tcp_ACK = tcpAck . tcp_header +tcp_PSH = tcpPsh . tcp_header +tcp_RST = tcpRst . tcp_header +tcp_SYN = tcpSyn . tcp_header +tcp_FIN = tcpFin . tcp_header +tcp_win = tcpWindow . tcp_header +tcp_urp = tcpUrgentPointer . tcp_header + +tcp_ws :: TCPSegment -> Maybe Int +tcp_ws = fmap prj . findTcpOption OptTagWindowScaling . tcp_header + where + prj (OptWindowScaling ws) = fromIntegral ws + +set_tcp_ws :: Maybe Int -> TcpHeader -> TcpHeader +set_tcp_ws Nothing = id +set_tcp_ws (Just ws) = setTcpOption (OptWindowScaling (fromIntegral ws)) + +tcp_mss :: TCPSegment -> Maybe Int +tcp_mss = fmap prj . findTcpOption OptTagMaxSegmentSize . tcp_header + where + prj (OptMaxSegmentSize mss) = fromIntegral mss + +set_tcp_mss :: Maybe Int -> TcpHeader -> TcpHeader +set_tcp_mss Nothing = id +set_tcp_mss (Just mss) = setTcpOption (OptMaxSegmentSize (fromIntegral mss)) + +tcp_ts :: TCPSegment -> Maybe (Timestamp,Timestamp) +tcp_ts = fmap prj . findTcpOption OptTagTimestamp . tcp_header + where + prj (OptTimestamp v r) = (Timestamp v, Timestamp r) + +set_tcp_ts :: Maybe (Timestamp,Timestamp) -> TcpHeader -> TcpHeader +set_tcp_ts Nothing = id +set_tcp_ts (Just (Timestamp v, Timestamp r)) = setTcpOption (OptTimestamp v r) + +{- +data TCPSegment = TCPSegment + { tcp_src :: !TCPAddr + , tcp_dst :: !TCPAddr + , tcp_seq :: !SeqLocal + , tcp_ack :: !SeqForeign + , tcp_URG :: !Bool + , tcp_ACK :: !Bool + , tcp_PSH :: !Bool + , tcp_RST :: !Bool + , tcp_SYN :: !Bool + , tcp_FIN :: !Bool + , tcp_win :: !Int + , tcp_urp :: !Int + , tcp_data :: !BufferChain + -- option: window scaling + , tcp_ws :: !(Maybe Int) + -- option: max segment size + , tcp_mss :: !(Maybe Int) + -- option: RFC1323 + , tcp_ts :: !(Maybe (Timestamp, Timestamp)) + } -} + +instance Show TCPSegment where + show seg = + let part1 = + if (get_port $ tcp_src seg) > 9999 || (get_port $ tcp_dst seg) == 8888 then + "<==" ++ (show $ tcp_src seg) + ++ " ack=" ++(show $ seq_val $ tcp_ack seg) + ++ " seq=" ++(show $ seq_val $ tcp_seq seg) + else + "==>" ++ (show $ tcp_dst seg) + ++ " seq=" ++(show $ seq_val $ tcp_seq seg) + ++ " ack=" ++(show $ seq_val $ tcp_ack seg) + in + let part2 = + " ["++ + (if tcp_URG seg then " URG(urp="++ (show $ tcp_urp seg) ++ ")" else "") ++ + (if tcp_SYN seg then " SYN" else "") ++ + (if tcp_FIN seg then " FIN" else "") ++ + (if tcp_RST seg then " RST" else "") ++ + (if tcp_ACK seg then " ACK" else "") ++ + (if tcp_PSH seg then " PSH" else "") ++ + " ]" + in + part1 + ++ " WIN=" ++(show $ tcp_win seg) + ++ " LEN=" ++(show $ bufc_length $ tcp_data seg) + ++ part2 + +data UDPDatagram = UDPDatagram + { udp_src :: TCPAddr + , udp_dst :: TCPAddr + , udp_data :: [Char] + } deriving (Show, Eq) + +data Protocol = PROTO_TCP | PROTO_UDP deriving (Show, Eq) + +data ICMPType = + ICMP_UNREACH Int + | ICMP_SOURCE_QUENCE Int + | ICMP_REDIRECT Int + | ICMP_TIME_EXCEEDED Int + | ICMP_PARAMPROB Int + deriving (Show, Eq) + +data ICMPDatagram = ICMPDatagram + { icmp_send :: IPAddr + , icmp_recv :: IPAddr + , icmp_src :: Maybe TCPAddr + , icmp_dst :: Maybe TCPAddr + , icmp_proto :: Protocol + , icmp_seq :: Maybe SeqLocal + , icmp_t :: ICMPType + } deriving (Show, Eq) + +data IPMessage = TCPMessage !TCPSegment + | ICMPMessage !ICMPDatagram + | UDPMessage !UDPDatagram + deriving (Show) diff --git a/src/Network/TCP/Type/Socket.hs b/src/Network/TCP/Type/Socket.hs new file mode 100644 index 0000000..9131203 --- /dev/null +++ b/src/Network/TCP/Type/Socket.hs @@ -0,0 +1,199 @@ +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.Type.Socket +where + +import Network.TCP.Type.Base +import Network.TCP.Type.Timer +import Network.TCP.Type.Datagram +import Network.TCP.Type.Syscall +import Data.Map as Map + +data TCPState = CLOSED + | LISTEN + | SYN_SENT + | SYN_RECEIVED + | ESTABLISHED + | CLOSE_WAIT + | FIN_WAIT_1 + | FIN_WAIT_2 + | CLOSING + | LAST_ACK + | TIME_WAIT + deriving (Show,Eq) + +data TCPReassSegment = TCPReassSegment + { trs_seq :: !SeqForeign + , trs_FIN :: !Bool + , trs_data :: !BufferChain + } deriving (Show) + +data RexmtMode = RexmtSyn + | Rexmt + | Persist + deriving (Show,Eq) + +data Rttinf = Rttinf + { t_rttupdated :: !Int + , tf_srtt_valid :: !Bool + , t_srtt :: !Time + , t_rttvar :: !Time + , t_rttmin :: !Time + , t_lastrtt :: !Time + , t_lastshift :: !Int + , t_wassyn :: !Bool + } deriving (Show) + + +data IOBC = NO_OOBDATA | OOBDATA Buffer | HAD_OOBDATA deriving (Show) + +data SocketListen = SocketListen + { lis_q0 :: ![SocketID] -- q0 + , lis_q :: ![SocketID] -- q + , lis_qlimit :: !Int + } deriving (Show,Eq) + +data TCBTiming = TCBTiming + { tt_keep :: !(Maybe Time) + , tt_conn_est :: !(Maybe Time) + , tt_fin_wait_2 :: !(Maybe Time) + , tt_2msl :: !(Maybe Time) + , t_idletime :: !Time + , ts_recent :: !(TimeWindow Timestamp) + , t_badrxtwin :: !(TimeWindow ()) + } deriving (Show) +data TCBSending = TCBSending + { sndq :: !BufferChain + , snd_una :: !SeqLocal + , snd_wnd :: !Int + , snd_wl1 :: !SeqForeign + , snd_wl2 :: !SeqLocal + , snd_cwnd :: !Int + , snd_nxt :: !SeqLocal + , snd_max :: !SeqLocal + , t_dupacks :: !Int + , t_rttinf :: !Rttinf + , t_rttseg :: !(Maybe (Timestamp, SeqLocal)) + , tt_rexmt :: !(Maybe (Timed (RexmtMode, Int))) + } deriving (Show) +data TCBReceiving = TCBReceiving + { last_ack_sent :: !SeqForeign + , tf_rxwin0sent :: !Bool + , tf_shouldacknow :: !Bool + , tt_delack :: !Bool + , rcv_adv :: !SeqForeign + , rcv_wnd :: !Int + , rcv_nxt :: !SeqForeign + , rcvq :: !BufferChain + , t_segq :: ![TCPReassSegment] + } deriving (Show) +data TCBMisc = TCBMisc + { -- retransmission + snd_ssthresh :: !Int + , snd_cwnd_prev :: !Int + , snd_ssthresh_prev :: !Int + , snd_recover :: !SeqLocal + -- some tags + , cantsndmore :: !Bool + , cantrcvmore :: !Bool + , bsd_cantconnect :: !Bool -- not very useful... + -- initialization parameters + , self_id :: !SocketID + , parent_id :: !SocketID + , local_addr :: !TCPAddr + , remote_addr :: !TCPAddr + , t_maxseg :: !Int + , t_advmss :: !(Maybe Int) + , tf_doing_ws :: !Bool + , tf_doing_tstmp :: !Bool + , tf_req_tstmp :: !Bool + , request_r_scale :: !(Maybe Int) + , snd_scale :: !Int + , rcv_scale :: !Int + , iss :: !SeqLocal + , irs :: !SeqForeign + -- other things i don't use for the moment + , sndurp :: !(Maybe Int) + , rcvurp :: !(Maybe Int) + , iobc :: !IOBC + , rcv_up :: !SeqForeign + , tf_needfin :: !Bool + } deriving (Show) + +data TCPSocket threadt = TCPSocket + { st :: !TCPState + , cb_time :: !TCBTiming + , cb_snd :: !TCBSending + , cb_rcv :: !TCBReceiving + , cb :: !TCBMisc + , sock_listen :: !SocketListen + -- suspended commands (threads) + , waiting_list :: ![(SockReq, SockRsp -> threadt)] + } + +instance Show (TCPSocket t) where + show (TCPSocket s cb1 cb2 cb3 cb4 lis wl) = + "TCPSocket state ="++(show s) ++ "\n" ++ + " " ++ (show cb1) ++ "\n" ++ + " " ++ (show cb2) ++ "\n" ++ + " " ++ (show cb3) ++ "\n" ++ + " " ++ (show cb4) ++ "\n" ++ + " " ++ (show lis) ++ "\n" ++ + " waiting: " ++ (show $ length wl) + +data Host threadt = Host + { sock_map :: !(Map SocketID (TCPSocket threadt)) + , output_queue :: ![IPMessage] + , ready_list :: ![threadt] + , ticks :: !Timestamp + , clock :: !Time + , next_timers :: !(Time,Time) -- fast timer, slow timer + , local_ports :: ![Port] + } + +empty_host :: Host t +empty_host = Host + { sock_map = Map.empty + , output_queue = [] + , ready_list = [] + , ticks = Timestamp 0 + , clock = 0 + , next_timers = (0,0) + , local_ports = [0..65535] + } + +update_host_time :: Time -> Host t -> Host t +update_host_time now h = h + { clock = now + } diff --git a/src/Network/TCP/Type/Syscall.hs b/src/Network/TCP/Type/Syscall.hs new file mode 100644 index 0000000..c4ca2c4 --- /dev/null +++ b/src/Network/TCP/Type/Syscall.hs @@ -0,0 +1,57 @@ +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.Type.Syscall +( SocketID +, TCPAddr +, Buffer (..) +, SockReq (..) +, SockRsp (..) +) + +where + +import Network.TCP.Type.Base + +data SockReq = SockConnect !IPAddr !TCPAddr -- create a client socket and connect (return: SockNew sock) + | SockListen !Port -- create a listening Socket (return: SockNew sock) + | SockAccept !SocketID -- accept connection from a listening socket (return: SockNew sock) + | SockSend !SocketID !Buffer -- (return: SockOK) + | SockRecv !SocketID -- (return: SockData) + | SockClose !SocketID -- (return: SockOK) + +data SockRsp = SockOK + | SockError !String + | SockNew !SocketID + | SockData !Buffer + deriving Show diff --git a/src/Network/TCP/Type/Timer.hs b/src/Network/TCP/Type/Timer.hs new file mode 100644 index 0000000..f365869 --- /dev/null +++ b/src/Network/TCP/Type/Timer.hs @@ -0,0 +1,68 @@ +{-- +Copyright (c) 2006, Peng Li + 2006, Stephan A. Zdancewic +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright owners nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--} + +module Network.TCP.Type.Timer + +where + +import Network.TCP.Type.Base + +data Timed a = Timed { timed_val :: a + , timed_exp :: Time + } deriving (Show, Eq) + +timed_expires :: Time -> Timed a -> Bool +timed_expires t (Timed x tm) = t >= tm + +timer_expires :: Time -> Time -> Bool +timer_expires t tm = t >= tm + +maybe_timed_expires :: Time -> Maybe (Timed a) -> Bool +maybe_timed_expires _ Nothing = False +maybe_timed_expires curr_time (Just t) = timed_expires curr_time t + +maybe_timer_expires :: Time -> Maybe Time -> Bool +maybe_timer_expires _ Nothing = False +maybe_timer_expires curr_time (Just t) = curr_time >= t + +type TimeWindow a = Maybe (Timed a) + +timewindow_open :: Time -> TimeWindow a -> Bool +timewindow_open = maybe_timed_expires + +timewindow_val :: Time -> TimeWindow a -> Maybe a +timewindow_val t Nothing = Nothing +timewindow_val t (Just tmd) = + if timed_expires t tmd then Nothing else Just (timed_val tmd) + +