In [1]:
import os
from pathlib import Path
import datetime as dt

import json
import xarray as xr
import numpy as np
import pandas as pd
from netCDF4 import Dataset as netcdf_dataset

import cmocean
import matplotlib as mpl

import matplotlib.pyplot as plt
import cartopy.crs as ccrs

from matplotlib.colors import BoundaryNorm
from matplotlib.ticker import MaxNLocator
from matplotlib.transforms import offset_copy

from math import floor
import imageio


#commented logging to make things easier
#import logging

In [2]:
class Settings(object):
    _instance = None
    _fname = "test_settings.json"
    _data = {}

    def set_config(self,_fname):
        with open(_fname) as json_file:
            self._data = json.load(json_file)
        if not os.path.exists(self._data['data_dir']):
            os.makedirs(self._data['data_dir'])
        if not os.path.exists(self._data['out_dir']):
            os.makedirs(self._data['out_dir'])

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(Settings, cls).__new__(cls)
            cls.set_config(cls, cls._fname)
        return cls._instance
    def get(self,atr):
        return self._data[atr]
    def set_attribute(self,atr, value):
        self._data[atr] = value 

In [3]:
#file list is []
def get_min_max(wave_dict, file_list):
    wave_min, wave_max = 0, 0
    for f_p in file_list:
        dataset = netcdf_dataset(os.path.join(Settings().get('data_dir'), f_p))
        for wave in wave_dict.keys():
            wave_data = dataset.variables[wave][0, :, :]
            #wtf why i have temp here?
            temp_min = wave_data.min()
            temp_max = wave_data.max()
            if temp_min < wave_min:
                wave_min = temp_min
            if temp_max > wave_max:
                wave_max = temp_max
        dataset.close()
    return (wave_min, wave_max)


In [4]:
def make_ticks(w_min, w_max):
    scale = w_max - w_min
    if scale > 10:
        num = floor((w_max - w_min)/2) + 1
    else:
        num = floor(w_max) + 1
    ticks = np.linspace(floor(w_min), floor(w_max), num=num, dtype = int)
    labels = []
    for t in ticks:
        if t == 0:
            labels.append('flat')
        elif t == 1:
            labels.append('1 meter')
        else:
            labels.append(str(t)+ ' meters')
    return ticks, labels

In [5]:
#just convert index to actual datime string eg:  '2021.01.04 00h:00'
def get_time_string(i_time):
    return pd.to_datetime(str(time_xr[i_time].data)).strftime('%Y.%m.%d %Hh:%M')

In [6]:
def draw_single_frame(lats, lons, wave_tn, wave, levels, norm, frame_n, time_stamp_n):
    
    cmap = cmocean.cm.tempo
    wave_var = wave['var']
    wave_name = wave['name']
    w_min = wave['min']
    w_max = wave['max']
    dx, dy = 0.05, 0.05
    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_subplot(1, 1, 1, projection=proj)
    
    stgs = Settings()
    ax.set_extent([stgs.get('min_lon'),stgs.get('max_lon'),stgs.get('min_lat'),stgs.get('max_lat')], crs=proj)
    gr =  ax.gridlines(draw_labels = True)
    gr.right_labels = False
    ax.coastlines()
    
    geodetic_transform = ccrs.Geodetic()._as_mpl_transform(ax)
    text_transform = offset_copy(geodetic_transform, units='dots', x=-25)

    ax.plot(stgs.get('point_lon'), stgs.get('point_lat'), marker='o', color='red', markersize=8,
            alpha=0.7, transform=ccrs.Geodetic())

    ax.text(stgs.get('text_lon'), stgs.get('text_lat'), stgs.get('text_label'),
            verticalalignment='bottom', horizontalalignment='left',
            transform=text_transform,
            bbox=dict(facecolor='white', alpha=0.5, boxstyle='round'))

    
    ax.set_title(wave_name + " " + time_stamp_n, fontsize=16)
    contour_img = ax.contourf(lons + dx/2., lats + dy/2.,wave_tn, levels=levels, cmap=cmap)
    
    ticks,labels = make_ticks(w_min, w_max)
    cbar = fig.colorbar(contour_img, ax=ax, ticks = ticks)
    cbar.ax.set_yticklabels(labels)
    frame_fname = wave_var + "_" + (str(frame_n)).zfill(2) + ".png"
    frame_path = os.path.join(stgs.get('out_dir'), frame_fname)
    plt.savefig(frame_path)
    plt.close('all')



In [7]:
def draw_frames(wave_dict, w_min, w_max, file_paths):
    frame_index = 0 
    #cmap = plt.get_cmap("PuBuGn")
    cmap = cmocean.cm.tempo
    levels = MaxNLocator(nbins=20).tick_values(w_min, w_max)
    norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)

    for f_p in file_paths:
        dataset = netcdf_dataset(f_p)
        lats = dataset.variables['latitude'][:]
        lons = dataset.variables['longitude'][:]
        time_data = dataset.variables['time'][:]

        dataset_xr = xr.open_dataset(f_p)
        time_xr = dataset_xr['time']

        for wave_var, wave_name in wave_dict.items():
            for i in range(time_data.size):
                loop_start_time = dt.datetime.now()
                wave_tn = dataset.variables[wave_var][i, :, :]
                data_time = pd.to_datetime(str(time_xr[i].data))
                ts = pd.to_datetime(str(time_xr[i].data)) 
                time_string = ts.strftime('%Y.%m.%d %Hh:%M')
                wave = {"var" :wave_var, "name": wave_name, "min": w_min, "max": w_max}
                draw_single_frame(lats, lons, wave_tn, wave, levels, norm,frame_index + i, time_string)

        frame_index += time_data.size
        dataset.close()
        dataset_xr.close()

In [8]:
def create_animations(wave_dict):
    anim_paths = []
    for wave_var in wave_dict.keys():
        loop_start_time = dt.datetime.now()
        frame_name_pattern = wave_var + "*" + ".png"
        anim_frames_glob = Path(Settings().get('out_dir'))
        frames = sorted(list(anim_frames_glob.glob(frame_name_pattern)))
        frame_list = []
        for frame in frames:
            frame_list.append(imageio.imread(frame))
        anim_fname = wave_var + '_anim.gif'
        anim_path = os.path.join(os.getcwd(), 'testdata', anim_fname)
        imageio.mimwrite(anim_path, frame_list, duration=0.5, loop=10)
        anim_paths.append(anim_path)
    return anim_paths

In [9]:
wave_dict = {"VHM0_WW": "significant wind wave height",
            "VHM0_SW1": "significant primary swell wave height",
            "VHM0_SW2": "significant secondary swell wave height"}

proj = ccrs.PlateCarree()

function_start_time = dt.datetime.now()    
#hardcoded just one file to make things easier
files = [os.path.join(os.getcwd(), 'testdata', 'testfile1.nc')]
min_wave, max_wave = get_min_max(wave_dict, files)
draw_frames(wave_dict, min_wave, max_wave, files)
anim_paths = create_animations(wave_dict)
print(anim_paths)
print(min_wave,max_wave)


['/home/anna/annaCode/CleanSwellVisualisations/testdata/VHM0_WW_anim.gif', '/home/anna/annaCode/CleanSwellVisualisations/testdata/VHM0_SW1_anim.gif', '/home/anna/annaCode/CleanSwellVisualisations/testdata/VHM0_SW2_anim.gif']
0 3.8
