In [4]:
%load_ext autoreload
%autoreload 2

In [30]:
from datasets import load_dataset
from copy import deepcopy
from datasets.formatting.formatting import LazyRow, LazyBatch

In [60]:
path = "DBD-research-group/BirdSet"
name = "XCM"
cache_dir = f"data/{path}/{name}"
trust_remote_code =  True
num_proc =  5

In [61]:
ds = load_dataset(path, name, cache_dir=cache_dir, trust_remote_code=True)

In [62]:
class AudioLengthFilter:
    def __init__(self, min_len: int = 0, max_len: int = 10):
        self.min_len = min_len
        self.max_len = max_len
    
    def __call__(self, batch):
        return self.min_len <= batch["end_time"] - batch["start_time"] <= self.max_len

In [107]:
class AudioSegmenting:
    def __init__(self, segment_length: int = 10, max_segments: int = 3):
        self.segment_length = segment_length
        self.max_segments = max_segments
        
    def __call__(self, batch):
        new_batch = {k:[] for k in batch.keys()}
        # iterate over all rows of batch
        for b_idx in range(len(batch["filepath"])):
            # skip audios with to long length
            if batch["length"][b_idx] > self.segment_length * self.max_segments:
                continue
            
            # add all keys to new_batch
            for key in batch.keys():
                # add duplicates if length is over 10, seconds
                for i in range((batch["length"][b_idx] // self.segment_length) + 1):
                    if key == "start_time":
                        new_batch[key] += [i * self.segment_length]
                    elif key == "end_time":
                        new_batch[key] += [min((i+1) * self.segment_length, batch["length"][b_idx])]
                    else:
                        new_batch[key] += [batch[key][b_idx]]
        return new_batch

In [104]:
new_ds = ds["train"]
new_ds = new_ds.map(AudioSegmenting(10), batched=True, batch_size=500)
print("ds length", len(new_ds), "from:", len(ds["train"]))
new_ds = new_ds.filter(AudioLengthFilter(5))

ds length 81025 from: 80012


In [105]:
len(new_ds)

61498

In [106]:
new_ds[:10]["start_time"], new_ds[:10]["end_time"]

([0.0, 10.0, 20.0, 0.0, 0.0, 10.0, 0.0, 10.0, 0.0, 10.0],
 [10.0, 20.0, 29.0, 10.0, 10.0, 15.0, 10.0, 18.0, 10.0, 20.0])