Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
branch: master
542 lines (452 sloc) 20.024 kb
from __future__ import unicode_literals
from operator import attrgetter
from django import VERSION
from django.contrib.contenttypes.models import ContentType
from django.db import models, router
from django.db.models.fields import Field
from django.db.models.fields.related import (add_lazy_relation, ManyToManyRel,
OneToOneRel, RelatedField)
if VERSION < (1, 8):
# related.py was removed in Django 1.8
# Depending on how Django was updated, related.py could still exist
# on the users system even on Django 1.8+, so we check the Django
# version before importing it to make sure this doesn't get imported
# accidentally.
from django.db.models.related import RelatedObject
else:
RelatedObject = None
from django.utils import six
from django.utils.text import capfirst
from django.utils.translation import ugettext_lazy as _
from taggit.forms import TagField
from taggit.models import GenericTaggedItemBase, TaggedItem
from taggit.utils import _get_field, require_instance_manager
try:
from django.contrib.contenttypes.fields import GenericRelation
except ImportError: # django < 1.7
from django.contrib.contenttypes.generic import GenericRelation
try:
from django.db.models.query_utils import PathInfo
except ImportError: # Django < 1.8
try:
from django.db.models.related import PathInfo
except ImportError:
pass # PathInfo is not used on Django < 1.6
def _model_name(model):
if VERSION < (1, 7):
return model._meta.module_name
else:
return model._meta.model_name
class TaggableRel(ManyToManyRel):
def __init__(self, field, related_name, through, to=None):
self.to = to
self.related_name = related_name
self.limit_choices_to = {}
self.symmetrical = True
self.multiple = True
self.through = None if VERSION < (1, 7) else through
self.field = field
def get_joining_columns(self):
return self.field.get_reverse_joining_columns()
def get_extra_restriction(self, where_class, alias, related_alias):
return self.field.get_extra_restriction(where_class, related_alias, alias)
class ExtraJoinRestriction(object):
"""
An extra restriction used for contenttype restriction in joins.
"""
def __init__(self, alias, col, content_types):
self.alias = alias
self.col = col
self.content_types = content_types
def as_sql(self, qn, connection):
if len(self.content_types) == 1:
extra_where = "%s.%s = %%s" % (qn(self.alias), qn(self.col))
else:
extra_where = "%s.%s IN (%s)" % (qn(self.alias), qn(self.col),
','.join(['%s'] * len(self.content_types)))
return extra_where, self.content_types
def relabel_aliases(self, change_map):
self.alias = change_map.get(self.alias, self.alias)
def clone(self):
return self.__class__(self.alias, self.col, self.content_types[:])
class _TaggableManager(models.Manager):
def __init__(self, through, model, instance, prefetch_cache_name):
self.through = through
self.model = model
self.instance = instance
self.prefetch_cache_name = prefetch_cache_name
self._db = None
def is_cached(self, instance):
return self.prefetch_cache_name in instance._prefetched_objects_cache
def get_queryset(self, extra_filters=None):
try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
except (AttributeError, KeyError):
kwargs = extra_filters if extra_filters else {}
return self.through.tags_for(self.model, self.instance, **kwargs)
def get_prefetch_queryset(self, instances, queryset=None):
if queryset is not None:
raise ValueError("Custom queryset can't be used for this lookup.")
instance = instances[0]
from django.db import connections
db = self._db or router.db_for_read(instance.__class__, instance=instance)
fieldname = ('object_id' if issubclass(self.through, GenericTaggedItemBase)
else 'content_object')
fk = self.through._meta.get_field(fieldname)
query = {
'%s__%s__in' % (self.through.tag_relname(), fk.name):
set(obj._get_pk_val() for obj in instances)
}
join_table = self.through._meta.db_table
source_col = fk.column
connection = connections[db]
qn = connection.ops.quote_name
qs = self.get_queryset(query).using(db).extra(
select={
'_prefetch_related_val': '%s.%s' % (qn(join_table), qn(source_col))
}
)
return (qs,
attrgetter('_prefetch_related_val'),
lambda obj: obj._get_pk_val(),
False,
self.prefetch_cache_name)
# Django < 1.6 uses the previous name of query_set
get_query_set = get_queryset
get_prefetch_query_set = get_prefetch_queryset
def _lookup_kwargs(self):
return self.through.lookup_kwargs(self.instance)
@require_instance_manager
def add(self, *tags):
str_tags = set()
tag_objs = set()
for t in tags:
if isinstance(t, self.through.tag_model()):
tag_objs.add(t)
elif isinstance(t, six.string_types):
str_tags.add(t)
else:
raise ValueError("Cannot add {0} ({1}). Expected {2} or str.".format(
t, type(t), type(self.through.tag_model())))
# If str_tags has 0 elements Django actually optimizes that to not do a
# query. Malcolm is very smart.
existing = self.through.tag_model().objects.filter(
name__in=str_tags
)
tag_objs.update(existing)
for new_tag in str_tags - set(t.name for t in existing):
tag_objs.add(self.through.tag_model().objects.create(name=new_tag))
for tag in tag_objs:
self.through.objects.get_or_create(tag=tag, **self._lookup_kwargs())
@require_instance_manager
def names(self):
return self.get_queryset().values_list('name', flat=True)
@require_instance_manager
def slugs(self):
return self.get_queryset().values_list('slug', flat=True)
@require_instance_manager
def set(self, *tags):
self.clear()
self.add(*tags)
@require_instance_manager
def remove(self, *tags):
self.through.objects.filter(**self._lookup_kwargs()).filter(
tag__name__in=tags).delete()
@require_instance_manager
def clear(self):
self.through.objects.filter(**self._lookup_kwargs()).delete()
def most_common(self):
return self.get_queryset().annotate(
num_times=models.Count(self.through.tag_relname())
).order_by('-num_times')
@require_instance_manager
def similar_objects(self):
lookup_kwargs = self._lookup_kwargs()
lookup_keys = sorted(lookup_kwargs)
qs = self.through.objects.values(*six.iterkeys(lookup_kwargs))
qs = qs.annotate(n=models.Count('pk'))
qs = qs.exclude(**lookup_kwargs)
qs = qs.filter(tag__in=self.all())
qs = qs.order_by('-n')
# TODO: This all feels like a bit of a hack.
items = {}
if len(lookup_keys) == 1:
# Can we do this without a second query by using a select_related()
# somehow?
f = _get_field(self.through, lookup_keys[0])
objs = f.rel.to._default_manager.filter(**{
"%s__in" % f.rel.field_name: [r["content_object"] for r in qs]
})
for obj in objs:
items[(getattr(obj, f.rel.field_name),)] = obj
else:
preload = {}
for result in qs:
preload.setdefault(result['content_type'], set())
preload[result["content_type"]].add(result["object_id"])
for ct, obj_ids in preload.items():
ct = ContentType.objects.get_for_id(ct)
for obj in ct.model_class()._default_manager.filter(pk__in=obj_ids):
items[(ct.pk, obj.pk)] = obj
results = []
for result in qs:
obj = items[
tuple(result[k] for k in lookup_keys)
]
obj.similar_tags = result["n"]
results.append(obj)
return results
# _TaggableManager needs to be hashable but BaseManagers in Django 1.8+ overrides
# the __eq__ method which makes the default __hash__ method disappear.
# This checks if the __hash__ attribute is None, and if so, it reinstates the original method.
if models.Manager.__hash__ is None:
__hash__ = object.__hash__
class TaggableManager(RelatedField, Field):
# Field flags
many_to_many = True
many_to_one = False
one_to_many = False
one_to_one = False
_related_name_counter = 0
def __init__(self, verbose_name=_("Tags"),
help_text=_("A comma-separated list of tags."),
through=None, blank=False, related_name=None, to=None,
manager=_TaggableManager):
self.through = through or TaggedItem
self.swappable = False
self.manager = manager
rel = TaggableRel(self, related_name, self.through, to=to)
Field.__init__(
self,
verbose_name=verbose_name,
help_text=help_text,
blank=blank,
null=True,
serialize=False,
rel=rel,
)
# NOTE: `to` is ignored, only used via `deconstruct`.
def __get__(self, instance, model):
if instance is not None and instance.pk is None:
raise ValueError("%s objects need to have a primary key value "
"before you can access their tags." % model.__name__)
manager = self.manager(
through=self.through,
model=model,
instance=instance,
prefetch_cache_name=self.name
)
return manager
def deconstruct(self):
"""
Deconstruct the object, used with migrations.
"""
name, path, args, kwargs = super(TaggableManager, self).deconstruct()
# Remove forced kwargs.
for kwarg in ('serialize', 'null'):
del kwargs[kwarg]
# Add arguments related to relations.
# Ref: https://github.com/alex/django-taggit/issues/206#issuecomment-37578676
if isinstance(self.rel.through, six.string_types):
kwargs['through'] = self.rel.through
elif not self.rel.through._meta.auto_created:
kwargs['through'] = "%s.%s" % (self.rel.through._meta.app_label, self.rel.through._meta.object_name)
if isinstance(self.rel.to, six.string_types):
kwargs['to'] = self.rel.to
else:
kwargs['to'] = '%s.%s' % (self.rel.to._meta.app_label, self.rel.to._meta.object_name)
return name, path, args, kwargs
def contribute_to_class(self, cls, name):
if VERSION < (1, 7):
self.name = self.column = self.attname = name
else:
self.set_attributes_from_name(name)
self.model = cls
cls._meta.add_field(self)
setattr(cls, name, self)
if not cls._meta.abstract:
if isinstance(self.rel.to, six.string_types):
def resolve_related_class(field, model, cls):
field.rel.to = model
add_lazy_relation(cls, self, self.rel.to, resolve_related_class)
if isinstance(self.through, six.string_types):
def resolve_related_class(field, model, cls):
self.through = model
self.rel.through = model
self.post_through_setup(cls)
add_lazy_relation(
cls, self, self.through, resolve_related_class
)
else:
self.post_through_setup(cls)
def get_internal_type(self):
return 'ManyToManyField'
def __lt__(self, other):
"""
Required contribute_to_class as Django uses bisect
for ordered class contribution and bisect requires
a orderable type in py3.
"""
return False
def post_through_setup(self, cls):
if RelatedObject is not None: # Django < 1.8
self.related = RelatedObject(cls, self.model, self)
self.use_gfk = (
self.through is None or issubclass(self.through, GenericTaggedItemBase)
)
if not self.rel.to:
self.rel.to = self.through._meta.get_field("tag").rel.to
if RelatedObject is not None: # Django < 1.8
self.related = RelatedObject(self.through, cls, self)
if self.use_gfk:
tagged_items = GenericRelation(self.through)
tagged_items.contribute_to_class(cls, 'tagged_items')
for rel in cls._meta.local_many_to_many:
if rel == self or not isinstance(rel, TaggableManager):
continue
if rel.through == self.through:
raise ValueError('You can\'t have two TaggableManagers with the'
' same through model.')
def save_form_data(self, instance, value):
getattr(instance, self.name).set(*value)
def formfield(self, form_class=TagField, **kwargs):
defaults = {
"label": capfirst(self.verbose_name),
"help_text": self.help_text,
"required": not self.blank
}
defaults.update(kwargs)
return form_class(**defaults)
def value_from_object(self, instance):
if instance.pk:
return self.through.objects.filter(**self.through.lookup_kwargs(instance))
return self.through.objects.none()
def related_query_name(self):
return _model_name(self.model)
def m2m_reverse_name(self):
return _get_field(self.through, 'tag').column
def m2m_reverse_field_name(self):
return _get_field(self.through, 'tag').name
def m2m_target_field_name(self):
return self.model._meta.pk.name
def m2m_reverse_target_field_name(self):
return self.rel.to._meta.pk.name
def m2m_column_name(self):
if self.use_gfk:
return self.through._meta.virtual_fields[0].fk_field
return self.through._meta.get_field('content_object').column
def db_type(self, connection=None):
return None
def m2m_db_table(self):
return self.through._meta.db_table
def bulk_related_objects(self, new_objs, using):
return []
def extra_filters(self, pieces, pos, negate):
if negate or not self.use_gfk:
return []
prefix = "__".join(["tagged_items"] + pieces[:pos - 2])
get = ContentType.objects.get_for_model
cts = [get(obj) for obj in _get_subclasses(self.model)]
if len(cts) == 1:
return [("%s__content_type" % prefix, cts[0])]
return [("%s__content_type__in" % prefix, cts)]
def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias):
model_name = _model_name(self.through)
if rhs_alias == '%s_%s' % (self.through._meta.app_label, model_name):
alias_to_join = rhs_alias
else:
alias_to_join = lhs_alias
extra_col = _get_field(self.through, 'content_type').column
content_type_ids = [ContentType.objects.get_for_model(subclass).pk for
subclass in _get_subclasses(self.model)]
if len(content_type_ids) == 1:
content_type_id = content_type_ids[0]
extra_where = " AND %s.%s = %%s" % (qn(alias_to_join),
qn(extra_col))
params = [content_type_id]
else:
extra_where = " AND %s.%s IN (%s)" % (qn(alias_to_join),
qn(extra_col),
','.join(['%s'] *
len(content_type_ids)))
params = content_type_ids
return extra_where, params
# This and all the methods till the end of class are only used in django >= 1.6
def _get_mm_case_path_info(self, direct=False):
pathinfos = []
linkfield1 = _get_field(self.through, 'content_object')
linkfield2 = _get_field(self.through, self.m2m_reverse_field_name())
if direct:
join1infos = linkfield1.get_reverse_path_info()
join2infos = linkfield2.get_path_info()
else:
join1infos = linkfield2.get_reverse_path_info()
join2infos = linkfield1.get_path_info()
pathinfos.extend(join1infos)
pathinfos.extend(join2infos)
return pathinfos
def _get_gfk_case_path_info(self, direct=False):
pathinfos = []
from_field = self.model._meta.pk
opts = self.through._meta
object_id_field = _get_field(self.through, 'object_id')
linkfield = _get_field(self.through, self.m2m_reverse_field_name())
if direct:
join1infos = [PathInfo(self.model._meta, opts, [from_field], self.rel, True, False)]
join2infos = linkfield.get_path_info()
else:
join1infos = linkfield.get_reverse_path_info()
join2infos = [PathInfo(opts, self.model._meta, [object_id_field], self, True, False)]
pathinfos.extend(join1infos)
pathinfos.extend(join2infos)
return pathinfos
def get_path_info(self):
if self.use_gfk:
return self._get_gfk_case_path_info(direct=True)
else:
return self._get_mm_case_path_info(direct=True)
def get_reverse_path_info(self):
if self.use_gfk:
return self._get_gfk_case_path_info(direct=False)
else:
return self._get_mm_case_path_info(direct=False)
def get_joining_columns(self, reverse_join=False):
if reverse_join:
return (("id", "object_id"),)
else:
return (("object_id", "id"),)
def get_extra_restriction(self, where_class, alias, related_alias):
extra_col = _get_field(self.through, 'content_type').column
content_type_ids = [ContentType.objects.get_for_model(subclass).pk
for subclass in _get_subclasses(self.model)]
return ExtraJoinRestriction(related_alias, extra_col, content_type_ids)
def get_reverse_joining_columns(self):
return self.get_joining_columns(reverse_join=True)
@property
def related_fields(self):
return [(_get_field(self.through, 'object_id'), self.model._meta.pk)]
@property
def foreign_related_fields(self):
return [self.related_fields[0][1]]
def _get_subclasses(model):
subclasses = [model]
if VERSION < (1, 8):
all_fields = (_get_field(model, f) for f in model._meta.get_all_field_names())
else:
all_fields = model._meta.get_fields()
for field in all_fields:
# Django 1.8 +
if (not RelatedObject and isinstance(field, OneToOneRel) and
getattr(field.field.rel, "parent_link", None)):
subclasses.extend(_get_subclasses(field.related_model))
# < Django 1.8
if (RelatedObject and isinstance(field, RelatedObject) and
getattr(field.field.rel, "parent_link", None)):
subclasses.extend(_get_subclasses(field.model))
return subclasses
# `total_ordering` does not exist in Django 1.4, as such
# we special case this import to be py3k specific which
# is not supported by Django 1.4
if six.PY3:
from django.utils.functional import total_ordering
TaggableManager = total_ordering(TaggableManager)
Jump to Line
Something went wrong with that request. Please try again.