Skip to content

Commit

Permalink
Pass host, port to transport factory, update demo
Browse files Browse the repository at this point in the history
  • Loading branch information
thobbs committed Sep 25, 2012
1 parent aa75158 commit 4475ff2
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 20 deletions.
39 changes: 33 additions & 6 deletions pycassa/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@
DEFAULT_SERVER = 'localhost:9160'
DEFAULT_PORT = 9160


def default_transport_factory(tsocket, host, port):
"""
Returns a normal :class:`TFramedTransport` instance wrapping `tsocket`.
"""
return TTransport.TFramedTransport(tsocket)


class Connection(Cassandra.Client):
"""Encapsulation of a client session."""

def __init__(self, keyspace, server, timeout=None,
credentials=None, transport_factory=TTransport.TFramedTransport):
credentials=None, transport_factory=default_transport_factory):
self.keyspace = None
self.server = server
server = server.split(':')
Expand All @@ -28,7 +36,7 @@ def __init__(self, keyspace, server, timeout=None,
socket = TSocket.TSocket(host, int(port))
if timeout is not None:
socket.setTimeout(timeout * 1000.0)
self.transport = transport_factory(socket)
self.transport = transport_factory(socket, host, port)
protocol = TBinaryProtocol.TBinaryProtocolAccelerated(self.transport)
Cassandra.Client.__init__(self, protocol)
self.transport.open()
Expand Down Expand Up @@ -153,10 +161,29 @@ def cstringio_refill(self, prefix, reqlen):
return self.__rbuf


def make_sasl_transport_factory(*sasl_args, **sasl_kwargs):
def make_sasl_transport_factory(credential_factory):
"""
A convenience function for creating a SASL transport factory.
`credential_factory` should be a function taking two args: `host` and
`port`. It should return a ``dict`` of kwargs that will be passed
to :func:`puresasl.client.SASLClient.__init__()`.
Example usage::
>>> def make_credentials(host, port):
... return {'sasl_host': host,
... 'sasl_service': 'cassandra',
... 'mechanism': 'GSSAPI'}
>>>
>>> factory = make_sasl_transport_factory(make_credentials)
>>> pool = ConnectionPool(..., transport_factory=factory)
"""

def transport_factory(tsocket):
sasl_transport = TSaslClientTransport(tsocket, *sasl_args, **sasl_kwargs)
def sasl_transport_factory(tsocket, host, port):
sasl_kwargs = credential_factory(host, port)
sasl_transport = TSaslClientTransport(tsocket, **sasl_kwargs)
return TTransport.TFramedTransport(sasl_transport)

return transport_factory
return sasl_transport_factory
16 changes: 10 additions & 6 deletions pycassa/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import Queue

from thrift import Thrift
from thrift.transport.TTransport import TTransportException, TFramedTransport
from connection import Connection
from thrift.transport.TTransport import TTransportException
from connection import Connection, default_transport_factory
from logging.pool_logger import PoolLogger
from util import as_interface
from cassandra.ttypes import TimedOutException, UnavailableException
Expand Down Expand Up @@ -250,10 +250,14 @@ def _set_max_overflow(self, max_overflow):
If multiple pools are in use for different purposes, setting `logging_name` will
help individual pools to be identified in the logs. """

transport_factory = TFramedTransport
transport_factory = default_transport_factory
""" A function that creates the transport for each connection in the pool.
This function should take one argument, a TSocket object for the transport
to wrap. By default, this is ``TTransport.TFramedTransport``. """
This function should take three arguments: `tsocket`, a TSocket object for the
transport, `host`, the host the connection is being made to, and `port`,
the destination port.
By default, this is function is :func:`~connection.default_transport_factory`.
"""

def __init__(self, keyspace,
server_list=['localhost:9160'],
Expand All @@ -262,7 +266,7 @@ def __init__(self, keyspace,
use_threadlocal=True,
pool_size=5,
prefill=True,
transport_factory=TFramedTransport,
transport_factory=default_transport_factory,
**kwargs):
"""
All connections in the pool will be opened to `keyspace`.
Expand Down
23 changes: 15 additions & 8 deletions sasl_demo.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
#!/usr/bin/python

from pycassa.pool import ConnectionPool
from pycassa.columnfamily import ColumnFamily
from pycassa.connection import make_sasl_transport_factory
from pycassa.system_manager import SystemManager

transport_factory = make_sasl_transport_factory(
sasl_host='thobbs-laptop',
sasl_service='host',
mechanism='GSSAPI'
)
def make_creds(host, port):
# typically, you would use the passed-in host, but my kerberos test setup
# is not that sophisticated
return {'sasl_host': 'thobbs-laptop',
'sasl_service': 'host',
'mechanism': 'GSSAPI'}

transport_factory = make_sasl_transport_factory(make_creds)

sysman = SystemManager(transport_factory=transport_factory)
sysman.create_keyspace('Keyspace1', 'SimpleStrategy', {'replication_factor': '1'})
sysman.create_column_family('Keyspace1', 'CF1')
if 'Keyspace1' not in sysman.list_keyspaces():
sysman.create_keyspace('Keyspace1', 'SimpleStrategy', {'replication_factor': '1'})
sysman.create_column_family('Keyspace1', 'Standard1')
sysman.close()

pool = ConnectionPool('Keyspace1', transport_factory=transport_factory)
cf = ColumnFamily(pool, 'CF1')
cf = ColumnFamily(pool, 'Standard1')

for i in range(100):
cf.insert('key%d' % i, {'col': 'val'})
Expand Down

0 comments on commit 4475ff2

Please sign in to comment.