In [3]:
import torch

def random_masking(x, mask_ratio):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """
    N, L, D = x.shape  # batch, length, dim
    len_keep = int(L * (1 - mask_ratio))
    
    noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
    
    # sort noise for each sample
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=x.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore

In [4]:
x = torch.rand(2, 10, 16)
x_masked, mask, ids_restore = random_masking(x, 0.2)
print(x_masked.shape)
print(mask.shape)
print(ids_restore.shape)

torch.Size([2, 8, 16])
torch.Size([2, 10])
torch.Size([2, 10])


In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

class RandomProjectionVectorQuantizer(nn.Module):
    DIST_FN_LIST = ["l2", "cosine"]
    def __init__(
        self,
        feat_dim: int,
        hidden_dim: int,
        num_classes: int,
        num_books: int,
        dist_fn: str = "cosine",
        time_first: bool = False,
        freeze: bool = True,
    ):
        """Vector quantization using random projection

         Args:
            dim: input dimension (channels)
            num_classes: number of quantized vectors per group
            num_groups: number of codebooks to use
            vq_dim: dimensionality of the resulting quantized vector
            time_first: if true, expect input in BxTxC format, otherwise in BxCxT
            activation: what activation to use (should be a module).
        
        """
        super().__init__()

        if dist_fn not in self.DIST_FN_LIST:
            raise ValueError(f"Unknown distance function {dist_fn}, must be one of {self.DIST_FN_LIST}")

        self.feat_dim = feat_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.num_books = num_books
        self.dist_fn = dist_fn
        self.time_first = time_first

        # (B, T, D) -> (B, T, num_books, hidden_dim)
        self.proj = nn.Linear(self.feat_dim, self.num_books*self.hidden_dim, bias=False).requires_grad_(not freeze)
        torch.nn.init.xavier_normal_(self.proj.weight)
    
        # (num_books, num_classes, hid_dim)
        codebooks = nn.Parameter(torch.FloatTensor(self.num_books, self.num_classes, self.hidden_dim)).requires_grad_(not freeze)
        torch.nn.init.normal_(codebooks, mean=0, std=1)
        self.codebooks = F.normalize(codebooks, dim=-1)
        
    def forward(self, x):
        """
        Args:
            x: (B, T, D) or (B, D, T)
        Returns:
            xq: (B, N, T, D) or (B, N, D, T)
            xid: (B, N, T)
        """
        if not self.time_first:
            # (B, D, T) -> (B, T, D)
            x = x.transpose(1, 2)

        B, T, _ = x.size()

        # (B, T, D) -> (B, T, num_books*hidden_dim)
        x = self.proj(x)

        # (B, T, num_books*hidden_dim) -> (B, T, num_books, hidden_dim)
        x = F.normalize(x.view(B, T, self.num_books, self.hidden_dim), dim=-1)

        # get tokens (xid) of shape (B, T, num_books)
        if self.dist_fn == "cosine":
            # (B, T, num_books, hidden_dim) -> (B, T, num_books, num_classes)
            xid = torch.einsum('btdh,dch->btdc', x, self.codebooks)
            # (B, T, num_books, num_classes) -> (B, T, num_books)
            xid = xid.max(dim=-1)[1]
        elif self.dist_fn == "l2":
            # (B, T, num_books, hidden_dim) -> (B, T, num_books, hidden_dim, num_classes)
            xid = x.unsqueeze(-1) - self.codebooks.transpose(1,2).unsqueeze(0).unsqueeze(0)
            xid = xid.norm(dim=-2).argmin(dim=-1)
        else:
            raise ValueError(f"Unknown distance function {self.dist_fn}, must be one of {self.DIST_FN_LIST}")
        
        
        # xid2: (B, T, num_books) -> (B, T, num_books)
        xid2 = xid + self.num_classes*torch.arange(self.num_books, device=xid.device).unsqueeze(0).unsqueeze(0)
        # xid2: (B, T, num_books) -> (B*num_books, T)
        xid2 = xid2.transpose(1,2).contiguous().view(-1, T)
        
        # get quantized vector (xq) of shape (B, T, hidden_dim, num_books)
        # codebook: (num_books, num_classes, hidden_dim) -> (num_books*num_classes, hidden_dim)
        xq = F.embedding(xid2.view(-1), self.codebooks.view(-1, self.hidden_dim)).view(B, T, self.hidden_dim, self.num_books)
    
        if not self.time_first:
            # (B, T, D) -> (B, D, T)
            xq = xq.transpose(1, 2)
        return xq, xid

quantizer = RandomProjectionVectorQuantizer(16, 8, 6, 3, time_first=True)
x = torch.rand(2, 5, 16)
xid = quantizer(x)
print(xid.shape)
# quantizer = RandomProjectionVectorQuantizer(16, 8, 6, 3, time_first=True, dist_fn="l2")
# x = torch.rand(2, 10, 16)
# xid = quantizer(x)
# print(xid.shape)

torch.Size([2, 5, 8, 3])
tensor(0.) tensor(25.1454)
torch.Size([2, 5, 8, 3])
torch.Size([2, 5, 3])


In [47]:
t1 = torch.zeros(2, 4, 3)
t2 = torch.arange(3, device=xid.device).unsqueeze(0).unsqueeze(0)
print(t1)
print(t1+t2)

tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])
tensor([[[0., 1., 2.],
         [0., 1., 2.],
         [0., 1., 2.],
         [0., 1., 2.]],

        [[0., 1., 2.],
         [0., 1., 2.],
         [0., 1., 2.],
         [0., 1., 2.]]])


In [56]:
class MLMLoss(nn.Module):
    def __init__(self, combine_time_steps: int = 1, mask_threshold: float = 0.8,):
        super().__init__()
        self.nll_loss = nn.NLLLoss()
        self.combine_time_steps = combine_time_steps
        self.mask_threshold = mask_threshold

    def forward(self, spec_masks, decoder_outputs, targets, decoder_lengths=None, target_lengths=None, masks=None):

        if masks is None:
            masks = spec_masks

        # B,D,T -> B,T,D
        masks = masks.transpose(1, 2)

        masks = masks.reshape(masks.shape[0], masks.shape[1] // self.combine_time_steps, -1)
        masks = masks.mean(-1) > self.mask_threshold

        out_masked_only = decoder_outputs[masks]
        targets = F.pad(targets, (0, masks.shape[-1] - targets.shape[-1]))
        targets_masked_only = targets[masks]

        loss = self.nll_loss(out_masked_only, targets_masked_only)
        loss = torch.mean(loss)

        return loss



In [60]:
t1 = torch.arange(12).reshape(2, 6)
m1 = torch.tensor([[0, 1, 1, 0, 1, 0], [1, 1, 0, 1, 0, 1]]).bool()
t2 = t1[m1]
t2.shape

torch.Size([7])

In [1]:
import os
from pathlib import Path

def process_audio_file_pathlib(audio_file, manifest_file, data_dir=None):
    audio_file = Path(audio_file)

    if (len(str(audio_file)) < 255) and not audio_file.is_absolute() and not audio_file.is_file():
        # If audio_file is not available and the path is not absolute, the full path is assumed
        # to be relative to the manifest file parent directory or data directory.

        # resolve the data directory
        if data_dir is None:
            data_dir = Path(manifest_file).parent.as_posix()

        # assume audio_file path is relative to data_dir
        audio_file_path = Path(data_dir, audio_file)
        return audio_file_path.as_posix()
    return audio_file.as_posix()

def process_audio_file_os(audio_file, manifest_file, data_dir=None):
    
    if (len(str(audio_file)) < 255) and not os.path.isabs(audio_file) and not os.path.exists(audio_file):
        # If audio_file is not available and the path is not absolute, the full path is assumed
        # to be relative to the manifest file parent directory or data directory.

        # resolve the data directory
        if data_dir is None:
            data_dir = os.path.dirname(manifest_file)
            
        # assume audio_file path is relative to data_dir
        audio_file_path = os.path.join(data_dir, audio_file)
        return audio_file_path
    return audio_file

In [2]:

def test_pathlib(num=50000):
    for _ in range(num):
        process_audio_file_pathlib("audios/test.wav", "/a/b/c/d/test.json")

def test_os_lib(num=50000):
    for _ in range(num):
        process_audio_file_os("audios/test.wav", "/a/b/c/d/test.json")


In [5]:
%timeit test_pathlib(500000)

6.96 s ± 24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
%timeit test_os_lib(500000)

1.64 s ± 9.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:

def test_pathlib2(num=50000):
    for _ in range(num):
        process_audio_file_pathlib("/a/audios/test.wav", "/a/b/c/d/test.json")

def test_os_lib2(num=50000):
    for _ in range(num):
        process_audio_file_os("/a/audios/test.wav", "/a/b/c/d/test.json")

In [8]:
%timeit test_pathlib2(500000)

1.95 s ± 11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%timeit test_os_lib2(500000)

405 ms ± 1.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
