diff --git a/aa_stripe/management/commands/sync_all_cards_for_all_customers.py b/aa_stripe/management/commands/sync_all_cards_for_all_customers.py new file mode 100644 index 0000000..4b507ee --- /dev/null +++ b/aa_stripe/management/commands/sync_all_cards_for_all_customers.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +from time import sleep + +import stripe +from django.conf import settings +from django.core.management.base import BaseCommand + +from aa_stripe.models import StripeCard, StripeCustomer +from aa_stripe.settings import stripe_settings + + +class Command(BaseCommand): + help = "Sync all cards for all customers with Stripe API" + + def add_arguments(self, parser): + min_customer_id = StripeCustomer.objects.order_by("id").first().id + max_customer_id = StripeCustomer.objects.order_by("-id").first().id + parser.add_argument( + '--max_customer_id', + nargs='?', + const=max_customer_id, + default=max_customer_id, + type=int, + dest='max_id', + help='Id of the last customer for whom to update cards', + ) + parser.add_argument( + '--min_customer_id', + nargs='?', + const=min_customer_id, + default=min_customer_id, + type=int, + dest='min_id', + help='Id of the first customer for whom to update cards', + ) + + def handle(self, *args, **options): + stripe.api_key = stripe_settings.API_KEY + + counts = {"created": 0, "updated": 0, "deleted": 0} + processed_users_set = set() + + for customer in StripeCustomer.objects.filter( + is_active=True, id__range=(options['min_id'], options['max_id'])).order_by("-id"): + if customer.user.id in processed_users_set: + continue + + customer_from_stripe = stripe.Customer.retrieve(customer.stripe_customer_id) + actual_cards = customer_from_stripe.sources.all(object="card") + actual_cards_map = {c.id: c for c in actual_cards} + actual_cards_set = set(actual_cards_map) + our_cards = StripeCard.objects.filter(customer=customer) + our_cards_set = set(our_cards.values_list('stripe_card_id', flat=True)) + our_deleted_cards = StripeCard.objects.deleted().filter(customer=customer) + our_deleted_cards_set = set(our_deleted_cards.values_list('stripe_card_id', flat=True)) + + stripe_deleted_cards = our_cards_set - actual_cards_set + for card_id in stripe_deleted_cards: + card = our_cards.get(stripe_card_id=card_id) + card.is_deleted = True + card.save() + + undelete_cards = actual_cards_set & our_deleted_cards_set + for card_id in undelete_cards: + card = our_deleted_cards.get(stripe_card_id=card_id) + card.is_deleted = False + card.save() + + update_cards = our_cards_set & actual_cards_set + for card_id in update_cards: + card = our_cards.get(stripe_card_id=card_id) + card.update_from_stripe_card(actual_cards_map[card_id]) + + created_cards = actual_cards_set - (our_cards_set | our_deleted_cards_set) + for card_id in created_cards: + card = StripeCard(customer=customer) + card.update_from_stripe_card(actual_cards_map[card_id]) + + stripe_defaut_source = customer_from_stripe.default_source + our_defaut_source = customer.default_card.stripe_card_id + if stripe_defaut_source != our_defaut_source: + customer.default_card = StripeCard.objects.get(stripe_card_id=stripe_defaut_source) + customer.save() + + counts["created"] += len(created_cards) + counts["deleted"] += len(stripe_deleted_cards) + counts["updated"] += len(update_cards) + len(undelete_cards) + processed_users_set.add(customer.user.id) + + print("Processed customer with id: {}".format(customer.id)) + if not settings.TESTING: + sleep(0.25) + + if options.get("verbosity") > 1: + print("Cards created: {created}, updated: {updated}, deleted: {deleted}".format(**counts)) diff --git a/aa_stripe/models.py b/aa_stripe/models.py index f9eb831..75f66a3 100644 --- a/aa_stripe/models.py +++ b/aa_stripe/models.py @@ -108,6 +108,13 @@ def _retrieve_from_stripe(self, set_deleted=False): if set_deleted: self.is_deleted = True + def _set_fields_from_stripe_object(self, stripe_object): + self.stripe_card_id = stripe_object["id"] + self.last4 = stripe_object["last4"] + self.exp_month = stripe_object["exp_month"] + self.exp_year = stripe_object["exp_year"] + self.stripe_response = stripe_object + def create_at_stripe(self): if self.is_created_at_stripe: raise StripeMethodNotAllowed() @@ -116,11 +123,10 @@ def create_at_stripe(self): customer = stripe.Customer.retrieve(self.customer.stripe_customer_id) card = customer.sources.create(source=self.stripe_js_response["source"]) - self.stripe_card_id = card["id"] - self.last4 = card["last4"] - self.exp_month = card["exp_month"] - self.exp_year = card["exp_year"] - self.stripe_response = card + return self.update_from_stripe_card(card) + + def update_from_stripe_card(self, card): + self._set_fields_from_stripe_object(card) self.is_created_at_stripe = True self.save() return self diff --git a/tests/test_cards.py b/tests/test_cards.py index b7dfa74..89e523a 100644 --- a/tests/test_cards.py +++ b/tests/test_cards.py @@ -1,17 +1,19 @@ from datetime import datetime from functools import partial +from itertools import product from random import randint from uuid import uuid4 import requests_mock import simplejson as json +from django.core.management import call_command from rest_framework.reverse import reverse from tests.test_utils import BaseTestCase -from aa_stripe.models import StripeCard +from aa_stripe.models import StripeCard, StripeCustomer -class TestCards(BaseTestCase): +class BaseCardsTestCase(BaseTestCase): _last4 = partial(randint, 1000, 9999) _exp_month = partial(randint, 1, 12) _todays_year = datetime.utcnow().year @@ -20,6 +22,12 @@ class TestCards(BaseTestCase): def _stripe_card_id(self): return "card_{}".format(uuid4().hex[:24]) + def _stripe_card_fingerprint(self): + return uuid4().hex[:16] + + +class TestCards(BaseCardsTestCase): + def _get_successful_retrive_stripe_customer_response(self, id, default_source=None): return { "id": id, @@ -368,12 +376,10 @@ def test_update_card(self): [{ "text": json.dumps(self._get_successful_retrive_stripe_customer_response(customer_id)) }]) - m.register_uri( - "POST", "https://api.stripe.com/v1/customers/{}".format(customer_id), - [{ - "text": - json.dumps(self._get_successful_retrive_stripe_customer_response(customer_id, updated_card_id)) - }]) + m.register_uri("POST", "https://api.stripe.com/v1/customers/{}".format(customer_id), [{ + "text": + json.dumps(self._get_successful_retrive_stripe_customer_response(customer_id, updated_card_id)) + }]) m.register_uri("GET", "https://api.stripe.com/v1/customers/{}/sources/{}".format( customer_id, updated_card_id), [{ "text": @@ -396,3 +402,236 @@ def test_update_card(self): data = {"stripe_token": "tok_amex", "set_default": set_default} response = self.client.patch(url, data, format="json") self.assertEqual(response.status_code, 403) + + +class TestCardsCommands(BaseCardsTestCase): + + def _get_empty_sources_retrive_customer_stripe_response(self, customer_id): + return json.loads('''{{ + "id": "{0}", + "object": "customer", + "account_balance": 0, + "created": 1515510882, + "currency": "usd", + "default_source": null, + "delinquent": false, + "description": null, + "discount": null, + "email": null, + "livemode": false, + "metadata": {{}}, + "shipping": null, + "sources": {{ + "object": "list", + "data": [], + "has_more": false, + "total_count": 0, + "url": "/v1/customers/{0}/sources" + }}, + "subscriptions": {{ + "object": "list", + "data": [], + "has_more": false, + "total_count": 0, + "url": "/v1/customers/{0}/subscriptions" + }} + }}'''.format(customer_id)) + + def _get_retrive_card_stripe_response(self, card_id, customer_id=None): + if customer_id: + return json.loads('''{{ + "id": "{0}", + "object": "card", + "address_city": null, + "address_country": null, + "address_line1": null, + "address_line1_check": null, + "address_line2": null, + "address_state": null, + "address_zip": null, + "address_zip_check": null, + "brand": "Visa", + "country": "US", + "customer": "{1}", + "cvc_check": null, + "dynamic_last4": null, + "exp_month": {2}, + "exp_year": {3}, + "fingerprint": "{4}", + "funding": "credit", + "last4": "{5}", + "metadata": {{}}, + "name": null, + "tokenization_method": null + }}'''.format(card_id, customer_id, self._exp_month(), self._exp_year(), self._stripe_card_fingerprint(), + self._last4())) + + return json.loads('''{{ + "error": {{ + "type": "invalid_request_error", + "message": "No such source: {}", + "param": "id" + }} + }}'''.format(card_id)) + + def setUp(self): + m = requests_mock.Mocker() + cards_created = (0, 2) + cards_not_changed = (1,) + cards_updated = (0, 3) + cards_deleted = (0, 1) + cards_we_deleted = (0, 1) + self.cards_cases = [ + case for case in product(cards_created, cards_not_changed, cards_updated, cards_deleted, cards_we_deleted) + ] + self.customer_id_case_map = {} + self.test_cases = {} + self.test_update_cases = {c: {} for c in self.cards_cases if c[2]} + self.all_cards_count_in_db = sum([c[1] + c[2] + c[3] + c[4] for c in self.cards_cases]) + self.all_not_deleted_cards_count_in_db = sum([c[1] + c[2] + c[3] for c in self.cards_cases]) + for i, case in enumerate(self.cards_cases): + self._create_user(i) + self._create_customer() + self.customer_id_case_map[self.customer.id] = case + customer_id = self.customer.stripe_customer_id + created_card_ids = [self._stripe_card_id() for r in range(case[0])] + not_changed_card_ids = [self._stripe_card_id() for r in range(case[1])] + updated_card_ids = [self._stripe_card_id() for r in range(case[2])] + deleted_card_ids = [self._stripe_card_id() for r in range(case[3])] + we_deleted_card_ids = [self._stripe_card_id() for r in range(case[4])] + + cards_at_stripe = created_card_ids + updated_card_ids + not_changed_card_ids + we_deleted_card_ids + cards_in_database = updated_card_ids + not_changed_card_ids + deleted_card_ids + we_deleted_card_ids + default_card_id = (updated_card_ids + not_changed_card_ids)[0] + swap_default_card = case[0] and case[2] or case[2] and case[3] + new_default_card = (created_card_ids + updated_card_ids)[1] if swap_default_card else default_card_id + + self.test_cases[case] = (customer_id, created_card_ids, not_changed_card_ids, updated_card_ids, + deleted_card_ids, we_deleted_card_ids, default_card_id, new_default_card) + + customer_response = self._get_empty_sources_retrive_customer_stripe_response(customer_id) + customer_response["default_source"] = new_default_card + customer_response["sources"]["total_count"] = len(cards_at_stripe) + for card_id in cards_at_stripe: + card_response = self._get_retrive_card_stripe_response(card_id, customer_id) + customer_response["sources"]["data"].append(card_response) + + if card_id in cards_in_database: + if card_id in updated_card_ids: + old_exp_month = self._exp_month() + old_exp_year = self._exp_year() + self.test_update_cases[case][card_id] = (card_response["exp_month"], card_response["exp_year"]) + self._create_card( + stripe_card_id=card_id, + is_default=card_id == default_card_id, + last4=card_response["last4"], + exp_month=old_exp_month, + exp_year=old_exp_year) + else: + self._create_card( + stripe_card_id=card_id, + is_default=card_id == default_card_id, + last4=card_response["last4"], + exp_month=card_response["exp_month"], + exp_year=card_response["exp_year"], + is_deleted=card_id in we_deleted_card_ids) + + m.register_uri( + "GET", + "https://api.stripe.com/v1/customers/{}/sources?object=card".format(customer_id), + status_code=200, + text=json.dumps(customer_response["sources"])) + m.register_uri( + "GET", + "https://api.stripe.com/v1/customers/{}".format(customer_id), + status_code=200, + text=json.dumps(customer_response)) + + for card_id in deleted_card_ids: + self._create_card(stripe_card_id=card_id) + m.register_uri( + "GET", + "https://api.stripe.com/v1/customers/{}/sources/{}".format(customer_id, card_id), + status_code=404, + text=json.dumps(self._get_retrive_card_stripe_response(card_id))) + + m.start() + self.addCleanup(m.stop) + + def _get_cards_counts_after_command_call(self, cases): + cards_count_in_database_after_command_call = self.all_cards_count_in_db + sum([c[0] for c in cases]) + cards_not_deleted_count_in_database_after_command_call = self.all_not_deleted_cards_count_in_db + sum( + [c[0] + c[4] for c in cases]) - sum([c[3] for c in cases]) + return (cards_count_in_database_after_command_call, cards_not_deleted_count_in_database_after_command_call) + + def test_sync_all_cards_for_all_customers_command(self): + cards_counts_after_command_call = self._get_cards_counts_after_command_call(self.cards_cases) + + self.assertEqual(StripeCard.objects.all_with_deleted().count(), self.all_cards_count_in_db) + self.assertEqual(StripeCard.objects.count(), self.all_not_deleted_cards_count_in_db) + call_command("sync_all_cards_for_all_customers") + self.assertEqual(StripeCard.objects.all_with_deleted().count(), cards_counts_after_command_call[0]) + self.assertEqual(StripeCard.objects.count(), cards_counts_after_command_call[1]) + + for case in self.cards_cases: + test_case = self.test_cases[case] + customer = StripeCustomer.objects.get(stripe_customer_id=test_case[0]) + # created on stripe + if len(test_case[1]): + self.assertTrue(StripeCard.objects.filter(customer=customer, stripe_card_id__in=test_case[1]).exists()) + # not changed + self.assertTrue(StripeCard.objects.filter(customer=customer, stripe_card_id__in=test_case[2]).exists()) + # updated + if len(test_case[3]): + updated_cards = StripeCard.objects.filter(customer=customer, stripe_card_id__in=test_case[3]) + self.assertTrue(updated_cards.exists()) + for update_case_card_id in test_case[3]: + update_case = self.test_update_cases[case][update_case_card_id] + updated_card = updated_cards.get(stripe_card_id=update_case_card_id) + self.assertEqual(updated_card.exp_month, update_case[0]) + self.assertEqual(updated_card.exp_year, update_case[1]) + # deleted on stripe + if len(test_case[4]): + self.assertTrue(StripeCard.objects.deleted().filter(customer=customer, + stripe_card_id__in=test_case[4]).exists()) + # we deleted, should be restored + if len(test_case[5]): + self.assertTrue(StripeCard.objects.filter(customer=customer, stripe_card_id__in=test_case[5]).exists()) + # changed default card on stripe + if test_case[6] != test_case[7]: + self.assertEqual(customer.default_card.stripe_card_id, test_case[7]) + else: + self.assertEqual(customer.default_card.stripe_card_id, test_case[6]) + + def test_sync_cards_for_customers_command_with_max_argument(self): + max_customer_id = max(self.customer_id_case_map) // 2 + cases_to_run = [self.customer_id_case_map[k] for k in self.customer_id_case_map if k <= max_customer_id] + cards_counts_after_command_call = self._get_cards_counts_after_command_call(cases_to_run) + + call_command("sync_all_cards_for_all_customers", max_customer_id=max_customer_id) + self.assertEqual(StripeCard.objects.all_with_deleted().count(), cards_counts_after_command_call[0]) + self.assertEqual(StripeCard.objects.count(), cards_counts_after_command_call[1]) + + def test_sync_cards_for_customers_command_with_min_argument(self): + min_customer_id = max(self.customer_id_case_map) // 3 + cases_to_run = [self.customer_id_case_map[k] for k in self.customer_id_case_map if k >= min_customer_id] + cards_counts_after_command_call = self._get_cards_counts_after_command_call(cases_to_run) + + call_command("sync_all_cards_for_all_customers", min_customer_id=min_customer_id) + self.assertEqual(StripeCard.objects.all_with_deleted().count(), cards_counts_after_command_call[0]) + self.assertEqual(StripeCard.objects.count(), cards_counts_after_command_call[1]) + + def test_sync_cards_for_customers_command_with_min_and_max_argument(self): + min_customer_id = max(self.customer_id_case_map) // 4 + max_customer_id = min_customer_id * 2 + cases_to_run = [ + self.customer_id_case_map[k] + for k in self.customer_id_case_map + if k >= min_customer_id and k <= max_customer_id + ] + cards_counts_after_command_call = self._get_cards_counts_after_command_call(cases_to_run) + + call_command( + "sync_all_cards_for_all_customers", max_customer_id=max_customer_id, min_customer_id=min_customer_id) + self.assertEqual(StripeCard.objects.all_with_deleted().count(), cards_counts_after_command_call[0]) + self.assertEqual(StripeCard.objects.count(), cards_counts_after_command_call[1]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 65e7a36..13b03c5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -70,14 +70,16 @@ def _create_card(self, set_self=True, last4="4242", exp_month=1, - exp_year=2025): + exp_year=2025, + is_deleted=False): card = StripeCard.objects.create( customer=customer or self.customer, last4=last4, exp_month=exp_month, exp_year=exp_year, stripe_card_id=stripe_card_id or "card_{}".format(uuid4().hex), - is_created_at_stripe=True) + is_created_at_stripe=True, + is_deleted=is_deleted) if is_default: card.customer.default_card = card card.customer.save()