Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ void Connection::maybe_set_keyspace(ResponseMessage* response) {
if (response->opcode() == CQL_OPCODE_RESULT) {
ResultResponse* result = static_cast<ResultResponse*>(response->response_body().get());
if (result->kind() == CASS_RESULT_KIND_SET_KEYSPACE) {
keyspace_ = result->keyspace().to_string();
keyspace_ = result->quoted_keyspace();
}
}
}
Expand Down
1 change: 0 additions & 1 deletion src/hash_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class CaseInsensitiveHashTable : public Allocated {

private:
size_t index_mask_;
size_t count_;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused variable. Unrelated.

SmallVector<T*, 32> index_;
EntryVec entries_;

Expand Down
2 changes: 1 addition & 1 deletion src/pooled_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ChainedSetKeyspaceCallback : public SimpleRequestCallback {
class SetKeyspaceRequest : public QueryRequest {
public:
SetKeyspaceRequest(const String& keyspace, uint64_t request_timeout_ms)
: QueryRequest("USE \"" + keyspace + "\"") {
: QueryRequest("USE " + keyspace) {
set_request_timeout_ms(request_timeout_ms);
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/request_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ void RequestExecution::notify_result_metadata_changed(const Request* request,
if (result_response->protocol_version().supports_set_keyspace() && !request->keyspace().empty()) {
keyspace = request->keyspace();
} else {
keyspace = result_response->keyspace().to_string();
keyspace = result_response->quoted_keyspace();
}

if (request->opcode() == CQL_OPCODE_EXECUTE && result_response->kind() == CASS_RESULT_KIND_ROWS) {
Expand Down Expand Up @@ -531,7 +531,7 @@ void RequestExecution::on_result_response(Connection* connection, ResponseMessag

case CASS_RESULT_KIND_SET_KEYSPACE:
// The response is set after the keyspace is propagated to all threads.
request_handler_->notify_keyspace_changed(result->keyspace().to_string(), current_host_,
request_handler_->notify_keyspace_changed(result->quoted_keyspace(), current_host_,
response->response_body());
break;

Expand Down
5 changes: 5 additions & 0 deletions src/result_response.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ class ResultResponse : public Response {
StringRef keyspace() const { return keyspace_; }
StringRef table() const { return table_; }

String quoted_keyspace() const {
String temp(keyspace_.to_string());
return escape_id(temp);
}

bool metadata_changed() { return new_metadata_id_.size() > 0; }
StringRef new_metadata_id() const { return new_metadata_id_; }

Expand Down
2 changes: 1 addition & 1 deletion src/statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ Statement::Statement(const Prepared* prepared)
// If the keyspace wasn't explictly set then attempt to set it using the
// prepared statement's result metadata.
if (keyspace().empty()) {
set_keyspace(prepared->result()->keyspace().to_string());
set_keyspace(prepared->result()->quoted_keyspace());
}
}

Expand Down
45 changes: 9 additions & 36 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,38 +111,22 @@ String& trim(String& str) {
return str;
}

static bool is_word_char(int c) {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_';
}
static bool is_lowercase(const String& str) {
if (str.empty()) return true;

static bool is_lower_word_char(int c) {
return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '_';
}
char c = str[0];
if (!(c >= 'a' && c <= 'z')) return false;

bool is_valid_cql_id(const String& str) {
for (String::const_iterator i = str.begin(), end = str.end(); i != end; ++i) {
if (!is_word_char(*i)) {
for (String::const_iterator it = str.begin() + 1, end = str.end(); it != end; ++it) {
char c = *it;
if (!((c >= '0' && c <= '9') || (c == '_') || (c >= 'a' && c <= 'z'))) {
return false;
}
}
return true;
}

bool is_valid_lower_cql_id(const String& str) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was implemented incorrectly.

if (str.empty() || !is_lower_word_char(str[0])) {
return false;
}
if (str.size() > 1) {
for (String::const_iterator i = str.begin() + 1, end = str.end(); i != end; ++i) {
if (!is_lower_word_char(*i)) {
return false;
}
}
}
return true;
}

String& quote_id(String& str) {
static String& quote_id(String& str) {
String temp(str);
str.clear();
str.push_back('"');
Expand All @@ -159,18 +143,7 @@ String& quote_id(String& str) {
return str;
}

String& escape_id(String& str) { return is_valid_lower_cql_id(str) ? str : quote_id(str); }

String& to_cql_id(String& str) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_cql_id() was dead code.

if (is_valid_cql_id(str)) {
std::transform(str.begin(), str.end(), str.begin(), tolower);
return str;
}
if (str.length() > 2 && str[0] == '"' && str[str.length() - 1] == '"') {
return str.erase(str.length() - 1, 1).erase(0, 1);
}
return str;
}
String& escape_id(String& str) { return is_lowercase(str) ? str : quote_id(str); }

int32_t get_pid() {
#if (defined(WIN32) || defined(_WIN32))
Expand Down
4 changes: 0 additions & 4 deletions src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ String implode(const Vector<String>& vec, const char delimiter = ',');

String& trim(String& str);

bool is_valid_cql_id(const String& str);

String& to_cql_id(String& str);

String& escape_id(String& str);

inline size_t num_leading_zeros(int64_t value) {
Expand Down
82 changes: 82 additions & 0 deletions tests/src/integration/tests/test_use_keyspace.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
Copyright (c) DataStax, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include "integration.hpp"

#include <locale>

/**
* "USE <keyspace>" case-sensitive tests
*/
class UseKeyspaceCaseSensitiveTests : public Integration {
public:
UseKeyspaceCaseSensitiveTests() {}

// Make a case-sensitive keyspace capitalizing the first char and wrapping in double quotes
virtual std::string default_keyspace() {
std::string temp(Integration::default_keyspace());
temp[0] = std::toupper(temp[0]);
return "\"" + temp + "\"";
}
Comment on lines +29 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart!


virtual void SetUp() {
Integration::SetUp();
session_.execute(
format_string(CASSANDRA_KEY_VALUE_TABLE_FORMAT, table_name_.c_str(), "int", "int"));
session_.execute(
format_string(CASSANDRA_KEY_VALUE_INSERT_FORMAT, table_name_.c_str(), "1", "2"));
}
};

/**
* Verify that case-sensitive keyspaces work when connecting a session with a keyspace.
*/
CASSANDRA_INTEGRATION_TEST_F(UseKeyspaceCaseSensitiveTests, ConnectWithKeyspace) {
CHECK_FAILURE;
Session session = default_cluster().connect(keyspace_name_);

Result result =
session.execute(format_string(CASSANDRA_SELECT_VALUE_FORMAT, table_name_.c_str(), "1"));

Row row = result.first_row();
EXPECT_EQ(row.column_by_name<Integer>("value"), Integer(2));
}

/**
* Verify that case-sensitive keyspaces work with "USE <keyspace>".
*/
CASSANDRA_INTEGRATION_TEST_F(UseKeyspaceCaseSensitiveTests, UseKeyspace) {
CHECK_FAILURE;
Session session = default_cluster().connect();

{ // Expect failure there's no keyspace set
Result result =
session.execute(format_string(CASSANDRA_SELECT_VALUE_FORMAT, table_name_.c_str(), "1"),
CASS_CONSISTENCY_ONE, false, false);

EXPECT_EQ(result.error_code(), CASS_ERROR_SERVER_INVALID_QUERY);
}

session.execute("USE " + keyspace_name_);

{ // Success
Result result =
session.execute(format_string(CASSANDRA_SELECT_VALUE_FORMAT, table_name_.c_str(), "1"));

Row row = result.first_row();
EXPECT_EQ(row.column_by_name<Integer>("value"), Integer(2));
}
}
33 changes: 22 additions & 11 deletions tests/src/unit/mockssandra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "memory.hpp"
#include "scoped_lock.hpp"
#include "tracing_data_handler.hpp" // For tracing query
#include "utils.hpp"
#include "uuids.hpp"

#include <openssl/bio.h>
Expand All @@ -34,10 +35,12 @@
#endif

using datastax::internal::bind_callback;
using datastax::internal::escape_id;
using datastax::internal::Map;
using datastax::internal::Memory;
using datastax::internal::OStringStream;
using datastax::internal::ScopedMutex;
using datastax::internal::trim;
using datastax::internal::core::UuidGen;

#define SSL_BUF_SIZE 8192
Expand Down Expand Up @@ -1357,6 +1360,10 @@ Action::Builder& Action::Builder::use_keyspace(const String& keyspace) {
return execute((new UseKeyspace(keyspace)));
}

Action::Builder& Action::Builder::use_keyspace(const Vector<String>& keyspaces) {
return execute((new UseKeyspace(keyspaces)));
}

Action::Builder& Action::Builder::plaintext_auth(const String& username, const String& password) {
return execute((new PlaintextAuth(username, password)));
}
Expand Down Expand Up @@ -1807,18 +1814,22 @@ void UseKeyspace::on_run(Request* request) const {
String query;
QueryParameters params;
if (request->decode_query(&query, &params)) {
query.erase(0, query.find_first_not_of(" \t"));
if (query.substr(0, 3) == "USE" || query.substr(0, 3) == "use") {
query.erase(0, 3);
query.erase(0, query.find_first_not_of(" \t"));
if (query.substr(0, keyspace.size()) == keyspace) {
String body;
encode_int32(RESULT_SET_KEYSPACE, &body);
encode_string(keyspace, &body);
request->write(OPCODE_RESULT, body);
} else {
request->error(ERROR_INVALID_QUERY, "Keyspace '" + keyspace + "' does not exist");
trim(query);
if (query.compare(0, 3, "USE") == 0 || query.compare(0, 3, "use") == 0) {
String keyspace(query.substr(query.find_first_not_of(" \t", 3)));
for (Vector<String>::const_iterator it = keyspaces.begin(), end = keyspaces.end(); it != end;
++it) {
String temp(*it);
if (keyspace == escape_id(temp)) {
String body;
encode_int32(RESULT_SET_KEYSPACE, &body);
encode_string(*it, &body);
request->client()->set_keyspace(*it);
request->write(OPCODE_RESULT, body);
return;
}
}
request->error(ERROR_INVALID_QUERY, "Keyspace '" + keyspace + "' does not exist");
} else {
run_next(request);
}
Expand Down
12 changes: 9 additions & 3 deletions tests/src/unit/mockssandra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ struct Action {
Builder& system_traces();

Builder& use_keyspace(const String& keyspace);
Builder& use_keyspace(const Vector<String>& keyspaces);
Builder& plaintext_auth(const String& username = "cassandra",
const String& password = "cassandra");

Expand Down Expand Up @@ -844,10 +845,11 @@ struct SystemTraces : public Action {
};

struct UseKeyspace : public Action {
UseKeyspace(const String& keyspace)
: keyspace(keyspace) {}
UseKeyspace(const String& keyspace) { keyspaces.push_back(keyspace); }
UseKeyspace(const Vector<String>& keyspaces)
: keyspaces(keyspaces) {}
virtual void on_run(Request* request) const;
String keyspace;
Vector<String> keyspaces;
};

struct PlaintextAuth : public Action {
Expand Down Expand Up @@ -1018,8 +1020,12 @@ class ClientConnection : public internal::ClientConnection {
const Options& options() const { return options_; }
void set_options(const Options& options) { options_ = options; }

const String& keyspace() const { return keyspace_; }
void set_keyspace(const String& keyspace) { keyspace_ = keyspace; }

private:
ProtocolHandler handler_;
String keyspace_;
const Cluster* cluster_;
int protocol_version_;
bool is_registered_for_events_;
Expand Down
Loading