In [1]:
from keyword_spotter.dataset import ZindiAudioDataset
from keyword_spotter.models import PalSolModel
from keyword_spotter.lightning import PalSolClassifier, ZindiDataModule
import pandas as pd
from pathlib import Path
from torchsummary import summary
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torch
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import train_test_split

In [2]:
pl.seed_everything(14300631)

14300631

In [3]:
data_dir = Path('../data')
df = pd.read_csv(data_dir / 'train.csv')

In [4]:
train_df, val_df = train_test_split(df, stratify=df['label'], test_size=0.3, random_state=42)

In [5]:
train_ds = ZindiAudioDataset(
    54243,
    22050,
    [data_dir / path for path in train_df['fn'].values],
    train_df['label'].values,
)
val_ds = ZindiAudioDataset(
    54243,
    22050,
    [data_dir / path for path in val_df['fn'].values],
    val_df['label'].values,
)

HBox(children=(FloatProgress(value=0.0, description='Loading files', max=776.0, style=ProgressStyle(descriptio…




HBox(children=(FloatProgress(value=0.0, description='Loading files', max=333.0, style=ProgressStyle(descriptio…




In [6]:
train_loader = DataLoader(train_ds, batch_size=128)
val_loader = DataLoader(val_ds, batch_size=128)

In [7]:
clf = PalSolClassifier(num_classes=len(train_ds.le.classes_), sample_rate=22050)

In [8]:
from pytorch_lightning.loggers import NeptuneLogger

neptune_logger = NeptuneLogger(
    project_name='astromid/zindi-keyword-spotter',
    experiment_name='test',
    params={'max_epochs': 50},
    api_key='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiNGJmN2RiYjMtZmNjZS00YmQyLWI3MjItZTk2YzMwMmY3YjI0In0='
)

https://ui.neptune.ai/astromid/zindi-keyword-spotter/e/ZIN-5
NeptuneLogger will work in online mode


In [9]:
trainer = pl.Trainer(gpus=2, max_epochs=50, logger=neptune_logger, distributed_backend='dp')

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0,1]


In [10]:
trainer.fit(clf, train_loader, val_loader)


  | Name  | Type        | Params
--------------------------------------
0 | model | PalSolModel | 935 K 


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..



1

In [11]:
sub = pd.read_csv('../data/SampleSubmission.csv')

In [12]:
test_paths = [data_dir / path for path in sub['fn'].values]

In [13]:
test_ds = ZindiAudioDataset(54243, 22050, test_paths)

HBox(children=(FloatProgress(value=0.0, description='Loading files', max=1017.0, style=ProgressStyle(descripti…




In [20]:
clf = clf.eval()

In [25]:
probs = []
with torch.no_grad():
    for item in tqdm(test_ds):
        logits = clf(item.unsqueeze(0))
        probs.append(torch.nn.functional.softmax(logits).numpy())

100%|██████████| 1017/1017 [00:16<00:00, 61.14it/s]


In [26]:
probs_matrix = np.vstack(probs)
new_sub = pd.DataFrame(probs_matrix)
new_sub.columns = train_ds.le.classes_
new_sub.insert(0, 'fn', sub['fn'])

In [27]:
new_sub.head()

Unnamed: 0,fn,Pump,Spinach,abalimi,afukirira,agriculture,akammwanyi,akamonde,akasaanyi,akatunda,...,suckers,sugarcane,sukumawiki,super grow,sweet potatoes,tomatoes,vegetables,watermelon,weeding,worm
0,audio_files/00118N3.wav,0.0035,0.000339,0.000998,0.006869,0.045642,0.001877,0.008848,0.000271,0.001311,...,0.000241,0.007257,0.001243,0.002021,0.000697,0.008931,0.007274,0.005158,0.000463,0.000365
1,audio_files/00P0NMV.wav,0.000253,0.000166,0.028981,0.010111,0.000531,0.004518,0.004797,0.003955,0.000172,...,0.000579,0.000713,0.000143,0.000268,0.000185,0.000227,0.00073,5.6e-05,3.8e-05,2.9e-05
2,audio_files/01QEEZI.wav,8e-05,0.001246,0.004954,5.7e-05,0.000139,0.00243,0.010121,0.000153,0.001402,...,4.2e-05,0.000912,0.000401,0.00014,0.000495,0.000652,0.00028,0.001127,0.008813,0.000176
3,audio_files/037YAED.wav,0.009099,0.000524,0.000482,0.000372,0.017476,0.001958,0.003198,2.3e-05,0.001647,...,0.007704,0.001175,0.001348,0.007324,0.000914,0.00074,0.001629,0.000309,0.009961,0.002679
4,audio_files/0382N0Y.wav,0.000286,0.000124,0.00499,0.001029,0.000211,0.000686,0.000966,0.000881,0.001358,...,0.000622,0.000568,0.000526,0.000951,0.000436,0.000588,0.000531,0.006216,0.00053,0.000503


In [28]:
new_sub.to_csv('zin_5.csv', index=False)