In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
cd "/content/drive/MyDrive/bird2/code"

/content/drive/MyDrive/bird2/code


In [5]:
import numpy as np
import librosa as lb
import soundfile as sf
import pandas as pd
import cv2
from pathlib import Path
import re

import torch
from torch import nn
from  torch.utils.data import Dataset, DataLoader

from tqdm.notebook import tqdm

import time
#from resnest.torch import resnest50

In [6]:
!pip install git+https://github.com/rwightman/pytorch-image-models.git

Collecting git+https://github.com/rwightman/pytorch-image-models.git
  Cloning https://github.com/rwightman/pytorch-image-models.git to /tmp/pip-req-build-_pttr23r
  Running command git clone -q https://github.com/rwightman/pytorch-image-models.git /tmp/pip-req-build-_pttr23r
Building wheels for collected packages: timm
  Building wheel for timm (setup.py) ... [?25l[?25hdone
  Created wheel for timm: filename=timm-0.4.11-cp37-none-any.whl size=372534 sha256=e7e8eafe1a03ee409891595ee05ee05798cc09bce5942f803b4f40aab6fa8a47
  Stored in directory: /tmp/pip-ephem-wheel-cache-js1crim1/wheels/20/b8/27/66bb141495c14daa67474754678277959ca333a352dab313a5
Successfully built timm
Installing collected packages: timm
Successfully installed timm-0.4.11


In [7]:
import timm

# Configs

In [8]:
NUM_CLASSES = 397
SR = 32_000
DURATION = 5
THRESH = 0.25
#THRESH = 0.3


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

TEST_AUDIO_ROOT = Path("../input/birdclef-2021/test_soundscapes")
SAMPLE_SUB_PATH = "../input/birdclef-2021/sample_submission.csv"
TARGET_PATH = None
    
if not len(list(TEST_AUDIO_ROOT.glob("*.ogg"))):
    TEST_AUDIO_ROOT = Path("../input/birdclef-2021/train_soundscapes")
    SAMPLE_SUB_PATH = None
    # SAMPLE_SUB_PATH = "../input/birdclef-2021/sample_submission.csv"
    TARGET_PATH = Path("../input/birdclef-2021/train_soundscape_labels.csv")

DEVICE: cuda


# Data

In [9]:
class MelSpecComputer:
    def __init__(self, sr, n_mels, fmin, fmax, **kwargs):
        self.sr = sr
        self.n_mels = n_mels
        self.fmin = fmin
        self.fmax = fmax
        kwargs["n_fft"] = kwargs.get("n_fft", self.sr//10)
        kwargs["hop_length"] = kwargs.get("hop_length", self.sr//(10*4))
        self.kwargs = kwargs

    def __call__(self, y):

        melspec = lb.feature.melspectrogram(
            y, sr=self.sr, n_mels=self.n_mels, fmin=self.fmin, fmax=self.fmax, **self.kwargs,
        )

        melspec = lb.power_to_db(melspec).astype(np.float32)
        return melspec

In [10]:
def mono_to_color(X, eps=1e-6, mean=None, std=None):
    mean = mean or X.mean()
    std = std or X.std()
    X = (X - mean) / (std + eps)
    
    _min, _max = X.min(), X.max()

    if (_max - _min) > eps:
        V = np.clip(X, _min, _max)
        V = 255 * (V - _min) / (_max - _min)
        V = V.astype(np.uint8)
    else:
        V = np.zeros_like(X, dtype=np.uint8)

    return V

def crop_or_pad(y, length):
    if len(y) < length:
        y = np.concatenate([y, length - np.zeros(len(y))])
    elif len(y) > length:
        y = y[:length]
    return y

In [11]:
# swin用
#RESIZE = [256, 562]
RESIZE = None

def resize(image, size=None):
    if size is not None:
        #print("1",image.shape)
        image = image.transpose((1, 2, 0))
        #print("2",image.shape)
        image = cv2.resize(image, (size[0], size[1]))
        image = image.transpose((2, 0, 1))
        #print("3",image.shape)

    return image

In [12]:
class BirdCLEFDataset(Dataset):
    def __init__(self, data, sr=SR, n_mels=128, fmin=0, fmax=None, duration=DURATION, step=None, res_type="kaiser_fast", resample=True):
        
        self.data = data
        
        self.sr = sr
        self.n_mels = n_mels
        self.fmin = fmin
        self.fmax = fmax or self.sr//2

        self.duration = duration
        self.audio_length = self.duration*self.sr
        self.step = step or self.audio_length
        
        self.res_type = res_type
        self.resample = resample

        self.mel_spec_computer = MelSpecComputer(sr=self.sr, n_mels=self.n_mels, fmin=self.fmin,
                                                 fmax=self.fmax)
    def __len__(self):
        return len(self.data)
    
    @staticmethod
    def normalize(image):
        image = image.astype("float32", copy=False) / 255.0
        image = np.stack([image, image, image])
        return image
    
    def audio_to_image(self, audio):
        #print("audio", audio.shape)
        melspec = self.mel_spec_computer(audio) 
        #print("melspec", melspec.shape)
        image = mono_to_color(melspec)
        image = self.normalize(image)
        return image

    def read_file(self, filepath):
        audio, orig_sr = sf.read(filepath, dtype="float32")

        if self.resample and orig_sr != self.sr:
            audio = lb.resample(audio, orig_sr, self.sr, res_type=self.res_type)
          
        audios = []
        for i in range(self.audio_length, len(audio) + self.step, self.step):
            start = max(0, i - self.audio_length)
            end = start + self.audio_length
            audios.append(audio[start:end])
            
        if len(audios[-1]) < self.audio_length:
            audios = audios[:-1]
            
        images = [self.audio_to_image(audio) for audio in audios]
        # swin ------
        images = []
        for audio in audios:
            image = self.audio_to_image(audio)
            image = resize(image, size=RESIZE)
            images.append(image)
        # ------------
        images = np.stack(images)
        
        return images
    
        
    def __getitem__(self, idx):
        return self.read_file(self.data.loc[idx, "filepath"])

In [13]:
data = pd.DataFrame(
     [(path.stem, *path.stem.split("_"), path) for path in Path(TEST_AUDIO_ROOT).glob("*.ogg")],
    columns = ["filename", "id", "site", "date", "filepath"]
)
print(data.shape)
data.head()

(20, 5)


Unnamed: 0,filename,id,site,date,filepath
0,10534_SSW_20170429,10534,SSW,20170429,../input/birdclef-2021/train_soundscapes/10534...
1,11254_COR_20190904,11254,COR,20190904,../input/birdclef-2021/train_soundscapes/11254...
2,20152_SSW_20170805,20152,SSW,20170805,../input/birdclef-2021/train_soundscapes/20152...
3,18003_COR_20190904,18003,COR,20190904,../input/birdclef-2021/train_soundscapes/18003...
4,14473_SSW_20170701,14473,SSW,20170701,../input/birdclef-2021/train_soundscapes/14473...


In [14]:
df_train = pd.read_csv("../input/birdclef-2021/train_metadata.csv")

LABEL_IDS = {label: label_id for label_id,label in enumerate(sorted(df_train["primary_label"].unique()))}
INV_LABEL_IDS = {val: key for key,val in LABEL_IDS.items()}

# Inference

In [15]:
test_data = BirdCLEFDataset(data=data)
len(test_data), test_data[0].shape

(20, (120, 3, 128, 201))

In [16]:
# net = timm.create_model("resnest50d", pretrained=False)
# print(net)

In [17]:
def load_net(checkpoint_path, model_name, num_classes=NUM_CLASSES):
    
    net = timm.create_model(model_name, pretrained=False)


    
    if "efficientnet" in model_name:
        net.classifier = nn.Linear(net.classifier.in_features, num_classes)
    elif "dens" in model_name:
        net.classifier = nn.Linear(net.classifier.in_features, num_classes)
    elif "nfnet" in model_name or "rexnet" in model_name:
        net.head.fc = nn.Linear(net.head.fc.in_features, num_classes)
    else:
        net.fc = nn.Linear(net.fc.in_features, num_classes)
    

    dummy_device = torch.device("cpu")
    d = torch.load(checkpoint_path, map_location=dummy_device)
    for key in list(d.keys()):
        d[key.replace("model.", "")] = d.pop(key)
    net.load_state_dict(d)
    net = net.to(DEVICE)
    net = net.eval()
    return net

In [18]:

checkpoint_paths = [
    # Path("./weights/0523_1200_resnest50d_4s2x40d_sr32000_d7_v1_v1/birdclef_resnest50d_4s2x40d_fold0_epoch_14_f1_val_06718_20210523051253.pth"),
    # Path("./weights/0523_1600_eca_nfnet_l0_sr32000_d7_v1_v1/birdclef_eca_nfnet_l0_fold0_epoch_14_f1_val_06533_20210523083616.pth"),
    # Path("./weights/0523_1100_tf_eff_b5_ns/birdclef_tf_efficientnet_b5_ns_fold1_epoch_12_f1_val_06412_20210522112214.pth"),
    #Path("./weights/0519_1430_tf_eff_b4/birdclef_tf_efficientnet_b4_fold0_epoch_11_f1_val_07519_20210519052502.pth"),
    Path("./weights/0520_1630_resnest50d/birdclef_resnest50d_fold0_epoch_11_f1_val_06515_20210520072735.pth"),
    # Path("./weights/0524_0100_swsl_resnet50_sr32000_d7_v1_v1/birdclef_swsl_resnet50_fold0_epoch_14_f1_val_05959_20210523164922.pth"),
    # Path("./weights/0524_1800_resnext50_32x4d_sr32000_d7_v1_v1/birdclef_resnext50_32x4d_fold3_epoch_11_f1_val_06291_20210524084908.pth"),
    # #Path("./weights/0524_1500_resnest50d_sr32000_d7_v1_v1/birdclef_resnest50d_fold2_epoch_11_f1_val_06542_20210524053705.pth"),
    # Path("./weights/0524_2100_tf_efficientnet_b0_ns_sr32000_d7_v1_v1/birdclef_tf_efficientnet_b0_ns_fold4_epoch_11_f1_val_05319_20210524122743.pth"),
    Path("./weights/0524_2200_densenet121_sr32000_d7_v1_v1/birdclef_densenet121_fold3_epoch_11_f1_val_05970_20210524130955.pth"),
    Path("./weights/0524_2330_densenet201_sr32000_d7_v1_v1/birdclef_densenet201_fold2_epoch_14_f1_val_06470_20210524143220.pth"),
    # #Path("./weights/0529_1930_resnest50d_sr32000_d7_v1_v1/birdclef_resnest50d_fold0_epoch_14_f1_val_07033_20210529101941.pth"),
    # #Path("./weights/0531_1630_resnest50d_sr32000_d7_v1_v1/birdclef_resnest50d_fold1_epoch_11_f1_val_05810_20210531071642.pth"),
    # #Path("./weights/0531_1900_resnest50d_sr32000_d7_v1_v1/birdclef_resnest50d_fold1_epoch_11_f1_val_06369_20210531100013.pth"),
    #Path("./weights/0531_2200_resnest50d_sr32000_d7_v1_v1/birdclef_resnest50d_fold1_epoch_11_f1_val_06930_20210531130530.pth"),
    # #Path("./weights/0601_1400_tf_efficientnet_b4_sr32000_d7_v1_v1/birdclef_tf_efficientnet_b4_fold2_epoch_11_f1_val_06780_20210601050127.pth"),
    # #Path("./weights/0601_1900_efficientnetv2_rw_m_sr32000_d7_v1_v1/birdclef_efficientnetv2_rw_m_fold3_epoch_11_f1_val_06767_20210601101238.pth"),
    Path("./weights/0602_0000_rexnet_200_sr32000_d7_v1_v1/birdclef_rexnet_200_fold4_epoch_11_f1_val_06798_20210601150743.pth"),
    #Path("./weights/0602_0600_rexnet_150_sr32000_d7_v1_v1/birdclef_rexnet_150_fold1_epoch_11_f1_val_06706_20210601171934.pth"),
    
    
]

model_names = [
              # "resnest50d_4s2x40d",
              # "eca_nfnet_l0",
              # "tf_efficientnet_b5_ns",
              #"tf_efficientnet_b4",
              "resnest50d",
              #  "swsl_resnet50",
              #  "resnext50_32x4d",
              #  #"resnest50d",
              #  "tf_efficientnet_b0_ns",
               "densenet121",
               "densenet201",
              #  #"resnest50d",
              #  #"resnest50d",
              #  #"resnest50d",
               # "resnest50d",
              #  #"tf_efficientnet_b4",
 #              "efficientnetv2_rw_m",
                "rexnet_200",
               #"rexnet_150",
]

count = 0
nets = []
for checkpoint_path in checkpoint_paths:
    nets.append(load_net(checkpoint_path.as_posix(), model_name=model_names[count]))
    count += 1

In [19]:
# @torch.no_grad()
# def get_thresh_preds(out, thresh=None):
#     thresh = thresh or THRESH
#     o = (-out).argsort(1)

#     test = out > thresh
#     print("test", test)
#     print("o test", o[test])
#     npreds = (out > thresh).sum(1)
#     preds = []
#     for oo, npred in zip(o, npreds):
#         preds.append(oo[:npred].cpu().numpy().tolist())
        
#     return preds

In [20]:
@torch.no_grad()
def get_thresh_preds(out, thresh=None):
    thresh = thresh or THRESH
    o = (-out).argsort(1)

    test = out > thresh

    npreds = (out > thresh).sum(1)

    preds = []
    for oo, npred in zip(o, npreds):
        preds.append(oo[:npred].cpu().numpy().tolist())
        
        
    return preds

In [21]:
def get_bird_names(preds):
    bird_names = []
    for pred in preds:
        if not pred:
            bird_names.append("nocall")
        else:
            bird_names.append(" ".join([INV_LABEL_IDS[bird_id] for bird_id in pred]))
    #print("bird_names", bird_names)
    return bird_names

In [22]:
def predict(nets, test_data, names=True):
    preds = []
    with torch.no_grad():
        for idx in  tqdm(list(range(len(test_data)))):
            xb = torch.from_numpy(test_data[idx]).to(DEVICE)
            pred = 0.
            for net in nets:
                o = net(xb)
                o = torch.sigmoid(o)

                pred += o

            pred /= len(nets)
            #print("pred.shape", pred.shape)
            if names:
                pred = get_bird_names(get_thresh_preds(pred))

            preds.append(pred)
    return preds

In [23]:
pred_probas = predict(nets, test_data, names=False)
print(len(pred_probas))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


20


In [24]:
len(pred_probas[0])

120

In [25]:
def get_bird_names_org(pred_probas):
  all_bird_names = []
  for pred_proba in pred_probas:  # ogg 20
    #print("ogg")
    ogg_bird_names = []
    for pb in pred_proba: # 区切りごと 120
      #print("区切り")
      bird_names = []
      yp = np.where(pb.cpu().numpy() > THRESH)  # 閾値を超えたNo (array([72]),)
      for y in yp:  # [72]
        if len(y) == 0:
          #print("nocall")
          bird_names.append("nocall")
        else:
          for m in y:  # 閾値を超えた数だけループ
            bird_names.append(INV_LABEL_IDS[m])
        ogg_bird_names.append(bird_names)
        #print(bird_names)
    all_bird_names.append(ogg_bird_names)
  return all_bird_names
    
    

In [26]:
THRESH = 0.3

In [27]:
preds = get_bird_names_org(pred_probas)

In [28]:

# preds = [get_bird_names(get_thresh_preds(pred, thresh=THRESH)) for pred in pred_probas]
# preds[:2]

In [29]:
# def preds_as_df(data, preds):
#     sub = {
#         "row_id": [],
#         "birds": [],
#     }
    
#     for row, pred in zip(data.itertuples(False), preds):
#         row_id = [f"{row.id}_{row.site}_{5*i}" for i in range(1, len(pred)+1)]
#         sub["birds"] += pred
#         sub["row_id"] += row_id
        
#     sub = pd.DataFrame(sub)
    
#     if SAMPLE_SUB_PATH:
#         sample_sub = pd.read_csv(SAMPLE_SUB_PATH, usecols=["row_id"])
#         sub = sample_sub.merge(sub, on="row_id", how="left")
#         sub["birds"] = sub["birds"].fillna("nocall")
#     return sub

In [30]:
print(len(preds))
print(len(preds[0]))
print(len(preds[0][0]))


20
120
1


In [31]:
# org
def preds_as_df(data, preds):
    sub = {
        "row_id": [],
        "birds": [],
    }
    count = 0
    
    for row in data.itertuples(False):  # ogg
        row_id = [f"{row.id}_{row.site}_{5*i}" for i in range(1, len(preds[count])+1)]
        sub["row_id"] += row_id
    
    for i in range(len(preds)):  # ogg
        yp = []
        for j in range(len(preds[i])): # 区切り
              y = ""
              #print(preds[i][j])
              for m in preds[i][j]:  # 鳴き声
                y += m
                y += " "
              
              yp.append(y)
    
        sub["birds"] += yp
    
    sub = pd.DataFrame(sub)
    
    if SAMPLE_SUB_PATH:
        sample_sub = pd.read_csv(SAMPLE_SUB_PATH, usecols=["row_id"])
        sub = sample_sub.merge(sub, on="row_id", how="left")
        sub["birds"] = sub["birds"].fillna("nocall")
    return sub

In [32]:

sub = preds_as_df(data, preds)
print(sub.shape)
sub

(2400, 2)


Unnamed: 0,row_id,birds
0,10534_SSW_5,nocall
1,10534_SSW_10,nocall
2,10534_SSW_15,nocall
3,10534_SSW_20,nocall
4,10534_SSW_25,nocall
...,...,...
2395,7954_COR_580,nocall
2396,7954_COR_585,nocall
2397,7954_COR_590,nocall
2398,7954_COR_595,nocall


In [33]:
#sub.to_csv("submission.csv", index=False)

# Small validation

In [34]:
def get_metrics(s_true, s_pred):
    s_true = set(s_true.split())
    s_pred = set(s_pred.split())
    n, n_true, n_pred = len(s_true.intersection(s_pred)), len(s_true), len(s_pred)
    
    prec = n/n_pred
    rec = n/n_true
    f1 = 2*prec*rec/(prec + rec) if prec + rec else 0
    
    return {"f1": f1, "prec": prec, "rec": rec, "n_true": n_true, "n_pred": n_pred, "n": n}

In [35]:
if TARGET_PATH:
    sub_target = pd.read_csv(TARGET_PATH)
    
    sub_target = sub_target.merge(sub, how="left", on="row_id")
    
    
    print(sub_target["birds_x"].notnull().sum(), sub_target["birds_x"].notnull().sum())
    assert sub_target["birds_x"].notnull().all()
    assert sub_target["birds_y"].notnull().all()
    
    df_metrics = pd.DataFrame([get_metrics(s_true, s_pred) for s_true, s_pred in zip(sub_target.birds_x, sub_target.birds_y)])
    
    print(df_metrics.mean())

2400 2400
f1        0.670444
prec      0.674167
rec       0.669285
n_true    1.130000
n_pred    1.002500
n         0.675000
dtype: float64


In [36]:
sub_target[sub_target.birds_y != "nocall"]

Unnamed: 0,row_id,site,audio_id,seconds,birds_x,birds_y
0,7019_COR_5,COR,7019,5,nocall,nocall
1,7019_COR_10,COR,7019,10,nocall,nocall
2,7019_COR_15,COR,7019,15,nocall,nocall
3,7019_COR_20,COR,7019,20,nocall,nocall
4,7019_COR_25,COR,7019,25,nocall,nocall
...,...,...,...,...,...,...
2395,54955_SSW_580,SSW,54955,580,nocall,nocall
2396,54955_SSW_585,SSW,54955,585,grycat,nocall
2397,54955_SSW_590,SSW,54955,590,grycat,nocall
2398,54955_SSW_595,SSW,54955,595,nocall,nocall


In [37]:
sub_target[sub_target.birds_x != "nocall"]

Unnamed: 0,row_id,site,audio_id,seconds,birds_x,birds_y
240,11254_COR_5,COR,11254,5,rubwre1,wbwwre1
242,11254_COR_15,COR,11254,15,rubwre1,wbwwre1
244,11254_COR_25,COR,11254,25,rubwre1,rubwre1 wbwwre1
267,11254_COR_140,COR,11254,140,obnthr1,nocall
268,11254_COR_145,COR,11254,145,obnthr1,nocall
...,...,...,...,...,...,...
2391,54955_SSW_560,SSW,54955,560,grycat,nocall
2393,54955_SSW_570,SSW,54955,570,grycat,nocall
2394,54955_SSW_575,SSW,54955,575,chswar,nocall
2396,54955_SSW_585,SSW,54955,585,grycat,nocall
