In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%cd /calc/baronig/Projects/cLSNN/romain_fork/SE-adlif
from IPython.display import Audio
from omegaconf import DictConfig, omegaconf
import torch
import torchaudio
from datasets.audio_compress import LibriTTS
from datasets.utils.diskcache import DiskCachedDataset
from models.pl_module_compress import MLPSNN
import os
import plotly.express as px
import plotly.graph_objects as go
print(f"number of cuda devices: {torch.cuda.device_count()}")

/home/baronig/Projects/cLSNN/romain_fork/SE-adlif


  from .autonotebook import tqdm as notebook_tqdm


number of cuda devices: 1


In [47]:
@torch.compiler.disable
def create_audio_example(cfg_path, ckpt_path, example_idx,return_spike_proba):
    cfg = omegaconf.OmegaConf.load(cfg_path)
    cfg.unroll_factor = 1
    # cfg.compile=False
    # cfg.decoder.l_out.cell = 'li'
    # dataset = LibriTTS('/home/romain/datasets/LibriTTS/', sampling_freq=16_000, sample_length=-1, prediction_delay=cfg.dataset.prediction_delay)
    # dataset = LibriTTS(save_to='/calc/baronig/data_sets/LibriTTS/', cache_path='/calc/baronig/data_sets/LibriTTS/', sampling_freq=16_000, sample_length=-1)
    print("loading dataset")
    dataset = LibriTTS(save_to='/calc/baronig/data_sets/LibriTTS/', cache_path='/scratch/baronig/cache/librispeech', sampling_freq=16_000, sample_length=-1)
    print("loading checkpoint")
    ckpt = torch.load(ckpt_path, map_location='cuda:0')
    ckpt_state_dict = ckpt['state_dict']
    model = MLPSNN(cfg)
    model.load_state_dict(ckpt_state_dict)
    model.to("cuda:0")
    model.eval()
    source_waveform = []
    prediction_waveform = []
    spike_probs = []
    bottleneck_spikes = []
    for idx in example_idx:
        print(f"Processing example {idx}")
        inputs, *rest = dataset[idx]
        source_waveform.append(inputs.cpu().numpy().squeeze())
        with torch.no_grad():
            if return_spike_proba:
                states, pred = model.forward_with_states(inputs.to("cuda:0").unsqueeze(0))
                # spike_prob = torch.mean(states[1][1]).item().cpu().numpy()
                spike_prob = torch.mean(states[1][1]).item()
                spike_probs.append(spike_prob)
                bottleneck_spikes.append(states[2][1][0,:,:cfg.encoder.l2.num_out_neuron].cpu().numpy())
            else:
                pred = model(inputs.to("cuda:0").unsqueeze(0))
            pred = pred[:, cfg.dataset.prediction_delay:]
            print(f"pred shape: {pred.shape}")
            pred_wave = model.loss.generate_wave(pred)
        prediction_waveform.append(pred_wave.cpu().numpy().squeeze())
    return source_waveform, prediction_waveform, spike_probs, dataset, bottleneck_spikes


def initialize_dataset(cfg_path, ckpt_path):
    cfg = omegaconf.OmegaConf.load(cfg_path)
    dataset = LibriTTS(save_to='/calc/baronig/data_sets/LibriTTS/', cache_path='/scratch/baronig/cache/librispeech', sampling_freq=16_000, sample_length=-1)
    return dataset

In [15]:
def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return "model." + text[len(prefix) :]
    return text


def repair_checkpoint(path):
    ckpt = torch.load(path)
    in_state_dict = ckpt["state_dict"]
    # in_state_dict = ckpt
    pairings = [
        (src_key, remove_prefix(src_key, "model._orig_mod."))
        for src_key in in_state_dict.keys()
    ]
    if all(src_key == dest_key for src_key, dest_key in pairings):
        return  # Do not write checkpoint if no need to repair!
    out_state_dict = {}
    for src_key, dest_key in pairings:
        print(f"{src_key}  ==>  {dest_key}")
        out_state_dict[dest_key] = in_state_dict[src_key]
    ckpt["state_dict"] = out_state_dict
    torch.save(ckpt, path)

In [41]:
# ckp_path = "/calc/baronig/Projects/sim_results/adlif_rebuttal/compression_task/hydra/2024-12-11/14-42-29/ckpt/last.ckpt"
# ckp_path = "/calc/baronig/Projects/sim_results/adlif_rebuttal/compression_task/hydra/2025-01-14/17-25-47/ckpt/epoch=16-step=66402.ckpt"
# ckp_path = "/calc/baronig/Projects/sim_results/adlif_rebuttal/compression_task/hydra/2025-01-16/18-28-21/ckpt/epoch=1372-step=107094.ckpt"
ckp_path = "/calc/baronig/Projects/sim_results/adlif_rebuttal/compression_task/hydra/2025-01-17/08-46-26/ckpt/epoch=5409-step=421980.ckpt"
# repair_checkpoint(ckp_path)
cfg_path = os.path.join(os.path.dirname(ckp_path), "..", ".hydra", "config.yaml")

dataset = initialize_dataset(cfg_path, ckp_path)

directory already exist
Wave files are resampled to 16000Hz
Chunk map loaded from /scratch/baronig/cache/librispeech/debug/16000/-1_map.pkl
Waves files are splited into -1 samples length


In [16]:
import matplotlib.pyplot as plt
import librosa.display
import torchaudio.transforms as T
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None, name=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)
    if title is not None:
        ax.set_title(title)
    ax.set_ylabel(ylabel)
    ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest")
    # ax.imshow(specgram, origin="lower", aspect="auto", interpolation="nearest")
    # plt.savefig(name)
    return ax

In [42]:

def plot_sample_spectrogram(sample, dataset, name):
    n_fft = 1024
    win_length = None
    hop_length = 512
    n_mels = 128
    sample_rate = 16000

    mel_spectrogram = T.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        center=True,
        pad_mode="reflect",
        power=2.0,
        norm="slaney",
        n_mels=n_mels,
        # mel_scale="htk",
        mel_scale="slaney",
    )
    print(sample.shape)

    melspec = mel_spectrogram(sample.squeeze())
    plot_spectrogram(melspec, title="MelSpectrogram - torchaudio", ylabel="mel freq", name=name)
    return melspec 


In [48]:
# repair_checkpoint(ckp_path)
source_waveform, prediction_waveform, spike_probs, dataset, bottleneck_spikes = create_audio_example(cfg_path, ckp_path, [0], True)

loading dataset
directory already exist
Wave files are resampled to 16000Hz
Chunk map loaded from /scratch/baronig/cache/librispeech/debug/16000/-1_map.pkl
Waves files are splited into -1 samples length
loading checkpoint
{'random_seed': 42, 'logdir': '/calc/baronig/Projects/sim_results/adlif_rebuttal/compression_task', 'datadir': '/calc/baronig/data_sets', 'cachedir': '/scratch/baronig/cache/librispeech', 'device': 'cuda:0', 'dataset': {'_target_': 'datasets.audio_compress.CompressLibri', 'name': 'Compress libri task', 'required_model_size': 'small', 'data_path': '${datadir}', 'cache_path': '${cachedir}', 'max_sample': 10000, 'sampling_freq': 16000, 'sample_length': 512, 'prediction_delay': 20, 'zero_input_proba': 0.0, 'batch_size': 128, 'num_workers': 8, 'fits_into_ram': True, 'num_classes': 1, 'normalization': '-1_1'}, 'enc_l1_neurons': 300, 'enc_l2_neurons': 300, 'dec_l1_neurons': 300, 'dec_l2_neurons': 300, 'dec_lout_neurons': 256, 'bottleneck_neurons': 64, 'main_a_range': [0, 5],


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



pred shape: torch.Size([1, 44940, 256])


In [17]:

import numpy as np
ds_spec = plot_sample_spectrogram(dataset[0][0], dataset, name="ds1.png")
rec_spec = plot_sample_spectrogram(torch.tensor(prediction_waveform[0]).unsqueeze(-1), dataset, name="rec1.png")
plt.show()

plot_spectrogram(np.abs(librosa.power_to_db(ds_spec) - librosa.power_to_db((rec_spec))), title="MelSpectrogram - torchaudio", ylabel="mel freq")

torch.Size([44960, 1])
torch.Size([44940, 1])


  plt.show()


<Axes: title={'center': 'MelSpectrogram - torchaudio'}, ylabel='mel freq'>

In [18]:
Audio(prediction_waveform[0], rate=16_000)

In [83]:
fig = px.line(y=source_waveform[0], title="Source waveform")
fig.add_trace(go.Scatter(y=prediction_waveform[0], mode='lines', name='Prediction waveform'))
fig.show()
# fig.write_image("waveform.pdf")
fr = 30000
to = 31000
# plot with matplotlib
plt.figure(figsize=(3.3, 0.8))
# plt.plot(source_waveform[0][fr:to])
plt.plot(source_waveform[0][fr:to]) #, color=get_sequential_colors()[0])
# plt.plot(prediction_waveform[0][fr:to])
plt.plot(prediction_waveform[0][fr:to])
# remove top and right spines
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
# add legend
plt.legend(['source', 'prediction'])
# set transparent background
plt.gca().patch.set_alpha(0)
# set tick size
plt.xticks(fontsize=7)
plt.yticks(fontsize=7)
# set tick font to arial
plt.xticks(fontname='Arial')
plt.yticks(fontname='Arial')

# plt.show()
# set fig size
plt.show()
plt.savefig("waveform_comparison.svg")

for source in [True,False]:
    fr = 29600
    to =  fr + 500
    # plot with matplotlib
    plt.figure(figsize=(0.5, 0.5), dpi=100)
    # plt.plot(source_waveform[0][fr:to])
    if source:
        plt.plot(source_waveform[0][fr:to], color="black", linewidth=0.5)
    else:
        plt.plot(prediction_waveform[0][fr:to], color="black", linewidth=0.5)
    # plt.plot(prediction_waveform[0][fr:to])
    # plt.plot(prediction_waveform[0][fr:to])
    # remove everything except the plotted line. remove axes and ticks
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['left'].set_visible(False)
    plt.gca().spines['bottom'].set_visible(False)
    plt.gca().get_xaxis().set_visible(False)
    plt.gca().get_yaxis().set_visible(False)

    # tight layout


    # Make the figure and axes fully transparent
    plt.gcf().patch.set_alpha(0)  # Makes the figure background transparent
    plt.gca().patch.set_alpha(0)   # Makes the axes background transparent

    # plt.show()
    # set fig size
    plt.show()
    # plt.tight_layout()
    # Tighten the axis to the data
    plt.axis('tight')

    # Turn off axes for a clean look
    plt.axis('off')

    # Save the figure with no whitespace and transparent background

    plt.savefig(f"waveform_small_source_{source}.svg", transparent=True, bbox_inches='tight', pad_inches=0 )


# plot spikes as plt eventplot
fig = plt.figure(figsize=(0.839, 0.521), dpi=100)
print(fig.get_size_inches())  # Verify the figure size
event_times = [np.where(row == 1)[0] for row in bottleneck_spikes[0][fr:to].T[:16]]
# Plot the spikes
plt.eventplot(event_times, orientation="horizontal", linelengths=0.5, colors='black', linewidths=0.3)
# set font to Arial with 7pt

# remove all borders, ticks, axes, etc
plt.axis('off')

plt.gcf().patch.set_alpha(0)  # Makes the figure background transparent
plt.gca().patch.set_alpha(0)   # Makes the axes background transparent
plt.gca().set_position([0, 0, 1, 1])
# plt.tight_layout()
plt.savefig("spikes.svg", transparent=True, pad_inches=0 , dpi=100)


# Plot as matplotlib eventplot
# fig, ax = plt.subplots()
# fig.set_size_inches(220/100, 80/100)
# # Transparent background
# fig.patch.set_alpha(0)
# ax.set_facecolor('none')

# # Plot the events with vertical lines and rasterize them
# lines = ax.eventplot(sorted_grouped_events, linelengths=5, linewidths=0.3, rasterized=True)

# # Remove top and right border
# ax.spines['top'].set_visible(False)
# ax.spines['right'].set_visible(False)

# # Ensure axis labels are visible
# plt.tight_layout(pad=0)

# # Set font to Arial with 7pt
# plt.rc('font', family='arial', size=7)

# # Multiply tick values by 1000 to get ms
# plt.xticks([0, 200000, 400000, 600000], ["0", "200", "400", "600"])
# plt.yticks([1, 350, 700])

# # Rasterize only the event lines, not the entire plot
# for line in lines:
#     line.set_rasterized(True)

# ax.margins(x=0)
# ax.margins(y=0)
# ax.set_ylim([0, 700])

# # Display the plot
# plt.show()



# save:

[0.839 0.521]



FigureCanvasAgg is non-interactive, and thus cannot be shown


FigureCanvasAgg is non-interactive, and thus cannot be shown

