diff --git a/airflow/providers/apache/cassandra/hooks/cassandra.py b/airflow/providers/apache/cassandra/hooks/cassandra.py index 71aea789a12a1..9e52f08aea28d 100644 --- a/airflow/providers/apache/cassandra/hooks/cassandra.py +++ b/airflow/providers/apache/cassandra/hooks/cassandra.py @@ -113,6 +113,10 @@ def __init__(self, cassandra_conn_id: str = 'cassandra_default'): if ssl_options: conn_config['ssl_options'] = ssl_options + protocol_version = conn.extra_dejson.get('protocol_version', None) + if protocol_version: + conn_config['protocol_version'] = protocol_version + self.cluster = Cluster(**conn_config) self.keyspace = conn.schema self.session = None diff --git a/docs/howto/connection/cassandra.rst b/docs/howto/connection/cassandra.rst index 69146738eb476..521c91819a966 100644 --- a/docs/howto/connection/cassandra.rst +++ b/docs/howto/connection/cassandra.rst @@ -53,6 +53,7 @@ Extra (optional) ``RoundRobinPolicy``, ``DCAwareRoundRobinPolicy``, ``WhiteListRoundRobinPolicy`` and ``TokenAwarePolicy``. ``RoundRobinPolicy`` is the default load balancing policy. * ``load_balancing_policy_args`` - This parameter specifies the arguments for the load balancing policy being used. * ``cql_version`` - This parameter specifies the CQL version of cassandra. + * ``protocol_version`` - This parameter specifies the maximum version of the native protocol to use. * ``ssl_options`` - This parameter specifies the details related to SSL, if it's enabled in Cassandra. diff --git a/tests/providers/apache/cassandra/hooks/test_cassandra.py b/tests/providers/apache/cassandra/hooks/test_cassandra.py index 06cf665918249..20da13fb7a6c5 100644 --- a/tests/providers/apache/cassandra/hooks/test_cassandra.py +++ b/tests/providers/apache/cassandra/hooks/test_cassandra.py @@ -43,7 +43,7 @@ def setUp(self): host='host-1,host-2', port='9042', schema='test_keyspace', - extra='{"load_balancing_policy":"TokenAwarePolicy"}', + extra='{"load_balancing_policy":"TokenAwarePolicy","protocol_version":4}', ) ) db.merge_conn( @@ -84,6 +84,7 @@ def test_get_conn(self): cluster = hook.get_cluster() self.assertEqual(cluster.contact_points, ['host-1', 'host-2']) self.assertEqual(cluster.port, 9042) + self.assertEqual(cluster.protocol_version, 4) self.assertTrue(isinstance(cluster.load_balancing_policy, TokenAwarePolicy)) def test_get_lb_policy_with_no_args(self):