In [1]:
import os
import shutil
import importlib
import copy
import glob
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time
import torch
import torchaudio
from tqdm import tqdm  # progress bar
import phaselocknet_model
import util

importlib.reload(phaselocknet_model)
importlib.reload(util)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [2]:
regex_filenames = "/media/marmoset/data/code_package/phaselock/localization/bandwidth_dependency_stim.hdf5"
sr = 50000 if "localization" in regex_filenames else 20000
sr_src = 44100
dataset = util.HDF5Dataset(regex_filenames)

In [15]:
dir_model = "../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch09"

model, config_model = phaselocknet_model.get_model(
    dir_model=dir_model,
    fn_config="config.json",
    fn_arch="arch.json",
)
# Load model weights from torch instead of tensorflow checkpoint
util.load_model_checkpoint(
    model=model.perceptual_model,
    dir_model=dir_model,
    fn_ckpt="ckpt_BEST.pt",
    weights_only=True,
)
model.train(mode=False)
model.to(device)
assert not model.training

[get_model] dir_model='../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch09'
[get_model] |__ input_shape=[2, 60000, 2]
[get_model] |__ config_random_slice={'size': [50, 10000], 'buffer': [0, 1000]}
[load_model_checkpoint] ../phaselocknet_torch/models/sound_localization/simplified_IHC3000/arch09/ckpt_BEST.pt


In [16]:
resampler = torchaudio.transforms.Resample(orig_freq=sr_src, new_freq=sr)
tap_name = "perceptual_model.body.fc_intermediate"
# tap_name = "perceptual_model.head.label_loc_int.fc_output"
handle = model.get_submodule(tap_name)
features = {}
def hook(_m, _in, output):
    features[tap_name] = output.detach()      
hook_handle = handle.register_forward_hook(hook)
model.eval()

all_feats, all_azim, all_bandwidth, all_frequency, all_freq_ref = [], [], [], [], []
with torch.inference_mode():
    for i in tqdm(range(len(dataset))):
        sample = dataset[i]  # dict per sample
        x_raw = torch.tensor(sample["signal"], dtype=torch.float32)[None, ...] # (1, T, C)
        x_pre = torch.stack([resampler(x_raw[..., ch]) for ch in range(x_raw.shape[-1])], dim=-1)
        x = util.pad_or_trim_to_len(x_pre, n=model.input_shape[1], dim=1)
        features.clear()
        _ = model(x.to(device))
        feat = features[tap_name].detach().squeeze(0).cpu().numpy()  # torch.Size([1, 512]) to array(512,)
        all_feats.append(feat)
        all_azim.append(sample["foreground_azimuth"]) # 0 to 375 smooth change
        all_bandwidth.append(sample["bandwidth"]) # 0 1 2 3 4
        all_frequency.append(sample["f"]) # 376 words in speech
        all_freq_ref.append(sample["f_ref"])

hook_handle.remove()
all_feats_np = np.stack(all_feats, axis=0)  # (N, 512)
np.save("bandW_all_emb_penultimate_512.npy", all_feats_np)
# np.save("bandW_all_azim.npy", np.array(all_azim))
# np.save("bandW_all_bandwidth.npy", np.array(all_bandwidth))
# np.save("bandW_all_frequency.npy", np.array(all_frequency))
# np.save("bandW_all_freq_ref.npy", np.array(all_freq_ref))


100%|█████████████████████████████████████| 55500/55500 [25:02<00:00, 36.94it/s]
