In [None]:
import sys
sys.path.append('../')

import numpy as np
import tensorflow as tf
import tensorflow.keras as keras

from tensorflow.keras.callbacks import ReduceLROnPlateau
from src.losses import dice_loss
from models.cnn_model_factory import get_no_downsample_cnn_model

np.random.seed(0)
tf.random.set_seed(0)

In [3]:
train_x_mmap = np.load('../data/train_x.npy', mmap_mode='r')
train_y_mmap = np.load('../data/train_y.npy', mmap_mode='r')
test_x_mmap = np.load('../data/test_x.npy', mmap_mode='r')
test_y_mmap = np.load('../data/test_y.npy', mmap_mode='r')

In [None]:
def resize_data_and_labels(x, y, reshape_size):
    x_resized = tf.image.resize(x, reshape_size)
    y_resized = tf.image.resize(y[..., np.newaxis], reshape_size)

    return x_resized, y_resized

RESIZE_SHAPE = (256, 256)

train_x_mmap, train_y_mmap = resize_data_and_labels(train_x_mmap, train_y_mmap, RESIZE_SHAPE)
test_x_mmap, test_y_mmap = resize_data_and_labels(test_x_mmap, test_y_mmap, RESIZE_SHAPE)

In [5]:
def get_global_normalization_mean_std(data):
    mean_global = np.mean(data, axis=(0, 1, 2), keepdims=True)
    std_global = np.std(data, axis=(0, 1, 2), keepdims=True)

    std_global[std_global == 0] = 1.0
    return mean_global, std_global

NORM_MEAN, NORM_STD = get_global_normalization_mean_std(train_x_mmap)

In [6]:
# Generator function for training data
def train_generator(norm_mean, norm_std):
    def generator():
        for i in range(len(train_x_mmap)):
            image = train_x_mmap[i]
            label = train_y_mmap[i]
            
            # Resize using TensorFlow
            # image_resized = tf.image.resize(image, resize_shape)
            # label_resized = tf.image.resize(label, resize_shape)
            
            yield (image - norm_mean) / norm_std, label
    return generator

# Generator function for testing data
def test_generator(norm_mean, norm_std):
    def generator():
        for i in range(len(test_x_mmap)):
            image = test_x_mmap[i]
            label = test_y_mmap[i]
            
            # Resize using TensorFlow
            # image_resized = tf.image.resize(image, resize_shape)
            # label_resized = tf.image.resize(label, resize_shape)
            
            yield (image - norm_mean) / norm_std, label
    return generator

train_dataset = tf.data.Dataset.from_generator(
  generator=train_generator(NORM_MEAN[0], NORM_STD[0]),
  output_signature=(
    tf.TensorSpec(shape=RESIZE_SHAPE + (16, ), dtype=tf.float32),  # Images
    tf.TensorSpec(shape=RESIZE_SHAPE + (1, ), dtype=tf.float32)   # Labels
  )
)

test_dataset = tf.data.Dataset.from_generator(
  generator=test_generator(NORM_MEAN[0], NORM_STD[0]),
  output_signature=(
    tf.TensorSpec(shape=RESIZE_SHAPE + (16, ), dtype=tf.float32),  # Images
    tf.TensorSpec(shape=RESIZE_SHAPE + (1, ), dtype=tf.float32)   # Labels
  )
)

# limit train dataset to images with plumes (due to severe class imbalance)
train_dataset = train_dataset.filter(lambda x, y: tf.reduce_any(y > 0.0))

batch_size = 32
train_dataset = train_dataset.shuffle(buffer_size=1000).batch(batch_size).prefetch(tf.data.AUTOTUNE).cache()
test_dataset = test_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE).cache()

In [7]:
model = get_no_downsample_cnn_model(input_shape=RESIZE_SHAPE + (16,), output_channels=1, loss_weight=5.0)

In [None]:
history = model.fit(
    train_dataset,
    validation_data=(test_dataset),
    epochs=200,
    callbacks=[
        ReduceLROnPlateau(
            monitor='val_loss',
            patience=75,
            verbose=1
        )
    ]
)

In [None]:
test_predictions = model.predict(test_dataset)

In [10]:
import matplotlib.pyplot as plt
from ipywidgets import interact
# %matplotlib notebook


def view_predictions_and_labels(test_predictions, test_y_resized, images):
    """
    Visualizes the predictions and labels side by side with a slider to select the index along the 0th axis.

    Parameters:
    - test_predictions: A NumPy array or TensorFlow tensor containing the predicted images.
    - test_y_resized: A NumPy array or TensorFlow tensor containing the ground truth images.
    """
    # Ensure inputs are NumPy arrays
    if not isinstance(test_predictions, np.ndarray):
        test_predictions = test_predictions.numpy()
    if not isinstance(test_y_resized, np.ndarray):
        test_y_resized = test_y_resized.numpy()
    if not isinstance(images, np.ndarray):
        images = images.numpy()
    
    # Define a function to update the plots based on the slider index
    def update(index):
        plt.figure(figsize=(10, 5))
        
        # Plot the prediction
        plt.subplot(1, 3, 1)
        plt.imshow(test_predictions[index, :, :, 0], cmap="viridis", vmin=0, vmax=1)
        plt.title("Prediction")
        plt.axis("off")
        
        # Plot the ground truth
        plt.subplot(1, 3, 2)
        plt.imshow(test_y_resized[index, :, :, 0], cmap="viridis", vmin=0, vmax=1)
        plt.title("Ground Truth")
        plt.axis("off")
        
        plt.subplot(1, 3, 3)
        plt.imshow(images[index, :, :, 0], cmap="viridis")
        plt.title("First channel")
        plt.axis("off")
        
        plt.suptitle(f'Dice: {dice_loss(test_y_resized[index], test_predictions[index])}')
        
        plt.tight_layout()
        # plt.show()
    
    # Create the interactive slider
    interact(update, index=(0, test_predictions.shape[0] - 1))

In [None]:
view_predictions_and_labels(test_predictions, tf.cast((test_y_mmap > 0.5), tf.float32), test_x_mmap)