In [1]:
import numpy as np
import pandas as pd

import torch

from torch import nn
from torch.nn import functional as F

In [5]:
a = np.arange(1000)

In [8]:
(a[10:] - a[:-10]).mean()

10.0

In [None]:
a.sort()

In [11]:
np.random.randint(5, size=10)

array([0, 0, 0, 3, 1, 3, 4, 4, 3, 1])

In [104]:
pd.Series(np.arange(1000)).rolling(128).mean().fillna(0)

0        0.0
1        0.0
2        0.0
3        0.0
4        0.0
       ...  
995    931.5
996    932.5
997    933.5
998    934.5
999    935.5
Length: 1000, dtype: float64

In [110]:
np.hstack([np.zeros(10), np.ones(10)])

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1.])

In [277]:
def local_shuffle(values, keys, window, seed):
    g = np.random.Generator(np.random.PCG64(seed))

    keys = np.array(keys)
    kam_middle = keys[2 * window:] - keys[:- (2 * window)]
    kam_prefix = keys[window: 2 * window] - keys[:window]
    kam_suffix = keys[-window:] - keys[-2 * window: -window]
    kam = np.hstack([kam_prefix, kam_middle, kam_suffix])
    new_keys = keys + g.uniform(-kam, kam, size=keys.shape)

    kvs = list(zip(new_keys, values))
    kvs.sort()

    return [v for _, v in kvs]


np.array(local_shuffle(np.arange(1000), sorted(np.random.rand(1000)), 64, 2))

array([ 64,  86,  84,  80,  66,  97,   7,   3,  89,  93,  11,  88,   6,
       106,  28,  29, 110, 117,   0,  19,  49,   1,   8,  48,  30, 120,
        77,  20,  98,  24,  47,  63, 141,  18,  59,  12,  14, 133,  60,
        61, 111,  45,  99, 107,  35,  10,  81,   4,  21,  26,  85,  39,
         9,  15, 140,  54,  13,  36,   5,  58,  51,  43,  17, 148, 155,
       103,   2,  46,  92, 101, 121, 187,  37,  34,  67,  27,  41,  23,
       137, 184,  32,  95,  56,  68,  44,  94,  22, 144, 182, 196, 162,
        70,  50,  16,  75,  62, 194,  91,  33, 151,  25,  31, 145,  42,
       139,  38, 175, 215, 100,  57,  79,  40,  73,  53, 123, 161, 112,
        52, 183,  69,  74,  96, 129,  55, 124, 170, 169, 153, 234, 172,
       229,  78, 108, 204, 193, 206, 232, 191, 219,  72, 102, 201, 218,
       146, 163, 127, 198, 270,  65, 119, 233, 258, 230,  83, 156, 265,
       209, 138, 135, 214, 113, 122, 171, 131,  87, 276, 225, 228, 203,
       247, 249, 211,  71, 250, 237, 109, 178,  82, 244, 157, 26

In [51]:
a = [1, 2, 3]
a.sort()
a

[1, 2, 3]

In [2]:
def merge(tensors, value=0, dtype=None):
    dtype = tensors[0].dtype if dtype is None else dtype
    max_len = max(tensor.shape[0] for tensor in tensors)
    new_tensors = []
    for tensor in tensors:
        pad = (2 * len(tensor.shape)) * [0]
        pad[-1] = max_len - tensor.shape[0]
        new_tensors.append(F.pad(tensor, pad=pad, value=value))
    return torch.stack(new_tensors).to(dtype=dtype)

In [3]:
text = [
    [0, 2, 0, 3, 0],
    [0, 5, 0, 6, 0],
    [0, 9, 0, 31, 31]
]
dur = [
    [0, 2, 10, 1, 17],
    [3, 1, 0, 1, 0],
    [4, 4, 4, 0, 0],
]
text = torch.tensor(text, dtype=torch.long)
mask = torch.ones_like(text)
dur = torch.tensor(dur, dtype=torch.long)
text.shape, mask.shape, dur.shape

(torch.Size([3, 5]), torch.Size([3, 5]), torch.Size([3, 5]))

In [93]:
class PolySpanEmb(nn.Module):
    def __init__(self, n_vocab, d_emb, pad_id):
        super().__init__()

        self._emb = nn.Embedding(n_vocab, d_emb, padding_idx=pad_id)

    def forward(self, text, mask, dur):
        lefts, rights = self._generate_sides(text)
        lefts = self._emb(self._generate_text_rep(lefts, dur))
        rights = self._emb(self._generate_text_rep(rights, dur))

        left_c = self._generate_left_c(dur).unsqueeze_(-1)

        x = left_c * lefts + (1 - left_c) * rights
        
        return x

    def _generate_sides(self, text):
        lefts = F.pad(text[:, :-1], [1, 0, 0, 0], value=self._emb.padding_idx)
        lefts[:, 1::2] = text[:, 1::2]
        rights = F.pad(text[:, 1:], [0, 1, 0, 0], value=self._emb.padding_idx)
        rights[:, 1::2] = text[:, 1::2]

        return lefts, rights
    
    def _generate_text_rep(self, text, dur):
        text_rep = []
        for t, d in zip(text, dur):
            text_rep.append(torch.repeat_interleave(t, d))

        text_rep = merge(text_rep)

        return text_rep
    
    def _generate_left_c(self, dur):
        x = F.pad(torch.cumsum(dur, dim=-1)[:, :-1], [1, 0], value=0)
        pos_cm = self._generate_text_rep(x, dur)
        mask = self._generate_text_rep(torch.ones_like(dur), dur)
        ones_cm = torch.cumsum(mask, dim=1)
        totals = self._generate_text_rep(dur, dur) + 1

        left_c = 1 - ((ones_cm - pos_cm) * mask).float() / totals

        return left_c


pe = PolySpanEmb(32, 128, 31)
x = pe(text, mask, dur)
x.shape

torch.Size([3, 30, 128])

In [6]:
mel = torch.rand((128, 160, 1000))
mel.shape

torch.Size([128, 160, 1000])

In [15]:
o = mel.reshape(mel.shape[0], 80, -1)
o.shape

torch.Size([128, 80, 2000])

In [23]:
m = torch.tensor([
    [1, 1, 0],  # mel_len was 4, becomes 2
    [1, 0, 0],  # was 2, becomes 1; was 3
])
m

tensor([[1, 1, 0],
        [1, 0, 0]])

In [25]:
m.sum(-1) * 2 - m.sum(-1) % 2

tensor([4, 1])

In [16]:
mel[0, :, 0]

tensor([0.5463, 0.5883, 0.8318, 0.0917, 0.1764, 0.1711, 0.7402, 0.2401, 0.6802,
        0.4101, 0.3471, 0.1783, 0.6591, 0.5719, 0.5398, 0.4631, 0.0854, 0.0559,
        0.0923, 0.5494, 0.1166, 0.0413, 0.3319, 0.1617, 0.6404, 0.5383, 0.5452,
        0.6088, 0.1368, 0.1795, 0.8656, 0.0528, 0.3815, 0.8936, 0.5145, 0.9768,
        0.2792, 0.9253, 0.7345, 0.4433, 0.7745, 0.5315, 0.6038, 0.9462, 0.9117,
        0.1661, 0.5549, 0.9857, 0.9305, 0.9008, 0.1744, 0.0595, 0.9134, 0.3502,
        0.0490, 0.9336, 0.1684, 0.1998, 0.8651, 0.6965, 0.5803, 0.6699, 0.3815,
        0.8126, 0.9338, 0.8119, 0.9597, 0.8481, 0.9223, 0.2737, 0.5776, 0.4658,
        0.9778, 0.1504, 0.6733, 0.1040, 0.4254, 0.7694, 0.1430, 0.5904, 0.7740,
        0.8999, 0.0431, 0.1385, 0.3678, 0.4122, 0.5445, 0.4348, 0.9733, 0.4544,
        0.4975, 0.8677, 0.3433, 0.9270, 0.0162, 0.6738, 0.6966, 0.3388, 0.9226,
        0.7075, 0.8833, 0.8069, 0.8924, 0.4891, 0.7229, 0.5195, 0.8639, 0.1736,
        0.7689, 0.2291, 0.6417, 0.0925, 

In [17]:
o[0, :, :2]

tensor([[0.5463, 0.8759],
        [0.8318, 0.0649],
        [0.1764, 0.1373],
        [0.7402, 0.4410],
        [0.6802, 0.0086],
        [0.3471, 0.7125],
        [0.6591, 0.1337],
        [0.5398, 0.2723],
        [0.0854, 0.8415],
        [0.0923, 0.8433],
        [0.1166, 0.6130],
        [0.3319, 0.3897],
        [0.6404, 0.3436],
        [0.5452, 0.3287],
        [0.1368, 0.3839],
        [0.8656, 0.1534],
        [0.3815, 0.0053],
        [0.5145, 0.4779],
        [0.2792, 0.0587],
        [0.7345, 0.3611],
        [0.7745, 0.5071],
        [0.6038, 0.4110],
        [0.9117, 0.9358],
        [0.5549, 0.2507],
        [0.9305, 0.2080],
        [0.1744, 0.1966],
        [0.9134, 0.2663],
        [0.0490, 0.5377],
        [0.1684, 0.0659],
        [0.8651, 0.5655],
        [0.5803, 0.8360],
        [0.3815, 0.5139],
        [0.9338, 0.1661],
        [0.9597, 0.1514],
        [0.9223, 0.9137],
        [0.5776, 0.9467],
        [0.9778, 0.4361],
        [0.6733, 0.3399],
        [0.4