In [0]:
!pip install  torchaudio

Collecting torchaudio
[?25l  Downloading https://files.pythonhosted.org/packages/9c/7d/8e01e21175dd2c9bb1b7e014e0c56cdd02618e2db5bebb4f52f6fdf253cb/torchaudio-0.5.0-cp36-cp36m-manylinux1_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 4.2MB/s 
Installing collected packages: torchaudio
Successfully installed torchaudio-0.5.0


In [0]:
!pip install efficientnet_pytorch

Collecting efficientnet_pytorch
  Downloading https://files.pythonhosted.org/packages/b8/cb/0309a6e3d404862ae4bc017f89645cf150ac94c14c88ef81d215c8e52925/efficientnet_pytorch-0.6.3.tar.gz
Building wheels for collected packages: efficientnet-pytorch
  Building wheel for efficientnet-pytorch (setup.py) ... [?25l[?25hdone
  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.6.3-cp36-none-any.whl size=12422 sha256=df3e2cf37592d829e05e2b6f829289d6df7ab336edd35946b27621137dd0d6aa
  Stored in directory: /root/.cache/pip/wheels/42/1e/a9/2a578ba9ad04e776e80bf0f70d8a7f4c29ec0718b92d8f6ccd
Successfully built efficientnet-pytorch
Installing collected packages: efficientnet-pytorch
Successfully installed efficientnet-pytorch-0.6.3


In [0]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchaudio
import torchvision
from torchaudio import transforms
from efficientnet_pytorch import EfficientNet
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from tqdm import tqdm
import numpy as np
from google.colab import drive
import random
import torchvision.models as models

In [0]:
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
# path to your train/test/meta folders
DATA_PATH = '/content/gdrive/My Drive/Data_Kaggle/'

# names of valuable files/folders
train_meta_fname = 'train.csv'
test_meta_fname = 'sample_submission.csv'
train_data_folder = 'audio_train/train'
test_data_folder = 'audio_test/test'

In [0]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True

In [0]:
df_train = pd.read_csv(os.path.join(DATA_PATH, train_meta_fname))
df_test = pd.read_csv(os.path.join(DATA_PATH, test_meta_fname))
df_train.head(2)

Unnamed: 0,fname,label
0,8bcbcc394ba64fe85ed4.wav,Finger_snapping
1,00d77b917e241afa06f1.wav,Squeak


In [0]:
n_classes = df_train.label.nunique()
print(n_classes)
classes_dict = {cl:i for i,cl in enumerate(df_train.label.unique())}
df_train['label_encoded'] = df_train.label.map(classes_dict)
df_train.head()

41


Unnamed: 0,fname,label,label_encoded
0,8bcbcc394ba64fe85ed4.wav,Finger_snapping,0
1,00d77b917e241afa06f1.wav,Squeak,1
2,17bb93b73b8e79234cb3.wav,Electric_piano,2
3,7d5c7a40a936136da55e.wav,Harmonica,3
4,17e0ee7565a33d6c2326.wav,Snare_drum,4


In [0]:
# https://github.com/lukemelas/EfficientNet-PyTorch
class BaseLineModel(nn.Module):
    
    def __init__(self, sample_rate=16000, n_classes=41):
        super().__init__()
        self.ms = torchaudio.transforms.MelSpectrogram(sample_rate)
#         self.bn1 = nn.BatchNorm2d(1)
        
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=3, padding=1)
        self.cnn3 = nn.Conv2d(in_channels=10, out_channels=3, kernel_size=3, padding=1)
        
        self.features = EfficientNet.from_pretrained('efficientnet-b4') 
        # use it as features
#         for param in self.features.parameters():
#             param.requires_grad = False
            
        self.lin1 = nn.Linear(1000, 333)
        
        self.lin2 = nn.Linear(333, 111)
                
        self.lin3 = nn.Linear(111, n_classes)
        
    def forward(self, x):
        x = self.ms(x)
#         x = self.bn1(x)
                
        x = F.relu(self.cnn1(x))
        x = F.relu(self.cnn3(x))
        
        x = self.features(x)

        x = x.view(x.shape[0], -1)
        x = F.relu(x)
        
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        return x
    
    def inference(self, x):
        x = self.forward(x)
        x = F.softmax(x)
        return x

In [0]:
def sample_or_pad(waveform, wav_len=48000):
    m, n = waveform.shape
    if n < wav_len:
        padded_wav = torch.zeros(1, wav_len)
        padded_wav[:, :n] = waveform
        return padded_wav
    elif n > wav_len:
        offset = np.random.randint(0, n - wav_len)
        sampled_wav = waveform[:, offset:offset+wav_len]
        return sampled_wav
    else:
        return waveform
        
class EventDetectionDataset(Dataset):
    def __init__(self, data_path, x, y=None):
        self.x = x
        self.y = y
        self.data_path = data_path
    
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        path2wav = os.path.join(self.data_path, self.x[idx])
        waveform, sample_rate = torchaudio.load(path2wav, normalization=True)
        waveform = sample_or_pad(waveform)
        if self.y is not None:
            return waveform, self.y[idx]
        return waveform

In [0]:
X_train, X_val, y_train, y_val = train_test_split(df_train.fname.values, df_train.label_encoded.values, 
                                                  test_size=0.2, random_state=42)
train_loader = DataLoader(
                        EventDetectionDataset(os.path.join(DATA_PATH, train_data_folder), X_train, y_train),
                        batch_size=41
                )
val_loader = DataLoader(
                        EventDetectionDataset(os.path.join(DATA_PATH, train_data_folder), X_val, y_val),
                        batch_size=41
                )
test_loader = DataLoader(
                        EventDetectionDataset(os.path.join(DATA_PATH, test_data_folder), df_test.fname.values, None),
                        batch_size=41, shuffle=False
                )

In [0]:
def eval_model(model, eval_dataset):
    model.eval()
    forecast, true_labs = [], []
    with torch.no_grad():
        for wavs, labs in tqdm(eval_dataset):
            wavs, labs = wavs.cuda(), labs.detach().numpy()
            true_labs.append(labs)
            outputs = model.inference(wavs)
            
            outputs = outputs.detach().cpu().numpy().argmax(axis=1)
            forecast.append(outputs)
    forecast = [x for sublist in forecast for x in sublist]
    true_labs = [x for sublist in true_labs for x in sublist]
    return f1_score(forecast, true_labs, average='macro')

In [0]:
criterion = nn.CrossEntropyLoss()
model = BaseLineModel()
model = model.cuda()
lr = 1e-3

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to /root/.cache/torch/checkpoints/efficientnet-b4-6ed6700e.pth


HBox(children=(IntProgress(value=0, max=77999237), HTML(value='')))


Loaded pretrained weights for efficientnet-b4


In [0]:
n_epoch = 100
best_f1 = 0
for epoch in range(n_epoch):
    model.train()
    for wavs, labs in tqdm(train_loader):
        optimizer.zero_grad()
        wavs, labs = wavs.cuda(), labs.cuda()
        outputs = model(wavs)
        loss = criterion(outputs, labs)
        loss.backward()
        optimizer.step()
#     if epoch % 10 == 0:
    f1 = eval_model(model, val_loader)
    f1_train = eval_model(model, train_loader)
    print(f'epoch: {epoch}, f1_test: {f1}, f1_train: {f1_train}')
    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), '/content/gdrive/My Drive/Data_Kaggle/baseline_29.pt')
        
    lr = lr * 0.95
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

100%|██████████| 111/111 [1:02:21<00:00, 33.71s/it]
100%|██████████| 28/28 [15:46<00:00, 33.81s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 0, f1_test: 0.1555526095544899, f1_train: 0.16822269445258298


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]


epoch: 1, f1_test: 0.2989444541753176, f1_train: 0.36431531486563246


100%|██████████| 111/111 [03:15<00:00,  1.77s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]


epoch: 2, f1_test: 0.4137619768374662, f1_train: 0.44523406742761806


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 3, f1_test: 0.440418012491293, f1_train: 0.5071463892002593


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 4, f1_test: 0.5436288719671727, f1_train: 0.632857142939209


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 5, f1_test: 0.5366584102794376, f1_train: 0.6252480861105303


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 6, f1_test: 0.5211035690790522, f1_train: 0.6246032805929391


100%|██████████| 111/111 [03:15<00:00,  1.77s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 7, f1_test: 0.5575642976159477, f1_train: 0.6802689595012301


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 8, f1_test: 0.5180308994061359, f1_train: 0.6678776504846891


100%|██████████| 111/111 [03:15<00:00,  1.77s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]


epoch: 9, f1_test: 0.5976878812609907, f1_train: 0.7649121012260492


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 10, f1_test: 0.6024945277019642, f1_train: 0.7625092962577597


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]


epoch: 11, f1_test: 0.6165805507101172, f1_train: 0.7880980655620298


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 12, f1_test: 0.6388695157914288, f1_train: 0.8257091379104404


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 13, f1_test: 0.6661734103715944, f1_train: 0.8678809244869338


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 14, f1_test: 0.6587353704494493, f1_train: 0.8632358837991981


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 15, f1_test: 0.6126580025687158, f1_train: 0.8194426648461186


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 16, f1_test: 0.6542532524334046, f1_train: 0.8705135921544243


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 17, f1_test: 0.6783251380690789, f1_train: 0.9045037838795499


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]


epoch: 18, f1_test: 0.6894970919012824, f1_train: 0.9092738492334431


100%|██████████| 111/111 [03:16<00:00,  1.77s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 19, f1_test: 0.6781182365989147, f1_train: 0.8918592916365795


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 20, f1_test: 0.6747537463497492, f1_train: 0.8928699279570487


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 21, f1_test: 0.6600537925727665, f1_train: 0.9034251383379976


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 22, f1_test: 0.6721543243317895, f1_train: 0.9200174869742576


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 23, f1_test: 0.6954243972149051, f1_train: 0.9212674191593491


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 24, f1_test: 0.7015820963448423, f1_train: 0.9381011574371843


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 25, f1_test: 0.687599465274299, f1_train: 0.9328721208457076


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 26, f1_test: 0.6960724033198109, f1_train: 0.9419583775814684


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 27, f1_test: 0.7159606482702882, f1_train: 0.944656732122106


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 28, f1_test: 0.7177466573463358, f1_train: 0.9535418658036054


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 29, f1_test: 0.720566762334659, f1_train: 0.9512780362204571


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 30, f1_test: 0.7158127925592408, f1_train: 0.9463024508946125


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 31, f1_test: 0.7000926290737319, f1_train: 0.9545148873738886


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 32, f1_test: 0.7188674170211725, f1_train: 0.9618464336461798


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 33, f1_test: 0.691549785598421, f1_train: 0.9432145353329887


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 34, f1_test: 0.7328895172468209, f1_train: 0.9548617161778532


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 35, f1_test: 0.6885097600503541, f1_train: 0.948972887425306


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 36, f1_test: 0.7015409897154261, f1_train: 0.9474869886125625


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 37, f1_test: 0.6915540683902466, f1_train: 0.9598719190464736


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 38, f1_test: 0.7089971951602378, f1_train: 0.9506161918078448


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 39, f1_test: 0.7097793738180576, f1_train: 0.9693988544524653


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]


epoch: 40, f1_test: 0.7390800706358124, f1_train: 0.9656754992990411


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 41, f1_test: 0.7090630088574128, f1_train: 0.9673535675377095


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 42, f1_test: 0.7039601837369118, f1_train: 0.9629438768028263


100%|██████████| 111/111 [03:15<00:00,  1.77s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 43, f1_test: 0.7174879033492867, f1_train: 0.9634625449425204


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 44, f1_test: 0.7120385029280111, f1_train: 0.966851822973846


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 45, f1_test: 0.7131018990159076, f1_train: 0.9683664933391827


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 46, f1_test: 0.7166889752849956, f1_train: 0.9689916638760573


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:09<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 47, f1_test: 0.7218022609715281, f1_train: 0.9696611367919846


100%|██████████| 111/111 [03:15<00:00,  1.76s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 48, f1_test: 0.7108288710746691, f1_train: 0.9635349154274695


100%|██████████| 111/111 [03:15<00:00,  1.77s/it]
100%|██████████| 28/28 [00:32<00:00,  1.16s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 49, f1_test: 0.7212155442778252, f1_train: 0.970549228105814


100%|██████████| 111/111 [03:15<00:00,  1.77s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 50, f1_test: 0.7205182148332383, f1_train: 0.9747282164168938


100%|██████████| 111/111 [03:15<00:00,  1.77s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 51, f1_test: 0.7245575626399291, f1_train: 0.9728820738573742


100%|██████████| 111/111 [03:16<00:00,  1.77s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 52, f1_test: 0.7051787768493486, f1_train: 0.9717958579281138


100%|██████████| 111/111 [03:16<00:00,  1.77s/it]
100%|██████████| 28/28 [00:32<00:00,  1.17s/it]
100%|██████████| 111/111 [02:10<00:00,  1.17s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 53, f1_test: 0.7075978639682597, f1_train: 0.9721627675259596


 11%|█         | 12/111 [00:20<02:42,  1.64s/it]

In [0]:
# make a model
model_name = '/content/gdrive/My Drive/Data_Kaggle/baseline_29.pt'
model = BaseLineModel()
model.load_state_dict(torch.load(os.path.join('..', model_name)))
model.eval()
z=[]
forecast = []
with torch.no_grad():
    for i_batch, sample_batched in enumerate(test_loader):
        print(i_batch)
        wavs = sample_batched
        #wavs = wavs.cuda()
        outputs = model.inference(wavs)
        z.append(outputs) #записываем вероятности
        outputs = outputs.detach().cpu().numpy().argmax(axis=1)
        forecast.append(outputs)
forecast = [x for sublist in forecast for x in sublist]
decoder = {classes_dict[cl]:cl for cl in classes_dict}
forecast = pd.Series(forecast).map(decoder)
df_test['label'] = forecast
df_test.to_csv('/content/gdrive/My Drive/Data_Kaggle/baseline_29.csv', index=None)

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to /root/.cache/torch/checkpoints/efficientnet-b4-6ed6700e.pth


HBox(children=(IntProgress(value=0, max=77999237), HTML(value='')))


Loaded pretrained weights for efficientnet-b4
0




1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92


In [0]:
probability = []
for i in range(len(z)):
    for j in range(z[i].shape[0]):
      probability.append(z[i][j].numpy())

df_test['label'] = probability
df_test.to_csv('/content/gdrive/My Drive/Data_Kaggle/baseline_29_probability.csv', index=None)