# google driveのマウント

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# 準備

## ライブラリのインポート

In [None]:
!pip install mir_eval

Collecting mir_eval
  Downloading mir_eval-0.6.tar.gz (87 kB)
[?25l[K     |███▊                            | 10 kB 31.8 MB/s eta 0:00:01[K     |███████▌                        | 20 kB 20.9 MB/s eta 0:00:01[K     |███████████▏                    | 30 kB 16.0 MB/s eta 0:00:01[K     |███████████████                 | 40 kB 14.0 MB/s eta 0:00:01[K     |██████████████████▋             | 51 kB 8.8 MB/s eta 0:00:01[K     |██████████████████████▍         | 61 kB 8.5 MB/s eta 0:00:01[K     |██████████████████████████      | 71 kB 9.3 MB/s eta 0:00:01[K     |█████████████████████████████▉  | 81 kB 10.3 MB/s eta 0:00:01[K     |████████████████████████████████| 87 kB 4.9 MB/s 
Building wheels for collected packages: mir-eval
  Building wheel for mir-eval (setup.py) ... [?25l[?25hdone
  Created wheel for mir-eval: filename=mir_eval-0.6-py3-none-any.whl size=96515 sha256=b26730170a8a87708133d3afdfdf5f54ae861a0002141126489e34601fa2735b
  Stored in directory: /root/.cache/pip/whee

In [None]:
import os
import numpy as np
import librosa
import math
import mir_eval
from tqdm import tqdm
import xml.dom.minidom
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pickle
import matplotlib.pyplot as plt
import warnings
import sys
from collections import defaultdict
import glob

## グローバル変数

In [None]:
EPS = 10e-6

SR = 8000
SR_WAV = SR
DURATION = 4
NSP_SRC = SR * DURATION
N_FFT = 1024
HOP = 1024 // 2


EVALDATA_PATH = "/content/drive/My Drive/data_evals"
DATA_PATH = "/content/drive/My Drive/data_drum_sources"
STEM_PATH1 = "/content/drive/My Drive/stems_synth"
STEM_PATH2 = "/content/drive/My Drive/stems_GMD"
MODEL_PATH = "/content/drive/My Drive/saved_models"
RESULT_PATH = "/content/drive/My Drive/results"

USE_CUDA = torch.cuda.is_available()

torch.set_default_tensor_type('torch.FloatTensor')

if USE_CUDA:
  DEVICE = torch.device('cuda:0')
else:
  DEVICE = 'cpu'


NPDTYPE = np.float32
TCDTYPE = torch.get_default_dtype()

DRUM_NAMES = ["KD_KD", "SD_SD", "HH_CHH", "HH_OHH", "HH_PHH", "TT_HIT",
              "TT_MHT", "TT_HFT", "CY_RDC", "CY_RDB", "CY_CRC", "CY_CHC",
              "CY_SPC", "OT_TMB", "OT_CL", "OT_CB"]
N_DRUM_VSTS = 12

## 変数(モデル)

In [None]:
# モデル
kernel_size = 3  # 奇数
n_channel = 50
scale_r = 2
activation = "elu"
enc_layers = 10
dec_layers = 7

tran_layers = 3
tran_heads = 5
d_ffn = 200
dropout = 0.1

sparsemax_lst = 32

# 訓練
batch_size = 4
loss_domains = ["spectrum"]  # melgram, stft, l1_reg
n_mels = 0
learning_rate = 0.00025
metrics = ["mae"]  # mse
source_norm = "sqrsum"  # no, abssum, sqrsum
l1_reg_lambda = 0.003
cqt_bins = 12
compare_after_hpss = False

## ドラム音のロード

In [None]:
# テストドラム音(1つずつ)のロード
def load_drum_srcs(idx=1):
  dura = 1.0
  srcs = []
  for drum_note_name in DRUM_NAMES:
    filename = "%d)%s.wav" % (idx, drum_note_name)
    src, _ = librosa.load(os.path.join(DATA_PATH, filename), sr=SR_WAV, duration=dura)
    srcs.append(src)
  inst_srcs = torch.from_numpy(np.array(srcs))
  return inst_srcs


# 学習時使用ドラム音のロード

def get_drumsets(norm):
  if norm == "abssum":
    norm_func = lambda x: 400 * 1.0 / (x.abs().sum())
  elif norm == "sqrsum":
    norm_func = lambda x: 10 * 1.0 / x.pow(2).sum().sqrt()
  else:
    norm_func = lambda x: 1

  notes = {key: [] for key in DRUM_NAMES}

  for i in range(1, N_DRUM_VSTS):
    srcs = load_drum_srcs(i)

    for src, key in zip(srcs, DRUM_NAMES):
      notes[key].append(norm_func(src) * src)

  for note in notes:
    srcs = notes[note]
    for i, src in enumerate(srcs):
      srcs[i] = torch.flip(src, dims=(0,))
    notes[note] = tuple(notes[note])

  return DrumSourceSet(notes=notes)


class DrumSourceSet(object):
  def __init__(self, notes, reverse=True):
    self.n_notes = len(notes)
    self.notes = notes.copy()
    self.note_names = list(self.notes.keys())
    self.reverse = reverse
    self.n_vsts = {note: len(self.notes[note]) for note in self.note_names}

  def __getitem__(self, note):
    return self.notes[note]

  def random_choice(self, note):
    return self.notes[note][np.random.choice(self.n_vsts[note])]

  def __str__(self):
    return ("n_notes: %d, n_vsts:" % self.n_notes) + str(self.n_vsts)

## モデル

In [None]:
class DrumTranModel(nn.Module):
  def __init__(self, inst_srcs, inst_names, drum_sets):
    super().__init__()
    self.test_inst_srcs = inst_srcs
    self.test_inst_names = inst_names
    self.drum_sets = drum_sets
    self.n_notes = self.drum_sets.n_notes

    self.unet = Unet(nn.Conv1d, nn.MaxPool1d)
    Rec = TransformerEncoder  # Recurrenter, Convoluter
    self.recurrenter = Rec(tran_layers, n_channel, tran_heads, d_ffn, dropout, self.n_notes)  # Rec(n_ch_repre, hidden_size=self.n_notes) n_ch_repre=n_channel*2
    SpMax = MultiplySparsemax  # SequentialSparsemax, SoftSoftSeq, SoftSoftMul
    self.sparsemax_lst = sparsemax_lst
    self.double_sparsemax = SpMax(self.sparsemax_lst)
    self.zero_inserter = ZeroInserter(self.unet.sr_ratio)
    self.synthesizer = FastDrumSynthesizer(self.n_notes, self.drum_sets)
    self.mixer = Mixer()

    """
    print('NUM_PARAM overall:', count_parameters(self))
    print('             unet:', count_parameters(self.unet))
    print('      recurrenter:', count_parameters(self.recurrenter))
    print('       sparsemaxs:', count_parameters(self.double_sparsemax))
    print('      synthesizer:', count_parameters(self.synthesizer))
    """

  def forward(self, x):
    nsp_src = x.shape[1]
    div = self.unet.compress_ratio
    nsp_pad = (div - (nsp_src % div)) % div
    if nsp_pad != 0:
      x = F.pad(x, (0, nsp_pad))

    r = self.unet(x)
    dense_y = self.recurrenter(r)
    sparse_y = self.double_sparsemax(dense_y)
    y_hat = upsampled_y = self.zero_inserter(sparse_y)
    tracks = self.synthesizer(y_hat)
    x_hat = est_mix = self.mixer(tracks)

    # trimmed = (x.shape[1] - x_hat.shape[1]) // 2
    # x_trimmed = x[:, trimmed: -trimmed]
    return x, x_hat, y_hat

## モデルモジュール - Unet

In [None]:
class Unet(nn.Module):
  def __init__(self, conv, mp):
    super().__init__()
    self.n_channel = n_channel
    self.kernel_size = kernel_size
    self.padding = self.kernel_size // 2
    self.act = F.elu  # F.relu, F.leaky_relu
    self.scale_r = scale_r
    self.enc_layers = enc_layers
    self.dec_layers = dec_layers
    self.sr_ratio = self.scale_r ** (self.enc_layers - self.dec_layers)
    self.compress_ratio = self.scale_r ** self.enc_layers

    n_ch, k_size, pd = self.n_channel, self.kernel_size, self.padding
    first_ch = min(128, n_ch)
    st = 1
    bias = False  # True
    # self.ch_out = 2 * n_ch

    self.d_conv0 = conv(1, first_ch, k_size, st, pd, bias=bias)

    self.d_convs = nn.ModuleList([conv(first_ch, n_ch, k_size, st, pd, bias=bias)] +
                                 [conv(n_ch, n_ch, k_size, st, pd, bias=bias) 
                                 for _ in range(self.enc_layers - 1)])
    self.pools = nn.ModuleList([mp(self.scale_r) for _ in range(self.enc_layers)])
    # self.encode_conv = conv(n_ch, n_ch, k_size, st, pd, bias=bias)
    self.u_convs = nn.ModuleList([conv(n_ch, n_ch, k_size, st, pd, bias=bias)] +
                                 [conv(2 * n_ch, n_ch, k_size, st, pd, bias=bias)
                                 for _ in range(self.dec_layers - 1)])
    self.last_conv = conv(2 * n_ch, n_ch, self.kernel_size, st, pd)

  def forward(self, x):
    x = torch.unsqueeze(x, 1)
    x = self.act(self.d_conv0(x))
    xs = []
    for pool, conv in zip(self.pools, self.d_convs):
      x = conv(x)
      x = pool(self.act(x))
      xs.append(x)

    # ys = []
    y = xs.pop()
    # y_end = self.encode_conv(xs.pop())
    # y = self.act(y_end)
    # ys.append(y)
    for conv in self.u_convs:
      y = conv(y)
      y = self.act(y)
      y = F.interpolate(y, scale_factor=self.scale_r,
                        mode=int(y.dim()==4)*"bi"+"linear", align_corners=False)
      x = xs.pop()
      # crop = (x.shape[2] - y.shape[2]) // 2
      # x = x[:, :, crop:-crop]
      y = torch.cat((y, x), dim=1)
      # ys.append(y)

    r = self.last_conv(y)
    return self.act(r)

## モデルモジュール - Transformer

In [None]:
class TransformerEncoder(nn.Module):
  def __init__(self, n_layers, d_model, n_heads, d_feedfoward, dropout, n_insts):
    super().__init__()
    self.pos_encoder = PositionalEncoding(d_model)
    self.layers = nn.ModuleList([TransformerEncoderLayer(d_model, n_heads, d_feedfoward, dropout)
                                 for _ in range(n_layers)])
    self.linear = nn.Linear(d_model, n_insts)

  def forward(self, src):
    src = src.transpose(1, 2)
    src = self.pos_encoder(src)
    for layer in self.layers:
      src = layer(src)
    src = self.linear(src)
    src = src.transpose(1, 2)
    return src


class TransformerEncoderLayer(nn.Module):
  def __init__(self, d_model, n_heads, d_feedforward, dropout):
    super().__init__()
    d_k = d_model // n_heads
    self.mul_h_attention = MultiHeadAttention(n_heads, d_model, d_k)
    self.norm = nn.LayerNorm(d_model)
    self.feedforward = nn.Sequential(nn.Linear(d_model, d_feedforward), 
                                     nn.ReLU(),
                                     nn.Linear(d_feedforward, d_model))

  def forward(self, src):
    src = src + self.mul_h_attention(src)
    src = self.norm(src)
    src = self.feedforward(src)
    return src


class MultiHeadAttention(nn.Module):
  def __init__(self, n_heads, d_model, d_k):
    super().__init__()
    self.heads = nn.ModuleList([AttentionHead(d_model, d_k) for _ in range(n_heads)])
  
  def forward(self, src):
    src = torch.cat([head(src) for head in self.heads], dim=-1)
    return src


class AttentionHead(nn.Module):
  def __init__(self, d_model, d_k):
    super().__init__()
    self.q = nn.Linear(d_model, d_k)
    self.k = nn.Linear(d_model, d_k)
    self.v = nn.Linear(d_model, d_k)
    self.softmax = nn.Softmax(dim=-1)
  
  def forward(self, src):
    q, k, v = self.q(src), self.k(src), self.v(src)
    temp = q.bmm(k.transpose(1, 2))
    scale = q.size(-1) ** 0.5
    softmax = self.softmax(temp / scale)
    return softmax.bmm(v)


class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len=100000, dropout=0.1):
    super().__init__()
    position = torch.arange(max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0)/d_model))
    pe = torch.zeros(1, max_len, d_model)
    pe[0, :, 0::2] = torch.sin(position * div_term)
    pe[0, :, 1::2] = torch.cos(position * div_term)
    self.register_buffer("pe", pe)
    self.dropout = nn.Dropout(p=dropout)
  
  def forward(self, x):
    x = x + self.pe[:, :x.size(1), :]
    x = self.dropout(x)
    return x

## モデルモジュール - sparsemax

In [None]:
class MultiplySparsemax(nn.Module):
  def __init__(self, sparsemax_lst=64):
    super().__init__()
    self.lst = sparsemax_lst
    self.sparsemax_inst = Sparsemax(dim=-1)
    self.sparsemax_time = Sparsemax(dim=-1)
    self.softmax_time = nn.Softmax(dim=-1)

  def forward(self, midis_out):
    batch, n_insts, time = midis_out.shape
    lst = self.lst
    len_pad = (lst - time % lst) % lst

    midis_out = F.pad(midis_out, [0, len_pad])
    midis_out_inst = self.sparsemax_inst(midis_out.transpose(1, 2)).transpose(1, 2)

    midis_out_hh = midis_out[:, 2:5, :]
    midis_out_hh_time = midis_out_hh.reshape(batch, 3, (time + len_pad) // lst, lst)
    midis_out_hh_time = self.softmax_time(midis_out_hh_time)
    midis_out_hh_time = midis_out_hh_time.reshape(batch, 3, (time + len_pad))

    midis_out_others = torch.cat((midis_out[:, :2, :], midis_out[:, 5:, :]), 1)
    midis_out_others_time = midis_out_others.reshape(batch, n_insts-3, (time + len_pad) // lst, lst)
    midis_out_others_time = self.sparsemax_time(midis_out_others_time)
    midis_out_others_time = midis_out_others_time.reshape(batch, n_insts-3, (time + len_pad))

    midis_out_time = torch.cat((midis_out_others_time[:, :2, :], midis_out_hh_time, midis_out_others_time[:, 2:, :]), 1)

    # midis_out_time = midis_out.reshape(batch, n_insts, (time + len_pad) // lst, lst)
    # midis_out_time = self.sparsemax_time(midis_out_time)
    # midis_out_time = midis_out_time.reshape(batch, n_insts, (time + len_pad))
    
    midis_final = midis_out_inst[:, :, :time] * midis_out_time[:, :, :time]
    return midis_final


class Sparsemax(nn.Module):
  def __init__(self, dim=-1):
    super().__init__()
    self.dim = dim

  def forward(self, input):
    original_size = input.size()
    input = input.contiguous().view(-1, input.size(self.dim))

    dim = 1
    number_of_logits = input.size(dim)  # 11

    input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)

    zs = torch.sort(input=input, dim=dim, descending=True)[0]
    range = torch.arange(start=1, end=number_of_logits + 1, device=input.device).view(1, -1)
    range = range.expand_as(zs).type(TCDTYPE)

    bound = 1 + range * zs
    cumulative_sum_zs = torch.cumsum(zs, dim)
    is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
    k = torch.max(is_gt * range, dim, keepdim=True)[0]

    zs_sparse = is_gt * zs

    taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
    taus = taus.expand_as(input)

    self.output = torch.max(torch.zeros_like(input), input - taus)

    output = self.output.view(original_size)
    return output

  def backward(self, grad_output):
    dim = 1

    nonzeros = torch.ne(self.output, 0)
    sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
    self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))

    return self.grad_input

## モデルモジュール - 0インサーター

In [None]:
class ZeroInserter(nn.Module):
  def __init__(self, insertion_rate):
    super().__init__()
    self.insertion_rate = insertion_rate

  def forward(self, downsampled_y):
    batch, ch, time = downsampled_y.shape
    upsampled_y = []
    for ch_idx in range(ch):
      ds_y = downsampled_y[:, ch_idx:ch_idx + 1, :]
      us_y = torch.cat((ds_y,
                        torch.zeros((batch, self.insertion_rate - 1, time),
                                    device=downsampled_y.device)), dim=1)
      us_y = us_y.transpose(2, 1)
      us_y = torch.reshape(us_y, (batch, 1, self.insertion_rate * time))
      upsampled_y.append(us_y)

    upsampled_y = torch.cat(upsampled_y, dim=1)
    return upsampled_y

## モデルモジュール - ドラムシンセサイザー

In [None]:
class FastDrumSynthesizer(nn.Module):
  def __init__(self, n_notes, drum_sets):
    super().__init__()
    self.drum_sets = drum_sets
    self.n_notes = n_notes

  def forward(self, midis):
    device_ = midis[0].device

    rv_insts = [self.drum_sets.random_choice(note_name).to(device_) for note_name in DRUM_NAMES]

    tracks = []
    for i in range(self.n_notes):
      md = midis[:, i:i + 1, :]
      track = fast_conv1d(md, torch.flip(rv_insts[i].expand(1, 1, -1), dims=(2,)))
      tracks.append(track)

    return torch.cat(tracks, dim=1)


# def complex_mul(t1, t2):
#   if t1.dim() != t2.dim():
#     raise ValueError('dim mismatch in complex_mul, {} and {}'.format(t1.dim(), t2.dim()))

#   if t1.dim() == 2:
#     r1, i1 = t1[:, 0], t1[:, 1]
#     r2, i2 = t2[:, 0], t2[:, 1]
#   elif t1.dim() == 3:
#     r1, i1 = t1[:, :, 0], t1[:, :, 1]
#     r2, i2 = t2[:, :, 0], t2[:, :, 1]
#   elif t1.dim() == 4:
#     r1, i1 = t1[:, :, :, 0], t1[:, :, :, 1]
#     r2, i2 = t2[:, :, :, 0], t2[:, :, :, 1]
#   else:
#     raise NotImplementedError
  
#   return torch.stack([r1*r2-i1*i2, r1*i2+i1*r2], dim=-1)


# def _rfft(x, signal_ndim=1, normalized=False, onesided=True):
#   odd_shape1 = (x.shape[1] % 2 != 0)
#   x_shape = x.shape
#   x = torch.fft.rfft(x)
#   x = torch.cat([x.real.unsqueeze(dim=2), x.imag.unsqueeze(dim=2)], dim=2)
#   if onesided == False:
#     _x = x[:, 1:, :].flip(dims=[1]).clone() if odd_shape1 else x[:, 1:-1, :].flip(dims=[1]).clone()
#     _x[:,:,1] = -1 * _x[:,:,1]
#     x = torch.cat([x, _x], dim=1)
#   if normalized == True:
#     p = 1
#     for i in x_shape:
#       p *= i
#     x /= math.sqrt(p)
#   return x


# def _irfft(x, signal_sizes, signal_ndim=1, normalized=False, onesided=True):
#   x_shape = x.shape
#   if onesided == False:
#     res_shape1 = x.shape[1]
#     x = x[:,:(x.shape[1] // 2 + 1),:]
#     x = torch.complex(x[:,:,0].float(), x[:,:,1].float())
#     x = torch.fft.irfft(x, n=res_shape1)
#   else:
#     x = torch.complex(x[:,:,0].float(), x[:,:,1].float())
#     x = torch.fft.irfft(x)
#   if normalized == True:
#     p = 1
#     for i in x_shape:
#       p *= i
#     x /= math.sqrt(p)
#   assert signal_sizes == x.shape[1], "こらっ！"
#   return x


def fast_conv1d(signal, kernel):
  batch, ch, L_sig = signal.shape
  assert ch == 1
  kernel = kernel.reshape(1, -1)
  L_I = kernel.shape[1]
  L_F = 2 << (L_I - 1).bit_length()
  L_S = L_F - L_I + 1

  device_ = signal.device
  pad_kernel = L_F - L_I
  FDir = torch.fft.rfft(torch.cat((kernel, torch.zeros(1, pad_kernel, device=device_)), dim=1))

  signal_sizes = L_F
  len_pad = (L_S - L_sig % L_S) % L_S
  offsets = range(0, L_sig, L_S)

  signal = torch.cat((signal, torch.zeros(batch, ch, len_pad, device=device_)), dim=2)

  result = torch.zeros(batch, 1, offsets[-1] + L_F).to(device_)
  pad_slice = L_F - L_S

  for idx_fr in offsets:
    idx_to_in = idx_fr + L_S
    idx_to_out = idx_fr + L_F
    to_rfft = torch.cat((signal[:, 0, idx_fr:idx_to_in],
                         torch.zeros(batch, pad_slice, device=device_)), dim=1)

    to_mul = torch.fft.rfft(to_rfft, norm="ortho")
    to_irfft = to_mul * FDir

    conved_slice = torch.fft.irfft(to_irfft, norm="ortho")
    result[:, 0, idx_fr: idx_to_out] += conved_slice

  return result[:, :, :L_sig]

## モデルモジュール - ミキサー

In [None]:
class Mixer(nn.Module):
  def forward(self, tracks, group_by=None):
    if group_by:
      return tracks[:, group_by, :].sum(dim=1)
    else:
      return tracks.sum(dim=1)

## モデル訓練



In [None]:
class Trainer(object):
  def __init__(self, model):
    lr = learning_rate
    self.model = model
    self.loss_histories = {'training': None, 'test': None}
    n_bins = cqt_bins
    self.cqters = {"c1": PseudoCqt(SR_WAV, 64, 1 * 32.703195, n_bins, n_bins),
                   "c2": PseudoCqt(SR_WAV, 64, 2 * 32.703195, n_bins, n_bins),
                   "c3": PseudoCqt(SR_WAV, 64, 4 * 32.703195, n_bins, n_bins),
                   "c4": PseudoCqt(SR_WAV, 64, 8 * 32.703195, n_bins, n_bins),
                   "c5": PseudoCqt(SR_WAV, 64, 16 * 32.703195, n_bins, n_bins),
                   "c6": PseudoCqt(SR_WAV, 64, 32 * 32.703195, n_bins, n_bins),
                   "c7": PseudoCqt(SR_WAV, 64, 2000, n_bins, n_bins)}
                   # "c8": PseudoCqt(SR_WAV, 64, 4000, n_bins, n_bins)}
    self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
    self.loss_domains = loss_domains
    metrics_dict = {'mae': F.l1_loss, 'mse': F.mse_loss}
    self.metric_funcs = [metrics_dict[m] for m in metrics]
    self.metric_names = metrics
    self.src_norm = source_norm
    self.hpss = compare_after_hpss

    if 'melgram' in self.loss_domains:
      self.mel_n_fft = 1024
      self.mel_fb = nn.Parameter(torch.from_numpy(librosa.filters.mel(
          sr=SR, n_fft=self.mel_n_fft, n_mels=n_mels, fmin=0.0, fmax=SR / 2.0)).type(TCDTYPE))
    if 'stft' in self.loss_domains:
      self.n_fft = 1024

    ddfs = [get_ddf_smt(), get_ddf_mdb()]
    self.evalers = [Evaluator(self.model, ddf, device=DEVICE) for ddf in ddfs]
    self.scores = {ddf.name: np.zeros((0, 4)) for ddf in ddfs}

  def train(self, epoch, max_epoch, tr_loader):
    for i in range(epoch):
      cur_epoch = i + max_epoch + 1
      print()
      print("{}epoch目".format(cur_epoch))
      self.train_epoch(tr_loader)
      torch.save(self.model.state_dict(), MODEL_PATH+f"/{cur_epoch}epoch_model.pth")
      # self.evaluate(result_subfolder=f"{cur_epoch}epoch")

  def train_epoch(self, tr_loader):
    def _loss_a_batch(mixes):
      self.optimizer.zero_grad()
      mixes = mixes.to(DEVICE, non_blocking=True)
      mix, est_mix, est_irs = self.model(mixes)

      losses = self._compute_loss(mix, est_mix, est_irs)
      loss = None
      for i, key in enumerate(losses):
        if i == 0:
          loss = losses[key]
        else:
          loss = loss + losses[key]

      return losses, loss

    def _train_a_batch(mixes):
      losses, loss = _loss_a_batch(mixes)
      loss.backward()
      self.optimizer.step()
      return losses

    self.model.train()
    bar = tqdm(enumerate(tr_loader), total=len(tr_loader))
    accum_losses = defaultdict(lambda: 0.)
    for batch_i, mixes in bar:
      losses = _train_a_batch(mixes)
      desc = ' '.join('{}:{:4.2f}'.format(k, v) for (k, v) in losses.items())
      bar.set_description(desc)
      for key in losses:
        accum_losses[key] += dcnp(losses[key])


  def evaluate(self, max_epoch):
    with torch.no_grad():
      result_subfolder = f"{max_epoch}epoch"

      self.model.eval()
      result_sub_path = os.path.join(RESULT_PATH, result_subfolder)
      os.makedirs(result_sub_path, exist_ok=True)
      keys = ['KD', 'SD', 'HH']
      for evaler in self.evalers:
        ddf_name = evaler.ddf.name
        print(ddf_name)
        path = os.path.join(result_sub_path, 'f1_scores_%s.pkl' % ddf_name)

        evaler.predict(verbose=True)
        evaler.pickpeaks(pickpeak_fix, verbose=True)
        evaler.mir_eval()
        evaler.save_and_print_result(result_sub_path)

        with open(path, 'wb') as f_write:
          pickle.dump(evaler.f_scores, f_write)
        np_path = os.path.join(result_sub_path, 'est_irs_%s.npz' % ddf_name)
        np.savez_compressed(np_path, *evaler.midis)

  def _compute_loss(self, mixes, est_mixes, est_impulses):
    losses = defaultdict(lambda: 0.)
    if self.hpss:
      mixes = pss_src(mixes)
      est_mixes = pss_src(est_mixes)

    if 'spectrum' in self.loss_domains:
      self._compute_spectrum_loss(mixes, est_mixes, losses=losses)
    if 'melgram' in self.loss_domains:
      self._compute_melgram_loss(mixes, est_mixes, losses=losses, weight=1.0)
    if 'stft' in self.loss_domains:
      self._compute_stft_loss(mixes, est_mixes, losses=losses, weight=1.0)
    if 'l1_reg' in self.loss_domains:
      losses['l1_reg'] = norm_losses(est_impulses, p=1, weight=self.args["l1_reg_lambda"])

    return losses

  def _compute_melgram_loss(self, mix, est_mix, losses, weight=1.0):
    loss_melgram(mix, est_mix, self.mel_fb.to(mix.device),
                 self.mel_n_fft, self.metric_funcs,
                 self.metric_names, losses)

  def _compute_stft_loss(self, mix, est_mix, losses, weight=1.0):
    loss_stft(mix, est_mix, self.n_fft, self.metric_funcs,
              self.metric_names, losses)

  def _compute_spectrum_loss(self, mix, est_mix, losses, weight=1.0, perceptual=False):
    for cqter_key in self.cqters:
      cqter = self.cqters[cqter_key]
      est_cqt = cqter(est_mix)
      org_cqt = cqter(mix)

      for metric_name, metric in zip(self.metric_names, self.metric_funcs):
        loss = weight * 4 * metric(org_cqt, est_cqt)
        if not torch.isnan(loss):
          losses[cqter_key + metric_name] = loss
        else:
          print('%s was NaN, it is not added to the total loss' % cqter_key)

def dcnp(torch_array):
  return torch_array.detach().cpu().numpy()

## 訓練モジュール - CQT

In [None]:
cqt_filter_fft = librosa.constantq.__cqt_filter_fft

class PseudoCqt:
  def __init__(self, sr=22050, hop_length=512, fmin=None, n_bins=84, bins_per_octave=12,
               tuning=0.0, filter_scale=1, norm=1, sparsity=0.01, window='hann', scale=True,
               pad_mode='reflect'):

        fft_basis, n_fft, _ = cqt_filter_fft(sr, fmin, n_bins, bins_per_octave,
                                             filter_scale, norm, sparsity,
                                             hop_length=hop_length, window=window)

        self.fft_basis = torch.tensor(np.array(np.abs(fft_basis.todense())), dtype=TCDTYPE,
                                      device=DEVICE)

        self.fmin = fmin
        self.fmax = fmin * 2 ** (float(n_bins) / bins_per_octave)
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.pad_mode = pad_mode
        self.scale = scale
        win = torch.zeros((self.n_fft,), device=DEVICE)
        win[self.n_fft // 2 - self.n_fft // 8:self.n_fft // 2 + self.n_fft // 8] = torch.hann_window(self.n_fft // 4)
        self.window = win

  def __call__(self, y):
        return self.forward(y)

  def forward(self, y):
        mag_stfts = torch.stft(y, self.n_fft,
                               hop_length=self.hop_length,
                               window=self.window, return_complex=False).pow(2).sum(-1)
        mag_stfts = torch.sqrt(mag_stfts + EPS)
        # C_torch = torch.stack([torch.sparse.mm(self.fft_basis, D_torch_row) for D_torch_row in D_torch])
        mag_melgrams = torch.matmul(self.fft_basis, mag_stfts)

        mag_melgrams /= torch.tensor(np.sqrt(self.n_fft), device=y.device)
        return to_log(mag_melgrams)

def to_log(mag_specgrams):
  return (torch.log10(mag_specgrams+EPS)-torch.log10(torch.tensor(EPS, device=mag_specgrams.device)))

## 評価データセットの取得

In [None]:
# IDMT-SMT-DRUMSの取得
def get_ddf_smt():
  ddf_smt = TestData(os.path.join(EVALDATA_PATH, 'SMT_DRUMS'), 'smt',
                              label_map=None, ann_folder='annotations',
                              audio_folder='audio')
  return ddf_smt


# ENST-DRUMSの取得
def get_ddf_enst():
  ddf_enst = TestData(os.path.join(EVALDATA_PATH, 'ENST_DTP(wet_mix-minus_one)'), 'enst',
                      label_map={'0': 'KD', '1': 'SD', '2': 'HH'},
                      ann_folder='annotations', audio_folder='audio')
  return ddf_enst

# MDBの取得
def get_ddf_mdb():
  ddf_mdb = TestData(os.path.join(EVALDATA_PATH, 'MDB_Drums'), 'mdb',
                              label_map=None, ann_folder='annotations/class',
                              audio_folder='audio/drum_only')
  return ddf_mdb


# 評価データの格納
class TestData:
  def __init__(self, path, name, label_map=None, ann_folder='annotations', audio_folder='audio'):
    self.name = name
    self.path = path
    self.label_map = label_map
    self.anno_fns, self.audio_fns = [], []
    self.ann_folder = ann_folder
    self.audio_folder = audio_folder
    self._scan_files()

  def _scan_files(self):
    anno_fns = os.listdir(os.path.join(self.path, self.ann_folder))
    anno_fns = [f for f in anno_fns if f.endswith('.txt')]
    self.anno_fns = sorted(anno_fns)

    audio_fns = os.listdir(os.path.join(self.path, self.audio_folder))
    audio_fns = [f for f in audio_fns if f.endswith('.wav')]
    self.audio_fns = sorted(audio_fns)
    assert len(self.anno_fns) == len(self.audio_fns), 'The number of files should be equal but %d and %d' % (len(self.anno_fns), len(self.audio_fns))
    self.n_files = len(self.audio_fns)

  def __iter__(self):
    self.n = 0
    return self

  def __len__(self):
    return self.n_files

  def __next__(self):
    if self.n < self.n_files:
      anno_fn = self.anno_fns[self.n]
      audio_fn = self.audio_fns[self.n]
      src, _ = librosa.load(os.path.join(self.path, self.audio_folder, audio_fn), sr=SR)
      onsets_tuple = mir_eval.io.load_labeled_events(os.path.join(self.path, self.ann_folder, anno_fn))
      onsets_dict = read_annotations_multilabel(onsets_tuple)
      onsets_dict = rename_key(onsets_dict, self.label_map)
      self.n += 1
      return src, onsets_dict
    else:
      raise StopIteration


warnings.filterwarnings(action='once')


def process_annotation(txtpath_r, txtpath_w, label_map, delimiter='\t'):
    """
    Args:
        txtpath_r (str): path to load the annotation
        txtpath_w (str): path to write
        label_map (dict): a dict, like {0:'KD', 1:'SD'} to process, for example
        delimiter (str): delimiter between events
    """
    with open(txtpath_r, 'r') as f_r:
        with open(txtpath_w, 'w') as f_w:
            for line in f_r:

                t, old_label = line.rstrip('\n').split(delimiter)
                t, old_label = t.strip(), old_label.strip()
                if label_map is not None:
                    new_label = label_map[old_label]
                else:
                    new_label = old_label
                f_w.write(delimiter.join([t, new_label]))
                f_w.write('\n')


def read_annotations_multilabel(onsets_tuple):
  labels = set(onsets_tuple[1])
  onsets_dict = {k: [] for k in labels}
  for t, label in [list(i) for i in zip(*onsets_tuple)]:
    onsets_dict[label].append(t)

  for key in onsets_dict:
    onsets_dict[key] = np.array(onsets_dict[key])
  return onsets_dict


def rename_key(old_dict, key_map=None):
  if key_map is None:
    return old_dict
  for old_key in old_dict:
    if old_key in key_map and old_key != key_map[old_key]:
      new_key = key_map[old_key]
      old_dict[new_key] = old_dict[old_key]
      del old_dict[old_key]
  return old_dict


def pickpeak_fix(impulse):
  div_max, div_avg, div_wait, div_thre = 20, 10, 16, 4

  impulse /= impulse.max()
  peak_idxs = librosa.util.peak_pick(impulse, SR // div_max, SR // div_max, SR // div_avg,
                                     SR // div_avg, 1.0 / div_thre, SR // div_wait)

  return librosa.samples_to_time(peak_idxs, sr=SR)

## 評価

In [None]:
class Evaluator(object):
  def __init__(self, model, ddf, device='cpu'):
    self.device = device
    self.model = model.to(self.device)
    self.ddf = ddf
    self.component_names = ['KD', 'SD', 'HH']
        
    self.lst = 1024
    self.max_nsp = self.lst * 100
    self.reset_data()
    self.midis = [None] * len(self.ddf)
    self.est_onsets = [None] * len(self.ddf)
    self.ref_onsets = [None] * len(self.ddf)
    self.reset_data()

  def reset_data(self):
    self.ndc = self.n_drum_components = len(self.component_names)
    self.f_scores = {k: [] for k in self.component_names}

  def save_and_print_result(self, path):
    if self.f_scores != {}:
      print('Means of F/P/R, Stds of F/P/R')
      for key in self.f_scores:
        songs_score = np.array(self.f_scores[key])
        mean_score, std_score = np.mean(songs_score, axis=0), np.std(songs_score, axis=0)
        print(key, mean_score, std_score)
        np.save(os.path.join(path, f"{self.ddf.name}_mean_score_{key}.npy"), arr=mean_score)
        np.save(os.path.join(path, f"{self.ddf.name}_std_score_{key}.npy"), arr=std_score)
    else:
      print('self.f_scores is blank, so nothing to print.')

  def predict(self, verbose=False):
    def send_pred_reduce(src):
      # pad = 4080
      # src = np.concatenate([np.zeros(pad, ), src, np.zeros(pad, )], axis=0)
      src = torch.tensor(src[np.newaxis, :], dtype=TCDTYPE).to(self.device)
      ret = self.model.forward(src)
      est_irs = ret[2]
      est_irs = est_irs[0].detach().cpu().numpy()
      return np.stack([est_irs[0], est_irs[1], est_irs[2:5].sum(axis=0)], axis=0).astype(np.float32)

    ddf_iter = iter(self.ddf)
    if verbose:
      bar = tqdm(enumerate(ddf_iter), total=len(self.ddf), desc='predicting..')
    else:
      bar = enumerate(ddf_iter)

    for song_idx, (src, onsets_dict) in bar:
      # prepare - make it multiple of lst
      len_pad = (self.lst - len(src) % self.lst) % self.lst
      if len_pad != 0:
        src = np.concatenate((src, np.zeros(len_pad, )), axis=0)
      src = src.astype(NPDTYPE)
      src = src / np.abs(src).max()

      # Do the prediction
      if len(src) >= self.max_nsp:
        has_residual = (len(src) % self.max_nsp) != 0
        midis = np.zeros((self.ndc, 0), dtype=np.float32)
        for i in range(len(src) // self.max_nsp + int(has_residual)):
          sub_midis = send_pred_reduce(src[i * self.max_nsp: (i + 1) * self.max_nsp])
          midis = np.concatenate((midis, sub_midis), axis=1)
      else:
        midis = send_pred_reduce(src)
      self.midis[song_idx] = midis.astype(np.float32)

  def pickpeaks(self, pp_func, verbose=False, **kwargs):
    ddf_iter = iter(self.ddf)
    if verbose:
      bar = tqdm(enumerate(ddf_iter), total=len(ddf_iter), desc='picking peaks...')
    else:
      bar = enumerate(ddf_iter)

    for song_idx, (src, onsets_dict) in bar:
      est_onset_song = []
      ref_onset_song = []
      for i, key in zip(range(self.ndc), self.component_names):
        est_onset = pp_func(self.midis[song_idx][i])  # onset positions
        est_onset_song.append(est_onset)

        if key in onsets_dict:
          ref_onset = onsets_dict[key]
        else:
          ref_onset = np.array([])
        ref_onset_song.append(ref_onset)
      self.est_onsets[song_idx] = np.array(est_onset_song, dtype=object)
      self.ref_onsets[song_idx] = np.array(ref_onset_song, dtype=object)

  def mir_eval(self):
    self.reset_data()
    for ref_onset, est_onset in zip(self.ref_onsets, self.est_onsets):
      for i, key in enumerate(self.component_names):
        f_score = mir_eval.onset.f_measure(ref_onset[i], est_onset[i])  # F, P, R
        self.f_scores[key].append(f_score)

  def illustrate_one(self, song_idx, img_folder, verbose=False):
    midis = self.midis[song_idx]
    est_onset = self.est_onsets[song_idx]
    ref_onset = self.ref_onsets[song_idx]
    if midis is None:
      if verbose:
        print('none...')
      return None

    # FIGURE 1
    plt.figure(figsize=(15, 3))
    for i in range(3):
      plt.subplot(3, 3, i + 1)
      display.waveplot(midis[i])
      plt.title(self.component_names[i] + ' est_irs')
      if i == 0:
        plt.title(self.component_names[i] + ' est_irs ' + str(song_idx) + ' ' + self.ddf.audio_fns[song_idx])

    for i, key in zip(range(self.ndc), self.component_names):  # KD, SD, HH
      # FIGURE 2
      plt.subplot(3, 3, i + 4)
      tmp = np.zeros_like(midis[i])
      np.put(tmp, librosa.time_to_samples(est_onset[i], sr=SR), np.ones(len(est_onset[i])))
      display.waveplot(tmp)
      plt.title('after peak picking')
      # FIGURE 3
      plt.subplot(3, 3, i + 7)
      tmp = np.zeros_like(midis[i])
      np.put(tmp, librosa.time_to_samples(ref_onset[i], sr=SR), np.ones(len(ref_onset[i])))
      display.waveplot(tmp)
      plt.title('reference')
      plt.savefig(os.path.join(img_folder + '/' + self.ddf.anno_fns[song_idx] + '.png'))
      if verbose:
        print('-%s: %3.0d %3.0d' % (key, len(ref_onset[i]), len(est_onset[i])), end='   ')
    if verbose:
      print('')

  def illustrate(self, img_folder):
    bar = tqdm(range(len(self.ddf)), total=len(self.ddf), desc='drawing..')
    for song_idx in bar:
      self.illustrate_one(song_idx, img_folder)

## 学習データセット

In [None]:
class DrumDataset(Dataset):
  def __init__(self, dir1, dir2):
    super().__init__()
    dir1 = dir1
    dir2 = dir2
    self.files = glob.glob(dir1+"/*")
    self.files.extend(glob.glob(dir2+"/*"))
    
  def __getitem__(self, idx):
    path = self.files[idx]
    src, _  = librosa.load(path, sr=SR_WAV, mono=True, duration=DURATION)
    return torch.from_numpy(librosa.util.normalize(src, axis=0))

  def __len__(self):
    return len(self.files)


def load_audio_file(filename, sample_rate=None, num_channels=None,
                    channel=None, start=None, stop=None, dtype=None,
                    replayagain_mode=None, replayagain_preamp=0.0):
    signal, file_sample_rate = librosa.core.load(filename, sr=sample_rate, mono=True)
    if start is not None:
        start = int(start * file_sample_rate)
    if stop is not None:
        stop = min(len(signal), int(stop * file_sample_rate))
    if start is not None or stop is not None:
        signal = signal[start: stop]
    return signal, file_sample_rate


class TxtSrcDataset(Dataset):
    def __init__(self, txt_path, src_path, duration, sr_wav, ext='mp3'):
        super(TxtSrcDataset, self).__init__()
        self.txt_path = txt_path
        self.src_path = src_path
        self.duration = duration
        self.sr_wav = sr_wav
        self.ext = ext
        self.lines = []
        self._read_txt()

    def _read_error(self, size):
        raise NotImplementedError()

    def _read_txt(self):
        raise NotImplementedError()

    def _read_audio(self, path, duration, file_dura):
        raise NotImplementedError()

    def _line_to_readpath(self, idx):
        raise NotImplementedError()

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        path, file_dura = self._line_to_readpath(idx)
        mix = self._read_audio(path, duration=self.duration, file_dura=file_dura)
        return torch.from_numpy(mix), torch.zeros(mix.shape), torch.zeros(mix.shape)


class TxtDrumstemDataset(TxtSrcDataset):
    """textfile-based datast but for drum stems
    """
    def __init__(self, *args, **kwargs):
        super(TxtDrumstemDataset, self).__init__(*args, **kwargs)

    def _read_error(self, size):
        return np.random.uniform(-0.01, 0.01, size=size).astype(NPDTYPE)

    def _read_txt(self):

        with open(self.txt_path) as f_read:
            for idx, line in enumerate(f_read):
                if int(float(line.rstrip('\n').split('\t')[1])) > self.duration + 1:
                    self.lines.append(line.rstrip('\n'))

    def _read_audio(self, path, duration, file_dura):
        start = np.random.choice(int(file_dura - duration))  # [second]
        try:
            src, _ = load_audio_file(path, sample_rate=SR_WAV, dtype=NPDTYPE, num_channels=1,
                                                     start=start, stop=start + duration)

            if len(src) < int(SR_WAV * duration):
                return np.concatenate(
                    (src, np.random.uniform(-0.01, 0.01, (NSP_SRC - len(src))).astype(NPDTYPE)), axis=0)
            return librosa.util.normalize(src, axis=0)

        except Exception as e:
            sys.stderr.write('AUDIO READ ERROR (%s): %s\n' % (path, e))
            return self._read_error(size=(int(SR_WAV * duration),))

    def _line_to_readpath(self, idx):
        filename, file_duration = self.lines[idx].split('\t')
        return '%s/%s' % (self.src_path, filename), float(file_duration)

# main(train・eval同)

In [None]:
warnings.filterwarnings('ignore', module='matplotlib')

#torch.multiprocessing.set_start_method('spawn', force=True)

torch.backends.cudnn.benchmark = True

inst_srcs = load_drum_srcs(idx=N_DRUM_VSTS)
inst_names = DRUM_NAMES

drum_sets = get_drumsets(source_norm)
model = DrumTranModel(inst_srcs, inst_names, drum_sets)
saved_models = os.listdir(MODEL_PATH)
max_epoch = 0
if saved_models:
  model_epochs = [int(model.split("epoch")[0]) for model in saved_models]
  max_epoch = max(model_epochs)
  model.load_state_dict(torch.load(MODEL_PATH+"/"+str(max_epoch)+"epoch_model.pth"))
  print(f"load {max_epoch}epoch_model")
model = model.to(DEVICE)
trainer = Trainer(model)

epoch = 5
start_num = 1

for i in range(10):
  drumstem_dataset = DrumDataset(STEM_PATH1+str(i+start_num), STEM_PATH2+str(i+start_num))
  tr_params = {"batch_size": batch_size, "shuffle": True, "num_workers": 2, "pin_memory":True, "drop_last": True}
  train_loader = DataLoader(drumstem_dataset, **tr_params)
  trainer.train(epoch, max_epoch, train_loader)
  trainer.evaluate(max_epoch)
  max_epoch += epoch