In [1]:
# Check if i can still read in a dat file
# I think this one is the real measured MSS
import pathlib

# Using this one for now, since it's the "real" one that Laura applied the denoising to
filepath = pathlib.Path(
    "/home/mh19137/geog_rdsf/data/projects/SING/richard_stuff/Table2/currents/dtu18_eigen-6c4_do0280_rr0004_cs.dat"
)
assert filepath.exists()

In [None]:
from current_denoising.generation import ioutils

data = ioutils.read_currents(filepath)

In [None]:
import numpy as np

data[data == 0] = np.nan

In [None]:
import matplotlib.pyplot as plt

from current_denoising.plotting import maps

fig, axis = plt.subplots(1, 1, figsize=(12, 6))

imshow_kw = {"origin": "lower", "vmin": 0, "vmax": 1.4, "cmap": "turbo"}
im = axis.imshow(data, **imshow_kw)

axis.set_title(filepath.stem)
fig.colorbar(im)

lat, long = maps.lat_long_grid(data.shape)
extent = [long[0], long[-1], lat[0], lat[-1]]
im.set_extent(extent)

fig.tight_layout()

In [None]:
from current_denoising.generation import ioutils

rng = np.random.default_rng(1234)

tile_size = 32
tiles, indices = ioutils.extract_tiles(
    rng,
    data,
    num_tiles=16,
    max_rms=np.inf,
    max_latitude=np.inf,
    tile_size=tile_size,
    return_indices=True,
)

In [None]:
def plot_tiles(tiles, indices) -> plt.Figure:
    fig, axes = plt.subplots(4, 4, figsize=(12, 12))
    for axis, tile, (y, x) in zip(axes.flat, tiles, indices):
        im = axis.imshow(tile, **imshow_kw)
        im.set_extent([long[x], long[x + tile_size], lat[y], lat[y + tile_size]])

    return fig


fig = plot_tiles(tiles, indices)
fig.suptitle("Example patches")

In [None]:
"""Plot the RMS of every possible patch in the image"""

from tqdm.notebook import tqdm

# Get the location of the tiles
lat_indices = np.arange(*ioutils._included_indices(data.shape[0], tile_size, np.inf))
long_indices = np.arange(0, data.shape[1] - tile_size + 1)

# Init an empty array
tile_rms = np.ones_like(data) * np.nan

# Iterate over the tiles
pbar = tqdm(total=len(lat_indices) * len(long_indices))

for i in lat_indices:
    for j in long_indices:
        tile = ioutils._tile(data, (i, j), tile_size)
        tile_rms[i + tile_size // 2, j + tile_size // 2] = ioutils._tile_rms(tile)
        pbar.update(1)

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(12, 12))
im = axes[0].imshow(tile_rms, **imshow_kw)
im.set_extent([long[0], long[-1], lat[0], lat[-1]])

_, bins, patches = axes[1].hist(tile_rms.ravel(), bins=100)
percentile = 0.50
threshold = np.nanquantile(tile_rms.ravel(), percentile)
axes[1].axvline(threshold, color="red")

bin_centres = 0.5 * (bins[1:] + bins[:-1])
for c, p in zip(bin_centres, patches):
    p.set_facecolor(im.cmap(im.norm(c)))
    p.set_edgecolor("none")

axes[0].set_title(f"Tile RMS, {tile_size=}")

axes[1].set_title(
    f"{percentile*100:.0f}"
    "$^{\mathrm{th}}$ percentile indicated"
    f" ({threshold:.2f}m/s)"
)
axes[1].set_xlabel("Tile RMS")
axes[1].set_ylabel("Count")

fig.colorbar(im, ax=axes[0])

In [None]:
# Tiles rejecting the ones with high latitude or RMS
latitude_threshold = 64.0
tiles, indices = ioutils.extract_tiles(
    rng,
    data,
    num_tiles=1024,
    max_rms=threshold,
    max_latitude=64.0,
    tile_size=tile_size,
    return_indices=True,
)

In [None]:
# Plot them on a world map, as a treat
from matplotlib import colors


def clear2black_cmap() -> colors.Colormap:
    """
    Colormap that varies from clear to black
    """
    c_white = colors.colorConverter.to_rgba("white", alpha=0)
    c_black = colors.colorConverter.to_rgba("black", alpha=1)
    return colors.ListedColormap([c_white, c_black], "clear2black")


tile_grid = np.ones_like(data) * np.nan

for tile, (y, x) in zip(tiles, indices):
    tile_grid[y : y + tile_size, x : x + tile_size] = tile

fig, axis = plt.subplots(1, 1, figsize=(12, 6))
axis.imshow(
    ~np.isnan(data),
    origin="lower",
    cmap=clear2black_cmap(),
    extent=(long[0], long[-1], lat[0], lat[-1]),
)
im = axis.imshow(tile_grid, **imshow_kw)

im.set_extent([long[0], long[-1], lat[0], lat[-1]])


for t in (latitude_threshold, -latitude_threshold):
    axis.axhline(t, color="r", linestyle="--")
axis.text(-198, -latitude_threshold, f"-{latitude_threshold}" + r"$\degree$", color="r")

fig.colorbar(im, ax=axis)
fig.suptitle(f"Extracted patches; patch RMS {threshold=:.2f}m/s")
fig.tight_layout()