In [1]:
import os
from random import randint
from collections import defaultdict
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('dark_background')

import torchaudio
import torch
import torch.nn as nn
import torch.nn.functional as F

from hparams import Hparam
from data.dataset import SpeechDataset

In [2]:
config = Hparam('./cpc_config.yaml')

In [3]:
class CPCModel(nn.Module):
    def __init__(self, config):
        super(CPCModel, self).__init__()
        strides = [5, 4, 2, 2, 2]
        kernels = [10,8, 4, 4, 4]
        padding = [2, 2, 2, 2, 1]
        self.predict_steps = config.model.predict_steps
        channels = config.model.conv_channels
        
        
        self.convolutions = []
        for i in range(5):
            dim = config.model.conv_channels
            if i == 0:
                dim = 1
            self.convolutions.append(nn.Conv1d(dim, 
                                               channels, 
                                               kernels[i], 
                                               strides[i], 
                                               padding[i]))
            self.convolutions.append(nn.ReLU())
            self.convolutions.append(nn.BatchNorm1d(channels))
        self.convolutions = nn.Sequential(*self.convolutions)    
        
        self.autoregressor = nn.GRU(channels, 
                                    config.model.context_size, 
                                    batch_first=True)
        
        self.coupling_transforms = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Conv1d(
                    channels, channels, kernel_size=1)
            )
            for steps in range(self.predict_steps)
        ])

    
    def forward(self, x):
        batch_size = x.size()[0]
        for conv in self.convolutions:
            x = conv(x)
        
        z = x.view(batch_size, -1, 512)
        ctx, state = self.autoregressor(z)
        z = z.permute(0, 2, 1)

        
        # https://github.com/ex4sperans/freesound-classification/blob/master/networks/cpc.py
        losses = []
        for i in range(len(self.coupling_transforms)):
            estimated = self.coupling_transforms[i](z)
            #print('est shape', estimated.size())
            
            logits = torch.bmm(z.permute(0, 2, 1), estimated) # b x f x f
        
            labels = torch.eye(logits.size(2) - i).cuda()
            labels = F.pad(labels, (0, i, i, 0))
            labels = labels.unsqueeze(0).expand_as(logits)
            
            loss = F.binary_cross_entropy_with_logits(logits, labels)
            losses.append(loss)
            
        return losses

## workbench

## model

In [4]:
ds = SpeechDataset(config.data.path)

In [5]:
model = CPCModel(config).cuda()

# Model summary

In [6]:
from torchsummaryX import summary

In [14]:
t = ds[1][1].unsqueeze(0).unsqueeze(0)
f = summary(model, t.cuda())

                                    Kernel Shape    Output Shape     Params  \
Layer                                                                         
0_convolutions.Conv1d_0             [1, 512, 10]  [1, 512, 4095]     5.632k   
1_convolutions.ReLU_1                          -  [1, 512, 4095]          -   
2_convolutions.BatchNorm1d_2               [512]  [1, 512, 4095]     1.024k   
3_convolutions.Conv1d_3            [512, 512, 8]  [1, 512, 1023]  2.097664M   
4_convolutions.ReLU_4                          -  [1, 512, 1023]          -   
5_convolutions.BatchNorm1d_5               [512]  [1, 512, 1023]     1.024k   
6_convolutions.Conv1d_6            [512, 512, 4]   [1, 512, 512]  1.049088M   
7_convolutions.ReLU_7                          -   [1, 512, 512]          -   
8_convolutions.BatchNorm1d_8               [512]   [1, 512, 512]     1.024k   
9_convolutions.Conv1d_9            [512, 512, 4]   [1, 512, 257]  1.049088M   
10_convolutions.ReLU_10                        -   [

# training

In [8]:
from tensorboardX import SummaryWriter

In [9]:
opt = torch.optim.Adam(model.parameters(), lr=config.train.lr)
writer = SummaryWriter()
dl = torch.utils.data.DataLoader(ds, config.train.batch_size, shuffle=True)
abs_step = 0

In [10]:
for e in range(1, config.train.epochs):
    print('Epoch %2d started' % e)
    
    for i, batch in enumerate(dl):
        opt.zero_grad()
        
        speakers, utters = batch
        losses = model(utters.unsqueeze(1).cuda())
        
        for j, loss in enumerate(losses):
            writer.add_scalar('Train/loss_%d' % j, loss.item(), abs_step)
            loss.backward(retain_graph=True)
        abs_step+=1
        
        opt.step()
    
    if e % config.train.save_every == 0:
        torch.save(model.state_dict(), config.train.save_name + '_%d_epoch.pt' % e)

Epoch  1 started


KeyboardInterrupt: 

In [None]:
lds = torchaudio.datasets.LIBRISPEECH('/data/deepvk/librispeech/train-clean-100/', url='train-clean-100')