<a href="https://colab.research.google.com/github/Abalo39/Machine_learning_at_scale/blob/main/perch_hoplite/agile/2_agile_modeling_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install hoplite and TF 2.20
!pip install --upgrade pip
!pip install git+https://github.com/google-research/perch-hoplite.git
!pip install tensorflow~=2.20.0

In [None]:
# Imports
import os
from matplotlib import pyplot as plt
import numpy as np
from etils import epath

# Hoplite imports
from perch_hoplite.agile import audio_loader
from perch_hoplite.agile import classifier
from perch_hoplite.agile import classifier_data
from perch_hoplite.agile import embedding_display
from perch_hoplite.agile import source_info
from perch_hoplite.db  import brutalism
from perch_hoplite.db import score_functions
from perch_hoplite.db  import search_results
from perch_hoplite.db import sqlite_usearch_impl
from perch_hoplite.zoo import model_configs

In [None]:
# Imports
import os
from matplotlib import pyplot as plt
import numpy as np
from etils import epath

# Hoplite imports
from perch_hoplite.agile import audio_loader
from perch_hoplite.agile import classifier
from perch_hoplite.agile import classifier_data
from perch_hoplite.agile import embedding_display
from perch_hoplite.agile import source_info
from perch_hoplite.db  import brutalism
from perch_hoplite.db import score_functions
from perch_hoplite.db  import search_results
from perch_hoplite.db import sqlite_usearch_impl
from perch_hoplite.zoo import model_configs

In [None]:
# Copy hoplite database locally.
intaka_path = epath.Path('gs://chirp-public-bucket/soundscapes/intaka')

# Copy the hoplite database to the Colab local storage.
# The usearch.index file is large, so this may take a minute.
print("Copying database files...")
for fp in intaka_path.glob('hoplite*'):
  print(f"Copying {fp}...")
  with fp.open('rb') as f:
    with open(fp.name, 'wb') as g:
      g.write(f.read())

print("Copying search index (this may take a moment)...")
with (intaka_path / 'usearch.index').open('rb') as f:
  with open('usearch.index', 'wb') as g:
    g.write(f.read())

# Update the DB with override values.
# We point the db to the google cloud bucket containing the audio files.
db_path = '/content'
db = sqlite_usearch_impl.SQLiteUsearchDB.create(db_path)

# Update Model Config
model_cfg = db.get_metadata('model_config')
model_cfg.model_config.tfhub_path = 'google/bird-vocalization-classifier/tensorFlow2/perch_v2_cpu'
model_cfg.model_config.tfhub_version = 1
db.insert_metadata('model_config', model_cfg)

# Update Audio Sources Config
sources_cfg = db.get_metadata('audio_sources')
sources_cfg.audio_globs[0]['base_path'] = 'gs://chirp-public-bucket/soundscapes/intaka'
db.insert_metadata('audio_sources', sources_cfg)

db.commit()
print("SUCCESS: Database setup complete.")

In [None]:
# Initialize Model and Audio Loader
annotator_id = 'linnaeus'  # Identifier for your labels

# 1. Get configuration from the Intaka DB
db_model_config = db.get_metadata('model_config')
embed_config = db.get_metadata('audio_sources')

# 2. Initialize the model
print("Initializing model...")
model_class = model_configs.get_model_class(db_model_config.model_key)
embedding_model = model_class.from_config(db_model_config.model_config)

# 3. Initialize the audio loader
audio_sources = source_info.AudioSources.from_config_dict(embed_config)
if hasattr(embedding_model, 'window_size_s'):
  window_size_s = embedding_model.window_size_s
else:
  window_size_s = 5.0

audio_filepath_loader = audio_loader.make_filepath_loader(
    audio_sources=audio_sources,
    window_size_s=window_size_s,
    sample_rate_hz=embedding_model.sample_rate,
)
print("SUCCESS: Model and Loader initialized.")

In [None]:
# Load query audio.
# using Red-eyed Dove (xc746686) as requested for the Intaka session
query_uri = 'xc746686'
query_label = 'redeye'

print(f"Loading query: {query_uri} ({query_label})")
query = embedding_display.QueryDisplay(
    uri=query_uri, offset_s=0.0, window_size_s=5.0, sample_rate_hz=32000)
_ = query.display_interactive()

In [None]:
# Embed the Query and Search.
num_results = 50

# Embed the query audio
query_embedding = embedding_model.embed(
    query.get_audio_window()).embeddings[0, 0]

# Perform the search
print(f"Searching for {num_results} matches...")
ann_matches = db.ui.search(query_embedding, count=num_results)
results = search_results.TopKSearchResults(top_k=num_results)
for k, d in zip(ann_matches.keys, ann_matches.distances):
  results.update(search_results.SearchResult(k, d))

print(f"Found {len(results.search_results)} results.")

In [None]:
# Display Results.
display_results = embedding_display.EmbeddingDisplayGroup.from_search_results(
    results, db, sample_rate_hz=32000, frame_rate=100,
    audio_loader=audio_filepath_loader)
display_results.display(positive_labels=[query_label])

In [None]:
#@title Load model and connect to database. { vertical-output: true }

#@markdown Location of database containing audio embeddings.
db_path = '/content'  #@param {type:'string'}
#@markdown Identifier (eg, name) to attach to labels produced during validation.
annotator_id = 'linnaeus'  #@param {type:'string'}

db = sqlite_usearch_impl.SQLiteUsearchDB.create(db_path)
db_model_config = db.get_metadata('model_config')
embed_config = db.get_metadata('audio_sources')
model_class = model_configs.get_model_class(db_model_config.model_key)
embedding_model = model_class.from_config(db_model_config.model_config)
audio_sources = source_info.AudioSources.from_config_dict(embed_config)
if hasattr(embedding_model, 'window_size_s'):
  window_size_s = embedding_model.window_size_s
else:
  window_size_s = 5.0
audio_filepath_loader = audio_loader.make_filepath_loader(
    audio_sources=audio_sources,
    window_size_s=window_size_s,
    sample_rate_hz=embedding_model.sample_rate,
)

In [None]:
#@title Load model and connect to database. { vertical-output: true }

#@markdown Location of database containing audio embeddings.
db_path = '/content'  #@param {type:'string'}
#@markdown Identifier (eg, name) to attach to labels produced during validation.
annotator_id = 'linnaeus'  #@param {type:'string'}

db = sqlite_usearch_impl.SQLiteUsearchDB.create(db_path)
db_model_config = db.get_metadata('model_config')
embed_config = db.get_metadata('audio_sources')
model_class = model_configs.get_model_class(db_model_config.model_key)
embedding_model = model_class.from_config(db_model_config.model_config)
audio_sources = source_info.AudioSources.from_config_dict(embed_config)
if hasattr(embedding_model, 'window_size_s'):
  window_size_s = embedding_model.window_size_s
else:
  window_size_s = 5.0
audio_filepath_loader = audio_loader.make_filepath_loader(
    audio_sources=audio_sources,
    window_size_s=window_size_s,
    sample_rate_hz=embedding_model.sample_rate,
)

# Search

In [None]:
#@title Load query audio. { vertical-output: true }

#@markdown The `query_uri` can be a URL, filepath, or Xeno-Canto ID
#@markdown (like `xc777802`, containing an Eastern Whipbird (`easwhi1`)).
query_uri = 'xc777802'  #@param {type:'string'}
query_label = 'easwhi1'  #@param {type:'string'}

query = embedding_display.QueryDisplay(
    uri=query_uri, offset_s=0.0, window_size_s=5.0, sample_rate_hz=32000)
_ = query.display_interactive()

In [None]:
# Copy hoplite database locally.
from etils import epath
import os
from perch_hoplite.db import sqlite_usearch_impl
intaka_path = epath.Path('gs://chirp-public-bucket/soundscapes/intaka')

# Copy the hoplite database to the Colab local storage.
# The usearch.index file is 3.6Gb, so takes a couple minutes to download.
# Files are placed in `/content` directory.
for fp in intaka_path.glob('hoplite*'):
  print(fp)
  with fp.open('rb') as f:
    with open(fp.name, 'wb') as g:
      g.write(f.read())
with (intaka_path / 'usearch.index').open('rb') as f:
  print(intaka_path / 'usearch.index')
  with open('usearch.index', 'wb') as g:
    %time g.write(f.read())

# Update the DB with some override values...
# Need to use the perch_v2_cpu model, and point the db to the google cloud
# bucket containing the data.
db_path = '/content'
db = sqlite_usearch_impl.SQLiteUsearchDB.create(db_path)

model_cfg = db.get_metadata('model_config')
model_cfg.model_config.tfhub_path = 'google/bird-vocalization-classifier/tensorFlow2/perch_v2_cpu'
model_cfg.model_config.tfhub_version = 1
db.insert_metadata('model_config', model_cfg)

sources_cfg = db.get_metadata('audio_sources')
sources_cfg.audio_globs[0]['base_path'] = 'gs://chirp-public-bucket/soundscapes/intaka'
db.insert_metadata('audio_sources', sources_cfg)

db.commit()