In [10]:
from IPython.display import Audio

import torch
import torch.nn as nn
import torchaudio
from utils import *
from feature_encoder import ConvFeatureExtractionModel
import math

torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
x = []
for i in range(1,11):
    path = f"data/mp3_train_files/Gould/Gould - WTC_clip_{i}.mp3"
    waveform,sample_rate = torchaudio.load(path)
    waveform = torch.mean(waveform, dim=0).unsqueeze(0)
    waveform = normalize_tensor(waveform)
    x.append(waveform)


x = torch.cat(x)

x.shape

torch.Size([10, 220500])

In [11]:
#plot_waveform(x, sample_rate)

In [4]:
dim = 512 # quadratic to number of params in conv... double dim quadruple params

conv_feature_layers = [(512, 10, 5)] + [(512, 3, 2)] * 5 + [(512, 2, 2)] + [(512, 2, 2)]
receptive_field = calculate_receptive_field(conv_feature_layers)
print(f"{receptive_field=} samples")
print(f"{round(receptive_field / (sample_rate / 1000),3)} ms")

receptive_field=800 samples
18.141 ms


In [5]:
conv = ConvFeatureExtractionModel(conv_feature_layers)
print('num params: ', conv.params())
print('input shape=',x.shape)
print('output shape= ',conv(x).shape)

num params:  4986880
input shape= torch.Size([10, 220500])
output shape=  torch.Size([10, 512, 344])


In [9]:
class SamePad(nn.Module):
    def __init__(self, kernel_size, causal=False):
        super().__init__()
        if causal:
            self.remove = kernel_size - 1
        else:
            self.remove = 1 if kernel_size % 2 == 0 else 0

    def forward(self, x):
        if self.remove > 0:
            x = x[:, :, : -self.remove]
        return x


def make_conv_pos(e, k, g):
    pos_conv = nn.Conv1d(
        e,
        e,
        kernel_size=k,
        padding=k // 2,
        groups=g,
    )
    dropout = 0
    std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
    nn.init.normal_(pos_conv.weight, mean=0, std=std)
    nn.init.constant_(pos_conv.bias, 0)

    pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2)
    pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU())

    return pos_conv

In [12]:
d = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] + [(512, 2, 2)]

d[-1][0]

512