Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualize MPAS mesh in python #281

Open
bradyrx opened this issue Jan 27, 2020 · 28 comments
Open

Visualize MPAS mesh in python #281

bradyrx opened this issue Jan 27, 2020 · 28 comments
Labels

Comments

@bradyrx
Copy link
Contributor

bradyrx commented Jan 27, 2020

I've been working on a method to visualize MPAS-O output on a static plot in python as an alternative to the more time-consuming route of rendering viz in ParaView. I'm looking for a way to faithfully represent the grid without regridding.

It seems like datashader/holoviews/geoviews is the way to go, but I'm having trouble getting it to work fully. I'm wondering if anyone has used this route or has a solution.

Method 1: Datashader points and shade:

from colorcet import cwr
import datashader 
import datashader.transfer_functions as tf


# Single time slice of air-sea CO2 flux.
ds = xr.open_dataset('MPAS-O.fgco2.nc')
verts = pd.DataFrame({'x': ds.lonCell, 'y': ds.latCell, 'z': ds.fgco2})

cvs = datashader.Canvas(plot_height=400,plot_width=400)
tf.shade(cvs.points(verts, 'x', 'y', agg=datashader.mean('z')), cmap=cwr,
         )

Screen Shot 2020-01-27 at 1 03 29 PM

This produces a map by coloring each pixel with an RGB value. This isn't a faithful representation of the grid, and I haven't found an easy way to plot this on a map projection.

Method 2: Trimesh

This is ultimately what I'd like to get working. This can easily be drawn on a map projection using geoviews once it works.

import datashader
import datashader.transfer_functions as tf
import pandas as pd
import xarray as xr

ds = xr.open_dataset('MPAS-O.fgco2.nc')
mesh = xr.open_dataset('oRRS30to10v3.171128.nc')

# Set up table with coordinates for each `nCell` index as well as the data being plotted ('z')
verts = pd.DataFrame({'x': ds.lonCell, 'y': ds.latCell, 'z': ds.fgco2})

# Construct indices for each triangle.
cell0 = mesh.nCells  # Base cell
cell1 = mesh.cellsOnCell.isel(maxEdges=1) - 1  # NE cell
cell2 = mesh.cellsOnCell.isel(maxEdges=2) - 1  # N cell

# Points to the triangle vertices (`verts`) for each triangle being plotted.
tris = pd.DataFrame({'v0': cell0, 'v1': cell1, 'v2': cell2})

cvs = datashader.Canvas(plot_height=400, plot_width=400)
tf.shade(cvs.trimesh(verts, tris))

Screen Shot 2020-01-27 at 3 44 10 PM

I tried also including all possible triangles, without success.

# Pairs of neighboring cell indices to construct triangles from.
PAIRS = [
    (0,1),
    (1,2),
    (2,3),
    (3,4),
    (4,5),
    (5,0)
]

cell0 = mesh.nCells
df = pd.DataFrame()
for idx in PAIRS:
    cell1 = mesh.cellsOnCell.isel(maxEdges=idx[0]) - 1
    cell2 = mesh.cellsOnCell.isel(maxEdges=idx[1]) - 1
    tris = pd.DataFrame({'v0': cell0,
                         'v1': cell1,
                         'v2': cell2})
    df = pd.concat([df, tris], ignore_index=True)

cvs = datashader.Canvas(plot_height=400,plot_width=400)
tf.shade(cvs.trimesh(verts, df), cmap=cwr)

Screen Shot 2020-01-27 at 3 46 25 PM

Any thoughts? The tris argument should be a three-column list of pointers to indices in verts to form a non-overlapping set of all triangles on the mesh.

Resources

https://datashader.org/user_guide/Trimesh.html

https://holoviews.org/reference/elements/bokeh/TriMesh.html

https://twitter.com/oceanographer/status/1126803420579545094

@bradyrx
Copy link
Contributor Author

bradyrx commented Jan 27, 2020

CC @xylar, @pwolfram

Also I can host these example files on a Google Bucket or something. This should be straightforward to assess with any arbitrary MPAS-O variable on any mesh.

@xylar
Copy link
Collaborator

xylar commented Jan 27, 2020

@bradyrx, I'd love to look at this more sometime soon but don't have time today. For now, I figured I'd point you to some viz stuff I have for another context that might be helpful. Here, I create a patch collection to reuse for many plots:
https://github.com/MPAS-Dev/MPAS-Model/blob/3e4d773acc4e4c76627124e1319e921dee8a908d/testing_and_setup/compass/ocean/isomip_plus/viz/plot.py#L667-L687
Here is an example of plotting that patch collection:
https://github.com/MPAS-Dev/MPAS-Model/blob/3e4d773acc4e4c76627124e1319e921dee8a908d/testing_and_setup/compass/ocean/isomip_plus/viz/plot.py#L583-L614
This method shows hexagons (etc.) of constant shade, which is what I wanted. Here's an example plot with apologies for the rainbow color map:
image

@bradyrx
Copy link
Contributor Author

bradyrx commented Jan 28, 2020

Thanks @xylar for giving this some thought and your working solution here. I'll give that a try as a potential avenue, but it would be good for us to find some solutions with e.g. holoviews to add to user documentation. I'll try your solution when I have a moment, but need to focus on analysis for now.

Here's a consolidated notebook with my attempts and the example output pulled from Google Storage. I have some running threads on gitter/twitter with folks familiar with this sort of thing so I'll point them here.

Notebook: https://nbviewer.jupyter.org/gist/bradyrx/fe329f47e9a85fb0adacefa113178b44

@bradyrx
Copy link
Contributor Author

bradyrx commented Jan 28, 2020

I've updated it with a working version. I had to drop -1 partial cells. This still isn't exactly what I want but is getting in the right direction:

https://gist.github.com/bradyrx/fe329f47e9a85fb0adacefa113178b44

@xylar
Copy link
Collaborator

xylar commented Jan 28, 2020

@bradyrx, it looks like you're essentially plotting the dual mesh (with cell centers as the vertices of triangles), is that correct? That's also what ParaView does but it drives me nuts for the reasons that I think you have found, too, that it leaves boundary gaps in the mesh that, in my view, shouldn't be there. I'd personally only be interested in pursuing a viz approach on the native grid if it plots cells as full polygons rather than the triangles on the dual mesh.

@milenaveneziani
Copy link
Contributor

just a quick note to say that I am following this and that it is great you are working on it @bradyrx.
I also do not like to tinker with paraview too much, so my simple solution at the moment is to just do a scatter plot..

@bradyrx
Copy link
Contributor Author

bradyrx commented Jan 28, 2020

@bradyrx, it looks like you're essentially plotting the dual mesh (with cell centers as the vertices of triangles), is that correct? That's also what ParaView does but it drives me nuts for the reasons that I think you have found, too, that it leaves boundary gaps in the mesh that, in my view, shouldn't be there.

Correct -- I'm following this approach since TriMesh is the only unstructured mesh option I've found in plotting packages so far. But I agree that it's not perfect.

I'd personally only be interested in pursuing a viz approach on the native grid if it plots cells as full polygons rather than the triangles on the dual mesh.

Agreed. In short, I want a static solution in python to plot the full polygons as we do in ParaView. The VTK extractor you wrote is awesome, but requires spending time to work with the GUI. In theory you could write some macros or use a VTK package on python, though. But I also stylistically would like to be able to leverage something like matplotlib/cartopy to make plots for papers for instance. It's really nice to use the matplotlib API to zoom into a region, tile a bunch of small multiples, add the coastline, colorbar, etc. This is a lot of manual work to make it look nice from ParaView. (But note I use the ParaView route a lot to make nice movies, etc. and it will continue to be useful)

Glad to hear it @milenaveneziani! Hopefully we can find a path forward here. CC @andsery005 who is working on this at NCAR for MPAS-A as well.

@milenaveneziani
Copy link
Contributor

Yes, I agree that it would be great to use cartopy. Here is a zoom-in scatter plot I did recently for the Southern Ocean, using cartopy. The only thing missing is the labeling of meridians and parallels, but that is something that is coming up in the next released version of cartopy:

salinitySH_depth0010_20200122 RedionHighTapering GMPAS-IAF oEC60to30v3 anvil_0005-01-01

@bradyrx
Copy link
Contributor Author

bradyrx commented Jan 28, 2020

@milenaveneziani, that looks great. I think that's probably the best solution until we can find one that faithfully depicts the polygons. I had used scatter in the past for global checks, but hadn't tried zooming in. I imagine you just tweak the marker size until it doesn't overlap?

@milenaveneziani
Copy link
Contributor

I imagine you just tweak the marker size until it doesn't overlap?

exactly, and that's the problem. For one, every time you plot a different mesh, you have to re-do the tweaking. For two, the rendering isn't uniform, of course, and there is bound to be an overlap of cells/dots where resolution is higher.
So, just a temporary solution for sure.

@bradyrx
Copy link
Contributor Author

bradyrx commented Jan 28, 2020

I spent a little bit of time with @xylar's patches approach:

import copy

import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection

import numpy
import xarray as xr

def _compute_cell_patches(dsMesh, cmap):
    patches = []
    nVerticesOnCell = dsMesh.nEdgesOnCell.values
    verticesOnCell = dsMesh.verticesOnCell.values - 1
    xVertex = dsMesh.xVertex.values
    yVertex = dsMesh.yVertex.values
    for iCell in range(dsMesh.sizes['nCells']):
        nVert = nVerticesOnCell[iCell]
        vertexIndices = verticesOnCell[iCell, :nVert]
        vertices = numpy.zeros((nVert, 2))
        vertices[:, 0] = 1e-3*xVertex[vertexIndices]
        vertices[:, 1] = 1e-3*yVertex[vertexIndices]

        polygon = Polygon(vertices, True)
        patches.append(polygon)

    p = PatchCollection(patches, cmap=cmap, alpha=1.)

    return p

dsMesh = xr.open_dataset('oRRS30to10v3.171128.nc')
ds = xr.open_zarr(fsspec.get_mapper('gcs://unstructured_mesh/mpas-o.fgco2')).load()

# Runtime ~2.5min
oceanPatches = _compute_cell_patches(dsMesh, 'RdBu')
# Copy since plotting seems to do some overwriting.
patches = copy.copy(oceanPatches)

# Set up specifically for fgco2
patches.set_array(ds.fgco2)
patches.set_edgecolor('face')
patches.set_clim(vmin=-0.0001, vmax=0.0001)

# plot
plt.figure(figsize=[9, 3])
ax = plt.subplot('111')
ax.add_collection(patches)
plt.colorbar(patches)
plt.axis([0, 500, 0, 1000])
ax.set_aspect('equal')
ax.autoscale(tight=True)

Screen Shot 2020-01-28 at 11 57 55 AM

This looks like it's going down the right road, but it does take 2.5 minutes to render the patch collection. Perhaps it can be saved out and reloaded for the viz? Also I need to figure out how to encode the lat/lon information in the patches collection so I can project, zoom in, etc.

@xylar
Copy link
Collaborator

xylar commented Jan 28, 2020

Okay, thanks for looking into it. That does sound too slow to be useful. I've only used it for idealized, small test cases.

@xylar xylar closed this as completed Jan 28, 2020
@xylar xylar reopened this Jan 28, 2020
@milenaveneziani
Copy link
Contributor

as @bradyrx says, we could save it though, right? Since it's only mesh dependent. 2.5min for the 30to10 doesn't sound like much to me if it's done only once.

@bradyrx
Copy link
Contributor Author

bradyrx commented Jan 28, 2020

I'm not sure if there is a way to save out matplotlib PatchCollections. A quick search returned no meaningful results.

For the lat/lon stuff, I think one might need to just change this:

xVertex = dsMesh.xVertex.values
yVertex = dsMesh.yVertex.values

to this:

xVertex = dsMesh.lonVertex.values * (180/np.pi)
yVertex = dsMesh.latVertex.values * (180/np.pi)

@xylar
Copy link
Collaborator

xylar commented Jan 31, 2020

@bradyrx, I'm going to give this some thought over the weekend. I'll let you know what I come up with.

@xylar
Copy link
Collaborator

xylar commented Feb 1, 2020

@bradyrx, I played with this a bit more. My ignorance of datashader, holoviews andgeoviews knows no bounds, and I prefer typically to work outside of Jupyter notebooks because I'm usually designing viz to be reused from the command line. But these packages seem really hard to use, especially with bokeh, outside of a notebook.

That being said, I had some success working from your Method 2 above:

import holoviews as hv
import geoviews as gv
import xarray
import numpy
import pandas as pd
import datashader
import datashader.transfer_functions as tf
from datashader import utils

# possibly unnecessary
gv.extension('matplotlib', 'bokeh')
gv.output(dpi=300, fig='png')
hv.output(backend='matplotlib')

# a small example MPAS-O initial condition I happen to be working with
ds = xarray.open_dataset('/home/xylar/data/mpas/test_bedmachine/ocean/'
                         'global_ocean/QU240wISC/init/initial_state/'
                         'initial_state.nc')

maxEdges = ds.sizes['maxEdges']
nCells = ds.sizes['nCells']
nVertices = ds.sizes['nVertices']
nEdgesOnCell = ds.nEdgesOnCell.values
verticesOnCell = ds.verticesOnCell.values - 1

lonVertex = numpy.rad2deg(ds.lonVertex.values)
latVertex = numpy.rad2deg(ds.latVertex.values)
lonCell = numpy.rad2deg(ds.lonCell.values)
latCell = numpy.rad2deg(ds.latCell.values)

# an example field to plot
sst = ds.temperature.isel(Time=0, nVertLevels=0).values

# Repeat last vertex on cell to get to maxEdges
# Doing this creates a bunch of extra zero-area triangles but makes indexing and
# looping a lot more efficient
for iVertex in range(1, maxEdges):
    mask = nEdgesOnCell <= iVertex
    verticesOnCell[mask, iVertex] = verticesOnCell[mask, iVertex-1]

# Represent each cell by maxEdges triangles.  Each cell (and therefore all
# maxEdges triangles for a given cell) needs its own vertices because color
# values are defined at vertices in TriMesh but at cell centers in MPAS-Ocean
lon = numpy.zeros((nCells, maxEdges, 3))
lat = numpy.zeros((nCells, maxEdges, 3))
pointToCell = numpy.zeros((nCells, maxEdges, 3), int)
cellIndices = numpy.arange(nCells)
for iVertex in range(maxEdges):
    # Each triangle is formed from the current vertex, the "next" vertex around 
    # the cell and the cell center
    nextVertex = numpy.mod(iVertex+1, maxEdges)
    for point in range(3):
        # pointToCell is the index of the cell corresponding to a given point 
        # (triangle vertex)
        pointToCell[:, iVertex, point] = cellIndices
    v1 = verticesOnCell[:, iVertex]
    v2 = verticesOnCell[:, nextVertex]
    lon[:, iVertex, 0] = lonVertex[v1]
    lon[:, iVertex, 1] = lonVertex[v2]
    lon[:, iVertex, 2] = lonCell

    lat[:, iVertex, 0] = latVertex[v1]
    lat[:, iVertex, 1] = latVertex[v2]
    lat[:, iVertex, 2] = latCell

# It was convenient to construct these as multidimensional arrays but now
# ravel them
lon = lon.ravel()
lat = lat.ravel()
pointToCell = pointToCell.ravel()
# Since we want different shading for adjacent triangles, they don't ever 
# share vertices. Thus, there is a point for each triangle vertex and the 
# connectivity is just and ascending list of points
nPoints = 3*maxEdges*nCells
tris = numpy.arange(nPoints).reshape(nCells*maxEdges, 3)

# convert to the format expected by trimesh
verts = pd.DataFrame({'x': lon, 'y': lat, 'z': sst[triCell]})
tris = pd.DataFrame({'v0': tris[:, 0], 'v1': tris[:, 1], 'v2': tris[:, 2]})
cvs = datashader.Canvas(plot_height=800, plot_width=800)
tf.shade(cvs.trimesh(verts, tris))

image

@xylar
Copy link
Collaborator

xylar commented Feb 1, 2020

@bradyrx, if you can take what I've done and figure out how to store a reusable trimesh (e.g. for plotting different points in time or different field) and also how to use this properly with geoviews projections, nice continent outlines, colorbars, etc., I would be very keen to have that.

@xylar
Copy link
Collaborator

xylar commented Feb 1, 2020

I'm going to try another version that interpolates the given field between cell centers but still includes the triangles (with constant field value) that correspond to boundary edges. This should potentially be more aesthetic and less computationally intensive but I think both versions would ideally be available.

@bradyrx
Copy link
Contributor Author

bradyrx commented Feb 2, 2020

That looks great @xylar! Yes of course the key is to get this onto a map projection with geoviews. The pyviz/pyviz gitter is very active and focuses on this viz stack (which I agree is hard to work with as a new user). So I can play around with getting this onto a projection then ask them for guidance if stuck.

I have to focus the next two weeks on getting some plots/analysis ready for Ocean Sciences, so I likely will have to wait til after the meeting to get back on this.

@xylar
Copy link
Collaborator

xylar commented Feb 2, 2020

@bradyrx, good to know about pyviz/pyviz. I'll check it out when I have time.

@xylar
Copy link
Collaborator

xylar commented Feb 2, 2020

It the meantime, I figured out a potential approach to viz that's smoother but still renders polygons according to the MPAS-Ocean mesh. It's substantially more complex:

import holoviews as hv
import geoviews as gv
import xarray
import numpy
import pandas as pd
import datashader
import datashader.transfer_functions as tf
from datashader import utils

gv.extension('matplotlib', 'bokeh')
gv.output(dpi=300, fig='png')
hv.output(backend='matplotlib')

ds = xarray.open_dataset('/home/xylar/data/mpas/test_bedmachine/ocean/'
                         'global_ocean/QU240wISC/init/initial_state/'
                         'initial_state.nc')

maxEdges = ds.sizes['maxEdges']
nCells = ds.sizes['nCells']
nVertices = ds.sizes['nVertices']
nPoints = 3*maxEdges*nCells
nEdgesOnCell = ds.nEdgesOnCell.values
verticesOnCell = ds.verticesOnCell.values - 1
cellsOnVertex = ds.cellsOnVertex.values - 1
cellsOnEdge = ds.cellsOnEdge.values - 1
verticesOnEdge = ds.verticesOnEdge.values - 1

lonVertex = numpy.rad2deg(ds.lonVertex.values)
latVertex = numpy.rad2deg(ds.latVertex.values)
lonCell = numpy.rad2deg(ds.lonCell.values)
latCell = numpy.rad2deg(ds.latCell.values)

sst = ds.temperature.isel(Time=0, nVertLevels=0).values

nCellsOnVertex = numpy.zeros(nVertices, int)
for index in range(3):
    mask = cellsOnVertex[:, index] > -1
    nCellsOnVertex[mask] += 1
    
fullVertices = nCellsOnVertex == 3
pairVertices = nCellsOnVertex == 2
nFullVertices = numpy.count_nonzero(fullVertices)
nPairVertices = numpy.count_nonzero(pairVertices)
boundaryEdges = cellsOnEdge[:, 1] == -1
nBoundaryEdges = numpy.count_nonzero(boundaryEdges)

nVerts = nCells + nVertices
nTris = nFullVertices + nPairVertices + nBoundaryEdges

lon = numpy.zeros(nVerts)
lat = numpy.zeros(nVerts)
tris = numpy.zeros((nTris, 3), int)

# with this approach, some triangle vertices are not at cell centers.  Their 
# values will need to bea weighted average of the valid cells adjacent to them.
vertCells = numpy.zeros((nVerts, 3), int)
cellIndices = numpy.arange(nCells)
vertCells[0:nCells, 0] = cellIndices
vertCells[nCells:nCells+nVertices, :] = cellsOnVertex
vertCellWeights = numpy.zeros((nVerts, 3))
# for cells, only the first index is used and gets a weight of 1
vertCellWeights[:, 0] = 1.
# for vertices, we want a simple average of all valid neighboring cells
for index in range(3):
    weight = 1./nCellsOnVertex*(cellsOnVertex[:, index] > -1)
    vertCellWeights[nCells:nCells+nVertices, index] = weight

# First, the vertices -- we include all vertices, not just the boundary ones
# for simplicity of computing the connectivity later
lon[0:nCells] = lonCell
lat[0:nCells] = latCell
lon[nCells:nCells+nVertices] = lonVertex
lat[nCells:nCells+nVertices] = latVertex

# Now, the connectivity of triangles
# First, the "full" vertices
tris[0:nFullVertices, :] = cellsOnVertex[fullVertices, :]

# Second, triangles with two cells and one vertex
offset = nFullVertices
vertexIndices = numpy.arange(nVertices)
trisPairVertices = cellsOnVertex[pairVertices, :]
mask = trisPairVertices == -1
trisPairVertices[mask] = nCells + vertexIndices[pairVertices]
tris[offset:offset+nPairVertices, :] = trisPairVertices

# Finally, triangles with the vertices on a boundary edge and one cell center
offset = nFullVertices + nPairVertices
tris[offset:offset+nBoundaryEdges, 0:2] = nCells + verticesOnEdge[boundaryEdges,:]
tris[offset:offset+nBoundaryEdges, 2] = cellsOnEdge[boundaryEdges, 0]

# this is where the weighted averaging comes in
sst_vert = numpy.sum(vertCellWeights*sst[vertCells], axis=1)

verts = pd.DataFrame({'x': lon, 'y': lat, 'z': sst_vert})

tris_pd = pd.DataFrame({'v0': tris[:, 0], 'v1': tris[:, 1], 'v2': tris[:, 2]})

cvs = datashader.Canvas(plot_height=1600, plot_width=1600)
tf.shade(cvs.trimesh(verts, tris_pd))

image

Some triangles are at the periodic boundary but they don't render because they have the wrong winding (a nice coincidence). If one cared about that, it would be necessary to make periodic copies of triangles that cross the boundary and fix the longitude where it's problematic.

@bradyrx
Copy link
Contributor Author

bradyrx commented Mar 3, 2020

I may not be able to look at this for a bit, but I want to keep it open so we can continue to think about it. I think @xylar got us a long way there. The next goal is to figure out how to actually get it on a projection with zooming/tick labeling, etc. I'm not sure how straight forward that would be though.

@xylar
Copy link
Collaborator

xylar commented Mar 3, 2020

Thanks for keeping this alive.

@bradyrx
Copy link
Contributor Author

bradyrx commented Mar 17, 2020

I just came across this due to some activity at MPAS-QuickViz: https://github.com/MPAS-Dev/MPAS-QuickViz/blob/master/ocean/plotting_library/simple_plotting.py. plot_poly looks like a similar approach to what we're doing here with patches, but more concise. Are you familiar with this @xylar? Any thoughts on this @pwolfram?

There's no docstring at all so I'm a little lost as to what is supposed to be input here. It also is going to loop through every cell so it's probably not feasible for 30to10, as we talked about with the patch approach here.

@pwolfram
Copy link
Contributor

It is for point-based output. We have other scripts that we've used for patch-based visualization in python, but largely I'm using ParaView these days for that type of output and then having specific diagnostics produce plots.

@bradyrx
Copy link
Contributor Author

bradyrx commented Mar 17, 2020

@pwolfram, I think the key here is to be able to interactively make visualizations of the native mesh through python. I.e. we don't want to regrid to make pcolor/contour plots with matplotlib. And while paraview is super useful, it takes a lot longer to polish up a nice figure. It's tough for exploratory work.

@xylar
Copy link
Collaborator

xylar commented Jul 3, 2020

I'm planning to work on this again over the weekend. I've been working on MPAS-Dev/MPAS-Analysis#586 and realized that I basically need to create a set of triangles that would be needed for this viz in order to do a proper job at transects on the MPAS-Ocean grid, too.

@xylar
Copy link
Collaborator

xylar commented Feb 3, 2023

Three years after the original post, here's a potential answer:
UXARRAY/uxarray#214

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants