diff --git a/api/tests/test_views.py b/api/tests/test_views.py index 74cb3156..c091964c 100644 --- a/api/tests/test_views.py +++ b/api/tests/test_views.py @@ -1,10 +1,10 @@ -import io import json -import csv from datetime import timedelta from django.test import TestCase, RequestFactory, mock from django.contrib.auth import get_user_model +from django.core import exceptions +from django.http import HttpResponse from rest_framework import exceptions @@ -17,7 +17,7 @@ ScoreSetFactory, ExperimentFactory, ExperimentSetFactory ) -from variant.factories import dna_hgvs, protein_hgvs +from variant.factories import VariantFactory from variant.models import Variant from .. import views @@ -268,6 +268,148 @@ def test_list_includes_private_when_authenticated(self): self.assertContains(response, instance2.urn) +class TestFormatCSVRows(TestCase): + def test_dicts_include_urn(self): + vs = [VariantFactory() for _ in range(5)] + rows = views.format_csv_rows( + vs, columns=['score', 'urn', ], + dtype=constants.variant_score_data + ) + for v, row in zip(vs, rows): + self.assertEqual(v.urn, row['urn']) + + def test_dicts_include_nt_hgvs(self): + vs = [VariantFactory() for _ in range(5)] + rows = views.format_csv_rows( + vs, columns=['score', constants.hgvs_nt_column, ], + dtype=constants.variant_score_data + ) + for v, row in zip(vs, rows): + self.assertEqual(v.hgvs_nt, row[constants.hgvs_nt_column]) + + def test_dicts_include_pro_hgvs(self): + vs = [VariantFactory() for _ in range(5)] + rows = views.format_csv_rows( + vs, columns=['score', constants.hgvs_pro_column, ], + dtype=constants.variant_score_data + ) + for v, row in zip(vs, rows): + self.assertEqual(v.hgvs_pro, row[constants.hgvs_pro_column]) + + def test_dicts_include_data_columns_as_strings(self): + vs = [VariantFactory(data={ + constants.variant_score_data: {'score': 1, 'se': 2.12}}) + for _ in range(5) + ] + rows = views.format_csv_rows( + vs, columns=['score', 'se', ], + dtype=constants.variant_score_data + ) + for v, row in zip(vs, rows): + self.assertEqual('1', row['score']) + self.assertEqual('2.12', row['se']) + + +class TestValidateRequest(TestCase): + def setUp(self): + self.factory = RequestFactory() + self.user = UserFactory() + self.request = self.factory.get('/') + self.request.user = self.user + self.instance = ScoreSetFactory(private=True) + + def test_returns_404_response_when_urn_model_not_found(self): + response = views.validate_request(self.request, 'urn') + self.assertEqual(response.status_code, 404) + + @mock.patch('api.views.validate_request') + def test_calls_authenticate(self, patch): + views.validate_request(self.request, self.instance.urn) + patch.assert_called() + + @mock.patch('api.views.validate_request') + def test_calls_check_permission(self, patch): + views.validate_request(self.request, self.instance.urn) + patch.assert_called() + + +class TestFormatResponse(TestCase): + def setUp(self): + self.user = UserFactory() + self.instance = ScoreSetFactory(private=True) + self.response = HttpResponse(content_type='text/csv') + + def test_adds_comments_to_response(self): + response = views.format_response( + self.response, self.instance, dtype='scores') + content = response.content.decode() + self.assertIn("# URN: {}".format(self.instance.urn), content) + self.assertIn("# Downloaded (UTC):", content) + self.assertIn("# Licence: {}".format( + self.instance.licence.long_name), content) + self.assertIn("# Licence URL: {}".format( + self.instance.licence.link), content) + + def test_raises_valueerror_unknown_dtype(self): + with self.assertRaises(ValueError): + views.format_response(self.response, self.instance, dtype='---') + + @mock.patch("api.views.format_csv_rows") + def test_calls_format_csv_correct_call_dtype_is_scores(self, patch): + self.instance.dataset_columns = { + constants.score_columns: ['score', 'se']} + self.instance.save() + for i in range(5): + data = {constants.variant_score_data: {'score': i, 'se': 2*i}} + VariantFactory(scoreset=self.instance, data=data) + + _ = views.format_response( + self.response, self.instance, dtype='scores') + + called_dtype = patch.call_args[1]['dtype'] + called_columns = patch.call_args[1]['columns'] + expected_columns = ['urn'] + self.instance.score_columns + self.assertEqual(called_dtype, constants.variant_score_data) + self.assertListEqual(called_columns, expected_columns) + + @mock.patch("api.views.format_csv_rows") + def test_calls_format_csv_correct_call_dtype_is_counts(self, patch): + self.instance.dataset_columns = { + constants.count_columns: ['count', 'se']} + self.instance.save() + for i in range(5): + data = {constants.variant_count_data: {'count': i, 'se': 2*i}} + VariantFactory(scoreset=self.instance, data=data) + + _ = views.format_response( + self.response, self.instance, dtype='counts') + + called_dtype = patch.call_args[1]['dtype'] + called_columns = patch.call_args[1]['columns'] + expected_columns = ['urn'] + self.instance.count_columns + self.assertEqual(called_dtype, constants.variant_count_data) + self.assertListEqual(called_columns, expected_columns) + + @mock.patch("api.views.format_csv_rows") + def test_returns_empty_csv_when_no_additional_columns_present(self, patch): + _ = views.format_response( + self.response, self.instance, dtype='scores') + patch.assert_not_called() + + def test_double_quotes_column_values_containing_commas(self): + self.instance.dataset_columns = { + constants.score_columns: ['hello,world',]} + + for i in range(5): + data = {constants.variant_score_data: {'hello,world': i}} + VariantFactory(scoreset=self.instance, data=data) + + response = views.format_response( + self.response, self.instance, dtype='scores') + content = response.content.decode() + self.assertIn('"hello,world"', content) + + class TestScoreSetAPIViews(TestCase): factory = ScoreSetFactory url = 'scoresets' @@ -388,183 +530,24 @@ def test_OK_private_download_meta_when_authenticated(self): def test_404_not_found(self): response = self.client.get("/api/scoresets/dddd/") self.assertEqual(response.status_code, 404) - - def test_can_download_scores(self): - scs = self.factory() - scs = publish_dataset(scs) - scs.refresh_from_db() - scs.dataset_columns = { - constants.score_columns: ["score"], - constants.count_columns: ["count"] - } - scs.save() - variant = Variant.objects.create( - hgvs_nt=dna_hgvs[0], hgvs_pro=protein_hgvs[0], - scoreset=scs, data={ - constants.variant_score_data: {"score": "1"}, - constants.variant_count_data: {"count": "1"} - } - ) - - response = self.client.get("/api/scoresets/{}/scores/".format(scs.urn)) - rows = list( - csv.reader( - io.TextIOWrapper( - io.BytesIO(response.content), encoding='utf-8'))) - - header = [constants.hgvs_nt_column, constants.hgvs_pro_column, 'score'] - data = [variant.hgvs_nt, variant.hgvs_pro, '1'] - self.assertEqual(rows, [header, data]) - - def test_comma_in_value_enclosed_by_quotes(self): - scs = self.factory() - scs = publish_dataset(scs) - scs.refresh_from_db() - scs.dataset_columns = { - constants.score_columns: ["score"], - constants.count_columns: ["count,count"] - } - scs.save(save_parents=True) - variant = Variant.objects.create( - hgvs_nt=dna_hgvs[0], hgvs_pro=protein_hgvs[0], - scoreset=scs, data={ - constants.variant_score_data: {"score": "1"}, - constants.variant_count_data: {"count,count": "4"} - } - ) - response = self.client.get("/api/scoresets/{}/counts/".format(scs.urn)) - rows = list( - csv.reader( - io.TextIOWrapper( - io.BytesIO(response.content), encoding='utf-8'))) - - header = [constants.hgvs_nt_column, constants.hgvs_pro_column, 'count,count'] - data = [variant.hgvs_nt, variant.hgvs_pro, '4'] - self.assertEqual(rows, [header, data]) - - def test_can_download_counts(self): - scs = self.factory() - scs = publish_dataset(scs) - scs.refresh_from_db() - scs.dataset_columns = { - constants.score_columns: ["score"], - constants.count_columns: ["count"] - } - scs.save(save_parents=True) - variant = Variant.objects.create( - hgvs_nt=dna_hgvs[0], hgvs_pro=protein_hgvs[0], - scoreset=scs, data={ - constants.variant_score_data: {"score": "1"}, - constants.variant_count_data: {"count": "4"} - } - ) - response = self.client.get("/api/scoresets/{}/counts/".format(scs.urn)) - rows = list( - csv.reader( - io.TextIOWrapper( - io.BytesIO(response.content), encoding='utf-8'))) - - header = [constants.hgvs_nt_column, constants.hgvs_pro_column, 'count'] - data = [variant.hgvs_nt, variant.hgvs_pro, '4'] - self.assertEqual(rows, [header, data]) - - def test_none_hgvs_written_as_blank(self): - scs = self.factory() - scs = publish_dataset(scs) - scs.refresh_from_db() - scs.dataset_columns = { - constants.score_columns: ["score"], - constants.count_columns: ["count"] - } - scs.save(save_parents=True) - variant = Variant.objects.create( - hgvs_nt=dna_hgvs[0], hgvs_pro=None, - scoreset=scs, - data={ - constants.variant_score_data: {"score": "1"}, - constants.variant_count_data: {"count": "4"} - } - ) - response = self.client.get("/api/scoresets/{}/scores/".format(scs.urn)) - rows = list( - csv.reader( - io.TextIOWrapper( - io.BytesIO(response.content), encoding='utf-8'))) - - header = [constants.hgvs_nt_column, constants.hgvs_pro_column, 'score'] - data = [variant.hgvs_nt, '', '1'] - self.assertEqual(rows, [header, data]) - - def test_no_variants_empty_file(self): - scs = self.factory() - scs = publish_dataset(scs) - scs.refresh_from_db() - scs.dataset_columns = { - constants.score_columns: ["score"], - constants.count_columns: ["count"] - } - scs.save(save_parents=True) - scs.children.delete() - - response = self.client.get("/api/scoresets/{}/scores/".format(scs.urn)) - rows = list( - csv.reader( - io.TextIOWrapper( - io.BytesIO(response.content), encoding='utf-8'))) - self.assertEqual(rows, []) - - response = self.client.get("/api/scoresets/{}/counts/".format(scs.urn)) - rows = list( - csv.reader( - io.TextIOWrapper( - io.BytesIO(response.content), encoding='utf-8'))) - self.assertEqual(rows, []) - - def test_empty_scores_returns_empty_file(self): - scs = self.factory() - scs = publish_dataset(scs) - scs.refresh_from_db() - scs.dataset_columns = { - constants.score_columns: [], - constants.count_columns: ['count'] - } - scs.save(save_parents=True) - _ = Variant.objects.create( - hgvs_nt=dna_hgvs[0], hgvs_pro=protein_hgvs[0], - scoreset=scs, data={ - constants.variant_score_data: {}, - constants.variant_count_data: {"count": "4"} - } - ) - response = self.client.get("/api/scoresets/{}/scores/".format(scs.urn)) - rows = list( - csv.reader( - io.TextIOWrapper( - io.BytesIO(response.content), encoding='utf-8'))) - self.assertEqual(rows, []) - - def test_empty_counts_returns_empty_file(self): - scs = self.factory() - scs = publish_dataset(scs) - scs.refresh_from_db() - scs.dataset_columns = { - constants.score_columns: ["score"], - constants.count_columns: [] - } - scs.save(save_parents=True) - _ = Variant.objects.create( - hgvs_nt=dna_hgvs[0], hgvs_pro=protein_hgvs[0], - scoreset=scs, data={ - constants.variant_score_data: {"score": "1"}, - constants.variant_count_data: {} - } - ) - response = self.client.get("/api/scoresets/{}/counts/".format(scs.urn)) - rows = list( - csv.reader( - io.TextIOWrapper( - io.BytesIO(response.content), encoding='utf-8'))) - self.assertEqual(rows, []) + + @mock.patch("api.views.format_response") + def test_calls_format_response_with_dtype_scores(self, patch): + request = RequestFactory().get('/') + request.user = UserFactory() + instance = self.factory(private=False) + instance.add_viewers(request.user) + views.scoreset_score_data(request, instance.urn) + self.assertEqual(patch.call_args[1]['dtype'], 'scores') + + @mock.patch("api.views.format_response") + def test_calls_format_response_with_dtype_counts(self, patch): + request = RequestFactory().get('/') + request.user = UserFactory() + instance = self.factory(private=False) + instance.add_viewers(request.user) + views.scoreset_count_data(request, instance.urn) + self.assertEqual(patch.call_args[1]['dtype'], 'counts') def test_can_download_metadata(self): scs = self.factory(private=False) diff --git a/api/views.py b/api/views.py index f26cd7b5..a628736f 100644 --- a/api/views.py +++ b/api/views.py @@ -1,4 +1,6 @@ +import string import csv +from datetime import datetime from rest_framework import viewsets, exceptions @@ -59,6 +61,9 @@ def check_permission(instance, user=None): # ViewSet CBVs for list/detail views # --------------------------------------------------------------------------- # class AuthenticatedViewSet(viewsets.ReadOnlyModelViewSet): + user = None + auth_token = None + def dispatch(self, request, *args, **kwargs): try: self.user, self.auth_token = authenticate(request) @@ -123,6 +128,23 @@ class UserViewset(AuthenticatedViewSet): # File download FBVs # --------------------------------------------------------------------------- # def validate_request(request, urn): + """ + Validates an incoming request using the token in the auth header or checks + session authentication. Also checks if urn exists. + + Returns JSON response on any error. + + Parameters + ---------- + request : object + Incoming request object. + urn : str + URN of the scoreset. + + Returns + ------- + `JsonResponse` + """ try: if not ScoreSet.objects.filter(urn=urn).count(): raise exceptions.NotFound() @@ -136,80 +158,121 @@ def validate_request(request, urn): return JsonResponse({'detail': e.detail}, status=e.status_code) -@cache_page(60 * 15) # 15 minute cache -def scoreset_score_data(request, urn): - response = HttpResponse(content_type='text/csv') - response['Content-Disposition'] = \ - 'attachment; filename="{}_scores.csv"'.format(urn) +def format_csv_rows(variants, columns, dtype): + """ + Formats each variant into a dictionary row containing the keys specified + in `columns`. + + Parameters + ---------- + variants : list[variant.models.Variant`] + List of variants. + columns : list[str] + Columns to serialize. + dtype : str, {'scores', 'counts'} + The type of data requested. Either the 'score_data' or 'count_data'. + + Returns + ------- + list[dict] + """ + rowdicts = [] + for variant in variants: + data = {} + for column_key in columns: + if column_key == constants.hgvs_nt_column: + data[column_key] = str(variant.hgvs_nt) + elif column_key == constants.hgvs_pro_column: + data[column_key] = str(variant.hgvs_pro) + elif column_key == 'urn': + data[column_key] = str(variant.urn) + else: + data[column_key] = str(variant.data[dtype][column_key]) + rowdicts.append(data) + return rowdicts - instance_or_response = validate_request(request, urn) - if not isinstance(instance_or_response, ScoreSet): - return instance_or_response - scoreset = instance_or_response - order_by = 'id' # scoreset.primary_hgvs_column - variants = scoreset.children.order_by('{}'.format(order_by)) - columns = scoreset.score_columns - type_column = constants.variant_score_data +def urn_number(variant): + number = variant.urn.split('#')[-1] + if not str.isdigit(number): + return 0 + return int(number) + - # hgvs_nt and hgvs_pro present by default, hence <= 2 - if not variants or len(columns) <= 2: +def format_response(response, scoreset, dtype): + """ + Writes the CSV response by formatting each variant into a row including + the columns `hgvs_nt`, `hgvs_pro`, `urn` and other uploaded columns. + + Parameters + ---------- + response : `HttpResponse` + Reponse object to write to. + scoreset : `dataset.models.scoreset.ScoreSet` + The scoreset requested. + dtype : str + The type of data requested. Either 'scores' or 'counts'. + + Returns + ------- + `HttpResponse` + """ + response.writelines([ + "# URN: {}\n".format(scoreset.urn), + "# Downloaded (UTC): {}\n".format(datetime.utcnow()), + "# Licence: {}\n".format(scoreset.licence.long_name), + "# Licence URL: {}\n".format(scoreset.licence.link), + ]) + + variants = sorted( + scoreset.children.all(), key=lambda v: urn_number(v)) + + if dtype == 'scores': + columns = ['urn', ] + scoreset.score_columns + type_column = constants.variant_score_data + elif dtype == 'counts': + columns = ['urn', ] + scoreset.count_columns + type_column = constants.variant_count_data + else: + raise ValueError( + "Unknown variant dtype {}. Expected " + "either 'scores' or 'counts'.".format(dtype)) + + # 'hgvs_nt', 'hgvs_pro', 'urn' are present by default, hence <= 2 + if not variants or len(columns) <= 3: return response - + + rows = format_csv_rows(variants, columns=columns, dtype=type_column) writer = csv.DictWriter( response, fieldnames=columns, quoting=csv.QUOTE_MINIMAL) writer.writeheader() - writer.writerows(_format_csv_rows(variants, columns, type_column)) + writer.writerows(rows) return response -@cache_page(60 * 15) # 15 minute cache -def scoreset_count_data(request, urn): +def scoreset_score_data(request, urn): response = HttpResponse(content_type='text/csv') response['Content-Disposition'] = \ - 'attachment; filename="{}_counts.csv"'.format(urn) - - instance_or_response = validate_request(request, urn) - if not isinstance(instance_or_response, ScoreSet): - return instance_or_response - - scoreset = instance_or_response - order_by = 'id' # scoreset.primary_hgvs_column - variants = scoreset.children.order_by('{}'.format(order_by)) - columns = scoreset.count_columns - type_column = constants.variant_count_data - - # hgvs_nt and hgvs_pro present by default, hence <= 2 - if not variants or len(columns) <= 2: - return response - - writer = csv.DictWriter( - response, fieldnames=columns, quoting=csv.QUOTE_MINIMAL) - writer.writeheader() - writer.writerows(_format_csv_rows(variants, columns, type_column)) - return response + 'attachment; filename="{}_scores.csv"'.format(urn) + scoreset = validate_request(request, urn) + if not isinstance(scoreset, ScoreSet): + return scoreset # Invalid request, return response. + return format_response(response, scoreset, dtype='scores') -def _format_csv_rows(variants, columns, type_column): - rowdicts = [] - for variant in variants: - data = {} - for column_key in columns: - if column_key == constants.hgvs_nt_column: - data[column_key] = variant.hgvs_nt - elif column_key == constants.hgvs_pro_column: - data[column_key] = variant.hgvs_pro - else: - data[column_key] = str(variant.data[type_column][column_key]) - rowdicts.append(data) - return rowdicts +def scoreset_count_data(request, urn): + response = HttpResponse(content_type='text/csv') + response['Content-Disposition'] = \ + 'attachment; filename="{}_counts.csv"'.format(urn) + scoreset = validate_request(request, urn) + if not isinstance(scoreset, ScoreSet): + return scoreset # Invalid request, return response. + return format_response(response, scoreset, dtype='counts') -@cache_page(60 * 15) # 24 hour cache def scoreset_metadata(request, urn): instance_or_response = validate_request(request, urn) if not isinstance(instance_or_response, ScoreSet): return instance_or_response - scoreset = instance_or_response return JsonResponse(scoreset.extra_metadata, status=200) diff --git a/core/tests/test_reversion.py b/core/tests/test_reversion.py index b7450b0c..8a33bc42 100644 --- a/core/tests/test_reversion.py +++ b/core/tests/test_reversion.py @@ -1,8 +1,10 @@ +import json import reversion from reversion.models import Version from django.test import TestCase +from main.models import Licence from accounts.factories import UserFactory from metadata.factories import ( @@ -18,6 +20,27 @@ class TestVersionControl(TestCase): + def test_saves_licence(self): + l1 = Licence.get_cc0() + l2 = Licence.get_cc_by() + + instance = factories.ScoreSetFactory() + instance.licence = l1 + instance.save() + track_changes(instance, None) + version_1 = json.loads( + Version.objects.order_by('id').first() + .serialized_data)[0]['fields'] + + instance.licence = l2 + instance.save() + track_changes(instance, None) + version_2 = json.loads( + Version.objects.order_by('id').last() + .serialized_data)[0]['fields'] + + self.assertEqual(version_1['licence'], l1.id) + self.assertEqual(version_2['licence'], l2.id) def test_new_version_NOT_created_when_there_are_no_changes(self): instance = factories.ScoreSetWithTargetFactory() diff --git a/core/utilities/versioning.py b/core/utilities/versioning.py index cee79f4f..83c6d987 100644 --- a/core/utilities/versioning.py +++ b/core/utilities/versioning.py @@ -1,6 +1,8 @@ import reversion from reversion.models import Version +from core.models import TimeStampedModel + from django.db import transaction from django.db.models import Manager from django.utils import timezone @@ -14,16 +16,28 @@ def track_changes(instance, user=None): if len(versions) < 1: comments.append("{} created first revision.".format(user)) else: - prev_version = versions[0] # Recent version is always first + prev_version = versions[0] # Recent version is always first for field in instance.tracked_fields(): - p_field = prev_version.field_dict.get(field) + if field == 'licence': + p_field = prev_version.field_dict.get( + '{}_{}'.format(field, 'id')) + else: + p_field = prev_version.field_dict.get(field) + + # Only compare the ID fields of database model instances n_field = getattr(instance, field) if isinstance(n_field, Manager): n_field = [i.id for i in n_field.all()] + elif isinstance(n_field, TimeStampedModel): + n_field = n_field.id + + # Sort the id lists to compare them if applicable + # (for ManyToMany fields) if isinstance(p_field, (list, set, tuple)): p_field = sorted(p_field) if isinstance(n_field, (list, set, tuple)): n_field = sorted(n_field) + if p_field != n_field: comments.append( "{} edited {} field {}".format(user, klass, field)) diff --git a/data/main/site_info.json b/data/main/site_info.json index 6ed3f0ff..5489af85 100644 --- a/data/main/site_info.json +++ b/data/main/site_info.json @@ -6,5 +6,5 @@ "md_privacy": "", "md_terms": "", "md_usage_guide": "", - "version": "1.2.4-alpha" + "version": "1.3.0-alpha" } diff --git a/dataset/factories.py b/dataset/factories.py index d9eda0db..304a52dd 100644 --- a/dataset/factories.py +++ b/dataset/factories.py @@ -11,11 +11,13 @@ import factory.faker from factory.django import DjangoModelFactory +from main.models import Licence from metadata.factories import ( KeywordFactory, SraIdentifierFactory, DoiIdentifierFactory, PubmedIdentifierFactory ) +from .constants import success from .models.base import DatasetModel from .models.experimentset import ExperimentSet from .models.experiment import Experiment @@ -35,6 +37,7 @@ class Meta: short_description = factory.faker.Faker('text', max_nb_chars=1000) extra_metadata = {"foo": "bar"} private = True + processing_state = success @factory.post_generation def keywords(self, create, extracted, **kwargs): @@ -100,6 +103,15 @@ class Meta: experiment = factory.SubFactory(ExperimentFactory) dataset_columns = default_dataset() replaces = None + licence = None + + @factory.post_generation + def licence(self, created, extracted, **kwargs): + if not created: + return self + self.licence = Licence.get_default() + self.licence.save() + return self class ScoreSetWithTargetFactory(ScoreSetFactory): diff --git a/dataset/forms/scoreset.py b/dataset/forms/scoreset.py index ad112a0c..3d29c515 100644 --- a/dataset/forms/scoreset.py +++ b/dataset/forms/scoreset.py @@ -489,5 +489,4 @@ def __init__(self, *args, **kwargs): self.fields.pop('score_data') self.fields.pop('count_data') self.fields.pop('meta_data') - self.fields.pop('licence') self.fields.pop('replaces') diff --git a/dataset/models/scoreset.py b/dataset/models/scoreset.py index 354562d3..3a6ff25a 100644 --- a/dataset/models/scoreset.py +++ b/dataset/models/scoreset.py @@ -188,7 +188,11 @@ def save(self, *args, **kwargs): if self.licence is None: self.licence = Licence.get_default() return super().save(*args, **kwargs) - + + @classmethod + def tracked_fields(cls): + return super().tracked_fields() + ('licence', ) + # Variant related methods # ---------------------------------------------------------------------- # @property @@ -200,11 +204,11 @@ def variant_count(self): return self.variants.count() def delete_variants(self): - if self.has_variants: - self.variants.all().delete() - self.dataset_columns = default_dataset() - self.last_child_value = 0 - self.save() + self.variants.all().delete() + self.dataset_columns = default_dataset() + self.last_child_value = 0 + self.save() + return self def get_target(self): if not hasattr(self, 'target'): @@ -397,7 +401,6 @@ def get_error_message(self): return 'An error occured during processing. Please contact support.' - # --------------------------------------------------------------------------- # # Post Save # --------------------------------------------------------------------------- # diff --git a/dataset/tests/test_forms_edit_scoreset.py b/dataset/tests/test_forms_edit_scoreset.py index 2dc92f3f..f90ee5f5 100644 --- a/dataset/tests/test_forms_edit_scoreset.py +++ b/dataset/tests/test_forms_edit_scoreset.py @@ -66,14 +66,14 @@ def test_cannot_save_popped_field(self): obj = ScoreSetFactory(replaces=replaced) for i in range(5): VariantFactory(scoreset=obj) + Licence.populate() old_experiment = obj.experiment - old_licence = obj.licence old_replaces = obj.previous_version old_variants = obj.children data, files = self.make_post_data() - data['licence'] = Licence.get_cc0() + data['licence'] = Licence.get_cc0().pk data['replaces'] = ScoreSetFactory(experiment=exp).pk form = ScoreSetEditForm( data=data, files=files, user=self.user, instance=obj @@ -83,4 +83,5 @@ def test_cannot_save_popped_field(self): self.assertEqual(instance.children.count(), old_variants.count()) self.assertEqual(instance.experiment, old_experiment) self.assertEqual(instance.previous_version, old_replaces) - self.assertEqual(instance.licence, old_licence) \ No newline at end of file + + self.assertEqual(instance.licence, Licence.get_cc0()) # new Licence diff --git a/dataset/tests/test_tasks.py b/dataset/tests/test_tasks.py index 21c732e3..db12c9c2 100644 --- a/dataset/tests/test_tasks.py +++ b/dataset/tests/test_tasks.py @@ -132,10 +132,17 @@ def test_converts_nan_hgvs_to_none(self): def test_create_variants_resets_dataset_columns(self): self.scoreset.dataset_columns = default_dataset() - create_variants.apply(kwargs=self.mock_kwargs()) + create_variants.run(**self.mock_kwargs()) self.scoreset.refresh_from_db() self.assertEqual(self.scoreset.dataset_columns, self.dataset_columns) + def test_create_variants_sets_last_child_value_to_zero(self): + self.scoreset.last_child_value = 100 + self.scoreset.save() + create_variants.run(**self.mock_kwargs()) + self.scoreset.refresh_from_db() + self.assertEqual(self.scoreset.last_child_value, 1) + class TestPublishScoresetTask(TestCase): def setUp(self): @@ -176,14 +183,14 @@ def test_resets_public_to_false_if_failed(self): self.assertTrue(scoreset.parent.parent.private) def test_propagates_modified(self): - publish_scoreset.apply(kwargs=self.mock_kwargs()) + publish_scoreset.run(**self.mock_kwargs()) scoreset = ScoreSet.objects.first() self.assertEqual(scoreset.modified_by, self.user) self.assertEqual(scoreset.parent.modified_by, self.user) self.assertEqual(scoreset.parent.parent.modified_by, self.user) def test_propagates_public(self): - publish_scoreset.apply(kwargs=self.mock_kwargs()) + publish_scoreset.run(**self.mock_kwargs()) scoreset = ScoreSet.objects.first() self.assertEqual(scoreset.private, False) self.assertEqual(scoreset.parent.private, False) @@ -192,7 +199,7 @@ def test_propagates_public(self): def test_publish_assigns_new_public_urns(self): var = VariantFactory(scoreset=self.scoreset) self.assertFalse(var.has_public_urn) - publish_scoreset.apply(kwargs=self.mock_kwargs()) + publish_scoreset.run(**self.mock_kwargs()) var.refresh_from_db() self.scoreset.refresh_from_db() self.assertTrue(var.has_public_urn) diff --git a/dataset/tests/test_views_scoreset.py b/dataset/tests/test_views_scoreset.py index ab22e078..eee5d2dc 100644 --- a/dataset/tests/test_views_scoreset.py +++ b/dataset/tests/test_views_scoreset.py @@ -285,7 +285,7 @@ def test_correct_tamplate_when_logged_in(self): ) response = self.client.get(self.path) self.assertTemplateUsed(response, self.template) - + def test_redirects_to_profile_after_success(self): data = self.post_data.copy() exp1 = ExperimentFactory() diff --git a/main/management/commands/createtestentries.py b/main/management/commands/createtestentries.py index fb82a8e2..dd5ed497 100644 --- a/main/management/commands/createtestentries.py +++ b/main/management/commands/createtestentries.py @@ -31,6 +31,8 @@ def handle(self, *args, **kwargs): for _ in [False, True]: # Configure the scoreset first. scoreset = factories.ScoreSetFactory(experiment=instance) + for _ in range(100): + VariantFactory(scoreset=scoreset) target = genome_factories.TargetGeneFactory(scoreset=scoreset) genomes = genome_models.ReferenceGenome.objects.all() genome_factories.ReferenceMapFactory( @@ -65,9 +67,6 @@ def handle(self, *args, **kwargs): scoreset.set_modified_by(user, propagate=True) scoreset.set_created_by(user, propagate=True) - - for _ in range(100): - VariantFactory(scoreset=scoreset) scoreset.save() experiment = scoreset.parent diff --git a/main/models.py b/main/models.py index 369c1a8a..2a1e0ec0 100644 --- a/main/models.py +++ b/main/models.py @@ -287,7 +287,7 @@ def get_default(cls): def get_cc0(cls): try: licence = cls.objects.get(short_name="CC0") - except ObjectDoesNotExist: + except cls.DoesNotExist: licence = cls.objects.create( short_name="CC0", long_name="CC0 (Public domain)", @@ -306,7 +306,7 @@ def get_cc0(cls): def get_cc_by_nc_sa(cls): try: licence = cls.objects.get(short_name="CC BY-NC-SA 4.0") - except ObjectDoesNotExist: + except cls.DoesNotExist: licence = cls.objects.create( short_name="CC BY-NC-SA 4.0", long_name="CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike)", @@ -325,7 +325,7 @@ def get_cc_by_nc_sa(cls): def get_cc_by(cls): try: licence = cls.objects.get(short_name="CC BY 4.0") - except ObjectDoesNotExist: + except cls.DoesNotExist: licence = cls.objects.create( short_name="CC BY 4.0", long_name="CC BY 4.0 (Attribution)", diff --git a/requirements/base.txt b/requirements/base.txt index 14d1799b..fdf07fe5 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -23,4 +23,4 @@ psycopg2-binary==2.7.4 coverage==4.5.1 idutils==1.1.0 hg+https://bitbucket.org/metapub/metapub@default -git+https://github.com/FowlerLab/hgvs-patterns.git +git+https://github.com/VariantEffect/hgvs-patterns.git diff --git a/variant/models.py b/variant/models.py index ba40579e..c994839a 100644 --- a/variant/models.py +++ b/variant/models.py @@ -202,4 +202,4 @@ def count_data(self): elif column == constants.hgvs_pro_column: yield self.hgvs_pro else: - yield self.data[constants.variant_count_data][column] \ No newline at end of file + yield self.data[constants.variant_count_data][column] diff --git a/variant/tests/test_models.py b/variant/tests/test_models.py index 3ff85293..90a6f860 100644 --- a/variant/tests/test_models.py +++ b/variant/tests/test_models.py @@ -1,13 +1,14 @@ from django.db import IntegrityError from django.core.exceptions import ValidationError -from django.test import TestCase +from django.test import TestCase, mock +from dataset.models.scoreset import ScoreSet import dataset.constants as constants from dataset.factories import ScoreSetFactory from dataset.utilities import publish_dataset from ..factories import VariantFactory -from ..models import assign_public_urn +from ..models import assign_public_urn, Variant from urn.validators import MAVEDB_VARIANT_URN_RE @@ -68,6 +69,67 @@ def test_hgvs_property_returns_pro_if_nt_column_not_defined(self): obj = VariantFactory(hgvs_nt=None) self.assertEqual(obj.hgvs, obj.hgvs_pro) + def test_bulk_create_urns_creates_sequential_urns(self): + parent = ScoreSetFactory() + urns = Variant.bulk_create_urns(10, parent) + for i, urn in enumerate(urns): + number = int(urn.split('#')[-1]) + self.assertEqual(number, i + 1) + + def test_bulk_create_urns_updates_parent_last_child_value(self): + parent = ScoreSetFactory() + Variant.bulk_create_urns(10, parent) + self.assertEqual(parent.last_child_value, 10) + + @mock.patch.object(Variant, 'bulk_create_urns', return_value=['',]) + def test_bulk_create_calls_bulk_create_urns_with_correct_args(self, patch): + parent = ScoreSetFactory() + column = constants.required_score_column + variant_kwargs_list = [{ + 'hgvs_nt': 'c.1A>G', 'hgvs_pro': 'p.G4Y', + 'data': dict({ + constants.variant_score_data: {column: 1}, + constants.variant_count_data: {},}), + }, { + 'hgvs_nt': 'c.2A>G', 'hgvs_pro': 'p.G5Y', + 'data': dict({ + constants.variant_score_data: {column: 2}, + constants.variant_count_data: {}, }), + }, + ] + _ = Variant.bulk_create(parent, variant_kwargs_list) + patch.assert_called_with(*(2, parent)) + + def test_bulk_create_creates_variants_with_kwargs(self): + parent = ScoreSetFactory() + column = constants.required_score_column + variant_kwargs_list = [{ + 'hgvs_nt': 'c.1A>G', 'hgvs_pro': 'p.G4Y', + 'data': dict({ + constants.variant_score_data: {column: 1}, + constants.variant_count_data: {},}), + }, { + 'hgvs_nt': 'c.2A>G', 'hgvs_pro': 'p.G5Y', + 'data': dict({ + constants.variant_score_data: {column: 2}, + constants.variant_count_data: {}, }), + }, + ] + count = Variant.bulk_create(parent, variant_kwargs_list) + self.assertEqual(count, 2) + + parent.refresh_from_db() + variants = parent.variants.order_by('urn') + self.assertEqual(variants[0].urn, '{}#{}'.format(parent.urn, 1)) + self.assertEqual(variants[0].hgvs_nt, 'c.1A>G') + self.assertEqual(variants[0].hgvs_pro, 'p.G4Y') + self.assertDictEqual(variants[0].data, variant_kwargs_list[0]['data']) + + self.assertEqual(variants[1].urn, '{}#{}'.format(parent.urn, 2)) + self.assertEqual(variants[1].hgvs_nt, 'c.2A>G') + self.assertEqual(variants[1].hgvs_pro, 'p.G5Y') + self.assertDictEqual(variants[1].data, variant_kwargs_list[1]['data']) + class TestAssignPublicUrn(TestCase): def setUp(self): diff --git a/variant/tests/test_utilities.py b/variant/tests/test_utilities.py index 479fb46b..6468bd14 100644 --- a/variant/tests/test_utilities.py +++ b/variant/tests/test_utilities.py @@ -9,6 +9,94 @@ from .. import utilities +class TestSplitVariant(TestCase): + def test_split_hgvs_singular_list_non_multi_variant(self): + self.assertListEqual( + ['c.100A>G'], utilities.split_variant('c.100A>G')[1]) + + def test_split_hgvs_returns_list_of_single_variants(self): + self.assertListEqual( + ['c.100A>G', 'c.101A>G'], + utilities.split_variant('c.[100A>G;101A>G]')[1] + ) + + def test_returns_prefix(self): + for p in 'cgmpn': + self.assertEqual( + p, utilities.split_variant( + '{}.[100A>G;101A>G]'.format(p))[0] + ) + + +class TestJoinVariants(TestCase): + def test_passes_on_special(self): + self.assertEqual( + utilities.join_variants('_wt', None), '_wt') + self.assertEqual( + utilities.join_variants('_sy', None), '_sy') + self.assertEqual( + utilities.join_variants(['_wt'], None), '_wt') + self.assertEqual( + utilities.join_variants(['_sy'], None), '_sy') + + def test_returns_single_variant(self): + self.assertEqual( + utilities.join_variants(['1A>G'], 'c'), 'c.1A>G') + + def test_clips_prefix(self): + self.assertEqual( + utilities.join_variants(['c.1A>G'], 'c'), 'c.1A>G') + + def test_returns_multi(self): + self.assertEqual( + utilities.join_variants(['c.1A>G', 'c.2A>G'], 'c'), 'c.[1A>G;2A>G]') + + def test_returns_none_empty_list(self): + self.assertEqual( + utilities.join_variants([], None), None) + + +class TestFormatVariant(TestCase): + def test_strips_white_space(self): + self.assertEqual(utilities.format_variant(' c.1A>G '), 'c.1A>G') + + def test_passes_on_special(self): + self.assertEqual(utilities.format_variant('_wt'), '_wt') + self.assertEqual(utilities.format_variant('_sy'), '_sy') + + def test_passes_on_none(self): + self.assertIsNone(utilities.format_variant(None)) + + def test_replaces_triple_q_with_X_in_protein_variant(self): + self.assertEqual(utilities.format_variant('p.G4???'), 'p.G4X') + self.assertEqual( + utilities.format_variant('p.[G4???;G3???]'), 'p.[G4X;G3X]') + + def test_replaces_Xaa_with_X_in_protein_variant(self): + self.assertEqual(utilities.format_variant('p.G4Xaa'), 'p.G4X') + self.assertEqual( + utilities.format_variant('p.[G4XaaXaa;G3Xaa]'), 'p.[G4XX;G3X]') + + def test_replaces_X_with_N_in_dna_variant(self): + for p in 'cgnm': + self.assertEqual( + utilities.format_variant( + '{}.100A>X'.format(p)), '{}.100A>N'.format(p) + ) + for p in 'cgnm': + self.assertEqual( + utilities.format_variant('{}.[1A>X;1_2delinsXXX]'.format(p)), + '{}.[1A>N;1_2delinsNNN]'.format(p) + ) + + def test_replaces_X_with_N_in_rna_variant(self): + self.assertEqual(utilities.format_variant('r.100a>x'), 'r.100a>n') + self.assertEqual( + utilities.format_variant('r.[1a>x;1_2delinsxxx]'), + 'r.[1a>n;1_2delinsnnn]' + ) + + class TestCreateVariantAttrsUtility(TestCase): @staticmethod def fixture_data(nt_score=('c.1A>G', 'c.2A>G'), diff --git a/variant/tests/test_validators.py b/variant/tests/test_validators.py index 14e31eb2..58d0d63a 100644 --- a/variant/tests/test_validators.py +++ b/variant/tests/test_validators.py @@ -277,6 +277,47 @@ def test_allows_wt_and_sy(self): non_hgvs_cols, _, df = validate_variant_rows(BytesIO(data.encode())) self.assertEqual(df[constants.hgvs_nt_column].values[0], wt) self.assertEqual(df[constants.hgvs_pro_column].values[0], sy) + + def test_converts_triple_q_to_single_q_in_protein_sub(self): + data = "{},{}\n{},1.0".format( + constants.hgvs_pro_column, + required_score_column, 'p.Gly4???' + ) + non_hgvs_cols, _, df = validate_variant_rows(BytesIO(data.encode())) + self.assertEqual(df[constants.hgvs_nt_column].values[0], None) + self.assertEqual(df[constants.hgvs_pro_column].values[0], 'p.Gly4X') + + def test_converts_triple_x_to_single_n_rna_dna(self): + data = "{},{}\n{},1.0\n{},2.0".format( + constants.hgvs_nt_column, + required_score_column, 'c.1A>X', 'r.1a>x' + ) + non_hgvs_cols, _, df = validate_variant_rows(BytesIO(data.encode())) + self.assertEqual(df[constants.hgvs_nt_column].values[0], 'c.1A>N') + self.assertEqual(df[constants.hgvs_nt_column].values[1], 'r.1a>n') + + def test_converts_triple_q_to_single_q_in_protein_multi_sub(self): + data = "{},{}\n{},1.0".format( + constants.hgvs_pro_column, + required_score_column, 'p.[Gly4???;Asp2???]' + ) + non_hgvs_cols, _, df = validate_variant_rows(BytesIO(data.encode())) + self.assertEqual(df[constants.hgvs_nt_column].values[0], None) + self.assertEqual( + df[constants.hgvs_pro_column].values[0], 'p.[Gly4X;Asp2X]') + + def test_converts_triple_x_to_single_n_in_multi_rna_dna(self): + data = "{},{}\n{},1.0\n{},2.0".format( + constants.hgvs_nt_column, + required_score_column, + 'n.[1A>X;1_2delinsXXX]', + 'r.[1a>x;1_2insxxx]' + ) + non_hgvs_cols, _, df = validate_variant_rows(BytesIO(data.encode())) + self.assertEqual( + df[constants.hgvs_nt_column].values[0], 'n.[1A>N;1_2delinsNNN]') + self.assertEqual( + df[constants.hgvs_nt_column].values[1], 'r.[1a>n;1_2insnnn]') def test_parses_numeric_column_values_into_float(self): hgvs = generate_hgvs() diff --git a/variant/utilities.py b/variant/utilities.py index 09d64435..1edf883e 100644 --- a/variant/utilities.py +++ b/variant/utilities.py @@ -1,9 +1,101 @@ +import re + import pandas as pd import numpy as np from pandas.testing import assert_index_equal +from hgvsp import protein, dna, rna + from core.utilities import is_null +from .constants import wildtype, synonymous + + +def split_variant(variant): + """ + Splits a multi-variant `HGVS` string into a list of single variants. If + a single variant string is provided, it is returned as a singular `list`. + + Parameters + ---------- + variant : str + A valid single or multi-variant `HGVS` string. + + Returns + ------- + list[str] + A list of single `HGVS` strings. + """ + prefix = variant[0] + if len(variant.split(';')) > 1: + return prefix, ['{}.{}'.format(prefix, e.strip()) + for e in variant[3:-1].split(';')] + return prefix, [variant] + + +def join_variants(variants, prefix): + """ + Joins a list of single variant events into a multi-variant HGVS string. + + Parameters + ---------- + variants : union[str, list[str]] + A list of valid single or multi-variant `HGVS` string. + prefix : str + HGVS prefix. + + Returns + ------- + str + """ + if isinstance(variants, str): + return variants + + if len(variants) == 1 and variants[0] in (wildtype, synonymous): + return variants[0] + + if len(variants) == 1: + return '{}.{}'.format( + prefix, + variants[0].replace('{}.'.format(prefix), '') + ) + elif len(variants) > 1: + return '{}.[{}]'.format( + prefix, + ';'.join([v.replace('{}.'.format(prefix), '') for v in variants]) + ) + else: + return None + + +def format_variant(variant): + """ + Replaces `???` for `X` in protein variants and `Xx` for `Nn` in + nucleotide variants to be compliant with the `hgvs` biocommons package. + + Parameters + ---------- + variant : str, optional. + HGVS_ formatted string. + + Returns + ------- + str + """ + if is_null(variant): + return None + + variant = variant.strip() + if 'p.' in variant: + variant, _ = re.subn(r'\?+', 'X', variant) + variant, _ = re.subn(r'Xaa', 'X', variant) + elif 'g.' in variant or 'n.' in variant or \ + 'c.' in variant or 'm.' in variant: + variant, _ = re.subn(r'X', 'N', variant) + elif 'r.' in variant: + variant, _ = re.subn(r'x', 'n', variant) + return variant + def convert_df_to_variant_records(scores, counts=None, index=None): """ diff --git a/variant/validators/__init__.py b/variant/validators/__init__.py index 6140877b..0defb8f5 100644 --- a/variant/validators/__init__.py +++ b/variant/validators/__init__.py @@ -21,6 +21,8 @@ from .hgvs import validate_multi_variant, \ validate_single_variant, validate_nt_variant, validate_pro_variant +from .. import utilities + def validate_hgvs_nt_uniqueness(df): """Validate that hgvs columns only define a variant once.""" @@ -192,6 +194,13 @@ def validate_variant_rows(file): primary_hgvs_column = hgvs_nt_column # Check that the primary column is fully specified. + if defines_nt_hgvs: + df[hgvs_nt_column] = df.loc[:, hgvs_nt_column].\ + apply(utilities.format_variant) + if defines_p_hgvs: + df[hgvs_pro_column] = df.loc[:, hgvs_pro_column].\ + apply(utilities.format_variant) + null_primary = df.loc[:, primary_hgvs_column].apply(is_null) if any(null_primary): raise ValidationError(