In [None]:
import pycochleagram.cochleagram as cgram
from pycochleagram import utils

import json
import sys
import os 
from os import listdir

import matplotlib.pyplot as plt
import numpy as np

from sklearn.model_selection import KFold
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms

from sound_dataset import SoundDataset_kfold
from AE_architectures import AE_RNN

In [None]:
""" Get NMSE and plot reconstructions for a given trained model"""

In [None]:
path="/Users/ariellerosinski/My Drive/UCL/MSc/Project/Code/full_coch" #path to datasets (needs to be changed)

sound_files=listdir(path)

file_name=[]
files=[]
for file in sound_files:
    if file.endswith(".npy") and file.startswith("full") and file != "full_forest_coch.npy":
        file_name.append(file)
        full_path=[path,file]
        full_path="/".join(full_path)
        files.append(full_path)

k_folds=5
kfold = KFold(n_splits=k_folds, shuffle=False)
foldid=0


def get_windows(files, kfold, fold_id, file_name, split_size=100, norm=True, training=True):
    ls_coch=[]
    ls_file=[]
    ls_file_name=[]
    for i,file in enumerate(files): 
        file=np.load(file)

        for fold, (train_ids, test_ids) in enumerate(kfold.split(file.T)):
            if fold==fold_id:
                if norm==True:
                    if file_name[i].endswith("coch.npy"):
                        file=file/2500
                    elif file_name[i] == "full_pennington_david.npy":
                        file=file/1000
                    else: 
                        file=file/5000

                if training:
                    indices=train_ids
                else:
                    indices=test_ids
                
                file=file.T[indices].T

    
                split_start=0
                split_end=split_start+split_size
                for j in range(file.shape[1]//split_size):
                    window=file[:,split_start:split_end]
                    split_start += split_size
                    split_end += split_size
                    ls_coch.append(window)
                         
    return ls_coch

In [None]:
#Evaluation for save, trained model (need to alter save paths)
loss_function=nn.MSELoss()

hidden_size=32
time_lag=99
burn_in=0
foldid=0

device = torch.device('cpu')

test_data=SoundDataset_kfold(get_windows(files, kfold, foldid, file_name, norm=True, training=False), transform = transforms.Compose([transforms.ToTensor(),]),uint8=False ) 
loader = DataLoader(test_data,batch_size=1,shuffle=False) 

model = AE_RNN(time_lag=time_lag, burn_in=burn_in, hidden_size=hidden_size).float() 
model_path="/path_to_.pt_model"
model.load_state_dict(torch.load(model_path, map_location=device))

ls_image = []
ls_reconstructed = []

for i,batch in enumerate(loader):
    image = batch[0].float() 

    initialization=torch.zeros((1,image.shape[0],hidden_size)) 
                    
    reconstructed=model.forward_train(image.float(),initialization,device) 
                             
    x=torch.swapaxes(image[:,0,:,:],1,2)                
    x_list = [x[:,time_lag+burn_in:,:]]
    for i in range(time_lag-1):
        x_t_minus_i=x[:,time_lag+burn_in-(i+1):-(i+1),:]
        x_list.append(x_t_minus_i)                                        
                
    img_comparison=torch.stack(x_list,axis=2) 


    ls_image.append(image)
    ls_reconstructed.append(reconstructed.detach()) 

image_all=np.vstack(ls_image[:-1]).squeeze() 
reconstructed_all=np.vstack(ls_reconstructed[:-1])
reconstructed_all_flipped=np.flip(reconstructed_all,axis=2)                     #flip because autoencoder reconstructs from most to least recent versus original image (least to more recent)
reconstructed_comp=np.swapaxes(reconstructed_all_flipped, -1, -2).squeeze()

#Normalize by the mean squares across all cochleagram in the dataset
MSE=((image_all[:,:,-30:]-reconstructed_comp[:,:,-30:])**2).mean() 
norm = (image_all[:,:,-30:]**2).mean()
mse_norm = MSE/norm


#MSE as a function of time:
ls_loss_fun_time=[]
for t in range(reconstructed_comp.shape[-1]):
    MSE=((image_all[:,:,t+1]-reconstructed_comp[:,:,t])**2).mean()

    norm = (image_all[:,:,t+1]**2).mean()
    ls_loss_fun_time.append(MSE/norm) 


np.save("/save path/image_all", image_all)
np.save("/save path/reconstructed_all_flipped", reconstructed_all_flipped)
np.save("/save path/mse_norm", mse_norm)
np.save("/save path/loss_fun_time", np.array(ls_loss_fun_time))

In [None]:
#Template for plotting cochleagrams 
#Note 1: example for 3 images ids, but could plot more examples e.g., randomly selected ids to evaluate model reconstructions
#Note 2: this is only for comparing original cochleagram with reconstructions from one model, but more columns can be added to integrate other models too (e.g., time scrambled)
#Note 3: that the plotting code below would be similar for reconstruction of the last 30 time steps, except for small changes e.g., tick_labels = [0, 100, 200, 300, 400, 500, 600], ls_img[i][:,-30:] instead of ls_img[i][:,:]

cochleagrams = image_all
ae_reconstructions = np.load("/Users/ariellerosinski/My Drive/UCL/MSc/Project/Thesis/reconstruction_analyses/ae_reconstructed_all_flipped_3.npy").squeeze()
ae_reconstructions = np.swapaxes(ae_reconstructions, -1, -2)

bird_id = 317
frog_id = 53
squirrel_id =903

ls_img = [cochleagrams[bird_id],cochleagrams[frog_id], cochleagrams[squirrel_id]]
ls_ae_recons = [ae_reconstructions[bird_id], ae_reconstructions[frog_id], ae_reconstructions[squirrel_id]]

freq= [35.09722644, 50.0, 65.74421619, 82.37738465, 99.94969756, 118.51418101, 138.1268551, 158.84690291, 180.73684916, 203.86274881, 228.29438645, 254.10548686, 281.37393747, 310.1820234, 340.61667578, 372.76973403, 406.73822305, 442.62464596, 480.53729343, 520.59057046, 562.90534158, 607.60929561, 654.83733097, 704.73196269, 757.44375255, 813.13176334, 871.9640389, 934.11811118, 999.78153594, 1069.15245879, 1142.44021301, 1219.86595135, 1301.66331326, 1388.07913001, 1479.37416949, 1575.8239231, 1677.71943707, 1785.36819074, 1899.09502439, 2019.24311947, 2146.1750342, 2280.27379762, 2421.94406542, 2571.61334099, 2729.7332655, 2896.78098077, 3073.26056903, 3259.70457412, 3456.67560841, 3664.7680506, 3884.60983927, 4116.86436779, 4362.23248612, 4621.45461572, 4895.31298385, 5184.63398397, 5490.29066953, 5813.20538843, 6154.35256633, 6514.76164703, 6895.52019897, 7297.77719703, 7722.74648967, 8171.71046186, 8646.02390477, 9147.118104, 9676.50515859, 10235.78254397, 10826.63793245, 11450.85428595, 12110.31523625, 12807.01076902, 13543.04322881, 14320.63366307, 15142.12852432, 16010.00675089, 16926.88724725, 17895.53678687, 18918.87836118, 20000.0, 21142.16408994]
ls_freq_2f=[]
for value in freq:
    ls_freq_2f.append("%.2f" % value)


plt.figure()
n_plots=3
fig, axs = plt.subplots(n_plots, 3, figsize=(15,8))
fontsize=12
for i in range(n_plots):
    default_y_ticks = range(81)

    img=axs[i, 0].imshow(ls_img[i][:,:],aspect="auto", cmap='magma')                #could set vmin=0,vmax=0.7
    axs[i, 0].axvline(x=70, color='red')
    
    recons_ae=axs[i, 1].imshow(ls_ae_recons[i][:,:],aspect="auto", cmap='magma') 
    axs[i, 1].axvline(x=70, color='red')
    
    axs[i, 0].set_yticks(default_y_ticks,ls_freq_2f,fontsize=fontsize)
    axs[i, 0].locator_params(axis='y', nbins=4)
    axs[i, 0].set_ylabel("Frequency [Hz]",fontsize=fontsize)
    
    if i == (n_plots-1):
        default_xticks = range(99)
        num_ticks = 5
        tick_positions = np.linspace(0, len(default_xticks) - 1, num_ticks)
        tick_labels = [0, 500, 1000, 1500, 2000]

        axs[i, 0].set_xticks(tick_positions, tick_labels,fontsize=fontsize)
        axs[i, 1].set_xticks(tick_positions, tick_labels,fontsize=fontsize)

        axs[i, 0].set_xlabel("Time [ms]",fontsize=fontsize)
        axs[i, 1].set_xlabel("Time [ms]",fontsize=fontsize)

    
        axs[i,0].tick_params(bottom=True, top=False, left=True, right=False,labelbottom=True, labeltop=False, labelleft=True, labelright=False)
        axs[i,1].tick_params(bottom=True, top=False, left=False, right=False,labelbottom=True, labeltop=False, labelleft=False, labelright=False)
    else:

        axs[i,0].tick_params(bottom=False, top=False, left=True, right=False, labelbottom=False, labeltop=False, labelleft=True, labelright=False)
        axs[i,1].tick_params(bottom=False, top=False, left=False, right=False, labelbottom=False, labeltop=False, labelleft=False, labelright=False)

cols = ["Input cochleagram", "Reconstruction"]
for ax, col in zip(axs[0], cols):
    ax.set_title(col,fontsize=13)
plt.subplots_adjust(hspace=0.3)