In [None]:
# Import the autoreload extension
%load_ext autoreload

# Set autoreload to reload all modules (except those excluded by %aimport)
%autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0" #for GPU inference
from glob import glob
import pickle as pickle
from gen_utils import bass_trans_ev_model_tf, generate_bass_ev_trans_tf, create_onehot_enc
import numpy as np

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [35]:
thr_measures = 16
thr_max_tokens = 800
thr_min_tokens = 50
dec_seq_length = 773

In [36]:
'''load Encoders pickle for onehotencoders'''

#encoders pickle is created during pre-processing
encoders_trans = r'..\..\data\processed\bass_encoders_cp.pickle'

    
with open(encoders_trans, 'rb') as handle:
    TransEncoders = pickle.load(handle)
#[Encoder_RG, Decoder_Bass]

In [41]:
'''Load Inference Transformer. You may download pre-trained model based 
on the paper. See instructions in ReadME.md'''
trans_bass_hb = bass_trans_ev_model_tf(TransEncoders, dec_seq_length)


'''Set Temperature'''
temperature = 0.9

'''Load MIDI files with Guitar (1st) and Bass (2nd). See examples in midi_in folder'''
'''max 16 bars'''
#input folder (put txt token files of rg only here)
inp_path = glob('./tokens_in/*.txt')
#output folder
out_path = './tokens_out/'

1929 1177
Loading Hybrid Music Transformer
Latest checkpoint restored!


## Inference on selected text token files

In [43]:
def get_measure_indices(lines):
    """Retrieve indices of 'new_measure' tokens in a list of lines."""
    return [i for i, line in enumerate(lines) if line == "new_measure"]

for trk in inp_path:
    # Get track name
    trk_name = trk.split('\\')[-1][:-4]  # Adjust for your OS if needed
    print('Generating..', trk_name)
    save_path = out_path + trk_name.replace('_rythmic', f'_with_bass_{temperature}.txt')

    # PREPROCESSING (get the sequence of tokens)
    with open(trk, 'r') as rg_file:
        rg_lines = [line.strip() for line in rg_file.readlines()]

    # Retrieve header and content tokens
    rg_header = rg_lines[:4]
    rg_content = rg_lines[4:]
    
    # Retrieve measure indices
    rg_measures = get_measure_indices(rg_content)
    max_measures = len(rg_measures)

    if max_measures < thr_measures:
        print(f"Less than {thr_measures} measures in {trk_name}")
        continue

    generated_tokens = []
    first_segment = True  # Flag to handle the first sequence

    # Generate sequences while ensuring concatenation
    for start_idx in range(0, max_measures, thr_measures):
        end_idx = min(start_idx + thr_measures, max_measures)

        print(f"Processing measures {start_idx} - {end_idx}")

        # Extract the relevant content (with header included)
        sequence_input = rg_header + [line for line in rg_content[rg_measures[start_idx]:rg_measures[end_idx]]]

        # Ensure token count is within limits
        if not (thr_min_tokens <= len(sequence_input) <= thr_max_tokens):
            print(f"Token count {len(sequence_input)} not in range {thr_min_tokens} - {thr_max_tokens}")
            continue

        # Convert tokens to one-hot encoding
        Enc_Input = create_onehot_enc(sequence_input, TransEncoders)
        Enc_Input = Enc_Input + [0] * (dec_seq_length - len(Enc_Input))  # Padding

        # Generate bass sequence
        bass_HB = generate_bass_ev_trans_tf(trans_bass_hb, TransEncoders, temperature, Enc_Input, dec_seq_length=dec_seq_length)

        # Remove the header tokens from the generated bass part (except for the first sequence)
        if first_segment:
            generated_tokens.extend(bass_HB)  # Keep header in the first segment
            first_segment = False
        else:
            # Remove header tokens from subsequent sequences
            first_measure_idx = next((i for i, token in enumerate(bass_HB) if token == "new_measure"), None)
            if first_measure_idx is not None:
                generated_tokens.extend(bass_HB[first_measure_idx:])  # Keep only the content

    # Save the final sequence in a single file
    with open(save_path, 'w') as f:
        for token in generated_tokens:
            f.write(f"{token}\n")

    print(f"Saved generated bass sequence to {save_path}")


Generating.. Herbie Hancock-Saturday Night_rythmic
Processing measures 0 - 16


  0%|          | 0/773 [00:06<?, ?it/s]


KeyboardInterrupt: 

In [None]:
def get_measure_indices(lines):
    return [i for i, line in enumerate(lines) if line == "new_measure"]

for trk in inp_path:
    #get name
    trk_name = trk.split('\\')[-1][:-4] #you may change it depending your OS
    print('Generating..', trk_name)
    save_path = out_path+trk_name
    # PREPROCESSING (get the sequence of tokens)
    rg_sequence = []
    
    with open(trk, 'r') as rg_file:
        # Retrieve the sequence of tokens 16 by 16 bars
        rg_lines = [line.strip() for line in rg_file.readlines()]

    # Retrieve the header tokens
    rg_header = rg_lines[:4]
    
    # Retrieve the content tokens
    rg_content = rg_lines[4:]
    
    # Retrieve the measure indices
    rg_measures = get_measure_indices(rg_content)

    max_measures = len(rg_measures)
    
    # If there aren't enough measures, skip this song
    if max_measures < thr_measures:
        print("Less than", thr_measures, "measures in", trk_name)
        continue
    
    # Generate overlapping sequences
    for start_idx in range(0, max_measures, thr_measures):
        end_idx = start_idx + thr_measures
        
        print("Processing measures", start_idx, "-", end_idx)

        # If the end index exceeds available measures, break
        if end_idx >= max_measures:
            break
        
        # Extract measure ranges safely
        rg_sequence = rg_header + [line for line in rg_content[rg_measures[start_idx]:rg_measures[end_idx]]]
        
        # Ensure the sequences have an acceptable number of tokens
        rg_token_count = len(rg_sequence)

        if not (thr_min_tokens <= rg_token_count <= thr_max_tokens):
            print("Token count", rg_token_count, "not in range", thr_min_tokens, "-", thr_max_tokens)
            continue
        
        # POST PROCESSING
        #create the Encoder: convert tokens to one-hot encoding
        Enc_Input = create_onehot_enc(rg_sequence, TransEncoders)
        #padding (add 0s to the input until it reaches length 793)
        Enc_Input = Enc_Input + [0]*(dec_seq_length-len(Enc_Input))

        # call generation functions
        bass_HB = generate_bass_ev_trans_tf(trans_bass_hb, TransEncoders, temperature, Enc_Input, dec_seq_length=dec_seq_length)
        
        save_path = save_path.replace('_rythmic', '_with_bass_' + str(temperature) + '_' + str(start_idx) + '_' + str(end_idx) + '.txt')
        
        if os.path.exists(save_path):
            print('Already generated..', trk_name)
            continue      
        # save token files to be passed to the tokens2gp5 algorithm
        with open(save_path, 'w') as f:
            for token in bass_HB:
                f.write("%s\n" % token)
    

Generating.. ACDC - Highway to Hell_rythmic
Processing measures 0 - 16


  0%|          | 0/773 [00:02<?, ?it/s]


KeyboardInterrupt: 

## Inference on the test set


In [26]:
test_path = r"..\..\data\processed\test_set_streams_16_8_800_50.pickle"

with open(test_path, 'rb') as handle:
    testSet = pickle.load(handle)

enc_input_test = np.int64(np.stack(testSet['Encoder_Input'])) #encoder input
output_test = []

for Enc_Input in enc_input_test:
    # call generation functions
    bass_HB = generate_bass_ev_trans_tf(trans_bass_hb, TransEncoders, temperature, Enc_Input, dec_seq_length=dec_seq_length)      
    # save token files to be passed to the tokens2gp5 algorithm
    output_test.append(bass_HB)
    
# output_test is a list of lists of tokens for the 11 817 test set sequences

  0%|          | 3/773 [00:22<1:37:42,  7.61s/it]


KeyboardInterrupt: 

In [34]:
trans_bass_hb.summary()

Model: "hybrid_transformer_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 blstm_encoder_7 (BLSTMEncod  multiple                 16196432  
 er)                                                             
                                                                 
 word_decoder_7 (WordDecoder  multiple                 4457728   
 )                                                               
                                                                 
 dense_343 (Dense)           multiple                  227354    
                                                                 
Total params: 20,881,514
Trainable params: 20,881,514
Non-trainable params: 0
_________________________________________________________________
