Skip to content

Commit

Permalink
add ipv6 support from msoulier#98
Browse files Browse the repository at this point in the history
  • Loading branch information
9001 committed Feb 17, 2024
1 parent bef1ea2 commit b3e3c39
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 11 deletions.
5 changes: 4 additions & 1 deletion partftpy/TftpClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


import logging
import socket
import types

from .TftpContexts import TftpContextClientDownload, TftpContextClientUpload
Expand All @@ -22,14 +23,15 @@ class TftpClient(TftpSession):
download can be initiated via the download() method, or an upload via the
upload() method."""

def __init__(self, host, port=69, options={}, localip=""):
def __init__(self, host, port=69, options={}, localip="", af_family=socket.AF_INET):
TftpSession.__init__(self)
self.context = None
self.host = host
self.iport = port
self.filename = None
self.options = options
self.localip = localip
self.af_family = af_family
if "blksize" in self.options:
size = self.options["blksize"]
tftpassert(int == type(size), "blksize must be an int")
Expand Down Expand Up @@ -71,6 +73,7 @@ def download(
timeout,
retries=retries,
localip=self.localip,
af_family=self.af_family,
ports=ports
)
self.context.start()
Expand Down
23 changes: 17 additions & 6 deletions partftpy/TftpContexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,15 @@ def add_dup(self, pkt):
class TftpContext(object):
"""The base class of the contexts."""

def __init__(self, host, port, timeout, retries=DEF_TIMEOUT_RETRIES, localip="", ports=None):
def __init__(self, host, port, timeout, retries=DEF_TIMEOUT_RETRIES, localip="", af_family=socket.AF_INET, ports=None):
"""Constructor for the base context, setting shared instance
variables."""
self.file_to_transfer = None
self.fileobj = None
self.options = None
self.packethook = None
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.af_family = af_family
self.sock = socket.socket(af_family, socket.SOCK_DGRAM)
for n in ports or [0]:
try:
if localip != "":
Expand All @@ -110,6 +111,7 @@ def __init__(self, host, port, timeout, retries=DEF_TIMEOUT_RETRIES, localip="",
self.next_block = 0
self.factory = TftpPacketFactory()
# Note, setting the host will also set self.address, as it's a property.
self.address = ""
self.host = host
self.port = port
# The port associated with the TID
Expand Down Expand Up @@ -170,7 +172,12 @@ def sethost(self, host):
of the host that is set.
"""
self.__host = host
self.address = socket.gethostbyname(host)
if self.af_family == socket.AF_INET:
self.address = socket.gethostbyname(host)
elif self.af_family == socket.AF_INET6:
self.address = socket.getaddrinfo(host, 0)[0][4][0]
else:
raise ValueError("af_family is not supported")

host = property(gethost, sethost)

Expand All @@ -191,7 +198,9 @@ def cycle(self):
something, and dispatch appropriate action to that response.
"""
try:
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
buffer, rai = self.sock.recvfrom(MAX_BLKSIZE)
raddress = rai[0]
rport = rai[1]
except socket.timeout:
log.warning("Timeout waiting for traffic, retrying...")
raise TftpTimeout("Timed-out waiting for traffic")
Expand Down Expand Up @@ -247,9 +256,10 @@ def __init__(
dyn_file_func=None,
upload_open=None,
retries=DEF_TIMEOUT_RETRIES,
af_family=socket.AF_INET,
ports=None,
):
TftpContext.__init__(self, host, port, timeout, retries, ports=ports)
TftpContext.__init__(self, host, port, timeout, retries, af_family=af_family, ports=ports)
# At this point we have no idea if this is a download or an upload. We
# need to let the start state determine that.
self.state = TftpStateServerStart(self)
Expand Down Expand Up @@ -305,9 +315,10 @@ def __init__(
timeout,
retries=DEF_TIMEOUT_RETRIES,
localip="",
af_family=socket.AF_INET,
ports=None,
):
TftpContext.__init__(self, host, port, timeout, retries, localip, ports)
TftpContext.__init__(self, host, port, timeout, retries, localip, af_family, ports)
self.file_to_transfer = filename
self.options = options
self.packethook = packethook
Expand Down
13 changes: 9 additions & 4 deletions partftpy/TftpServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def listen(
listenport=DEF_TFTP_PORT,
timeout=SOCK_TIMEOUT,
retries=DEF_TIMEOUT_RETRIES,
af_family=socket.AF_INET,
ports=None,
):
"""Start a server listening on the supplied interface and port. This
Expand All @@ -93,12 +94,13 @@ def listen(
tftp_factory = TftpPacketFactory()

listenip = listenip or "0.0.0.0"
log.info("listening @ %s:%s", listenip, listenport)
ip_str = listenip if af_family == socket.AF_INET else "[%s]" % (listenip,)
log.info("listening @ %s:%s", ip_str, listenport)
try:
# FIXME - sockets should be non-blocking
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock = socket.socket(af_family, socket.SOCK_DGRAM)
self.sock.bind((listenip, listenport))
_, self.listenport = self.sock.getsockname()
self.listenport = self.sock.getsockname()[1]
except OSError as err:
# Reraise it for now.
raise err
Expand Down Expand Up @@ -153,7 +155,9 @@ def listen(
# Is the traffic on the main server socket? ie. new session?
if readysock == self.sock:
log.debug("Data ready on our main socket")
buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE)
buffer, rai = self.sock.recvfrom(MAX_BLKSIZE)
raddress = rai[0]
rport = rai[1]

log.debug("Read %d bytes", len(buffer))

Expand All @@ -177,6 +181,7 @@ def listen(
self.dyn_file_func,
self.upload_open,
retries=retries,
af_family=af_family,
ports=ports,
)
try:
Expand Down

0 comments on commit b3e3c39

Please sign in to comment.