diff --git a/zinnia/comparison.py b/zinnia/comparison.py index 9557338dc..39120addb 100644 --- a/zinnia/comparison.py +++ b/zinnia/comparison.py @@ -6,9 +6,12 @@ from django.utils import six from django.core.cache import caches from django.utils.html import strip_tags +from django.utils.functional import cached_property from django.core.cache import InvalidCacheBackendError +from zinnia.models.entry import Entry from zinnia.settings import STOP_WORDS +from zinnia.settings import COMPARISON_FIELDS PUNCTUATION = dict.fromkeys( @@ -55,13 +58,18 @@ class VectorBuilder(object): """ Build a list of vectors based on datasets. """ + fields = None + queryset = None - def __init__(self, queryset, fields): - self.clustered_model = ClusteredModel(queryset, fields) + def __init__(self, **kwargs): + self.fields = kwargs.pop('fields', self.fields) + self.queryset = kwargs.pop('queryset', self.queryset) + self.clustered_model = ClusteredModel(self.queryset, self.fields) - def build_dataset(self): + @cached_property + def columns_dataset(self): """ - Generate the whole dataset. + Generate the columns and the whole dataset. """ data = {} words_total = {} @@ -83,30 +91,45 @@ def build_dataset(self): for word in columns] return columns, dataset - def columns_dataset(self): - """ - Cache system for columns and dataset. - """ - cache = get_comparison_cache() - columns_dataset = cache.get('vectors') - if not columns_dataset: - columns_dataset = self.build_dataset() - cache.set('vectors', columns_dataset) - return columns_dataset - @property def columns(self): """ Access to columns. """ - return self.columns_dataset()[0] + return self.columns_dataset[0] @property def dataset(self): """ Access to dataset. """ - return self.columns_dataset()[1] + return self.columns_dataset[1] + + +class CachedVectorBuilder(VectorBuilder): + """ + Cached version of VectorBuilder. + """ + + @property + def columns_dataset(self): + """ + Implement high level cache system for columns and dataset. + """ + cache = get_comparison_cache() + columns_dataset = cache.get('vectors') + if not columns_dataset: + columns_dataset = super(CachedVectorBuilder, self).columns_dataset + cache.set('vectors', columns_dataset) + return columns_dataset + + +class EntryPublishedVectorBuilder(CachedVectorBuilder): + """ + Vector builder for published entries. + """ + queryset = Entry.published + fields = COMPARISON_FIELDS def pearson_score(list1, list2): diff --git a/zinnia/templatetags/zinnia.py b/zinnia/templatetags/zinnia.py index e603d2111..cc534342a 100644 --- a/zinnia/templatetags/zinnia.py +++ b/zinnia/templatetags/zinnia.py @@ -34,16 +34,12 @@ from ..managers import tags_published from ..flags import PINGBACK, TRACKBACK from ..settings import PROTOCOL -from ..settings import COMPARISON_FIELDS -from ..comparison import VectorBuilder from ..comparison import compute_related from ..comparison import get_comparison_cache +from ..comparison import EntryPublishedVectorBuilder from ..calendar import Calendar from ..breadcrumbs import retrieve_breadcrumbs -register = Library() - -VECTORS = VectorBuilder(Entry.published, COMPARISON_FIELDS) WIDONT_REGEXP = re.compile( r'\s+(\S+\s*)$') @@ -52,6 +48,8 @@ END_PUNCTUATION_WIDONT_REGEXP = re.compile( r'\s+([?!]+\s*)$') +register = Library() + @register.inclusion_tag('zinnia/tags/dummy.html', takes_context=True) def get_categories(context, template='zinnia/tags/categories.html'): @@ -149,8 +147,9 @@ def get_similar_entries(context, number=5, cache_related = cache.get('related_entries', {}) if cache_key not in cache_related: + vectors = EntryPublishedVectorBuilder() related_entry_pks = compute_related( - entry.pk, VECTORS.dataset)[:number] + entry.pk, vectors.dataset)[:number] related_entries = sorted( Entry.objects.filter(pk__in=related_entry_pks), key=lambda x: related_entry_pks.index(x.pk)) diff --git a/zinnia/tests/test_comparison.py b/zinnia/tests/test_comparison.py index 434f41e94..7743ccfc0 100644 --- a/zinnia/tests/test_comparison.py +++ b/zinnia/tests/test_comparison.py @@ -6,7 +6,6 @@ from zinnia.comparison import compute_related from zinnia.comparison import VectorBuilder from zinnia.comparison import ClusteredModel -from zinnia.comparison import get_comparison_cache from zinnia.signals import disconnect_entry_signals @@ -33,10 +32,8 @@ def test_clustered_model(self): ' entry 2 content 2 zinnia test'])) def test_vector_builder(self): - cache = get_comparison_cache() - cache.delete('vectors') - vectors = VectorBuilder(Entry.objects.all(), - ['title', 'excerpt', 'content']) + vectors = VectorBuilder(queryset=Entry.objects.all(), + fields=['title', 'excerpt', 'content']) self.assertEqual(vectors.dataset, {}) self.assertEqual(vectors.columns, []) params = {'title': 'My entry 1', 'content': @@ -47,9 +44,8 @@ def test_vector_builder(self): 'My second entry', 'slug': 'my-entry-2'} e2 = Entry.objects.create(**params) - self.assertEqual(vectors.dataset, {}) - self.assertEqual(vectors.columns, []) - cache.delete('vectors') + vectors = VectorBuilder(queryset=Entry.objects.all(), + fields=['title', 'excerpt', 'content']) self.assertEqual(sorted(vectors.columns), sorted( ['1', '2', 'content', 'entry'])) self.assertEqual(sorted(vectors.dataset[e1.pk]), [0, 1, 1, 1]) diff --git a/zinnia/tests/test_templatetags.py b/zinnia/tests/test_templatetags.py index aef989c0d..5d317a6ab 100644 --- a/zinnia/tests/test_templatetags.py +++ b/zinnia/tests/test_templatetags.py @@ -29,7 +29,6 @@ from zinnia.signals import disconnect_entry_signals from zinnia.signals import disconnect_discussion_signals from zinnia.signals import flush_similar_cache_handler -from zinnia.signals import ENTRY_PS_FLUSH_SIMILAR_CACHE from zinnia.templatetags.zinnia import widont from zinnia.templatetags.zinnia import week_number from zinnia.templatetags.zinnia import get_authors @@ -253,7 +252,7 @@ def test_get_popular_entries(self): def test_get_similar_entries(self): post_save.connect( flush_similar_cache_handler, sender=Entry, - dispatch_uid=ENTRY_PS_FLUSH_SIMILAR_CACHE) + dispatch_uid='flush_cache') self.publish_entry() source_context = Context({'object': self.entry}) with self.assertNumQueries(0): @@ -285,19 +284,16 @@ def test_get_similar_entries(self): third_entry = Entry.objects.create(**params) third_entry.sites.add(self.site) - source_context = Context({'entry': second_entry}) with self.assertNumQueries(2): context = get_similar_entries(source_context, 3, 'custom_template.html') self.assertEqual(len(context['entries']), 2) - self.assertEqual(context['entries'][0].pk, third_entry.pk) + self.assertEqual(context['entries'][0].pk, second_entry.pk) self.assertEqual(context['template'], 'custom_template.html') with self.assertNumQueries(0): - context = get_similar_entries(source_context, 3, - 'custom_template.html') + context = get_similar_entries(source_context, 3) post_save.disconnect( - sender=Entry, - dispatch_uid=ENTRY_PS_FLUSH_SIMILAR_CACHE) + sender=Entry, dispatch_uid='flush_cache') def test_get_archives_entries(self): with self.assertNumQueries(0):