# Analyze customer's data

In [None]:
import numpy as np

data_path = "data/Z_RADA_C_BABJ_20231213034816_P_ACHN.QREF.20231213.034200.bin.txt"
data = np.loadtxt(data_path)

data.shape

In [None]:
data[0][100]

In [None]:
np.unique(data)

In [None]:
import torch

a =  torch.rand((2, 8 + 4, 1, 256, 256))
a0 = a[0, :4, :, :]
a1 = a[0, 4:, :, :]
a0.shape

In [None]:
import bz2
import struct
import zipfile
import numpy as np
import os

def read_bin(path):
    with open(path, 'rb+') as f:
        bt = f.read()
        edge_s = struct.unpack('i', bt[124:128])[0] / 1000
        edge_w = struct.unpack('i', bt[128:132])[0] / 1000
        edge_n = struct.unpack('i', bt[132:136])[0] / 1000
        edge_e = struct.unpack('i', bt[136:140])[0] / 1000

        nX = struct.unpack('i', bt[148:152])[0]
        nY = struct.unpack('i', bt[152:156])[0]

        max_lon = edge_e  # max(lons)
        min_lon = edge_w  # min(lons)
        max_lat = edge_n  # max(lats)
        min_lat = edge_s  # min(lats)

        s = bz2.decompress(bt[256:])
        # print(len(s) / 2 / 4200)
        n = []
        for i in range(0, nY):
            inner = []
            for j in range(0, nX):
                a = struct.unpack('h', s[((i * nX * 2) + j * 2):((i * nX * 2) + j * 2 + 2)])[0]
                if a < 0:
                    a = 0
                inner.append(a / 10.0)
            n.append(inner)

        return n
    
def read_zuimei_data(data_folder):
    
    filenames = []
    
    for filename in os.listdir(data_folder):
        if filename.endswith('.bin'):
            end_number = int(filename.split('_')[4][-4:-2])
            # print(end_number)
            if end_number % 6 == 0: #  in ['00', '06', '12', '18', '24', '30', '36', '42', '48', '54']:
                filenames.append(filename)
        
        # break
    
    filenames.sort(key=lambda x: int(x.split('_')[4]))
    
    # print(filenames)
    data = []
    num = 24
    for filename in filenames[:num]:
        frame = read_bin(os.path.join(data_folder, filename))
        data.append(frame)
        
    return np.array(data)

In [None]:
data_folder = 'data/zuimei_precipitation'

frames = read_zuimei_data(data_folder)
print(frames.shape)

In [None]:
from tqdm import tqdm

image = frames[0]
crop_size = 256

def compute_integral_image(image):
    """Compute the integral image of a given image."""
    integral_image = np.zeros_like(image, dtype=np.uint64)
    rows, cols = image.shape
    
    # Compute the integral image
    for i in tqdm(range(rows)):
        for j in range(cols):
            integral_image[i, j] = image[i, j]
            if i > 0:
                integral_image[i, j] += integral_image[i-1, j]
            if j > 0:
                integral_image[i, j] += integral_image[i, j-1]
            if i > 0 and j > 0:
                integral_image[i, j] -= integral_image[i-1, j-1]
                
    return integral_image

def get_window_sum(integral_image, top_left, bottom_right):
    """Compute the sum of pixel values in a window using the integral image."""
    top, left = top_left
    bottom, right = bottom_right
    
    window_sum = integral_image[bottom, right]
    if top > 0:
        window_sum -= integral_image[top-1, right]
    if left > 0:
        window_sum -= integral_image[bottom, left-1]
    if top > 0 and left > 0:
        window_sum += integral_image[top-1, left-1]
        
    return window_sum

def crop_max_non_zero_optimized(image, crop_size=256):
    # Compute the integral image
    integral_image = compute_integral_image(image)
    
    # Get the dimensions of the image
    rows, cols = image.shape
    
    # Initialize variables to keep track of the maximum non-zero count and its position
    max_count = -1
    max_pos = (0, 0)
    
    # Slide the window across the integral image
    for i in tqdm(range(0, rows - crop_size + 1)):
        for j in range(0, cols - crop_size + 1):
            # Compute the sum of pixel values in the current window
            count = get_window_sum(integral_image, (i, j), (i+crop_size-1, j+crop_size-1))
            
            # Update the maximum count and its position if the current count is greater
            if count > max_count:
                max_count = count
                max_pos = (i, j)
    
    # Crop the image using the position with the maximum count
    cropped_image = image[max_pos[0]:max_pos[0]+crop_size, max_pos[1]:max_pos[1]+crop_size]
    
    return cropped_image, max_pos

# Test with the sample image
cropped_image, max_pos = crop_max_non_zero_optimized(image)
cropped_image.shape, np.count_nonzero(cropped_image)  # Display the shape of the cropped image and the number of non-zero values

In [None]:
cropped_frames = frames[:, max_pos[0]:max_pos[0]+crop_size, max_pos[1]:max_pos[1]+crop_size]
cropped_frames.shape

In [None]:
import matplotlib.pyplot as plt

# Create a figure and axis
fig, ax = plt.subplots(figsize=(8, 6))

# Plot the data as an image
im = ax.imshow(cropped_frames[6], vmin=0, vmax=60, cmap='jet')  # viridis 

# Add a colorbar beside the figure
cbar = fig.colorbar(im, ax=ax, shrink=1)

# Add a label to the colorbar
cbar.set_label('Precipitation (mm/h)', rotation=270, labelpad=20)
# plt.imshow(cropped_frames[0, ...], vmin=0, vmax=10, cmap='jet')
plt.show()

In [None]:
path = 'data/zuimei_precipitation/Z_RADA_C_BABJ_20240429012413_P_DOR_ACHN_OHP06_20240429_011800.bin'

precipitation = read_bin(path)
precipitation = np.array(precipitation)
precipitation.shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Generate a sample NumPy array
np.random.seed(42)
# data = np.random.randn(100, 200)

# Flatten the array to 1D
data_flat = precipitation.flatten()

filtered_data = data_flat[data_flat >= 0]

# Plot the histogram
plt.hist(filtered_data, bins=50, density=True, alpha=0.6)
plt.title('Data Distribution')
plt.xlabel('Value')
plt.ylabel('Density')
plt.show()

# Test loading MRMS dataset

In [None]:
# !pip install -U "huggingface_hub[cli]"

In [None]:
# !git config --global credential.helper store

In [None]:
# !huggingface-cli login --token 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 --add-to-git-credential

In [None]:
# !pip install xarray
# !pip install -U tensorflow~=2.10.0
# !pip install numpy==1.23.4

In [None]:
from datasets import load_dataset

dataset = load_dataset("openclimatefix/mrms", "default", split='test', use_auth_token=True) # streaming=True, 

In [None]:
# !pip install -q pyarrow

In [None]:
import pandas as pd
import pyarrow.parquet as pq

test_path = 'data/dgmr-mrms/data/train-00000-of-58896-a1855e9018799bba.parquet'
parquet_file = pq.ParquetFile(test_path)

# Read the data from the file
data = parquet_file.read()

# Convert the data to a Pandas DataFrame
import pandas as pd
df = data.to_pandas()

# Print the first few rows of the DataFrame
print(df.head())

In [None]:
df.loc[0, 'precipitation_rate'][0].shape

In [None]:
df.loc[0, 'timestamps'].shape

In [None]:
df.loc[0, 'latitude'].shape

In [None]:
df.loc[0, 'longitude'].shape

# Test loading nimrod uk data

In [None]:
import os
import tensorflow as tf

print(tf.__version__)

_FEATURES = {name: tf.io.FixedLenFeature([], dtype)
             for name, dtype in [
               ("radar", tf.string), ("sample_prob", tf.float32),
               ("osgb_extent_top", tf.int64), ("osgb_extent_left", tf.int64),
               ("osgb_extent_bottom", tf.int64), ("osgb_extent_right", tf.int64),
               ("end_time_timestamp", tf.int64),
             ]}

_SHAPE_BY_SPLIT_VARIANT = {
    ("train", "random_crops_256"): (24, 256, 256, 1),
    ("valid", "subsampled_tiles_256_20min_stride"): (24, 256, 256, 1),
    ("test", "full_frame_20min_stride"): (24, 1536, 1280, 1),
    ("test", "subsampled_overlapping_padded_tiles_512_20min_stride"): (24, 512, 512, 1),
}

_MM_PER_HOUR_INCREMENT = 1/32.
_MAX_MM_PER_HOUR = 128.
_INT16_MASK_VALUE = -1


def parse_and_preprocess_row(row, split, variant):
    result = tf.io.parse_example(row, _FEATURES)
    shape = _SHAPE_BY_SPLIT_VARIANT[(split, variant)]
    radar_bytes = result.pop("radar")
    radar_int16 = tf.reshape(tf.io.decode_raw(radar_bytes, tf.int16), shape)
    mask = tf.not_equal(radar_int16, _INT16_MASK_VALUE)
    radar = tf.cast(radar_int16, tf.float32) * _MM_PER_HOUR_INCREMENT
    radar = tf.clip_by_value(
      radar, _INT16_MASK_VALUE * _MM_PER_HOUR_INCREMENT, _MAX_MM_PER_HOUR)
    result["radar_frames"] = radar
    result["radar_mask"] = mask
    return result


def reader(root_dir, split="train", variant="random_crops_256", shuffle_files=False):
    """Reader for open-source nowcasting datasets.
    Args:
    split: Which yearly split of the dataset to use:
      "train": Data from 2016 - 2018, excluding the first day of each month.
      "valid": Data from 2016 - 2018, only the first day of the month.
      "test": Data from 2019.
    variant: Which variant to use. The available variants depend on the split:
      "random_crops_256": Available for the training split. 24x256x256 pixel
        crops, sampled with a bias towards crops containing rainfall. Crops at
        all spatial and temporal offsets were able to be sampled, some crops may
        overlap.
      "subsampled_tiles_256_20min_stride": Available for the validation set.
        Non-spatially-overlapping 24x256x256 pixel crops, subsampled from a
        regular spatial grid with stride 256x256 pixels, and a temporal stride
        of 20mins (4 timesteps at 5 minute resolution). Sampling favours crops
        containing rainfall.
      "subsampled_overlapping_padded_tiles_512_20min_stride": Available for the
        test set. Overlapping 24x512x512 pixel crops, subsampled from a
        regular spatial grid with stride 64x64 pixels, and a temporal stride
        of 20mins (4 timesteps at 5 minute resolution). Subsampling favours
        crops containing rainfall.
        These crops include extra spatial context for a fairer evaluation of
        the PySTEPS baseline, which benefits from this extra context. Our other
        models only use the central 256x256 pixels of these crops.
      "full_frame_20min_stride": Available for the test set. Includes full
        frames at 24x1536x1280 pixels, every 20 minutes with no additional
        subsampling.
    shuffle_files: Whether to shuffle the shard files of the dataset
      non-deterministically before interleaving them. Recommended for the
      training set to improve mixing and read performance (since
      non-deterministic parallel interleave is then enabled).

    Returns:
    A tf.data.Dataset whose rows are dicts with the following keys:

    "radar_frames": Shape TxHxWx1, float32. Radar-based estimates of
      ground-level precipitation, in units of mm/hr. Pixels which are masked
      will take on a value of -1/32 and should be excluded from use as
      evaluation targets. The coordinate reference system used is OSGB36, with
      a spatial resolution of 1000 OSGB36 coordinate units (approximately equal
      to 1km). The temporal resolution is 5 minutes.
    "radar_mask": Shape TxHxWx1, bool. A binary mask which is False
      for pixels that are unobserved / unable to be inferred from radar
      measurements (e.g. due to being too far from a radar site). This mask
      is usually static over time, but occasionally a whole radar site will
      drop in or out resulting in large changes to the mask, and more localised
      changes can happen too. 
    "sample_prob": Scalar float. The probability with which the row was
      sampled from the overall pool available for sampling, as described above
      under 'variants'. We use importance weights proportional to 1/sample_prob
      when computing metrics on the validation and test set, to reduce bias due
      to the subsampling.
    "end_time_timestamp": Scalar int64. A timestamp for the final frame in
      the example, in seconds since the UNIX epoch (1970-01-01 00:00:00 UTC).
    "osgb_extent_left", "osgb_extent_right", "osgb_extent_top",
    "osgb_extent_bottom":
      Scalar int64s. Spatial extent for the crop in the OSGB36 coordinate
      reference system.
    """
    shards_glob = os.path.join(root_dir, split, variant, "*.tfrecord.gz")
    shard_paths = tf.io.gfile.glob(shards_glob)  # ['data/nimrod-uk-1km/seq-24-00002-of-00712.tfrecord.gz'] #
    shard_paths = shard_paths[:2]
    shards_dataset = tf.data.Dataset.from_tensor_slices(shard_paths)
    if shuffle_files:
        shards_dataset = shards_dataset.shuffle(buffer_size=len(shard_paths))
    return (
      shards_dataset
      .interleave(lambda x: tf.data.TFRecordDataset(x, compression_type="GZIP"),
                  num_parallel_calls=tf.data.AUTOTUNE,
                  deterministic=not shuffle_files)
      .map(lambda row: parse_and_preprocess_row(row, split, variant),
           num_parallel_calls=tf.data.AUTOTUNE)
      # Do your own subsequent repeat, shuffle, batch, prefetch etc as required.
    )

In [None]:
from datasets import load_dataset

# dataset = load_dataset("openclimatefix/nimrod-uk-1km", "sample", use_auth_token=True) # streaming=True, 

root_dir = 'data/nimrod-uk-1km/20200718'

# dataset = reader(root_dir, split='train', variant='random_crops_256')  # 'full_frame_20min_stride'

data_path = "/home/ec2-user/SageMaker/efs/Projects/skillful_nowcasting/data/nimrod-uk-1km"
split = "train"
dataset = load_dataset(data_path, split=split, streaming=True)

In [None]:
type(dataset)

In [None]:
example = next(iter(dataset))
example

In [None]:
NUM_INPUT_FRAMES = 4
NUM_TARGET_FRAMES = 18

def extract_input_and_target_frames(radar_frames):
    """Extract input and target frames from a dataset row's radar_frames."""
    # We align our targets to the end of the window, and inputs precede targets.
    input_frames = radar_frames[-NUM_TARGET_FRAMES - NUM_INPUT_FRAMES : -NUM_TARGET_FRAMES]
    target_frames = radar_frames[-NUM_TARGET_FRAMES:]
    return input_frames, target_frames

def collate_fn(data):
    """
       data: is a list of tuples with (example, label, length)
             where 'example' is a tensor of arbitrary shape
             and label/length are scalars
    """
    print(f"data: {data}")
    
    input_frames, target_frames = extract_input_and_target_frames(data['radar_frames'])
    
    return np.moveaxis(input_frames, [0, 1, 2, 3], [0, 2, 3, 1]), np.moveaxis(
                    target_frames, [0, 1, 2, 3], [0, 2, 3, 1]
    )

In [None]:
import torch

train_dataloader = torch.utils.data.DataLoader(
    dataset,
    shuffle=False,
    collate_fn=collate_fn,
    batch_size=1,
    num_workers=4,
)

In [None]:
batch = iter(train_dataloader)
print(batch)

In [None]:
help(dataset)

In [None]:
dataset.dataset_size

In [None]:
# for i, row in enumerate(dataset):
#     print(i)

In [None]:
row = next(iter(dataset))

In [None]:
{k: (v.dtype, v.shape) for k, v in row.items()}

In [None]:
{k: v.numpy() for k, v in row.items() if v.shape.ndims == 0}

In [None]:
row['radar_frames'].shape

In [None]:
NUM_INPUT_FRAMES = 4
NUM_TARGET_FRAMES = 24

def extract_input_and_target_frames(radar_frames):
    """Extract input and target frames from a dataset row's radar_frames."""
    # We align our targets to the end of the window, and inputs precede targets.
    input_frames = radar_frames[-NUM_TARGET_FRAMES-NUM_INPUT_FRAMES : -NUM_TARGET_FRAMES]
    input_frames = input_frames.unsqueeze(0)
    input_frames = input_frames.permute(0, 1, 4, 2, 3)
    
    target_frames = radar_frames[-NUM_TARGET_FRAMES : ]
    target_frames = target_frames.unsqueeze(0)
    target_frames = target_frames.permute(0, 1, 4, 2, 3)
    
    return input_frames, target_frames

In [None]:
import torch

# radar_frames = torch.from_numpy(row['radar_frames'].numpy())
radar_frames = torch.from_numpy(cropped_frames[...,np.newaxis])
print(radar_frames.dtype)
radar_frames = radar_frames.to(torch.float32)
print(radar_frames.dtype)
type(radar_frames)

In [None]:
radar_frames.shape

In [None]:
input_frames, target_frames = extract_input_and_target_frames(radar_frames)
print(input_frames.shape)
print(target_frames.shape)

In [None]:
input_frames.dtype

# Loading radar data

In [1]:
import random
import torch
import os
import numpy as np

num_input = 4
num_target = 20

def revert_back_numpy_array(byte_array, size=(24, 256, 256), dtype=np.float32, source_dtype=np.float32):
    # Load the flattened data from disk
    flattened_data = bytearray(byte_array) 

    # Convert the bytearray to a numpy array
    # flattened_array = np.asarray(flattened_data, dtype=np.float64)
    flattened_array = np.frombuffer(flattened_data, dtype=source_dtype)
    # print(f"flattened_array: {flattened_array.shape}")
    # Reshape the flattened array to the original shape
    original_shape = size
    original_array = flattened_array.reshape(original_shape).astype(dtype)
    
    return original_array

def collate_fn(examples):
    """
       data: is a list of tuples with (example, label, length)
             where 'example' is a tensor of arbitrary shape
             and label/length are scalars
    """
    # print(f"data: {data[0]}, {type(data)}, {len(data)}")
    
    inputs, targets = [], []
    for i, example in enumerate(examples):
        # cropped_frames_max_nonzero = np.asarray(bytearray(), dtype="float32")
        # print(f"cropped_frames_max_nonzero: {cropped_frames_max_nonzero.shape}")
        cropped_frames_max_nonzero = revert_back_numpy_array(example["cropped_frames_max_nonzero"], size=(24, 256, 256), dtype=np.float32)
        # max_pos = np.asarray(bytearray(example["max_pos"]), dtype="uint8")
        max_pos = revert_back_numpy_array(example["max_pos"], size=(2), dtype=np.uint8, source_dtype=np.float32)
        
        cropped_frames_random = revert_back_numpy_array(example["cropped_frames_random"], size=(24, 256, 256), dtype=np.float32)
        # cropped_frames_random = np.asarray(bytearray(example["cropped_frames_random"]), dtype="float32")
        # random_pos = np.asarray(bytearray(example["random_pos"]), dtype="uint8")
        random_pos = revert_back_numpy_array(example["random_pos"], size=(2), dtype=np.uint8, source_dtype=np.float32)
        
        if random.random() < 0.5:
            input_frames = cropped_frames_max_nonzero[:num_input, ...]
            target_frames = cropped_frames_max_nonzero[num_input:num_input+num_target, ...]
        else:
            input_frames = cropped_frames_random[:num_input, ...]
            target_frames = cropped_frames_random[num_input:num_input+num_target, ...]
                        
        inputs.append(input_frames)
        targets.append(target_frames)
        
    inputs_tensor = torch.Tensor(np.stack(inputs)).unsqueeze(2)
    targets_tensor = torch.Tensor(np.stack(targets)).unsqueeze(2)
    
    return inputs_tensor, targets_tensor

In [2]:
from datasets import load_dataset

data_dir = "./data/zuimei-radar-cropped/"

train_dataset = load_dataset("webdataset", 
                    data_files={"train": os.path.join(data_dir,"*.tar")}, 
                    split="train", 
                    streaming=True)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=False,
        collate_fn=collate_fn,
        batch_size=1,
        num_workers=1,
    )

In [4]:
for step, batch in enumerate(train_dataloader):
    if step >= 1:
        break
    print(batch)

(tensor([[[[[ 0.0000,  0.0000,  0.0000,  ..., 21.5000, 24.0000, 24.5000],
           [ 0.0000,  0.0000,  0.0000,  ..., 23.5000, 23.5000, 22.0000],
           [ 0.0000,  0.0000,  0.0000,  ..., 23.0000, 24.0000, 23.0000],
           ...,
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],


         [[[ 0.0000,  0.0000,  0.0000,  ..., 22.0000, 24.5000, 25.0000],
           [ 0.0000,  0.0000,  0.0000,  ..., 24.0000, 25.0000, 24.0000],
           [ 0.0000,  0.0000,  0.0000,  ..., 24.5000, 24.5000, 23.5000],
           ...,
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],


         [[[ 0.0000,  0.0000,  0.0000,  ..., 25.5000, 25.8000, 24.3000],
          

In [5]:
batch[0].shape

torch.Size([1, 4, 1, 256, 256])

In [6]:
batch[1].shape

torch.Size([1, 20, 1, 256, 256])

In [8]:
input_frames = batch[0]
target_frames = batch[1]
print(input_frames.shape, target_frames.shape)

torch.Size([1, 4, 1, 256, 256]) torch.Size([1, 20, 1, 256, 256])


# Test DGMR model for inference

In [9]:
import os
from dgmr import DGMR, Sampler, Generator, Discriminator, LatentConditioningStack, ContextConditioningStack

model_folder = "models/"

model = DGMR.from_pretrained(model_folder+"dgmr")
sampler = Sampler.from_pretrained(model_folder+"dgmr-sampler")
discriminator = Discriminator.from_pretrained(model_folder+"dgmr-discriminator")
latent_stack = LatentConditioningStack.from_pretrained(model_folder+"dgmr-latent-conditioning-stack")
context_stack = ContextConditioningStack.from_pretrained(model_folder+"dgmr-context-conditioning-stack")
generator = Generator(conditioning_stack=context_stack, latent_stack=latent_stack, sampler=sampler)

Loading weights from local directory
Loading weights from local directory
Loading weights from local directory
Loading weights from local directory
Loading weights from local directory


In [10]:
# help(model)

In [11]:
with torch.no_grad():
    pred_frames = model(input_frames)

print(pred_frames.shape)

torch.Size([1, 20, 1, 256, 256])


In [12]:
# from dgmr import DGMR
# import torch.nn.functional as F
# import torch

# model = DGMR(
#         forecast_steps=18,
#         input_channels=1,
#         output_shape=256,
#         latent_channels=384,
#         context_channels=192,
#         num_samples=3,
#     )

# x = torch.rand((2, 4, 1, 256, 256))

# y = torch.rand((2, 4, 1, 128, 128))
# loss = F.mse_loss(y, out)
# loss.backward()

In [13]:
# pred_frames = pred_frames.squeeze(0).permute(0, 2, 3, 1)
# pred_frames.shape

target_pred_frames = torch.cat((target_frames, pred_frames), dim=0)
target_pred_frames.shape

torch.Size([2, 20, 1, 256, 256])

In [14]:
def horizontally_concatenate_batch(samples):
    n, t, c, h, w = samples.shape
    # N, T, C, H, W, C => T, H, N, W, C
    samples = samples.permute(1, 3, 0, 4, 2)
    # T, H, N, W, C => T, H, N*W, C
    samples = samples.reshape(t, h, n*w, c)
    return samples

# visualize results

In [15]:
import datetime
import os

import cartopy
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import shapely.geometry as sgeom
import tensorflow as tf
# import tensorflow_hub

matplotlib.rc('animation', html='jshtml')

def plot_animation(field, figsize=None, vmin=0, vmax=10, cmap="jet", **imshow_args):
    fig = plt.figure(figsize=figsize)
    ax = plt.axes()
    ax.set_axis_off()
    plt.close() # Prevents extra axes being plotted below animation
    img = ax.imshow(field[0, ..., 0], vmin=vmin, vmax=vmax, cmap=cmap, **imshow_args)

    def animate(i):
        img.set_data(field[i, ..., 0])
        return (img,)

    return animation.FuncAnimation(
      fig, animate, frames=field.shape[0], interval=24, blit=False)


class ExtendedOSGB(cartopy.crs.OSGB):
    """MET office radar data uses OSGB36 with an extended bounding box."""

    def __init__(self):
        super().__init__(approx=False)

    @property
    def x_limits(self):
        return (-405000, 1320000)

    @property
    def y_limits(self):
        return (-625000, 1550000)
    
    @property
    def boundary(self):
        x0, x1 = self.x_limits
        y0, y1 = self.y_limits
        return sgeom.LinearRing([(x0, y0), (x0, y1), (x1, y1), (x1, y0), (x0, y0)])


def plot_rows_on_map(rows, field_name="radar_frames", timestep=0, num_rows=None,
                     cbar_label=None, **imshow_kwargs):
    fig = plt.figure(figsize=(10, 10))
    axes = fig.add_subplot(1, 1, 1, projection=ExtendedOSGB())
    if num_rows is None:
        num_rows = next(iter(rows.values())).shape[0]
    for b in range(num_rows):
        extent = (rows["osgb_extent_left"][b].numpy(),
                  rows["osgb_extent_right"][b].numpy(),
                  rows["osgb_extent_bottom"][b].numpy(),
                  rows["osgb_extent_top"][b].numpy())
        im = axes.imshow(rows[field_name][b, timestep, ..., 0].numpy(),
                    extent=extent, **imshow_kwargs)

    axes.set_xlim(*axes.projection.x_limits)
    axes.set_ylim(*axes.projection.y_limits)
    axes.set_facecolor("black")
    axes.gridlines(alpha=0.5)
    axes.coastlines(resolution="50m", color="white")
    if cbar_label:
        cbar = fig.colorbar(im)
        cbar.set_label(cbar_label)
    return fig


def plot_animation_on_map(row):
    fig = plt.figure(figsize=(10, 10))
    axes = fig.add_subplot(1, 1, 1, projection=ExtendedOSGB())
    plt.close() # Prevents extra axes being plotted below animation

    axes.gridlines(alpha=0.5)
    axes.coastlines(resolution="50m", color="white")

    extent = (row["osgb_extent_left"].numpy(),
            row["osgb_extent_right"].numpy(),
            row["osgb_extent_bottom"].numpy(),
            row["osgb_extent_top"].numpy())

    img = axes.imshow(
      row["radar_frames"][0, ..., 0].numpy(),
      extent=extent, vmin=0, vmax=15., cmap="jet")

    cbar = fig.colorbar(img)
    cbar.set_label("Precipitation, mm/hr")

    def animate(i):
        return img.set_data(row["radar_frames"][i, ..., 0].numpy()),

    return animation.FuncAnimation(
      fig, animate, frames=row["radar_frames"].shape[0],
      interval=24, blit=False)


def plot_mask_on_map(row):
    fig = plt.figure(figsize=(10, 10))
    axes = fig.add_subplot(1, 1, 1, projection=ExtendedOSGB())
    axes.gridlines(alpha=0.5)
    axes.coastlines(resolution="50m", color="black")

    extent = (row["osgb_extent_left"].numpy(),
            row["osgb_extent_right"].numpy(),
            row["osgb_extent_bottom"].numpy(),
            row["osgb_extent_top"].numpy())

    img = axes.imshow(
      row["radar_mask"][0, ..., 0].numpy(),
      extent=extent, vmin=0, vmax=1, cmap="viridis")

2024-05-10 07:29:17.062251: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
plot_animation(horizontally_concatenate_batch(target_pred_frames).numpy(), figsize=(10, 5), vmax=60)

In [None]:
plot_animation(row["radar_frames"].numpy())

In [None]:
plt.imshow(row["radar_mask"][0, ..., 0].numpy(), vmin=0, vmax=1);

In [None]:
dataset = reader(root_dir=root_dir, split="test", variant="full_frame_20min_stride")
full_frame_test_set_row = next(iter(dataset))

In [None]:
plot_animation_on_map(full_frame_test_set_row)

In [None]:
plot_mask_on_map(full_frame_test_set_row)

In [None]:
BATCH_SIZE = 60
dataset = reader(root_dir=root_dir, split="train", variant="random_crops_256")
rows = next(iter(dataset.batch(BATCH_SIZE)))

In [None]:
plot_rows_on_map(rows, field_name="radar_frames", num_rows=10, vmin=0, vmax=15.,
                 cmap="jet", cbar_label="Precipitation, mm/hr");

In [None]:
plot_rows_on_map(rows, field_name="radar_mask", vmin=0, vmax=1, alpha=0.5, cmap="spring");