In [None]:
import math
import collections

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

from osgeo import osr
import cartopy.crs as ccrs

%matplotlib inline

In [None]:
fname = 'C:/UserData/hatfielz/data/LVMC_2017_h31v10.nc'
ds = xr.open_dataset(fname, chunks=dict(time=1))
ds.lvmc_mean.isel(time=-1).plot.imshow(robust=True)
ds

In [None]:
AffineGeoTransform = collections.namedtuple(
    'GeoTransform', ['origin_x', 'pixel_width', 'x_rot',
                     'origin_y', 'y_rot', 'pixel_height'])
RasterShape = collections.namedtuple('RasterShape', ['time', 'y', 'x'])
LonLat = collections.namedtuple('LonLat', ['lon', 'lat'])  # X,Y order

In [None]:
def reproj_bounding_box(from_sr, to_sr, old_geot, arr_shape, out_res_degrees=0.005):
    """Reproject the bounding box as upperleft, lowerright coordinate pairs.
    It turns out this is substantially harder than it sounds.
    - All corners must be calculated, as reprojection will skew the box.
      (fortunately convex edges stay within the box of corners...)
    - Bounding boxes which cross the dateline need very careful handling.
      Latitude is OK, but Longitude is tricky.
    Returns (new_geot, new_shape) in the transformed coord system.
    """
    transform = osr.CoordinateTransformation(from_sr, to_sr).TransformPoint
    inv_transform = osr.CoordinateTransformation(to_sr, from_sr).TransformPoint

    # make a list of corners, and transform them into lat/lon coordinates
    ysize, xsize = arr_shape  # unintuitive order, but correct!
    far_x = old_geot.origin_x + old_geot.pixel_width * xsize
    far_y = old_geot.origin_y + old_geot.pixel_height * ysize
    corners = [(old_geot.origin_x, old_geot.origin_y),
               (old_geot.origin_x, far_y),
               (far_x, old_geot.origin_y),
               (far_x, far_y)]
    llcorners = [LonLat(*transform(x, y)[:2]) for x, y in corners]

    # Reprojecting modis tiles should give consistent latitudes, well within
    # one degree given precision limits
    lat_min, _lat_small, _lat_big, lat_max = sorted(c.lat for c in llcorners)
    assert _lat_small - lat_min < 1 and lat_max - _lat_big < 1

    # A 'good' coordinate has x-coord error less than 1Km (no dateline)
    # and not near the prime meridian (reliable demihemisphere detection)
    check = [inv_transform(lat, lon)[:2] for lat, lon in llcorners]
    is_ok = [abs(x1 - x2) < 1000 and abs(ll.lon) > 1
             for (x1, _), ll, (x2, _) in zip(corners, llcorners, check)]
    good_coords = [coord for coord, ok in zip(llcorners, is_ok) if ok]
    assert good_coords, "Every data-containing tile has >= 1 corner with data"

    # Now clip the longitude bounding box to avoid crossing the dateline...
    lon_bound = -180 if good_coords[0].lon < 0 else 180
    lons = sorted(coord.lon if ok else lon_bound
                                    for coord, ok in zip(llcorners, is_ok))
    lon_min, _, _, lon_max = lons
    # Expand so that grid is integer-aligned
    lat_min -= lat_min % out_res_degrees
    lon_min -= lon_min % out_res_degrees
    lat_max += out_res_degrees - lat_max % out_res_degrees
    lon_max += out_res_degrees - lon_max % out_res_degrees

    return (AffineGeoTransform(lon_min, out_res_degrees, 0,
                               lat_max, 0, -out_res_degrees),
            # Remember, our array dimensions are ordered Y,X == Lat,Lon
            (math.ceil((lat_max - lat_min) / out_res_degrees),
             math.ceil((lon_max - lon_min) / out_res_degrees)))

In [None]:
def project_array_to_latlon(array, geot, wkt_str):
    """Reproject a tile from Modis Sinusoidal to WGS84 Lat/Lon coordinates.
    Metadata is handled by the calling function.
    """
    from osgeo import gdal, gdal_array, osr
    assert isinstance(geot, AffineGeoTransform)

    def array_to_raster(array, geot, wkt):
        ysize, xsize = array.shape  # unintuitive order, but correct!
        dataset = gdal.GetDriverByName('MEM').Create(
            '', xsize, ysize,
            eType=gdal_array.NumericTypeCodeToGDALTypeCode(array.dtype))
        dataset.SetGeoTransform(geot)
        dataset.SetProjection(wkt)
        dataset.GetRasterBand(1).WriteArray(array)
        return dataset

    input_data = array_to_raster(array, geot, wkt_str)

    # Set up the reference systems and transformation
    from_sr = osr.SpatialReference()
    from_sr.ImportFromWkt(wkt_str)
    to_sr = osr.SpatialReference()
    to_sr.SetWellKnownGeogCS("WGS84")

    # Get new geotransform and create destination raster
    ll_geot, new_shape = reproj_bounding_box(from_sr, to_sr, geot, array.shape)
    dest_arr = np.empty(new_shape)
    dest_arr[:] = np.nan
    dest = array_to_raster(dest_arr, ll_geot, to_sr.ExportToWkt())

    # Perform the projection/resampling
    gdal.ReprojectImage(
        input_data, dest,
        from_sr.ExportToWkt(), to_sr.ExportToWkt(),
        gdal.GRA_Bilinear)
    return dest.GetRasterBand(1).ReadAsArray(), ll_geot

In [None]:
def convert_xr_dataset(ds):
    """Convert a sinusoidal MODIS dataset to WGS84 lat/lon."""
    # TODO: handle actual dataset instead of one timestep of a single variable
    geot = AffineGeoTransform(*[float(v) for v in ds.sinusoidal.GeoTransform.split()])
    print(geot)
    arr, new_geot = project_array_to_latlon(ds.lvmc_mean.isel(time=-1).values, geot, ds.sinusoidal.spatial_ref)
    print(new_geot)
    coords = {'latitude': np.arange(arr.shape[0]) * new_geot.pixel_height + new_geot.origin_y,
              'longitude': np.arange(arr.shape[1]) * new_geot.pixel_width + new_geot.origin_x}
    da = xr.DataArray(arr, coords=coords, dims=('latitude', 'longitude'), name='lvmc_mean', encoding=dict())
    return da

In [None]:
out = convert_xr_dataset(ds)
out.plot.imshow(robust=True)
out

In [None]:
ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines()
out.plot.imshow(transform=ccrs.PlateCarree(), ax=ax, robust=True)

In [None]:
ds2 = xr.open_dataset('http://dapds00.nci.org.au/thredds/dodsC/ub8/au/FMC/LVMC/LVMC_2017_h30v10.nc')
out2 = convert_xr_dataset(ds2)

In [None]:
merged = xr.merge([out, out2])
merged

In [None]:
ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines()
merged.lvmc_mean.plot.imshow(transform=ccrs.PlateCarree(), ax=ax, robust=True)

In [None]:
ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines()
out2.plot.imshow(transform=ccrs.PlateCarree(), ax=ax, robust=True)

In [None]:
merged = out.combine_first(out2)
ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines()
merged.plot.imshow(transform=ccrs.PlateCarree(), ax=ax, robust=True)
merged

In [None]:
out.longitude % 0.005

In [None]:
out2.longitude % 0.005

In [None]:
out.latitude

In [None]:
# NOTE - must proceed west-to-east to get correct sinusoidal origin coord
in_both = ds2.isel(time=slice(-2, None)).combine_first(ds.isel(time=slice(-2, None))).isel(time=slice(1, 2))
in_both

In [None]:
merged = convert_xr_dataset(in_both)
merged

In [None]:
ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines()
merged.plot.imshow(transform=ccrs.PlateCarree(), ax=ax, robust=True)

In [None]:
in_both.lvmc_mean.plot.imshow(robust=True, col='time')