Skip to content

Commit

Permalink
netbox.py inventory plugin - revision
Browse files Browse the repository at this point in the history
  • Loading branch information
Yannis100 committed Oct 15, 2019
1 parent fb49b8b commit 7d999e7
Showing 1 changed file with 93 additions and 37 deletions.
130 changes: 93 additions & 37 deletions lib/ansible/plugins/inventory/netbox.py
Expand Up @@ -12,11 +12,14 @@
- Remy Leone (@sieben)
- Anthony Ruhier (@Anthony25)
- Nikhil Singh Baliyan (@nikkytub)
- Sander Steffann (@steffann)
- Yannis Ansermoz (@Yannis100)
short_description: NetBox inventory source
description:
- Get inventory hosts from NetBox
extends_documentation_fragment:
- constructed
- inventory_cache
options:
plugin:
description: token that ensures this is a source file for the 'netbox' plugin.
Expand Down Expand Up @@ -144,12 +147,12 @@
from threading import Thread
from itertools import chain

from ansible.plugins.inventory import BaseInventoryPlugin, Constructable
from ansible.plugins.inventory import BaseInventoryPlugin, Constructable, Cacheable
from ansible.module_utils.ansible_release import __version__ as ansible_version
from ansible.errors import AnsibleError
from ansible.module_utils._text import to_text
from ansible.module_utils.urls import open_url
from ansible.module_utils.six.moves.urllib.parse import urljoin, urlencode
from ansible.module_utils.six.moves.urllib.parse import urlencode
from ansible.module_utils.compat.ipaddress import ip_interface

ALLOWED_DEVICE_QUERY_PARAMETERS = (
Expand Down Expand Up @@ -186,21 +189,62 @@
)


class InventoryModule(BaseInventoryPlugin, Constructable):
class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
NAME = 'netbox'

def _fetch_information(self, url):
response = open_url(url, headers=self.headers, timeout=self.timeout, validate_certs=self.validate_certs)

try:
raw_data = to_text(response.read(), errors='surrogate_or_strict')
except UnicodeError:
raise AnsibleError("Incorrect encoding of fetched payload from NetBox API.")

try:
return json.loads(raw_data)
except ValueError:
raise AnsibleError("Incorrect JSON payload: %s" % raw_data)
# response = open_url(url, headers=self.headers, timeout=self.timeout, validate_certs=self.validate_certs)
#
# try:
# raw_data = to_text(response.read(), errors='surrogate_or_strict')
# except UnicodeError:
# raise AnsibleError("Incorrect encoding of fetched payload from NetBox API.")
#
# try:
# return json.loads(raw_data)
# except ValueError:
# raise AnsibleError("Incorrect JSON payload: %s" % raw_data)

results = None
cache_key = self.get_cache_key(url)

# get the user's cache option to see if we should save the cache if it is changing
user_cache_setting = self.get_option('cache')

# read if the user has caching enabled and the cache isn't being refreshed
attempt_to_read_cache = user_cache_setting and self.use_cache

# attempt to read the cache if inventory isn't being refreshed and the user has caching enabled
if attempt_to_read_cache:
try:
results = self._cache[cache_key]
need_to_fetch = False
except KeyError:
# occurs if the cache_key is not in the cache or if the cache_key expired
# we need to fetch the URL now
need_to_fetch = True
else:
# not reading from cache so do fetch
need_to_fetch = True

if need_to_fetch:
self.display.v("Fetching: " + url)
response = open_url(url, headers=self.headers, timeout=self.timeout, validate_certs=self.validate_certs)
try:
raw_data = to_text(response.read(), errors='surrogate_or_strict')
except UnicodeError:
raise AnsibleError("Incorrect encoding of fetched payload from NetBox API.")

try:
results = json.loads(raw_data)
except ValueError:
raise AnsibleError("Incorrect JSON payload: %s" % raw_data)

# put result in cache if enabled
if user_cache_setting:
self._cache[cache_key] = results

return results

def get_resource_list(self, api_url):
"""Retrieves resource list from netbox API.
Expand Down Expand Up @@ -291,10 +335,11 @@ def extract_device_role(self, host):

def extract_config_context(self, host):
try:
if self.config_context:
url = self.api_endpoint + "/api/dcim/devices/" + str(host["id"])
device_lookup = self._fetch_information(url)
return device_lookup["config_context"]
# if self.config_context:
# url = self.api_endpoint + "/api/dcim/devices/" + str(host["id"])
# device_lookup = self._fetch_information(url)
# return device_lookup["config_context"]
return host["config_context"]
except Exception:
return

Expand All @@ -316,11 +361,11 @@ def extract_interfaces(self, host):
url = self.api_endpoint + "/api/dcim/interfaces/?limit=0&device_id=" + str(host["id"])
interfaces_lookup = self._fetch_information(url)
wanted_keys = ['description', 'enabled', 'lag', 'name', 'mode', 'tagged_vlans', 'untagged_vlan', 'tags', 'form_factor', 'ip']
interfaces_short = []
interfaces_short = dict()
for interface_lookup in interfaces_lookup['results']:
if interface_lookup['count_ipaddresses'] > 0:
interface_lookup['ip'] = self.extract_interface_ip(interface_lookup['id'])
interfaces_short.append(dict((k, interface_lookup[k]) for k in wanted_keys if k in interface_lookup))
interfaces_short[interface_lookup['name']] = (dict((k, interface_lookup[k]) for k in wanted_keys if k in interface_lookup))
return interfaces_short
except Exception:
return
Expand Down Expand Up @@ -410,8 +455,9 @@ def refresh_manufacturers_lookup(self):
self.manufacturers_lookup = dict((manufacturer["id"], manufacturer["name"]) for manufacturer in manufacturers)
self.manufacturers_slug_lookup = dict((manufacturer["id"], manufacturer["slug"]) for manufacturer in manufacturers)

def refresh_lookups(self):
lookup_processes = (
@property
def lookup_processes(self):
return [
self.refresh_sites_lookup,
self.refresh_regions_lookup,
self.refresh_tenants_lookup,
Expand All @@ -420,10 +466,11 @@ def refresh_lookups(self):
self.refresh_platforms_lookup,
self.refresh_device_types_lookup,
self.refresh_manufacturers_lookup,
)
]

def refresh_lookups(self):
thread_list = []
for p in lookup_processes:
for p in self.lookup_processes:
t = Thread(target=p)
thread_list.append(t)
t.start()
Expand All @@ -450,8 +497,14 @@ def refresh_url(self):
if self.query_filters:
query_parameters.extend(filter(lambda x: x,
map(self.validate_query_parameters, self.query_filters)))
self.device_url = self.api_endpoint + "/api/dcim/devices/?" + urlencode(query_parameters)
self.virtual_machines_url = self.api_endpoint + "/api/virtualization/virtual-machines/?" + urlencode(query_parameters)
# self.device_url = self.api_endpoint + "/api/dcim/devices/?" + urlencode(query_parameters)
# self.virtual_machines_url = self.api_endpoint + "/api/virtualization/virtual-machines/?" + urlencode(query_parameters)
if self.config_context:
self.device_url = self.api_endpoint + "/api/dcim/devices/?" + urlencode(query_parameters)
self.virtual_machines_url = self.api_endpoint + "/api/virtualization/virtual-machines/?" + urlencode(query_parameters)
else:
self.device_url = self.api_endpoint + "/api/dcim/devices/?" + urlencode(query_parameters) + "&exclude=config_context"
self.virtual_machines_url = self.api_endpoint + "/api/virtualization/virtual-machines/?" + urlencode(query_parameters) + "&exclude=config_context"

def fetch_hosts(self):
return chain(
Expand All @@ -475,7 +528,8 @@ def add_host_to_groups(self, host, hostname):
for sub_group in sub_groups:
group_name = "_".join([group[:self.substr], sub_group])
self.inventory.add_group(group=group_name)
self.inventory.add_host(group=group_name, host=hostname)
if group != "regions":
self.inventory.add_host(group=group_name, host=hostname)

def _fill_host_variables(self, host, hostname):
for attribute, extractor in self.group_extractors.items():
Expand Down Expand Up @@ -514,20 +568,21 @@ def main(self):

for host in hosts_list:
hostname = self.extract_name(host=host)
self.inventory.add_host(host=hostname)
self._fill_host_variables(host=host, hostname=hostname)
if 'virtual_chassis' in host and ((host['virtual_chassis'] is not None and host['virtual_chassis']['master']['id'] == host['id']) or host['virtual_chassis'] is None):
self.inventory.add_host(host=hostname)
self._fill_host_variables(host=host, hostname=hostname)

strict = self.get_option("strict")
strict = self.get_option("strict")

# Composed variables
self._set_composite_vars(self.get_option('compose'), host, hostname, strict=strict)
# Composed variables
self._set_composite_vars(self.get_option('compose'), host, hostname, strict=strict)

# Complex groups based on jinja2 conditionals, hosts that meet the conditional are added to group
self._add_host_to_composed_groups(self.get_option('groups'), host, hostname, strict=strict)
# Complex groups based on jinja2 conditionals, hosts that meet the conditional are added to group
self._add_host_to_composed_groups(self.get_option('groups'), host, hostname, strict=strict)

# Create groups based on variable values and add the corresponding hosts to it
self._add_host_to_keyed_groups(self.get_option('keyed_groups'), host, hostname, strict=strict)
self.add_host_to_groups(host=host, hostname=hostname)
# Create groups based on variable values and add the corresponding hosts to it
self._add_host_to_keyed_groups(self.get_option('keyed_groups'), host, hostname, strict=strict)
self.add_host_to_groups(host=host, hostname=hostname)

self.sgroup_by = set(self.group_by)
if "regions" in self.sgroup_by:
Expand All @@ -547,6 +602,7 @@ def main(self):
def parse(self, inventory, loader, path, cache=True):
super(InventoryModule, self).parse(inventory, loader, path)
self._read_config_data(path=path)
self.use_cache = cache

# Netbox access
token = self.get_option("token")
Expand Down

0 comments on commit 7d999e7

Please sign in to comment.