In [None]:
import os
import time
import csv
import resource
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import librosa
import librosa.display
from IPython.display import Audio
from pytube import YouTube
from pydub import AudioSegment
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (
    Input, Conv1D, MaxPooling1D, Conv1DTranspose, Concatenate,
    Add, Activation, Multiply, Dropout, BatchNormalization
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import he_normal
from tensorflow.keras.callbacks import (
    EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, LearningRateScheduler
)


def load_data(dataset_path, duration=30):
    tracks = []
    required_files = ['mixed_original.mp3', 'mixed_with_noise.mp3', 'mixed_with_pitch_shift.mp3', 'mixed_with_speed_change.mp3', 'fiddle.mp3', 'flute.mp3', 'xylophone.mp3']

    for folder in os.listdir(dataset_path):
      for mixed_file in ['mixed_original.mp3', 'mixed_with_noise.mp3', 'mixed_with_pitch_shift.mp3', 'mixed_with_speed_change.mp3']:
        folder_path = os.path.join(dataset_path, folder)
        if os.path.isdir(folder_path):
            files_in_folder = os.listdir(folder_path)
            if all(file in files_in_folder for file in required_files):
                mixture_path = os.path.join(folder_path, mixed_file)
                fiddle_path = os.path.join(folder_path, 'fiddle.mp3')
                flute_path = os.path.join(folder_path, 'flute.mp3')
                xylophone_path = os.path.join(folder_path, 'xylophone.mp3')

                mixture, sr = librosa.load(mixture_path, sr=None, duration=duration)
                fiddle, _ = librosa.load(fiddle_path, sr=sr, duration=duration)
                flute, _ = librosa.load(flute_path, sr=sr, duration=duration)
                xylophone, _ = librosa.load(xylophone_path, sr=sr, duration=duration)

                if len(mixture) == duration * sr and len(fiddle) == duration * sr and len(flute) == duration * sr and len(xylophone) == duration * sr:
                    mixture = mixture.reshape(-1, 1)
                    instruments = np.stack([fiddle, flute, xylophone], axis=-1)
                    tracks.append((mixture, instruments))
    return tracks, sr

dataset_path = '/content/DSTHAI/content/DSTHAI'
tracks, sample_rate = load_data(dataset_path)


def split_dataset(tracks, split_ratio=0.8):
    train_tracks, val_tracks = train_test_split(tracks, train_size=split_ratio, random_state=42)
    return train_tracks, val_tracks
train_tracks, val_tracks = split_dataset(tracks, split_ratio=0.8)

def data_generator(tracks, batch_size):
    while True:
        for start in range(0, len(tracks), batch_size):
            end = min(start + batch_size, len(tracks))
            batch_tracks = tracks[start:end]

            mixtures = np.array([track[0] for track in batch_tracks])
            instruments = np.array([track[1] for track in batch_tracks])

            yield mixtures, instruments

batch_size = 12
train_generator = data_generator(train_tracks, batch_size)
val_generator = data_generator(val_tracks, batch_size)

def create_wave_unet_model(input_shape):
    def downsampling_block(x, filters, kernel_size, pool_size):
        conv = Conv1D(filters, kernel_size, activation='relu', padding='same')(x)
        downsampled = MaxPooling1D(pool_size)(conv)
        return downsampled, conv

    def upsampling_block(x, skip_connection, filters, kernel_size, upsample_size):
        upsampled = Conv1DTranspose(filters, kernel_size, strides=upsample_size, activation='relu', padding='same', kernel_initializer=he_normal())(x)
        concat = Concatenate()([upsampled, skip_connection])
        return concat

    inputs = Input(shape=input_shape)
    skip_connections = []
    x = inputs

    # Downsampling blocks
    for filters in [64, 128, 256]:
        x, skip = downsampling_block(x, filters, 15, 2)
        skip_connections.append(skip)

    # Bottleneck
    x = Conv1D(512, 15, activation='relu', padding='same')(x)

    # Upsampling blocks
    for filters, skip in zip([256, 128, 64], reversed(skip_connections)):
        x = upsampling_block(x, skip, filters, 5, 2)

    outputs = Conv1D(3, 1, activation='linear', padding='same')(x)

    model = Model(inputs, outputs)
    return model

input_shape = (None, 1)
model = create_wave_unet_model(input_shape)
adam_optimizer = tf.keras.optimizers.Adam(learning_rate=0.004, beta_1=0.9, beta_2=0.999, epsilon=1e-07)
model.compile(optimizer=adam_optimizer, loss='mse', metrics=[tf.keras.metrics.MeanAbsoluteError()])
model.summary()

class keepDataCSV(tf.keras.callbacks.Callback):
    def __init__(self, save_path):
        self.save_path = save_path
        self.csv_file = open(os.path.join(save_path, 'training_stats.csv'), 'w')
        self.csv_writer = csv.writer(self.csv_file)
        self.csv_writer.writerow(['Epoch Number', 'Training Loss', 'Validate Loss', 'MAE', 'Validate MAE', 'Learning Rate', 'Elapsed Time', 'RAM', 'CPU'])

    def on_train_begin(self, logs=None):
        # Record the start time when training begins
        self.start_time = time.time()

    def on_epoch_end(self, epoch, logs={}):
        # Calculate RAM and CPU usage
        ram_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 ** 2
        cpu_usage = resource.getrusage(resource.RUSAGE_SELF).ru_utime + resource.getrusage(resource.RUSAGE_SELF).ru_stime

        # Write epoch data to CSV
        self.csv_writer.writerow([epoch + 1, logs.get('loss'), logs.get('val_loss'), logs.get('mean_absolute_error'), logs.get('val_mean_absolute_error'),
                                  logs.get('lr'), time.time() - self.start_time, ram_usage, cpu_usage])
        self.csv_file.flush()

# Create the callback object
save_path = '/content/ModelComplete'
keepDataCSV = keepDataCSV(save_path)
# Callbacks


# Define callbacks
checkpoint = ModelCheckpoint('/content/ModelComplete/model_{epoch:02d}.h5', save_freq='epoch', period=5)

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.25, patience=3, min_lr=1e-6)

early_stopping = EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=True)

epochs = 20 # Adjust as needed
steps_per_epoch = len(tracks) // batch_size
validation_steps = len(val_tracks) // batch_size

model.fit(train_generator, epochs=epochs, callbacks=[early_stopping, lr_scheduler, checkpoint, keepDataCSV],
          steps_per_epoch=steps_per_epoch, validation_data=val_generator, validation_steps=validation_steps)

keepDataCSV.csv_file.close()

model.save('model.h5')