In [None]:
from datasets import Dataset, load_from_disk
import pandas as pd
import numpy as np
from datetime import datetime, date
from pathlib import Path
import os
from tempfile import TemporaryDirectory

In [None]:
data_dir = "../../data/xeno_canto"
train_data_dir = "../../data/training_data"

In [None]:
def save(dataset, path=data_dir):
    path = Path(path)
    with TemporaryDirectory() as temp_dir:
        temp_dir = Path(temp_dir)
        dataset.save_to_disk(temp_dir, max_shard_size="100MB")
        for file in path.glob("*"):
            os.remove(file)
        path.rmdir()
        temp_dir.rename(path)

In [None]:
def generate_and_save_train_data(data: Dataset, path=train_data_dir):
    def filter_func(record):
        # keep everything that meets all conditions
        conditions = []
        
        # needs to bu <= 5 minutes long
        conditions.append(record["duration"] <= 300 if record["duration"] is not None else True)
        
        # needs to be available
        conditions.append(record["available"])
        
        return all(conditions)
    
    # filter data
    data = data.filter(filter_func)
    # project data
    data = data.select_columns(["file", "sci_name", "duration", "simple_label", "natural_label"])
    # save data
    save(data, path)
    return data

In [None]:
data = Dataset.load_from_disk(data_dir)
data

Dataset({
    features: ['file', 'available', 'type', 'duration', 'date', 'elevation', 'name_eng', 'time_label', 'time', 'remarks', 'longitude', 'background', 'latitude', 'country', 'sci_name', 'simple_label', 'full_output', 'corrupt', 'natural_label'],
    num_rows: 691930
})

In [None]:
data

In [27]:
def get_simple_label_from_record(record):
    from datetime import datetime

    name = record.get("sci_name")
    name = "Unknown" if name is None else name
    
    country = record.get("country")
    country = "" if country is None else country
    
    date = record.get("date")
    if date is not None:
        date = datetime.fromisoformat(date)
        calendar_week = date.strftime('week %V')
        year = date.strftime('%Y')
    else:
        calendar_week = year = ""
        
     
    time_label = record.get("time_label")
    time_label = "" if time_label is None else time_label
    
    duration = record.get("duration")
    duration = "" if duration is None else f"{duration} seconds"
    
    rating = record.get("remarks").get("rating")
    rating = f"{rating} of 5"

    call_types = record.get("type")
    
    background_birds = [bird.get("name_sci") for bird in record.get("background")]

    label = []

    label.append(name)
    label.append(time_label)
    label.append(calendar_week)
    label.append(year)
    label.append(country)
    label.append(duration)
    label.append(rating)
    label.extend(call_types)
    label.extend(background_birds)

    label = "; ".join(label)
    
    return {"simple_label": label}

if __name__ == "__main__":
    data = data.map(get_simple_label_from_record, num_proc=8)

Map (num_proc=8):   0%|          | 0/691930 [00:00<?, ? examples/s]

### Testing label length

In [13]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-base")

filtered = data.filter(lambda rec: rec["available"] and rec["duration"]<=300).select_columns(["simple_label", "natural_label"])
simple_label_len = filtered.map(lambda rec: dict(out=len(tokenizer.encode(rec["simple_label"])))).with_format("pandas")["out"]
natural_label_len = filtered.map(lambda rec: dict(out=len(tokenizer.encode(rec["natural_label"])))).with_format("pandas")["out"]

print(simple_label_len.describe())
print(natural_label_len.describe())

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


Filter:   0%|          | 0/691930 [00:00<?, ? examples/s]

Map:   0%|          | 0/665295 [00:00<?, ? examples/s]

Map:   0%|          | 0/665295 [00:00<?, ? examples/s]

count    665295.000000
mean         35.422464
std          13.379283
min          18.000000
25%          28.000000
50%          30.000000
75%          38.000000
max         257.000000
Name: out, dtype: float64
count    665295.000000
mean         44.527987
std          18.038706
min          16.000000
25%          34.000000
50%          37.000000
75%          51.000000
max         324.000000
Name: out, dtype: float64


In [14]:
print("Relative")
for length in [64, 128, 256]:
    print(f"% Simple labels with length over {length}: {(simple_label_len > length).mean():.5f}")
    print(f"% Natural labels with length over {length}: {(natural_label_len > length).mean():.5f}")

print("Absolute")
for length in [64, 128, 256]:
    print(f"# Simple labels with length over {length}: {(simple_label_len > length).sum()}")
    print(f"# Natural labels with length over {length}: {(natural_label_len > length).sum()}")

Relative
% Simple labels with length over 64: 0.04775
% Natural labels with length over 64: 0.12535
% Simple labels with length over 128: 0.00064
% Natural labels with length over 128: 0.00342
% Simple labels with length over 256: 0.00000
% Natural labels with length over 256: 0.00001
Absolute
# Simple labels with length over 64: 31768
# Natural labels with length over 64: 83398
# Simple labels with length over 128: 427
# Natural labels with length over 128: 2275
# Simple labels with length over 256: 1
# Natural labels with length over 256: 5


In [51]:
def date_mapper(record):
    from datetime import datetime
    
    date_str = record["date"]
    output_col = "date"
    
    try:
        return {output_col: datetime.fromisoformat(date_str).date()}
    except ValueError:
        year_str, month_str, day_str = date_str.split("-")
        year = int(year_str)
        month = int(month_str)
        day = int(day_str)
        if year == 0:
            return {output_col: None}
        if month == 0:
            month = 1
        if day == 0:
            day = 1
        
        return {output_col: datetime(year, month, day).date()}
    
if __name__ == "__main__":
    data = data.map(date_mapper, num_proc=8)

Map (num_proc=8):   0%|          | 0/691930 [00:00<?, ? examples/s]

In [29]:
def time_label_mapper(record):
    from datetime import time as timeutil
    
    time_str = record["time"]
    if time_str is None:
        return {"time_label": None}
    
    time = timeutil.fromisoformat(time_str)
    hour = time.hour
    if 5 <= hour < 12:
        time_label = "morning"
    elif 12 <= hour < 17:
        time_label = "afternoon"
    elif 17 <= hour < 21:
        time_label = "evening"
    else:
        time_label = "night"
    
    return {"time_label": time_label}

if __name__ == "__main__":
    data = data.map(time_label_mapper, num_proc=8)

Map (num_proc=8):   0%|          | 0/691930 [00:00<?, ? examples/s]

In [15]:
def get_natural_label_from_record(record):
    from datetime import datetime

    def get_background_combined_str_from_record(record):
        background = record.get("background")
        if background == []:
            return None
        
        sep = ", a "
        last_sep = " and a "
        names = [bird.get("name_sci") for bird in background]
        combined_str = sep.join(names)
        # instead of last comma add an "and" as last separator
        combined_str = last_sep.join(combined_str.rsplit(sep, 1))

        # add leading "a "
        combined_str = "a " + combined_str
        return combined_str

    name = record.get("name_sci")
    
    background_combined_str = get_background_combined_str_from_record(record)

    country = record.get("country")

    date = record.get("date")
    month = datetime.fromisoformat(date).strftime('%B') if date is not None else None
    time_label = record.get("time_label")
    
    duration = record.get("duration")
    rating = record.get("remarks").get("rating")

    call_type_sep = ", "
    call_type = call_type_sep.join(record.get("type"))

    label = f"A {name} " if name is not None else "An unknown bird "
    label += f"was recorded"
    label += f" in {country}" if country is not None else ""
    label += f" in {month}" if month is not None else ""
    label += f" at {time_label}" if time_label is not None else ""
    label += ". "
    label += f"The recording is {duration} seconds long. " if duration is not None else ""
    label += f"The sound is described as {call_type}. " if call_type != "" else ""
    label += f"The sound quality is {rating} out of 5."
    label += f"In the background there is {background_combined_str}. " if background_combined_str is not None else ""
    
    return {"natural_label": label}

if __name__ == "__main__":
    data = data.map(get_natural_label_from_record, num_proc=8)

Map (num_proc=8):   0%|          | 0/691930 [00:00<?, ? examples/s]

In [11]:
sound_types = ["song","call","wingbeat","rattle","wings","wingclap",
               "duet","juvenile","adult","clapping","wing beats",
               "begging","imitation","knocks","drumming","wing flapping",
               "buzz","trill","clappering","chatter",
               "trill",  "conflict",
               ]

def type_mapper(record):
    type_list  = record.get("type")
    type_str = ";".join(type_list).lower()
    out = []
    for sound_type in sound_types:
        if sound_type in type_str:
            out.append(sound_type)
    return {"type": out}

data = data.map(type_mapper)

Map:   0%|          | 0/691930 [00:00<?, ? examples/s]