In [1]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import glob
import random
from data_loader import full_load_map, data_dir, load_map, Note
from concurrent.futures import ProcessPoolExecutor
import json
import shutil
import os
import traceback

In [2]:
random_seed = 1470258369

random.seed(random_seed)
np.random.seed(random_seed)
tf.random.set_seed(random_seed)

In [3]:
def data_generator_multi_process(map_folders):
    map_folders = [map_folder.decode('UTF-8') for map_folder in map_folders]
    max_workers = 18
    items_in_queue = max_workers * 5
    queued_maps = items_in_queue
    cancel = False
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        map_tasks = [(executor.submit(full_load_map, map_folder), map_folder) for map_folder in map_folders[:items_in_queue]]
        while len(map_tasks) > 0:
            map_task, map_folder = map_tasks.pop(0)
            try:
                if cancel:
                    map_task.cancel()
                    continue
                results = map_task.result()
                for result in results:
                    x_context_prev_audio, x_context_prev_notes, x_context_audio, y_context_notes, z_timing_counts, z_note_counts, z_note_pos_counts, z_acc_prediction, z_speed_prediction = result
                    yield (x_context_prev_audio), (x_context_prev_notes), (x_context_audio), z_timing_counts, z_note_counts/20, z_note_pos_counts/10, z_acc_prediction, z_speed_prediction, (y_context_notes)
            except InterruptedError as ke:
                cancel = True
            except Exception as exc:
                if str(exc) != "'_version'" and str(exc) != 'not v2':
                    print(map_folder)
                    print(exc)
                    traceback.print_exc()
            finally:
                if not cancel:
                    queued_maps += 1
                    if queued_maps < len(map_folders):
                        map_tasks.append((executor.submit(full_load_map, map_folders[queued_maps]), map_folders[queued_maps]))

In [4]:
def create_ds_for_files(map_folders, batch_size, name, cache=False, shuffle=False):
    ds = tf.data.Dataset.from_generator(data_generator_multi_process, args=[map_folders], output_signature=(
        tf.TensorSpec(shape=(None, 2, 87, 129), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 2, 40, 25), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 2, 87, 129), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 1), dtype=tf.int32),
        tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 2, 40, 25), dtype=tf.float32),
        # tf.TensorSpec(shape=(None, 1025, 44), dtype=tf.float32),
        # tf.TensorSpec(shape=(None, 35), dtype=tf.float32),
    ))
    ds = ds.flat_map(lambda x1, x2, x3, x4, x5, x6, x7, x8, y: tf.data.Dataset.from_tensor_slices((x1, x2, x3, x4, x5, x6, x7, x8, y)))
    ds = ds.prefetch(20000)

    if cache:
        # ds = ds.cache()
        ds = ds.cache(f"./somethingsomething/{name}")
    if shuffle:
        ds = ds.shuffle(25000, reshuffle_each_iteration=True)
        # ds = ds.shuffle(len([v for v in ds]), reshuffle_each_iteration=True)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(256)
    return ds

In [5]:
# I currently cache the entire dataset, since the data loading part is quite compute intensive. Added a limit of 50 maps to avoid running out of ram on a test run.
maps = [path.replace("\\", "/") for path in glob.glob("../data/maps/*")]
random.shuffle(maps)
maps = maps[:50]

In [6]:
batch_size = 128
val_split = 0.1
train_ds = create_ds_for_files(maps[int(len(maps)*val_split):], batch_size, "train", True, True)
val_ds = create_ds_for_files(maps[:int(len(maps)*val_split)], batch_size, "val", True, False)

In [7]:
# preload the dataset into cache to keep the data loading errors away from the training logs
discard_val = 0
for v in tqdm(train_ds):
    discard_val = discard_val + 1

for v in tqdm(val_ds):
    discard_val = discard_val + 1

In [8]:
# Per position loss balance, might not be relevent with the updatd note format
# # with tf.device('/CPU:0'):
# out_value_counts = tf.constant([0]*49, dtype=tf.float32)
# for v_batch in tqdm(train_ds):
#     out_value_counts = out_value_counts + tf.reduce_sum(v_batch[-1], axis=[0, 1])
# # for i, ovc in enumerate(out_value_counts):
# #     print(f"{(i-1)} {ovc}")
# note_poss_loss_balance = tf.expand_dims(tf.expand_dims(1/tf.maximum(out_value_counts[1:] * (1/np.max(out_value_counts[1:])*10), 1), 0), 0)
# note_poss_loss_balance

In [9]:
# tf.config.run_functions_eagerly(False)

In [22]:
def audio_block():
    input_audio = tf.keras.Input(shape=(87, 129, 1), dtype="float32")
    l = input_audio
    l = tf.keras.layers.Conv2D(128, 5, activation="relu", padding="same")(l)
    l = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding="same")(l)
    l = tf.keras.layers.Conv2D(128, 3, activation="relu")(l)
    l = tf.keras.layers.MaxPooling2D(pool_size=(1, 2))(l)
    l = tf.keras.layers.Conv2D(128, 3, activation="relu")(l)
    l = tf.keras.layers.MaxPooling2D(pool_size=(1, 2))(l)
    l = tf.keras.layers.Reshape((40, -1))(l)
    l = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(128, activation="tanh"))(l)
    return tf.keras.Model(input_audio, l)

In [23]:
def audio_block_stereo():
    l_audio_block = audio_block()
    
    input_audio = tf.keras.Input(shape=(2, 87, 129, 1), dtype="float32")
    l = input_audio
    l = tf.keras.layers.TimeDistributed(l_audio_block)(l)
    l = tf.keras.layers.Reshape((40, -1))(l)
    l = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True))(l)
    l = tf.keras.layers.LSTM(64, return_sequences=True)(l)
    return tf.keras.Model(input_audio, l)

In [24]:
def note_positioning_block():
    input_timings = tf.keras.Input(shape=(2, 40, 1), dtype="float32")
    input_features = tf.keras.Input(shape=(40, 128), dtype="float32")
    
    l_timings = tf.keras.layers.Reshape((40, -1))(input_timings)

    l = tf.keras.layers.Concatenate(axis=2)([l_timings, input_features])
    l = tf.keras.layers.LSTM(256, return_sequences=True)(l)
    l = tf.keras.layers.LSTM(256, return_sequences=True)(l)
    l_pos_out = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(24, activation="sigmoid"))(l)
    l = tf.keras.layers.Concatenate(axis=2)([l_pos_out, l])
    l = tf.keras.layers.LSTM(128, return_sequences=True)(l)
    l = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(256, activation="relu"))(l)
    l_angle_out = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(24, activation="relu"))(l)
    
    l_pos_out = tf.keras.layers.Reshape((2, 40, -1))(l_pos_out)
    l_angle_out = tf.keras.layers.Reshape((2, 40, -1))(l_angle_out)
    
    return tf.keras.Model([input_timings, input_features], [l_pos_out, l_angle_out])

In [25]:
def make_model():
    input_prev_audio = tf.keras.Input(shape=(2, 87, 129, 1), dtype="float32")
    input_prev_notes = tf.keras.Input(shape=(2, 40, 25), dtype="float32")
    input_audio = tf.keras.Input(shape=(2, 87, 129, 1), dtype="float32")
    input_acc_prediction = tf.keras.Input(shape=(1), dtype="float32")
    input_speed_prediction = tf.keras.Input(shape=(1), dtype="float32")
    input_y_note_timings = tf.keras.Input(shape=(2, 40, 1), dtype="float32")

    audio_l = audio_block_stereo()

    l_prev_audio = audio_l(input_prev_audio)
    l_audio = audio_l(input_audio)
    
    l_prev_notes = tf.keras.layers.Reshape((40, -1))(input_prev_notes)
    
    l_prev = tf.keras.layers.Concatenate(axis=2)([l_prev_audio, l_prev_notes])
    l_prev = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True))(l_prev)
    l_prev = tf.keras.layers.LSTM(64)(l_prev)
    
    l_input_acc_prediction = tf.keras.layers.RepeatVector(40)(input_acc_prediction)
    l_input_speed_prediction = tf.keras.layers.RepeatVector(40)(input_speed_prediction)
    l_prev = tf.keras.layers.RepeatVector(40)(l_prev)
    l = tf.keras.layers.Concatenate(axis=2)([l_audio, l_prev, l_input_acc_prediction, l_input_speed_prediction])
    l = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True))(l)
    l = tf.keras.layers.LSTM(128, return_sequences=True)(l)
    l_timings_out = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(2, activation="sigmoid"))(l)
    l_timings_out = tf.keras.layers.Reshape((2, 40, -1))(l_timings_out)

    note_positioning_l = note_positioning_block()
    
    l_pos_out, l_angle_out = note_positioning_l([input_y_note_timings, l])

    model = tf.keras.Model(inputs = [input_prev_audio, input_prev_notes, input_audio, input_acc_prediction, input_speed_prediction, input_y_note_timings], outputs = [l_timings_out, l_pos_out, l_angle_out])
    return model

In [26]:
model = make_model()
model.summary()

Model: "model_7"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_13 (InputLayer)          [(None, 2, 87, 129,  0           []                               
                                 1)]                                                              
                                                                                                  
 input_11 (InputLayer)          [(None, 2, 87, 129,  0           []                               
                                 1)]                                                              
                                                                                                  
 input_12 (InputLayer)          [(None, 2, 40, 25)]  0           []                               
                                                                                            

In [27]:
# visualize the model
# tf.keras.utils.plot_model(model, show_shapes=True, show_layer_names=True, expand_nested=True)

In [59]:
optimizer = tf.keras.optimizers.Adam()

In [60]:
train_timing_loss_metric = tf.keras.metrics.Mean(name='train_timing_loss')
train_positioning_loss_metric = tf.keras.metrics.Mean(name='train_positioning_loss')
train_y_positioning_loss_metric = tf.keras.metrics.Mean(name='train_y_positioning_loss')
train_y_positioning_angle_loss_metric = tf.keras.metrics.Mean(name='train_y_positioning_angle_loss')
train_loss_metric = tf.keras.metrics.Mean(name='train_loss')

val_timing_loss_metric = tf.keras.metrics.Mean(name='val_timing_loss')
val_positioning_loss_metric = tf.keras.metrics.Mean(name='val_positioning_loss')
val_y_positioning_loss_metric = tf.keras.metrics.Mean(name='val_y_positioning_loss')
val_y_positioning_angle_loss_metric = tf.keras.metrics.Mean(name='val_y_positioning_angle_loss')
val_loss_metric = tf.keras.metrics.Mean(name='val_loss')

loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)

### Custom loss
Calculates 2 separate values:
 - timing loss - simple loss based on when the notes were placed
 - positioning loss - loss based on the position and direction of the placed note, adjusted for the number of notes that appear in different positions to avoid a massive bias towards placing most commonly appearing notes

In [65]:
@tf.function
def custom_loss(y, predictions):
    timing_predictions, pos_predictions, angle_predictions = predictions
    
    org_timing_loss_matrix = tf.square(y[:, :, :, :1] - timing_predictions) * (y[:, :, :, :1] + 0.069)
    timing_loss_matrix = org_timing_loss_matrix
    timing_loss = tf.reduce_mean(timing_loss_matrix) * 6.9
    
    # positioning_loss_matrix = tf.square(y[:, :, 1:] - positioning_predictions) * (y[:, :, :1]) * (y[:, :, 1:] * note_poss_loss_balance + 0.0169)
    # positioning_loss = tf.reduce_sum(positioning_loss_matrix) / tf.reduce_sum(y[:, :, 1:]) * 0.5
    
    y_positioning_loss_matrix = tf.square(y[:, :, :, 1::2] - pos_predictions) * (y[:, :, :, :1]) * (y[:, :, :, 1::2] + 0.069)
    y_positioning_loss = tf.reduce_sum(y_positioning_loss_matrix) / tf.reduce_sum(y[:, :, :, 1::2]) * 0.69
    
    y_positioning_angle_loss_matrix = tf.square(tf.minimum(tf.abs(y[:, :, :, 2::2] - angle_predictions), tf.minimum(tf.abs(y[:, :, :, 2::2] - angle_predictions + 1), tf.abs(y[:, :, :, 2::2] - angle_predictions - 1)))) * (y[:, :, :, 1::2])
    y_positioning_angle_loss = tf.reduce_sum(y_positioning_angle_loss_matrix) / tf.reduce_sum(y[:, :, :, 1::2]) * 6.9
    
    loss = timing_loss + y_positioning_loss + y_positioning_angle_loss
    return timing_loss, y_positioning_loss, y_positioning_angle_loss, loss

In [66]:
@tf.function
def train_step(model, optimizer, data):
    x1, x2, x3, x4, x5, x6, x7, x8, y = data
    
    with tf.GradientTape() as tape:
        predictions = model([x1, x2, x3, x7, x8, y[:, :, :, :1]], training=True)
        # predictions = model([x1, x2, x3, x4, x5, x6, x7, x8, y[:, :, :1]], training=True)
        # with tf.device('/CPU:0'):
        timing_loss, y_positioning_loss, y_positioning_angle_loss, loss = custom_loss(y, predictions)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_timing_loss_metric(timing_loss)
    # train_positioning_loss_metric(positioning_loss)
    train_y_positioning_loss_metric(y_positioning_loss)
    train_y_positioning_angle_loss_metric(y_positioning_angle_loss)
    train_loss_metric(loss)

@tf.function
def val_step(model, data):
    x1, x2, x3, x4, x5, x6, x7, x8, y = data
    
    predictions = model([x1, x2, x3, x7, x8, y[:, :, :, :1]], training=False)
    # predictions = model([x1, x2, x3, x4, x5, x6, x7, x8, y[:, :, :1]], training=False)
    # with tf.device('/CPU:0'):
    timing_loss, y_positioning_loss, y_positioning_angle_loss, loss = custom_loss(y, predictions)
        
    val_timing_loss_metric(timing_loss)
    # val_positioning_loss_metric(positioning_loss)
    val_y_positioning_loss_metric(y_positioning_loss)
    val_y_positioning_angle_loss_metric(y_positioning_angle_loss)
    val_loss_metric(loss)

### Results validation

- specify the correct folder with maps for which you would want to generate the map
- add maps that you want to use for testing, better to avoid using the maps that already exist in the training dataset to avoid false positives of AI learning a specific map
- add an Expert diff if doesn't exist, currently hardcoded to just override the Expert diff to avoid setting up all the metadata

In [67]:
def full_validation(timing_threshhold, epoch, acc_prediction, speed_prediction):
    base_validation_path = "./validation"
    os.makedirs(base_validation_path, exist_ok=True)
    
    validation_map = "mazule"
    validation_map = "isaidthat_DA42AF71F4CA5AD280C3F69BCA0BD6C6D1CDA06E"
    validation_map = "DA42AF71F4CA5AD280C3F69BCA0BD6C6D1CDA06E"
    validation_map_dir = f"{base_validation_path}/{validation_map}"
    (song_data, segment_duration), diffs = load_map(validation_map_dir)
    
    
    angle_to_direction = {
        180:0,
        0:1,
        90:2,
        270:3,
        135:4,
        225:5,
        45:6,
        315:7
    }
    direction_to_angle = {
        0: 180,
        1: 0,
        2: 90,
        3: 270,
        4: 135,
        5: 225,
        6: 45,
        7: 315
    }


    def get_note_angle(direction):
        return direction_to_angle[direction] / 360
    
    def get_note_direction(angle):
        angle = int(angle * 360)
        angle = (angle - angle % 45) % 360
        return angle_to_direction[angle]

    def validate_model(song_data, segment_duration, timing_threshhold, positioning_threshhold, intensity_1, intensity_2, acc_prediction, speed_prediction, note_pos_count):
        context_length = 1
        prediction_note_count = context_length * 40
        prediction_note_time_length = context_length / prediction_note_count

        context_steps = int(context_length / segment_duration) + 1
        step_size = context_steps
        
        generated_notes = []
        max_val_timing = 0
        max_val_positioning = 0
        
        zero_notes = 1
        one_notes = 1
        
        prev_note_segment = ([[0]*25 for i in range(prediction_note_count)], [[0]*25 for i in range(prediction_note_count)])
        prev_audio_segment = song_data[:, :context_steps, :]
        with tqdm(range(context_steps, song_data.shape[1] - context_steps, step_size)) as _tqdm:
          for i in _tqdm:
            curr_time = i * segment_duration
            
            x_context_prev_audio = prev_audio_segment
            x_context_prev_notes = prev_note_segment
            x_context_audio = song_data[:, i:i+context_steps, :]
            timing_prediction, placement_prediction, placement_angle_prediction = model([np.array([x_context_prev_audio]), np.array([x_context_prev_notes]), np.array([x_context_audio]), np.array([acc_prediction + (random.random() - 0.5) * 0.1]), np.array([speed_prediction + (random.random() - 0.5) * 0.1]), np.array([[[[0]]*40, [[0]]*40]])], training=False)
            timing_prediction = tf.where(timing_prediction > timing_threshhold, 1, 0)
            _timing_prediction, placement_prediction, placement_angle_prediction = model([np.array([x_context_prev_audio]), np.array([x_context_prev_notes]), np.array([x_context_audio]), np.array([acc_prediction + (random.random() - 0.5) * 0.1]), np.array([speed_prediction + (random.random() - 0.5) * 0.1]), timing_prediction], training=False)
            timing_prediction = np.array(timing_prediction[0])
            placement_prediction = np.array(placement_prediction[0])
            placement_angle_prediction = np.array(placement_angle_prediction[0])
            
            
            x_context_prev_audio = x_context_audio
            prev_note_segment = ([[0]*25 for i in range(prediction_note_count)], [[0]*25 for i in range(prediction_note_count)])

            # I use them to find values that would generate a reasonable number of notes.
            # Small adjustments to the model and it's loss function can significantly shift the actual number values.
            # if max_val_timing < np.max(timing_prediction):
            #     max_val_timing = np.max(timing_prediction)
            #     print(f"max_timing: {max_val_timing}")
                
                # if max_val_positioning < np.max(positioning_prediction):
                #     max_val_positioning = np.max(positioning_prediction)
                #     print(f"max_positioning: {max_val_positioning}")
            
            for j in range(prediction_note_count):
                for color in range(2):
                    curr_note_time = curr_time + j * prediction_note_time_length
                    prediction_timing = timing_prediction[color][j][0]
                    if prediction_timing < timing_threshhold:
                        continue
                    prediction_positioning = placement_prediction[color][j]
                    prediction_angle = placement_angle_prediction[color][j]
                    
                    # prediction_positioning[:12] = prediction_positioning[:12] * (one_notes / (zero_notes + one_notes + 0.00001))
                    # prediction_positioning[12:] = prediction_positioning[12:] * (zero_notes / (zero_notes + one_notes + 0.00001))

                    # Place only 1 note per timing or to place many notes. More than 1 note can get super repetitive more easily, but both are of quite poor quality so far.
                    max_one_note_per_placement = True
                    if max_one_note_per_placement:
                        note_prediction_iter = np.argmax(prediction_positioning)
                        prediction_positioning_enumerated = [(note_prediction_iter, prediction_positioning[note_prediction_iter], prediction_angle[note_prediction_iter])]
                    else:
                        prediction_positioning_enumerated = [(i, note_prediction, prediction_angle[i]) for i, note_prediction in enumerate(prediction_positioning) if note_prediction > positioning_threshhold]

                    for note_prediction_iter, note_prediction, angle_prediction in prediction_positioning_enumerated:
                            line_layer = note_prediction_iter % 3
                            line_index = int(note_prediction_iter / 3) % 4
                            direction = get_note_direction(angle_prediction)
                            generated_notes.append(Note(curr_note_time, line_index, line_layer, color, direction))
                            if color == 0:
                                zero_notes += 1
                            else:
                                one_notes += 1
                            prev_note_segment[color][j][0] = 1
                            prev_note_segment[color][j][1 + note_prediction_iter * 2] = 1
                            prev_note_segment[color][j][1 + note_prediction_iter * 2 + 1] = get_note_angle(direction)
            if len(generated_notes) > 0:
                average_notes_per_second = len(generated_notes)/generated_notes[-1].time
            else:
                average_notes_per_second = -1
            _tqdm.set_postfix(average_notes_per_second=average_notes_per_second)
        
        generated_notes.sort(key=lambda note: note.time)
        return generated_notes

    intensity_timings_per_second = 7 # model input for number of correct timings per second
    intensity_notes_per_second = intensity_timings_per_second # model input for sum of '1's in the prediction segment. Increasing this value should result in more stacks and sliders.
    note_pos_count = 3
    generated_notes = validate_model(song_data, segment_duration, timing_threshhold=timing_threshhold, positioning_threshhold=0.45, intensity_1=intensity_timings_per_second, intensity_2=intensity_notes_per_second/20, acc_prediction=acc_prediction, speed_prediction=speed_prediction, note_pos_count=note_pos_count/10)

    if len(generated_notes) > 0:
        average_notes_per_second = len(generated_notes)/generated_notes[-1].time
    else:
        average_notes_per_second = -1
    
    with open(validation_map_dir + "/Info.dat", "rb") as f:
        info_json = json.load(f)
        bpm = info_json["_beatsPerMinute"]
        
    with open(validation_map_dir + "/ExpertStandard.dat", "rb") as f:
        diff_json = json.load(f)

    diff_json["_notes"] = [{"_time": note.time / 60 * bpm, "_lineIndex": int(note.lineIndex), "_lineLayer": int(note.lineLayer), "_type": int(note.type), "_cutDirection": int(note.direction)} for note in generated_notes]
    if len(diff_json["_notes"]) == 0:
        diff_json["_notes"] = [{"_time": 1, "_lineIndex": 0, "_lineLayer": 0, "_cutDirection": 0, "_type": 0}]
    with open(validation_map_dir + "/ExpertStandard.dat", "w") as f:
        json.dump(diff_json, f)
        
    shutil.make_archive(f"{validation_map_dir}q{epoch}q{timing_threshhold}q{average_notes_per_second}q{acc_prediction}q{speed_prediction}", 'zip', validation_map_dir)

In [68]:
for epoch in range(0, 15):
    train_timing_loss_metric.reset_states()
    train_positioning_loss_metric.reset_states()
    train_y_positioning_loss_metric.reset_states()
    train_y_positioning_angle_loss_metric.reset_states()
    train_loss_metric.reset_states()
    
    val_timing_loss_metric.reset_states()
    val_positioning_loss_metric.reset_states()
    val_y_positioning_loss_metric.reset_states()
    val_y_positioning_angle_loss_metric.reset_states()
    val_loss_metric.reset_states()

    if epoch == 0:
        optimizer.learning_rate.assign(0.001)
    elif epoch == 2:
        optimizer.learning_rate.assign(0.00025)
    elif epoch == 4:
        optimizer.learning_rate.assign(0.0001)
    elif epoch == 6:
        optimizer.learning_rate.assign(0.000025)
    elif epoch == 8:
        optimizer.learning_rate.assign(0.00001)
    elif epoch == 11:
        optimizer.learning_rate.assign(0.0000025)
    
    with tqdm(train_ds.enumerate(), unit="batch") as _tqdm:
        _tqdm.set_description(f"Epoch train: {epoch}")
        for step, data in _tqdm:
            train_step(model, optimizer, data)
            _tqdm.set_postfix(
                timing_loss=train_timing_loss_metric.result().numpy(),
                positioning_loss=train_positioning_loss_metric.result().numpy(),
                positioning_y_loss=train_y_positioning_loss_metric.result().numpy(),
                positioning_y_angle_loss=train_y_positioning_angle_loss_metric.result().numpy(),
                loss=train_loss_metric.result().numpy(),
            )
    
    if 'val_ds' in locals() or 'val_ds' in globals():
        with tqdm(val_ds.enumerate(), unit="batch") as _tqdm:
            _tqdm.set_description(f"Epoch val: {epoch}")
            for step, data in _tqdm:
                val_step(model, data)
                _tqdm.set_postfix(
                    timing_loss=val_timing_loss_metric.result().numpy(),
                    positioning_loss=val_positioning_loss_metric.result().numpy(),
                    positioning_y_loss=val_y_positioning_loss_metric.result().numpy(),
                    positioning_y_angle_loss=val_y_positioning_angle_loss_metric.result().numpy(),
                    loss=val_loss_metric.result().numpy(),
                )
    
    full_validation(0.825 + (random.random() - 0.5)*0.2, epoch, 0.775 + (random.random() - 0.5)*0.25, 0.35 + (random.random() - 0.5)*0.25)
    full_validation(0.825 + (random.random() - 0.5)*0.2, epoch, 0.775 + (random.random() - 0.5)*0.25, 0.35 + (random.random() - 0.5)*0.25)

Epoch train: 2: : 317batch [01:35,  4.91batch/s, loss=0.453, positioning_loss=0, positioning_y_angle_loss=0.176, positioning_y_loss=0.174, timing_loss=0.103]

Epoch train: 2: : 328batch [01:37,  4.64batch/s, loss=0.452, positioning_loss=0, positioning_y_angle_loss=0.176, positioning_y_loss=0.174, timing_loss=0.102]