In [9]:
import matplotlib.pyplot as plt
import torchaudio
import os
from glob import glob
import yaml

def plot_spec(specs, titles):
    fig, axs = plt.subplots(1, len(specs))
    for i, spec in enumerate(specs):
        axs[i].imshow(spec.detach().cpu().numpy(),
                      origin="lower", aspect="auto", cmap='magma')
        axs[i].set_xticks([])
        axs[i].set_yticks([])
        axs[i].set_title(titles[i])
    fig.tight_layout(pad=0)
    return fig

class SPEC():
    def __init__(self, n_fft, overlap=0.75, eps=1e-7):
        super().__init__()
        self.n_fft = n_fft
        self.eps = eps
        self.hop_length = int(n_fft * (1 - overlap))  # 25% of the length
        self.spec = torchaudio.transforms.Spectrogram(
            n_fft=self.n_fft, hop_length=self.hop_length)

    def log_spec(self, x):
        S = self.spec(x)
        log_S = (S + self.eps).log2()
        return log_S


def load_audio(audio_dir):
    subdirs = glob(audio_dir + "/*")
    audios = []
    forces = []
    for sspath in subdirs:
        files = os.listdir(sspath)
        gain = [1, 1]
        pad = [0, 0]
        for filename in files:
            filedir = sspath + "/" + filename
            if "mic" in filename:
                audio, sr = torchaudio.load(filedir)
            if "Force" in filename:
                force, sr = torchaudio.load(filedir)
            if "metadata" in filename:
                f = open(filedir)
                yaml_data = yaml.safe_load(f)
                gain = yaml_data.get("gain")
                pad = yaml_data.get("pad")

        force = torchaudio.functional.gain(force, gain[0])
        audio = torchaudio.functional.gain(audio, gain[1])
        force = force[:, pad[0] * sr:]
        audio = audio[:, pad[1] * sr:]
        audios.append(audio[0])  # only use the first channel
        forces.append(force[0])  # only use the first channel
    return audios, forces, sr

spec = SPEC(512)

audio_specs = []
titles = []
for i in range(100):
    audio_dir = f'../data/audio_data/{i+1}/audio'
    audios, forces, sr = load_audio(audio_dir)
    audio_spec = spec.log_spec(audios[0][:sr // 2])
    audio_specs.append(audio_spec)
    titles.append(f'audio {i+1}')

for i in range(20):
    fig = plot_spec(audio_specs[i*5:i*5+5], titles[i*5:i*5+5])
    fig.savefig(f'specs/audio_spec_{i}.png', dpi=300)
    plt.close(fig)