In [None]:
import os
import collections
from etils import epath
from IPython.display import HTML
import matplotlib.pyplot as plt
from ml_collections import config_dict
import numpy as np
from scipy.io import wavfile
import tensorflow as tf
import tqdm

# from chirp.inference import colab_utils
# colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)

from chirp import audio_utils
from chirp import config_utils
from chirp import path_utils
from chirp.inference import embed_lib
from chirp.inference import models
from chirp.inference import tf_examples
from chirp.models import metrics
from chirp.inference.search import bootstrap
from chirp.inference.search import search
from chirp.inference.search import display
from chirp.inference.classify import classify
from chirp.inference.classify import data_lib

# Gambiarrinha pra rodar o modelo: https://stackoverflow.com/questions/71153492/invalid-argument-error-graph-execution-error
from tensorflow.keras.optimizers.legacy import Adam


# We should see a GPU in the list of devices, if connected to a Colab GPU.
tf.config.list_physical_devices()

# Necessary components

- `refernce_audio_path` : path to example file used for embedding search
- `unlabeled_audio_path`: path to field data recordings

# Global parameters

**Project folder:**

In [None]:
parent_folder = '/Users/lviotti/Library/CloudStorage/Dropbox/Work/Kitzes/projects/active-learning-nb/td-tutorial'

**Reference example** to be used for querying similar examples in unlabeled data.


In [None]:
refernce_audio_path = f'{parent_folder}/XC771930-Wood Thrush.mp3'  #@param

**Classes:**

In [None]:
target_classes = ['Wood Thrush', 'Unknown'] #@param

**Unllabelled data path and `glob` pattern:**

[glob](https://docs.python.org/3/library/glob.html) string to be used for finding unlabeled audio files in your `shared_data_folder` and nested directories if any.

In [None]:
# Input and output folders
shared_data_folder = f'{parent_folder}/CCAI 2023 Google Perch Tutorial Shared Data/'

unlabeled_audio_pattern = shared_data_folder + 'Powdermill Embeddings/Recording_4/Recording_4_Segment_2*.wav' #@param

# '/Users/lviotti/Library/CloudStorage/Dropbox/Work/Kitzes/datasets/song9_v1/clips_20211025_appl2021a_1min_each_point_A'

**Create embeddigs** - If you have already created embeddings, change this to True.

In [None]:
create_embeddings = False

**Distance parameters:** 

In [None]:
top_k = 10 # @param {type:"number"}

# The Perch codebase supports:
#  'euclidean', which is the standard euclidean distance
#  'cosine', which is the cosine similarity,
#  'mip', which is Maximum Inner Product
metric = 'euclidean'  #@param['euclidean', 'mip', 'cosine']

# Target distance for search results.
# This lets us try to hone in on a 'classifier boundary' instead of just
# looking at the closest matches.
# Set to 'None' for raw 'best results' search.
target_score = 3.4  #@param

Outputs folder:


In [None]:
output_directory = f'{parent_folder}/CCAI 2023 Google Perch Tutorial Outputs/'


Other configs:

In [None]:
config = config_dict.ConfigDict()

# The location of the pre-trained model.
model_path = shared_data_folder + 'bird-vocalization-classifier/'

# Specify a directory where the embeddings will be written.
embedding_output_dir = output_directory + 'raw_embeddings/' #@param

config.output_dir = embedding_output_dir
config.source_file_patterns = [unlabeled_audio_pattern]

# Create output directory and write the configuration.
output_dir = epath.Path(config.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)

# The path to an empty directory where the generated labeled samples will be
# placed. Each labeled sample will be placed into a subdirectory corresponding
# to the target class that we select for that sample.
labeled_data_path = output_directory + 'labeled_outputs/'  #@param

if not os.path.exists(output_directory):
  os.mkdir(output_directory)

In [None]:
# Model specific parameters: PLEASE DO NOT CHANGE THE CODE IN THIS CELL.
embed_fn_config = config_dict.ConfigDict()
embed_fn_config.model_key = 'taxonomy_model_tf'
model_config = config_dict.ConfigDict()

# The size of each "chunk" of audio.
model_config.window_size_s = 5.0

# The hop size (aka model 'stride') is the offset in seconds between successive
# chunks of audio. When hop_size is equal to window size, the chunks of audio
# will not overlap at all. Choosing a smaller hop_size (a common choice is half
# of the window_size) may be useful for capturing interesting data points that
# correspond to audio on the boundary between two windows. However, a smaller
# hop size may also lead to a larger embedding dataset because each instant of
# audio is now pesent in multiple windows. As a consequence, you might need to
# "de-dupe" your matches since multiple embedded data points may borreopnd to
# the same snippet of raw audio.
model_config.hop_size_s = 5.0

# All audio in this tutorial is resampled to 32 kHz.
model_config.sample_rate = 32000

# The location of the pre-trained model.
model_config.model_path = model_path

# Only write embeddings to reduce size. The Perch codebase supports serializing
# a variety of metadata along with the embeddings, but for the purposes of this
# tutorial we will not need to make use of those features.
embed_fn_config.write_embeddings = True
embed_fn_config.write_logits = False
embed_fn_config.write_separated_audio = False
embed_fn_config.write_raw_audio = False

config.embed_fn_config = embed_fn_config
embed_fn_config.model_config = model_config

# We have functionality to break large inputs up into smaller chunks;
# this is especially helpful for dealing with long files or very large datasets.
# get in touch if you think you may need this.
config.shard_len_s = -1
config.num_shards_per_file = -1

# Number of parent directories to include in the filename. This allows us to
# provess raw audio that lives in multiple directories.
config.embed_fn_config.file_id_depth = 1

# Number of TFRecord files to create.
config.tf_record_shards = 1

### Write the configuration to JSON to ensure consistency with later stages of the pipeline

In [None]:
# This dumps a config json file next to the embeddings that allows us to reuse
# the same embeddings and ensure that we have the correct config that was used
# to generate them.
embed_lib.maybe_write_config(config, output_dir)

# Create SourceInfos configuration, used in sharded computation when computing
# embeddings.  These source_infos contain metadata about how we're going to
# partition the search corpus.  In particular, we're splitting the Powdermill
# audio into hundreds of 5s chunks, and the source_infos help us with the
# keep track of which chunk came from which raw audio file.
source_infos = embed_lib.create_source_infos(
    config.source_file_patterns,
    config.num_shards_per_file,
    config.shard_len_s)
print(f'Constructed {len(source_infos)} source infos.')

### Load the pre-trained embedding model

We will apply *transfer learning* by using a model pre-trained on an avian bioacoustics dataset ([Xeno-Canto](xeno-canto.org)) to compute the embeddings we will search over. This pre-trained model allows us to leverage a rich learned representation so that we do not need to train a custom model for the specific species we search for.

The pre-trained model that we use is (open-sourced and available on [TFHub](https://tfhub.dev/google/bird-vocalization-classifier/4)).

In [None]:
# Here we're loading our generic Bird Classifier model.
# The embed_fn object is a wrapper over the model.
embed_fn = embed_lib.EmbedFn(**config.embed_fn_config)
print('\n\nLoading model(s)...')
embed_fn.setup()

print('\n\nTest-run of model...')
z = np.zeros([int(model_config.sample_rate * model_config.window_size_s)])
embed_fn.embedding_model.embed(z)
print('Setup complete!')

<a name=embed_data></a>
## Generate Embeddings

### Process the search dataset

In [None]:
# To reduce the overhead computational resources required and speed up execution
# time, we use multiple threads to load the audio before embedding. This tends
# to perform faster, but can fail if any audio files are corrupt.

# The source_infos variable contains metadata about how to parition the search
# corpus.  This step creates an audio_iterator which iterates over the 5 second
# chunks of audio.
embed_fn.min_audio_s = 1.0
record_file = (output_dir / 'embeddings.tfrecord').as_posix()
succ, fail = 0, 0

audio_iterator = audio_utils.multi_load_audio_window(
    filepaths=[s.filepath for s in source_infos],
    offsets=[s.shard_num * s.shard_len_s for s in source_infos],
    sample_rate=model_config.sample_rate,
    window_size_s=config.shard_len_s,
)

### Embed the search dataset

We are ready to generate the embeddings for the raw audio.  This cell iterates over the `audio_iterator` created in the previous cell and creates a point (vector) in *embedding space* for each 5 second chunk of raw audio.  We write these embeddings to files (which are written into your `emebdding_output_dir` directory that we specified above), and then return a `ds` variable that is a handle on the resulting TFRecordDataset object.

Writing the embeddings to file is useful because for large datasets, this embedding step can take minutes or hours, and we don't want to have to repeatedly regenerate the embeddings.


In [None]:
if create_embeddings:
    with tf_examples.EmbeddingsTFRecordMultiWriter(
        output_dir=output_dir, num_files=config.tf_record_shards) as file_writer:
      for source_info, audio in tqdm.tqdm(
          zip(source_infos, audio_iterator), total=len(source_infos)):
        if audio.shape[0] < embed_fn.min_audio_s * model_config.sample_rate:
          # Ignore short audio.
          continue
        file_id = source_info.file_id(config.embed_fn_config.file_id_depth)
        offset_s = source_info.shard_num * source_info.shard_len_s
        example = embed_fn.audio_to_example(file_id, offset_s, audio)
        if example is None:
          fail += 1
          continue
        file_writer.write(example.SerializeToString())
        succ += 1
      file_writer.flush()
    print(f'\n\nSuccessfully processed {succ} source_infos, failed {fail} times.')
    
    fns = [fn for fn in output_dir.glob('embeddings-*')]
    ds = tf.data.TFRecordDataset(fns)
    parser = tf_examples.get_example_parser()
    ds = ds.map(parser)
    for ex in ds.as_numpy_iterator():
      print('Recording filename:', ex['filename'])
      print('Shape of the embedding:', ex['embedding'].shape)
      break

<a name=similarity-search></a>
## Audio Similarity Search

In the previous section we generated a set of embeddings for our search corpus as well as for our Wood Thrush query.  In this section we will 'search' within the search corpus to find examples that are similar to our query, and we'll manually label each sample as 'Wood Thrush' or 'Unknown'.  These labeled samples then get saved and will be used as training data for the linear classifier in the next section.

Recall that an *embedding* is simply a vector (point) in some high dimensional space.  The similarity search we implement relies of Euclidean distance between two vectors. Other metrics can be used to compare two vectors, such as cosine similarity or Maximum Inner Product (MIP).

**Note:** In practice, we end up looping through this section repeatedly with different query samples (if we have them) and different search parameters.  That helps us generate a robust training set.

### Prepare the query vector

We're now ready to create the 'query', which uses a known Wood Thrush vocalization.  In general, this query is the audio that you're looking for in your search corpus.

As mentioned above, there are massive repositories of avian vocalizations available for download from [eBird.org](https://www.ebird.org) or [Xeno-Canto](https://xeno-canto.org/).

This audio snippets are visualized as [spectrograms](https://en.wikipedia.org/wiki/Spectrogram), which are a common way to display time-series data.

In [None]:

# refernce_audio_path = shared_data_folder + 'XC771930-Wood Thrush.mp3'  #@param

audio = audio_utils.load_audio(refernce_audio_path, model_config.sample_rate)

# Display the full file.
display.plot_audio_melspec(audio, model_config.sample_rate)

#### Select the specific query window

The full audio sample in the previous cell might be longer than 5 seconds, so we need to find a specific window within the full clip to use (recall we're always working with 5 second audio samples).  Listen to the audio output above and find a 5 second window with a nice, clear vocalization, and note the elapsed time in seconds.

In [None]:
# The downloaded audio sample might be longer than 5 seconds, which is the
# length required in order to generate the embedding for the query.
# Mess around with manual selection of the query start time.  A value of 2 works
# well for the XC771930-Wood Thrush file.
start_s = 2  #@param

# Display the selected window.
print('Selected audio window:')
st = int(start_s * model_config.sample_rate)
end = int(st + model_config.window_size_s * model_config.sample_rate)
if end > audio.shape[0]:
  end = audio.shape[0]
  st = max([0, int(end - model_config.window_size_s * model_config.sample_rate)])
audio_window = audio[st:end]
display.plot_audio_melspec(audio_window, model_config.sample_rate)

query_audio = audio_window
sep_outputs = None

#### Embed the query

Here we run the query audio through the embedding function to generate the embedding vector for the query.

While it isn't a requirement for this tutorial, depending on the quality of the source audio, we can add in an additional preprocessing step using **sound separation**. In particular, [it's been shown](https://arxiv.org/abs/2110.03209) that unsupervised sound separation can help improve bird classification by automatedly reducing the presence of overlapping vocalizations, background noise, etc. to increase the signal-to-noise ratio (SNR).

In [None]:
query = query_audio

embedded_query = embed_fn.embedding_model.embed(query).embeddings[ :, 0, :]

### Create a TensorFlow Dataset (TFDS) wrapper over the embeddings

This is a technical step that wraps our search corpus in a TFDS object to allow us to use some convenient built-in features.

If you would like to utilize the larger search corpus from the whole Powdermill dataset, set the `embeddings_dir` variable to
`drive_shared_data_folder + 'powdermill/'`.  If you would like to first use the
smaller search corpus that you generated above, you can come back to this cell later and re-run this cell with the updated `embeddings_dir` and all subsequent cells.

In [None]:
#@title Refresh pre-computed embeddings

use_precomputed_embeddings = False # @param {type:"boolean"}
if use_precomputed_embeddings:
  # Use the embeddings from the entire Powdermill dataset.
  # This dataset contains a much larger number of embeddings,
  # so it might be more interesting to explore.
  embeddings_dir = shared_data_folder + 'Powdermill Embeddings/'
else:
  # Use the embedded dataset that we created above...
  embeddings_dir = embedding_output_dir

bootstrap_config = bootstrap.BootstrapConfig.load_from_embedding_config(
    embeddings_path=embeddings_dir,
    annotated_path=labeled_data_path
)

project_state = bootstrap.BootstrapState(
    bootstrap_config, embedding_model=embed_fn.embedding_model)

# GAMBIARRA da boa to create .source_map attribute
project_state.config.audio_globs = [f'{shared_data_folder}Powdermill Embeddings/*/*.wav']

project_state = bootstrap.BootstrapState(
    bootstrap_config, embedding_model=embed_fn.embedding_model)

embeddings_ds = project_state.create_embeddings_dataset()

### <a name=top-k-search></a>Run top-k search using a comparison metric

In this cell we run a "nearest neighbor" search over the search corpus to find embeddings that are closest to the query embedding according to our chosen metric.  These will correspond to snippets of Powdermill audio that sound similar to our Wood Thrush query.

The `target_score` variable is a param that allows us to surface not just the closest matches, but rather matches that lie some fixed distance away.  When `target_score` is set to `None` (or `0`), the search will return the closest matches.

Recall that our goal in this section is to generate training data for a linear classifier. In order to train a robust model, we want this training dataset to contain both obvious/easy examples as well as not-so-obvious examples. If in our search we just looked for the closest possible matches, we would likely only find easy examples. The `target_score` param allows us to look for examples that might be less obvious (because they are farther away in embedding space).

**Important note:** You will likely need to come back to this cell and produce more top-k instances by modulating the value of `target_score`. We explain the approach to choosing useful values for `target_score` in the section [Choose a target score...'](#choosing-target-score).

In [None]:
# Number of search results to capture.


results, all_scores = search.search_embeddings_parallel(
    embeddings_ds, embedded_query,
    hop_size_s=model_config.hop_size_s,
    top_k=top_k, target_score=target_score, score_fn=metric,
    random_sample=False)

### User-in-the-loop data labeling (**requires user interaction**)

The cell below displays the search results in a user-interface in the following format:
* image: a plot visualization of the audio search result (Mel spectrogram, frequency in Hz over time)
* a playback of the audio sample itself
* metrics and metadata: `rank` position, `source file` of the recording segment, `offset_s` (in seconds) from the recording, and the search `score` (i.e. similarity with the query)
* candidate labels for the sample

**Instructions to the user:** <br>
For each search result presented below, select the label ('Wood Thrush' or 'Unknown') you think is the closest fit to the sound you hear in the file. In this tutorial, the label options are Wood Thrush or Unknown.

#### Quick guide to the Wood Thrush's vocalizations

**Songs**
---------
The Wood Thrush's easily recognized, "flute-like" **ee-oh-lay** is the middle phrase of a three-part song. However, if you recognize this distinct middle portion present in the retrieved samples, you can use and rely this as an anchor in when annotating.

**Calls**
---------
If you notice a staccato like **bup-bup-bup**, this is associated as a mild distress call that "rises in pitch and grows louder and more complex...until it becomes a distinctive **pit-pit-pit** alarm."

**For examples of the Wood Thrush's vocalizations** please check out the sound examples from the All About Birds [Macauley library here](https://www.allaboutbirds.org/guide/Wood_Thrush/sounds).

In [None]:
# Display the search results for user labeling.
display.display_search_results(
    results, model_config.sample_rate, project_state.source_map,
    checkbox_labels=target_classes,
    max_workers=5)

### Write the user-annotated search results to file

This cell saves the annotations you generated in the previous step.  It writes data to the `labeled_data_path` location that was specified above.

In [None]:
def write_labeled_data(search_results, labeled_data_path: str, sample_rate: int):
  """Write labeled results to the labeled data collection."""
  labeled_data_path = epath.Path(labeled_data_path)
  counts = collections.defaultdict(int)
  duplicates = collections.defaultdict(int)
  for r in search_results:
    labels = [ch.description for ch in r.label_widgets if ch.value]
    if not labels:
      continue
    extension = epath.Path(r.filename).suffix
    filename = epath.Path(r.filename).name[: -len(extension)]
    output_filename = f'{filename}___{r.timestamp_offset}{extension}'
    for label in labels:
      output_path = labeled_data_path / label
      output_path.mkdir(parents=True, exist_ok=True)
      output_filepath = epath.Path(output_path / output_filename)
      if output_filepath.exists():
        duplicates[f'{label}'] += 1
        continue
      else:
        counts[label] += 1
      with output_filepath.open('wb') as f:
        wavfile.write(f, sample_rate, np.float32(r.audio))
  for label, count in counts.items():
    print(f'Wrote {count} examples for label {label}')
  for label, count in duplicates.items():
    print(f'Not saving {count} duplicates for label {label}')

write_labeled_data(results, labeled_data_path, model_config.sample_rate)

In [None]:
os.listdir('img')

### <a name=choosing-target-score></a> Choose a distance score 

**Goal:** Select a small enough distance value that examples will be close to example in the embedding space.



In [None]:
print(f'Your current score is: {target_score}')

**Tutorial instructions:**

During the top-k search step above, the Perch code also computed and saved the distances to *every* point in the search corpus.  The actual numerical values of these distances are hard to interpret, but the relative values are very useful.  In the following cell we plot a histogram of this set of distances to help us conceptualize the geometry of our embedded dataset.  This histogram helps us find and tune our values for the `target_score` variable in the top-k search.

A typical histogram will appear to fit some vaguely-normal looking distribution, possibly skewed left with a heavy tail.  While there is no prescriptive formula for finding useful values of `target_score`, the Perch team has found that good choices tend to lie near the left-hand 'hockey-stick' point of the distribution.  For example, in the following histogram, you might try playing with values somewhere in the range of 2.8 to 3.1:

<!-- ![search](https://drive.google.com/uc?id=1bLc2XDTqutihg4wJSkCfB2DiT4Dpr4UO) -->
![title](./img/score_hist.png)


These tend to correspond to examples that are faint, or have background noise, or are otherwise not especially obvious.  Annotating these examples and adding them to the training set is very important because they help the linear model discriminate better on these less-clear "boundary" points.  

In [None]:
# Plot histogram of distances.
ys, _, _ = plt.hist(all_scores, bins=128, density=True)
hit_scores = [r.score for r in results.search_results]
plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',
            color='r', alpha=0.5)

plt.xlabel(metric)
plt.ylabel('density')
if target_score is not None:
  plt.plot([target_score, target_score], [0.0, np.max(ys)], 'r:')
  # Compute the proportion of scores < target_score.
  hit_percentage = (all_scores < target_score).mean()
  print(f'score < target_score percentage : {hit_percentage:5.3f}')
min_score = np.min(all_scores)
plt.plot([min_score, min_score], [0.0, np.max(ys)], 'g:')

plt.show()

<a name=active-learning></a>
## Train a Linear Classifier

In the last stage, we labeled samples from the search dataset as matches (by similarity comparison with the queries) for each of our target classes. We will now train a simple linear model using those bootstrapped (labeled) data points from the search dataset.

**Important:** in order to be able to train the linear model, we need ~10-12 examples from each class, in our case 'Wood Thrush' and 'Unknown'. If you encounter an error in this section when training the model, you likely did not generate a sufficient amount of labeled data. Please go back and choose a new value of the `target_score` attribute.

In [None]:
# @title Load and embed the search-annotated dataset { vertical-output: true }

# Load the training data that is located in the `labeled_data_path` directory.
# In that directory there will be folders corresponding to our target labels

merged = data_lib.MergedDataset.from_folder_of_folders(
    base_dir=labeled_data_path,
    embedding_model=project_state.embedding_model,
    time_pooling='mean',
    load_audio=False,
    target_sample_rate=-2,
    audio_file_pattern='*',
    embedding_config_hash=bootstrap_config.embedding_config_hash(),
)

# Label distribution
lbl_counts = np.sum(merged.data['label_hot'], axis=0)
print('num classes :', (lbl_counts > 0).sum())
print('mean ex / class :', lbl_counts.sum() / (lbl_counts > 0).sum())
print('min ex / class :', (lbl_counts + (lbl_counts == 0) * 1e6).min())

### Train a simple linear model using the labeled embeddings

We use the following hyperparameters which should serve as reasonably performing defaults to train this linear model (classifier):

- `batch_size`: 12
- `num_epochs`: 128
- `num_hiddens`: -1 (to match the dimens of the embeddings)
- `learning_rate`: 0.001

Additionally, we compute the following metrics to measure the "goodness" of the trained model:
- `acc`: overall accuracy
- `auc_roc`: AUC ROC, or area under the receiving curve
- `cmAP`: mean average precision averaged across species
- `maps`: mean average precision for each class

In [None]:
# Number of random training examples to choose from each class.

# Note that if you don't have very many samples you'll need to set
# train_ratio=None and train_examples_per_class to a value that is
# less than the minimum number of examples you have of each class.

# Set exactly one of train_ratio and train_examples_per_class
train_ratio = None  #@param
train_examples_per_class = 5  #@param

# Number of random re-trainings. In other words, this value indicates how many
# models we will train.  By training multiple models, we get a sense of model
# robustness.  Training a single model (ie, num_seeds = 1) is sufficient for
# this tutorial, but feel free to bump this up to 4 or 12 to see the different
# characteristics of how the models perform.
num_seeds = 1  #@param

# Classifier training hyperparams.
# These should be good defaults.
batch_size = 12
num_epochs = 128
num_hiddens = -1
learning_rate = 1e-3

In [None]:
# This cell trains the linear model(s) and output some summary statistics for
# each model.  If you only have num_seeds = 1 then we'll only train a single
# model here.
metrics = collections.defaultdict(list)
for seed in tqdm.tqdm(range(num_seeds)):
  if num_hiddens > 0:
    model = classify.get_two_layer_model(
        num_hiddens, merged.embedding_dim, merged.num_classes)
  else:
    model = classify.get_linear_model(
        merged.embedding_dim, merged.num_classes)
  run_metrics = classify.train_embedding_model(
      model, merged, train_ratio, train_examples_per_class,
      num_epochs, seed, batch_size, learning_rate)
  metrics['acc'].append(run_metrics.top1_accuracy)
  metrics['auc_roc'].append(run_metrics.auc_roc)
  metrics['cmap'].append(run_metrics.cmap_value)
  metrics['maps'].append(run_metrics.class_maps)
  metrics['test_logits'].append(run_metrics.test_logits)

#### Compute the average metrics and print the model performance

In the previous cells, the `num_seeds` param controls how many times we train a model.  Each time we train a model there is some randomness in terms of which data points we choose from our labeled data, as well as some randomness in the model's initialization.  We can get a sense of how robust our classifier is by training multple times and looking at the summary statistics computed by the following cell.  A low `auc_roc` value (ie, less than 0.9 or so) probably indicates that we should generate some more training data.

In [None]:
mean_acc = np.mean(metrics['acc'])
mean_auc = np.mean(metrics['auc_roc'])
mean_cmap = np.mean(metrics['cmap'])
# Merge the test_logits into a single array.
test_logits = {
    k: np.concatenate([logits[k] for logits in metrics['test_logits']])
    for k in metrics['test_logits'][0].keys()
}

print(f'acc:{mean_acc:5.2f}, auc_roc:{mean_auc:5.2f}, cmap:{mean_cmap:5.2f}')
for lbl, auc in zip(merged.labels, run_metrics.class_maps):
  if np.isnan(auc):
    continue
  print(f'\n{lbl:8s}, auc_roc:{auc:5.2f}')
  colab_utils.prstats(f'test_logits({lbl})',
                      test_logits[merged.labels.index(lbl)])

### The Active Learning Loop: Generating more training examples

It is possible, and even likely, that your first pass through the data yielded a model that could be improved, possibly with a low AUC-ROC, or poor performance in one particular class.  To improve the model, we need to generate additional interesting labeled examples from each class.  There are two ways we can proceed:

1. Go back to the audio search step and either use a new query example (if you have one), or play with the `target_score` param in the top-k search cell.
2. Now that we have trained a linear model, we can examine that model's outputs to help us surface new examples from th search corpus.

Both methods are perfectly reasonable, there is no right answer here!  If you'd like try the second method, proceed to the next cell.

### Generate new samples using logit scores

Most models output some notion of 'activation' per class, which can be thought of as some form of confidence score.  When you run a new sample through a model, the resulting activation value is called is a **logit**.  In the next cell we will surface new examples by looking for points in the search corpus that yield high logit scores for a given target class.  Intuitively, these are points that our linear model is very excited about, and therefore the resulting points are likely to be relevant to our task.

This method is similar to our top-k search method above.  We can input a `target_logit`, which will be used to surface points with a logit score that is close to the `target_logit`.  In a later cell, we show a histogram of all logit scores to indicate what the useful values of `target_logit` might be.  The first time through this cell, set `target_logit = None` to find matches with the highest logit.

Another good choice is `target_logit = 0.0`. This corresponds to the data where the model is most uncertain - labeling this data is a common active learning strategy called **margin sampling** and can help produce good models quickly.




In [None]:
# Choose the target class to work with.  This must be one of the values from the
# target_classes list above.
target_class = 'Wood Thrush'  #@param
# Choose a target logit; will display results close to the target.
# Set to None to get the highest-logit examples.
target_logit = None  #@param
# Number of results to display.
num_results = 10  #@param

# Create the embeddings dataset.
target_class_idx = merged.labels.index(target_class)
results, all_logits = search.classifer_search_embeddings_parallel(
    embeddings_classifier=model,
    target_index=target_class_idx,
    embeddings_dataset=embeddings_ds,
    hop_size_s=model_config.hop_size_s,
    target_score=target_logit,
    top_k=num_results
)

In [None]:
# Plot the histogram of model logits.
_, ys, _ = plt.hist(all_logits, bins=128, density=True)
plt.xlabel(f'{target_class} logit')
plt.ylabel('density')
# plt.yscale('log')
plt.plot([target_logit, target_logit], [0.0, np.max(ys)], 'r:')
plt.show()

In [None]:
#@title Display results for the target label { vertical-output: true }

display_labels = merged.labels

extra_labels = []  #@param
for label in extra_labels:
  if label not in merged.labels:
    display_labels += (label,)
if 'Unknown' not in merged.labels:
  display_labels += ('Unknown',)

display.display_search_results(
    results, model_config.sample_rate,
    project_state.source_map,
    checkbox_labels=display_labels,
    max_workers=5)

#### Add selected results to the labeled data

As before, once we've annotated the examples from the previous cell, we'll save them in the `labeled_data_dir`.

In [None]:
results.write_labeled_data(labeled_data_path, model_config.sample_rate)

### Repeat!

We now have two distinct methods for generating training data.  Keep looping through those two methods and training new linear models until you have a model you're happy with (a high AUC-ROC score is a good indicator of model quality).

**Note:** If you want to start fresh, you can delete all of your previously labeled samples by deleting all of the files and folders in the `labeled_data_dir`.

### Write the trained model's classification results (inference) to CSV

Usually the purpose of creating a model in the first place is to bulk-process many hours of raw audio.  In this cell, we'll run our linear model over the entire search corpus.  The output will be a CSV containing the results.  The `threshold` parameter is the minimum logit score that will get recorded (ie, samples with low logit scores are simply omitted from the results).  You can tune this score to generate different output CSVs at different confidence scores.

In [None]:
# { vertical-output: true }

threshold = 1.0  #@param
# output_filepath = 'data/active-learning-nb/CCAI 2023 Google Perch Tutorial Outputs/content/inference.csv'  #@param
output_filepath = f'{output_directory}/content/inference.csv'  #@param


def classify_batch(batch):
  """Classify a batch of embeddings."""
  emb = batch[tf_examples.EMBEDDING]
  emb_shape = tf.shape(emb)
  flat_emb = tf.reshape(emb, [-1, emb_shape[-1]])
  logits = model(flat_emb)
  logits = tf.reshape(
      logits, [emb_shape[0], emb_shape[1], tf.shape(logits)[-1]])
  # Take the maximum logit over channels.
  logits = tf.reduce_max(logits, axis=-2)
  batch['logits'] = logits
  return batch


inference_ds = embeddings_ds.map(
    classify_batch, num_parallel_calls=tf.data.AUTOTUNE
)
all_embeddings = []
all_logits = []

with open(output_filepath, 'w') as f:
  # Write column headers.
  headers = ['filename', 'timestamp_s', 'label', 'logit']
  f.write(', '.join(headers) + '\n')
  for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):
    all_embeddings.append(ex['embedding'])
    all_logits.append(ex['logits'])
    for t in range(ex['logits'].shape[0]):
      for i, label in enumerate(merged.labels):
        if ex['logits'][t, i] > threshold:
          offset = ex['timestamp_s'] + t * model_config.hop_size_s

          logit = '{:.2f}'.format(ex['logits'][t, i])
          row = [ex['filename'].decode('utf-8'),
                 '{:.2f}'.format(offset),
                 label, logit]
          f.write(', '.join(row) + '\n')

all_embeddings = np.concatenate(all_embeddings, axis=0)
all_logits = np.concatenate(all_logits, axis=0)

## Visualization of results

We will now use [t-SNE](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding) to produce a scatter plot of all of the embeddings.
t-SNE is a technique for reducing high-dimensional data to two dimensions
for visualization. It tries to keep points which are close in the embedding
space close in the two-dimensional space.

We will also use the output probabilities from the classifier to color the
resulting scatter plot.

In [None]:
# { vertical-output: true }

import sklearn
from sklearn.manifold import TSNE
import matplotlib.patches as mpatches
from matplotlib import colormaps

def sigmoid(x):
  return 1 / (1 + np.exp(-x))

pca = sklearn.decomposition.PCA(256)
reduced = pca.fit_transform(all_embeddings.squeeze())
tsne = TSNE()
%time ts = tsne.fit_transform(reduced)

# Plot the t-SNE embedding.
plt.figure(figsize=(12,12))
plt.title('Powdermill Embeddings', fontsize=24)
plt.xticks([]),plt.yticks([])
cmap = colormaps['plasma']
cs = cmap(sigmoid(all_logits[:, 1]))
plt.scatter(ts[:, 0], ts[:, 1], c=cs, alpha=0.25, marker='o')

patches = []
for i, pr in enumerate((0.0, 0.5, 1.0)):
  patch = mpatches.Patch(color=cmap(pr), label=f'P(woothr)={pr:5.2f}', alpha=0.25)
  patches.append(patch)
plt.legend(handles=patches)