In [None]:
from google.colab import files
uploaded = files.upload()

Saving resample.py to resample.py
Saving utils.py to utils.py


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import math
import time
import torch as th
from torch import nn
from torch.nn import functional as F
from resample import downsample2, upsample2
from utils import capture_init

#### Long-Short Term Memory Recurrent Neural Network
- Will be replaced with other Neural Network such as GRU, BRNN, ESNs, and Peephole, Connections Neural Network, and etc.
- To experiment with them all and try to build other architecture to train the model.

In [None]:
class BLSTM(nn.Module):
  def __init__(self, dim, layers=2, bi=True):
    super().__init__()
    klass = nn.LSTM
    self.lstm = klass(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim)
    self.linear = None
    if bi:
      self.linear = nn.Linear(2 * dim, dim)

    def forward(self, x, hidden=None):
      x, hidden = self.lstm(x, hidden)
      if self.linear:
        x = self.linear(x)
      return x, hidden

In [None]:

def rescale_conv(conv, reference):
      std = conv.weight.std().detach()
      scale = (std / reference)**0.5
      conv.weight.data /= scale
      if conv.bias is not None:
        conv.bias.data /= scale

def rescale_module(module, reference):
  for sub in module.modules():
    if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
      rescale_conv(sub, reference)

#### Demucs Architecture

In [None]:
class Demucs(nn.Module):
  @capture_init
  def __init__(self, chin=1, chout=1, hidden=48, depth=5, kernel_size=8, stride=4, causal=True, resample=4, growth=2, max_hidden=10_1000,
               normalize=True, glu=True, rescale=0.1, floor=1e-3, sample_rate=16_000):
    super().__init__()
    if resample not in [1, 2, 4]:
      raise ValueError("Resample should be 1, 2, or 4.")

    self.chin = chin
    self.chout = chout
    self.hidden = hidden
    self.depth = depth
    self.kernel_size = kernel_size
    self.stride = stride
    self.causal = causal
    self.floor = floor
    self.resemple = resample
    self.normalize = normalize
    self.sample_rate = sample_rate

    self.encoder = nn.ModuleList()
    self.decoder = nn.ModuleList()
    activation = nn.GLU(1) if glu else nn.ReLU()
    ch_scale = 2 if glu else 1

    for index in range(depth):
      encode = []
      encode += [
          nn.Conv1d(chin, hidden, kernel_size, stride),
          nn.ReLU(),
          nn.Conv1d(hidden, hidden * ch_scale, 1), activation
      ]
      self.encoder.append(nn.Sequential(*encode))

      decode = []
      decode += [
          nn.Conv1d(hidden, ch_scale * hidden, 1), activation,
          nn.ConvTranspose1d(hidden, chout, kernel_size, stride)
      ]
      if index > 0:
        decode.append(nn.ReLU())
      self.decoder.insert(0, nn.Sequential(*decode))
      chout = hidden
      chin = hidden
      hidden = min(int(growth * hidden), max_hidden)

    self.lstm = BLSTM(chin, bi=not causal)
    if rescale:
      rescale_module(self, reference=rescale)

  def valid_length(self, length):
    length = math.ceil(length * self.resample)
    for idx in range(self.depth):
      length = math.ceil((length - self.kernel_size) / self.stride) + 1
      length = max(length, 1)
    for idx in range(self.depth):
      length = (length - 1) * self.stride + self.kernel_size
    length = int(math.ceil(length / self.resample))
    return int(length)

  @property
  def total_stride(self):
    return self.stride ** self.depth // self.resample

  def forward(self, mix):
    if mix.dim() == 2:
      mix = mix.unsqueeze(1)

    if self.normalize:
      mono = mix.mean(dim=1, keepdim=True)
      std = mono.std(dim=-1, keepdim=True)
      mix = mix / (self.floor + std)
    else:
      std = 1
    length = mix.shape[-1]
    x = mix
    x = F.pad(x, (0, self.valid_length(length) - length))
    if self.resample == 2:
      x = upsample2(x)
    elif self.resample == 4:
      x = upsample2(x)
      x = upsample2(x)
    skips = []

    for encode in self.encoder:
      x = encode(x)
      skips.append(x)
    x = x.permute(2, 0, 1)
    x, _ = self.lstm(x)
    x = x.permute(1, 2, 0)
    for decode in self.decoder:
      skip = skips.pop(-1)
      x = x + skip[..., :x.shape[-1]]
      x = decode(x)
    if self.resample == 2:
      x = downsample2(x)
    elif self.resample == 4:
      x = downsample2(x)
      x = downsample2(x)

    x = x[..., :length] #get the num of value at the length of length
    return std * x


#### DemucsStreamer

In [None]:
class DemucsStreamer:
  def __init__(self, demucs, dry=0, num_frames=1, resample_lookahead=64, resample_buffer=256):
    device = next(iter(demucs.parameters())).device
    self.demucs = demucs
    self.lstm_state = None
    self.conv_state = None
    self.dry = dry
    self.resample_lookahead = resample_lookahead
    resample_buffer = min(demucs.total_stride, resample_buffer)
   ############################################################################

