In [None]:
  !pip install librosa
!pip install mir_eval
!pip install numpy
!pip install pytorch_lightning
!pip install torch
!pip install torchaudio

Collecting mir_eval
  Downloading mir_eval-0.7.tar.gz (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.7/90.7 kB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: mir_eval
  Building wheel for mir_eval (setup.py) ... [?25l[?25hdone
  Created wheel for mir_eval: filename=mir_eval-0.7-py3-none-any.whl size=100703 sha256=33d8fef3108154eae9bd756736141b7662d3f258b126e116aa94fae65284e101
  Stored in directory: /root/.cache/pip/wheels/3e/2f/0d/dda9c4c77a170e21356b6afa2f7d9bb078338634ba05d94e3f
Successfully built mir_eval
Installing collected packages: mir_eval
Successfully installed mir_eval-0.7
Collecting pytorch_lightning
  Downloading pytorch_lightning-2.1.2-py3-none-any.whl (776 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m776.9/776.9 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading

In [None]:
from glob import glob
import librosa
import mir_eval
import numpy as np
import os
import pytorch_lightning as pl
from random import shuffle
import torch as th
import torch.nn as nn
import torch.utils.data as tud
import torchaudio
import scipy.signal

In [None]:
## CREATE TRAIN/VAL/TEST SPLITS IN DATA ##

audio_dir = "./data/BallroomData/"

audio_files = glob(os.path.join(audio_dir, "**", "*.wav"))
print(f"# of Audio files: {len(audio_files)}")

g = th.Generator().manual_seed(2147483647)

# test split
test_size = int(len(audio_files) * 0.1)
rp = th.randperm(len(audio_files), generator=g).tolist()
temp_audio_files = [audio_files[i] for i in rp[:-test_size]]
test_audio_files = [audio_files[i] for i in rp[-test_size:]]
print(f"# of Test files: {len(test_audio_files)}")

# train / val split
val_size = int(len(temp_audio_files) * 0.11111)
rp = th.randperm(len(temp_audio_files), generator=g).tolist()
train_audio_files = [temp_audio_files[i] for i in rp[:-val_size]]
print(f"# of Train files: {len(train_audio_files)}")
val_audio_files = [temp_audio_files[i] for i in rp[-val_size:]]
print(f"# of Val files: {len(val_audio_files)}")

# train
train_file_path = os.path.join(audio_dir, "train.txt")
with open(train_file_path, "w") as f:
    for file in train_audio_files:
        f.write(f"{os.path.relpath(file, start=audio_dir)}\n")

# val
val_file_path = os.path.join(audio_dir, "val.txt")
with open(val_file_path, "w") as f:
    for file in val_audio_files:
        f.write(f"{os.path.relpath(file, start=audio_dir)}\n")

# test
test_file_path = os.path.join(audio_dir, "test.txt")
with open(test_file_path, "w") as f:
    for file in test_audio_files:
        f.write(f"{os.path.relpath(file, start=audio_dir)}\n")

In [None]:
##HARMONICS TFT BLOCK##

def hz_to_midi(hz):
    return 12 * (th.log2(hz) - np.log2(440.0)) + 69


def midi_to_hz(midi):
    return 440.0 * (2.0 ** ((midi - 69.0)/12.0))


def note_to_midi(note):
    return librosa.core.note_to_midi(note)


def hz_to_note(hz):
    return librosa.core.hz_to_note(hz)


def initialize_filterbank(sample_rate, n_harmonic, semitone_scale):
    # MIDI
    # lowest note
    low_midi = note_to_midi('C1')
    # highest note
    high_note = hz_to_note(sample_rate / (2 * n_harmonic))
    high_midi = note_to_midi(high_note)
    # number of scales
    level = (high_midi - low_midi) * semitone_scale
    midi = np.linspace(low_midi, high_midi, level + 1)
    hz = midi_to_hz(midi[:-1])
    # stack harmonics
    harmonic_hz = []
    for i in range(n_harmonic):
        harmonic_hz = np.concatenate((harmonic_hz, hz * (i+1)))
    return harmonic_hz, level


class HarmonicSTFT(nn.Module):
    """
    Trainable harmonic filters as implemented by Minz Won.

    Paper: https://ccrma.stanford.edu/~urinieto/MARL/publications/ICASSP2020_Won.pdf
    Code: https://github.com/minzwon/data-driven-harmonic-filters
    Pretrained: https://github.com/minzwon/sota-music-tagging-models/tree/master/training
    """

    def __init__(self,
                 sample_rate=16000,
                 n_fft=513,
                 win_length=None,
                 hop_length=None,
                 pad=0,
                 power=2,
                 normalized=False,
                 n_harmonic=6,
                 semitone_scale=2,
                 bw_Q=1.0,
                 learn_bw=None,
                 checkpoint=None):
        super(HarmonicSTFT, self).__init__()

        # Parameters
        self.sample_rate = sample_rate
        self.n_harmonic = n_harmonic
        self.bw_alpha = 0.1079
        self.bw_beta = 24.7

        # Spectrogram
        self.spec = torchaudio.transforms.MelSpectrogram(n_fft=n_fft, win_length=win_length,
                                                      hop_length=hop_length, pad=pad,
                                                      window_fn=th.hann_window,
                                                      power=power, normalized=normalized, wkwargs=None)
        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()

        # Initialize the filterbank. Equally spaced in MIDI scale.
        harmonic_hz, self.level = initialize_filterbank(
            sample_rate, n_harmonic, semitone_scale)

        # Center frequncies to tensor
        self.f0 = th.tensor(harmonic_hz.astype('float32'))

        # Bandwidth parameters
        if learn_bw == 'only_Q':
            self.bw_Q = nn.Parameter(th.tensor(
                np.array([bw_Q]).astype('float32')))
        elif learn_bw == 'fix':
            self.bw_Q = th.tensor(np.array([bw_Q]).astype('float32'))

        if checkpoint is not None:
            state_dict = th.load(checkpoint)
            hstft_state_dict = {k.replace('hstft.', ''): v for k,
                                v in state_dict.items() if 'hstft.' in k}
            self.load_state_dict(hstft_state_dict)

    def get_harmonic_fb(self):
        # bandwidth
        bw = (self.bw_alpha * self.f0 + self.bw_beta) / self.bw_Q
        bw = bw.unsqueeze(0)  # (1, n_band)
        f0 = self.f0.unsqueeze(0)  # (1, n_band)
        fft_bins = self.fft_bins.unsqueeze(1)  # (n_bins, 1)

        up_slope = th.matmul(fft_bins, (2/bw)) + 1 - (2 * f0 / bw)
        down_slope = th.matmul(fft_bins, (-2/bw)) + 1 + (2 * f0 / bw)
        fb = th.max(self.zero, th.min(down_slope, up_slope))
        return fb

    def to_device(self, device, n_bins):
        self.f0 = self.f0.to(device)
        self.bw_Q = self.bw_Q.to(device)
        # fft bins
        self.fft_bins = th.linspace(0, self.sample_rate//2, n_bins)
        self.fft_bins = self.fft_bins.to(device)
        self.zero = th.zeros(1)
        self.zero = self.zero.to(device)

    def forward(self, waveform):
        # stft
        spectrogram = self.spec(waveform)
        # to device
        self.to_device(waveform.device, spectrogram.size(1))
        # triangle filter
        harmonic_fb = self.get_harmonic_fb()
        harmonic_spec = th.matmul(
            spectrogram.transpose(1, 2), harmonic_fb).transpose(1, 2)
        # (batch, channel, length) -> (batch, harmonic, f0, length)
        b, c, l = harmonic_spec.size()
        harmonic_spec = harmonic_spec.view(b, self.n_harmonic, self.level, l)
        # amplitude to db
        harmonic_spec = self.amplitude_to_db(harmonic_spec)
        return harmonic_spec

In [None]:
##NETWORKS BLOCK##

class Res2DMaxPoolModule(nn.Module):
    def __init__(self, in_channels, out_channels, pooling=2):
        super(Res2DMaxPoolModule, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn_1 = nn.BatchNorm2d(out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn_2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.mp = nn.MaxPool2d(tuple(pooling))

        # residual
        self.diff = False
        if in_channels != out_channels:
            self.conv_3 = nn.Conv2d(
                in_channels, out_channels, 3, padding=1)
            self.bn_3 = nn.BatchNorm2d(out_channels)
            self.diff = True

    def forward(self, x):
        out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
        if self.diff:
            x = self.bn_3(self.conv_3(x))
        out = x + out
        out = self.mp(self.relu(out))
        return out


class ResFrontEnd(nn.Module):
    """
    Adapted from Minz Won ResNet implementation.

    Original code: https://github.com/minzwon/semi-supervised-music-tagging-transformer/blob/master/src/modules.py
    """
    def __init__(self, in_channels, out_channels, freq_pooling, time_pooling):
        super(ResFrontEnd, self).__init__()
        self.input_bn = nn.BatchNorm2d(in_channels)
        self.layer1 = Res2DMaxPoolModule(
            in_channels, out_channels, pooling=(freq_pooling[0], time_pooling[0]))
        self.layer2 = Res2DMaxPoolModule(
            out_channels, out_channels, pooling=(freq_pooling[1], time_pooling[1]))
        self.layer3 = Res2DMaxPoolModule(
            out_channels, out_channels, pooling=(freq_pooling[2], time_pooling[2]))

    def forward(self, hcqt):
        """
        Inputs:
            hcqt: [B, F, K, T]

        Outputs:
            out: [B, ^F, ^K, ^T]
        """
        # batch normalization
        out = self.input_bn(hcqt)

        # CNN
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)

        return out


class SpecTNTBlock(nn.Module):
    def __init__(
        self, n_channels, n_frequencies, n_times,
        spectral_dmodel, spectral_nheads, spectral_dimff,
        temporal_dmodel, temporal_nheads, temporal_dimff,
        embed_dim, dropout, use_tct
    ):
        super().__init__()

        self.D = embed_dim
        self.F = n_frequencies
        self.K = n_channels
        self.T = n_times

        # TCT: Temporal Class Token
        if use_tct:
            self.T += 1

        # Shared frequency-time linear layers
        self.D_to_K = nn.Linear(self.D, self.K)
        self.K_to_D = nn.Linear(self.K, self.D)

        # Spectral Transformer Encoder
        self.spectral_linear_in = nn.Linear(self.F+1, spectral_dmodel)
        self.spectral_encoder_layer = nn.TransformerEncoderLayer(
            d_model=spectral_dmodel, nhead=spectral_nheads, dim_feedforward=spectral_dimff, dropout=dropout, batch_first=True, activation="gelu", norm_first=True)
        self.spectral_linear_out = nn.Linear(spectral_dmodel, self.F+1)

        # Temporal Transformer Encoder
        self.temporal_linear_in = nn.Linear(self.T, temporal_dmodel)
        self.temporal_encoder_layer = nn.TransformerEncoderLayer(
            d_model=temporal_dmodel, nhead=temporal_nheads, dim_feedforward=temporal_dimff, dropout=dropout, batch_first=True, activation="gelu", norm_first=True)
        self.temporal_linear_out = nn.Linear(temporal_dmodel, self.T)

    def forward(self, spec_in, temp_in):
        """
        Inputs:
            spec_in: spectral embedding input [B, T, F+1, K]
            temp_in: temporal embedding input [B, T, 1, D]

        Outputs:
            spec_out: spectral embedding output [B, T, F+1, K]
            temp_out: temporal embedding output [B, T, 1, D]
        """
        # Element-wise addition between TE and FCT
        spec_in = spec_in +             nn.functional.pad(self.D_to_K(temp_in), (0, 0, 0, self.F))

        # Spectral Transformer
        spec_in = spec_in.flatten(0, 1).transpose(1, 2)  # [B*T, K, F+1]
        emb = self.spectral_linear_in(spec_in)  # [B*T, K, spectral_dmodel]
        spec_enc_out = self.spectral_encoder_layer(
            emb)  # [B*T, K, spectral_dmodel]
        spec_out = self.spectral_linear_out(spec_enc_out)  # [B*T, K, F+1]
        spec_out = spec_out.view(-1, self.T, self.K,
                                 self.F+1).transpose(2, 3)  # [B, T, F+1, K]

        # FCT slicing (first raw) + back to D
        temp_in = temp_in + self.K_to_D(spec_out[:, :, :1, :])  # [B, T, 1, D]

        # Temporal Transformer
        temp_in = temp_in.permute(0, 2, 3, 1).flatten(0, 1)  # [B, D, T]
        emb = self.temporal_linear_in(temp_in)  # [B, D, temporal_dmodel]
        temp_enc_out = self.temporal_encoder_layer(
            emb)  # [B, D, temporal_dmodel]
        temp_out = self.temporal_linear_out(temp_enc_out)  # [B, D, T]
        temp_out = temp_out.unsqueeze(1).permute(0, 3, 1, 2)  # [B, T, 1, D]

        return spec_out, temp_out


class SpecTNTModule(nn.Module):
    def __init__(
        self, n_channels, n_frequencies, n_times,
        spectral_dmodel, spectral_nheads, spectral_dimff,
        temporal_dmodel, temporal_nheads, temporal_dimff,
        embed_dim, n_blocks, dropout, use_tct
    ):
        super().__init__()

        D = embed_dim
        F = n_frequencies
        K = n_channels
        T = n_times

        # Frequency Class Token
        self.fct = nn.Parameter(th.zeros(1, T, 1, K))

        # Frequency Positional Encoding
        self.fpe = nn.Parameter(th.zeros(1, 1, F+1, K))

        # TCT: Temporal Class Token
        if use_tct:
            self.tct = nn.Parameter(th.zeros(1, 1, 1, D))
        else:
            self.tct = None

        # Temporal Embedding
        self.te = nn.Parameter(th.rand(1, T, 1, D))

        # SpecTNT blocks
        self.spectnt_blocks = nn.ModuleList([
            SpecTNTBlock(
                n_channels,
                n_frequencies,
                n_times,
                spectral_dmodel,
                spectral_nheads,
                spectral_dimff,
                temporal_dmodel,
                temporal_nheads,
                temporal_dimff,
                embed_dim,
                dropout,
                use_tct
            )
            for _ in range(n_blocks)
        ])

    def forward(self, x):
        """
        Input:
            x: [B, T, F, K]

        Output:
            spec_emb: [B, T, F+1, K]
            temp_emb: [B, T, 1, D]
        """
        batch_size = len(x)

        # Initialize spectral embedding - concat FCT (first raw) + add FPE
        fct = th.repeat_interleave(self.fct, batch_size, 0)  # [B, T, 1, K]
        spec_emb = th.cat([fct, x], dim=2)  # [B, T, F+1, K]
        spec_emb = spec_emb + self.fpe
        if self.tct is not None:
            spec_emb = nn.functional.pad(
                spec_emb, (0, 0, 0, 0, 1, 0))  # [B, T+1, F+1, K]

        # Initialize temporal embedding
        temp_emb = th.repeat_interleave(self.te, batch_size, 0)  # [B, T, 1, D]
        if self.tct is not None:
            tct = th.repeat_interleave(self.tct, batch_size, 0)  # [B, 1, 1, D]
            temp_emb = th.cat([tct, temp_emb], dim=1)  # [B, T+1, 1, D]

        # SpecTNT blocks inference
        for block in self.spectnt_blocks:
            spec_emb, temp_emb = block(spec_emb, temp_emb)

        return spec_emb, temp_emb


class SpecTNT(nn.Module):
    def __init__(
        self, fe_model,
        n_channels, n_frequencies, n_times,
        spectral_dmodel, spectral_nheads, spectral_dimff,
        temporal_dmodel, temporal_nheads, temporal_dimff,
        embed_dim, n_blocks, dropout, use_tct, n_classes
    ):
        super().__init__()

        # TCT: Temporal Class Token
        self.use_tct = use_tct

        # Front-end model
        self.fe_model = fe_model

        # Main model
        self.main_model = SpecTNTModule(
            n_channels,
            n_frequencies,
            n_times,
            spectral_dmodel,
            spectral_nheads,
            spectral_dimff,
            temporal_dmodel,
            temporal_nheads,
            temporal_dimff,
            embed_dim,
            n_blocks,
            dropout,
            use_tct
        )

        # Linear layer
        self.linear_out = nn.Linear(embed_dim, n_classes)

    def forward(self, features):
        """
        Input:
            features: [B, K, F, T]

        Output:
            logits:
                - [B, n_classes] if use_tct
                - [B, T, n_classes] otherwise
        """
        # Add channel dimension if None
        if len(features.size()) == 3:
            features = features.unsqueeze(1)
        # Front-end model
        fe_out = self.fe_model(features)            # [B, ^K, ^F, ^T]
        fe_out = fe_out.permute(0, 3, 2, 1)         # [B, T, F, K]
        # Main model
        _, temp_emb = self.main_model(fe_out)       # [B, T, 1, D]
        # Linear layer
        if self.use_tct:
            return self.linear_out(temp_emb[:, 0, 0, :])   # [B, n_classes]
        else:
            return self.linear_out(temp_emb[:, :, 0, :])   # [B, T, n_classes]

In [None]:
class BaseModel(pl.LightningModule):
    def __init__(self, feature_extractor, net, optimizer, lr_scheduler, criterion, datamodule, activation_fn):
        super().__init__()

        self.feature_extractor = feature_extractor
        self.net = net
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.criterion = criterion
        self.datamodule = datamodule

        if activation_fn == "softmax":
            self.activation = nn.Softmax(dim=2)
        elif activation_fn == "sigmoid":
            self.activation = nn.Sigmoid()

    def configure_optimizers(self):
        if self.lr_scheduler is None:
            return {"optimizer": self.optimizer}
        else:
            return {"optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler, "monitor": "val_loss"}

    @staticmethod
    def _classname(obj, lower=True):
        if hasattr(obj, '__name__'):
            name = obj.__name__
        else:
            name = obj.__class__.__name__
        return name.lower() if lower else name


class BeatEstimator(BaseModel):
    def __init__(self, feature_extractor, net, optimizer, lr_scheduler, criterion, datamodule, activation_fn):
        super().__init__(
            feature_extractor,
            net,
            optimizer,
            lr_scheduler,
            criterion,
            datamodule,
            activation_fn
        )

        self.target_fps = datamodule.sample_rate /             (datamodule.hop_length * datamodule.pooling_shrinking)

    def training_step(self, batch, batch_idx):
        losses = {}
        x, y = batch['audio'], batch['targets']
        features = self.feature_extractor(x)
        logits = self.net(features)
        losses['train_loss'] = self.criterion(
            logits.flatten(end_dim=1), y.flatten(end_dim=1))
        self.log_dict(losses, on_step=False, on_epoch=True)
        return losses['train_loss']

    def validation_step(self, batch, batch_idx):
        losses = {}
        audio, targets, ref_beats, ref_downbeats = (
            batch['audio'][0],
            batch['targets'][0].cpu(),
            batch['beats'][0].cpu(),
            batch['downbeats'][0].cpu()
        )
        input_length, sample_rate, batch_size = (
            self.datamodule.input_length,
            self.datamodule.sample_rate,
            self.datamodule.batch_size
        )
        audio_chunks = th.cat([el.unsqueeze(0) for el in audio.split(
            split_size=int(input_length*sample_rate))[:-1]], dim=0)
        # Inference loop
        logits_list, probs_list = th.tensor([]), th.tensor([])
        for batch_audio in audio_chunks.split(batch_size):
            with th.no_grad():
                features = self.feature_extractor(batch_audio)
                logits = self.net(features)
                probs = self.activation(logits)
                logits_list = th.cat(
                    [logits_list, logits.flatten(end_dim=1).cpu()], dim=0)
                probs_list = th.cat(
                    [probs_list, probs.flatten(end_dim=1).cpu()], dim=0)
        # Postprocessing
        beats_data = probs_list.argmax(dim=1)
        est_beats = th.where(beats_data == 0)[0] / self.target_fps
        est_downbeats = th.where(beats_data == 1)[0] / self.target_fps
        # Eval
        losses['val_loss'] = self.criterion(
            logits_list, targets[:len(logits_list)])
        losses['beats_f_measure'] = mir_eval.beat.f_measure(
            ref_beats, est_beats)
        losses['downbeats_f_measure'] = mir_eval.beat.f_measure(
            ref_downbeats, est_downbeats)
        self.log_dict(losses, on_step=False, on_epoch=True)
        return losses['val_loss']

In [None]:
##DUMMY DATASETS BLOCK##

class DummyBeatDataset(tud.Dataset):

    def __init__(self, sample_rate, input_length, hop_length, time_shrinking, mode):
        self.sample_rate = sample_rate
        self.input_length = input_length

        self.target_fps = sample_rate / (hop_length * time_shrinking)
        self.target_nframes = int(input_length * self.target_fps)

        assert mode in ["train", "validation", "test"]
        self.mode = mode

    def __len__(self):
        if self.mode == "train":
            return 80
        elif self.mode == "validation":
            return 10
        elif self.mode == "test":
            return 10

    def __getitem__(self, i):
        if self.mode == "train":
            return {
                'audio': th.zeros(self.input_length * self.sample_rate),
                'targets': th.zeros(self.target_nframes, 3)
            }
        elif self.mode in ["validation", "test"]:
            return {
                'audio': th.zeros(10 * self.input_length * self.sample_rate),
                'targets': th.zeros(10 * self.target_nframes, 3),
                'beats': th.arange(0, 50, 0.5),
                'downbeats': th.arange(0, 50, 2.)
            }


class DummyBeatDataModule(pl.LightningDataModule):
    def __init__(self, batch_size, n_workers, pin_memory, sample_rate, input_length, hop_length, time_shrinking):
        self.batch_size = batch_size
        self.n_workers = n_workers
        self.pin_memory = pin_memory
        self.sample_rate = sample_rate
        self.input_length = input_length
        self.hop_length = hop_length
        self.time_shrinking = time_shrinking
        self._log_hyperparams = False
        self.allow_zero_length_dataloader_with_multiple_devices = True


    def setup(self, stage):
        self.train_set = DummyBeatDataset(
            self.sample_rate,
            self.input_length,
            self.hop_length,
            self.time_shrinking,
            "train"
        )
        self.val_set = DummyBeatDataset(
            self.sample_rate,
            self.input_length,
            self.hop_length,
            self.time_shrinking,
            "validation"
        )
        self.test_set = DummyBeatDataset(
            self.sample_rate,
            self.input_length,
            self.hop_length,
            self.time_shrinking,
            "test"
        )

    def prepare_data_per_node(self):
        return None

    def train_dataloader(self):
        return tud.DataLoader(self.train_set,
                              batch_size=self.batch_size,
                              pin_memory=self.pin_memory,
                              shuffle=True,
                              num_workers=self.n_workers)

    def val_dataloader(self):
        return tud.DataLoader(self.val_set,
                              batch_size=1,
                              pin_memory=self.pin_memory,
                              shuffle=False,
                              num_workers=self.n_workers)

    def test_dataloader(self):
        return tud.DataLoader(self.test_set,
                              batch_size=1,
                              pin_memory=self.pin_memory,
                              shuffle=False,
                              num_workers=self.n_workers)

In [None]:
## BALLROOM DATASETS ##

class BallroomDataset(tud.Dataset):
    def __init__(self, audio_dir, annotation_dir, sample_rate, input_length, hop_length, fft_win_length, pooling_shrinking, mode):
        super(BallroomDataset, self).__init__()

        assert mode in ["train", "validation", "test"]
        self.mode = mode

        self.sample_rate = sample_rate
        self.input_length = input_length
        self.hop_length = hop_length
        self.fft_win_length = fft_win_length
        self.pooling_shrinking = pooling_shrinking

        audio_files_index = ""
        if self.mode == "train":
            audio_files_index = os.path.join(audio_dir, "train.txt")
        elif self.mode == "validation":
            audio_files_index = os.path.join(audio_dir, "val.txt")
        elif self.mode == "test":
            audio_files_index = os.path.join(audio_dir, "test.txt")

        # get audio files for given mode
        self.audio_files = []
        with open(audio_files_index, "r") as f:
            for audio_file in f:
                audio_file = audio_file.strip()
                self.audio_files.append(os.path.join(audio_dir, audio_file))

        # get annotation files for given mode
        self.ann_files = []
        for audio_file in self.audio_files:
            self.ann_files.append(os.path.join(annotation_dir, os.path.splitext(os.path.basename(audio_file))[0] + ".beats"))

        # get audio chunks for given mode
        self.audio_chunks = [] # ('index into self.audio_files', 'offset in seconds from start of audio file')
        if self.mode == "train":
            for i, audio_file in enumerate(self.audio_files):
                sample_rate = librosa.get_samplerate(audio_file)
                duration = librosa.get_duration(path=audio_file, sr=sample_rate)
                offsets = th.arange(0, int(duration - self.input_length) + 1)
                for offset in offsets:
                    self.audio_chunks.append((i, offset))
            shuffle(self.audio_chunks)
        else:
            for i, _ in enumerate(self.audio_files):
                self.audio_chunks.append((i, 0))

    def __len__(self):
        return len(self.audio_chunks)

    def __getitem__(self, i):
        idx, chunk_offset_in_seconds = self.audio_chunks[i]

        # load audio data
        audio_file = self.audio_files[idx]
        waveform, sample_rate = torchaudio.load(audio_file)
        nsamples = tuple(waveform.shape)[1]
        waveform = waveform.flatten()
        #duration = nsamples / sample_rate

        # load annotation data
        ann_file = self.ann_files[idx]
        beat_offsets_in_seconds = []
        beats = []
        with open(ann_file, "r") as f:
            for line in f:
                line = line.strip()
                line = line.split(" ")
                beat_offsets_in_seconds.append(float(line[0]))
                beats.append(int(line[1]))

        # generate targets
        nframes = int((1 + (nsamples - self.fft_win_length) // self.hop_length) // self.pooling_shrinking)
        samples_per_frame = nsamples / nframes

        targets = [[0,0,1] for i in range(nframes)] # init to 'non-beat'
        for i, beat_offset_in_seconds in enumerate(beat_offsets_in_seconds):
            beat_offset_in_samples = beat_offset_in_seconds * self.sample_rate
            beat_offset_in_frames = int(beat_offset_in_samples // samples_per_frame)
            if beat_offset_in_frames < len(targets):
                if beats[i] == 1: # downbeat
                    targets[beat_offset_in_frames] = [0,1,0] # set to 'downbeat'
                else: # beat
                    targets[beat_offset_in_frames] = [1,0,0] # set to 'beat'

        targets = th.Tensor(targets)

        # return mode-specific data
        if self.mode == "train":
            # compute audio chunk
            chunk_length_in_samples = self.input_length * self.sample_rate
            chunk_offset_in_samples = chunk_offset_in_seconds * self.sample_rate
            audio_chunk = waveform[chunk_offset_in_samples:chunk_offset_in_samples + chunk_length_in_samples]

            # compute targets chunk
            chunk_length_in_frames = int((1 + (chunk_length_in_samples - self.fft_win_length) // self.hop_length) // self.pooling_shrinking)
            chunk_offset_in_frames = int(chunk_offset_in_samples // samples_per_frame)
            targets_chunk = targets[chunk_offset_in_frames:chunk_offset_in_frames + chunk_length_in_frames, :]

            audio, target = self.apply_augmentations(audio_chunk, targets_chunk)

            return {
                'audio': audio,
                'targets': target
            }
        elif self.mode in ["validation", "test"]:
            # get downbeat offsets
            downbeat_offsets_in_seconds = []
            for i, beat in enumerate(beats):
                if beat == 1:
                    downbeat_offsets_in_seconds.append(beat_offsets_in_seconds[i])

            return {
                'audio': waveform,
                'targets': targets,
                'beats': th.Tensor(beat_offsets_in_seconds),
                'downbeats': th.Tensor(downbeat_offsets_in_seconds)
            }

    def apply_augmentations(self, audio, target):

        # random gain from 0dB to -6 dB
        #if np.random.rand() < 0.2:
        #    #sgn = np.random.choice([-1,1])
        #    audio = audio * (10**((-1 * np.random.rand() * 6)/20))

        # phase inversion
        if np.random.rand() < 0.5:
            audio = -audio

        # drop continguous frames
        if np.random.rand() < 0.05:
            zero_size = int(self.length*0.1)
            start = np.random.randint(audio.shape[-1] - zero_size - 1)
            stop = start + zero_size
            audio[:,start:stop] = 0
            target[:,start:stop] = 0

        # apply time stretching
        # if np.random.rand() < 0.0:
        #     factor = np.random.normal(1.0, 0.5)
        #     factor = np.clip(factor, a_min=0.6, a_max=1.8)

        #     tfm = sox.Transformer()

        #     if abs(factor - 1.0) <= 0.1: # use stretch
        #         tfm.stretch(1/factor)
        #     else:   # use tempo
        #         tfm.tempo(factor, 'm')

        #     audio = tfm.build_array(input_array=audio.squeeze().numpy(),
        #                             sample_rate_in=self.audio_sample_rate)
        #     audio = th.from_numpy(audio.astype('float32')).view(1,-1)

        #     # now we update the targets based on new tempo
        #     dbeat_ind = (target[1,:] == 1).nonzero(as_tuple=False)
        #     dbeat_sec = dbeat_ind / self.target_sample_rate
        #     new_dbeat_sec = (dbeat_sec / factor).squeeze()
        #     new_dbeat_ind = (new_dbeat_sec * self.target_sample_rate).long()

        #     beat_ind = (target[0,:] == 1).nonzero(as_tuple=False)
        #     beat_sec = beat_ind / self.target_sample_rate
        #     new_beat_sec = (beat_sec / factor).squeeze()
        #     new_beat_ind = (new_beat_sec * self.target_sample_rate).long()

        #     # now convert indices back to target vector
        #     new_size = int(np.ceil(target.shape[-1] / factor))
        #     streteched_target = th.zeros(2,new_size)
        #     streteched_target[0,new_beat_ind] = 1
        #     streteched_target[1,new_dbeat_ind] = 1
        #     target = streteched_target

        # shift targets forward/back max 70ms
        if np.random.rand() < 0.3:

            # in this method we shift each beat and downbeat by a random amount
            max_shift = int(0.045 * self.target_sample_rate)

            beat_ind = th.logical_and(target[0,:] == 1, target[1,:] != 1).nonzero(as_tuple=False) # all beats EXCEPT downbeats
            dbeat_ind = (target[1,:] == 1).nonzero(as_tuple=False)

            # shift just the downbeats
            dbeat_shifts = th.normal(0.0, max_shift/2, size=(1,dbeat_ind.shape[-1]))
            dbeat_ind += dbeat_shifts.long()

            # now shift the non-downbeats
            beat_shifts = th.normal(0.0, max_shift/2, size=(1,beat_ind.shape[-1]))
            beat_ind += beat_shifts.long()

            # ensure we have no beats beyond max index
            beat_ind = beat_ind[beat_ind < target.shape[-1]]
            dbeat_ind = dbeat_ind[dbeat_ind < target.shape[-1]]

            # now convert indices back to target vector
            shifted_target = th.zeros(2,target.shape[-1])
            shifted_target[0,beat_ind] = 1
            shifted_target[0,dbeat_ind] = 1 # set also downbeats on first channel
            shifted_target[1,dbeat_ind] = 1

            target = shifted_target

        # apply pitch shifting
        # if np.random.rand() < 0.5:
        #     sgn = np.random.choice([-1,1])
        #     factor = sgn * np.random.rand() * 8.0
        #     tfm = sox.Transformer()
        #     tfm.pitch(factor)
        #     audio = tfm.build_array(input_array=audio.squeeze().numpy(),
        #                             sample_rate_in=self.audio_sample_rate)
        #     audio = th.from_numpy(audio.astype('float32')).view(1,-1)

        # apply a lowpass filter
        if np.random.rand() < 0.1:
            cutoff = (np.random.rand() * 4000) + 4000
            sos = scipy.signal.butter(2,
                                      cutoff,
                                      btype="lowpass",
                                      fs=self.audio_sample_rate,
                                      output='sos')
            audio_filtered = scipy.signal.sosfilt(sos, audio.numpy())
            audio = th.from_numpy(audio_filtered.astype('float32'))

        # apply a highpass filter
        if np.random.rand() < 0.1:
            cutoff = (np.random.rand() * 1000) + 20
            sos = scipy.signal.butter(2,
                                      cutoff,
                                      btype="highpass",
                                      fs=self.audio_sample_rate,
                                      output='sos')
            audio_filtered = scipy.signal.sosfilt(sos, audio.numpy())
            audio = th.from_numpy(audio_filtered.astype('float32'))

        # apply a chorus effect
        # if np.random.rand() < 0.05:
        #     tfm = sox.Transformer()
        #     tfm.chorus()
        #     audio = tfm.build_array(input_array=audio.squeeze().numpy(),
        #                             sample_rate_in=self.audio_sample_rate)
        #     audio = th.from_numpy(audio.astype('float32')).view(1,-1)

        # apply a compressor effect
        # if np.random.rand() < 0.15:
        #     attack = (np.random.rand() * 0.300) + 0.005
        #     release = (np.random.rand() * 1.000) + 0.3
        #     tfm = sox.Transformer()
        #     tfm.compand(attack_time=attack, decay_time=release)
        #     audio = tfm.build_array(input_array=audio.squeeze().numpy(),
        #                             sample_rate_in=self.audio_sample_rate)
        #     audio = th.from_numpy(audio.astype('float32')).view(1,-1)

        # apply an EQ effect
        if np.random.rand() < 0.15:
            freq = (np.random.rand() * 8000) + 60
            q = (np.random.rand() * 7.0) + 0.1
            g = np.random.normal(0.0, 6)
            tfm = sox.Transformer()
            tfm.equalizer(frequency=freq, width_q=q, gain_db=g)
            audio = tfm.build_array(input_array=audio.squeeze().numpy(),
                                    sample_rate_in=self.audio_sample_rate)
            audio = th.from_numpy(audio.astype('float32')).view(1,-1)

        # add white noise
        if np.random.rand() < 0.05:
            wn = (th.rand(audio.shape) * 2) - 1
            g = 10**(-(np.random.rand() * 20) - 12)/20
            audio = audio + (g * wn)

        # apply nonlinear distortion
        if np.random.rand() < 0.2:
            g = 10**((np.random.rand() * 12)/20)
            audio = th.tanh(audio)

        # normalize the audio
        audio /= audio.float().abs().max()

        return audio, target

class BallroomDataModule(pl.LightningDataModule):
    def __init__(self, batch_size, n_workers, pin_memory, audio_dir, annotation_dir, sample_rate, input_length, hop_length, fft_win_length, pooling_shrinking):
        self.batch_size = batch_size
        self.n_workers = n_workers
        self.pin_memory = pin_memory
        self.audio_dir = audio_dir
        self.annotation_dir = annotation_dir
        self.sample_rate = sample_rate
        self.input_length = input_length
        self.hop_length = hop_length
        self.fft_win_length = fft_win_length
        self.pooling_shrinking = pooling_shrinking
        self._log_hyperparams = False
        self.allow_zero_length_dataloader_with_multiple_devices = True

    def setup(self, stage):
        self.train_set = BallroomDataset(
            audio_dir=self.audio_dir,
            annotation_dir=self.annotation_dir,
            sample_rate=self.sample_rate,
            input_length=self.input_length,
            hop_length=self.hop_length,
            fft_win_length=self.fft_win_length,
            pooling_shrinking=self.pooling_shrinking,
            mode="train"
        )
        self.val_set = BallroomDataset(
            audio_dir=self.audio_dir,
            annotation_dir=self.annotation_dir,
            sample_rate=self.sample_rate,
            input_length=self.input_length,
            hop_length=self.hop_length,
            fft_win_length=self.fft_win_length,
            pooling_shrinking=self.pooling_shrinking,
            mode="validation"
        )
        self.test_set = BallroomDataset(
            audio_dir=self.audio_dir,
            annotation_dir=self.annotation_dir,
            sample_rate=self.sample_rate,
            input_length=self.input_length,
            hop_length=self.hop_length,
            fft_win_length=self.fft_win_length,
            pooling_shrinking=self.pooling_shrinking,
            mode="test"
        )

    def prepare_data_per_node(self):
        return None

    def train_dataloader(self):
        return tud.DataLoader(self.train_set,
                              batch_size=self.batch_size,
                              pin_memory=self.pin_memory,
                              shuffle=True,
                              num_workers=self.n_workers)

    def val_dataloader(self):
        return tud.DataLoader(self.val_set,
                              batch_size=1,
                              pin_memory=self.pin_memory,
                              shuffle=False,
                              num_workers=self.n_workers)

    def test_dataloader(self):
        return tud.DataLoader(self.test_set,
                              batch_size=1,
                              pin_memory=self.pin_memory,
                              shuffle=False,
                              augment=False,num_workers=self.n_workers)

In [None]:
trainer = pl.Trainer(precision=32, accumulate_grad_batches= 16, check_val_every_n_epoch= 2, max_steps= 3)

feature_extractor = HarmonicSTFT(sample_rate=16000, n_fft=512, n_harmonic=6, semitone_scale=2, learn_bw = 'only_Q')
fe_model = ResFrontEnd(in_channels=6, out_channels=256, freq_pooling=[2,2,2], time_pooling=[2,2,1])
net = SpecTNT(fe_model = fe_model, n_channels=256, n_frequencies=16, n_times=215, embed_dim=128, spectral_dmodel=64, spectral_nheads=4, spectral_dimff=64,
                           temporal_dmodel=256, temporal_nheads=8, temporal_dimff=256, n_blocks=5, dropout=0.15, use_tct=False, n_classes=3)
optimizer = th.optim.AdamW(params=net.parameters())
criterion = th.nn.CrossEntropyLoss()
#datamodule = DummyBeatDataModule(batch_size=2, n_workers=4, pin_memory=False, sample_rate=16000, input_length=5,
#                                                    hop_length=256, time_shrinking=4)
datamodule = BallroomDataModule(batch_size=2, n_workers=4, pin_memory=False, audio_dir="./data/BallroomData/",
                                annotation_dir="./data/BallroomAnnotations/", sample_rate=44100, input_length=5,
                                hop_length=256, fft_win_length=512, pooling_shrinking=4)
model = BeatEstimator(feature_extractor=feature_extractor, net=net, optimizer=optimizer,
                                       lr_scheduler=None, criterion=criterion, datamodule=datamodule, activation_fn= 'softmax')

logger = pl.loggers.tensorboard.TensorBoardLogger(name = "", save_dir= "Logger")
trainer.fit(model=model, datamodule=datamodule)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name              | Type             | Params
-------------------------------------------------------
0 | feature_extractor | HarmonicSTFT     | 1     
1 | net               | SpecTNT          | 5.7 M 
2 | criterion         | CrossEntropyLoss | 0     
3 | activation        | Softmax          | 0     
-------------------------------------------------------
5.7 M     Trainable params
0         Non-trainable params
5.7 M     Total params
22.656    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (40) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=3` reached.
