Skip to content

Commit

Permalink
Merge pull request #97 from sethdenner/primary-key-fixes
Browse files Browse the repository at this point in the history
Primary Key Fixes
  • Loading branch information
sethdenner committed Dec 14, 2016
2 parents cacf9a1 + ddfd5b8 commit 6798eb0
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 68 deletions.
63 changes: 39 additions & 24 deletions djangocassandra/db/fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid

from math import log10, floor
from datetime import datetime
from django.utils.six import with_metaclass
from django.db.models import (
Expand All @@ -14,21 +15,28 @@

from cassandra.cqlengine.functions import Token

from .values import PrimaryKeyValue


class DateTimeField(DjangoDateTimeField):
def get_prep_value(self, value):
# Hack cassandra truncates microseconds to milliseconds.
value = datetime(
value.year,
value.month,
value.day,
value.hour,
value.minute,
value.second,
int(str(
value.microsecond
)[:-3] + "000"),
value.tzinfo
)
microsecond = value.microsecond
if 0 != microsecond:
microsecond = round(
value.microsecond,
-int(floor(log10(abs(value.microsecond)))) + 2
)
value = datetime(
value.year,
value.month,
value.day,
value.hour,
value.minute,
value.second,
int(microsecond),
value.tzinfo
)

return super(DateTimeField, self).get_prep_value(value)

Expand Down Expand Up @@ -167,7 +175,11 @@ def get_prep_value(self, value):
):
return value

return uuid.UUID(value)
try:
return uuid.UUID(value)

except:
return value

def get_internal_type(self):
return 'FieldUUID'
Expand Down Expand Up @@ -222,14 +234,17 @@ def to_python(
return value

def get_prep_value(self, value):
value = super(AutoField, self).get_prep_value(value)
if (
value is None or
isinstance(value, uuid.UUID)
):
return value

return uuid.UUID(value)
try:
return uuid.UUID(value)

except:
return value

def get_internal_type(self):
return 'AutoFieldUUID'
Expand Down Expand Up @@ -301,10 +316,17 @@ def __init__(self, *args, **kwargs):
)

def get_prep_value(self, value):
return value
if isinstance(value, PrimaryKeyValue):
return value

elif isinstance(self, ForeignKey):
return self.related_field.get_prep_value(value)

else:
return super(PrimaryKeyField, self).get_prep_value(value)

def get_internal_type(self):
if ForeignKey in self.__class__.__bases__:
if isinstance(self, ForeignKey):
return "ForeignKey"

return super(
Expand All @@ -319,10 +341,3 @@ def get_internal_type(self):
**self.field_kwargs
)
)

def get_internal_type(self):
import pdb; pdb.set_trace()
return self.field_class(
*self.field_args,
**self.field_kwargs
).get_internal_type()
31 changes: 7 additions & 24 deletions djangocassandra/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
FieldDoesNotExist
)

from .fields import (
TokenPartitionKeyField,
PrimaryKeyField
)

from .fields import TokenPartitionKeyField
from .values import PrimaryKeyValue
from .query import QuerySet


Expand Down Expand Up @@ -59,22 +56,6 @@ def create(
return instance


class PrimaryKeyValue(OrderedDict):
def __int__(self):
return self.__hash__()

def __hash__(self):
return hash(self.to_tuple())

def __str__(self):
return str(self.to_tuple())

def __unicode__(self):
return unicode(self.to_tuple())

def to_tuple(self):
return tuple(self.iteritems())


class ColumnFamilyModel(DjangoModel):
class Meta:
Expand Down Expand Up @@ -124,10 +105,12 @@ def _get_pk_val(self, meta=None):
for key in all_keys:
field = self._meta.get_field_by_name(key)[0]
if isinstance(field, ForeignKey):
pk_value[key] = getattr(self, key + "_id")
key += "_id"

else:
pk_value[key] = getattr(self, key)
pk_value[key] = field.get_prep_value(getattr(
self,
key
))

return pk_value

Expand Down
19 changes: 19 additions & 0 deletions djangocassandra/db/values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from collections import OrderedDict


class PrimaryKeyValue(OrderedDict):
def __int__(self):
return self.__hash__()

def __hash__(self):
return hash(self.to_tuple())

def __str__(self):
return str(self.to_tuple())

def __unicode__(self):
return unicode(self.to_tuple())

def to_tuple(self):
return tuple(self.iteritems())

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='djangocassandra',
version='0.7.3',
version='0.7.4',
description='Cassandra support for the Django web framework',
long_description=(
'The Cassandra database backend for Django has been '
Expand Down
12 changes: 5 additions & 7 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import datetime

from uuid import UUID

from django.db.models import (
DO_NOTHING,
Model,
Expand Down Expand Up @@ -175,16 +177,14 @@ class ClusterPrimaryKeyModel(ColumnFamilyModel):
class Cassandra:
clustering_keys = ['field_2', 'field_3']

field_1 = CharField(
primary_key=True,
max_length=32
field_1 = FieldUUID(
primary_key=True
)
field_2 = CharField(max_length=32)
field_3 = CharField(max_length=32)
data = CharField(max_length=64)

def auto_populate(self):
self.field_1 = random_string(32)
self.field_2 = random_string(32)
self.field_3 = random_string(32)
self.data = random_string(64)
Expand Down Expand Up @@ -297,7 +297,6 @@ class Cassandra:
default=datetime.datetime.utcnow
)


class ForeignPartitionKeyModel(ColumnFamilyModel):
class Cassandra:
partition_keys = [
Expand All @@ -309,8 +308,7 @@ class Cassandra:

related = PrimaryKeyField(
field_class=ForeignKey,
field_args=[ClusterPrimaryKeyModel],
field_kwargs={"primary_key": True}
field_args=[ClusterPrimaryKeyModel]
)
created = DateTimeField(
default=datetime.datetime.utcnow
Expand Down
9 changes: 9 additions & 0 deletions tests/test_columnfamily.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,15 @@ def test_order_by_efficient(self):
len(results)
)

for i in instances:
i.delete()

all_instances = ForeignPartitionKeyModel.objects.all()
self.assertEqual(
0,
len(all_instances)
)


class TestDictFieldModel(TestCase):
def setUp(self):
Expand Down
28 changes: 16 additions & 12 deletions tests/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
import uuid

from unittest import TestCase

Expand Down Expand Up @@ -119,29 +120,32 @@ def setUp(self):
)

manager = ClusterPrimaryKeyModel.objects

self.uuid0 = str(uuid.uuid4())
manager.create(
field_1='aaaa',
field_1=self.uuid0,
field_2='aaaa',
field_3='bbbb',
data='Foo'
)

manager.create(
field_1='aaaa',
field_1=self.uuid0,
field_2='bbbb',
field_3='cccc',
data='Tao'
)

self.uuid1 = str(uuid.uuid4())
manager.create(
field_1='bbbb',
field_1=self.uuid1,
field_2='aaaa',
field_3='aaaa',
data='Bar'
)

manager.create(
field_1='bbbb',
field_1=self.uuid1,
field_2='bbbb',
field_3='aaaa',
data='Lel'
Expand Down Expand Up @@ -182,9 +186,9 @@ def test_pk_filter(self):
manager = ClusterPrimaryKeyModel.objects
all_rows = list(manager.all())

filtered_rows = list(manager.filter(field_1='bbbb'))
filtered_rows = list(manager.filter(field_1=self.uuid1))

filtered_rows_inmem = [r for r in all_rows if r.pk == 'bbbb']
filtered_rows_inmem = [r for r in all_rows if r.pk == self.uuid1]

self.assertEqual(
len(filtered_rows),
Expand All @@ -203,7 +207,7 @@ def test_clustering_key_filter(self):

with warnings.catch_warnings(record=True) as w:
filtered_rows = list(manager.filter(
field_1='bbbb',
field_1=self.uuid1,
field_2='aaaa',
field_3='aaaa'
))
Expand All @@ -215,7 +219,7 @@ def test_clustering_key_filter(self):

with warnings.catch_warnings(record=True) as w:
filtered_rows = list(manager.filter(
field_1='bbbb'
field_1=self.uuid1
).filter(
field_2='aaaa'
).filter(
Expand All @@ -229,7 +233,7 @@ def test_clustering_key_filter(self):

filtered_rows_inmem = [
r for r in all_rows if
r.field_1 == 'bbbb' and
r.field_1 == self.uuid1 and
r.field_2 == 'aaaa' and
r.field_3 == 'aaaa'
]
Expand All @@ -247,20 +251,20 @@ def test_clustering_key_filter(self):

def test_orderby(self):
manager = ClusterPrimaryKeyModel.objects
filtered_rows = list(manager.filter(field_1='bbbb'))
filtered_rows = list(manager.filter(field_1=self.uuid1))

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')

filtered_rows_ordered = list(
manager.filter(
field_1='bbbb'
field_1=self.uuid1
).order_by('field_2')
)

filtered_rows_ordered_desc = list(
manager.filter(
field_1='bbbb'
field_1=self.uuid1
).order_by('-field_2')
)

Expand Down

0 comments on commit 6798eb0

Please sign in to comment.