In [1]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("using", device)

torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed(777)

  from .autonotebook import tqdm as notebook_tqdm


using cuda


In [2]:
class Wavenet(torch.nn.Module):
    def __init__(self, class_num, hidden_channels, cond_channels, layer_num, repeat_num):
        super(Wavenet, self).__init__()
        self.class_num = class_num
        self.hidden_channels = hidden_channels
        self.cond_channels = cond_channels
        self.layer_num = layer_num
        self.repeat_num = repeat_num
        self.casual_conv = torch.nn.Conv1d(class_num, hidden_channels, kernel_size=1)
        
        self.residual_list = torch.nn.ModuleList()
        for i in range(repeat_num):
            for j in range(layer_num):
                dil = 2**j
                self.residual_list.append(dil_and_act(hidden_channels, cond_channels, dil))
        
        self.relu = torch.nn.ReLU()
        self.out_conv1 = torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size=1)
        self.out_conv2 = torch.nn.Conv1d(hidden_channels, class_num, kernel_size=1)
        self.softmax = torch.nn.Softmax(dim=1)
    
    def forward(self, x, cond):
        skip = torch.zeros((x.shape[0], self.hidden_channels, x.shape[2]), dtype=torch.float, device=device)
        x = self.casual_conv(x)
        for i in range(self.repeat_num * self.layer_num):
            y = self.residual_list[i](x, cond)
            skip += y
            x = x + y
        
        z = self.relu(skip)
        z = self.out_conv1(z)
        z = self.relu(skip)
        z = self.out_conv2(z)
        z = self.softmax(z)
        return z


    

class dil_and_act(torch.nn.Module): # dilation & gated activation
    def __init__(self, hidden_channels, cond_channels, dilation):
        super(dil_and_act, self).__init__()
        self.hidden_channels = hidden_channels
        self.cond_channels = cond_channels
        self.pad = torch.nn.ConstantPad1d((dilation, 0), 0)
        self.dilation = dilation
        self.tanh = torch.nn.Tanh()
        self.sig = torch.nn.Sigmoid()
        self.dil_to_tanh = torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size=2, dilation=dilation)
        self.dil_to_sig = torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size=2, dilation=dilation)
        self.con_to_tanh = torch.nn.Conv1d(cond_channels, hidden_channels, kernel_size=2, dilation=dilation)
        self.con_to_sig = torch.nn.Conv1d(cond_channels, hidden_channels, kernel_size=2, dilation=dilation)
        self.out_conv = torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size=1)
    
    def forward(self, x, cond):
        pad_x = self.pad(x)
        pad_cond = self.pad(cond)
        tanh_in = self.dil_to_tanh(pad_x) + self.con_to_tanh(pad_cond)
        sig_in = self.dil_to_sig(pad_x) + self.con_to_sig(pad_cond)
        z = self.tanh(tanh_in) * self.sig(sig_in)
        z = self.out_conv(z)
        return z
        

In [3]:
net = Wavenet(256, 128, 128, 10, 2).to(device)

tensor([[[0.0036, 0.0036, 0.0034,  ..., 0.0028, 0.0027, 0.0029],
         [0.0029, 0.0027, 0.0028,  ..., 0.0029, 0.0028, 0.0024],
         [0.0052, 0.0046, 0.0043,  ..., 0.0048, 0.0048, 0.0043],
         ...,
         [0.0037, 0.0035, 0.0038,  ..., 0.0033, 0.0039, 0.0037],
         [0.0034, 0.0035, 0.0031,  ..., 0.0026, 0.0028, 0.0031],
         [0.0033, 0.0037, 0.0038,  ..., 0.0038, 0.0033, 0.0029]],

        [[0.0035, 0.0037, 0.0037,  ..., 0.0028, 0.0032, 0.0028],
         [0.0028, 0.0030, 0.0028,  ..., 0.0030, 0.0029, 0.0029],
         [0.0047, 0.0047, 0.0043,  ..., 0.0048, 0.0045, 0.0043],
         ...,
         [0.0039, 0.0037, 0.0032,  ..., 0.0036, 0.0035, 0.0037],
         [0.0033, 0.0035, 0.0033,  ..., 0.0035, 0.0024, 0.0028],
         [0.0037, 0.0035, 0.0037,  ..., 0.0032, 0.0029, 0.0034]],

        [[0.0037, 0.0039, 0.0033,  ..., 0.0029, 0.0030, 0.0033],
         [0.0031, 0.0031, 0.0027,  ..., 0.0027, 0.0029, 0.0025],
         [0.0046, 0.0042, 0.0042,  ..., 0.0044, 0.0042, 0.