In [1]:
!pip install wandb
!pip install torchaudio
!pip install editdistance==0.3.1

You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
Collecting editdistance==0.3.1
  Downloading editdistance-0.3.1.tar.gz (19 kB)
Building wheels for collected packages: editdistance
  Building wheel for editdistance (setup.py) ... [?25ldone
[?25h  Created wheel for editdistance: filename=editdistance-0.3.1-cp37-cp37m-linux_x86_64.whl size=212630 sha256=9eca6eb881dc4936347806cf5be2c916efb539cf7f1eab773c4721941924639c
  Stored in directory: /root/.cache/pip/wheels/a9/9e/6f/0c07a94bbfae707c540b9cd2d7be284e0bc02ecd1234a3b6ed
Successfully built editdistance
Installing collected packages: editdistance
Successfully installed editdistance-0.3.1
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [2]:
!pip install torch_optimizer

Collecting torch_optimizer
  Downloading torch_optimizer-0.0.1a16-py3-none-any.whl (51 kB)
[K     |████████████████████████████████| 51 kB 288 kB/s eta 0:00:011
Collecting pytorch-ranger>=0.1.1
  Downloading pytorch_ranger-0.1.1-py3-none-any.whl (14 kB)
Installing collected packages: pytorch-ranger, torch-optimizer
Successfully installed pytorch-ranger-0.1.1 torch-optimizer-0.0.1a16
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [4]:
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchaudio
import pandas as pd
import os
from torch.optim import Adam
from torch.utils.data import Subset, Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import torch_optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR

In [5]:
import string
SR = 16000
N_MELS = 80
AUDIO_LEN = 365472

TRAIN_DS = 'cv-valid-train.csv'
DEV_DS = 'cv-valid-dev.csv'

CHAR_VOCAB = {k: v for v, k in enumerate(['<>'] + list(string.ascii_lowercase) + [' '])}
TOK_VOCAB = {k:v for k, v in enumerate([''] + list(string.ascii_lowercase) + [' '])}
ALPHABET = np.array([''] + list(string.ascii_lowercase) + [' '])
#'<>' means blank

In [6]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [7]:
#https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, checkpoint, patience=7, verbose=False, delta=0, min_loss=np.inf):
        """
        :param
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None if min_loss == np.inf else  -min_loss
        self.early_stop = False
        self.val_loss_min = min_loss
        self.delta = delta
        self.checkpoint = checkpoint

    def __call__(self, val_loss, model, epoch):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, epoch):
        """
        Saves model when validation loss decrease.
        """
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': val_loss
            }, self.checkpoint)

        self.val_loss_min = val_loss

In [8]:
class TSC(nn.Module):
    def __init__(self, kernel_size, in_channels, out_channels, n_groups=1, 
                 dilation=1, stride=1):
        super(TSC, self).__init__()
        self.tsc = nn.Sequential(
            nn.Conv1d(in_channels, in_channels, kernel_size, 
                      dilation=dilation, stride=stride,
                      groups=in_channels, padding=dilation * kernel_size // 2),
            nn.Conv1d(in_channels, out_channels, 1, groups=n_groups),
            nn.BatchNorm1d(out_channels)
        )

    def forward(self, x):
        x = self.tsc(x)
        return x  


class TSCActivated(nn.Module):
    def __init__(self, kernel_size, in_channels, out_channels, n_groups=1, 
                 dilation=1, stride=1):
        super(TSCActivated, self).__init__()
        self.tsc = TSC(kernel_size, in_channels, out_channels, n_groups, 
                       dilation, stride)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.tsc(x)
        x = self.activation(x)
        return x  


class TSCBlock(nn.Module):
    def __init__(self, n_blocks, kernel_size, in_channels, out_channels,
                 n_groups=1, is_intermediate=False):
        super(TSCBlock, self).__init__()
        if is_intermediate:
            in_channels = out_channels
        self.n_blocks = n_blocks
        self.tsc_list = nn.ModuleList([TSCActivated(kernel_size, in_channels, out_channels, n_groups)])
        self.tsc_list.extend([TSCActivated(kernel_size, out_channels, out_channels, n_groups) 
                                  for i in range(1, self.n_blocks-1)])
        self.tsc_list.append(TSC(kernel_size, out_channels, out_channels, n_groups))
        self.pnt_wise_conv = nn.Conv1d(in_channels, out_channels, kernel_size=1, groups=n_groups)
        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(True)

    def forward(self, x):
        x_res = self.bn(self.pnt_wise_conv(x))
        for layer in self.tsc_list:
            x = layer(x)
        return self.relu(x + x_res)


class ConvBlock(nn.Module):
    def __init__(self, kernel_size, in_channels, out_channels, dilation=1, stride=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size, 
                      padding=kernel_size // 2, dilation=dilation, 
                      stride=stride),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class Debug(nn.Module):
    def __init__(self, msg=''):
        super().__init__()
        self.msg = msg
    
    def forward(self, x):
        print(f'{x.shape}\n{self.msg}')
        return x


class QuarzNet(nn.Module):
    def __init__(self, config):
        super(QuarzNet, self).__init__() 
        self.config = config
        self.s = config['s']
        self.net = nn.Sequential(
            TSCActivated(**config['c1']),
            *[TSCBlock(**config['b1'], is_intermediate=bool(i)) for i in range(s)],
            *[TSCBlock(**config['b2'], is_intermediate=bool(i)) for i in range(s)],
            *[TSCBlock(**config['b3'], is_intermediate=bool(i)) for i in range(s)],
            *[TSCBlock(**config['b4'], is_intermediate=bool(i)) for i in range(s)],
            *[TSCBlock(**config['b5'], is_intermediate=bool(i)) for i in range(s)],
            TSCActivated(**config['c2']),
            TSCActivated(**config['c3']),
            nn.Conv1d(**config['c4']),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x):
        x = self.net(x)
        return x

In [9]:
def make_param_dict(names, params):
    param_dict = {n : p for n, p in zip(names, params)}
    return param_dict

n_labels = len(CHAR_VOCAB)

c_names = ['kernel_size', 'in_channels', 'out_channels', 'dilation', 'stride']
c1, c2 = [33, N_MELS, 256, 1, 2], [87, 512, 512, 2, 1]
c3, c4 = [1, 512, 1024, 1, 1], [1, 1024, n_labels, 1, 1]

b_names = ['n_blocks', 'kernel_size', 'in_channels', 'out_channels', 'n_groups']
n_groups = 1
b1, b2 = [5, 33, 256, 256, n_groups], [5, 39, 256, 256, n_groups]
b3, b4 = [5, 51, 256, 512, n_groups], [5, 63, 512, 512, n_groups]
b5 = [5, 75, 512, 512, n_groups]
s = 1



config = {
    'c1': make_param_dict(c_names, c1),
    'b1': make_param_dict(b_names, b1),
    'b2': make_param_dict(b_names, b2),
    'b3': make_param_dict(b_names, b3),
    'b4': make_param_dict(b_names, b4),
    'b5': make_param_dict(b_names, b5),
    'c2': make_param_dict(c_names, c2),
    'c3': make_param_dict(c_names, c3),
    'c4': make_param_dict(c_names, c4),
    's' : s
}

In [10]:
class CharRNN(nn.Module):
    
    def __init__(self, tokens, n_hidden=612, n_layers=4,
                               drop_prob=0.5, lr=0.001):
        super().__init__()
        self.drop_prob = drop_prob
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.lr = lr
        self.chars = tokens
        self.int2char = dict(enumerate(self.chars))
        self.char2int = {ch: ii for ii, ch in self.int2char.items()}
        self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, 
                            dropout=drop_prob, batch_first=True)
        self.dropout = nn.Dropout(drop_prob)
        self.fc = nn.Linear(n_hidden, len(self.chars))
      
    
    def forward(self, x, hidden):
        ''' Forward pass through the network. 
            These inputs are x, and the hidden/cell state `hidden`. '''
                
        r_output, hidden = self.lstm(x, hidden)
        out = self.dropout(r_output)
        out = out.contiguous().view(-1, self.n_hidden)
        out = self.fc(out)
        
       
        
        # return the final output and the hidden state
        return out, hidden
    
    
    def init_hidden(self, batch_size):
        ''' Initializes hidden state '''
        weight = next(self.parameters()).data
        
        if (train_on_gpu):
            hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),
                  weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda())
        else:
            hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),
                      weight.new(self.n_layers, batch_size, self.n_hidden).zero_())
        
        return hidden

    
def one_hot_encode(arr, n_labels):
    one_hot = np.zeros((np.multiply(*arr.shape), n_labels), dtype=np.float32)
    one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.
    one_hot = one_hot.reshape((*arr.shape, n_labels))
    
    return one_hot

In [11]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
chars = tuple(["'"] + list(string.ascii_lowercase) + [' '])

with open('../input/simple-lm/rnn_25_epoch.pt', 'rb') as f:
    checkpoint = torch.load(f)

loaded = CharRNN(chars, n_hidden=checkpoint['n_hidden'], n_layers=checkpoint['n_layers'])
loaded.load_state_dict(checkpoint['state_dict'])
loaded.to(device)

print(loaded)

CharRNN(
  (lstm): LSTM(28, 512, num_layers=4, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=512, out_features=28, bias=True)
)


In [12]:
def predict(net, char, h=None, top_k=None):
        ''' Given a character, predict the next character.
            Returns the predicted character and the hidden state.
        '''
        x = np.array([[net.char2int[char]]])
        x = one_hot_encode(x, len(net.chars))
        inputs = torch.from_numpy(x)
        
        inputs = inputs.to(device)
        h = tuple([each.data for each in h])
        out, h = net(inputs, h)
        p = F.softmax(out, dim=1).data
        p = p.cpu() 
        return p, h

In [13]:
def get_lm_prob(net, text):
    if net is None:
        return 0.0
    net.eval()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    prob = 1.0
    text += "'"
    h = None
    for in_c, out_c in zip(text[:-1], text[1:]):
        x = np.array([[net.char2int[in_c]]])
        x = one_hot_encode(x, len(net.chars))
        inputs = torch.from_numpy(x)

        inputs = inputs.to(device)
        out, h = net(inputs, h)
        p = (F.softmax(out, dim=1).data).squeeze()
        prob *= p[net.char2int[out_c]].cpu().item() 
    return prob

In [14]:
print(get_lm_prob(loaded, 'hello'))
print(get_lm_prob(loaded, 'hella'))

1.0455580190219525e-08
6.207382350092737e-10


In [15]:
model = QuarzNet(config)

print("Total number of trainable parameters:", count_parameters(model))

for module in model.net:
    print(count_parameters(module))

Total number of trainable parameters: 6742460
23968
441344
449024
1439744
1745920
1776640
308736
528384
28700
0


In [16]:
def clean(text):
    text = text.translate(str.maketrans('', '', string.punctuation))
    text = text.lower()
    return text

def measure_len(root, folder, filename):
    audio_file = os.path.join(os.path.join(root, folder), filename)
    file, sr = torchaudio.load_wav(audio_file)
    return file.shape[1]

def preprocess_targets(root, folder, lblpath):
    target_file = os.path.join(root, lblpath)
    targets = pd.read_csv(target_file)
    targets = targets.dropna(subset=['text'])
    new_targets = pd.DataFrame({'filename' : targets['filename'].values})
    new_targets['cleaned_text'] = targets['text'].apply(clean)
    new_targets['audio_len'] = targets['filename'].apply(lambda x: measure_len(root, folder, x))
    new_targets = new_targets[new_targets['audio_len'] <= AUDIO_LEN]
    new_targets.to_csv(lblpath, index=False)

In [17]:
'''
dev_folder = 'cv-valid-dev'
preprocess_targets('../input/common-voice/', dev_folder, DEV_DS)
train_folder = 'cv-valid-train'
preprocess_targets('../input/common-voice/', train_folder, TRAIN_DS)
'''

"\ndev_folder = 'cv-valid-dev'\npreprocess_targets('../input/common-voice/', dev_folder, DEV_DS)\ntrain_folder = 'cv-valid-train'\npreprocess_targets('../input/common-voice/', train_folder, TRAIN_DS)\n"

In [18]:
'''
#audio_len = targets['audio_len'].quantile(0.95)
targets = pd.read_csv('cv-valid-dev.csv')
print(len(targets.index))
targets.head()
#targets = targets[targets['audio_len'] < AUDIO_LEN]
targets = targets[targets['cleaned_text'] != 'undefined']
targets.to_csv('cv-valid-dev.csv')

targets = pd.read_csv('cv-valid-train.csv')
print(len(targets.index))
targets.head()
#targets = targets[targets['audio_len'] < AUDIO_LEN]
targets = targets[targets['cleaned_text'] != 'undefined']
targets.to_csv('cv-valid-train.csv')
'''

"\n#audio_len = targets['audio_len'].quantile(0.95)\ntargets = pd.read_csv('cv-valid-dev.csv')\nprint(len(targets.index))\ntargets.head()\n#targets = targets[targets['audio_len'] < AUDIO_LEN]\ntargets = targets[targets['cleaned_text'] != 'undefined']\ntargets.to_csv('cv-valid-dev.csv')\n\ntargets = pd.read_csv('cv-valid-train.csv')\nprint(len(targets.index))\ntargets.head()\n#targets = targets[targets['audio_len'] < AUDIO_LEN]\ntargets = targets[targets['cleaned_text'] != 'undefined']\ntargets.to_csv('cv-valid-train.csv')\n"

In [19]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchaudio

from torchvision.transforms import Compose

#https://github.com/toshiks/number_recognizer/blob/master/app/dataset/preprocessing.py

class LogMelSpectrogram(nn.Module):
    """
    Create spectrogram from raw audio and make
    that logarithmic for avoiding inf values.
    """
    def __init__(self, sample_rate: int = 16000, n_mels: int = 128):
        super(LogMelSpectrogram, self).__init__()
        self.transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels,
                                                              n_fft=1024, hop_length=256,
                                                              f_min=0, f_max=8000)

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        spectrogram = self.transform(waveform)
        log_mel = torch.log(spectrogram + 1e-9)
        return (log_mel - log_mel.mean()) / (log_mel.std() + 1e-9)


class MelAug(nn.Module):
    def __init__(self):
        super(MelAug, self).__init__()
        self.transforms = nn.Sequential(
            torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
            torchaudio.transforms.TimeMasking(time_mask_param=15)
        )

    def forward(self, melspec: torch.Tensor) -> torch.Tensor:
        return self.transforms(melspec)

class WavAug(nn.Module):
    def __init__(self):
        super(WavAug, self).__init__()

    def forward(self, wav):
        gain = torch.rand((1,)).item()
        fade_const = 20
        fade = torch.randint(low=0, high=fade_const, size=(1,)).item()
        transform = nn.Sequential(
            #torchaudio.transforms.Resample(48000, SR),
            torchaudio.transforms.Vol(gain),
            torchaudio.transforms.Fade(fade, fade_const - fade)
        )
        return transform(wav)


In [20]:
import torch.nn as nn
from torchvision.datasets import DatasetFolder
from torch.utils.data import Dataset, ConcatDataset, Subset, DataLoader
from torch.utils.data import WeightedRandomSampler
from sklearn.model_selection import train_test_split
from torch.nn.utils.rnn import pad_sequence
import os

class CommonVoiceDataset(Dataset):
    
    def __init__(self, root, lblpath, transform=None):
        super(CommonVoiceDataset).__init__()
        self.root = root
        self.targets = None
        self.transform = None
        meta = pd.read_csv(lblpath)
        self.files = meta.filename.values
        self.targets = meta.cleaned_text.values
        if transform is not None:
            self.transform = transform
        
        
    def __getitem__(self, idx):
        filepath = os.path.join(self.root, self.files[idx])
        mp3, sr = torchaudio.load_wav(filepath)
        if self.transform is not None:
            mp3 = self.transform(mp3)
        mp3 = mp3.squeeze()
        target = [CHAR_VOCAB[c] for c in self.targets[idx].lower()]
        n_frames = mp3.shape[0] // 256 + 1 # hop_length = 256
        return mp3, n_frames // 2, torch.Tensor(target).type(torch.int), len(target)

  

    def __len__(self):
        return len(self.files)


def collate_fn(batch):
    X, X_lens, y, y_lens = zip(*batch)
    X = pad_sequence(X, batch_first=True)
    y = pad_sequence(y, batch_first=True)
    return X, torch.Tensor(X_lens).type(torch.int32), y, torch.Tensor(y_lens).type(torch.int32)


def make_loader(root, lblpath, transform=None, bs=512, train=True):
    dataset = CommonVoiceDataset(root, lblpath, transform)
    meta = pd.read_csv(lblpath)
    weights = torch.ones_like(torch.Tensor(meta.index), dtype=torch.float32)
    if train:
        weights = SR * 5 / meta['audio_len'].values
    sampler = WeightedRandomSampler(weights, num_samples=len(weights))
    loader = DataLoader(dataset, batch_size=bs, num_workers=0, pin_memory=True, 
                              collate_fn=collate_fn, drop_last=True, sampler=sampler)
    return loader

In [21]:
import editdistance

def wer(pred, lbl):
    lbl_tok = lbl.split()
    return editdistance.eval(pred.split(), lbl_tok) / len(lbl_tok) * 100

def cer(pred, lbl):
    lbl = lbl.strip()
    return editdistance.eval(pred, lbl) / len(lbl) * 100

In [22]:
def to_text(pred, target):
    pred_shifted = np.append(pred[1:], 0)
    char_pred = pred[pred != pred_shifted]
    text_pred = ''.join(ALPHABET[char_pred].squeeze().tolist())
    target = target.squeeze().cpu().numpy()
    text_target = ''.join(ALPHABET[target].tolist())
    return text_pred, text_target
    

In [23]:
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, f1_score, classification_report
from tqdm.notebook import tqdm
from itertools import groupby

def remove_dups(text_list):
    return [i[0] for i in groupby(text_list.cpu().detach())]

def train(epochs, model, optimizer, scheduler, device, early_stopping,
          train_loader, valid_loader=None, grad_acum=1, criterion=nn.CTCLoss()):
    process = nn.Sequential(
          LogMelSpectrogram(SR, N_MELS).to(device),
          MelAug().to(device)
    )
    clip = 15
    val_table = wandb.Table(columns=["Epoch", "Predicted Text", "True Text"])
    for epoch in range(epochs):
        optimizer.zero_grad()
        tr_loss, val_loss = 0, 0
        train_wer, train_cer = 0, 0
        tr_steps, val_steps = 0, 0
        for batch in tqdm(train_loader):
            model.train()
            train_input, input_lengths, target, target_lengths = batch
            target = target.to(device, non_blocking=True)
            input_lengths = input_lengths.to(device, non_blocking=True)
            target_lengths = target_lengths.to(device, non_blocking=True)
            X = process(train_input.to(device))
            out = model(X).permute(2, 0, 1)
            loss = criterion(out, target, input_lengths, target_lengths)
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            tr_loss += loss.item()
            loss.backward()
            wandb.log({'loss/train' : tr_loss / (tr_steps + 1)})
            tr_steps += 1
            if (tr_steps % grad_acum) == 0:
                optimizer.step()
                optimizer.zero_grad()
                
            pred = torch.argmax(out, dim=2).squeeze().cpu().detach().numpy()
            bs = pred.shape[1]
            for pred_el, target_el in zip(pred.transpose(1, 0), target):
                text_pred, text_target = to_text(pred_el.squeeze(), target_el)
                
                train_wer += wer(text_pred)
                train_cer += cer(text_target)
                
        print(f'train wer: {train_wer / (len(train_loader) * bs)}, train_cer: {train_cer / (len(train_loader) * bs)}')
        val_cer, val_wer = 0, 0
        if valid_loader is not None:
            for batch in tqdm(valid_loader):
                model.eval()
                with torch.no_grad():
                    val_input, input_lengths, target, target_lengths = batch
                    target = target.to(device, non_blocking=True)
                    input_lengths = input_lengths.to(device, non_blocking=True)
                    target_lengths = target_lengths.to(device, non_blocking=True)
                    X = process(val_input.to(device))    
                    out = model(X).permute(2, 0, 1)
                    loss = criterion(out, target, input_lengths, target_lengths)

                    val_loss += loss.item()
                    pred = torch.argmax(out, dim=2).squeeze().cpu().detach().numpy()
                    text_pred, text_target = to_text(pred, target)
                    val_wer += wer(text_pred, text_target)
                    val_cer += cer(text_pred, text_target)
                    if val_steps < 5:
                        val_table.add_data(epoch, text_pred, text_target)
                        print(f'prediction: {text_pred}\nlabel: {text_target}')        
                    val_steps += 1
                    wandb.log({'loss/val' : val_loss / (val_steps + 1)})
            wandb.log({'wer/val' : val_wer / len(valid_loader)})
            wandb.log({'cer/val' : val_cer / len(valid_loader)})
            early_stopping(val_loss, model, epoch)
            scheduler.step(val_loss)
            if early_stopping.early_stop:
                print("Early stopping")
                break
    wandb.log({"val examples": val_table})

In [24]:
from itertools import islice

lr = 1e-2
epochs = 10
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
wandb.init(project='dla hw2', name='CommonVoice weighted sampling')
model = QuarzNet(config).to(device)
wandb.watch(model)
optimizer = torch_optimizer.NovoGrad(
                        model.parameters(),
                        lr=lr,
                        betas=(0.8, 0.5),
                        weight_decay=0.001,
)

checkpoint = torch.load('../input/checkpoint/checkpoint (2)')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0, last_epoch=-1)
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
early_stopping = EarlyStopping(checkpoint='./checkpoint', patience=10, verbose=True)
wav_aug = WavAug()
train_loader = make_loader('../input/common-voice/cv-valid-train', '../input/checkpoint/cv-valid-train.csv',
                           transform=wav_aug, bs=96, train=True)
dev_loader = make_loader('../input/common-voice/cv-valid-dev', '../input/checkpoint/cv-valid-dev.csv',
                         transform=wav_aug, bs=1, train=False)

[34m[1mwandb[0m: Currently logged in as: [33marinaruck[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.8 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [25]:
from collections import defaultdict
import heapq

ALPHABET = np.array([''] + list(string.ascii_lowercase) + [' '])

def update_pred(pred, c):
    if pred != '' and c == pred[-1]: 
        return pred
    return pred + c
  
def beam_search(probs, lm_model, beam_width=256, alpha=1, beta=1, gamma=0.1):
    cache = {}
    beam = {'' : 1.0}
    probs = probs.squeeze()
    for frame in probs:
        curr_beam = defaultdict(float)
        for prefix, prob in beam.items():
            for c, p in enumerate(frame):
                pred = update_pred(prefix, ALPHABET[c])
                curr_beam[pred] +=  prob * (alpha * p.item() + beta * get_lm_prob(lm_model, pred) + gamma * len(pred))
        beam_items = heapq.nlargest(beam_width, list(curr_beam.items()), key=lambda x: x[1])
        beam = {k: v for k, v in beam_items}
            
        best_pred, best_prob = heapq.nlargest(1, beam.items(), key=lambda x: x[1])[0]
    return best_pred.strip()

In [26]:
val_dataset = CommonVoiceDataset('../input/common-voice/cv-valid-dev', '../input/checkpoint/cv-valid-dev.csv')
val_subset = Subset(val_dataset, list(range(32)))
subset_loader = DataLoader(val_subset, batch_size=1, num_workers=0, pin_memory=True, 
                            collate_fn=collate_fn, drop_last=True)

process = nn.Sequential(
        LogMelSpectrogram(SR, N_MELS).to(device)
)

table = wandb.Table(columns=["Argmax Text", "Beam search Text", "True Text"])        
val_cer, val_wer = 0, 0
val_steps = 0
for batch in tqdm(subset_loader):
    model.eval()
    with torch.no_grad():
        val_input, input_lengths, target, target_lengths = batch
        target = target.to(device, non_blocking=True)
        input_lengths = input_lengths.to(device, non_blocking=True)
        target_lengths = target_lengths.to(device, non_blocking=True)
        X = process(val_input.to(device))    
        out = model(X).permute(2, 0, 1)
        pred = torch.argmax(out, dim=2).squeeze().cpu().detach().numpy()
        text_pred, text_target = to_text(pred, target)
        text_beam = beam_search(torch.exp(out), None, beam_width=64, alpha=1, beta=0, gamma=0)
        val_wer += wer(text_beam, text_target)
        val_cer += cer(text_beam, text_target)
        if val_steps < 5:
            table.add_data(text_pred, text_beam, text_target)
            print(f'argmax prediction: {text_pred}\nbeam seach prediction: {text_beam}\nlabel: {text_target}')   
        val_steps += 1
print(f'wer: {val_wer / val_steps}, cer: {val_cer / val_steps}')
wandb.log({"examples": table})

HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

argmax prediction: be careful lith you pronostications had the stranger
beam seach prediction: be careful lith you prognostications had the stranger
label: be careful with your prognostications said the stranger
argmax prediction: anfie sheud ut u ee the plize o efleee un
beam seach prediction: ined fie sheud ut u e the plize o efle une
label: then why should they be surprised when they see one
argmax prediction: a oung arab waked the ledde own ad package anded and geited the englishman
beam seach prediction: a oung arab waked the leade downd ad package anded and geieted the englishman
label: a young arab also loaded down with baggage entered and greeted the englishman
argmax prediction: i felhtd that everything thi owoed would ee de twroeeed
beam seach prediction: i felht that everything thi owed would e de twroed
label: i thought that everything i owned would be destroyed
argmax prediction: he woved abant invisible but every one could hear him
beam seach prediction: he woved aband in

In [None]:
val_dataset = CommonVoiceDataset('../input/common-voice/cv-valid-dev', '../input/checkpoint/cv-valid-dev.csv')
val_subset = Subset(val_dataset, list(range(3)))
subset_loader = DataLoader(val_subset, batch_size=1, num_workers=0, pin_memory=True, 
                            collate_fn=collate_fn, drop_last=True)

process = nn.Sequential(
        LogMelSpectrogram(SR, N_MELS).to(device)
)

table = wandb.Table(columns=["Argmax Text", "Beam search Text", "True Text"])        
val_cer, val_wer = 0, 0
val_steps = 0
for batch in tqdm(subset_loader):
    model.eval()
    with torch.no_grad():
        val_input, input_lengths, target, target_lengths = batch
        target = target.to(device, non_blocking=True)
        input_lengths = input_lengths.to(device, non_blocking=True)
        target_lengths = target_lengths.to(device, non_blocking=True)
        X = process(val_input.to(device))    
        out = model(X).permute(2, 0, 1)
        pred = torch.argmax(out, dim=2).squeeze().cpu().detach().numpy()
        text_pred, text_target = to_text(pred, target)
        text_beam = beam_search(torch.exp(out), loaded, beam_width=16, alpha=1, beta=0.5, gamma=1e-3)
        val_wer += wer(text_beam, text_target)
        val_cer += cer(text_beam, text_target)
        if val_steps < 5:
            table.add_data(text_pred, text_beam, text_target)
            print(f'argmax prediction: {text_pred}\nbeam seach + lm prediction: {text_beam}\nlabel: {text_target}')   
        val_steps += 1
print(f'wer: {val_wer / val_steps}, cer: {val_cer / val_steps}')
wandb.log({"examples": table})

HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

argmax prediction: be careful lith you pronostications had the stranger
beam seach + lm prediction: tions had the stranger
label: be careful with your prognostications said the stranger


In [None]:
train(epochs, model, optimizer, scheduler, device, early_stopping, train_loader, dev_loader)