Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ Explore the documentation to learn about all features:

* [Materialized views](/materialized_views)

* [Indexes](/indexes)
* [ConditionalUniqueIndex](/indexes/#conditional-unique-index)

## Installation

1. Install the package from PyPi:
Expand Down
46 changes: 46 additions & 0 deletions docs/indexes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
## Conditional Unique Index

The `ConditionalUniqueIndex` lets you create partial unique indexes in case you ever need `unique together` constraints
on nullable columns.

e.g.

Before:

```
from django.db import models

class Model(models.Model):
class Meta:
unique_together = ['a', 'b'']

a = models.ForeignKey('some_model', null=True)
b = models.ForeignKey('some_other_model')

# Works like a charm!
b = B()
Model.objects.create(a=None, b=b)
Model.objects.create(a=None, b=b)
```

After:

```
from django.db import models
from from psqlextra.indexes import ConditionalUniqueIndex

class Model(models.Model):
class Meta:
indexes = [
ConditionalUniqueIndex(fields=['a', 'b'], condition='"a" IS NOT NULL'),
ConditionalUniqueIndex(fields=['b'], condition='"a" IS NULL')
]

a = models.ForeignKey('some_model', null=True)
b = models.ForeignKey('some_other_model')

# Integrity Error!
b = B()
Model.objects.create(a=None, b=b)
Model.objects.create(a=None, b=b)
```
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pages:
- HStore: hstore.md
- Signals: signals.md
- Materialized Views: materialized_views.md
- Indexes: indexes.md
5 changes: 5 additions & 0 deletions psqlextra/indexes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .conditional_unique_index import ConditionalUniqueIndex

__all__ = [
'ConditionalUniqueIndex'
]
39 changes: 39 additions & 0 deletions psqlextra/indexes/conditional_unique_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from django.db.models.indexes import Index


class ConditionalUniqueIndex(Index):
"""
Creates a partial unique index based on a given condition.

Useful, for example, if you need unique combination of foreign keys, but you might want to include
NULL as a valid value. In that case, you can just use:
>>> class Meta:
... indexes = [
... ConditionalUniqueIndex(fields=['a', 'b', 'c'], condition='"c" IS NOT NULL'),
... ConditionalUniqueIndex(fields=['a', 'b'], condition='"c" IS NULL')
... ]
"""

sql_create_index = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s WHERE %(condition)s"

def __init__(self, condition: str, fields=[], name=None):
"""Initializes a new instance of :see:ConditionalUniqueIndex."""

super().__init__(fields=fields, name=name)
self.condition = condition

def create_sql(self, model, schema_editor, using=''):
"""Creates the actual SQL used when applying the migration."""

sql_create_index = self.sql_create_index
sql_parameters = {
**Index.get_sql_create_template_values(self, model, schema_editor, using),
'condition': self.condition
}
return sql_create_index % sql_parameters

def deconstruct(self):
"""Serializes the :see:ConditionalUniqueIndex for the migrations file."""
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
path = path.replace('django.db.models.indexes', 'django.db.models')
return path, (), {'fields': self.fields, 'name': self.name, 'condition': self.condition}
18 changes: 15 additions & 3 deletions tests/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,17 +357,29 @@ def make_migrations(self):
self.project_state = new_project_state
return migration

def migrate(self):
"""Executes the recorded migrations."""
def migrate(self, *filters: List[str]):
"""
Executes the recorded migrations.

Arguments:
filters: List of strings to filter SQL statements on.

Returns:
The filtered calls of every migration
"""

calls_for_migrations = []
while len(self.migrations) > 0:
migration = self.migrations.pop()

with connection.schema_editor() as schema_editor:
with filtered_schema_editor(*filters) as (schema_editor, calls):
migration_executor = MigrationExecutor(schema_editor.connection)
migration_executor.apply_migration(
self.project_state, migration
)
calls_for_migrations.append(calls)

return calls_for_migrations

def _generate_random_name(self):
return str(uuid.uuid4()).replace('-', '')[:8]
75 changes: 75 additions & 0 deletions tests/test_conditional_unique_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest

from psqlextra.indexes import ConditionalUniqueIndex
from .migrations import MigrationSimulator

from django.db import models, IntegrityError, transaction
from django.db.migrations import AddIndex, CreateModel


def test_deconstruct():
"""Tests whether the :see:HStoreField's deconstruct()
method works properly."""

original_kwargs = dict(condition='field IS NULL', name='great_index', fields=['field', 'build'])
_, _, new_kwargs = ConditionalUniqueIndex(**original_kwargs).deconstruct()

for key, value in original_kwargs.items():
assert new_kwargs[key] == value


def test_migrations():
"""Tests whether the migrations are properly generated and executed."""

simulator = MigrationSimulator()

Model = simulator.define_model(
fields={
'id': models.IntegerField(primary_key=True),
'name': models.CharField(max_length=255, null=True),
'other_name': models.CharField(max_length=255)
},
meta_options={
'indexes': [
ConditionalUniqueIndex(
fields=['name', 'other_name'],
condition='"name" IS NOT NULL',
name='index1'
),
ConditionalUniqueIndex(
fields=['other_name'],
condition='"name" IS NULL',
name='index2'
)
]
}
)

migration = simulator.make_migrations()
assert len(migration.operations) == 3

operations = migration.operations
assert isinstance(operations[0], CreateModel)

for operation in operations[1:]:
assert isinstance(operation, AddIndex)

calls = [call[0] for _, call, _ in simulator.migrate('CREATE UNIQUE INDEX')[0]['CREATE UNIQUE INDEX']]

db_table = Model._meta.db_table
assert calls[0] == 'CREATE UNIQUE INDEX "index1" ON "{0}" ("name", "other_name") WHERE "name" IS NOT NULL'.format(
db_table
)
assert calls[1] == 'CREATE UNIQUE INDEX "index2" ON "{0}" ("other_name") WHERE "name" IS NULL'.format(
db_table
)

with transaction.atomic():
Model.objects.create(id=1, name="name", other_name="other_name")
with pytest.raises(IntegrityError):
Model.objects.create(id=2, name="name", other_name="other_name")

with transaction.atomic():
Model.objects.create(id=1, name=None, other_name="other_name")
with pytest.raises(IntegrityError):
Model.objects.create(id=2, name=None, other_name="other_name")