Skip to content

Commit

Permalink
Continuing the great refactoring: POSTing to FileUploadRequest can pr…
Browse files Browse the repository at this point in the history
…int out the request body
  • Loading branch information
ajdavis committed Oct 16, 2011
1 parent a9eb9db commit 1c14472
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 172 deletions.
4 changes: 3 additions & 1 deletion syncsend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from twisted.internet import protocol, reactor
from twisted.web import http

from upload import HTTPFileUploadChannel

#
#class FileProxy:
# def __init__(self, content):
Expand Down Expand Up @@ -198,7 +200,7 @@ class SyncSendHttp(http.HTTPChannel):
requestFactory = SyncSendRequest

class SyncSendHttpFactory(http.HTTPFactory):
protocol = SyncSendHttp
protocol = HTTPFileUploadChannel

if __name__ == "__main__":
from twisted.internet import reactor
Expand Down
232 changes: 61 additions & 171 deletions upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from urllib import unquote

from twisted.web.http_headers import _DictHeaders, Headers
from twisted.web.http import protocol_version, datetimeToString, toChunk, RESPONSES, Request
from twisted.web.http import protocol_version, datetimeToString, toChunk, RESPONSES, Request, _IdentityTransferDecoder

class FileDownloadRequest(Request):
pass
Expand All @@ -37,7 +37,7 @@ class FileUploadRequest:
"""
A HTTP request for uploading files. Copied and adapted from twisted.web.http.Request.
Many simplifications over Twisted's Request: no queueing, no cookies, no authentication, no SSL
Many simplifications over Twisted's Request: no cookies, no authentication, no SSL
"""
implements(interfaces.IConsumer)

Expand All @@ -54,18 +54,29 @@ class FileUploadRequest:
args = None
path = None
content = None
queued = False
_disconnected = False

def __init__(self, channel):
def sendData(self, data):
print data # TODO

def __init__(self, channel, queued):
"""
@param channel: the channel we're connected to.
@param queued: are we in the request queue, or can we start writing to
the transport?
"""
self.notifications = []
self.channel = channel
self.queued = queued
self.requestHeaders = Headers()
self.received_cookies = {}
self.responseHeaders = Headers()
self.transport = self.channel.transport

if queued:
self.transport = StringTransport()
else:
self.transport = self.channel.transport


def __setattr__(self, name, value):
Expand Down Expand Up @@ -105,6 +116,33 @@ def _cleanup(self):
d.callback(None)
self.notifications = []

def noLongerQueued(self):
"""
Notify the object that it is no longer queued.
We start writing whatever data we have to the transport, etc.
This method is not intended for users.
"""
if not self.queued:
raise RuntimeError, "noLongerQueued() got called unnecessarily."

self.queued = 0

# set transport to real one and send any buffer data
data = self.transport.getvalue()
self.transport = self.channel.transport
if data:
self.transport.write(data)

# if we have producer, register it with transport
if (self.producer is not None) and not self.finished:
self.transport.registerProducer(self.producer, self.streamingProducer)

# if we're finished, clean up
if self.finished:
self._cleanup()

def gotLength(self, length):
"""
Called when HTTP channel got length of content in this request.
Expand Down Expand Up @@ -137,12 +175,18 @@ def registerProducer(self, producer, streaming):

self.streamingProducer = streaming
self.producer = producer
self.transport.registerProducer(producer, streaming)

if self.queued:
if streaming:
producer.pauseProducing()
else:
self.transport.registerProducer(producer, streaming)
def unregisterProducer(self):
"""
Unregister the producer.
"""
if not self.queued:
self.transport.unregisterProducer()
self.producer = None

# private http response methods
Expand Down Expand Up @@ -204,7 +248,7 @@ def finish(self):

# log request
if hasattr(self.channel, "factory"):
self.channel.factory.log(self)
self.channel.factory.log(self) # TODO

self.finished = 1
if not self.queued:
Expand Down Expand Up @@ -335,33 +379,6 @@ def getHost(self):
"""
return self.host

def setHost(self, host, port, ssl=0):
"""
Change the host and port the request thinks it's using.
This method is useful for working with reverse HTTP proxies (e.g.
both Squid and Apache's mod_proxy can do this), when the address
the HTTP client is using is different than the one we're listening on.
For example, Apache may be listening on https://www.example.com, and then
forwarding requests to http://localhost:8080, but we don't want HTML produced
by Twisted to say 'http://localhost:8080', they should say 'https://www.example.com',
so we do::
request.setHost('www.example.com', 443, ssl=1)
@type host: C{str}
@param host: The value to which to change the host header.
@type ssl: C{bool}
@param ssl: A flag which, if C{True}, indicates that the request is
considered secure (if C{True}, L{isSecure} will return C{True}).
"""
self._forceSSL = ssl
self.requestHeaders.setRawHeaders("host", [host])
self.host = address.IPv4Address("TCP", host, port)


def getClientIP(self):
"""
Return the IP address of the client who submitted this request.
Expand All @@ -374,76 +391,6 @@ def getClientIP(self):
else:
return None

def isSecure(self):
"""
Return True if this request is using a secure transport.
Normally this method returns True if this request's HTTPChannel
instance is using a transport that implements ISSLTransport.
This will also return True if setHost() has been called
with ssl=True.
@returns: True if this request is secure
@rtype: C{bool}
"""
if self._forceSSL:
return True
transport = getattr(getattr(self, 'channel', None), 'transport', None)
if interfaces.ISSLTransport(transport, None) is not None:
return True
return False

def _authorize(self):
# Authorization, (mostly) per the RFC
try:
authh = self.getHeader("Authorization")
if not authh:
self.user = self.password = ''
return
bas, upw = authh.split()
if bas.lower() != "basic":
raise ValueError
upw = base64.decodestring(upw)
self.user, self.password = upw.split(':', 1)
except (binascii.Error, ValueError):
self.user = self.password = ""
except:
log.err()
self.user = self.password = ""

def getUser(self):
"""
Return the HTTP user sent with this request, if any.
If no user was supplied, return the empty string.
@returns: the HTTP user, if any
@rtype: C{str}
"""
try:
return self.user
except:
pass
self._authorize()
return self.user

def getPassword(self):
"""
Return the HTTP password sent with this request, if any.
If no password was supplied, return the empty string.
@returns: the HTTP password, if any
@rtype: C{str}
"""
try:
return self.password
except:
pass
self._authorize()
return self.password

def getClient(self):
if self.client.type != 'TCP':
return None
Expand All @@ -458,6 +405,15 @@ def getClient(self):
return name
return names[0]

def requestReceived(self, command, path, version):
assert command == 'POST'

self.client = self.channel.transport.getPeer()
self.host = self.channel.transport.getHost()
self.clientproto = version

self.setResponseCode(200)
self.finish()

def connectionLost(self, reason):
"""
Expand Down Expand Up @@ -486,7 +442,6 @@ class HTTPFileUploadChannel(basic.LineReceiver, policies.TimeoutMixin):
maxHeaders = 500 # max number of headers allowed per request

length = 0
persistent = 1
__header = ''
__first_line = 1
__content = None
Expand All @@ -510,19 +465,12 @@ def lineReceived(self, line):
self.resetTimeout()

if self.__first_line:
# if this connection is not persistent, drop any data which
# the client (illegally) sent after the last request.
if not self.persistent:
self.dataReceived = self.lineReceived = lambda *args: None
return

# IE sends an extraneous empty line (\r\n) after a POST request;
# eat up such a line, but only ONCE
if not line and self.__first_line == 1:
self.__first_line = 2
return


self.__first_line = 0
parts = line.split()
if len(parts) != 3:
Expand All @@ -540,7 +488,7 @@ def lineReceived(self, line):
self._version = version

# create a new Request object
request = FileUploadRequest(self) if self._command == 'POST' else FileDownloadRequest(self)
request = FileUploadRequest(self, False) if self._command == 'POST' else FileDownloadRequest(self, False)
self.requests.append(request)
elif line == '':
if self.__header:
Expand Down Expand Up @@ -625,73 +573,15 @@ def rawDataReceived(self, data):

def allHeadersReceived(self):
req = self.requests[-1]
req.parseCookies()
self.persistent = self.checkPersistence(req, self._version)
req.gotLength(self.length)


def checkPersistence(self, request, version):
"""
Check if the channel should close or not.
@param request: The request most recently received over this channel
against which checks will be made to determine if this connection
can remain open after a matching response is returned.
@type version: C{str}
@param version: The version of the request.
@rtype: C{bool}
@return: A flag which, if C{True}, indicates that this connection may
remain open to receive another request; if C{False}, the connection
must be closed in order to indicate the completion of the response
to C{request}.
"""
connection = request.requestHeaders.getRawHeaders('connection')
if connection:
tokens = map(str.lower, connection[0].split(' '))
else:
tokens = []

# HTTP 1.0 persistent connection support is currently disabled,
# since we need a way to disable pipelining. HTTP 1.0 can't do
# pipelining since we can't know in advance if we'll have a
# content-length header, if we don't have the header we need to close the
# connection. In HTTP 1.1 this is not an issue since we use chunked
# encoding if content-length is not available.

#if version == "HTTP/1.0":
# if 'keep-alive' in tokens:
# request.setHeader('connection', 'Keep-Alive')
# return 1
# else:
# return 0
if version == "HTTP/1.1":
if 'close' in tokens:
request.responseHeaders.setRawHeaders('connection', ['close'])
return False
else:
return True
else:
return False


def requestDone(self, request):
"""
Called by first request in queue when it is done.
"""
if request != self.requests[0]: raise TypeError
del self.requests[0]

if self.persistent:
# notify next request it can start writing
if self.requests:
self.requests[0].noLongerQueued()
else:
if self._savedTimeOut:
self.setTimeout(self._savedTimeOut)
else:
self.transport.loseConnection()
self.transport.loseConnection()

def timeoutConnection(self):
log.msg("Timing out client: %s" % str(self.transport.getPeer()))
Expand Down

0 comments on commit 1c14472

Please sign in to comment.