# Singing Voice Separation

- https://qiita.com/xiao_ming/items/88826e576b87141c4909

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [83]:
class UNet(nn.Module):
    
    def __init__(self):
        super(UNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.bn6 = nn.BatchNorm2d(512)
        
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.debn1 = nn.BatchNorm2d(256)
        self.deconv2 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)
        self.debn2 = nn.BatchNorm2d(128)
        self.deconv3 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)
        self.debn3 = nn.BatchNorm2d(64)
        self.deconv4 = nn.ConvTranspose2d(128, 32, kernel_size=4, stride=2, padding=1)
        self.debn4 = nn.BatchNorm2d(32)
        self.deconv5 = nn.ConvTranspose2d(64, 16, kernel_size=4, stride=2, padding=1)
        self.debn5 = nn.BatchNorm2d(16)
        self.deconv6 = nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        # encoder
        h1 = F.leaky_relu(self.bn1(self.conv1(x)))
#        print(h1.size())
        h2 = F.leaky_relu(self.bn2(self.conv2(h1)))
#        print(h2.size())
        h3 = F.leaky_relu(self.bn3(self.conv3(h2)))
#        print(h3.size())
        h4 = F.leaky_relu(self.bn4(self.conv4(h3)))
#        print(h4.size())
        h5 = F.leaky_relu(self.bn5(self.conv5(h4)))
#        print(h5.size())
        h6 = F.leaky_relu(self.bn6(self.conv6(h5)))
#        print(h6.size())
        
        # decoder
        dh = F.relu(F.dropout(self.debn1(self.deconv1(h6))))
        dh = F.relu(F.dropout(self.debn2(self.deconv2(torch.cat((dh, h5), dim=1)))))
        dh = F.relu(F.dropout(self.debn3(self.deconv3(torch.cat((dh, h4), dim=1)))))
        dh = F.relu(self.debn4(self.deconv4(torch.cat((dh, h3), dim=1))))
        dh = F.relu(self.debn5(self.deconv5(torch.cat((dh, h2), dim=1))))
        dh = torch.sigmoid(self.deconv6(torch.cat((dh, h1), dim=1)))

        return dh

In [84]:
model = UNet()

In [85]:
x = torch.rand(1, 1, 512, 128)
x.size()

torch.Size([1, 1, 512, 128])

In [86]:
out = model(x)

In [87]:
out.size()

torch.Size([1, 1, 512, 128])