diff --git a/src/ngamsCore/ngamsLib/ngamsDbNgasSubscribers.py b/src/ngamsCore/ngamsLib/ngamsDbNgasSubscribers.py index 6bc39322..cbb0a68d 100644 --- a/src/ngamsCore/ngamsLib/ngamsDbNgasSubscribers.py +++ b/src/ngamsCore/ngamsLib/ngamsDbNgasSubscribers.py @@ -37,6 +37,7 @@ from . import ngamsDbCore from .ngamsCore import fromiso8601 +from .ngamsSubscriber import ngamsSubscriber class ngamsDbNgasSubscribers(ngamsDbCore.ngamsDbCore): @@ -132,6 +133,11 @@ def _getSubscriberInfo(self, tx, subscrId=None, hostId=None, portNo=-1): def insertSubscriberEntry(self, sub_obj): + """ + Inserts the new subscription object into the NGAS subscription table. + If an object with the same subscription ID exists, its contents are + returned; otherwise the given object is returned. + """ hostId = sub_obj.getHostId() portNo = sub_obj.getPortNo() @@ -157,8 +163,15 @@ def insertSubscriberEntry(self, sub_obj): filterPlugIn, filterPlugInPars, \ lastFileIngDate, concurrent_threads) - self.query2(sql, args = vals) - self.triggerEvents() + # If a subscriber with the given ID already exists return that + with self.transaction() as tx: + existing_sub = self._getSubscriberInfo(tx, subscrId=subscrId) + if existing_sub: + sub_obj = ngamsSubscriber().unpackSqlResult(existing_sub[0]) + else: + tx.execute(sql, vals) + self.triggerEvents() + return sub_obj def updateSubscriberEntry(self, sub_obj): diff --git a/src/ngamsServer/ngamsServer/commands/subscribe.py b/src/ngamsServer/ngamsServer/commands/subscribe.py index d733bbac..0b00d460 100644 --- a/src/ngamsServer/ngamsServer/commands/subscribe.py +++ b/src/ngamsServer/ngamsServer/commands/subscribe.py @@ -52,10 +52,16 @@ def addSubscriber(srvObj, subscrObj): Returns: Void. """ - srvObj.getDb().insertSubscriberEntry(subscrObj) - #subscrObj.write(srvObj.getDb()) - - srvObj.registerSubscriber(subscrObj) + subscr_in_db = srvObj.getDb().insertSubscriberEntry(subscrObj) + if subscr_in_db is subscrObj: + srvObj.registerSubscriber(subscrObj) + # Trigger the Data Susbcription Thread to make it check if there are + # files to deliver to the new Subscriber. + srvObj.addSubscriptionInfo([], [subscrObj]).triggerSubscriptionThread() + elif subscr_in_db == subscrObj: + return 'equal' + else: + return 'unequal' def handleCmd(srvObj, @@ -137,10 +143,10 @@ def handleCmd(srvObj, subscrObj.setLastFileIngDate(lastIngDate) # Register the Subscriber. - addSubscriber(srvObj, subscrObj) - - # Trigger the Data Susbcription Thread to make it check if there are - # files to deliver to the new Subscriber. - srvObj.addSubscriptionInfo([], [subscrObj]).triggerSubscriptionThread() + existence_test = addSubscriber(srvObj, subscrObj) + if existence_test == 'equal': + return 201, "Identical subscription with ID '%s' existed" % (id,) + elif existence_test == 'unequal': + return 409, "Different subscription with ID '%s' existed" % (id,) return "Handled SUBSCRIBE command" diff --git a/test/test_subscription.py b/test/test_subscription.py index b92461de..40e26b8a 100644 --- a/test/test_subscription.py +++ b/test/test_subscription.py @@ -42,8 +42,8 @@ import requests import trustme -from ngamsLib import ngamsHttpUtils, ngamsSubscriber -from ngamsLib.ngamsCore import getHostName +from ngamsLib import ngamsHttpUtils, ngamsDb, ngamsSubscriber +from ngamsLib.ngamsCore import getHostName, toiso8601 from .ngamsTestLib import ngamsTestSuite, tmp_path, genTmpFilename try: @@ -437,6 +437,39 @@ def test_subscription_equality(self): subs2 = ngamsSubscriber.ngamsSubscriber(url=URL, subscrId='my-id') self.assertEqual(subs1, subs2) + # Store in DB, check equality holds + cfg = self.env_aware_cfg() + self.point_to_sqlite_database(cfg, tmp_path('ngas.sqlite')) + db = ngamsDb.from_config(cfg, maxpool=1) + with contextlib.closing(db): + db_subs1 = db.insertSubscriberEntry(subs1) + db_subs2 = db.insertSubscriberEntry(subs2) + self.assertIs(subs1, db_subs1) + self.assertEqual(subs1, db_subs2) + self.assertEqual(subs2, db_subs2) + + def test_duplicate_subscription(self): + """ + Test that creating multiple subscriptions with the same ID results in + different HTTP codes returned to the client + """ + self.prepExtSrv() + URL = 'http://127.0.0.1:1234/path' + NOW = time.time() + START_DATE = toiso8601(NOW, local=True) + def subscribe(url=URL, start_date=START_DATE): + return self.client.subscribe(url=url, startDate=start_date, + pars=[['subscr_id', 'my-id']]) + + status = subscribe() + self.assertEqual(status.http_status, 200) + status = subscribe() + self.assertEqual(status.http_status, 201) + status = subscribe(url=URL + '/subpath') + self.assertEqual(status.http_status, 409) + status = subscribe(start_date=toiso8601(NOW + 1, local=True)) + self.assertEqual(status.http_status, 409) + def upload_subscription_files(self, start_port, end_port, pars=[]): # Initial archiving self.qarchive(start_port, 'src/SmallFile.fits', mimeType='application/octet-stream', pars=pars)