Skip to content
This repository has been archived by the owner. It is now read-only.
Permalink
Browse files
Added unit tests
  • Loading branch information
agoodm committed Jul 25, 2016
1 parent 198de48 commit 98a67d130ad9b2c946f18dba90da6b68c51594de
Showing 1 changed file with 168 additions and 0 deletions.
@@ -0,0 +1,168 @@
import unittest
import os
import copy
import netCDF4
import numpy as np
from ocw.dataset import Dataset
from ocw.dataset_loader import DatasetLoader

class TestDatasetLoader(unittest.TestCase):
def setUp(self):
# Read netCDF file
self.file_path = create_netcdf_object()
self.netCDF_file = netCDF4.Dataset(self.file_path, 'r')
self.latitudes = self.netCDF_file.variables['latitude'][:]
self.longitudes = self.netCDF_file.variables['longitude'][:]
self.times = self.netCDF_file.variables['time'][:]
self.alt_lats = self.netCDF_file.variables['alt_lat'][:]
self.alt_lons = self.netCDF_file.variables['alt_lon'][:]
self.values = self.netCDF_file.variables['value'][:]
self.values2 = self.values + 1

# Set up config
self.reference_config = {'data_source': 'local',
'file_path': self.file_path,
'variable_name': 'value'}
self.target_config = copy.deepcopy(self.reference_config)
self.no_data_source_config = {'file_path': self.file_path,
'variable_name': 'value'}
self.new_data_source_config = {'data_source': 'foo',
'lats': self.latitudes,
'lons': self.longitudes,
'times': self.times,
'values': self.values2,
'variable': 'value'}

def tearDown(self):
os.remove(self.file_path)

def testInputHasDataSource(self):
'''
Make sure input data source is specified for each dataset to be loaded
'''
with self.assertRaises(KeyError):
self.loader = DatasetLoader(self.reference_config,
self.no_data_source_config)

def testReferenceHasDataSource(self):
'''
Make sure ref data source is specified for each dataset to be loaded
'''
with self.assertRaises(KeyError):
self.loader = DatasetLoader(self.reference_config,
self.target_config)
self.loader.set_reference(**self.no_data_source_config)

def testTargetHasDataSource(self):
'''
Make sure target data source is specified for each dataset to be loaded
'''
with self.assertRaises(KeyError):
self.loader = DatasetLoader(self.reference_config,
self.target_config)
self.loader.add_target(**self.no_data_source_config)

def testNewDataSource(self):
'''
Ensures that custom data source loaders can be added
'''
self.loader = DatasetLoader(self.new_data_source_config,
self.target_config)

# Here the the data_source "foo" represents the Dataset constructor
self.loader.add_source_loader('foo', build_dataset)
self.loader.load_datasets()
self.assertEqual(self.loader.reference_dataset.origin['source'],
'foo')
np.testing.assert_array_equal(self.loader.reference_dataset.values,
self.values2)

def testExistingDataSource(self):
'''
Ensures that existing data source loaders can be added
'''
self.loader = DatasetLoader(self.reference_config,
self.target_config)
self.loader.load_datasets()
self.assertEqual(self.loader.reference_dataset.origin['source'],
'local')
np.testing.assert_array_equal(self.loader.reference_dataset.values,
self.values)

def testMultipleTargets(self):
'''
Test for when multiple target dataset configs are specified
'''
self.loader = DatasetLoader(self.reference_config,
[self.target_config,
self.new_data_source_config])

# Here the the data_source "foo" represents the Dataset constructor
self.loader.add_source_loader('foo', build_dataset)
self.loader.load_datasets()
self.assertEqual(self.loader.target_datasets[0].origin['source'],
'local')
self.assertEqual(self.loader.target_datasets[1].origin['source'],
'foo')
np.testing.assert_array_equal(self.loader.target_datasets[0].values,
self.values)
np.testing.assert_array_equal(self.loader.target_datasets[1].values,
self.values2)

def build_dataset(*args, **kwargs):
'''
Wrapper to Dataset constructor from fictitious 'foo' data_source.
'''
origin = {'source': 'foo'}
return Dataset(*args, origin=origin, **kwargs)

def create_netcdf_object():
# To create the temporary netCDF file
file_path = '/tmp/temporaryNetcdf.nc'
netCDF_file = netCDF4.Dataset(file_path, 'w', format='NETCDF4')
# To create dimensions
netCDF_file.createDimension('lat_dim', 5)
netCDF_file.createDimension('lon_dim', 5)
netCDF_file.createDimension('time_dim', 3)
# To create variables
latitudes = netCDF_file.createVariable('latitude', 'd', ('lat_dim',))
longitudes = netCDF_file.createVariable('longitude', 'd', ('lon_dim',))
times = netCDF_file.createVariable('time', 'd', ('time_dim',))
# unusual variable names to test optional arguments for Dataset constructor
alt_lats = netCDF_file.createVariable('alt_lat', 'd', ('lat_dim',))
alt_lons = netCDF_file.createVariable('alt_lon', 'd', ('lon_dim',))
alt_times = netCDF_file.createVariable('alt_time', 'd', ('time_dim',))
values = netCDF_file.createVariable('value', 'd',
('time_dim',
'lat_dim',
'lon_dim')
)

# To latitudes and longitudes for five values
latitudes_data = np.arange(5.)
longitudes_data = np.arange(150., 155.)
# Three months of data.
times_data = np.arange(3)
# Create 150 values
values_data = np.array([i for i in range(75)])
# Reshape values to 4D array (level, time, lats, lons)
values_data = values_data.reshape(len(times_data), len(latitudes_data),
len(longitudes_data))

# Ingest values to netCDF file
latitudes[:] = latitudes_data
longitudes[:] = longitudes_data
times[:] = times_data
alt_lats[:] = latitudes_data + 10
alt_lons[:] = longitudes_data - 10
alt_times[:] = times_data
values[:] = values_data
# Assign time info to time variable
netCDF_file.variables['time'].units = 'months since 2001-01-01 00:00:00'
netCDF_file.variables['alt_time'].units = 'months since 2001-04-01 00:00:00'
netCDF_file.variables['value'].units = 'foo_units'
netCDF_file.close()
return file_path

if __name__ == '__main__':
unittest.main()

0 comments on commit 98a67d1

Please sign in to comment.