Skip to content

Commit

Permalink
Assert no traversing on unloaded entity field (#969)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvorisek committed May 27, 2022
1 parent 31310a3 commit 05339b9
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 46 deletions.
21 changes: 19 additions & 2 deletions src/Model/Join.php
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,17 @@ public function importModel(Model $model, array $defaults = [])
}
*/

/**
* @param mixed $value
*/
protected function assertReferenceIdNotNull($value): void
{
if ($value === null) {
throw (new Exception('Unable to join on null value'))
->addMoreInfo('value', $value);
}
}

/**
* @return mixed
*
Expand All @@ -434,6 +445,8 @@ protected function getId(Model $entity)
*/
protected function setId(Model $entity, $id): void
{
$this->assertReferenceIdNotNull($id);

$this->idByOid[spl_object_id($entity)] = $id;
}

Expand Down Expand Up @@ -530,12 +543,14 @@ public function afterInsert(Model $entity): void
return;
}

$this->setSaveBufferValue($entity, $this->foreign_field, $this->hasJoin() ? $this->getJoin()->getId($entity) : $entity->getId()); // TODO needed? from array persistence
$id = $this->hasJoin() ? $this->getJoin()->getId($entity) : $entity->getId();
$this->assertReferenceIdNotNull($id);
$this->setSaveBufferValue($entity, $this->foreign_field, $id); // TODO needed? from array persistence

$foreignModel = $this->getForeignModel();
$foreignEntity = $foreignModel->createEntity()
->setMulti($this->getAndUnsetSaveBuffer($entity))
->set($this->foreign_field, $this->hasJoin() ? $this->getJoin()->getId($entity) : $entity->getId());
->set($this->foreign_field, $id);
$foreignEntity->save();

$this->setId($entity, $entity->getId()); // TODO why is this here? it seems to be not needed
Expand All @@ -553,6 +568,7 @@ public function beforeUpdate(Model $entity, array &$data): void

$foreignModel = $this->getForeignModel();
$foreignId = $this->reverse ? $entity->getId() : $entity->get($this->master_field);
$this->assertReferenceIdNotNull($foreignId);
$saveBuffer = $this->getAndUnsetSaveBuffer($entity);
$foreignModel->atomic(function () use ($foreignModel, $foreignId, $saveBuffer) {
$foreignModel = (clone $foreignModel)->addCondition($this->foreign_field, $foreignId);
Expand All @@ -573,6 +589,7 @@ public function doDelete(Model $entity): void

$foreignModel = $this->getForeignModel();
$foreignId = $this->reverse ? $entity->getId() : $entity->get($this->master_field);
$this->assertReferenceIdNotNull($foreignId);
$foreignModel->atomic(function () use ($foreignModel, $foreignId) {
$foreignModel = (clone $foreignModel)->addCondition($this->foreign_field, $foreignId);
foreach ($foreignModel as $foreignEntity) {
Expand Down
11 changes: 11 additions & 0 deletions src/Reference.php
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ public function setOwner(object $owner)
return $this->_setOwner($owner);
}

/**
* @param mixed $value
*/
protected function assertReferenceValueNotNull($value): void
{
if ($value === null) {
throw (new Exception('Unable to traverse on null value'))
->addMoreInfo('value', $value);
}
}

protected function getOurFieldName(): string
{
return $this->our_field ?: $this->getOurModel(null)->id_field;
Expand Down
7 changes: 5 additions & 2 deletions src/Reference/HasMany.php
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,13 @@ protected function getOurFieldValueForRefCondition(Model $ourModel)
{
$ourModel = $this->getOurModel($ourModel);

if ($ourModel->isEntity() && $ourModel->isLoaded()) {
return $this->our_field
if ($ourModel->isEntity()) {
$res = $this->our_field
? $ourModel->get($this->our_field)
: $ourModel->getId();
$this->assertReferenceValueNotNull($res);

return $res;
}

// create expression based on existing conditions
Expand Down
30 changes: 15 additions & 15 deletions src/Reference/HasOne.php
Original file line number Diff line number Diff line change
Expand Up @@ -88,30 +88,30 @@ public function ref(Model $ourModel, array $defaults = []): Model
});

if ($ourModel->isEntity()) {
if ($ourValue = $this->getOurFieldValue($ourModel)) {
// if our model is loaded, then try to load referenced model
if ($this->their_field) {
$theirModel = $theirModel->tryLoadBy($this->their_field, $ourValue);
} else {
$theirModel = $theirModel->tryLoad($ourValue);
}
$ourValue = $this->getOurFieldValue($ourModel);
$this->assertReferenceValueNotNull($ourValue);

if ($this->their_field) {
$theirModel = $theirModel->tryLoadBy($this->their_field, $ourValue);
} else {
$theirModel = $theirModel->createEntity();
$theirModel = $theirModel->tryLoad($ourValue);
}
}

// their model will be reloaded after saving our model to reflect changes in referenced fields
$theirModel->getModel(true)->reload_after_save = false;

$this->onHookToTheirModel($theirModel, Model::HOOK_AFTER_SAVE, function (Model $theirModel) use ($ourModel) {
$theirValue = $this->their_field ? $theirModel->get($this->their_field) : $theirModel->getId();
if ($ourModel->isEntity()) {
$this->onHookToTheirModel($theirModel, Model::HOOK_AFTER_SAVE, function (Model $theirModel) use ($ourModel) {
$theirValue = $this->their_field ? $theirModel->get($this->their_field) : $theirModel->getId();

if (!$this->getOurField()->compare($this->getOurFieldValue($ourModel), $theirValue)) {
$ourModel->set($this->getOurFieldName(), $theirValue)->save();
}
if (!$this->getOurField()->compare($this->getOurFieldValue($ourModel), $theirValue)) {
$ourModel->set($this->getOurFieldName(), $theirValue)->save();
}

$theirModel->reload();
});
$theirModel->reload();
});
}

return $theirModel;
}
Expand Down
24 changes: 9 additions & 15 deletions src/Reference/HasOneSql.php
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,17 @@ public function ref(Model $ourModel, array $defaults = []): Model

$theirFieldName = $this->their_field ?? $theirModel->id_field; // TODO why not $this->getTheirFieldName() ?
// At this point the reference
// if our_field is the id_field and is being used in the reference
// we should persist the relation in condtition
// example - $model->load(1)->ref('refLink')->import($rows);
if ($ourModel->isEntity() && $ourModel->isLoaded() && !$theirModel->isLoaded()) {
if ($ourModel->id_field === $this->getOurFieldName()) {
return $theirModel->getModel()
->addCondition($theirFieldName, $this->getOurFieldValue($ourModel));
}
if ($ourModel->isEntity()) {
$theirModel->getModel()
->addCondition($theirFieldName, $this->getOurFieldValue($ourModel));
} else {
// handles the deep traversal using an expression
$ourFieldExpression = $ourModel->action('field', [$this->getOurField()]);

$theirModel->getModel(true)
->addCondition($theirFieldName, $ourFieldExpression);
}

// handles the deep traversal using an expression
$ourFieldExpression = $ourModel->action('field', [$this->getOurField()]);

$theirModel->getModel(true)
->addCondition($theirFieldName, $ourFieldExpression);

return $theirModel;
}

Expand Down
44 changes: 33 additions & 11 deletions tests/ReferenceSqlTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

namespace Atk4\Data\Tests;

use Atk4\Data\Exception;
use Atk4\Data\Model;
use Atk4\Data\Schema\TestCase;
use Doctrine\DBAL\Platforms\PostgreSQLPlatform;
Expand Down Expand Up @@ -365,7 +366,7 @@ public function testOtherAggregates(): void
$this->assertNull($ll->get('chicken5'));
}

public function testReferenceHasOneTraversing(): void
protected function setupDbForTraversing(): Model
{
$this->setDb([
'user' => [
Expand Down Expand Up @@ -396,23 +397,44 @@ public function testReferenceHasOneTraversing(): void

$company->hasMany('Orders', ['model' => $order]);

$user = $user->load(1);
return $user;
}

$firstUserOrders = $user->ref('Company')->ref('Orders');
$firstUserOrders->setOrder('id');
public function testReferenceHasOneTraversing(): void
{
$user = $this->setupDbForTraversing();
$userEntity = $user->load(1);

$this->assertSameExportUnordered([
['id' => 1, 'company_id' => '1', 'description' => 'Vinny Company Order 1', 'amount' => 50.0],
['id' => 3, 'company_id' => '1', 'description' => 'Vinny Company Order 2', 'amount' => 15.0],
], $firstUserOrders->export());

$user->unload();
], $userEntity->ref('Company')->ref('Orders')->export());

$this->assertSameExportUnordered([
['id' => 1, 'company_id' => '1', 'description' => 'Vinny Company Order 1', 'amount' => 50.0],
['id' => 2, 'company_id' => '2', 'description' => 'Zoe Company Order', 'amount' => 10.0],
['id' => 3, 'company_id' => '1', 'description' => 'Vinny Company Order 2', 'amount' => 15.0],
], $user->ref('Company')->ref('Orders')->setOrder('id')->export());
], $userEntity->getModel()->ref('Company')->ref('Orders')->export());
}

public function testUnloadedEntityTraversingHasOnedEx(): void
{
$user = $this->setupDbForTraversing();
$userEntity = $user->createEntity();

$this->expectException(Exception::class);
$this->expectExceptionMessage('Unable to traverse on null value');
$userEntity->ref('Company');
}

public function testUnloadedEntityTraversingHasManyEx(): void
{
$user = $this->setupDbForTraversing();
$companyEntity = $user->ref('Company')->createEntity();

$this->expectException(Exception::class);
$this->expectExceptionMessage('Unable to traverse on null value');
$companyEntity->ref('Orders');
}

public function testReferenceHook(): void
Expand Down Expand Up @@ -443,14 +465,14 @@ public function testReferenceHook(): void
$uu = $u->load(2);
$this->assertNull($uu->get('address'));
$this->assertNull($uu->get('contact_id'));
$this->assertNull($uu->ref('contact_id')->get('address'));

$uu = $u->load(3);
$this->assertSame('Joe contact', $uu->get('address'));
$this->assertSame('Joe contact', $uu->ref('contact_id')->get('address'));

$uu = $u->load(2);
$uu->ref('contact_id')->save(['address' => 'Peters new contact']);
$cc = $uu->getModel()->ref('contact_id')->createEntity()->save(['address' => 'Peters new contact']);
$uu->set('contact_id', $cc->getId());

$this->assertNotNull($uu->get('contact_id'));
$this->assertSame('Peters new contact', $uu->ref('contact_id')->get('address'));
Expand Down Expand Up @@ -486,7 +508,7 @@ public function testIdFieldReferenceOurFieldCase(): void
$p->hasOne('Stadium', ['model' => $s, 'our_field' => 'id', 'their_field' => 'player_id']);

$p = $p->load(2);
$p->ref('Stadium')->import([['name' => 'Nou camp nou']]);
$p->ref('Stadium')->getModel()->import([['name' => 'Nou camp nou']]);
$this->assertSame('Nou camp nou', $p->ref('Stadium')->get('name'));
$this->assertSame(2, $p->ref('Stadium')->get('player_id'));
}
Expand Down
2 changes: 1 addition & 1 deletion tests/ReferenceTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public function testModelProperty(): void
$user = $user->createEntity();
$user->setId(1);
$user->getModel()->hasOne('order_id', ['model' => [Model::class, 'table' => 'order']]);
$o = $user->ref('order_id');
$o = $user->getModel()->ref('order_id');
$this->assertSame('order', $o->table);
}

Expand Down

0 comments on commit 05339b9

Please sign in to comment.