Skip to content

Commit

Permalink
feat(model): add num_frames and receptive_field to segmentation m…
Browse files Browse the repository at this point in the history
…odels

Co-authored-by: Bilal Rahou <Bilal-Rahou@users.noreply.github.com>
  • Loading branch information
hbredin and Bilal-Rahou committed Dec 15, 2023
1 parent 6580e6c commit 66dd72b
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 2 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,11 @@
# Changelog

## develop branch

### New features

- feat(model): add `num_frames` and `receptive_field` to segmentation models

## Version 3.1.1 (2023-12-01)

### TL;DR
Expand Down
89 changes: 87 additions & 2 deletions pyannote/audio/models/blocks/sincnet.py
Expand Up @@ -28,17 +28,21 @@
import torch.nn as nn
import torch.nn.functional as F
from asteroid_filterbanks import Encoder, ParamSincFB
from pyannote.core import SlidingWindow

from pyannote.audio.utils.frame import conv1d_num_frames, conv1d_receptive_field_size


class SincNet(nn.Module):
def __init__(self, sample_rate: int = 16000, stride: int = 1):
super().__init__()

if sample_rate != 16000:
raise NotImplementedError("PyanNet only supports 16kHz audio for now.")
raise NotImplementedError("SincNet only supports 16kHz audio for now.")
# TODO: add support for other sample rate. it should be enough to multiply
# kernel_size by (sample_rate / 16000). but this needs to be double-checked.

self.sample_rate = sample_rate
self.stride = stride

self.wav_norm1d = nn.InstanceNorm1d(1, affine=True)
Expand Down Expand Up @@ -70,6 +74,88 @@ def __init__(self, sample_rate: int = 16000, stride: int = 1):
self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
self.norm1d.append(nn.InstanceNorm1d(60, affine=True))

def num_frames(self, num_samples: int) -> int:
"""Compute number of output frames for a given number of input samples
Parameters
----------
num_samples : int
Number of input samples
Returns
-------
num_frames : int
Number of output frames
"""

kernel_size = [251, 3, 5, 3, 5, 3]
stride = [self.stride, 3, 1, 3, 1, 3]
padding = [0, 0, 0, 0, 0, 0]
dilation = [1, 1, 1, 1, 1, 1]

num_frames = num_samples
for k, s, p, d in zip(kernel_size, stride, padding, dilation):
num_frames = conv1d_num_frames(
num_frames, kernel_size=k, stride=s, padding=p, dilation=d
)

return num_frames

def receptive_field_size(self, num_frames: int = 1) -> int:
"""Compute receptive field size
Parameters
----------
num_frames : int, optional
Number of frames in the output signal
Returns
-------
receptive_field_size : int
Receptive field size
"""

kernel_size = [251, 3, 5, 3, 5, 3]
stride = [self.stride, 3, 1, 3, 1, 3]
padding = [0, 0, 0, 0, 0, 0]
dilation = [1, 1, 1, 1, 1, 1]

receptive_field_size = num_frames
for k, s, p, d in reversed(list(zip(kernel_size, stride, padding, dilation))):
receptive_field_size = conv1d_receptive_field_size(
num_frames=receptive_field_size,
kernel_size=k,
stride=s,
padding=p,
dilation=d,
)

return receptive_field_size

def receptive_field(self) -> SlidingWindow:
"""Compute receptive field
Returns
-------
receptive field : SlidingWindow
Source
------
https://distill.pub/2019/computing-receptive-fields/
"""

# duration of the receptive field of each output frame
duration = self.receptive_field_size() / self.sample_rate

# step between the receptive field region of two consecutive output frames
step = (
self.receptive_field_size(num_frames=2)
- self.receptive_field_size(num_frames=1)
) / self.sample_rate

return SlidingWindow(start=0.0, duration=duration, step=step)

def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
"""Pass forward
Expand All @@ -83,7 +169,6 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
for c, (conv1d, pool1d, norm1d) in enumerate(
zip(self.conv1d, self.pool1d, self.norm1d)
):

outputs = conv1d(outputs)

# https://github.com/mravanelli/SincNet/issues/4
Expand Down
31 changes: 31 additions & 0 deletions pyannote/audio/models/segmentation/PyanNet.py
Expand Up @@ -27,6 +27,7 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from pyannote.core import SlidingWindow
from pyannote.core.utils.generators import pairwise

from pyannote.audio.core.model import Model
Expand Down Expand Up @@ -157,6 +158,36 @@ def build(self):
self.classifier = nn.Linear(in_features, out_features)
self.activation = self.default_activation()

def num_frames(self, num_samples: int) -> int:
"""Compute number of output frames for a given number of input samples
Parameters
----------
num_samples : int
Number of input samples
Returns
-------
num_frames : int
Number of output frames
"""

return self.sincnet.num_frames(num_samples)

def receptive_field(self) -> SlidingWindow:
"""Compute receptive field
Returns
-------
receptive field : SlidingWindow
Source
------
https://distill.pub/2019/computing-receptive-fields/
"""
return self.sincnet.receptive_field()

def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
"""Pass forward
Expand Down
78 changes: 78 additions & 0 deletions pyannote/audio/models/segmentation/SSeRiouSS.py
Expand Up @@ -27,10 +27,12 @@
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from pyannote.core import SlidingWindow
from pyannote.core.utils.generators import pairwise

from pyannote.audio.core.model import Model
from pyannote.audio.core.task import Task
from pyannote.audio.utils.frame import conv1d_num_frames, conv1d_receptive_field_size
from pyannote.audio.utils.params import merge_dict


Expand Down Expand Up @@ -191,6 +193,82 @@ def build(self):
self.classifier = nn.Linear(in_features, out_features)
self.activation = self.default_activation()

def num_frames(self, num_samples: int) -> int:
"""Compute number of output frames for a given number of input samples
Parameters
----------
num_samples : int
Number of input samples
Returns
-------
num_frames : int
Number of output frames
"""

num_frames = num_samples
for conv_layer in self.wav2vec.feature_extractor.conv_layers:
num_frames = conv1d_num_frames(
num_frames,
kernel_size=conv_layer.kernel_size,
stride=conv_layer.stride,
padding=conv_layer.conv.padding[0],
dilation=conv_layer.conv.dilation[0],
)

return num_frames

def receptive_field_size(self, num_frames: int = 1) -> int:
"""Compute receptive field size
Parameters
----------
num_frames : int, optional
Number of frames in the output signal
Returns
-------
receptive_field_size : int
Receptive field size
"""

receptive_field_size = num_frames
for conv_layer in reversed(self.wav2vec.feature_extractor.conv_layers):
receptive_field_size = conv1d_receptive_field_size(
num_frames=receptive_field_size,
kernel_size=conv_layer.kernel_size,
stride=conv_layer.stride,
padding=conv_layer.conv.padding[0],
dilation=conv_layer.conv.dilation[0],
)

return receptive_field_size

def receptive_field(self) -> SlidingWindow:
"""Compute receptive field
Returns
-------
receptive field : SlidingWindow
Source
------
https://distill.pub/2019/computing-receptive-fields/
"""

# duration of the receptive field of each output frame
duration = self.receptive_field_size() / self.hparams.sample_rate

# step between the receptive field region of two consecutive output frames
step = (
self.receptive_field_size(num_frames=2)
- self.receptive_field_size(num_frames=1)
) / self.hparams.sample_rate

return SlidingWindow(start=0.0, duration=duration, step=step)

def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
"""Pass forward
Expand Down
77 changes: 77 additions & 0 deletions pyannote/audio/models/segmentation/debug.py
Expand Up @@ -26,6 +26,7 @@
import torch
import torch.nn as nn
from einops import rearrange
from pyannote.core import SlidingWindow
from torchaudio.transforms import MFCC

from pyannote.audio.core.model import Model
Expand Down Expand Up @@ -57,6 +58,82 @@ def __init__(
bidirectional=True,
)

def num_frames(self, num_samples: int) -> int:
"""Compute number of output frames for a given number of input samples
Parameters
----------
num_samples : int
Number of input samples
Returns
-------
num_frames : int
Number of output frames
Source
------
https://pytorch.org/docs/stable/generated/torch.stft.html#torch.stft
"""

hop_length = self.mfcc.MelSpectrogram.spectrogram.hop_length
n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft
center = self.mfcc.MelSpectrogram.spectrogram.center
return (
1 + num_samples // hop_length
if center
else 1 + (num_samples - n_fft) // hop_length
)

def receptive_field_size(self, num_frames: int = 1) -> int:
"""Compute receptive field size
Parameters
----------
num_frames : int, optional
Number of frames in the output signal
Returns
-------
receptive_field_size : int
Receptive field size
"""

hop_length = self.mfcc.MelSpectrogram.spectrogram.hop_length
n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft
center = self.mfcc.MelSpectrogram.spectrogram.center

if center:
return (num_frames - 1) * hop_length
else:
return (num_frames - 1) * hop_length + n_fft

def receptive_field(self) -> SlidingWindow:
"""Compute receptive field
Returns
-------
receptive field : SlidingWindow
Source
------
https://distill.pub/2019/computing-receptive-fields/
"""

# duration of the receptive field of each output frame
duration = (
self.mfcc.MelSpectrogram.spectrogram.win_length / self.hparams.sample_rate
)

# step between the receptive field region of two consecutive output frames
step = (
self.mfcc.MelSpectrogram.spectrogram.hop_length / self.hparams.sample_rate
)

return SlidingWindow(start=0.0, duration=duration, step=step)

def build(self):
# define task-dependent layers

Expand Down

0 comments on commit 66dd72b

Please sign in to comment.