In [1]:
import os
from datetime import datetime as dt
from random import randint
from collections import defaultdict
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('dark_background')

import torchaudio
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from tensorboardX import SummaryWriter

from hparams import Hparam
from data.dataset import SpeechDataset
from CPC.model import CPCModel
from speaker_cpc_model import SpeakerClassificationCPC

In [2]:
config = Hparam('./CPC/config.yaml')
# config.train.device = 'cpu'
gettime = lambda: str(dt.time(dt.now()))[:8]

## workbench

## model

In [3]:
ds = SpeechDataset(config.data.path)
train_ds, test_ds = train_test_split(ds, test_size=0.2)

100%|██████████| 2703/2703 [00:05<00:00, 529.05it/s]


In [5]:
batch_size = 11
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=True, drop_last=True)
writer = SummaryWriter()
train_step = 0
test_step =  0

model_cpc = CPCModel(config).to(config.train.device)
model = SpeakerClassificationCPC(model_cpc, 512, ds.n_speakers).to(config.train.device)
model.load_cpc_checkpoint('checkpoints/cpc_model_35_epoch.pt')

opt = torch.optim.Adadelta(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

In [7]:
for e in range(1, 20):
    print('epoch %d' % e)
    print('  train')
    for batch in train_dl:
        opt.zero_grad()

        speakers, utters = batch
        #speakers = F.one_hot(speakers, num_classes=ds.n_speakers).to(config.train.device)
        speakers = speakers.long().to(config.train.device)
        logits = model(utters.unsqueeze(1).to(config.train.device))
        logits = logits.squeeze(0)
        loss = criterion(logits, speakers)

        loss.backward()
        opt.step()

        writer.add_scalar('loss/train', loss.item(), train_step)
        writer.add_scalar('accuracy/train', SpeakerClassificationCPC.get_acc(logits, speakers).item(), train_step)
        
        train_step += 1
        
        
    print('  test')
    for batch in test_dl:
        with torch.no_grad():
            speakers, utters = batch
            speakers = speakers.long().to(config.train.device)
            logits = model(utters.unsqueeze(1).to(config.train.device))
            logits = logits.squeeze(0)
            loss = criterion(logits, speakers)

            writer.add_scalar('loss/test', loss.item(), test_step)
            writer.add_scalar('accuracy/test', SpeakerClassificationCPC.get_acc(logits, speakers).item(), test_step)

            test_step += 1

epoch 1
  train
  test
epoch 2
  train
  test
epoch 3
  train
  test
epoch 4
  train
  test
epoch 5
  train
  test
epoch 6
  train
  test
epoch 7
  train
  test
epoch 8
  train
  test
epoch 9
  train
  test
epoch 10
  train
  test
epoch 11
  train


KeyboardInterrupt: 

In [6]:
model.cpc_model.requires_grad_(True)

CPCModel(
  (convolutions): Sequential(
    (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), padding=(2,))
    (1): ReLU()
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv1d(512, 512, kernel_size=(8,), stride=(4,), padding=(2,))
    (4): ReLU()
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv1d(512, 512, kernel_size=(4,), stride=(2,), padding=(2,))
    (7): ReLU()
    (8): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): Conv1d(512, 512, kernel_size=(4,), stride=(2,), padding=(2,))
    (10): ReLU()
    (11): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): Conv1d(512, 512, kernel_size=(4,), stride=(2,), padding=(1,))
    (13): ReLU()
    (14): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (autoregressor): GRU(512, 256, batch_first=True)
  (coupling_tran

In [None]:
model.load_state_dict(torch.load('speakers_clf.pt'))