In [781]:
SAMPLE_DURATION = 2.0       # how long should a sample be?
SAMPLES_PER_GENRE = 200     # lets have this many samples for every genre
MAX_SAMPLES_PER_FILE = 11   # don't oversample 1 beat
RANDOM_SEED = 42            # keep it reproducible

In [782]:
import pandas as pd

data = pd.read_csv('data_reduced.csv', encoding="latin-1")
data.head()

# only use beats longer than SAMPLE_DURATION, not fills
beat_mask = (data['beat_type'] == 1) & (data.duration >= SAMPLE_DURATION)
data = data[beat_mask]

In [783]:
data.columns

Index(['Unnamed: 0', 'drummer', 'session', 'id', 'style', 'simplified_style',
       'bpm', 'beat_type', 'time_signature', 'midi_filename', 'audio_filename',
       'duration', 'split', 'onset_env_mean', 'onset_env_std', 'mfcc_mean',
       'mfcc_std', 'spectral_flux_mean', 'spectral_flux_std',
       'spectral_contrast_mean', 'spectral_contrast_std', 'tonnetz_mean',
       'tonnetz_std', 'rms_mean', 'rms_std', 'spectral_centroid_mean',
       'spectral_centroid_std', 'spectral_bandwidth_mean',
       'spectral_bandwidth_std', 'spectral_flatness_mean',
       'spectral_flatness_std', 'tempogram_mean', 'tempogram_std'],
      dtype='object')

In [784]:
# all the features need to be recalculated for the sliced audio data, so we drop them
columns_to_drop = ['Unnamed: 0', 'onset_env_mean', 'onset_env_std', 'mfcc_mean',
       'mfcc_std', 'spectral_flux_mean', 'spectral_flux_std',
       'spectral_contrast_mean', 'spectral_contrast_std', 'tonnetz_mean',
       'tonnetz_std', 'rms_mean', 'rms_std', 'spectral_centroid_mean',
       'spectral_centroid_std', 'spectral_bandwidth_mean',
       'spectral_bandwidth_std', 'spectral_flatness_mean',
       'spectral_flatness_std', 'tempogram_mean', 'tempogram_std']

data = data.drop(columns_to_drop, axis=1)
print(data.simplified_style.value_counts())
data.describe()

simplified_style
rock      179
funk       82
latin      59
jazz       52
hiphop     34
pop        20
Name: count, dtype: int64


Unnamed: 0,bpm,beat_type,duration
count,426.0,426.0,426.0
mean,109.589202,1.0,86.707123
std,28.059731,0.0,84.930078
min,50.0,1.0,3.03125
25%,93.0,1.0,23.990625
50%,104.0,1.0,54.098745
75%,120.0,1.0,136.830771
max,290.0,1.0,611.564048


In [785]:
# Add start and end sample to the data. All original entries get the first 2 seconds of the audio file.
data["start"] = 0
data["end"] = 2
data["times_sampled"] = 1
data.head()

Unnamed: 0,drummer,session,id,style,simplified_style,bpm,beat_type,time_signature,midi_filename,audio_filename,duration,split,start,end,times_sampled
0,drummer1,drummer1/eval_session,drummer1/eval_session/1,funk/groove1,funk,138,1,4-4,drummer1/eval_session/1_funk-groove1_138_beat_...,drummer1/eval_session/1_funk-groove1_138_beat_...,27.872308,test,0,2,1
1,drummer1,drummer1/eval_session,drummer1/eval_session/10,soul/groove10,funk,102,1,4-4,drummer1/eval_session/10_soul-groove10_102_bea...,drummer1/eval_session/10_soul-groove10_102_bea...,37.691158,test,0,2,1
2,drummer1,drummer1/eval_session,drummer1/eval_session/2,funk/groove2,funk,105,1,4-4,drummer1/eval_session/2_funk-groove2_105_beat_...,drummer1/eval_session/2_funk-groove2_105_beat_...,36.351218,test,0,2,1
3,drummer1,drummer1/eval_session,drummer1/eval_session/3,soul/groove3,funk,86,1,4-4,drummer1/eval_session/3_soul-groove3_86_beat_4...,drummer1/eval_session/3_soul-groove3_86_beat_4...,44.716543,test,0,2,1
4,drummer1,drummer1/eval_session,drummer1/eval_session/4,soul/groove4,funk,80,1,4-4,drummer1/eval_session/4_soul-groove4_80_beat_4...,drummer1/eval_session/4_soul-groove4_80_beat_4...,47.9875,test,0,2,1


In [786]:
def pick_area_to_sample(audio_filename):
    # get all samples of that file so far
    samples = data[data.audio_filename == audio_filename]
    
    if len(samples) == 0:
        return None
    
    if len(samples) == 1:
        return (samples.iloc[0].end, samples.iloc[0].duration)

    samples = samples.sort_values(by='start', ascending=True)
    
    # get all unused 'inbetween' areas, that are longer than 2 seconds
    free_areas = []
    for i in range(1, len(samples)):
        start_of_free_area = samples.iloc[i-1].end
        
        if i == len(samples) - 1:
            end_of_free_area = samples.iloc[i].duration
        else:
            end_of_free_area = samples.iloc[i].start
        
        if start_of_free_area + 2 < end_of_free_area:
            free_areas.append((start_of_free_area, end_of_free_area))
    
    # pick a random free area
    if len(free_areas) == 0:
        return None
    else:
        return free_areas[random.randint(0, len(free_areas)-1)]
    
pick_area_to_sample(data.audio_filename[0])

(2, 27.872308)

In [787]:
import random
random.seed(RANDOM_SEED)

data["possible_samples"] = (data.duration / SAMPLE_DURATION)
data.possible_samples = data.possible_samples - 1.0

for style in data.simplified_style.unique():
    samples = data[data.simplified_style == style]
    
    while len(samples) < SAMPLES_PER_GENRE:
        samples = data[data.simplified_style == style]
        possible_samples = samples[(samples.possible_samples >= 1.0) & (samples.times_sampled < MAX_SAMPLES_PER_FILE)]
        
        # pick random file
        index = random.randint(0, len(possible_samples)-1)
        filename = possible_samples.iloc[index].audio_filename
        
        # pick an unused area to sample 2 seconds from
        unused_area = pick_area_to_sample(possible_samples.iloc[index].audio_filename)
        if(unused_area == None):
            data.loc[(data.audio_filename == filename), "possible_samples"] = -1.0
            # try another file
            continue
        
        new_sample = possible_samples.iloc[index].copy()
        new_sample.start = random.uniform(unused_area[0], unused_area[1] - 2.0)
        new_sample.end = new_sample.start + 2.0
        #new_sample.times_sampled += 1
        data = pd.concat([data, new_sample.to_frame().T], ignore_index=True)
        
        # subtract 1 from the column possible_samples from all entries of a Dataframe where the column audio_filename is equal to 'test'
        data.loc[(data.audio_filename == filename), "possible_samples"] = data.loc[(data.audio_filename == filename), "possible_samples"] - 1.0
        data.loc[(data.audio_filename == filename), "times_sampled"] = data.loc[(data.audio_filename == filename), "times_sampled"] + 1.0

In [788]:
print(data.simplified_style.value_counts())

simplified_style
funk      201
hiphop    201
pop       201
rock      201
latin     201
jazz      201
Name: count, dtype: int64


In [789]:
# most sampled files
data.sort_values(by='times_sampled', ascending=False).audio_filename.unique()

array(['drummer7/session3/23_pop-soft_83_beat_4-4.wav',
       'drummer7/session3/22_pop-soft_83_beat_4-4.wav',
       'drummer1/session1/101_dance-disco_120_beat_4-4.wav',
       'drummer7/session2/80_country_78_beat_4-4.wav',
       'drummer7/session2/97_pop_142_beat_4-4.wav',
       'drummer7/session3/11_pop-soft_83_beat_4-4.wav',
       'drummer7/session2/96_pop_142_beat_4-4.wav',
       'drummer7/eval_session/7_pop-groove7_138_beat_4-4.wav',
       'drummer5/eval_session/7_pop-groove7_138_beat_4-4.wav',
       'drummer1/session2/10_country_114_beat_4-4.wav',
       'drummer5/session2/21_latin-brazilian-ijexa_108_beat_4-4.wav',
       'drummer1/session3/7_dance-disco_120_beat_4-4.wav',
       'drummer7/session2/104_pop_132_beat_4-4.wav',
       'drummer8/session1/14_hiphop_94_beat_4-4.wav',
       'drummer8/eval_session/7_pop-groove7_138_beat_4-4.wav',
       'drummer7/session2/100_pop_142_beat_4-4.wav',
       'drummer8/eval_session/6_hiphop-groove6_87_beat_4-4.wav',
       'drumm

In [790]:
data.times_sampled.value_counts().sort_index()

times_sampled
1       218
2.0     160
3.0      84
4.0      52
5.0     110
6.0      90
7.0      21
8.0      80
9.0      63
10.0     20
11.0    308
Name: count, dtype: int64

In [791]:
# almost completly sampled files
data[data.possible_samples == -1].audio_filename.unique()

array([], dtype=object)

In [792]:
# check if no two start positions exist in the same file
data.start.value_counts()

start
0.000000      426
59.590658       1
139.390476      1
29.908664       1
41.288381       1
             ... 
14.663295       1
16.618516       1
25.946823       1
6.014909        1
41.314932       1
Name: count, Length: 781, dtype: int64

In [793]:
data.to_csv('data_balanced.csv', index=False)