This notebook enables to investigate about TiSSNet features.

In [None]:
import numpy as np
import scipy
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras import backend as K
from keras import Model
from keras.src.optimizers import Adam
from matplotlib import pyplot as plt

from utils.training.data_loading import load_spectro
from utils.training.keras_models import TiSSNet

## Parameters

In [None]:
epoch = 22  # epoch checkpoint that we want to load
CHECKPOINT = f"../../../../data/model_saves/TiSSNet/cp-{{epoch:04d}}.ckpt" # path of TiSSNet checkpoint to load
SIZE = (128, 186)  # shape of the input images

model = TiSSNet(SIZE)
model.load_weights(CHECKPOINT)
model.summary()

## Initialization: listing the conv layers

In [None]:
conv_layers = []
for layer in model.layers:
    if 'conv' not in layer.name: # maxpooling, flattening
        continue
    filters, biases = layer.get_weights()
    conv_layers.append(layer)
    print(layer.name, filters.shape)

## Filters visualization

In [None]:
block_id = 0  # block we want to inspect (blocks are layers where data keeps the same shape)
conv_id = 1  # conv layer we want to inspect in the chosen block (from 0 to 2)
n_filters, n_channels = 16, 16  # choice of the nb of filters to inspect, and the nb of channels for each
# reminder : a layer with n filters has n channels as output, but maybe more or less as input

layer = 3*block_id+conv_id
filters, biases = conv_layers[layer].get_weights()
f_min, f_max = filters.min(), filters.max()
filters = (filters - f_min) / (f_max - f_min)  # normalization of filters

plot_idx = 1
plt.subplots(n_filters, n_channels, figsize=(20,20))
for i in range(n_filters):
    f = filters[:, :, :, i]
    # plot some channels
    for j in range(n_channels):
         ax = plt.subplot(n_filters, n_channels, plot_idx)
         ax.set_xticks([])
         ax.set_yticks([])
         plt.imshow(f[:, :, j], cmap='gray')
         plot_idx += 1
# show the figure
plt.show()

## Filters outputs visualization

In [None]:
img_path = ".png"  # path of a spectrogram to show
img = load_spectro(img_path, SIZE, 1)  # prepare the spectrogram to feed the model
plt.imshow(img, cmap="inferno")

In [None]:
block_id = 0  # block we want to inspect (blocks are layers where data keeps the same shape)
conv_id = 0  # conv layer we want to inspect in the chosen block (from 0 to 2)
square = 2  # side of the figure. square^2 convolution outputs will be shown

layer = 3*block_id+conv_id
# predict output at the chosen layer
temp_model = Model(inputs=model.inputs, outputs=conv_layers[layer].output)
feature_maps = temp_model.predict((img.numpy()).reshape(1, SIZE[0], SIZE[1], 1))

plot_idx = 1
aspect = 2**(min(block_id,1))*4**(max(block_id-1,0))  # output, kept as a square
plt.subplots(square, square, figsize=(16,12))
for _ in range(square):
    for _ in range(square):
        ax = plt.subplot(square, square, plot_idx)
        ax.set_xticks([])
        ax.set_yticks([])
        plt.imshow(feature_maps[0, :, :, plot_idx-1], cmap='gray', aspect=aspect)
        plot_idx += 1
plt.tight_layout()
plt.savefig("../../../../data/figures/conv1_features.png", dpi=150, bbox_inches='tight')

# Heatmaps to show the contribution of each pixel to the output

In [None]:
block_id = 0  # block we want to inspect (blocks are layers where data keeps the same shape)
conv_id = 0  # conv layer we want to inspect in the chosen block (from 0 to 2)
square = 2  # side of the figure. square^2 convolution outputs will be shown

layer = 3*block_id+conv_id

# get prediction of the model and compute its gradient
heatmap_model = Model(inputs=model.inputs, outputs=[conv_layers[layer].output, model.output])
with tf.GradientTape() as gtape:
    conv_output, predictions = heatmap_model((img.numpy()).reshape(1, SIZE[0], SIZE[1], 1))
    loss = K.mean(predictions[0, np.argmax(predictions[0])])
    grads = gtape.gradient(loss, conv_output)

# plot the resulting heatmaps
plt.subplots(square, square, figsize=(16,12))
plot_idx=0
for _ in range(square):
    for _ in range(square):
        plt.subplot(square, square, plot_idx+1)
        
        # multiply the output of the filter by its gradient
        heatmap = np.multiply(np.array(grads)[:,:,:,plot_idx], conv_output[:,:,:,plot_idx])
        heatmap = np.array(heatmap)
        heatmap = np.maximum(heatmap, 0)
        max_heat = np.max(heatmap)
        heatmap /= max_heat
    
        plt.imshow(heatmap.reshape((SIZE[0],SIZE[1],1)), cmap='jet')
        plt.axis('off')
        plot_idx += 1
plt.tight_layout()
plt.savefig("../../../../data/figures/conv1_heatmaps.png", dpi=150, bbox_inches='tight')

## Output fitting
Starting from an image and an objective output, tune the image to minimize the loss of the model

In [None]:
# we define a gaussian that the model will try to match
mean = 60
var = 4
# now choose an input image from which the optimization process will start
start_img_path = ".png"

law = scipy.stats.norm(mean, var)
expected_output = [law.pdf(i) for i in range(SIZE[1])]

# plot the output we want the network to give
plt.plot(expected_output)
plt.xlim(0, SIZE[1])
plt.show()


start_img = load_spectro(start_img_path, SIZE, 1)

# show the input image the network starts with
plt.imshow(start_img, cmap="inferno")
plt.show()

In [None]:
# optimizer
optimizer = Adam(learning_rate=0.1)
num_iterations = 1000

# assign the image as a tf variable for optimization
input_to_optimize = tf.Variable(np.copy(start_img).reshape((1, SIZE[0], SIZE[1], 1)))

# define the loss function as the difference between the actual output and the wanted output
@tf.function
def loss():
    output = model(input_to_optimize, training=False)
    return tf.keras.losses.mean_squared_error(expected_output, output)

# optimization loop
for i in (pbar:=tqdm(range(num_iterations))):
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(input_to_optimize)
        output = model(input_to_optimize, training=False)
        
        # compute loss and gradient
        l = tf.keras.losses.mean_squared_error(expected_output, output)
        g = tape.gradient(l, input_to_optimize)
        
        # apply optimization
        optimizer.apply_gradients(zip([g], [input_to_optimize]))
        
        # show the loss in the progress bar
        pbar.set_postfix_str(str(l))

print("Final loss:", loss().numpy())

In [None]:
# take a look at the modified image
new_img = input_to_optimize.numpy().reshape((128,186,1))
plt.imshow(new_img, cmap="inferno")
plt.show()
# plot the current output of the network given this image as input, together with the wanted output
plt.plot(expected_output, label="objective")
plt.plot(model.predict(input_to_optimize).reshape(SIZE[1]), label="optimization result")
plt.legend(loc="upper right")
plt.show()