In [28]:
from dataset import ZindiAudioDataset
from models import PalSolModel
from lightning import PalSolClassifier
import pandas as pd
from pathlib import Path
from torchsummary import summary
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torch
from tqdm.auto import tqdm
import numpy as np


In [2]:
data_dir = Path('../data')
df = pd.read_csv('../data/train.csv')
paths = [data_dir / path for path in df['fn'].values]
labels = df['label'].values

In [3]:
ds = ZindiAudioDataset(54243, 22050, paths, labels)

Loading files: 100%|██████████| 1109/1109 [00:11<00:00, 100.53it/s]


In [4]:
loader = DataLoader(ds, batch_size=32)

In [5]:
clf = PalSolClassifier(num_classes=len(ds.le.classes_), sample_rate=22050)

In [8]:
from pytorch_lightning.loggers import NeptuneLogger

neptune_logger = NeptuneLogger(
    project_name='astromid/sandbox',
    experiment_name='test',
    params={'max_epochs': 5},
)

psutil is not installed. You will not be able to abort this experiment from the UI.
psutil is not installed. Hardware metrics will not be collected.
https://ui.neptune.ai/astromid/sandbox/e/SAN-1
NeptuneLogger will work in online mode


In [9]:
trainer = pl.Trainer(gpus=1, max_epochs=5, logger=neptune_logger)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [10]:
trainer.fit(clf, loader)


  | Name  | Type        | Params
--------------------------------------
0 | model | PalSolModel | 935 K 
Epoch 4: 100%|██████████| 35/35 [00:19<00:00,  1.76it/s, loss=3.985, v_num=AN-1]Saving latest checkpoint..
Epoch 4: 100%|██████████| 35/35 [00:19<00:00,  1.75it/s, loss=3.985, v_num=AN-1]


1

In [11]:
sub = pd.read_csv('../data/SampleSubmission.csv')

In [18]:
test_paths = [data_dir / path for path in sub['fn'].values]

In [22]:
test_ds = ZindiAudioDataset(54243, 22050, test_paths)

Loading files: 100%|██████████| 1017/1017 [00:08<00:00, 114.60it/s]


In [26]:
clf = clf.eval()

In [51]:
probs = []
with torch.no_grad():
    for item in tqdm(test_ds):
        logits = clf(item.unsqueeze(0))
        probs.append(torch.nn.functional.softmax(logits).numpy())

100%|██████████| 1017/1017 [01:06<00:00, 15.38it/s]


In [52]:
probs_matrix = np.vstack(probs)
new_sub = pd.DataFrame(probs_matrix)
# new_sub.columns = ds.le.classes_
# new_sub.values = probs_matrix

In [43]:
new_sub['fn'] = sub['fn']

(1017, 193)

In [54]:
new_sub.columns = ds.le.classes_

In [55]:
new_sub.head()

Unnamed: 0,Pump,Spinach,abalimi,afukirira,agriculture,akammwanyi,akamonde,akasaanyi,akatunda,akatungulu,...,suckers,sugarcane,sukumawiki,super grow,sweet potatoes,tomatoes,vegetables,watermelon,weeding,worm
0,0.012175,0.002343,0.008473,0.003873,0.011243,0.004418,0.002169,0.010284,0.004583,0.004816,...,0.007943,0.002807,0.002287,0.003243,0.001718,0.004745,0.003433,0.002362,0.003323,0.006075
1,0.004885,0.002677,0.008709,0.020518,0.005466,0.011951,0.017236,0.009939,0.006564,0.013447,...,0.00554,0.002088,0.002369,0.001096,0.001761,0.002434,0.002804,0.001064,0.001397,0.003076
2,0.001335,0.002567,0.005726,0.003267,0.00129,0.001728,0.004658,0.0008,0.003714,0.005663,...,0.001452,0.003533,0.002494,0.003664,0.003812,0.003583,0.004573,0.004816,0.004851,0.001428
3,0.049555,0.000507,0.002672,0.002595,0.013128,0.016373,0.010434,0.008203,0.002858,0.005017,...,0.030116,0.001026,0.000938,0.001808,0.000691,0.001015,0.001118,0.000548,0.000327,0.013761
4,0.002854,0.002721,0.006294,0.003672,0.005224,0.002702,0.00355,0.002318,0.005443,0.004982,...,0.003256,0.003643,0.002264,0.004063,0.002669,0.004179,0.004302,0.004263,0.003964,0.003377


In [56]:
new_sub.insert(0, 'fn', sub['fn'])

In [57]:
new_sub.head()

Unnamed: 0,fn,Pump,Spinach,abalimi,afukirira,agriculture,akammwanyi,akamonde,akasaanyi,akatunda,...,suckers,sugarcane,sukumawiki,super grow,sweet potatoes,tomatoes,vegetables,watermelon,weeding,worm
0,audio_files/00118N3.wav,0.012175,0.002343,0.008473,0.003873,0.011243,0.004418,0.002169,0.010284,0.004583,...,0.007943,0.002807,0.002287,0.003243,0.001718,0.004745,0.003433,0.002362,0.003323,0.006075
1,audio_files/00P0NMV.wav,0.004885,0.002677,0.008709,0.020518,0.005466,0.011951,0.017236,0.009939,0.006564,...,0.00554,0.002088,0.002369,0.001096,0.001761,0.002434,0.002804,0.001064,0.001397,0.003076
2,audio_files/01QEEZI.wav,0.001335,0.002567,0.005726,0.003267,0.00129,0.001728,0.004658,0.0008,0.003714,...,0.001452,0.003533,0.002494,0.003664,0.003812,0.003583,0.004573,0.004816,0.004851,0.001428
3,audio_files/037YAED.wav,0.049555,0.000507,0.002672,0.002595,0.013128,0.016373,0.010434,0.008203,0.002858,...,0.030116,0.001026,0.000938,0.001808,0.000691,0.001015,0.001118,0.000548,0.000327,0.013761
4,audio_files/0382N0Y.wav,0.002854,0.002721,0.006294,0.003672,0.005224,0.002702,0.00355,0.002318,0.005443,...,0.003256,0.003643,0.002264,0.004063,0.002669,0.004179,0.004302,0.004263,0.003964,0.003377


In [58]:
new_sub.to_csv('test_sub.csv', index=False)