This notebook assumes that you have downloaded CIDRZ data available at https://www.kaggle.com/datasets/googlehealthai/google-health-ai?resource=download

More specifically, for this example, we will use audio recordings from the `Chainda South Phone B` and `Kanyama Phone B` directories, and metadata from `Metadata and Codebook`

In [None]:
from collections import Counter
import concurrent.futures
import io
import os
import zipfile

# copybara:strip_begin(Internal imports)
from colabtools import drive
# copybara:strip_end

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import scipy.io.wavfile as wavfile
from scipy.signal import resample
from sklearn import metrics
import tensorflow as tf



In [None]:
# copybara:strip_begin(Internal imports)
metadata_zip_bytes = drive.LoadFile(file_id='1bE5hlnBpUsgfC-5GOtjTuR2z4ZPcCkUe')
# copybara:strip_end_and_replace_begin
# # This cell assumes that you have downloaded the metadata zip file at `/path/metadata.zip`
# # Please update the path accordingly.
# with open('/path/metadata.zip', 'rb') as zfile:
#   metadata_zip_bytes = zfile.read()
# copybara:replace_end


In [None]:


def read_csvs_from_zip(zip_bytes):
    """Reads CSV files from a zip archive into pandas DataFrames.

    Args:
        zip_bytes: The bytes representing the zip file.

    Returns:
        A dictionary where keys are filenames and values are pandas DataFrames.
        Returns an empty dictionary if the zip file is invalid or contains no CSVs.
        Prints error messages if CSV reading fails for a specific file.
    """
    try:
        zf = zipfile.ZipFile(io.BytesIO(zip_bytes))
        csv_data = {}

        for info in zf.infolist():

            if info.filename.endswith(".csv"):
                try:
                    with zf.open(info) as csvfile:  # No need for extractall
                        df = pd.read_csv(csvfile)
                        csv_data[info.filename] = df
                except pd.errors.ParserError as e: # Catch CSV parsing errors
                    print(f"Error reading CSV {info.filename}: {e}")
                except Exception as e: # Catch other errors
                    print(f"Error processing file {info.filename}: {e}")

        return csv_data

    except zipfile.BadZipFile:
        print("Error: Invalid zip file")
        return {}


dfs = read_csvs_from_zip(metadata_zip_bytes)


In [None]:
dfs.keys()

In [None]:
dfs['Metadata and Codebook/Google_Health_AI_Final_Codebook.csv']


In [None]:
dfs['Metadata and Codebook/GHAI_Final_Data_2023.csv']


In [None]:
dfs['Metadata and Codebook/GHAI_Final_Data_2023.csv']


# Read audio

* convert to mono
* resamples to 16 kHz
* represents as numpy arrays

In [None]:
# copybara:strip_begin(Internal imports)
chainda_B_zip_bytes = drive.LoadFile(
    file_id='1PpQz7KuZKzX47lImrj5mZzQRYS3t7N8s'
)
kanyama_B_zip_bytes = drive.LoadFile(file_id='1DCJk8wnZQdSBaj-uiEr4HRgVkGntHb3Q')
# copybara:strip_end_and_replace_begin
# # This cell assumes that you have downloaded the metadata zip file at `/path/chainda_B.zip`
# # and `/path/kanyama_B.zip`
# # Please update the paths accordingly.
# with open('/path/chainda_B.zip', 'rb') as zfile:
#   chainda_B_zip_bytes = zfile.read()
# with open('/path/kanyama_B.zip', 'rb') as zfile:
#   kanyama_B_zip_bytes = zfile.read()
# copybara:replace_end



In [None]:


def process_zipped_wavs(zip_bytes: bytes) -> dict[str, np.ndarray]:
    """Processes a zip file containing WAV files and downsamples them to 16kHz.

    Args:
        zip_bytes: Bytes representing the zip file.

    Returns:
        A dictionary where keys are filenames and values are NumPy arrays
        containing the processed audio data (16kHz, last 15s).
        Returns an empty dictionary if there are errors or no WAV files are found.
    """
    try:
      zf = zipfile.ZipFile(io.BytesIO(zip_bytes))
      wav_data = {}
      for info in zf.infolist():
        if info.filename.endswith(".wav"):

          with zf.open(info) as wav_file:
            try:
              rate, data = wavfile.read(wav_file)

              if data.dtype != np.float32:
                if data.dtype == np.int16:
                  data = data.astype(np.float32) / 32768.0
                elif data.dtype == np.int32:
                  data = data.astype(np.float32) / 2147483648.0

              # Handle multi-channel WAV files (e.g., stereo)
              if data.ndim > 1:
                data = np.mean(data, axis=1)

              # Downsample to 16kHz
              if rate != 16000:
                num_samples_new = int(len(data) * 16000 / rate)
                data = resample(data, num_samples_new)

              wav_data[info.filename.replace(' ', '_')] = data

            except Exception as e:  # Handle potential WAV read errors
              print(f"Error reading WAV file {info.filename}: {e}")

      return wav_data

    except zipfile.BadZipFile:  # Handle invalid zip files
      print("Error: Invalid zip file")
      return {}



In [None]:
processed_wavs = process_zipped_wavs(kanyama_B_zip_bytes)

for filename, audio_data in processed_wavs.items():
  print(f"File: {filename}, Shape: {audio_data.shape}")

In [None]:
processed_wavs_test = process_zipped_wavs(chainda_B_zip_bytes)

In [None]:
for k, v in processed_wavs.items():
  plt.plot(v)
  plt.show()

# Process audio

## Extract final sequence of coughs

As part of the CIDRZ protocol, participants are required to cough once, later, another one, and finally, they are asked to repeatedly cough.

In https://arxiv.org/abs/2403.02522, we found that the final sequence of coughs resulted in better performance, hypothetically because this "forced" sequence of cough also elicit involuntary coughs, which have been shown to be more predictive of disease status in https://www.science.org/doi/10.1126/sciadv.adi0282.

In our experiments, we had access to a cough detector, which outputs a score between 0 and 1 indicating how likely a 2s 16kHz audio clip is to contain a cough event. Since we do not have access to this model here, we will use a simple heuristic to extract the final sequence of coughs from the audio files.

In [None]:

def compute_spectrogram(
    audio: np.ndarray | tf.Tensor,
    frame_length: int = 400,
    frame_step: int = 160,
    ):

  if len(audio.shape) == 2:
    audio = np.mean(audio, axis=1)
  elif len(audio.shape) > 2:
    raise NotImplementedError(
        f'`audio` should have at most 2 dimensions but had {len(audio.shape)}')
  stft_output = tf.signal.stft(
      audio,
      frame_length=frame_length,
      frame_step=frame_step,
      fft_length=frame_length)
  spectrogram = tf.abs(stft_output)
  return spectrogram


def compute_loudness(
    audio: np.ndarray | tf.Tensor,
    sample_rate: float = 16000.0,
) -> np.ndarray:
  """Computes loudness.

  It is defined as the per-channel per-timestep cross-frequency L2 norm of the
  log mel spectrogram.

  Args:
    audio: Array of shape [num_timesteps] representing a raw wav
      file.
    sample_rate: The sample rate of the input audio.
    fft_output_conversion: The string indicating the output conversion function.
      Currently, only `magnitude` and `magnitude_squared` are supported.

  Returns:
    An array of shape [num_timesteps] representing the loudness.
  """
  frame_step = int(sample_rate) // 100  # 10 ms
  frame_length = 25 * int(sample_rate) // 1000  # 25 ms
  linear_spectrogram = compute_spectrogram(
      audio.astype(np.float32),
      frame_length=frame_length,
      frame_step=frame_step,
  )
  print(audio.shape, audio.shape[0] //16000, linear_spectrogram.shape)
  sum_amplitude = np.sum(linear_spectrogram, axis=1)
  loudness_db_timeseries = 20 * np.log10(sum_amplitude)
  return np.asarray(loudness_db_timeseries)


In [None]:
loudness = {}
for k, v in processed_wavs.items():
  loudness[k] = compute_loudness(v)

In [None]:
loudness_test = {}
for k, v in processed_wavs_test.items():
  loudness_test[k] = compute_loudness(v)

In [None]:
LOUDNESS_THRESHOLD = 42

for i, (participant_id, loudness_series) in enumerate(loudness.items()):
  plt.plot(loudness_series)
  plt.axhline(LOUDNESS_THRESHOLD, c='k', linestyle='--')
  plt.show()
  if i > 20:
    break

We can see that the peaks above `LOUDNESS_THRESHOLD` most likely correspond to coughs, and we want to extract the final ones.

In [None]:

def extract_final_loud_clips_information(
    loudness: np.ndarray,
    min_peak_height: float = LOUDNESS_THRESHOLD,
    window_size: int = 200,
    window_step: int = 100,
    number_of_peaks: int = 5,
) -> list[dict[str, np.ndarray | int]]:
  """Extracts final sequence of coughs from the loudness timeseries.

  Args:
    loudness: Array of shape [num_timesteps] representing the loudness.
    min_peak_height: Minimal amplitude of a peak to be considered a likely cough
    window_size: Size of the window. 100 corresponds to 1s.
    window_step: Step of the window. 100 corresponds to 1s.
    number_of_peaks: Number of peaks to extract.
  """
  picked_windows = []
  for i in range(loudness.size//window_step):
    end = loudness.size - i * window_step
    start = end - window_size
    window = loudness[start: end]
    if np.max(window) > min_peak_height:
      picked_windows.append({
          'window': window,
          # Multiply by 160 to convert back to the initial temporal scale
          'start': 160 * start,
          'end': 160 * end,
      })
    if len(picked_windows) >= number_of_peaks:
      return picked_windows
  return picked_windows


In [None]:
final_loud_clips_data = {}
audio_clips_per_participant = {}

for i, (pid, series) in enumerate(loudness.items()):
  try:
    final_loud_clips_data[pid] = extract_final_loud_clips_information(series)
  except:
    continue
  audio_clips = []
  for clip in final_loud_clips_data[pid]:
    start = clip['start']
    end = clip['end']
    wav = processed_wavs[pid]
    audio_clips.append(wav[start:end])
    if i < 3:
      fig, axes = plt.subplots(nrows=1, ncols=2)
      axes[0].plot(clip['window'])
      print(pid)
      print(start / wav.size, end / wav.size)
      axes[1].plot(wav[start: end])
      plt.show()
  audio_clips_per_participant[pid] = audio_clips


In [None]:
final_loud_clips_data_test = {}
audio_clips_per_participant_test = {}

for i, (pid, series) in enumerate(loudness_test.items()):
  try:
    final_loud_clips_data_test[pid] = extract_final_loud_clips_information(series)
  except:
    print(f'Exception for {pid}')
    continue
  audio_clips = []
  for clip in final_loud_clips_data_test[pid]:
    start = clip['start']
    end = clip['end']
    wav = processed_wavs_test[pid]
    audio_clips.append(wav[start:end])
    if i < 3:
      print(pid)
      fig, axes = plt.subplots(nrows=1, ncols=2)
      axes[0].plot(clip['window'])
      print(start / wav.size, end / wav.size)
      axes[1].plot(wav[start: end])
      plt.show()
  audio_clips_per_participant_test[pid] = audio_clips


The JSON file mentioned in the cell below is created by running the following command (for service accounts)

```
gcloud auth application-default login --impersonate-service-account SERVICE_ACCT
```

or that command

```
gcloud auth application-default login
```

to identify with your own account.

This assumes that you have first [installed](https://cloud.google.com/sdk/docs/install) `gcloud` CLI and created a service account (see [[1]](https://cloud.google.com/iam/docs/service-account-overview), [[2]](https://cloud.google.com/iam/docs/service-accounts-create)) (identified by `SERVICE_ACCT` above)

In [None]:
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/path/to/your/credentials/json/file'


In [None]:
# Environment variable `GOOGLE_APPLICATION_CREDENTIALS` must be set for these
# imports to work.
import api_utils
import eval_utils


In [None]:
audio_clips = np.concatenate([clips for clips in audio_clips_per_participant.values()])
print(audio_clips.shape)

In [None]:
audio_clips_test = np.concatenate([clips for clips in audio_clips_per_participant_test.values()])
print(audio_clips_test.shape)

In [None]:
batches = [audio_clips[k: k+4] for k in range(0, len(audio_clips), 4)]
final_batch = batches[-1]
batches = np.stack(batches[:-1])
print(batches.shape)
print(final_batch.shape)

responses = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
  futures = {
      executor.submit(api_utils.make_prediction_with_exponential_backoff, api_utils.RAW_AUDIO_ENDPOINT_PATH, batch): batch_idx
      for batch_idx, batch in enumerate(batches)
  }
  for future in concurrent.futures.as_completed(futures):
    batch_idx = futures[future]
    try:
      responses[batch_idx] = future.result()
    except Exception as e:
      print("An error occurred:", e)

responses[len(batches)] = api_utils.make_prediction_with_exponential_backoff(
    endpoint_path=api_utils.RAW_AUDIO_ENDPOINT_PATH,
    instances=final_batch,
  )

responses = [responses[k] for k in sorted(responses.keys())]

In [None]:
batches = [audio_clips_test[k: k+4] for k in range(0, len(audio_clips_test), 4)]
final_batch = batches[-1]
batches = np.stack(batches[:-1])
print(batches.shape)
print(final_batch.shape)

responses_test = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
  futures = {
      executor.submit(api_utils.make_prediction_with_exponential_backoff, api_utils.RAW_AUDIO_ENDPOINT_PATH, batch): batch_idx
      for batch_idx, batch in enumerate(batches)
  }
  for future in concurrent.futures.as_completed(futures):
    i = futures[future]
    try:
      responses_test[i] = future.result()
    except Exception as e:
      print("An error occurred:", e)

responses_test[len(batches)] = api_utils.make_prediction_with_exponential_backoff(
    endpoint_path=api_utils.RAW_AUDIO_ENDPOINT_PATH,
    instances=final_batch,
  )
responses_test = [responses_test[k] for k in sorted(responses_test.keys())]

In [None]:
embeddings = np.concatenate(responses, axis=0)
embeddings.shape

In [None]:
embeddings_test = np.concatenate(responses_test, axis=0)
embeddings_test.shape

# Train linear probes

## Fetch labels

In [None]:
label_per_barcode = dfs['Metadata and Codebook/GHAI_Final_Data_2023.csv'][['barcode', 'tb_decision']].set_index('barcode').tb_decision.to_dict()
# `barcode` column has format `XX-XXX-XX.wav`
label_per_participant_id = {k: label_per_barcode.get(k.split('.')[0]) for k in audio_clips_per_participant.keys()}


In [None]:
participant_ids = np.concatenate([[pid] * len(clips) for pid, clips in audio_clips_per_participant.items()])
labels = [label_per_participant_id[pid] for pid in participant_ids]

In [None]:
len(labels)

In [None]:
Counter(labels)

In [None]:
# `barcode` column has format `YYYYYYYYYY/XX-XXX-XX.wav`
label_per_participant_id_test = {k: label_per_barcode.get(k.split('/')[1].split('.')[0]) for k in audio_clips_per_participant_test.keys()}
participant_ids_test = np.concatenate([[pid] * len(clips) for pid, clips in audio_clips_per_participant_test.items()])
labels_test = [label_per_participant_id_test[pid] for pid in participant_ids_test]

In [None]:
Counter(labels_test)

## Train using participant-level cross-validation
Training data comes from `Kanyama Phone B`

In [None]:
# Train on data from `Kanyama Phone B`
w = [l is not None for l in labels]
labels = np.array(labels)
w = np.array(w)
participant_ids = np.array(participant_ids)

In [None]:
probe = eval_utils.train_linear_probe_with_participant_level_crossval(
    features=embeddings[w],
    labels=labels[w].astype(int),
    participant_ids=participant_ids[w],
    n_folds = 5,
    use_sgd_classifier = True,
    stratify_per_label = True,
)

## Evaluate
Eval data comes from `Chainda South Phone B`

In [None]:
# ROCAUC per recording

w_test = [l is not None for l in labels_test]
labels_test = np.array(labels_test)
w_test = np.array(w_test)

metrics.roc_auc_score(
    y_true=labels_test[w_test].astype(int),
    y_score=probe.predict_proba(embeddings_test[w_test])[:, 1],
    )

In [None]:
# ROCAUC per participant

score_df = pd.DataFrame({
    'label': labels_test[w_test].astype(int),
    'score': probe.predict_proba(embeddings_test[w_test])[:, 1],
    'id': participant_ids_test[w_test],
}).groupby('id').agg({'label': 'max', 'score': 'mean'})

metrics.roc_auc_score(
    y_true=score_df.label.values,
    y_score=score_df.score.values,
    )
