Skip to content

Commit

Permalink
Add email to User
Browse files Browse the repository at this point in the history
  • Loading branch information
csirmazbendeguz committed Apr 19, 2024
1 parent 2aa74f4 commit 6a26b19
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 44 deletions.
3 changes: 3 additions & 0 deletions django/db/models/fields/composite.py
Expand Up @@ -151,6 +151,9 @@ def __iter__(self):
def cached_col(self):
return Cols(self.model._meta.db_table, self.fields, self)

def get_col(self, alias, output_field=None):
return self.cached_col

def get_lookup(self, lookup_name):
if lookup_name == "exact":
return TupleExact
Expand Down
1 change: 1 addition & 0 deletions tests/composite_pk/models/tenant.py
Expand Up @@ -13,6 +13,7 @@ class Tenant(models.Model):
class User(models.Model):
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
id = ID(unique=True)
email = models.EmailField()

class Meta:
primary_key = ("tenant_id", "id")
Expand Down
96 changes: 57 additions & 39 deletions tests/composite_pk/test_create.py
Expand Up @@ -22,90 +22,105 @@ def setUpTestData(cls):
@unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test")
def test_create_user_in_sqlite(self):
test_cases = [
({"tenant": self.tenant, "id": 2412}, 2412),
({"tenant_id": self.tenant.id, "id": 5316}, 5316),
({"pk": (self.tenant.id, 7424)}, 7424),
{"tenant": self.tenant, "id": 2412, "email": "user2412@example.com"},
{"tenant_id": self.tenant.id, "id": 5316, "email": "user5316@example.com"},
{"pk": (self.tenant.id, 7424), "email": "user7424@example.com"},
]

for fields, user_id in test_cases:
with self.subTest(fields=fields, user_id=user_id):
for fields in test_cases:
user = User(**fields)
self.assertIsNotNone(user.id)
self.assertIsNotNone(user.email)

with self.subTest(fields=fields):
with CaptureQueriesContext(connection) as context:
obj = User.objects.create(**fields)

self.assertEqual(obj.tenant_id, self.tenant.id)
self.assertEqual(obj.id, user_id)
self.assertEqual(obj.pk, (self.tenant.id, user_id))
self.assertEqual(obj.id, user.id)
self.assertEqual(obj.pk, (self.tenant.id, user.id))
self.assertEqual(obj.email, user.email)
self.assertEqual(len(context.captured_queries), 1)
u = User._meta.db_table
self.assertEqual(
context.captured_queries[0]["sql"],
f'INSERT INTO "{u}" ("tenant_id", "id") '
f"VALUES ({self.tenant.id}, {user_id})",
f'INSERT INTO "{u}" ("tenant_id", "id", "email") '
f"VALUES ({self.tenant.id}, {user.id}, '{user.email}')",
)

@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test")
def test_create_user_in_postgresql(self):
test_cases = [
({"tenant": self.tenant, "id": 5231}, 5231),
({"tenant_id": self.tenant.id, "id": 6123}, 6123),
({"pk": (self.tenant.id, 3513)}, 3513),
{"tenant": self.tenant, "id": 5231, "email": "user5231@example.com"},
{"tenant_id": self.tenant.id, "id": 6123, "email": "user6123@example.com"},
{"pk": (self.tenant.id, 3513), "email": "user3513@example.com"},
]

for fields, user_id in test_cases:
with self.subTest(fields=fields, user_id=user_id):
for fields in test_cases:
user = User(**fields)
self.assertIsNotNone(user.id)
self.assertIsNotNone(user.email)

with self.subTest(fields=fields):
with CaptureQueriesContext(connection) as context:
obj = User.objects.create(**fields)

self.assertEqual(obj.tenant_id, self.tenant.id)
self.assertEqual(obj.id, user_id)
self.assertEqual(obj.pk, (self.tenant.id, user_id))
self.assertEqual(obj.id, user.id)
self.assertEqual(obj.pk, (self.tenant.id, user.id))
self.assertEqual(obj.email, user.email)
self.assertEqual(len(context.captured_queries), 1)
u = User._meta.db_table
self.assertEqual(
context.captured_queries[0]["sql"],
f'INSERT INTO "{u}" ("tenant_id", "id") '
f"VALUES ({self.tenant.id}, {user_id}) "
f'INSERT INTO "{u}" ("tenant_id", "id", "email") '
f"VALUES ({self.tenant.id}, {user.id}, '{user.email}') "
f'RETURNING "{u}"."id"',
)

@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test")
def test_create_user_with_autofield_in_postgresql(self):
test_cases = [
{"tenant": self.tenant},
{"tenant_id": self.tenant.id},
{"tenant": self.tenant, "email": "user1111@example.com"},
{"tenant_id": self.tenant.id, "email": "user2222@example.com"},
]

for fields in test_cases:
user = User(**fields)
self.assertIsNotNone(user.email)

with CaptureQueriesContext(connection) as context:
obj = User.objects.create(**fields)

self.assertEqual(obj.tenant_id, self.tenant.id)
self.assertIsInstance(obj.id, int)
self.assertGreater(obj.id, 0)
self.assertEqual(obj.pk, (self.tenant.id, obj.id))
self.assertEqual(obj.email, user.email)
self.assertEqual(len(context.captured_queries), 1)
u = User._meta.db_table
self.assertEqual(
context.captured_queries[0]["sql"],
f'INSERT INTO "{u}" ("tenant_id") '
f"VALUES ({self.tenant.id}) "
f'INSERT INTO "{u}" ("tenant_id", "email") '
f"VALUES ({self.tenant.id}, '{user.email}') "
f'RETURNING "{u}"."id"',
)

def test_save_user(self):
user = User(tenant=self.tenant, id=9241)
user = User(tenant=self.tenant, id=9241, email="user9241@example.com")
user.save()
self.assertEqual(user.tenant_id, self.tenant.id)
self.assertEqual(user.tenant, self.tenant)
self.assertEqual(user.id, 9241)
self.assertEqual(user.pk, (self.tenant.id, 9241))
self.assertEqual(user.email, "user9241@example.com")

@unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test")
def test_bulk_create_users_in_sqlite(self):
objs = [
User(tenant=self.tenant, id=8291),
User(tenant_id=self.tenant.id, id=4021),
User(pk=(self.tenant.id, 8214)),
User(tenant=self.tenant, id=8291, email="user8291@example.com"),
User(tenant_id=self.tenant.id, id=4021, email="user4021@example.com"),
User(pk=(self.tenant.id, 8214), email="user8214@example.com"),
]

with CaptureQueriesContext(connection) as context:
Expand All @@ -125,19 +140,20 @@ def test_bulk_create_users_in_sqlite(self):
u = User._meta.db_table
self.assertEqual(
context.captured_queries[0]["sql"],
f'INSERT INTO "{u}" ("tenant_id", "id") '
f"VALUES ({self.tenant.id}, 8291), ({self.tenant.id}, 4021), "
f"({self.tenant.id}, 8214)",
f'INSERT INTO "{u}" ("tenant_id", "id", "email") '
f"VALUES ({self.tenant.id}, 8291, 'user8291@example.com'), "
f"({self.tenant.id}, 4021, 'user4021@example.com'), "
f"({self.tenant.id}, 8214, 'user8214@example.com')",
)

@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test")
def test_bulk_create_users_in_postgresql(self):
objs = [
User(tenant=self.tenant, id=8361),
User(tenant_id=self.tenant.id, id=2819),
User(pk=(self.tenant.id, 9136)),
User(tenant=self.tenant),
User(tenant_id=self.tenant.id),
User(tenant=self.tenant, id=8361, email="user8361@example.com"),
User(tenant_id=self.tenant.id, id=2819, email="user2819@example.com"),
User(pk=(self.tenant.id, 9136), email="user9136@example.com"),
User(tenant=self.tenant, email="user1111@example.com"),
User(tenant_id=self.tenant.id, email="user2222@example.com"),
]

with CaptureQueriesContext(connection) as context:
Expand Down Expand Up @@ -165,15 +181,17 @@ def test_bulk_create_users_in_postgresql(self):
u = User._meta.db_table
self.assertEqual(
context.captured_queries[0]["sql"],
f'INSERT INTO "{u}" ("tenant_id", "id") '
f"VALUES ({self.tenant.id}, 8361), ({self.tenant.id}, 2819), "
f"({self.tenant.id}, 9136) "
f'INSERT INTO "{u}" ("tenant_id", "id", "email") '
f"VALUES ({self.tenant.id}, 8361, 'user8361@example.com'), "
f"({self.tenant.id}, 2819, 'user2819@example.com'), "
f"({self.tenant.id}, 9136, 'user9136@example.com') "
f'RETURNING "{u}"."id"',
)
self.assertEqual(
context.captured_queries[1]["sql"],
f'INSERT INTO "{u}" ("tenant_id") '
f"VALUES ({self.tenant.id}), ({self.tenant.id}) "
f'INSERT INTO "{u}" ("tenant_id", "email") '
f"VALUES ({self.tenant.id}, 'user1111@example.com'), "
f"({self.tenant.id}, 'user2222@example.com') "
f'RETURNING "{u}"."id"',
)

Expand Down
4 changes: 2 additions & 2 deletions tests/composite_pk/test_delete.py
Expand Up @@ -76,7 +76,7 @@ def test_delete_tenant_by_pk(self):

def test_delete_user_by_id(self):
with CaptureQueriesContext(connection) as context:
result = User.objects.filter(id=self.user.id).delete()
result = User.objects.only("pk").filter(id=self.user.id).delete()

self.assertEqual(
result, (2, {"composite_pk.User": 1, "composite_pk.Comment": 1})
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_delete_user_by_id(self):

def test_delete_user_by_pk(self):
with CaptureQueriesContext(connection) as context:
result = User.objects.filter(pk=self.user.pk).delete()
result = User.objects.only("pk").filter(pk=self.user.pk).delete()

self.assertEqual(
result, (2, {"composite_pk.User": 1, "composite_pk.Comment": 1})
Expand Down
4 changes: 2 additions & 2 deletions tests/composite_pk/test_get.py
Expand Up @@ -51,7 +51,7 @@ def test_get_user_by_pk(self):
for lookup in test_cases:
with self.subTest(lookup=lookup):
with CaptureQueriesContext(connection) as context:
obj = User.objects.get(**lookup)
obj = User.objects.only("pk").get(**lookup)

self.assertEqual(obj, self.user)
self.assertEqual(len(context.captured_queries), 1)
Expand All @@ -78,7 +78,7 @@ def test_get_user_by_field(self):
for lookup, column, value in test_cases:
with self.subTest(lookup=lookup, column=column, value=value):
with CaptureQueriesContext(connection) as context:
obj = User.objects.get(**lookup)
obj = User.objects.only("pk").get(**lookup)

self.assertEqual(obj, self.user)
self.assertEqual(len(context.captured_queries), 1)
Expand Down
16 changes: 16 additions & 0 deletions tests/composite_pk/test_update.py
Expand Up @@ -105,3 +105,19 @@ def test_update_or_create_user_by_pk(self):
self.assertEqual(user.pk, self.user.pk)
self.assertEqual(user.tenant_id, self.tenant.id)
self.assertEqual(user.id, self.user.id)

# def test_update_comment(self):
# with CaptureQueriesContext(connection) as context:
# result = Comment.objects.filter(user__tenant__id=self.tenant.id).update(
# id=8341
# )
#
# if connection.vendor in ("sqlite", "postgresql"):
# u = User._meta.db_table
# self.assertEqual(
# context.captured_queries[0]["sql"],
# f'UPDATE "{u}" '
# 'SET "id" = 8341 '
# f'WHERE ("{u}"."tenant_id" = {self.tenant.id} '
# f'AND "{u}"."id" = {self.user.id})',
# )
2 changes: 1 addition & 1 deletion tests/composite_pk/tests.py
Expand Up @@ -53,7 +53,7 @@ def test_pk_updated_if_field_updated(self):

def test_composite_pk_in_fields(self):
user_fields = {f.name for f in User._meta.get_fields()}
self.assertEqual(user_fields, {"id", "tenant", "primary_key"})
self.assertEqual(user_fields, {"id", "tenant", "primary_key", "email"})

comment_fields = {f.name for f in Comment._meta.get_fields()}
self.assertEqual(
Expand Down

0 comments on commit 6a26b19

Please sign in to comment.