<a href="https://colab.research.google.com/github/KW-plato/PrimateComms/blob/main/predictor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Script for demo of call-type prediction**

In [None]:
%pip install pydub

In [None]:
import os
import numpy as np
import pandas as pd
import pydub
from scipy import signal
from scipy.io import wavfile
import torch
from torch.utils.data import DataLoader
from handlers.datahandler import AudioSpectDataset
from models.chimp_model import ChimpCallClassifier

In [None]:
"""
Helper functions
"""
#adds a processed clip to list of clips to be predicted
def add_processed_clip(recording, segment, t1, t2, len, spec, label):
    spec_list.append({
        'recording': recording,
        'segment': segment,
        't1': t1,
        't2': t2,
        'len (s)': len,
        'spectrogram': spec,
        'label': label
    })

#Pretty prints the predicted call labels
def pretty_print(df):
    print("{0:*^80s}".format(df.iloc[1]['recording']))
    print("{0:^15s} {1:^15s} {2:^10s} {3:<40s}".format("Start Time","End Time","Duration","Call Type/Comment"))
    for i,row in df.iterrows():
        print("{0:^15.4f} {1:^15.4f} {2:^10.4f} {3:<40s}".format(row['t1'], row['t2'], row['len (s)'], row['label']))



In [None]:
datasrc = "/Techspace/Chimp/data/Demo"
detector_output = "ARU18_20120410_090000.txt"
audio_file = "ARU18_20120410_090000.wav"
spec_list = []

In [None]:
#Create temporary directories
temp_dir = os.path.join(datasrc,"temp")
if not os.path.isdir(temp_dir):
    os.makedirs(temp_dir)

temp_wavs = os.path.join(temp_dir, "wav", audio_file.rstrip(".wav"))
if not os.path.isdir(temp_wavs):
    os.makedirs(temp_wavs)

temp_specs = os.path.join(temp_dir, "specs", audio_file.rstrip(".wav"))
if not os.path.isdir(temp_specs):
    os.makedirs(temp_specs)

In [None]:
#Read detector output and the audio recording
chimp_calls = pd.read_csv(os.path.join(datasrc,detector_output), delimiter='\t', encoding='utf-16')
newAudio = pydub.AudioSegment.from_wav(os.path.join(datasrc, audio_file))

In [None]:
#The contents of the detector's output
display(chimp_calls)

In [None]:
# Reads the timestamps of start and end points of each chimp volcalization available from detector's output
# Clips the recording at the timestamps, converts into spectrogram, stores in temporary directory and creates a list
for r, row in chimp_calls.iterrows():
    segment = "{} - {}".format(row["Begin Time (s)"], row["End Time (s)"])
    dur = row["End Time (s)"] - row["Begin Time (s)"]
    if dur >= 0.5:
        start = row["Begin Time (s)"] * 1000
        end = row["End Time (s)"] * 1000
        finalAudio = newAudio[start:end]
        if finalAudio.channels > 1:
            finalAudio = finalAudio.set_channels(1)
        filename = os.path.join(temp_wavs, str(start) + '_' + str(end) + '.wav')
        finalAudio.export(filename, format='wav')
        sampling_rate, input_clip = wavfile.read(filename)
        clip_len = len(input_clip)
        if sampling_rate != 44100:
            input_clip = signal.resample(input_clip, int(clip_len * 44100 / sampling_rate))
            sampling_rate = 44100
        x1 = x2 = 0
        clip_len = len(input_clip)
        while ( x2 <= clip_len):
            x2 = int(x1 + 0.5 * sampling_rate)
            t1 = row["Begin Time (s)"] + x1 / sampling_rate
            if x2 > clip_len:
                t2 = row["End Time (s)"]
                x1 = int(clip_len - 0.5 * sampling_rate)
            else:
                t2 = row["Begin Time (s)"] + x2 / sampling_rate
            l = t2 - t1
            piece = input_clip[x1:x2]
            try:
                _, _, spectrogram = signal.spectrogram(
                    x=piece,
                    fs=sampling_rate,
                    nfft=512,
                    noverlap=427,
                    detrend=False,
                    window=signal.get_window('hamming', 512)
                )
                if spectrogram.shape == (257, 254):
                    temp = os.path.join(temp_specs, str(r + 1) + "_" + str(x1) + ".spec")
                    with open(temp, 'wb') as f_save:
                        np.save(f_save, spectrogram)
                    add_processed_clip(audio_file, segment, t1, t2, l, temp, 'undecoded')
                else:
                    add_processed_clip(audio_file, segment, t1, t2, l, "Unavailable", "specgram shape {}".format(spectrogram.shape))
            except Exception as e:
                add_processed_clip(audio_file, segment, t1, t2, l, "Unavailable", e)
            x1 = x2
    else:
        add_processed_clip(audio_file, segment, row["Begin Time (s)"], row["End Time (s)"], dur, "Unavailable", "Segment smaller than 500ms")

In [None]:
#Feeds the volcalization segments into the trained model
#Stores the prediction for the entire recording
if spec_list:
    df = pd.DataFrame(spec_list, columns=['recording','segment','t1','t2', 'len (s)', 'spectrogram', 'label'])
    df_train = df.loc[df['label'] == 'undecoded',['spectrogram', 'label']].reset_index()

    pred_data = AudioSpectDataset(df_train['spectrogram'].to_list(), df_train['label'].to_list())

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    saved_model = torch.load(os.path.join(datasrc, "SPECGRAM_final_other.pth"), map_location=torch.device('cpu'))
    label_dict = saved_model['labels']
    classifier = ChimpCallClassifier(
        num_labels=len(label_dict),
        spectrogram_shape=saved_model['spectrogram_shape'],
        dropout=saved_model['dropout']
    ).float()

    classifier.load_state_dict(saved_model['model'])
    classifier.eval()
    dataloaders_test = torch.utils.data.DataLoader(pred_data, batch_size=1, shuffle=False)

    with torch.no_grad():
        for i, samples in enumerate(dataloaders_test):
            input = samples['spectrogram'].to(device)
            outputs = classifier(input)
            pred = int(torch.argmax(outputs, dim=1).cpu().numpy())
            calltype = label_dict[pred]
            pos = df_train.at[df_train.index[i],'index']
            df.at[df.index[pos],'label'] = calltype

    df = df[['recording','segment','t1','t2', 'len (s)', 'label']]
    df.to_csv(os.path.join(datasrc,audio_file.rstrip(".wav") + '.csv'), index=False, header=True)
else:
    print("No chimp call clips found in input")


In [None]:
#Show the predcited call type labels for the entire recording
pretty_print(df)