Skip to content

Commit

Permalink
Merge pull request #37 from TUW-GEO/netcdf_2D_shape
Browse files Browse the repository at this point in the history
Improvements to netcdf storage
  • Loading branch information
cpaulik committed May 11, 2016
2 parents 1bbae05 + 8ae99eb commit e372f27
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 58 deletions.
4 changes: 3 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
Changelog
=========

v0.1.10
v0.1.xx
=======

- fix bug in storing/loading grids with shape attribute set.
- change equality check of grids to be more flexible. Now only a match of the
tuples gpi, lon, lat, cell is checked. The order does no longer matter.

v0.1.9
======
Expand Down
20 changes: 15 additions & 5 deletions pygeogrids/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,19 +624,26 @@ def __eq__(self, other):
Returns True if grids are equal.
"""
# only test to certain significance for float variables
# grids are assumed to be the same if the gpi, lon, lat tuples are the
# same
idx_gpi = np.argsort(self.gpis)
idx_gpi_other = np.argsort(other.gpis)
gpisame = np.all(self.gpis[idx_gpi] == other.gpis[idx_gpi_other])
try:
nptest.assert_allclose(self.arrlon, other.arrlon)
nptest.assert_allclose(self.arrlon[idx_gpi],
other.arrlon[idx_gpi_other])
lonsame = True
except AssertionError:
lonsame = False
try:
nptest.assert_allclose(self.arrlat, other.arrlat)
nptest.assert_allclose(self.arrlat[idx_gpi],
other.arrlat[idx_gpi_other])
latsame = True
except AssertionError:
latsame = False
gpisame = np.all(self.gpis == other.gpis)
if self.subset is not None and other.subset is not None:
subsetsame = np.all(self.subset == other.subset)
subsetsame = np.all(
sorted(self.gpis[self.subset]) == sorted(other.gpis[other.subset]))
elif self.subset is None and other.subset is None:
subsetsame = True
else:
Expand Down Expand Up @@ -944,7 +951,10 @@ def __eq__(self, other):
Returns true if equal.
"""
basicsame = super(CellGrid, self).__eq__(other)
cellsame = np.all(self.arrcell == other.arrcell)
idx_gpi = np.argsort(self.gpis)
idx_gpi_other = np.argsort(other.gpis)
cellsame = np.all(self.arrcell[idx_gpi]
== other.arrcell[idx_gpi_other])
return np.all([basicsame, cellsame])


Expand Down
149 changes: 117 additions & 32 deletions pygeogrids/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,15 @@
import numpy as np
import os
from datetime import datetime

from pygeogrids import CellGrid, BasicGrid


def save_lonlat(filename, arrlon, arrlat, geodatum, arrcell=None,
gpis=None, subsets={}, global_attrs=None):
gpis=None, subsets={}, global_attrs=None,
format='NETCDF4',
zlib=False,
complevel=4,
shuffle=True):
"""
saves grid information to netCDF file
Expand All @@ -65,11 +68,21 @@ def save_lonlat(filename, arrlon, arrlat, geodatum, arrcell=None,
'meaning': 'water, land'}}
global_attrs : dict, optional
if given will be written as global attributs into netCDF file
format: string, optional
choose either from one of these NetCDF formats
'NETCDF4'
'NETCDF4_CLASSIC'
'NETCDF3_CLASSIC'
'NETCDF3_64BIT_OFFSET'
zlib: boolean, optional
see netCDF documentation
shuffle: boolean, optional
see netCDF documentation
complevel: int, opational
see netCDF documentation
"""

nc_name = filename

with Dataset(nc_name, 'w', format='NETCDF4') as ncfile:
with Dataset(filename, 'w', format=format) as ncfile:

if (global_attrs is not None and 'shape' in global_attrs and
type(global_attrs['shape']) is not int and
Expand All @@ -79,35 +92,48 @@ def save_lonlat(filename, arrlon, arrlat, geodatum, arrcell=None,
lonsize = global_attrs['shape'][0]
ncfile.createDimension("lat", latsize)
ncfile.createDimension("lon", lonsize)
arrlat = np.unique(arrlat)[::-1] # sorts arrlat descending
arrlon = np.unique(arrlon)

gpisize = global_attrs['shape'][0] * global_attrs['shape'][1]
if gpis is None:
gpivalues = np.arange(gpisize,
dtype=np.int32).reshape(latsize,
lonsize)
gpivalues = gpivalues[::-1]
else:
gpivalues = gpis.reshape(latsize, lonsize)

lons = arrlon.reshape(latsize, lonsize)
lats = arrlat.reshape(latsize, lonsize)
# sort arrlon, arrlat and gpis
arrlon_sorted, arrlat_sorted, gpivalues = sort_for_netcdf(
lons, lats, gpivalues)

# sorts arrlat descending
arrlat_store = np.unique(arrlat_sorted)[::-1]
arrlon_store = np.unique(arrlon_sorted)

else:
ncfile.createDimension("gp", arrlon.size)
gpisize = arrlon.size
if gpis is None:
gpivalues = np.arange(arrlon.size, dtype=np.int32)
else:
gpivalues = gpis
arrlon_store = arrlon
arrlat_store = arrlat

dim = list(ncfile.dimensions.keys())

crs = ncfile.createVariable('crs', np.dtype('int32').char)
crs = ncfile.createVariable('crs', np.dtype('int32').char,
shuffle=shuffle,
zlib=zlib, complevel=complevel)
setattr(crs, 'grid_mapping_name', 'latitude_longitude')
setattr(crs, 'longitude_of_prime_meridian', 0.)
setattr(crs, 'semi_major_axis', geodatum.geod.a)
setattr(crs, 'inverse_flattening', 1. / geodatum.geod.f)
setattr(crs, 'ellipsoid_name', geodatum.name)

gpi = ncfile.createVariable('gpi', np.dtype('int32').char, dim)
gpi = ncfile.createVariable('gpi', np.dtype('int32').char, dim,
shuffle=shuffle,
zlib=zlib, complevel=complevel)

if gpis is None:
gpi[:] = gpivalues
Expand All @@ -123,8 +149,10 @@ def save_lonlat(filename, arrlon, arrlat, geodatum, arrcell=None,
gpidirect = 0x0b

latitude = ncfile.createVariable('lat', np.dtype('float64').char,
dim[0])
latitude[:] = arrlat
dim[0],
shuffle=shuffle,
zlib=zlib, complevel=complevel)
latitude[:] = arrlat_store
setattr(latitude, 'long_name', 'Latitude')
setattr(latitude, 'units', 'degree_north')
setattr(latitude, 'standard_name', 'latitude')
Expand All @@ -135,19 +163,25 @@ def save_lonlat(filename, arrlon, arrlat, geodatum, arrcell=None,
else:
londim = dim[0]
longitude = ncfile.createVariable('lon', np.dtype('float64').char,
londim)
longitude[:] = arrlon
londim,
shuffle=shuffle,
zlib=zlib, complevel=complevel)
longitude[:] = arrlon_store
setattr(longitude, 'long_name', 'Longitude')
setattr(longitude, 'units', 'degree_east')
setattr(longitude, 'standard_name', 'longitude')
setattr(longitude, 'valid_range', [-180.0, 180.0])

if arrcell is not None:
cell = ncfile.createVariable('cell', np.dtype('int16').char, dim)
cell = ncfile.createVariable('cell', np.dtype('int16').char,
dim,
shuffle=shuffle,
zlib=zlib, complevel=complevel)

if len(dim) == 2:
arrcell = arrcell.reshape(latsize,
lonsize)
_, _, arrcell = sort_for_netcdf(lons, lats, arrcell)
cell[:] = arrcell
setattr(cell, 'long_name', 'Cell')
setattr(cell, 'units', '')
Expand All @@ -156,7 +190,9 @@ def save_lonlat(filename, arrlon, arrlat, geodatum, arrcell=None,
if subsets:
for subset_name in subsets.keys():
flag = ncfile.createVariable(subset_name, np.dtype('int8').char,
dim)
dim,
shuffle=shuffle,
zlib=zlib, complevel=complevel)

# create flag array based on shape of data
lf = np.zeros_like(gpivalues)
Expand All @@ -165,6 +201,7 @@ def save_lonlat(filename, arrlon, arrlat, geodatum, arrcell=None,
lf[subsets[subset_name]['points']] = 1
if len(dim) == 2:
lf = lf.reshape(latsize, lonsize)
_, _, lf = sort_for_netcdf(lons, lats, lf)

flag[:] = lf
setattr(flag, 'long_name', subset_name)
Expand All @@ -188,12 +225,63 @@ def save_lonlat(filename, arrlon, arrlat, geodatum, arrcell=None,
}

ncfile.setncatts(attr)
if global_attrs is not None and type(global_attrs) is dict:

if global_attrs is not None:
ncfile.setncatts(global_attrs)


def sort_for_netcdf(lons, lats, values):
"""
Sort an 2D array for storage in a netCDF file.
This mans that the latitudes are stored from
90 to -90 and the longitudes from -180 to 180.
Input arrays have to have shape latdim, londim
which would mean for a global 10 degree grid (18, 36).
Parameters
----------
lons: numpy.ndarray
2D numpy array of longitudes
lats: numpy.ndarray
2D numpy array of latitudes
values: numpy.ndarray
2D numpy array of values to sort
Returns
-------
lons: numpy.ndarray
2D numpy array of longitudes, sorted
lats: numpy.ndarray
2D numpy array of latitudes, sorted
values: numpy.ndarray
2D numpy array of values to sort, sorted
"""

arrlat = lats.flatten()
arrlon = lons.flatten()
arrval = values.flatten()
idxlatsrt = np.argsort(arrlat)[::-1]
idxlat = np.argsort(arrlat[idxlatsrt].
reshape(lats.shape),
axis=0)[::-1]
idxlon = np.argsort(arrlon[idxlatsrt].
reshape(lons.shape)
[idxlat, np.arange(lons.shape[1])], axis=1)

values = arrval[idxlatsrt].reshape(*lons.shape)\
[idxlat, np.arange(lons.shape[1])]\
[np.arange(lons.shape[0])[:, None], idxlon]
lons = arrlon[idxlatsrt].reshape(*lons.shape)\
[idxlat, np.arange(lons.shape[1])]\
[np.arange(lons.shape[0])[:, None], idxlon]
lats = arrlat[idxlatsrt].reshape(*lons.shape)\
[idxlat, np.arange(lons.shape[1])]\
[np.arange(lons.shape[0])[:, None], idxlon]
return lons, lats, values


def save_grid(filename, grid, subset_name='subset_flag',
subset_meaning="water, land", global_attrs=None):
subset_meaning='water land', global_attrs=None):
"""
save a BasicGrid or CellGrid to netCDF
it is assumed that a subset should be used as land_points
Expand All @@ -218,17 +306,18 @@ def save_grid(filename, grid, subset_name='subset_flag',
except AttributeError:
arrcell = None

if grid.gpidirect is True:
gpis = None
else:
gpis = grid.gpis
gpis = grid.gpis

if grid.shape is not None:
if global_attrs is None:
global_attrs = {}
global_attrs['shape'] = grid.shape

subsets = {subset_name: {'points': grid.subset, 'meaning': subset_meaning}}
if grid.subset is not None:
subsets = {subset_name: {
'points': grid.subset, 'meaning': subset_meaning}}
else:
subsets = None

save_lonlat(filename, grid.arrlon, grid.arrlat, grid.geodatum,
arrcell=arrcell, gpis=gpis, subsets=subsets,
Expand Down Expand Up @@ -258,11 +347,7 @@ def load_grid(filename, subset_flag='subset_flag'):
if 'cell' in nc_data.variables.keys():
arrcell = nc_data.variables['cell'][:].flatten()

# determine if gpis are in order or custom order
if nc_data.gpidirect == 0x1b:
gpis = None # gpis can be calculated through np.arange..
else:
gpis = nc_data.variables['gpi'][:].flatten()
gpis = nc_data.variables['gpi'][:].flatten()

shape = None
if hasattr(nc_data, 'shape'):
Expand All @@ -287,9 +372,9 @@ def load_grid(filename, subset_flag='subset_flag'):
# check if grid has regular shape
if len(shape) == 2:
lons, lats = np.meshgrid(nc_data.variables['lon'][:],
nc_data.variables['lat'][::-1])
lons = lons.flatten('F')
lats = lats.flatten('F')
nc_data.variables['lat'][:])
lons = lons.flatten()
lats = lats.flatten()

if subset_flag in nc_data.variables.keys():
subset = np.where(
Expand Down

0 comments on commit e372f27

Please sign in to comment.