## Set up paths and imports

In [None]:
import os

import torch
from torchvision import transforms
import matplotlib.pyplot as plt

if not os.path.exists("./notebooks"):
    %cd ..

import src.model
from PIL import Image
from src.data_processing import load_mean_std
from src.config import DATASET_DIR
from src.dataset_analysis import plot_spectrogram

## Load Model and Dataset

In [None]:
from src.dataset import prepare_dataset_loaders

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

name = "OriginalSizeCNN"
model = src.model.OriginalSizeCNN()
model_path = f"./models/{name}.pth"
model.load_state_dict(torch.load(model_path, weights_only=True))


mean, std = load_mean_std(f"{DATASET_DIR}/scaling_params.json")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

image_path = "datasets/test/f8_script1_ipad_office1_25_clip.png"
sample_image = Image.open(image_path).convert("L")
sample_image = transform(sample_image)
plot_spectrogram(plt.imread(image_path), "sample_image")


#sample_image = test_loader.dataset[0][0]


## Get feture map for first Convolutional Layer

In [None]:
c1_layer = src.model.ModelWithLayerOutput(model,"conv1")
c1_layer.device = device
output = c1_layer(sample_image).detach().numpy()

## Visualize feature map

In [None]:
def plot_feature_map(feature_map, max_grid):
    
    fig, ax = plt.subplots(max_grid, max_grid, figsize=(7,7))
    channel_idx = 0
    
    for i in range(max_grid):
        for j in range(max_grid):
            ax[i][j].imshow(feature_map[channel_idx,:,:])
            ax[i][j].axis('off')
            
            channel_idx += 1
            
    fig.suptitle(f'Feature Map - Displaying {max_grid**2} of {feature_map.shape[0]} Channels')
    fig.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.show()

In [None]:
plot_feature_map(output, 4)

## Compare feature maps from third layer

In [None]:
c3_layer = src.model.ModelWithLayerOutput(model,"conv3")
c3_layer.device = device
output = c3_layer(sample_image).detach().numpy()
plot_feature_map(output, 8)