In [23]:
import os
import sys
import numpy as np
import pandas as pd
import librosa
import tqdm
from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch.utils.data.sampler import SubsetRandomSampler

import matplotlib.pyplot as plt

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
sys.path.append('..')

In [4]:
test_df = pd.read_csv('../data/sample_submission.csv')

In [5]:
test_df.head()

Unnamed: 0,fname,label
0,00063640.wav,Laughter Hi-Hat Flute
1,0013a1db.wav,Laughter Hi-Hat Flute
2,002bb878.wav,Laughter Hi-Hat Flute
3,002d392d.wav,Laughter Hi-Hat Flute
4,00326aa9.wav,Laughter Hi-Hat Flute


In [6]:
labels = np.load('../labels.npy')

In [7]:
labels

array(['Acoustic_guitar', 'Applause', 'Bark', 'Bass_drum',
       'Burping_or_eructation', 'Bus', 'Cello', 'Chime', 'Clarinet',
       'Computer_keyboard', 'Cough', 'Cowbell', 'Double_bass',
       'Drawer_open_or_close', 'Electric_piano', 'Fart',
       'Finger_snapping', 'Fireworks', 'Flute', 'Glockenspiel', 'Gong',
       'Gunshot_or_gunfire', 'Harmonica', 'Hi-hat', 'Keys_jangling',
       'Knock', 'Laughter', 'Meow', 'Microwave_oven', 'Oboe', 'Saxophone',
       'Scissors', 'Shatter', 'Snare_drum', 'Squeak', 'Tambourine',
       'Tearing', 'Telephone', 'Trumpet', 'Violin_or_fiddle', 'Writing'],
      dtype=object)

In [22]:
cuda = torch.cuda.is_available()
if cuda:
    print('cuda available!')

device = torch.device('cuda' if cuda else 'cpu')
num_workers = 0

In [8]:
from model import LeNet
from dataset import AudioDataset

In [9]:
model = LeNet(len(labels))

In [10]:
model.load_state_dict(torch.load('../logs/log.1/epoch075-1.799-0.651.pth',
                                 map_location=lambda storage, loc: storage))

In [47]:
# テストなので最初の400テストデータのみ評価
test_dataset = AudioDataset(test_df[:400], '../data/audio_test', test=True)

In [48]:
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    512,
    shuffle=False
)

In [39]:
data = iter(test_loader).next()

In [40]:
print(data.size())

torch.Size([128, 1, 64, 401])


In [41]:
result = model(data)

In [42]:
result.size()

torch.Size([128, 41])

In [43]:
model.eval()

LeNet(
  (block1): Sequential(
    (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block2): Sequential(
    (0): Conv2d(20, 20, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=25220, out_features=1000, bias=True)
  (fc2): Linear(in_features=1000, out_features=41, bias=True)
)

In [44]:
len(test_loader.dataset)

400

In [50]:
predictions = []
with torch.no_grad():
    for batch_idx, data in enumerate(test_loader):
        data = data.to(device)
        output = model(data)
        predictions.append(output)

# 41クラスの予測結果を縦方向に結合
predictions = torch.cat(predictions, dim=0)
print(predictions.size())
np.save('predictions.npy', predictions.numpy())

torch.Size([400, 41])


In [54]:
predictions = torch.from_numpy(np.load('predictions.npy'))

In [55]:
predictions[0]

tensor([ -6.0687,  -8.1937,  -8.3637,  -6.1772,  -6.2904,  -8.0602,
         -7.5709,  -7.7106, -10.4121,  -3.9519,  -2.6194,  -6.3658,
         -9.1989,  -3.1801,  -9.4000,  -5.8844,  -1.2346,  -3.2125,
         -6.5601, -10.4873,  -6.0737,  -4.1005,  -9.8429,  -0.9729,
         -3.2759,  -2.9223,  -6.4541,  -6.9585,  -2.1553, -11.9581,
         -9.1059,  -1.9610,  -0.7883,  -3.8649,  -4.7348,  -3.6549,
         -0.2592,  -6.6816,  -1.3779,  -8.0430,  -2.9970])

In [62]:
value, index = predictions[0].topk(3)

In [63]:
value

tensor([-0.2592, -0.7883, -0.9729])

In [64]:
predictions[0].max()

tensor(-0.2592)

In [66]:
index

tensor([ 36,  32,  23])

In [68]:
labels[index]

array(['Tearing', 'Shatter', 'Hi-hat'], dtype=object)

In [69]:
labels[36]

'Tearing'

In [70]:
labels[32]

'Shatter'

In [71]:
labels[23]

'Hi-hat'

In [72]:
_, indices = predictions.topk(3)

In [73]:
indices.size()

torch.Size([400, 3])

In [91]:
predicted_labels = labels[indices]

In [92]:
predicted_labels.shape

(400, 3)

In [93]:
predicted_labels = [' '.join(lst) for lst in predicted_labels]
predicted_labels

['Tearing Shatter Hi-hat',
 'Meow Flute Oboe',
 'Fireworks Bass_drum Gunshot_or_gunfire',
 'Bass_drum Knock Double_bass',
 'Oboe Meow Bark',
 'Bass_drum Knock Gunshot_or_gunfire',
 'Squeak Telephone Violin_or_fiddle',
 'Gong Acoustic_guitar Electric_piano',
 'Clarinet Flute Telephone',
 'Saxophone Cello Violin_or_fiddle',
 'Cello Flute Saxophone',
 'Clarinet Saxophone Flute',
 'Chime Keys_jangling Glockenspiel',
 'Cello Saxophone Double_bass',
 'Violin_or_fiddle Trumpet Clarinet',
 'Acoustic_guitar Gong Double_bass',
 'Flute Gong Electric_piano',
 'Bass_drum Flute Gong',
 'Clarinet Saxophone Flute',
 'Bass_drum Snare_drum Electric_piano',
 'Clarinet Flute Saxophone',
 'Clarinet Saxophone Flute',
 'Saxophone Violin_or_fiddle Clarinet',
 'Hi-hat Shatter Gong',
 'Cello Acoustic_guitar Flute',
 'Keys_jangling Shatter Microwave_oven',
 'Trumpet Knock Fireworks',
 'Shatter Tambourine Keys_jangling',
 'Clarinet Flute Oboe',
 'Hi-hat Tearing Shatter',
 'Cough Acoustic_guitar Laughter',
 'Cello

In [80]:
test_df['label'][:400]

0      Laughter Hi-Hat Flute
1      Laughter Hi-Hat Flute
2      Laughter Hi-Hat Flute
3      Laughter Hi-Hat Flute
4      Laughter Hi-Hat Flute
5      Laughter Hi-Hat Flute
6      Laughter Hi-Hat Flute
7      Laughter Hi-Hat Flute
8      Laughter Hi-Hat Flute
9      Laughter Hi-Hat Flute
10     Laughter Hi-Hat Flute
11     Laughter Hi-Hat Flute
12     Laughter Hi-Hat Flute
13     Laughter Hi-Hat Flute
14     Laughter Hi-Hat Flute
15     Laughter Hi-Hat Flute
16     Laughter Hi-Hat Flute
17     Laughter Hi-Hat Flute
18     Laughter Hi-Hat Flute
19     Laughter Hi-Hat Flute
20     Laughter Hi-Hat Flute
21     Laughter Hi-Hat Flute
22     Laughter Hi-Hat Flute
23     Laughter Hi-Hat Flute
24     Laughter Hi-Hat Flute
25     Laughter Hi-Hat Flute
26     Laughter Hi-Hat Flute
27     Laughter Hi-Hat Flute
28     Laughter Hi-Hat Flute
29     Laughter Hi-Hat Flute
               ...          
370    Laughter Hi-Hat Flute
371    Laughter Hi-Hat Flute
372    Laughter Hi-Hat Flute
373    Laughte

In [94]:
test_df['label'][:400] = predicted_labels

In [95]:
test_df['label'][:400]

0                           Tearing Shatter Hi-hat
1                                  Meow Flute Oboe
2           Fireworks Bass_drum Gunshot_or_gunfire
3                      Bass_drum Knock Double_bass
4                                   Oboe Meow Bark
5               Bass_drum Knock Gunshot_or_gunfire
6                Squeak Telephone Violin_or_fiddle
7              Gong Acoustic_guitar Electric_piano
8                         Clarinet Flute Telephone
9                 Saxophone Cello Violin_or_fiddle
10                           Cello Flute Saxophone
11                        Clarinet Saxophone Flute
12                Chime Keys_jangling Glockenspiel
13                     Cello Saxophone Double_bass
14               Violin_or_fiddle Trumpet Clarinet
15                Acoustic_guitar Gong Double_bass
16                       Flute Gong Electric_piano
17                            Bass_drum Flute Gong
18                        Clarinet Saxophone Flute
19             Bass_drum Snare_

In [None]:
q