In [1]:
import numpy as np
import xarray as xr
import pickle
import napari
from mpl_toolkits.basemap import Basemap
from tqdm.notebook import tqdm
import OpenVisus as ov
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import plotly.graph_objects as go
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import os
from matplotlib.widgets import Slider, RadioButtons
import matplotlib.colors as mcolors
import matplotlib
import logging
import vtk
from vtk.util.numpy_support import numpy_to_vtk
matplotlib.use('Agg')

os.environ['VISUS_CACHE']= "./visus_can_be_deleted"

### Part 0: Preprocessing the lon and lat locations for later use

In [2]:
def downsample_latlon(latitudes, longitudes, target_shape):
    factor_lat = latitudes.shape[0] // target_shape[0]
    factor_lon = latitudes.shape[1] // target_shape[1]
    latitudes_downsampled = latitudes[::factor_lat, ::factor_lon]
    longitudes_downsampled = longitudes[::factor_lat, ::factor_lon]
    return latitudes_downsampled, longitudes_downsampled

In [3]:
temp = xr.open_dataset('geos_c1440_lats_lons_2D.nc')

latitudes = temp['lats']
longitudes = temp['lons']

lat_shape = latitudes.shape
lon_shape = longitudes.shape

total_rows, total_cols = latitudes.shape

faces_per_row = 3
faces_per_col = 2
face_rows = total_rows // faces_per_col
face_cols = total_cols // faces_per_row

lat_faces = []
lon_faces = []

for row in range(faces_per_col):
    for col in range(faces_per_row):
        start_row = row * face_rows
        end_row = start_row + face_rows
        start_col = col * face_cols
        end_col = start_col + face_cols

        lat_face = latitudes[start_row:end_row, start_col:end_col]
        lon_face = longitudes[start_row:end_row, start_col:end_col]

        lat_faces.append(lat_face)
        lon_faces.append(lon_face)


### Part 1: Iterate to generate images by frame

In [None]:
%%time
def downsample_latlon(latitudes, longitudes, target_shape):
    factor_lat = latitudes.shape[0] // target_shape[0]
    factor_lon = latitudes.shape[1] // target_shape[1]
    latitudes_downsampled = latitudes[::factor_lat, ::factor_lon]
    longitudes_downsampled = longitudes[::factor_lat, ::factor_lon]
    return latitudes_downsampled, longitudes_downsampled

for t in tqdm(range(len(total_faces[0])), desc="Saving frames"):
    data_faces = []
    for i in range(len(total_faces)):
        data_faces.append(total_faces[i][t])

    fig = plt.figure(figsize=(10, 8))
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.set_global()

    num_faces = 6
    face_size = latitudes.shape[0] // num_faces

    all_lats = []
    all_lons = []
    all_vals = []

    for i in range(num_faces):
        lat_face = latitudes[i * face_size:(i + 1) * face_size, :].values
        lon_face = longitudes[i * face_size:(i + 1) * face_size, :].values
        data_face = data_faces[i][0, :, :]  # shape: (y, x)

        lat_ds, lon_ds = downsample_latlon(lat_face, lon_face, data_face.shape)

        if lat_ds.shape != data_face.shape:
            print(f"Skipping face {i} due to shape mismatch: {lat_ds.shape} vs {data_face.shape}")
            continue

        all_lats.append(lat_ds.flatten())
        all_lons.append(lon_ds.flatten())
        all_vals.append(data_face.flatten())

    merged_lat = np.concatenate(all_lats)
    merged_lon = np.concatenate(all_lons)
    merged_vals = np.concatenate(all_vals)

    scatter = ax.scatter(
        merged_lon,
        merged_lat,
        c=merged_vals,
        cmap='cividis',
        s=1,
        alpha=0.5,
        transform=ccrs.PlateCarree(),
        vmin=vmin,
        vmax=vmax
    )  

    gl = ax.gridlines(draw_labels=True, crs=ccrs.PlateCarree())
    gl.xlabels_top = False
    gl.ylabels_right = False
    gl.xlines = False
    gl.ylines = False
    gl.xlabel_style = {'size': 10}
    gl.ylabel_style = {'size': 10}
    cbar = plt.colorbar(scatter, ax=ax, orientation='horizontal', pad=0.05)

    os.makedirs(f'{variable}_images', exist_ok=True)

    plt.savefig(f'{variable}_images/frame_{t:03d}.png')
    plt.close(fig)
    

### Part 2: Preprocessing data for Paraview visualization
#### 2-1: Load each data attributes of complete timesteps

In [33]:
variables = ['U', 'V', 'W', 'T', 'FCLD']#,'P']
for variable in variables:
    collection = {}
    
    data_file=[f"https://maritime.sealstorage.io/api/v0/s3/utah/nasa/dyamond/GEOS/GEOS_{variable.upper()}/{variable.lower()}_face_0_depth_52_time_0_10269.idx?access_key=any&secret_key=any&endpoint_url=https://maritime.sealstorage.io/api/v0/s3&cached=arco",
              f"https://maritime.sealstorage.io/api/v0/s3/utah/nasa/dyamond/GEOS/GEOS_{variable.upper()}/{variable.lower()}_face_1_depth_52_time_0_10269.idx?access_key=any&secret_key=any&endpoint_url=https://maritime.sealstorage.io/api/v0/s3&cached=arco",
              f"https://maritime.sealstorage.io/api/v0/s3/utah/nasa/dyamond/GEOS/GEOS_{variable.upper()}/{variable.lower()}_face_2_depth_52_time_0_10269.idx?access_key=any&secret_key=any&endpoint_url=https://maritime.sealstorage.io/api/v0/s3&cached=arco",
              f"https://maritime.sealstorage.io/api/v0/s3/utah/nasa/dyamond/GEOS/GEOS_{variable.upper()}/{variable.lower()}_face_3_depth_52_time_0_10269.idx?access_key=any&secret_key=any&endpoint_url=https://maritime.sealstorage.io/api/v0/s3&cached=arco",
              f"https://maritime.sealstorage.io/api/v0/s3/utah/nasa/dyamond/GEOS/GEOS_{variable.upper()}/{variable.lower()}_face_4_depth_52_time_0_10269.idx?access_key=any&secret_key=any&endpoint_url=https://maritime.sealstorage.io/api/v0/s3&cached=arco",
              f"https://maritime.sealstorage.io/api/v0/s3/utah/nasa/dyamond/GEOS/GEOS_{variable.upper()}/{variable.lower()}_face_5_depth_52_time_0_10269.idx?access_key=any&secret_key=any&endpoint_url=https://maritime.sealstorage.io/api/v0/s3&cached=arco",]
    total_faces = []
    
    for actual_file_path in tqdm(data_file, desc=f"Loading {variable}"):
        db = ov.LoadDataset(actual_file_path)
        data=[]
        for i in tqdm(range(0, len(db.getTimesteps()), 24), desc="Weekly timesteps", leave=False):
            database = db.read(time=i,quality=-7)
            data.append(database)
        total_faces.append(data)
        
    depth = total_faces[0][0].shape[0]
        
    for t in tqdm(range(len(total_faces[0])), desc=f"Time:"):
        for d in range(depth):
            data_faces = []
            for i in range(len(total_faces)):
                data_faces.append(total_faces[i][t])

            num_faces = 6
            face_size = latitudes.shape[0] // num_faces       
            if t==0 and d==0:
                all_lats = []
                all_lons = []
                all_vals = []

                for i in range(num_faces):
                    lat_face = latitudes[i * face_size:(i + 1) * face_size, :].values
                    lon_face = longitudes[i * face_size:(i + 1) * face_size, :].values
                    data_face = data_faces[i][d, :, :]
                    temp = data_faces[i]
                    lat_ds, lon_ds = downsample_latlon(lat_face, lon_face, data_face.shape)
                    if lat_ds.shape != data_face.shape:
                        print(f"Skipping face {i} due to shape mismatch: {lat_ds.shape} vs {data_face.shape}")
                        continue
                    all_lats.append(lat_ds.flatten())
                    all_lons.append(lon_ds.flatten())
                    all_vals.append(data_face.flatten())

                merged_lat = np.concatenate(all_lats)
                merged_lon = np.concatenate(all_lons)
                collection[f'{variable} lat'] = merged_lat
                collection[f'{variable} lon'] = merged_lon
                merged_vals = np.concatenate(all_vals)

            else:
                all_vals = []
                for i in range(num_faces):
                    data_face = data_faces[i][d, :, :]
                    all_vals.append(data_face.flatten())
                merged_vals = np.concatenate(all_vals)

            collection[f'{variable} {d} {t}'] = np.copy(merged_vals)
    
    with open(f'{variable}.pkl', 'wb') as f:
        pickle.dump(collection, f)
    
    print(f'{variable} done!')

Loading U:   0%|          | 0/6 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Time::   0%|          | 0/428 [00:00<?, ?it/s]

U done!


Loading V:   0%|          | 0/6 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Time::   0%|          | 0/428 [00:00<?, ?it/s]

V done!


Loading W:   0%|          | 0/6 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Time::   0%|          | 0/428 [00:00<?, ?it/s]

W done!


Loading T:   0%|          | 0/6 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Time::   0%|          | 0/428 [00:00<?, ?it/s]

T done!


Loading FCLD:   0%|          | 0/6 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Weekly timesteps:   0%|          | 0/428 [00:00<?, ?it/s]

Time::   0%|          | 0/428 [00:00<?, ?it/s]

FCLD done!


#### 2-2: Assemble each file that has the same shape into one vtp file

In [38]:
def add_vertex_cells(polydata):
    """
    Add vertex cells to a vtkPolyData object to ensure each point is treated as an individual cell.

    Parameters:
    polydata (vtk.vtkPolyData): The polydata object to which vertex cells will be added.
    """
    num_points = polydata.GetNumberOfPoints()
    vertices = vtk.vtkCellArray()
    for i in range(num_points):
        vertex = vtk.vtkVertex()
        vertex.GetPointIds().SetId(0, i)
        vertices.InsertNextCell(vertex)
    polydata.SetVerts(vertices)

In [None]:
# Define variables
variables = ['U', 'V']
collections = {}

# Load all variable data once
for var in variables:
    with open(f'{var}.pkl', 'rb') as file:
        collections[var] = pickle.load(file)

# Extract latitude and longitude from the 'U' collection
lat = collections['U']['U lat']
lon = collections['U']['U lon']

# Determine dimensions
num_depths = (len(collections['U']) - 2) // 428
num_points = len(lat) * num_depths
num_vars = len(variables)

# Initialize holder array
holder = np.full((num_points, 3 + num_vars), np.nan)

# Populate static latitude, longitude, and depth information
for d in range(num_depths):
    start_idx = d * len(lat)
    end_idx = (d + 1) * len(lat)
    holder[start_idx:end_idx, 0] = lat
    holder[start_idx:end_idx, 1] = lon
    holder[start_idx:end_idx, 2] = d

# Process each time step
for t in tqdm(range(0, 428), desc="Processing time steps", leave=True):
    # Populate variable data for the current time step
    for v, var in enumerate(variables):
        collection = collections[var]
        for d in range(num_depths):
            start_idx = d * len(lat)
            end_idx = (d + 1) * len(lat)
            holder[start_idx:end_idx, v + 3] = collection[f'{var} {d} {t}']

    # Create VTK points
    vtk_points = vtk.vtkPoints()
    vtk_points.SetData(numpy_to_vtk(holder[:, :3], deep=True))

    # Create VTK polydata
    polydata = vtk.vtkPolyData()
    polydata.SetPoints(vtk_points)

    # Add variable data to polydata
    for v, var in enumerate(variables):
        var_data = holder[:, v + 3]
        vtk_array = numpy_to_vtk(var_data, deep=True)
        vtk_array.SetName(var)
        polydata.GetPointData().AddArray(vtk_array)

    writer = vtk.vtkXMLPolyDataWriter()
    writer.SetFileName(f"./vtkp/series2_{t}.vtp")
    writer.SetInputData(polydata)
    writer.Write()

In [39]:
variables = ['P', 'W', 'FCLD', 'T']
collections = {}

for var in variables:
    with open(f'{var}.pkl', 'rb') as file:
        collections[var] = pickle.load(file)

# Extract latitude and longitude from the 'P' collection
lat = collections['P']['P lat']
lon = collections['P']['P lon']

# Determine dimensions
num_depths = (len(collections['P']) - 2) // 428
num_points = len(lat) * num_depths
num_vars = len(variables)

# Initialize holder array
holder = np.full((num_points, 3 + num_vars), np.nan)

# Populate static latitude, longitude, and depth information
for d in range(num_depths):
    start_idx = d * len(lat)
    end_idx = (d + 1) * len(lat)
    holder[start_idx:end_idx, 0] = lat
    holder[start_idx:end_idx, 1] = lon
    holder[start_idx:end_idx, 2] = d

# Process each time step
for t in tqdm(range(428), desc="Processing time steps", leave=True):
    # Populate variable data for the current time step
    for v, var in enumerate(variables):
        collection = collections[var]
        for d in range(num_depths):
            start_idx = d * len(lat)
            end_idx = (d + 1) * len(lat)
            holder[start_idx:end_idx, v + 3] = collection[f'{var} {d} {t}']

    # Create VTK points
    vtk_points = vtk.vtkPoints()
    vtk_points.SetData(numpy_to_vtk(holder[:, :3], deep=True))

    # Create VTK polydata
    polydata = vtk.vtkPolyData()
    polydata.SetPoints(vtk_points)

    add_vertex_cells(polydata)
    
    # Add variable data to polydata
    for v, var in enumerate(variables):
        var_data = holder[:, v + 3]
        vtk_array = numpy_to_vtk(var_data, deep=True)
        vtk_array.SetName(var)
        polydata.GetPointData().AddArray(vtk_array)

    # Write to VTK file
    writer = vtk.vtkXMLPolyDataWriter()
    writer.SetFileName(f"./vtkp/series1_{t}.vtp")
    writer.SetInputData(polydata)
    writer.Write()

Processing time steps:   0%|          | 0/428 [00:00<?, ?it/s]