In [1]:
import os
from random import randint
from collections import defaultdict
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('dark_background')

import torchaudio
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from data.dataset import SpeechDataset

In [3]:
data = '/data/deepvk/librispeech/dev_clean/dev-clean/'

In [4]:
class CPCModel(nn.Module):
    def __init__(self):
        super(CPCModel, self).__init__()
        strides = [5, 4, 2, 2, 2]
        kernels = [10,8, 4, 4, 4]
        padding = [2, 2, 2, 2, 1]
        self.N = 5
        
        self.convolutions = []
        for i in range(5):
            dim = 512
            if i == 0:
                dim = 1
            self.convolutions.append(nn.Conv1d(dim, 512, kernels[i], strides[i], padding[i]))
            self.convolutions.append(nn.ReLU())
            self.convolutions.append(nn.BatchNorm1d(512))
        self.convolutions = nn.Sequential(*self.convolutions)    
        
        self.autoregressor = nn.GRU(512, 256, batch_first=True)
        
        self.coupling_transforms = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Conv1d(
                    512, 512, kernel_size=1)
            )
            for steps in range(12)
        ])

    
    def forward(self, x):
        batch_size = x.size()[0]
        for conv in self.convolutions:
            x = conv(x)
        
        z = x.view(batch_size, -1, 512)
        ctx, state = self.autoregressor(z)
        z = z.permute(0, 2, 1)
        #print('sizes ar', ctx.size(), state.size())
        #print('size z', z.size())
        
        
        # https://github.com/ex4sperans/freesound-classification/blob/master/networks/cpc.py
        losses = []
        for i in range(len(self.coupling_transforms)):
            estimated = self.coupling_transforms[i](z)
            #print('est shape', estimated.size())
            
            logits = torch.bmm(z.permute(0, 2, 1), estimated) # b x f x f
        
            labels = torch.eye(logits.size(2) - i).cuda()
            labels = F.pad(labels, (0, i, i, 0))
            labels = labels.unsqueeze(0).expand_as(logits)
            
            loss = F.binary_cross_entropy_with_logits(logits, labels)
            losses.append(loss)
            
        return losses

## workbench

## model

In [5]:
ds = SpeechDataset(data)

In [6]:
model = CPCModel().cuda()

# Model summary

In [7]:
from torchsummaryX import summary

In [8]:
#t = ds[1][1].unsqueeze(0).unsqueeze(0)

#f = summary(model, t)

# training

In [9]:
from tensorboardX import SummaryWriter

In [10]:
opt = torch.optim.Adam(model.parameters(), lr=2e-4)
writer = SummaryWriter()
dl = torch.utils.data.DataLoader(ds, 15, shuffle=True)
abs_step = 0

In [11]:
for e in range(1, 10):
    print('Epoch %2d started' % e)
    
    for i, batch in enumerate(dl):
        opt.zero_grad()
        
        speakers, utters = batch
        losses = model(utters.unsqueeze(1).cuda())
        
        for j, loss in enumerate(losses):
            writer.add_scalar('Train/loss_%d' % j, loss.item(), abs_step)
            loss.backward(retain_graph=True)
        abs_step+=1
        
        opt.step()

Epoch  1 started
Epoch  2 started
Epoch  3 started


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'cpc')

In [None]:
lds = torchaudio.datasets.LIBRISPEECH('/data/deepvk/librispeech/train-clean-100/', url='train-clean-100')

In [1]:
import yaml

In [11]:
config = yaml.safe_load(open('conf.yaml'))

[{'sports': ['soccer',
   'football',
   'basketball',
   'cricket',
   'hockey',
   'table tennis']},
 {'countries': ['Pakistan',
   'USA',
   'India',
   'China',
   'Germany',
   'France',
   'Spain']}]