Skip to content

Commit

Permalink
Use sql alchemy to fetch a scalar for the max tunnel id
Browse files Browse the repository at this point in the history
Bug 1174998

Change-Id: I77203c81c2c910a1a601416efc4567f1320a2eef
  • Loading branch information
salv-orlando committed May 1, 2013
1 parent dbae68c commit bde296d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
20 changes: 6 additions & 14 deletions quantum/plugins/openvswitch/ovs_db_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# @author: Bob Kukura, Red Hat, Inc.

from sqlalchemy.orm import exc
from sqlalchemy.sql import func

from quantum.common import exceptions as q_exc
import quantum.db.api as db
Expand Down Expand Up @@ -365,18 +366,9 @@ def get_tunnel_endpoints():


def _generate_tunnel_id(session):
try:
# TODO(rpodolyaka): Query.all() can't raise the NoResultNound exception
# Fix this later along with other identical cases.
tunnels = session.query(ovs_models_v2.TunnelEndpoint).all()
except exc.NoResultFound:
return 0
tunnel_ids = ([tunnel['id'] for tunnel in tunnels])
if tunnel_ids:
id = max(tunnel_ids)
else:
id = 0
return id + 1
max_tunnel_id = session.query(
func.max(ovs_models_v2.TunnelEndpoint.id)).scalar() or 0
return max_tunnel_id + 1


def add_tunnel_endpoint(ip):
Expand All @@ -385,8 +377,8 @@ def add_tunnel_endpoint(ip):
tunnel = (session.query(ovs_models_v2.TunnelEndpoint).
filter_by(ip_address=ip).with_lockmode('update').one())
except exc.NoResultFound:
id = _generate_tunnel_id(session)
tunnel = ovs_models_v2.TunnelEndpoint(ip, id)
tunnel_id = _generate_tunnel_id(session)
tunnel = ovs_models_v2.TunnelEndpoint(ip, tunnel_id)
session.add(tunnel)
session.flush()
return tunnel
8 changes: 8 additions & 0 deletions quantum/tests/unit/openvswitch/test_ovs_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,14 @@ def test_tunnel_pool(self):
for tunnel_id in tunnel_ids:
ovs_db_v2.release_tunnel(self.session, tunnel_id, TUNNEL_RANGES)

def test_add_tunnel_endpoints(self):
tun_1 = ovs_db_v2.add_tunnel_endpoint('192.168.0.1')
tun_2 = ovs_db_v2.add_tunnel_endpoint('192.168.0.2')
self.assertEquals(1, tun_1.id)
self.assertEquals('192.168.0.1', tun_1.ip_address)
self.assertEquals(2, tun_2.id)
self.assertEquals('192.168.0.2', tun_2.ip_address)

def test_specific_tunnel_inside_pool(self):
tunnel_id = TUN_MIN + 5
self.assertFalse(ovs_db_v2.get_tunnel_allocation(tunnel_id).allocated)
Expand Down

0 comments on commit bde296d

Please sign in to comment.