# **Cornell BirdCall audio Recognition using JAX/FLAX**

### **Import necessary modules and setup Jupyter notebook environment**

In [2]:
import os, sys, time, math, gc, functools
import random, librosa, cv2, requests
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from typing import Any
from tqdm import tqdm, tqdm_notebook
from tqdm.auto import tqdm
tqdm.pandas()
%matplotlib inline

import tensorflow as tf
from tensorflow.keras.utils import to_categorical

import torch, torchvision
import torch.nn as nn
import torch.nn.init as init
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, random_split, DataLoader

import jax, optax, jax.nn
import flax.linen as nn
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
from jax.config import config


# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
def fxn():
    warnings.warn("deprecated", DeprecationWarning)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fxn()
warnings.filterwarnings("ignore")

seed = 1234
np.random.seed(seed)

# to suppress warnings caused by cuda version
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

### **TPU Detection & Configuration**

In [3]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
    tpu = None
    
if tpu:
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1
        
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU: ', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected!')
    
print(jax.devices())

No TPU detected!
[CpuDevice(id=0)]


### **Load & Preprocess dataset**

In [4]:
# seeding function for reproducibility
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONAHSHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [5]:
ROOT = "../input/birdsong-recognition"
os.listdir(ROOT)

['example_test_audio_metadata.csv',
 'sample_submission.csv',
 'example_test_audio',
 'train_audio',
 'train.csv',
 'test.csv',
 'example_test_audio_summary.csv']

In [6]:
df = pd.read_csv(os.path.join(ROOT, 'train.csv'))[['ebird_code', 'filename', 'duration']]
df['path'] = ROOT + 'train_audio/' + df['ebird_code'] + "/" + df['filename']
df.head()

Unnamed: 0,ebird_code,filename,duration,path
0,aldfly,XC134874.mp3,25,../input/birdsong-recognitiontrain_audio/aldfl...
1,aldfly,XC135454.mp3,36,../input/birdsong-recognitiontrain_audio/aldfl...
2,aldfly,XC135455.mp3,39,../input/birdsong-recognitiontrain_audio/aldfl...
3,aldfly,XC135456.mp3,33,../input/birdsong-recognitiontrain_audio/aldfl...
4,aldfly,XC135457.mp3,36,../input/birdsong-recognitiontrain_audio/aldfl...


In [7]:
SEED = 42
FRAC = 0.2    #
SR = 44100
MAXLEN = 60
N_MELS = 128

seed_everything(seed)
device = torch.device('cpu')

classes = set(random.sample(df['ebird_code'].unique().tolist(), 15))
print(classes)

{'comnig', 'whtspa', 'wesblu', 'baleag', 'brthum', 'pinwar', 'herthr', 'btywar', 'pilwoo', 'banswa', 'bulori', 'stejay', 'amecro', 'canwar', 'amerob'}


In [8]:
df = df[df.ebird_code.apply(lambda x: x in classes)].reset_index(drop=True)
keys = set(df.ebird_code)
values = np.arange(0, len(keys))
code_dict = dict(zip(sorted(keys), values))
df['label'] = df['ebird_code'].apply(lambda x: code_dict[x])
df.head()

Unnamed: 0,ebird_code,filename,duration,path,label
0,amecro,XC109768.mp3,16,../input/birdsong-recognitiontrain_audio/amecr...,0
1,amecro,XC112598.mp3,126,../input/birdsong-recognitiontrain_audio/amecr...,0
2,amecro,XC112829.mp3,135,../input/birdsong-recognitiontrain_audio/amecr...,0
3,amecro,XC114550.mp3,17,../input/birdsong-recognitiontrain_audio/amecr...,0
4,amecro,XC114551.mp3,11,../input/birdsong-recognitiontrain_audio/amecr...,0


In [9]:
class BirdSoundDataset(Dataset):
    """Bird Sound dataset."""

    def __init__(self, df, transform = None):
        """
        Args:
            df (pd.DataFrame): must have ['path', 'label'] columns
        """
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    
    def loadMP3(self, path, duration):
        """
        Args:
            path: path of the audio file 
        Returns:
            mels: Melspectrogram of the given audio file 
        """
        try:
            duration=5
            samples = SR* duration
            audio, _ = librosa.load(path, sr=SR)
            
            if 0 < len(audio):
                audio, _ = librosa.effects.trim(audio)
            if len(audio) > samples: # long enough
                audio = audio[0:0+samples]
            else: # pad blank
                padding = samples - len(audio)
                offset = padding // 2
                y = np.pad(audio, (offset, samples - len(audio) - offset), 'constant')

            mels = librosa.feature.melspectrogram(y=audio, sr=SR,n_mels=N_MELS, hop_length = 347,n_fft = N_MELS *20,fmin = 20, fmax = SR//2)
            mels = librosa.power_to_db(mels).astype(np.float32)
            mels = mels.transpose()
            eps = 0.001
            if np.std(mels) != 0:
                mels = (mels - np.mean(mels)) / np.std(mels)
            else:
                mels = (mels - np.mean(mels)) / eps
            return mels
            
        except Exception as e:
            print("Error encountered while parsing file: ", path, e)
            mels = np.zeros((N_MELS, MAXLEN*SR//347), dtype=np.float32)
            return mels
            

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        path = self.df['path'].iloc[idx]
    
        duration=5
        if os.path.exists("./"+path.split('/')[-1]+".npy"):
            mels = np.load("./"+path.split('/')[-1]+".npy")
        else:
            
            mels = self.loadMP3(path, duration)
            np.save("./"+path.split('/')[-1]+".npy", mels)
        label  = self.df['label'].iloc[idx]
        mels = np.resize(mels,(636,128,1))
        return mels, label

In [10]:
# Dividing the dataset into train and validation sets
df = df.sample(frac=1).reset_index(drop=True)
train_len = int(len(df) * (1-FRAC))
train_df = df.iloc[:train_len]
valid_df = df.iloc[train_len:]
train_df.shape, valid_df.shape

((1080, 5), (271, 5))

In [11]:
# prepare data loaders 
#NUM_TPUS = jax.device_count()
BATCH_SIZE = 32

train_loader = torch.utils.data.DataLoader(BirdSoundDataset(train_df),
                                           batch_size=BATCH_SIZE, 
                                           num_workers=0, 
                                           shuffle=True, 
                                           drop_last = True)

valid_loader = torch.utils.data.DataLoader(BirdSoundDataset(valid_df), 
                                           batch_size=BATCH_SIZE, 
                                           num_workers=0, 
                                           shuffle=True, 
                                           drop_last = True)

len(train_loader), len(valid_loader)

(33, 8)

In [12]:
(image_batch, label_batch) = next(iter(train_loader))
print(image_batch.shape)
print(label_batch.shape)

torch.Size([32, 636, 128, 1])
torch.Size([32])


In [None]:
NUM_TPUS = jax.device_count()

def copy_dataset_to_devices(dataset, devices, num_reps=1):
    sharded_images = []
    sharded_labels = []
    for _ in range(num_reps):
        for image_batch, label_batch in tqdm(dataset, ncols=100):
            image_batch = image_batch.detach().cpu().numpy()
            image_batches = np.split(image_batch, NUM_TPUS, axis = 0)
            sharded_device_images = jax.device_put_sharded(image_batches, devices)
            sharded_images.append(sharded_device_images)

            label_batch = label_batch.detach().cpu().numpy()
            label_batches = np.split(label_batch, NUM_TPUS, axis = 0)
            sharded_device_labels = jax.device_put_sharded(label_batches, devices)
            sharded_labels.append(sharded_device_labels)

    return sharded_images, sharded_labels

devices = jax.local_devices()
sharded_training_images, sharded_training_labels = copy_dataset_to_devices(train_loader, devices, num_reps=10)

  0%|                                                                        | 0/33 [00:00<?, ?it/s]

  0%|                                                                        | 0/33 [00:00<?, ?it/s]

  0%|                                                                        | 0/33 [00:00<?, ?it/s]

  0%|                                                                        | 0/33 [00:00<?, ?it/s]

  0%|                                                                        | 0/33 [00:00<?, ?it/s]

  0%|                                                                        | 0/33 [00:00<?, ?it/s]

  0%|                                                                        | 0/33 [00:00<?, ?it/s]

  0%|                                                                        | 0/33 [00:00<?, ?it/s]

  0%|                                                                        | 0/33 [00:00<?, ?it/s]

  0%|                                                                        | 0/33 [00:00<?, ?it/s]

In [None]:
NUM_CLASSES = 15 
class VGG19(nn.Module):
    @nn.compact
    def __call__(self, x, training):
        x = self._stack(x, 64, training)
        x = self._stack(x, 64, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
    
        x = self._stack(x, 128, training)
        x = self._stack(x, 128, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = self._stack(x, 256, training)
        x = self._stack(x, 256, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))    

        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))    
    
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))  

        x = x.reshape((x.shape[0], -1))

        x = nn.Dense(features=4096)(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.5, deterministic=not training)(x)

        x = nn.Dense(features=4096)(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.5, deterministic=not training)(x)
    
        x = nn.Dense(features=NUM_CLASSES)(x)
        return x
  
    @staticmethod
    def _stack(x, features, training, dropout=None):
        x = nn.Conv(features=features, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        return x

In [None]:
def average_metrics(metrics):
    '''
    Takes the list of dictionaries of the form k: v, and returns a dictionary
     of the form k: (average of the v).
    '''
    return {k: np.mean([metric[k] for metric in metrics])
        for k in metrics[0]}

def train(initial_network_state, num_epochs):
    '''
    Training the model from the given state, returns the state along with the training accuracies
    '''
    training_accuracies = []
    state = initial_network_state
    for i in range(num_epochs):
        batch_metrics = []
        for (image_batch, label_batch) in tqdm(zip(sharded_training_images,
                                               sharded_training_labels),
                                           total=len(sharded_training_images),
                                           ncols=100):
            state, metrics = train_batch(state, image_batch, label_batch)
            batch_metrics.append(metrics)
        train_metrics = average_metrics(batch_metrics)
        print(f'Epoch {i+1} done.', flush=True)
        print(f'  Loss: {train_metrics["loss"]:.4f}, '
          + f'accuracy: {train_metrics["accuracy"]:.4f}', flush=True)
        training_accuracies.append(train_metrics["accuracy"])
    return state, training_accuracies

In [None]:
class VGGState(train_state.TrainState):
    rng: Any
    batch_stats: Any
  
    @classmethod
    def create(cls, apply_fn, params, tx, rng, batch_stats):
        opt_state = tx.init(params)
        state = cls(0, apply_fn, params, tx, opt_state, rng, batch_stats)
        return state
  
    @classmethod
    def update_rng(cls, state, rng):
        return VGGState.create(state.apply_fn, state.params, state.tx, rng,
                           state.batch_stats)
  
    @classmethod
    def update_batch_stats(cls, state, batch_stats):
        return VGGState.create(state.apply_fn, state.params, state.tx,
                           state.rng, batch_stats)

In [None]:
def accuracy(logits, labels):
    '''
    Calcualtes the accuracy using the given logits and labels
    '''
    return jnp.mean(jnp.argmax(logits, -1) == labels)

def cross_entropy(logits, labels):
    '''
    Cross Entropy error between the logits and labels
    '''
    one_hot_labels = jax.nn.one_hot(labels, NUM_CLASSES)
    cross_entropy = optax.softmax_cross_entropy(logits, one_hot_labels)
    return jnp.mean(cross_entropy)

def training_loss(image_batch, label_batch, rng, batch_stats, params):
    '''
    Calculates the training loss 
    '''
    logits, batch_stats = VGG19().apply({'params': params,
                                       'batch_stats': batch_stats},
                                      image_batch, 
                                      training=True,
                                      rngs={'dropout': rng},
                                      mutable=['batch_stats'])
    loss = cross_entropy(logits, label_batch)
    return loss, (logits, batch_stats)

In [None]:
@functools.partial(jax.pmap, axis_name='tpu')
def train_batch(state, image_batch, label_batch):
    '''
    Training a single batch and returns loss and the accuracy
    '''
    rng, subrng = jax.random.split(state.rng)
    batch_loss_fn = functools.partial(training_loss, image_batch, label_batch,
                                    subrng, state.batch_stats)
    (batch_loss, (logits, batch_stats)), grads = \
    jax.value_and_grad(batch_loss_fn, has_aux=True)(state.params)
  
    gradsum = jax.lax.psum(grads, axis_name='tpu')

    state = state.apply_gradients(grads=gradsum)
    state = state.update_batch_stats(state, batch_stats['batch_stats'])
    state = state.update_rng(state, rng)

    batch_accuracy = accuracy(logits=logits, labels=label_batch)
    batch_accuracy_sum = jax.lax.pmean(batch_accuracy, axis_name='tpu')
    batch_loss = jax.lax.psum(batch_loss, axis_name='tpu')
    stats = {'loss': batch_loss, 'accuracy': batch_accuracy_sum}  

    return state, stats

In [None]:
def create_train_state(rng, dummy_image_batch):
    net = VGG19()
    params = net.init({'params': rng, 'dropout': rng}, dummy_image_batch, True)
    tx = optax.adam(learning_rate=0.01)
    state = VGGState.create(net.apply, params['params'], tx, rng,
                          params['batch_stats'])
    return state

In [None]:
rng = jax.random.PRNGKey(42)
rngs = np.broadcast_to(rng, (NUM_TPUS,) + rng.shape)
some_dummy_image_batch = sharded_training_images[0]
state = jax.pmap(create_train_state, axis_name='tpu')(rngs,some_dummy_image_batch)

In [None]:
start = time.time()
final_state, training_accuracies = train(state, num_epochs=25)
print("Total time: ", time.time() - start, "seconds")

In [None]:
# Plot the Accuracy 
plt.plot(training_accuracies)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()