In [1]:
import torch
import torchaudio

################################################################################
###          (please add 'export KALDI_ROOT=<your_path>' in your $HOME/.profile)
###          (or run as: KALDI_ROOT=<your_path> python <your_script>.py)
################################################################################

  '"sox" backend is being deprecated. '


In [2]:
import numpy as np
np_rng = np.random.default_rng(1)
import pandas as pd


import urllib.parse
from IPython.display import display, Markdown

import os

from lidbox.meta import (
    common_voice,
    generate_label2target,
    verify_integrity,
    read_audio_durations,
    random_oversampling_on_split
)


train = pd.read_csv("/tf/datasets/train.tsv", sep="\t")
test = pd.read_csv("/tf/datasets/test.tsv", sep="\t")
dev = pd.read_csv("/tf/datasets/dev.tsv", sep="\t")

train["path"] = train["path"].apply(lambda x: x[:-3] + "mp3")
test["path"] = test["path"].apply(lambda x: x[:-3] + "mp3")
dev["path"] = dev["path"].apply(lambda x: x[:-3] + "mp3")

train["split"] = "train"
test["split"] = "test"
dev["split"] = "dev"
#test = test.sample(30000, replace=False)
meta = pd.concat([train, test, dev])


In [3]:
meta.loc[meta["locale"] != "kz", "path"] = "/tf/datasets/data_untar/cv-corpus-6.1-2020-12-11/" +  meta.loc[meta["locale"] != "kz"]["locale"] + "/clips/" + meta.loc[meta["locale"] != "kz"]["path"]
targets = {"kz": 0, "ru": 1, "en":2, "other":3}
meta["target"] = meta["locale"]
meta.loc[(meta["locale"] != "kz") & (meta["locale"] != "ru") & (meta["locale"]!="en"), "target"] = "other"
meta = meta.loc[meta["path"] != "/tf/datasets/data_untar/cv-corpus-6.1-2020-12-11/kz/clips/5f590a130a73c.mp3"]
meta = meta.loc[meta["path"] != "/tf/datasets/data_untar/cv-corpus-6.1-2020-12-11/kz/clips/5ef9bd9ba7029.mp3"]

meta["id"] = meta["Unnamed: 0"].apply(str)
meta["target"] = meta["target"].map(targets)

meta

2021-06-16 02:55:50.369 I numexpr.utils: Note: detected 96 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
2021-06-16 02:55:50.370 I numexpr.utils: Note: NumExpr detected 96 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2021-06-16 02:55:50.371 I numexpr.utils: NumExpr defaulting to 8 threads.


Unnamed: 0.1,Unnamed: 0,path,locale,split,target,id
0,1486,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,ru,train,1,1486
1,56701,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,kz,train,0,56701
2,3364,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,ru,train,1,3364
3,110475,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,rw,train,3,110475
4,45384,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,en,train,2,45384
...,...,...,...,...,...,...
30751,2677,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,ru,dev,1,2677
30752,881,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,ru,dev,1,881
30753,68709,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,kz,dev,0,68709
30754,249025,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,en,dev,2,249025


In [29]:
from torch.utils.data import Dataset
import random
import math

def _get_sample(path, resample=None):
  effects = [
    ["remix", "1"]
  ]
  if resample:
    effects.append(["rate", f'{resample}'])
  return torchaudio.sox_effects.apply_effects_file(path, effects=effects)

SAMPLE_RIR_PATH = os.path.join(os.getcwd(), "rir.wav")

def get_rir_sample(*, resample=None, processed=False):
    rir_raw, sample_rate = _get_sample(SAMPLE_RIR_PATH, resample=resample)
    if not processed:
        return rir_raw, sample_rate
    rir = rir_raw[:, int(sample_rate*1.01):int(sample_rate*1.3)]
    rir = rir / torch.norm(rir, p=2)
    rir = torch.flip(rir, [1])
    return rir, sample_rate

class AudiosDataset(Dataset):
    def __init__(self, paths=None, targets=None, augment=False) -> None:
        self.paths = paths
        self.targets = targets
        self.augment = augment
        self.rir = get_rir_sample()[0]
        
    
    def __len__(self) -> int:
        return len(self.paths)

    def __getitem__(self, idx: int) -> dict:
        y, sr = torchaudio.load(self.paths.iloc[idx], normalization=True)
        """
        if self.augment:
            effects = [
                    ["lowpass", "-1", "300"], 
                    ["speed", f"{random.uniform(0.7, 1.3)}"],  # change speed
                  ]
            y, sr = torchaudio.sox_effects.apply_effects_tensor(
                y, sr, effects)
        #
            
        if self.augment:
            # augment sound in order to imitate the room change
            rir = self.rir[:, int(16000*1.1):int(16000*1.3)]
            rir = rir / torch.norm(rir, p=2)
            rir = torch.flip(rir, [1])
            y = torch.nn.functional.conv1d(y[None, ...], rir[None, ...])[0]
        
        """
        y = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(y)
        y = torchaudio.transforms.Vad(sample_rate = 16000)(y)

        """
        if self.augment:
            y = (0.5)*torch.randn(y.shape)
        """
        # convert to spectogram
        spectogram = torchaudio.transforms.MelSpectrogram()(y)
        #spectogram = torch.log(spectogram + 1e-5)
        melspectogram_db = torchaudio.transforms.AmplitudeToDB()(spectogram)
        
        #Make sure all spectrograms are the same size
        fixed_length = 12 * (16000//200)
        
        if melspectogram_db.shape[2] < fixed_length:
            melspectogram_db = torch.nn.functional.pad(
              melspectogram_db, (0, fixed_length - melspectogram_db.shape[2]))
        else:
            melspectogram_db = melspectogram_db[:, :, :fixed_length]
        
        spectogram = melspectogram_db
        
        if self.augment:

            spectogram = torchaudio.transforms.FrequencyMasking(100)(spectogram)
            spectogram = torchaudio.transforms.TimeMasking(100)(spectogram)
        
        # returning result
        result = {"spec": spectogram, "target":self.targets.iloc[idx]}

        return result

In [30]:
ds = AudiosDataset(meta["path"], meta["target"])

In [31]:
from torch.utils.data import TensorDataset, DataLoader
train_ds = AudiosDataset(meta.loc[meta["split"]=="train"]["path"], meta.loc[meta["split"]=="train"]["target"], augment=True)
val_ds = AudiosDataset(meta.loc[meta["split"]=="dev"]["path"], meta.loc[meta["split"]=="dev"]["target"])
test_ds = AudiosDataset(meta.loc[meta["split"]=="test"]["path"], meta.loc[meta["split"]=="test"]["target"])

In [32]:
for i in train_ds:
    print(i)
    break

{'spec': tensor([[[-100.0000, -100.0000, -100.0000,  ...,    0.0000,    0.0000,
             0.0000],
         [   4.6951,   12.7690,    8.0257,  ...,    0.0000,    0.0000,
             0.0000],
         [  12.0064,   20.0803,   15.3369,  ...,    0.0000,    0.0000,
             0.0000],
         ...,
         [  22.3348,   23.4628,   23.1085,  ...,    0.0000,    0.0000,
             0.0000],
         [  23.0709,   18.7987,   23.5499,  ...,    0.0000,    0.0000,
             0.0000],
         [  17.9024,   21.4223,   22.0901,  ...,    0.0000,    0.0000,
             0.0000]]]), 'target': 1.0}


  "At least one mel filterbank has all zero values. "
  normalized, onesided, return_complex)
  normalized, onesided, return_complex)


In [33]:
batch_size = 32
num_workers = 10
loaders = {
    "train": DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True,
    ),
    "valid": DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=True,
    ),
    "test":DataLoader(
        test_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=True,
    ),
}

In [16]:
from torchvision import  models
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
model = models.resnet34(pretrained=True)
model.conv1=nn.Conv2d(1, model.conv1.out_channels, 
                      kernel_size=model.conv1.kernel_size[0], 
                      stride=model.conv1.stride[0], 
                      padding=model.conv1.padding[0])
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 4)

In [183]:
from tqdm import tqdm

def train(model, opt, scheduler, loss_fn, epochs, data_tr, data_val, max_stable=5):
    best_val_loss = 1e9
    counter = 0
    for epoch in range(epochs):
        #tic = time()
        print('* Epoch %d/%d' % (epoch+1, epochs))

        avg_loss = 0
        model.train()  # train mode
        for batch in tqdm(data_tr):
            loss = 0
            # data to device
            X_batch, Y_batch = batch["spec"], batch["target"]
            #print(X_batch.shape)
            X_batch = X_batch.to(DEVICE)
            Y_batch = Y_batch.to(DEVICE)
            # set parameter gradients to zero
            opt.zero_grad()
            # forward
            Y_pred = model(X_batch)
            #print(Y_pred)
            loss = loss_fn(Y_pred, Y_batch)# forward-pass
            loss.backward()  # backward-pass
            opt.step()  # update weights
            if not scheduler is None:
                scheduler.step()
            # calculate loss to show the user
            avg_loss += loss / len(data_tr)
      #  toc = time()
        print('loss: %f' % avg_loss)
        # show intermediate results
        model.eval()  # testing mode
        val_loss = 0
        print("start validation")
        for v_b in tqdm(data_val):
            X_val, Y_val = v_b["spec"], v_b["target"]
            Y_hat = model(X_val.to(DEVICE)).detach().cpu()# detach and put into cpu
            val_loss += loss_fn(Y_hat, Y_val)
        val_loss /= len(data_val)
        print( f"validation loss: {val_loss}")
        if val_loss <= best_val_loss and val_loss > 0:
            counter = 0
            print("Save new model!")
            best_val_loss = val_loss
            torch.save(model.state_dict(), f'best_model.h5')
            best_patn = f'{epoch}_{best_val_loss}.h5'
        else:
            counter += 1
        if counter == max_stable:
            break

In [184]:
DEVICE = 'cuda'
max_epochs = 100
model = model.to(DEVICE)
#torch.cuda.empty_cache()
loss_fn =  nn.CrossEntropyLoss()
optimaizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimaizer, T_0=5, T_mult=1, eta_min=1e-8, last_epoch=-1)
train(model, optimaizer,scheduler, loss_fn, max_epochs, loaders["train"], loaders["valid"])

  0%|          | 0/2576 [00:00<?, ?it/s]

* Epoch 1/100


100%|██████████| 2576/2576 [07:11<00:00,  5.97it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.654073
start validation


100%|██████████| 961/961 [02:37<00:00,  6.09it/s]


validation loss: 0.8601071238517761
Save new model!
* Epoch 2/100


100%|██████████| 2576/2576 [07:37<00:00,  5.64it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.526455
start validation


100%|██████████| 961/961 [02:29<00:00,  6.44it/s]


validation loss: 0.3954550325870514
Save new model!
* Epoch 3/100


100%|██████████| 2576/2576 [06:54<00:00,  6.21it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.469784
start validation


100%|██████████| 961/961 [02:47<00:00,  5.73it/s]
  0%|          | 0/2576 [00:00<?, ?it/s]

validation loss: 0.5551520586013794
* Epoch 4/100


100%|██████████| 2576/2576 [07:28<00:00,  5.74it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.428406
start validation


100%|██████████| 961/961 [02:35<00:00,  6.19it/s]


validation loss: 0.5290905833244324
* Epoch 5/100


100%|██████████| 2576/2576 [07:22<00:00,  5.82it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.398306
start validation


100%|██████████| 961/961 [02:34<00:00,  6.21it/s]


validation loss: 0.38806942105293274
Save new model!


  0%|          | 0/2576 [00:00<?, ?it/s]

* Epoch 6/100


100%|██████████| 2576/2576 [06:55<00:00,  6.19it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.371820
start validation


100%|██████████| 961/961 [02:41<00:00,  5.97it/s]


validation loss: 0.4634988307952881
* Epoch 7/100


100%|██████████| 2576/2576 [07:26<00:00,  5.77it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.349526
start validation


100%|██████████| 961/961 [02:36<00:00,  6.15it/s]


validation loss: 0.3907879889011383
* Epoch 8/100


100%|██████████| 2576/2576 [07:26<00:00,  5.77it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.327548
start validation


100%|██████████| 961/961 [02:37<00:00,  6.09it/s]


validation loss: 0.2315412014722824
Save new model!


  0%|          | 0/2576 [00:00<?, ?it/s]

* Epoch 9/100


100%|██████████| 2576/2576 [07:14<00:00,  5.93it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.310639
start validation


100%|██████████| 961/961 [02:30<00:00,  6.37it/s]
  0%|          | 0/2576 [00:00<?, ?it/s]

validation loss: 0.29240575432777405
* Epoch 10/100


100%|██████████| 2576/2576 [07:18<00:00,  5.87it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.295653
start validation


100%|██████████| 961/961 [02:38<00:00,  6.08it/s]


validation loss: 0.5641764402389526
* Epoch 11/100


100%|██████████| 2576/2576 [07:22<00:00,  5.83it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.275224
start validation


100%|██████████| 961/961 [02:34<00:00,  6.23it/s]


validation loss: 0.5109254121780396
* Epoch 12/100


100%|██████████| 2576/2576 [07:15<00:00,  5.92it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.262520
start validation


100%|██████████| 961/961 [02:33<00:00,  6.25it/s]


validation loss: 0.23857124149799347
* Epoch 13/100


100%|██████████| 2576/2576 [07:13<00:00,  5.95it/s]
  0%|          | 0/961 [00:00<?, ?it/s]

loss: 0.250046
start validation


100%|██████████| 961/961 [02:35<00:00,  6.17it/s]


validation loss: 0.26080694794654846


In [17]:
model.load_state_dict(torch.load("best_model.h5"))

<All keys matched successfully>

In [186]:
true_labels = []
predicted_labels = []
for batch in loaders["test"]:
    model.eval()
    prediction = model(batch["spec"].to(DEVICE)).detach().cpu()
    predicted_labels.extend(torch.argmax(prediction, dim=1).tolist())
    true_labels.extend(batch["target"].tolist())


In [188]:
from sklearn.metrics import classification_report

report = classification_report(true_labels, predicted_labels, target_names=list(targets.keys()), labels=range(4))
print(report)



              precision    recall  f1-score   support

          kz       1.00      1.00      1.00     17337
          ru       0.79      0.88      0.83     10374
          en       0.82      0.87      0.85     12956
       other       0.84      0.73      0.78     15077

    accuracy                           0.87     55744
   macro avg       0.86      0.87      0.86     55744
weighted avg       0.88      0.87      0.87     55744



## Testing on VOX

In [18]:
import numpy as np
np_rng = np.random.default_rng(1)


import urllib.parse
from IPython.display import display, Markdown

import os

from lidbox.meta import (
    common_voice,
    generate_label2target,
    verify_integrity,
    read_audio_durations,
    random_oversampling_on_split
)


test = pd.read_csv("/tf/datasets/new_test.tsv", sep="\t")

test["path"] = test["path"].apply(lambda x: x[:-3] + "mp3")

test["split"] = "test"
meta = pd.concat([train, test, dev])


In [19]:
meta.loc[((meta["locale"] != "kz") & ~(((meta["split"] == "dev") | (meta["split"] == "test")) & ((meta["locale"] == "ru") | (meta["locale"] == "kz") | (meta["locale"] == "en")))), "path"] = "/tf/datasets/data_untar/cv-corpus-6.1-2020-12-11/" + meta.loc[((meta["locale"] != "kz") & ~(((meta["split"] == "dev") | (meta["split"] == "test")) & ((meta["locale"] == "ru") | (meta["locale"] == "kz") | (meta["locale"] == "en"))))]["locale"]  + "/clips/" + meta.loc[((meta["locale"] != "kz") & ~(((meta["split"] == "dev") | (meta["split"] == "test")) & ((meta["locale"] == "ru") | (meta["locale"] == "kz") | (meta["locale"] == "en"))))]["path"]
targets = {"kz": 0, "ru": 1, "en":2, "other":3}
meta["target"] = meta["locale"]
meta.loc[(meta["locale"] != "kz") & (meta["locale"] != "ru") & (meta["locale"]!="en"), "target"] = "other"
meta = meta.loc[meta["path"] != "/tf/datasets/data_untar/cv-corpus-6.1-2020-12-11/kz/clips/5f590a130a73c.mp3"]
meta = meta.loc[meta["path"] != "/tf/datasets/data_untar/cv-corpus-6.1-2020-12-11/kz/clips/5ef9bd9ba7029.mp3"]

meta["id"] = meta["Unnamed: 0"].apply(str)
meta["target"] = meta["target"].map(targets)


workdir = "/tf/datasets/transformer"


In [20]:
meta.loc[(meta["split"] == "test") & (meta["locale"] == "ru"), "path"] = meta.loc[(meta["split"] == "test") & (meta["locale"] == "ru")]["path"].apply(lambda x: f"/tf/datasets/vox/ru_test/{x}")
meta.loc[(meta["split"] == "test") & (meta["locale"] == "ru"), "path"]
meta.loc[(meta["split"] == "test") & (meta["locale"] == "kz"), "path"] = meta.loc[(meta["split"] == "test") & (meta["locale"] == "kz")]["path"].apply(lambda x: f"/tf/datasets/vox/kz_test/{x}")
meta.loc[(meta["split"] == "test") & (meta["locale"] == "kz"), "path"] 
meta.loc[(meta["split"] == "test") & (meta["locale"] == "en"), "path"] = meta.loc[(meta["split"] == "test") & (meta["locale"] == "en")]["path"].apply(lambda x: f"/tf/datasets/vox/en_test/{x}")
meta.loc[(meta["split"] == "test") & (meta["locale"] == "en"), "path"] 

0       /tf/datasets/vox/en_test/shrDRhToGpY__U__S133-...
1       /tf/datasets/vox/en_test/mzfg0RGJnV8__U__S123-...
2       /tf/datasets/vox/en_test/-_PPCH3y0eE__U__S1---...
3       /tf/datasets/vox/en_test/DQMxvGYyu6Q__U__S0---...
4       /tf/datasets/vox/en_test/x4lfSc7PrB0__U__S0---...
                              ...                        
9995    /tf/datasets/vox/en_test/KLiy94kfZI4__U__S133-...
9996    /tf/datasets/vox/en_test/YTlliEr5LOA__U__S113-...
9997    /tf/datasets/vox/en_test/bSs0gNq6Kkc__U__S0---...
9998    /tf/datasets/vox/en_test/Da7c-BY6MDA__U__S2---...
9999    /tf/datasets/vox/en_test/VWvPndMo1F8__U__S24--...
Name: path, Length: 10000, dtype: object

In [21]:
meta.loc[meta["split"]=="test", "Unnamed: 0"] = meta.loc[meta["split"]=="test"]["path"]

In [22]:
meta["id"] = meta["Unnamed: 0"].apply(str)

In [23]:
meta.loc[meta["split"] == "test", "id"] = meta.loc[meta["split"] == "test"]["path"]

In [24]:
meta = meta.set_index("Unnamed: 0")
meta.loc[meta["split"]=="test"]

Unnamed: 0_level_0,path,locale,split,target,id
Unnamed: 0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
/tf/datasets/vox/en_test/shrDRhToGpY__U__S133---0944.430-0958.260.mp3,/tf/datasets/vox/en_test/shrDRhToGpY__U__S133-...,en,test,2,/tf/datasets/vox/en_test/shrDRhToGpY__U__S133-...
/tf/datasets/vox/en_test/mzfg0RGJnV8__U__S123---0427.020-0444.670.mp3,/tf/datasets/vox/en_test/mzfg0RGJnV8__U__S123-...,en,test,2,/tf/datasets/vox/en_test/mzfg0RGJnV8__U__S123-...
/tf/datasets/vox/en_test/-_PPCH3y0eE__U__S1---0398.760-0403.940.mp3,/tf/datasets/vox/en_test/-_PPCH3y0eE__U__S1---...,en,test,2,/tf/datasets/vox/en_test/-_PPCH3y0eE__U__S1---...
/tf/datasets/vox/en_test/DQMxvGYyu6Q__U__S0---1473.480-1485.720.mp3,/tf/datasets/vox/en_test/DQMxvGYyu6Q__U__S0---...,en,test,2,/tf/datasets/vox/en_test/DQMxvGYyu6Q__U__S0---...
/tf/datasets/vox/en_test/x4lfSc7PrB0__U__S0---0125.230-0140.900.mp3,/tf/datasets/vox/en_test/x4lfSc7PrB0__U__S0---...,en,test,2,/tf/datasets/vox/en_test/x4lfSc7PrB0__U__S0---...
...,...,...,...,...,...
/tf/datasets/data_untar/cv-corpus-6.1-2020-12-11/it/clips/common_voice_it_20015623.mp3,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,it,test,3,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...
/tf/datasets/data_untar/cv-corpus-6.1-2020-12-11/uk/clips/common_voice_uk_23554602.mp3,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,uk,test,3,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...
/tf/datasets/data_untar/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_20416266.mp3,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,tr,test,3,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...
/tf/datasets/data_untar/cv-corpus-6.1-2020-12-11/it/clips/common_voice_it_20263173.mp3,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...,it,test,3,/tf/datasets/data_untar/cv-corpus-6.1-2020-12-...


In [25]:
meta.loc[meta["split"] == "test"] = meta.loc[(meta["split"] == "test") & (meta["target"] != 3)] 

In [26]:
meta.loc[meta["split"]=="test", "id"] = meta.loc[meta["split"]=="test"]["path"]
meta.loc[meta["split"]=="test"] 

Unnamed: 0_level_0,path,locale,split,target,id
Unnamed: 0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
/tf/datasets/vox/en_test/shrDRhToGpY__U__S133---0944.430-0958.260.mp3,/tf/datasets/vox/en_test/shrDRhToGpY__U__S133-...,en,test,2.0,/tf/datasets/vox/en_test/shrDRhToGpY__U__S133-...
/tf/datasets/vox/en_test/mzfg0RGJnV8__U__S123---0427.020-0444.670.mp3,/tf/datasets/vox/en_test/mzfg0RGJnV8__U__S123-...,en,test,2.0,/tf/datasets/vox/en_test/mzfg0RGJnV8__U__S123-...
/tf/datasets/vox/en_test/-_PPCH3y0eE__U__S1---0398.760-0403.940.mp3,/tf/datasets/vox/en_test/-_PPCH3y0eE__U__S1---...,en,test,2.0,/tf/datasets/vox/en_test/-_PPCH3y0eE__U__S1---...
/tf/datasets/vox/en_test/DQMxvGYyu6Q__U__S0---1473.480-1485.720.mp3,/tf/datasets/vox/en_test/DQMxvGYyu6Q__U__S0---...,en,test,2.0,/tf/datasets/vox/en_test/DQMxvGYyu6Q__U__S0---...
/tf/datasets/vox/en_test/x4lfSc7PrB0__U__S0---0125.230-0140.900.mp3,/tf/datasets/vox/en_test/x4lfSc7PrB0__U__S0---...,en,test,2.0,/tf/datasets/vox/en_test/x4lfSc7PrB0__U__S0---...
...,...,...,...,...,...
/tf/datasets/vox/kz_test/rCpb0p_lyxI__U__S25---0107.830-0127.780.mp3,/tf/datasets/vox/kz_test/rCpb0p_lyxI__U__S25--...,kz,test,0.0,/tf/datasets/vox/kz_test/rCpb0p_lyxI__U__S25--...
/tf/datasets/vox/kz_test/BkLVX9wf2YI__U__S26---0236.830-0241.550.mp3,/tf/datasets/vox/kz_test/BkLVX9wf2YI__U__S26--...,kz,test,0.0,/tf/datasets/vox/kz_test/BkLVX9wf2YI__U__S26--...
/tf/datasets/vox/kz_test/RqdH-JD8TpM__U__S78---0466.720-0470.860.mp3,/tf/datasets/vox/kz_test/RqdH-JD8TpM__U__S78--...,kz,test,0.0,/tf/datasets/vox/kz_test/RqdH-JD8TpM__U__S78--...
/tf/datasets/vox/kz_test/oCjW4Jy6azE__U__S110---0669.320-0675.220.mp3,/tf/datasets/vox/kz_test/oCjW4Jy6azE__U__S110-...,kz,test,0.0,/tf/datasets/vox/kz_test/oCjW4Jy6azE__U__S110-...


In [27]:
newtest_ds = AudiosDataset(meta.loc[meta["split"]=="test"]["path"], meta.loc[meta["split"]=="test"]["target"])

In [38]:
loaders["new_test"] = DataLoader(
        newtest_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=True,
    )
DEVICE = 'cuda'
model = model.to(DEVICE)

In [39]:
true_labels = []
predicted_labels = []
for batch in loaders["new_test"]:
    model.eval()
    prediction = model(batch["spec"].to(DEVICE)).detach().cpu()
    predicted_labels.extend(torch.argmax(prediction, dim=1).tolist())
    true_labels.extend(batch["target"].tolist())


In [40]:
from sklearn.metrics import classification_report

report = classification_report(true_labels, predicted_labels, target_names=list(targets.keys()), labels=range(4))
print(report)



              precision    recall  f1-score   support

          kz       0.35      0.67      0.46     13925
          ru       0.33      0.01      0.02     12107
          en       0.26      0.12      0.17     10000
       other       0.00      0.00      0.00         0

    accuracy                           0.30     36032
   macro avg       0.23      0.20      0.16     36032
weighted avg       0.32      0.30      0.23     36032



  _warn_prf(average, modifier, msg_start, len(result))
