In [1]:
import os
import shutil
import importlib
import copy
import glob
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time
import torch
import torchaudio
from tqdm import tqdm  # progress bar
import phaselocknet_model
import util
import h5py
from util import get_hdf5_dataset_key_list

importlib.reload(phaselocknet_model)
importlib.reload(util)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [2]:
regex_filenames = "/media/marmoset/data/code_package/phaselock/localization/spectral_smoothing_stim.hdf5"
sr = 50000 if "localization" in regex_filenames else 20000
sr_src = 44100
# build a key list that excludes 'sr' (and optionally anything else mismatched)
with h5py.File(regex_filenames, "r") as f:
    all_keys = get_hdf5_dataset_key_list(f)
    # keep only numeric datasets AND with length matching the main arrays
    # (here explicitly dropping 'sr')
    keys = []
    for k in all_keys:
        d = f[k]
        if not np.issubdtype(d.dtype, np.number):
            continue
        if k == "sr":
            continue
        keys.append(k)

print("Using keys:", keys)

dataset = util.HDF5Dataset(regex_filenames, keys=keys)

Using keys: ['dbspl', 'foreground_azimuth', 'foreground_elevation', 'foreground_wav_clip_pos', 'idx', 'index', 'noise_high', 'noise_low', 'signal', 'smoothed']


In [3]:
dir_model_all = [f"../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch{i:02d}" for i in range(1, 11)]

for m in tqdm(range(len(dir_model_all))):
    dir_model = dir_model_all[m]
    
    model, config_model = phaselocknet_model.get_model(
        dir_model=dir_model,
        fn_config="config.json",
        fn_arch="arch.json",
    )
    # Load model weights from torch instead of tensorflow checkpoint
    util.load_model_checkpoint(
        model=model.perceptual_model,
        dir_model=dir_model,
        fn_ckpt="ckpt_BEST.pt",
        weights_only=True,
    )
    model.train(mode=False)
    model.to(device)
    assert not model.training
    
    resampler = torchaudio.transforms.Resample(orig_freq=sr_src, new_freq=sr)
    tap_name = "perceptual_model.body.fc_intermediate"
    # tap_name = "perceptual_model.head.label_loc_int.fc_output"
    handle = model.get_submodule(tap_name)
    features = {}
    def hook(_m, _in, output):
        features[tap_name] = output.detach()      
    hook_handle = handle.register_forward_hook(hook)
    model.eval()
    
    all_feats, all_azimuth, all_elevation, all_smoothed = [], [], [], []
    with torch.inference_mode():
        # for i in tqdm(range(len(dataset))):
        for i in range(len(dataset)):
            sample = dataset[i]  # dict per sample
            if np.isin(sample['foreground_azimuth'], [0, 180]): # front and back at median plane
                x_raw = torch.tensor(sample["signal"], dtype=torch.float32)[None, ...] # (1, T, C)
                x_pre = torch.stack([resampler(x_raw[..., ch]) for ch in range(x_raw.shape[-1])], dim=-1)
                x = util.pad_or_trim_to_len(x_pre, n=model.input_shape[1], dim=1)
                features.clear()
                _ = model(x.to(device))
                feat = features[tap_name].detach().squeeze(0).cpu().numpy()  # torch.Size([1, 512]) to array(512,)
                all_feats.append(feat)
                all_azimuth.append(sample["foreground_azimuth"]) # should always be zero
                all_elevation.append(sample["foreground_elevation"]) # 300 305...355 0 5...55 60
                all_smoothed.append(sample["smoothed"]) # inf, -inf, 0, 0.005, 0.01, to 0.1 (23 in total)
    
    hook_handle.remove()
    all_feats_np = np.stack(all_feats, axis=0)  # (N, 512)
    np.save(f"sm_spectral_IHC3000_A{m+1:02d}_all_emb_penultimate_512.npy", all_feats_np)
    np.save(f"sm_spectral_IHC3000_A{m+1:02d}_all_azimuth.npy", np.array(all_azimuth))
    np.save(f"sm_spectral_IHC3000_A{m+1:02d}_all_elevation.npy", np.array(all_elevation))
    np.save(f"sm_spectral_IHC3000_A{m+1:02d}_all_smoothed.npy", np.array(all_smoothed))


  0%|                                                    | 0/10 [00:00<?, ?it/s]

[get_model] dir_model='../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch01'
[get_model] |__ input_shape=[2, 60000, 2]
[get_model] |__ config_random_slice={'size': [50, 10000], 'buffer': [0, 1000]}
[load_model_checkpoint] ../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch01/ckpt_BEST.pt


 10%|████                                     | 1/10 [06:42<1:00:25, 402.79s/it]

[get_model] dir_model='../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch02'
[get_model] |__ input_shape=[2, 60000, 2]
[get_model] |__ config_random_slice={'size': [50, 10000], 'buffer': [0, 1000]}
[load_model_checkpoint] ../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch02/ckpt_BEST.pt


 20%|████████▌                                  | 2/10 [13:20<53:19, 399.90s/it]

[get_model] dir_model='../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch03'
[get_model] |__ input_shape=[2, 60000, 2]
[get_model] |__ config_random_slice={'size': [50, 10000], 'buffer': [0, 1000]}
[load_model_checkpoint] ../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch03/ckpt_BEST.pt


 30%|████████████▉                              | 3/10 [19:58<46:33, 399.05s/it]

[get_model] dir_model='../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch04'
[get_model] |__ input_shape=[2, 60000, 2]
[get_model] |__ config_random_slice={'size': [50, 10000], 'buffer': [0, 1000]}
[load_model_checkpoint] ../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch04/ckpt_BEST.pt


 40%|█████████████████▏                         | 4/10 [26:51<40:26, 404.46s/it]

[get_model] dir_model='../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch05'
[get_model] |__ input_shape=[2, 60000, 2]
[get_model] |__ config_random_slice={'size': [50, 10000], 'buffer': [0, 1000]}
[load_model_checkpoint] ../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch05/ckpt_BEST.pt


 50%|█████████████████████▌                     | 5/10 [33:33<33:38, 403.74s/it]

[get_model] dir_model='../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch06'
[get_model] |__ input_shape=[2, 60000, 2]
[get_model] |__ config_random_slice={'size': [50, 10000], 'buffer': [0, 1000]}
[load_model_checkpoint] ../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch06/ckpt_BEST.pt


 60%|█████████████████████████▊                 | 6/10 [40:04<26:36, 399.16s/it]

[get_model] dir_model='../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch07'
[get_model] |__ input_shape=[2, 60000, 2]
[get_model] |__ config_random_slice={'size': [50, 10000], 'buffer': [0, 1000]}
[load_model_checkpoint] ../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch07/ckpt_BEST.pt


 70%|██████████████████████████████             | 7/10 [46:57<20:11, 403.78s/it]

[get_model] dir_model='../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch08'
[get_model] |__ input_shape=[2, 60000, 2]
[get_model] |__ config_random_slice={'size': [50, 10000], 'buffer': [0, 1000]}
[load_model_checkpoint] ../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch08/ckpt_BEST.pt


 80%|██████████████████████████████████▍        | 8/10 [53:38<13:25, 402.77s/it]

[get_model] dir_model='../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch09'
[get_model] |__ input_shape=[2, 60000, 2]
[get_model] |__ config_random_slice={'size': [50, 10000], 'buffer': [0, 1000]}
[load_model_checkpoint] ../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch09/ckpt_BEST.pt


 90%|████████████████████████████████████▉    | 9/10 [1:00:29<06:45, 405.39s/it]

[get_model] dir_model='../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch10'
[get_model] |__ input_shape=[2, 60000, 2]
[get_model] |__ config_random_slice={'size': [50, 10000], 'buffer': [0, 1000]}
[load_model_checkpoint] ../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch10/ckpt_BEST.pt


100%|████████████████████████████████████████| 10/10 [1:07:01<00:00, 402.10s/it]


In [4]:
all_feats_np.shape

(2340, 512)

In [5]:
f"sm_spectral_IHC3000_A{m+1:02d}_all_emb_penultimate_512.npy"

'sm_spectral_IHC3000_A10_all_emb_penultimate_512.npy'