# MAML Notebook

The training process of MAML

## 1. Load libraries


In [None]:
%cd '/content/drive/MyDrive/Meta_learning_research/Notebooks/'
%pip install wandb
import os
import wandb
import numpy as np
import datetime
import tensorflow as tf
import matplotlib.pyplot as plt
from data_util import MetaDataLoader

/content/drive/MyDrive/Meta_learning_research/Notebooks


## 2. Prepare training data

Each episode is one location (task).

In [None]:
data_dir = './samples/'  # Replace with the path to your directory containing numpy files
locations_meta_training = ['Alexander', 'Rowancreek']
locations_meta_testing = ['Covington']
num_samples_per_location = 25  # Configure the number of samples per location
num_episodes = 10  # Number of episodes
normalization_type='-1' # the lower end of the normalized range  ("0" or "-1")
data_loader = MetaDataLoader(data_dir, num_samples_per_location, normalization_type)

In [None]:
# Create multi episodes for meta-training
meta_train_episodes = data_loader.create_multi_episodes(num_episodes, locations_meta_training)

Alexander
Alexander
Rowancreek
Rowancreek
Alexander
Rowancreek
Alexander
Rowancreek
Rowancreek
Alexander


## 3. Define the model

We are using simple U-net model with 700K trainable variables.  

### Define Dice_loss (same as F1-score)

In [None]:
# Dice coefficient (similar to F1 score but differentiable)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-6  # Small constant to avoid division by zero
    y_true_f = tf.cast(tf.reshape(y_true, [-1]), tf.float32)
    y_pred_f = tf.cast(tf.reshape(y_pred, [-1]), tf.float32)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

# Dice loss to be minimized
def dice_loss(y_true, y_pred):
    return 1 - dice_coefficient(y_true, y_pred)

### Define Attention U-net model

In [None]:
from tensorflow.keras.layers import ( Input, Conv2D, BatchNormalization, Activation, Conv2DTranspose,
                                        MaxPooling2D, Layer, add, multiply, GlobalAveragePooling2D,
                                        Dense, Reshape, Multiply )
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K

class ChannelAttention(Layer):
    def __init__(self, reduction_ratio=8, **kwargs):
        super(ChannelAttention, self).__init__(**kwargs)
        self.reduction_ratio = reduction_ratio

    def build(self, input_shape):
        channel = input_shape[-1]
        self.fc1 = Dense(channel // self.reduction_ratio, activation='relu', kernel_initializer='he_normal', use_bias=True, bias_initializer='zeros')
        self.fc2 = Dense(channel, activation='sigmoid', kernel_initializer='he_normal', use_bias=True, bias_initializer='zeros')
        super(ChannelAttention, self).build(input_shape)

    def call(self, inputs):
        avg_pool = GlobalAveragePooling2D()(inputs)
        avg_pool = Reshape((1, 1, avg_pool.shape[1]))(avg_pool)
        fc1_out = self.fc1(avg_pool)
        fc2_out = self.fc2(fc1_out)
        scale = Multiply()([inputs, fc2_out])
        return scale

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = super(ChannelAttention, self).get_config()
        config.update({'reduction_ratio': self.reduction_ratio})
        return config

class SpatialAttention(Layer):
    def __init__(self, kernel_size=7, **kwargs):
        super(SpatialAttention, self).__init__(**kwargs)
        self.kernel_size = kernel_size

    def build(self, input_shape):
        self.conv = Conv2D(1, self.kernel_size, padding='same', activation='sigmoid', kernel_initializer='he_normal', use_bias=False)
        super(SpatialAttention, self).build(input_shape)

    def call(self, inputs):
        avg_pool = K.mean(inputs, axis=3, keepdims=True)
        max_pool = K.max(inputs, axis=3, keepdims=True)
        concat = K.concatenate([avg_pool, max_pool], axis=3)
        attention = self.conv(concat)
        return multiply([inputs, attention])

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = super(SpatialAttention, self).get_config()
        config.update({'kernel_size': self.kernel_size})
        return config

class AttentionUnet:
    def __init__(self, img_width=224, input_channels=8, output_mask_channels=1, filters=32, last_dropout=0.2):
        self.img_width = img_width
        self.input_channels = input_channels
        self.output_mask_channels = output_mask_channels
        self.filters = filters
        self.last_dropout = last_dropout

    def residual_cnn_block(self, x, size, dropout=0.0, batch_norm=True):
        conv = Conv2D(size, (3, 3), padding='same')(x)
        if batch_norm:
            conv = BatchNormalization()(conv)
        conv = Activation('relu')(conv)
        conv = Conv2D(size, (3, 3), padding='same')(conv)
        if batch_norm:
            conv = BatchNormalization()(conv)
        conv = Activation('relu')(conv)
        return conv

    def attention_up_and_concatenate(self, inputs, attention_type):
        g, x = inputs
        if attention_type == 'spatial':
            attention_layer = SpatialAttention()
        elif attention_type == 'channel':
            attention_layer = ChannelAttention()
        x = attention_layer(x)
        inter_channel = x.get_shape().as_list()[3]
        g = Conv2DTranspose(inter_channel, (3,3), strides=(2, 2), padding='same')(g)
        return add([g, x])

    def build_model(self):
        inputs = Input((self.img_width, self.img_width, self.input_channels))
        filters = self.filters

        conv_224 = self.residual_cnn_block(inputs, filters)
        pool_112 = MaxPooling2D(pool_size=(2, 2))(conv_224)
        conv_112 = self.residual_cnn_block(pool_112, filters * 2)
        pool_56 = MaxPooling2D(pool_size=(2, 2))(conv_112)
        conv_56 = self.residual_cnn_block(pool_56, filters * 4)
        pool_28 = MaxPooling2D(pool_size=(2, 2))(conv_56)
        conv_28 = self.residual_cnn_block(pool_28, filters * 8)
        pool_14 = MaxPooling2D(pool_size=(2, 2))(conv_28)
        conv_14 = self.residual_cnn_block(pool_14, filters * 16)
        pool_7 = MaxPooling2D(pool_size=(2, 2))(conv_14)
        conv_7 = self.residual_cnn_block(pool_7, filters * 32)

        # Upsampling path
        up_14 = self.attention_up_and_concatenate([conv_7, conv_14], 'spatial')
        up_conv_14 = self.residual_cnn_block(up_14, filters * 16)
        up_28 = self.attention_up_and_concatenate([up_conv_14, conv_28], 'spatial')
        up_conv_28 = self.residual_cnn_block(up_28, filters * 8)
        up_56 = self.attention_up_and_concatenate([up_conv_28, conv_56], 'channel')
        up_conv_56 = self.residual_cnn_block(up_56, filters * 4)
        up_112 = self.attention_up_and_concatenate([up_conv_56, conv_112], 'channel')
        up_conv_112 = self.residual_cnn_block(up_112, filters * 2)
        up_224 = self.attention_up_and_concatenate([up_conv_112, conv_224], 'channel')
        up_conv_224 = self.residual_cnn_block(up_224, filters, dropout=self.last_dropout)

        # Output layer
        conv_final = Conv2D(self.output_mask_channels, (1, 1), activation='sigmoid')(up_conv_224)

        # Create model
        model = Model(inputs, conv_final, name="AttentionUnet")
        return model


### Define Simple U-net model

In [None]:
import tensorflow as tf

class SimpleUNet:
    def __init__(self, input_shape=(224, 224, 8), num_classes=1):
        self.input_shape = input_shape
        self.num_classes = num_classes

    def build_model(self):
        inputs = tf.keras.Input(shape=self.input_shape)

        # Downsample
        c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(inputs)
        c1 = tf.keras.layers.Dropout(0.1)(c1)
        c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
        p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)

        # Bottleneck
        c5 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
        c5 = tf.keras.layers.Dropout(0.2)(c5)
        c5 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)

        # Upsample
        u6 = tf.keras.layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c5)
        u6 = tf.keras.layers.concatenate([u6, c1])
        c6 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
        c6 = tf.keras.layers.Dropout(0.1)(c6)
        c6 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

        # Output layer
        outputs = tf.keras.layers.Conv2D(self.num_classes, (1, 1), activation='sigmoid')(c6)

        model = tf.keras.Model(inputs=inputs, outputs=outputs)
        return model

## 4. Define the training process (MAML)

### The training with constant LR


In [None]:

def train_task_model(model, inputs, outputs, learning_rate=0.001, momentum=0.5):
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=momentum)
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = dice_loss(outputs, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return model, loss  # Return the updated model

def maml_model(base_model, episodes, meta_lr=0.001, inner_lr=0.0035, momentum=0.9, meta_batch_size=1, inner_steps=10, epochs=10):
    meta_optimizer = tf.keras.optimizers.Adam(learning_rate=meta_lr)

    for epoch in range(epochs):
        task_losses = []
        for batch_index in range(meta_batch_size):
            task_updates = []
            for episode_index, episode in enumerate(episodes):
                # Copy model for task-specific training
                model_copy = tf.keras.models.clone_model(base_model)
                model_copy.set_weights(base_model.get_weights())

                print(f"Epoch {epoch + 1}, Meta-batch {batch_index + 1}, Starting training on episode {episode_index + 1}")

                # Inner loop: Task-specific adjustments with momentum using support set
                support_data = episode["support_set_data"]
                support_labels = episode["support_set_labels"]
                for step in range(inner_steps):
                    model_copy, inner_loss= train_task_model(model_copy, support_data, support_labels, learning_rate=inner_lr, momentum=momentum)
                    print(f" -- Inner step {step + 1}, Loss {inner_loss}")

                # Evaluate the adapted model on the query set
                query_data = episode["query_set_data"]
                query_labels = episode["query_set_labels"]
                val_predictions = model_copy(query_data)
                val_loss = dice_loss(query_labels, val_predictions)
                task_losses.append(val_loss.numpy())  # Record validation loss
                print(f" -- Validation loss after adapting to episode {episode_index + 1}: {val_loss.numpy()}")

              # Compute gradients for meta-update using the base model's variables
                with tf.GradientTape() as meta_tape:
                    # Watch the copy's variables directly for changes
                    meta_tape.watch(model_copy.trainable_variables)
                    new_val_loss = dice_loss(query_labels, model_copy(query_data))
                gradients = meta_tape.gradient(new_val_loss, model_copy.trainable_variables)

                # Map gradients back to the base model's variables
                mapped_gradients = [tf.identity(grad) for grad in gradients]
                task_updates.append((mapped_gradients, new_val_loss))

            # Outer loop: Update the base model using aggregated gradients from all tasks
            if task_updates:
                num_variables = len(base_model.trainable_variables)
                mean_gradients = []
                for i in range(num_variables):
                    grads = [update[0][i] for update in task_updates if update[0][i] is not None]
                    if grads:
                        mean_grad = tf.reduce_mean(tf.stack(grads), axis=0)
                        mean_gradients.append(mean_grad)
                    else:
                        mean_gradients.append(None)  # Handle the case where all gradients for a variable are None

                # Only apply gradients that are not None
                gradients_to_apply = [(grad, var) for grad, var in zip(mean_gradients, base_model.trainable_variables) if grad is not None]
                if gradients_to_apply:
                    meta_optimizer.apply_gradients(gradients_to_apply)

        print(f"Epoch {epoch + 1} completed, Mean Validation Loss across all episodes: {tf.reduce_mean(task_losses)}")

    return base_model

### The training with dynamics LR



In [None]:
def train_task_model(model, inputs, outputs, optimizer):
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = dice_loss(outputs, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return model, loss.numpy()

def maml_model(base_model, episodes, initial_meta_lr=0.001, initial_inner_lr=0.001, decay_steps=1000, decay_rate=0.96, meta_batch_size=1, inner_steps=1, epochs=500, patience=15, save_path='models'):

    date_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    name = f'maml_{inner_steps}_{epochs}_{meta_batch_size}_{date_time}_best_model'
    model_path = os.path.join(save_path, name)
    print(f'Initialize the Training process: {name}')

    # Initialize WandB
    wandb.init(project="maml_experiment",
               name = name,
               config={
                    "initial_meta_lr": initial_meta_lr,
                    "initial_inner_lr": initial_inner_lr,
                    "decay_steps": decay_steps,
                    "decay_rate": decay_rate,
                    "meta_batch_size": meta_batch_size,
                    "inner_steps": inner_steps,
                    "epochs": epochs,
                    "patience": patience,
                    "model_path": model_path,
                    "model_type": "unet"
                })


    # Define learning rate schedules
    inner_lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_inner_lr, decay_steps, decay_rate, staircase=True)
    meta_lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_meta_lr, decay_steps, decay_rate, staircase=True)
    meta_optimizer = tf.keras.optimizers.Adam(learning_rate=meta_lr_schedule)

    best_loss = float('inf')
    no_improvement_count = 0  # Counter to track the number of epochs without improvement

    for epoch in range(epochs):
        task_losses = []
        for batch_index in range(meta_batch_size):
            task_updates = []
            for episode_index, episode in enumerate(episodes):
                # Copy model for task-specific training
                model_copy = tf.keras.models.clone_model(base_model)
                model_copy.set_weights(base_model.get_weights())

                inner_optimizer = tf.keras.optimizers.SGD(learning_rate=inner_lr_schedule(epoch * len(episodes) + episode_index))

                # Inner loop: Task-specific adjustments with dynamic learning rate
                support_data = episode["support_set_data"]
                support_labels = episode["support_set_labels"]
                episode_losses = []
                for step in range(inner_steps):
                    model_copy, loss = train_task_model(model_copy, support_data, support_labels, inner_optimizer)
                    episode_losses.append(loss)

                # Evaluate the adapted model on the query set
                query_data = episode["query_set_data"]
                query_labels = episode["query_set_labels"]
                val_predictions = model_copy(query_data)
                val_loss = dice_loss(query_labels, val_predictions)
                task_losses.append(val_loss.numpy())

                wandb.log({
                    "epoch": epoch,
                    "episode": episode_index,
                    "eps_loss": tf.reduce_mean(episode_losses),
                    "eps_val_loss": val_loss.numpy()
                })

                # Compute gradients for meta-update using the base model's variables
                with tf.GradientTape() as meta_tape:
                    meta_tape.watch(model_copy.trainable_variables)
                    new_val_loss = dice_loss(query_labels, model_copy(query_data))
                gradients = meta_tape.gradient(new_val_loss, model_copy.trainable_variables)

                # Map gradients back to the base model's variables
                mapped_gradients = [tf.identity(grad) for grad in gradients]
                task_updates.append((mapped_gradients, new_val_loss))

            # Outer loop: Update the base model using aggregated gradients from all tasks
            if task_updates:
                num_variables = len(base_model.trainable_variables)
                mean_gradients = []
                for i in range(num_variables):
                    grads = [update[0][i] for update in task_updates if update[0][i] is not None]
                    if grads:
                        mean_grad = tf.reduce_mean(tf.stack(grads), axis=0)
                        mean_gradients.append(mean_grad)
                    else:
                        mean_gradients.append(None)  # Handle the case where all gradients for a variable are None

                gradients_to_apply = [(grad, var) for grad, var in zip(mean_gradients, base_model.trainable_variables) if grad is not None]
                if gradients_to_apply:
                    meta_optimizer.apply_gradients(gradients_to_apply)

        mean_loss = tf.reduce_mean(task_losses)
        wandb.log({
            "epoch": epoch,
            "mean_val_loss": mean_loss
        })

        # Early stopping and model saving
        if mean_loss < best_loss:
            best_loss = mean_loss
            no_improvement_count = 0
            base_model.save(model_path)  # Save the best model
            print(f"Saved new best model with validation loss: {best_loss}")

        else:
            no_improvement_count += 1
            if no_improvement_count >= patience:
                print(f"No improvement for {patience} consecutive epochs, stopping training.")
                break  # Stop training if no improvement in 'patience' number of epochs

        print(f"Epoch {epoch + 1} completed, Mean Validation Loss across all episodes: {mean_loss}")

    wandb.finish()
    return base_model, name

## 5. Train the model with the Meta-training set.

In [None]:
# Example usage:
unet_model = SimpleUNet(input_shape=(224, 224, 8))
model = unet_model.build_model()
model.summary()

# Dynamic LR
adapted_model, name = maml_model(model, meta_train_episodes, initial_meta_lr=0.001, initial_inner_lr=0.001, decay_steps=1000, decay_rate=0.96, meta_batch_size=1, inner_steps=1, epochs=500)
print("model name:", name)

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 224, 224, 8)]        0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 224, 224, 16)         1168      ['input_1[0][0]']             
                                                                                                  
 dropout (Dropout)           (None, 224, 224, 16)         0         ['conv2d[0][0]']              
                                                                                                  
 conv2d_1 (Conv2D)           (None, 224, 224, 16)         2320      ['dropout[0][0]']             
                                                                                              

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc




Saved new best model with validation loss: 0.9409248232841492
Epoch 1 completed, Mean Validation Loss across all episodes: 0.9409248232841492




Saved new best model with validation loss: 0.9346744418144226
Epoch 2 completed, Mean Validation Loss across all episodes: 0.9346744418144226




Saved new best model with validation loss: 0.927566647529602
Epoch 3 completed, Mean Validation Loss across all episodes: 0.927566647529602




Saved new best model with validation loss: 0.9184465408325195
Epoch 4 completed, Mean Validation Loss across all episodes: 0.9184465408325195




Saved new best model with validation loss: 0.9053918123245239
Epoch 5 completed, Mean Validation Loss across all episodes: 0.9053918123245239




Saved new best model with validation loss: 0.8863885998725891
Epoch 6 completed, Mean Validation Loss across all episodes: 0.8863885998725891




Saved new best model with validation loss: 0.8591081500053406
Epoch 7 completed, Mean Validation Loss across all episodes: 0.8591081500053406




Saved new best model with validation loss: 0.8217463493347168
Epoch 8 completed, Mean Validation Loss across all episodes: 0.8217463493347168




Saved new best model with validation loss: 0.7717688679695129
Epoch 9 completed, Mean Validation Loss across all episodes: 0.7717688679695129




Saved new best model with validation loss: 0.7100585699081421
Epoch 10 completed, Mean Validation Loss across all episodes: 0.7100585699081421




Saved new best model with validation loss: 0.6444177031517029
Epoch 11 completed, Mean Validation Loss across all episodes: 0.6444177031517029




Saved new best model with validation loss: 0.5865862369537354
Epoch 12 completed, Mean Validation Loss across all episodes: 0.5865862369537354




Saved new best model with validation loss: 0.5432615280151367
Epoch 13 completed, Mean Validation Loss across all episodes: 0.5432615280151367




Saved new best model with validation loss: 0.5119611024856567
Epoch 14 completed, Mean Validation Loss across all episodes: 0.5119611024856567




Saved new best model with validation loss: 0.48743677139282227
Epoch 15 completed, Mean Validation Loss across all episodes: 0.48743677139282227




Saved new best model with validation loss: 0.46768999099731445
Epoch 16 completed, Mean Validation Loss across all episodes: 0.46768999099731445




Saved new best model with validation loss: 0.4495214819908142
Epoch 17 completed, Mean Validation Loss across all episodes: 0.4495214819908142




Saved new best model with validation loss: 0.43353453278541565
Epoch 18 completed, Mean Validation Loss across all episodes: 0.43353453278541565




Saved new best model with validation loss: 0.4186418056488037
Epoch 19 completed, Mean Validation Loss across all episodes: 0.4186418056488037




Saved new best model with validation loss: 0.40500983595848083
Epoch 20 completed, Mean Validation Loss across all episodes: 0.40500983595848083




Saved new best model with validation loss: 0.391934871673584
Epoch 21 completed, Mean Validation Loss across all episodes: 0.391934871673584




Saved new best model with validation loss: 0.37920087575912476
Epoch 22 completed, Mean Validation Loss across all episodes: 0.37920087575912476




Saved new best model with validation loss: 0.3673875331878662
Epoch 23 completed, Mean Validation Loss across all episodes: 0.3673875331878662




Saved new best model with validation loss: 0.3570271134376526
Epoch 24 completed, Mean Validation Loss across all episodes: 0.3570271134376526




Saved new best model with validation loss: 0.35265636444091797
Epoch 25 completed, Mean Validation Loss across all episodes: 0.35265636444091797




Saved new best model with validation loss: 0.3481772840023041
Epoch 26 completed, Mean Validation Loss across all episodes: 0.3481772840023041




Saved new best model with validation loss: 0.33559709787368774
Epoch 27 completed, Mean Validation Loss across all episodes: 0.33559709787368774
Epoch 28 completed, Mean Validation Loss across all episodes: 0.3368086516857147




Saved new best model with validation loss: 0.3290213942527771
Epoch 29 completed, Mean Validation Loss across all episodes: 0.3290213942527771




Saved new best model with validation loss: 0.3280598521232605
Epoch 30 completed, Mean Validation Loss across all episodes: 0.3280598521232605




Saved new best model with validation loss: 0.3234453797340393
Epoch 31 completed, Mean Validation Loss across all episodes: 0.3234453797340393




Saved new best model with validation loss: 0.31802257895469666
Epoch 32 completed, Mean Validation Loss across all episodes: 0.31802257895469666




Saved new best model with validation loss: 0.31749358773231506
Epoch 33 completed, Mean Validation Loss across all episodes: 0.31749358773231506




Saved new best model with validation loss: 0.30921268463134766
Epoch 34 completed, Mean Validation Loss across all episodes: 0.30921268463134766




Saved new best model with validation loss: 0.3073142468929291
Epoch 35 completed, Mean Validation Loss across all episodes: 0.3073142468929291




Saved new best model with validation loss: 0.302977979183197
Epoch 36 completed, Mean Validation Loss across all episodes: 0.302977979183197




Saved new best model with validation loss: 0.297488272190094
Epoch 37 completed, Mean Validation Loss across all episodes: 0.297488272190094




Saved new best model with validation loss: 0.29652664065361023
Epoch 38 completed, Mean Validation Loss across all episodes: 0.29652664065361023




Saved new best model with validation loss: 0.29122599959373474
Epoch 39 completed, Mean Validation Loss across all episodes: 0.29122599959373474




Saved new best model with validation loss: 0.28808027505874634
Epoch 40 completed, Mean Validation Loss across all episodes: 0.28808027505874634




Saved new best model with validation loss: 0.28707608580589294
Epoch 41 completed, Mean Validation Loss across all episodes: 0.28707608580589294




Saved new best model with validation loss: 0.2828981280326843
Epoch 42 completed, Mean Validation Loss across all episodes: 0.2828981280326843




Saved new best model with validation loss: 0.28056350350379944
Epoch 43 completed, Mean Validation Loss across all episodes: 0.28056350350379944




Saved new best model with validation loss: 0.280374139547348
Epoch 44 completed, Mean Validation Loss across all episodes: 0.280374139547348




Saved new best model with validation loss: 0.2780155837535858
Epoch 45 completed, Mean Validation Loss across all episodes: 0.2780155837535858




Saved new best model with validation loss: 0.2735399901866913
Epoch 46 completed, Mean Validation Loss across all episodes: 0.2735399901866913




Saved new best model with validation loss: 0.27156561613082886
Epoch 47 completed, Mean Validation Loss across all episodes: 0.27156561613082886
Epoch 48 completed, Mean Validation Loss across all episodes: 0.27172595262527466




Saved new best model with validation loss: 0.27119654417037964
Epoch 49 completed, Mean Validation Loss across all episodes: 0.27119654417037964




Saved new best model with validation loss: 0.2672566771507263
Epoch 50 completed, Mean Validation Loss across all episodes: 0.2672566771507263




Saved new best model with validation loss: 0.2641837000846863
Epoch 51 completed, Mean Validation Loss across all episodes: 0.2641837000846863




Saved new best model with validation loss: 0.26362118124961853
Epoch 52 completed, Mean Validation Loss across all episodes: 0.26362118124961853




Saved new best model with validation loss: 0.26258474588394165
Epoch 53 completed, Mean Validation Loss across all episodes: 0.26258474588394165




Saved new best model with validation loss: 0.26001009345054626
Epoch 54 completed, Mean Validation Loss across all episodes: 0.26001009345054626




Saved new best model with validation loss: 0.2575299143791199
Epoch 55 completed, Mean Validation Loss across all episodes: 0.2575299143791199




Saved new best model with validation loss: 0.25636112689971924
Epoch 56 completed, Mean Validation Loss across all episodes: 0.25636112689971924




Saved new best model with validation loss: 0.2560012936592102
Epoch 57 completed, Mean Validation Loss across all episodes: 0.2560012936592102




Saved new best model with validation loss: 0.255760133266449
Epoch 58 completed, Mean Validation Loss across all episodes: 0.255760133266449




Saved new best model with validation loss: 0.25311988592147827
Epoch 59 completed, Mean Validation Loss across all episodes: 0.25311988592147827




Saved new best model with validation loss: 0.25070610642433167
Epoch 60 completed, Mean Validation Loss across all episodes: 0.25070610642433167




Saved new best model with validation loss: 0.24800975620746613
Epoch 61 completed, Mean Validation Loss across all episodes: 0.24800975620746613




Saved new best model with validation loss: 0.24623259902000427
Epoch 62 completed, Mean Validation Loss across all episodes: 0.24623259902000427




Saved new best model with validation loss: 0.24514245986938477
Epoch 63 completed, Mean Validation Loss across all episodes: 0.24514245986938477




Saved new best model with validation loss: 0.24478352069854736
Epoch 64 completed, Mean Validation Loss across all episodes: 0.24478352069854736
Epoch 65 completed, Mean Validation Loss across all episodes: 0.24675822257995605
Epoch 66 completed, Mean Validation Loss across all episodes: 0.24690179526805878
Epoch 67 completed, Mean Validation Loss across all episodes: 0.24504438042640686




Saved new best model with validation loss: 0.2392890900373459
Epoch 68 completed, Mean Validation Loss across all episodes: 0.2392890900373459




Saved new best model with validation loss: 0.2376699000597
Epoch 69 completed, Mean Validation Loss across all episodes: 0.2376699000597
Epoch 70 completed, Mean Validation Loss across all episodes: 0.23952636122703552
Epoch 71 completed, Mean Validation Loss across all episodes: 0.23962163925170898




Saved new best model with validation loss: 0.2366502583026886
Epoch 72 completed, Mean Validation Loss across all episodes: 0.2366502583026886




Saved new best model with validation loss: 0.23287372291088104
Epoch 73 completed, Mean Validation Loss across all episodes: 0.23287372291088104
Epoch 74 completed, Mean Validation Loss across all episodes: 0.2338755577802658
Epoch 75 completed, Mean Validation Loss across all episodes: 0.2372680902481079
Epoch 76 completed, Mean Validation Loss across all episodes: 0.2348136454820633




Saved new best model with validation loss: 0.22886362671852112
Epoch 77 completed, Mean Validation Loss across all episodes: 0.22886362671852112
Epoch 78 completed, Mean Validation Loss across all episodes: 0.2347916066646576
Epoch 79 completed, Mean Validation Loss across all episodes: 0.23572103679180145




Saved new best model with validation loss: 0.22875580191612244
Epoch 80 completed, Mean Validation Loss across all episodes: 0.22875580191612244
Epoch 81 completed, Mean Validation Loss across all episodes: 0.2313677817583084




Saved new best model with validation loss: 0.22841370105743408
Epoch 82 completed, Mean Validation Loss across all episodes: 0.22841370105743408




Saved new best model with validation loss: 0.22653570771217346
Epoch 83 completed, Mean Validation Loss across all episodes: 0.22653570771217346
Epoch 84 completed, Mean Validation Loss across all episodes: 0.23161093890666962




Saved new best model with validation loss: 0.22568681836128235
Epoch 85 completed, Mean Validation Loss across all episodes: 0.22568681836128235




Saved new best model with validation loss: 0.22444219887256622
Epoch 86 completed, Mean Validation Loss across all episodes: 0.22444219887256622
Epoch 87 completed, Mean Validation Loss across all episodes: 0.2270773947238922




Saved new best model with validation loss: 0.22182908654212952
Epoch 88 completed, Mean Validation Loss across all episodes: 0.22182908654212952




Saved new best model with validation loss: 0.2205473929643631
Epoch 89 completed, Mean Validation Loss across all episodes: 0.2205473929643631
Epoch 90 completed, Mean Validation Loss across all episodes: 0.22148850560188293




Saved new best model with validation loss: 0.21824367344379425
Epoch 91 completed, Mean Validation Loss across all episodes: 0.21824367344379425
Epoch 92 completed, Mean Validation Loss across all episodes: 0.2200905829668045
Epoch 93 completed, Mean Validation Loss across all episodes: 0.21951131522655487




Saved new best model with validation loss: 0.21562501788139343
Epoch 94 completed, Mean Validation Loss across all episodes: 0.21562501788139343
Epoch 95 completed, Mean Validation Loss across all episodes: 0.2174651175737381
Epoch 96 completed, Mean Validation Loss across all episodes: 0.21698060631752014




Saved new best model with validation loss: 0.21346597373485565
Epoch 97 completed, Mean Validation Loss across all episodes: 0.21346597373485565
Epoch 98 completed, Mean Validation Loss across all episodes: 0.2143508940935135
Epoch 99 completed, Mean Validation Loss across all episodes: 0.21586613357067108




Saved new best model with validation loss: 0.21105608344078064
Epoch 100 completed, Mean Validation Loss across all episodes: 0.21105608344078064
Epoch 101 completed, Mean Validation Loss across all episodes: 0.21129897236824036
Epoch 102 completed, Mean Validation Loss across all episodes: 0.21386435627937317




Saved new best model with validation loss: 0.21104462444782257
Epoch 103 completed, Mean Validation Loss across all episodes: 0.21104462444782257




Saved new best model with validation loss: 0.20779645442962646
Epoch 104 completed, Mean Validation Loss across all episodes: 0.20779645442962646
Epoch 105 completed, Mean Validation Loss across all episodes: 0.2085815966129303
Epoch 106 completed, Mean Validation Loss across all episodes: 0.20876696705818176




Saved new best model with validation loss: 0.20778293907642365
Epoch 107 completed, Mean Validation Loss across all episodes: 0.20778293907642365




Saved new best model with validation loss: 0.20510800182819366
Epoch 108 completed, Mean Validation Loss across all episodes: 0.20510800182819366




Saved new best model with validation loss: 0.20495395362377167
Epoch 109 completed, Mean Validation Loss across all episodes: 0.20495395362377167
Epoch 110 completed, Mean Validation Loss across all episodes: 0.20558643341064453
Epoch 111 completed, Mean Validation Loss across all episodes: 0.2055923044681549




Saved new best model with validation loss: 0.20446249842643738
Epoch 112 completed, Mean Validation Loss across all episodes: 0.20446249842643738




Saved new best model with validation loss: 0.20260246098041534
Epoch 113 completed, Mean Validation Loss across all episodes: 0.20260246098041534




Saved new best model with validation loss: 0.20104193687438965
Epoch 114 completed, Mean Validation Loss across all episodes: 0.20104193687438965




Saved new best model with validation loss: 0.2003001719713211
Epoch 115 completed, Mean Validation Loss across all episodes: 0.2003001719713211




Saved new best model with validation loss: 0.19992713630199432
Epoch 116 completed, Mean Validation Loss across all episodes: 0.19992713630199432
Epoch 117 completed, Mean Validation Loss across all episodes: 0.19998955726623535
Epoch 118 completed, Mean Validation Loss across all episodes: 0.20074394345283508
Epoch 119 completed, Mean Validation Loss across all episodes: 0.2027205228805542
Epoch 120 completed, Mean Validation Loss across all episodes: 0.20427381992340088
Epoch 121 completed, Mean Validation Loss across all episodes: 0.20093555748462677




Saved new best model with validation loss: 0.19695548713207245
Epoch 122 completed, Mean Validation Loss across all episodes: 0.19695548713207245




Saved new best model with validation loss: 0.19564706087112427
Epoch 123 completed, Mean Validation Loss across all episodes: 0.19564706087112427
Epoch 124 completed, Mean Validation Loss across all episodes: 0.19689524173736572
Epoch 125 completed, Mean Validation Loss across all episodes: 0.1990009993314743
Epoch 126 completed, Mean Validation Loss across all episodes: 0.19714152812957764




Saved new best model with validation loss: 0.19457419216632843
Epoch 127 completed, Mean Validation Loss across all episodes: 0.19457419216632843




Saved new best model with validation loss: 0.19284473359584808
Epoch 128 completed, Mean Validation Loss across all episodes: 0.19284473359584808
Epoch 129 completed, Mean Validation Loss across all episodes: 0.19330359995365143
Epoch 130 completed, Mean Validation Loss across all episodes: 0.1950249969959259
Epoch 131 completed, Mean Validation Loss across all episodes: 0.19546332955360413
Epoch 132 completed, Mean Validation Loss across all episodes: 0.19441866874694824




Saved new best model with validation loss: 0.1917916238307953
Epoch 133 completed, Mean Validation Loss across all episodes: 0.1917916238307953




Saved new best model with validation loss: 0.19007551670074463
Epoch 134 completed, Mean Validation Loss across all episodes: 0.19007551670074463




Saved new best model with validation loss: 0.18909801542758942
Epoch 135 completed, Mean Validation Loss across all episodes: 0.18909801542758942




Saved new best model with validation loss: 0.1887800097465515
Epoch 136 completed, Mean Validation Loss across all episodes: 0.1887800097465515
Epoch 137 completed, Mean Validation Loss across all episodes: 0.18930158019065857
Epoch 138 completed, Mean Validation Loss across all episodes: 0.19240231812000275
Epoch 139 completed, Mean Validation Loss across all episodes: 0.19896148145198822
Epoch 140 completed, Mean Validation Loss across all episodes: 0.19741074740886688
Epoch 141 completed, Mean Validation Loss across all episodes: 0.1895127296447754




Saved new best model with validation loss: 0.18596377968788147
Epoch 142 completed, Mean Validation Loss across all episodes: 0.18596377968788147
Epoch 143 completed, Mean Validation Loss across all episodes: 0.1904224455356598
Epoch 144 completed, Mean Validation Loss across all episodes: 0.19447986781597137
Epoch 145 completed, Mean Validation Loss across all episodes: 0.18936045467853546




Saved new best model with validation loss: 0.1839856505393982
Epoch 146 completed, Mean Validation Loss across all episodes: 0.1839856505393982
Epoch 147 completed, Mean Validation Loss across all episodes: 0.1890573501586914
Epoch 148 completed, Mean Validation Loss across all episodes: 0.19281251728534698
Epoch 149 completed, Mean Validation Loss across all episodes: 0.18725790083408356




Saved new best model with validation loss: 0.1834757775068283
Epoch 150 completed, Mean Validation Loss across all episodes: 0.1834757775068283
Epoch 151 completed, Mean Validation Loss across all episodes: 0.19248972833156586
Epoch 152 completed, Mean Validation Loss across all episodes: 0.1890803873538971
Epoch 153 completed, Mean Validation Loss across all episodes: 0.18387146294116974
Epoch 154 completed, Mean Validation Loss across all episodes: 0.1841268241405487
Epoch 155 completed, Mean Validation Loss across all episodes: 0.18768349289894104




Saved new best model with validation loss: 0.18271894752979279
Epoch 156 completed, Mean Validation Loss across all episodes: 0.18271894752979279




Saved new best model with validation loss: 0.180808424949646
Epoch 157 completed, Mean Validation Loss across all episodes: 0.180808424949646
Epoch 158 completed, Mean Validation Loss across all episodes: 0.1810915172100067
Epoch 159 completed, Mean Validation Loss across all episodes: 0.18088987469673157




Saved new best model with validation loss: 0.17903189361095428
Epoch 160 completed, Mean Validation Loss across all episodes: 0.17903189361095428




Saved new best model with validation loss: 0.17816786468029022
Epoch 161 completed, Mean Validation Loss across all episodes: 0.17816786468029022




Saved new best model with validation loss: 0.17750227451324463
Epoch 162 completed, Mean Validation Loss across all episodes: 0.17750227451324463
Epoch 163 completed, Mean Validation Loss across all episodes: 0.1779015064239502
Epoch 164 completed, Mean Validation Loss across all episodes: 0.1777302324771881
Epoch 165 completed, Mean Validation Loss across all episodes: 0.1788996458053589
Epoch 166 completed, Mean Validation Loss across all episodes: 0.1782594919204712
Epoch 167 completed, Mean Validation Loss across all episodes: 0.17849493026733398
Epoch 168 completed, Mean Validation Loss across all episodes: 0.1780681610107422
Epoch 169 completed, Mean Validation Loss across all episodes: 0.1778966188430786




Saved new best model with validation loss: 0.17598707973957062
Epoch 170 completed, Mean Validation Loss across all episodes: 0.17598707973957062




Saved new best model with validation loss: 0.17504481971263885
Epoch 171 completed, Mean Validation Loss across all episodes: 0.17504481971263885




Saved new best model with validation loss: 0.174386665225029
Epoch 172 completed, Mean Validation Loss across all episodes: 0.174386665225029




Saved new best model with validation loss: 0.1733415424823761
Epoch 173 completed, Mean Validation Loss across all episodes: 0.1733415424823761




Saved new best model with validation loss: 0.17225894331932068
Epoch 174 completed, Mean Validation Loss across all episodes: 0.17225894331932068
Epoch 175 completed, Mean Validation Loss across all episodes: 0.172311469912529
Epoch 176 completed, Mean Validation Loss across all episodes: 0.17545829713344574
Epoch 177 completed, Mean Validation Loss across all episodes: 0.18274717032909393
Epoch 178 completed, Mean Validation Loss across all episodes: 0.18666215240955353
Epoch 179 completed, Mean Validation Loss across all episodes: 0.17679545283317566




Saved new best model with validation loss: 0.17019779980182648
Epoch 180 completed, Mean Validation Loss across all episodes: 0.17019779980182648
Epoch 181 completed, Mean Validation Loss across all episodes: 0.17554374039173126
Epoch 182 completed, Mean Validation Loss across all episodes: 0.18127447366714478
Epoch 183 completed, Mean Validation Loss across all episodes: 0.17577557265758514
Epoch 184 completed, Mean Validation Loss across all episodes: 0.17023026943206787
Epoch 185 completed, Mean Validation Loss across all episodes: 0.18100816011428833
Epoch 186 completed, Mean Validation Loss across all episodes: 0.17251701653003693




Saved new best model with validation loss: 0.1696205586194992
Epoch 187 completed, Mean Validation Loss across all episodes: 0.1696205586194992
Epoch 188 completed, Mean Validation Loss across all episodes: 0.17315752804279327
Epoch 189 completed, Mean Validation Loss across all episodes: 0.17115117609500885




Saved new best model with validation loss: 0.166734978556633
Epoch 190 completed, Mean Validation Loss across all episodes: 0.166734978556633
Epoch 191 completed, Mean Validation Loss across all episodes: 0.16799786686897278
Epoch 192 completed, Mean Validation Loss across all episodes: 0.170289546251297




Saved new best model with validation loss: 0.16659432649612427
Epoch 193 completed, Mean Validation Loss across all episodes: 0.16659432649612427




Saved new best model with validation loss: 0.16490554809570312
Epoch 194 completed, Mean Validation Loss across all episodes: 0.16490554809570312
Epoch 195 completed, Mean Validation Loss across all episodes: 0.1654222309589386
Epoch 196 completed, Mean Validation Loss across all episodes: 0.16638043522834778
Epoch 197 completed, Mean Validation Loss across all episodes: 0.16589495539665222


## 6. Adapt the model to tartget area
After we ge the model that is trained on Alexander and Rowancreek using MAML.  

In [None]:
def adapt_to_new_task(base_model, support_data, support_labels, inner_lr=0.001, inner_steps=1):
    model_copy = tf.keras.models.clone_model(base_model)
    model_copy.set_weights(base_model.get_weights())

    optimizer = tf.keras.optimizers.Adam(learning_rate=inner_lr)
    for i in range(inner_steps):
        with tf.GradientTape() as tape:
            predictions = model_copy(support_data)
            loss = dice_loss(support_labels, predictions)
        gradients = tape.gradient(loss, model_copy.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model_copy.trainable_variables))
        print(f"Inner Step: {i+1}, Loss: {loss.numpy()}")
    return model_copy

def evaluate_adapted_model(model, query_data, query_labels):
    predictions = model(query_data)
    loss = dice_loss(query_labels, predictions)
    return loss.numpy()

In [None]:
locations_meta_testing = ['Covington']
# locations_meta_testing = ['Rowancreek']
# locations_meta_training = ['Alexander', 'Rowancreek']
meta_test_episodes = data_loader.create_multi_episodes(4, locations_meta_testing)


Covington
Covington
Covington
Covington


In [None]:
name = "maml_1_500_1_20240503_201907_best_model"
base_model = tf.keras.models.load_model('models/'+name)

# Example usage:
for test_episode in meta_test_episodes:
    adapted_model = adapt_to_new_task(base_model, test_episode["support_set_data"], test_episode["support_set_labels"], inner_lr=0.001, inner_steps=400)
    test_loss = evaluate_adapted_model(adapted_model, test_episode["query_set_data"], test_episode["query_set_labels"])
    print(f"Test Loss after adaptation: {test_loss}")

adapted_model.save('models/'+name+'/adapted_covington')  # Save the best model



Inner Step: 1, Loss: 0.7070581912994385
Inner Step: 2, Loss: 0.6860436201095581
Inner Step: 3, Loss: 0.648269772529602
Inner Step: 4, Loss: 0.616706132888794
Inner Step: 5, Loss: 0.6115401983261108
Inner Step: 6, Loss: 0.6031354665756226
Inner Step: 7, Loss: 0.588828980922699
Inner Step: 8, Loss: 0.576479434967041
Inner Step: 9, Loss: 0.5720885992050171
Inner Step: 10, Loss: 0.5728148221969604
Inner Step: 11, Loss: 0.5644044876098633
Inner Step: 12, Loss: 0.555587887763977
Inner Step: 13, Loss: 0.5521668195724487
Inner Step: 14, Loss: 0.5511040687561035
Inner Step: 15, Loss: 0.5485281944274902
Inner Step: 16, Loss: 0.5442174077033997
Inner Step: 17, Loss: 0.5417673587799072
Inner Step: 18, Loss: 0.5420837998390198
Inner Step: 19, Loss: 0.5413864850997925
Inner Step: 20, Loss: 0.5384175777435303
Inner Step: 21, Loss: 0.536340594291687
Inner Step: 22, Loss: 0.5360579490661621
Inner Step: 23, Loss: 0.5344541072845459
Inner Step: 24, Loss: 0.5314226746559143
Inner Step: 25, Loss: 0.5293649

In [None]:
%cd '/content/drive/MyDrive/Meta_learning_research/Notebooks/'
input_data = './samples/'
location = "Covinton"

X_test = np.load('/content/drive/MyDrive/Meta_learning_research/Notebooks/samples/Covington/bottom_half_test_data.npy').astype(np.float32)
X_test[X_test < 0] = 0

/content/drive/MyDrive/Meta_learning_research/Notebooks


In [None]:
# This normalization_type was define on the top of the notebook for the dataloader
if normalization_type == '0':
    data_min = 0
    data_max = 255
    X_test_norm = (X_test - data_min) / (data_max - data_min)
elif normalization_type == '-1':
    data_min = 0
    data_max = 255
    X_test_norm = 2 * ((X_test - data_min) / (data_max - data_min)) - 1
elif normalization_type == 'none':
    X_test_norm = X_test
else:
    raise ValueError("Unsupported normalization type. Choose '0-1' or '-1-1'.")

In [None]:
prediction = adapted_model.predict(X_test_norm)

np.save('/content/drive/MyDrive/Meta_learning_research/Notebooks/predicts/'+name+".npy", prediction)
print('/content/drive/MyDrive/Meta_learning_research/Notebooks/predicts/'+name+".npy")

/content/drive/MyDrive/Meta_learning_research/Notebooks/predicts/maml_1_500_1_20240503_201907_best_model.npy


In [None]:
print('/content/drive/MyDrive/Meta_learning_research/Notebooks/predicts/'+name+".npy")

/content/drive/MyDrive/Meta_learning_research/Notebooks/predicts/maml_1_500_1_20240430_013450_best_model.npy


# Visualization for the input data

In [None]:
def visualize_images(image_stack, labels, num_images=25):
    """
    Visualizes images with their respective channels and labels.

    Args:
    - image_stack (numpy.ndarray): An array of shape (N, 224, 224, 8) where N is the number of images.
    - labels (numpy.ndarray): An array of labels of shape (N, 224, 224).
    - num_images (int): Number of images to visualize.
    """
    fig, axs = plt.subplots(nrows=num_images, ncols=9, figsize=(18, 2 * num_images))

    for i in range(min(num_images, image_stack.shape[0])):
        # Normalize each channel for visualization purposes
        for ch in range(8):
            axs[i, ch].imshow(image_stack[i, :, :, ch], cmap='gray', aspect='auto')
            axs[i, ch].axis('off')  # Turn off axis
        # Adding the label image in the last column
        axs[i, 8].imshow(labels[i], cmap='gray', aspect='auto')
        axs[i, 8].axis('off')

    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    plt.show()

In [None]:
eps_idx= 0

visualize_images(meta_test_episodes[eps_idx]["support_set_data"], meta_test_episodes[eps_idx]["support_set_labels"], num_images=25)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
eps_idx= 0

visualize_images(meta_test_episodes[eps_idx]["query_set_data"], meta_test_episodes[eps_idx]["query_set_labels"], num_images=25)