Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Add -N (--auto-nets) option for auto-discovering subnets.

Now if you do

	./sshuttle -Nr username@myservername

It'll automatically route the "local" subnets (ie., stuff in the routing
table) from myservername.  This is (hopefully a reasonable default setting
for most people.
  • Loading branch information...
commit 7043195043d5a1885235833804ae7f90404e4a46 1 parent 77935bd
apenwarr authored May 07, 2010
2  assembler.py
@@ -7,7 +7,7 @@
7 7
     if name:
8 8
         nbytes = int(sys.stdin.readline())
9 9
         if verbosity >= 2:
10  
-            sys.stderr.write('remote assembling %r (%d bytes)\n'
  10
+            sys.stderr.write('server: assembling %r (%d bytes)\n'
11 11
                              % (name, nbytes))
12 12
         content = z.decompress(sys.stdin.read(nbytes))
13 13
         exec compile(content, name, "exec")
32  client.py
@@ -22,11 +22,11 @@ def original_dst(sock):
22 22
 class FirewallClient:
23 23
     def __init__(self, port, subnets):
24 24
         self.port = port
  25
+        self.auto_nets = []
25 26
         self.subnets = subnets
26  
-        subnets_str = ['%s/%d' % (ip,width) for ip,width in subnets]
27 27
         argvbase = ([sys.argv[0]] +
28 28
                     ['-v'] * (helpers.verbose or 0) +
29  
-                    ['--firewall', str(port)] + subnets_str)
  29
+                    ['--firewall', str(port)])
30 30
         argv_tries = [
31 31
             ['sudo'] + argvbase,
32 32
             ['su', '-c', ' '.join(argvbase)],
@@ -66,6 +66,9 @@ def check(self):
66 66
             raise Fatal('%r returned %d' % (self.argv, rv))
67 67
 
68 68
     def start(self):
  69
+        self.pfile.write('ROUTES\n')
  70
+        for (ip,width) in self.subnets+self.auto_nets:
  71
+            self.pfile.write('%s,%d\n' % (ip, width))
69 72
         self.pfile.write('GO\n')
70 73
         self.pfile.flush()
71 74
         line = self.pfile.readline()
@@ -80,7 +83,7 @@ def done(self):
80 83
             raise Fatal('cleanup: %r returned %d' % (self.argv, rv))
81 84
 
82 85
 
83  
-def _main(listener, fw, use_server, remotename):
  86
+def _main(listener, fw, use_server, remotename, auto_nets):
84 87
     handlers = []
85 88
     if use_server:
86 89
         if helpers.verbose >= 1:
@@ -102,9 +105,22 @@ def _main(listener, fw, use_server, remotename):
102 105
             raise Fatal('expected server init string %r; got %r'
103 106
                             % (expected, initstring))
104 107
 
105  
-    # we definitely want to do this *after* starting ssh, or we might end
106  
-    # up intercepting the ssh connection!
107  
-    fw.start()
  108
+    def onroutes(routestr):
  109
+        if auto_nets:
  110
+            for line in routestr.strip().split('\n'):
  111
+                (ip,width) = line.split(',', 1)
  112
+                fw.auto_nets.append((ip,int(width)))
  113
+
  114
+        # we definitely want to do this *after* starting ssh, or we might end
  115
+        # up intercepting the ssh connection!
  116
+        #
  117
+        # Moreover, now that we have the --auto-nets option, we have to wait
  118
+        # for the server to send us that message anyway.  Even if we haven't
  119
+        # set --auto-nets, we might as well wait for the message first, then
  120
+        # ignore its contents.
  121
+        mux.got_routes = None
  122
+        fw.start()
  123
+    mux.got_routes = onroutes
108 124
 
109 125
     def onaccept():
110 126
         sock,srcip = listener.accept()
@@ -149,7 +165,7 @@ def onaccept():
149 165
             mux.check_fullness()
150 166
 
151 167
 
152  
-def main(listenip, use_server, remotename, subnets):
  168
+def main(listenip, use_server, remotename, auto_nets, subnets):
153 169
     debug1('Starting sshuttle proxy.\n')
154 170
     listener = socket.socket()
155 171
     listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@@ -179,6 +195,6 @@ def main(listenip, use_server, remotename, subnets):
179 195
     fw = FirewallClient(listenip[1], subnets)
180 196
     
181 197
     try:
182  
-        return _main(listener, fw, use_server, remotename)
  198
+        return _main(listener, fw, use_server, remotename, auto_nets)
183 199
     finally:
184 200
         fw.done()
19  firewall.py
@@ -140,7 +140,7 @@ def program_exists(name):
140 140
 # exit.  In case that fails, it's not the end of the world; future runs will
141 141
 # supercede it in the transproxy list, at least, so the leftover rules
142 142
 # are hopefully harmless.
143  
-def main(port, subnets):
  143
+def main(port):
144 144
     assert(port > 0)
145 145
     assert(port <= 65535)
146 146
 
@@ -173,8 +173,21 @@ def main(port, subnets):
173 173
     line = sys.stdin.readline(128)
174 174
     if not line:
175 175
         return  # parent died; nothing to do
176  
-    if line != 'GO\n':
177  
-        raise Fatal('firewall: expected GO but got %r' % line)
  176
+
  177
+    subnets = []
  178
+    if line != 'ROUTES\n':
  179
+        raise Fatal('firewall: expected ROUTES but got %r' % line)
  180
+    while 1:
  181
+        line = sys.stdin.readline(128)
  182
+        if not line:
  183
+            raise Fatal('firewall: expected route but got %r' % line)
  184
+        elif line == 'GO\n':
  185
+            break
  186
+        try:
  187
+            (ip,width) = line.strip().split(',', 1)
  188
+        except:
  189
+            raise Fatal('firewall: expected route or GO but got %r' % line)
  190
+        subnets.append((ip, int(width)))
178 191
     try:
179 192
         if line:
180 193
             debug1('firewall manager: starting transproxy.\n')
13  main.py
@@ -50,6 +50,7 @@ def parse_ipport(s):
50 50
 sshuttle --server
51 51
 --
52 52
 l,listen=  transproxy to this ip address and port number [default=0]
  53
+N,auto-nets automatically determine subnets to route
53 54
 r,remote=  ssh hostname (and optional username) of remote sshuttle server
54 55
 v,verbose  increase debug message verbosity
55 56
 noserver   don't use a separate server process (mostly for debugging)
@@ -65,19 +66,19 @@ def parse_ipport(s):
65 66
     if opt.server:
66 67
         sys.exit(server.main())
67 68
     elif opt.firewall:
68  
-        if len(extra) < 1:
69  
-            o.fatal('at least one argument expected')
70  
-        sys.exit(firewall.main(int(extra[0]),
71  
-                               parse_subnets(extra[1:])))
  69
+        if len(extra) != 1:
  70
+            o.fatal('exactly one argument expected')
  71
+        sys.exit(firewall.main(int(extra[0])))
72 72
     else:
73  
-        if len(extra) < 1:
74  
-            o.fatal('at least one subnet expected')
  73
+        if len(extra) < 1 and not opt.auto_nets:
  74
+            o.fatal('at least one subnet (or -N) expected')
75 75
         remotename = opt.remote
76 76
         if remotename == '' or remotename == '-':
77 77
             remotename = None
78 78
         sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'),
79 79
                              not opt.noserver,
80 80
                              remotename,
  81
+                             opt.auto_nets,
81 82
                              parse_subnets(extra)))
82 83
 except Fatal, e:
83 84
     log('fatal: %s\n' % e)
73  server.py
... ...
@@ -1,15 +1,83 @@
1  
-import struct, socket, select
  1
+import re, struct, socket, select, subprocess
2 2
 if not globals().get('skip_imports'):
3 3
     import ssnet, helpers
4 4
     from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
5 5
     from helpers import *
6 6
 
7 7
 
  8
+def _ipmatch(ipstr):
  9
+    if ipstr == 'default':
  10
+        ipstr = '0.0.0.0/0'
  11
+    m = re.match(r'^(\d+(\.\d+(\.\d+(\.\d+)?)?)?)(?:/(\d+))?$', ipstr)
  12
+    if m:
  13
+        g = m.groups()
  14
+        ips = g[0]
  15
+        width = int(g[4] or 32)
  16
+        if g[1] == None:
  17
+            ips += '.0.0.0'
  18
+            width = min(width, 8)
  19
+        elif g[2] == None:
  20
+            ips += '.0.0'
  21
+            width = min(width, 16)
  22
+        elif g[3] == None:
  23
+            ips += '.0'
  24
+            width = min(width, 24)
  25
+        return (struct.unpack('!I', socket.inet_aton(ips))[0], width)
  26
+
  27
+
  28
+def _ipstr(ip, width):
  29
+    if width >= 32:
  30
+        return ip
  31
+    else:
  32
+        return "%s/%d" % (ip, width)
  33
+
  34
+
  35
+def _maskbits(netmask):
  36
+    if not netmask:
  37
+        return 32
  38
+    for i in range(32):
  39
+        if netmask[0] & (1<<i):
  40
+            return 32-i
  41
+    return 0
  42
+
  43
+
  44
+def _list_routes():
  45
+    argv = ['netstat', '-rn']
  46
+    p = subprocess.Popen(argv, stdout=subprocess.PIPE)
  47
+    routes = []
  48
+    for line in p.stdout:
  49
+        cols = re.split(r'\s+', line)
  50
+        ipw = _ipmatch(cols[0])
  51
+        if not ipw:
  52
+            continue  # some lines won't be parseable; never mind
  53
+        maskw = _ipmatch(cols[2])  # linux only
  54
+        mask = _maskbits(maskw)   # returns 32 if maskw is null
  55
+        width = min(ipw[1], mask)
  56
+        ip = ipw[0] & (((1<<width)-1) << (32-width))
  57
+        routes.append((socket.inet_ntoa(struct.pack('!I', ip)), width))
  58
+    rv = p.wait()
  59
+    if rv != 0:
  60
+        raise Fatal('%r returned %d' % (argv, rv))
  61
+    return routes
  62
+
  63
+
  64
+def list_routes():
  65
+    for (ip,width) in _list_routes():
  66
+        if not ip.startswith('0.') and not ip.startswith('127.'):
  67
+            yield (ip,width)
  68
+        
  69
+
  70
+
8 71
 def main():
9 72
     if helpers.verbose >= 1:
10 73
         helpers.logprefix = ' s: '
11 74
     else:
12 75
         helpers.logprefix = 'server: '
  76
+
  77
+    routes = list(list_routes())
  78
+    debug1('available routes:\n')
  79
+    for r in routes:
  80
+        debug1('  %s/%d\n' % r)
13 81
         
14 82
     # synchronization header
15 83
     sys.stdout.write('SSHUTTLE0001')
@@ -21,6 +89,9 @@ def main():
21 89
               socket.fromfd(sys.stdout.fileno(),
22 90
                             socket.AF_INET, socket.SOCK_STREAM))
23 91
     handlers.append(mux)
  92
+    routepkt = ''.join('%s,%d\n' % r
  93
+                       for r in routes)
  94
+    mux.send(0, ssnet.CMD_ROUTES, routepkt)
24 95
 
25 96
     def new_channel(channel, data):
26 97
         (dstip,dstport) = data.split(',', 1)
14  ssnet.py
@@ -12,6 +12,7 @@
12 12
 CMD_CLOSE = 0x4204
13 13
 CMD_EOF = 0x4205
14 14
 CMD_DATA = 0x4206
  15
+CMD_ROUTES = 0x4207
15 16
 
16 17
 cmd_to_name = {
17 18
     CMD_EXIT: 'EXIT',
@@ -21,6 +22,7 @@
21 22
     CMD_CLOSE: 'CLOSE',
22 23
     CMD_EOF: 'EOF',
23 24
     CMD_DATA: 'DATA',
  25
+    CMD_ROUTES: 'ROUTES',
24 26
 }
25 27
     
26 28
 
@@ -220,7 +222,7 @@ def __init__(self, rsock, wsock):
220 222
         Handler.__init__(self, [rsock, wsock])
221 223
         self.rsock = rsock
222 224
         self.wsock = wsock
223  
-        self.new_channel = None
  225
+        self.new_channel = self.got_routes = None
224 226
         self.channels = {}
225 227
         self.chani = 0
226 228
         self.want = 0
@@ -259,12 +261,13 @@ def send(self, channel, cmd, data):
259 261
         p = struct.pack('!ccHHH', 'S', 'S', channel, cmd, len(data)) + data
260 262
         self.outbuf.append(p)
261 263
         debug2(' > channel=%d cmd=%s len=%d (fullness=%d)\n'
262  
-               % (channel, cmd_to_name[cmd], len(data), self.fullness))
  264
+               % (channel, cmd_to_name.get(cmd,hex(cmd)),
  265
+                  len(data), self.fullness))
263 266
         self.fullness += len(data)
264 267
 
265 268
     def got_packet(self, channel, cmd, data):
266 269
         debug2('<  channel=%d cmd=%s len=%d\n' 
267  
-               % (channel, cmd_to_name[cmd], len(data)))
  270
+               % (channel, cmd_to_name.get(cmd,hex(cmd)), len(data)))
268 271
         if cmd == CMD_PING:
269 272
             self.send(0, CMD_PONG, data)
270 273
         elif cmd == CMD_PONG:
@@ -277,6 +280,11 @@ def got_packet(self, channel, cmd, data):
277 280
             assert(not self.channels.get(channel))
278 281
             if self.new_channel:
279 282
                 self.new_channel(channel, data)
  283
+        elif cmd == CMD_ROUTES:
  284
+            if self.got_routes:
  285
+                self.got_routes(data)
  286
+            else:
  287
+                raise Exception('weird: got CMD_ROUTES without got_routes?')
280 288
         else:
281 289
             callback = self.channels[channel]
282 290
             callback(cmd, data)

0 notes on commit 7043195

Please sign in to comment.
Something went wrong with that request. Please try again.