In [1]:
import os
import sys
import shutil
sys.path.append(os.path.dirname(os.getcwd()))

In [2]:
from utils.audio_utils import convert_mp3_to_wav, get_audio_list, convert_mp3_to_wav
from utils.paths import p

In [3]:
from utils.ensemble import ensemble_files

In [4]:
from utils.inference import proc_folder
import logging
from pathlib import Path

In [5]:
dir_name = 'music'
mp3_root = p.Datasets / dir_name
mp3_files = get_audio_list(mp3_root, audio_type = '.mp3')

In [6]:
chunk_size = 100
num_chunks = (len(mp3_files) + chunk_size - 1) // chunk_size 

In [7]:
device_ids = [1]
extract_instrumental = False

In [8]:
chunk_list = [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]

In [9]:
total_processed_files = 0

In [10]:
log_file = p.Logs / 'process.log'
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename=log_file, filemode='w')
logger = logging.getLogger(__name__)

In [11]:
def get_matching_file_pairs(root_dirs, audio_type='.wav'):
    files_dict = {}
    for root_dir in root_dirs:
        for path in Path(root_dir).rglob(f'*{audio_type}'):
            filename = path.name
            if filename not in files_dict:
                files_dict[filename] = []
            files_dict[filename].append(str(path))
    
    file_pairs = [{'files': paths, 'output': filename} for filename, paths in files_dict.items() if len(paths) > 1]
    return file_pairs

In [12]:
for i in chunk_list:
    files = mp3_files[i*chunk_size:(i+1)*chunk_size]
    temp_dir = p.Datasets / 'wav' / dir_name / f'batch_{i}'
    temp_dir.mkdir(parents=True, exist_ok=True)

    convert_mp3_to_wav(files, temp_dir)
    logger.info(f"Batch {i}: Converted {len(files)} MP3 files to WAV")
    print(f"Batch {i}: Converted {len(files)} MP3 files to WAV")

    wav_files = get_audio_list(temp_dir, audio_type='.wav')
    input_folder = str(temp_dir)

    model_configs = [
        ('htdemucs', 'Configs/htdemucs_config.yaml', 'Results/model_htdemucs_ep_8_sdr_12.8897.ckpt', f'Results/htdemucs/wav/{dir_name}/batch_{i}'),
        ('mdx23c', 'Configs/mdx23c_config.yaml', 'Results/mdx23c/model_mdx23c_ep_3_sdr_12.9580.ckpt', f'Results/mdx23c/wav/{dir_name}/batch_{i}')
    ]

    for model_type, config_path, start_check_point, store_dir in model_configs:
        store_dir=Path(store_dir)
        store_dir.mkdir(parents=True, exist_ok=True)
        proc_folder(model_type, config_path, start_check_point, input_folder, store_dir, device_ids, extract_instrumental)
        logger.info(f"Batch {i}: Processed with model {model_type}")
        print(f"Batch {i}: Processed with model {model_type}")
        wav_files = get_audio_list(store_dir)
        for wav_file in wav_files:
            if '_other' in wav_file.name:
                if wav_file.is_file():
                    wav_file.unlink()
                    logger.info(f"Deleted {wav_file} because it contains '_other'")
    root_dirs = [p.Results / 'htdemucs' / 'wav' / dir_name / f'batch_{i}', p.Results / 'mdx23c' / 'wav' / dir_name / f'batch_{i}']

    file_pairs = get_matching_file_pairs(root_dirs)
    ensemble_files(file_pairs, algorithm='avg_wave', output_dir=p.Results/"ensemble"/f"batch_{i}", output_type='.mp3')
    logger.info(f"Batch {i}: Ensemble completed")
    print(f"Batch {i}: Ensemble completed")

    # temp WAV 파일 삭제
    for wav_file in wav_files:
        file_path = temp_dir / wav_file
        if file_path.is_file():
            file_path.unlink()
            logger.info(f"Deleted {file_path}")
            
    # temp_dir 삭제
    if temp_dir.exists():
        shutil.rmtree(temp_dir)
        logger.info(f"Deleted directory {temp_dir}")

    for root_dir in root_dirs:
        if root_dir.exists():
            shutil.rmtree(root_dir)
            logger.info(f"Deleted directory {root_dir}")
    
    total_processed_files += len(files)
    logger.info(f"Total processed files so far: {total_processed_files}")
    print(f"Total processed files so far: {total_processed_files}")

logger.info(f"Processing completed. Total processed files: {total_processed_files}")
print(f"Processing completed. Total processed files: {total_processed_files}")

Batch 16: Converted 100 MP3 files to WAV
Start from checkpoint: Results/model_htdemucs_ep_8_sdr_12.8897.ckpt
Instruments: ['vocals', 'other']
Total files found: 100


100%|██████████| 100/100 [23:42<00:00, 14.22s/it, track=임창정 - 그 사람을 아나요.wav]                    wav]


Elapsed time: 1423.48 sec
Batch 16: Processed with model htdemucs
Start from checkpoint: Results/mdx23c/model_mdx23c_ep_3_sdr_12.9580.ckpt
Instruments: ['vocals', 'other']
Total files found: 100


100%|██████████| 100/100 [1:35:45<00:00, 57.46s/it, track=임창정 - 그 사람을 아나요.wav]                  wav]  


Elapsed time: 5746.97 sec
Batch 16: Processed with model mdx23c


KeyboardInterrupt: 

In [13]:
root_dirs = [p.Results / 'htdemucs' / 'wav' / dir_name / f'batch_{i}', p.Results / 'mdx23c' / 'wav' / dir_name / f'batch_{i}']

file_pairs = get_matching_file_pairs(root_dirs)
ensemble_files(file_pairs, algorithm='avg_wave', output_dir=p.Results/"ensemble"/f"batch_{i}", output_type='.mp3')
logger.info(f"Batch {i}: Ensemble completed")
print(f"Batch {i}: Ensemble completed")

# temp WAV 파일 삭제
for wav_file in wav_files:
    file_path = temp_dir / wav_file
    if file_path.is_file():
        file_path.unlink()
        logger.info(f"Deleted {file_path}")
        
# temp_dir 삭제
if temp_dir.exists():
    shutil.rmtree(temp_dir)
    logger.info(f"Deleted directory {temp_dir}")

for root_dir in root_dirs:
    if root_dir.exists():
        shutil.rmtree(root_dir)
        logger.info(f"Deleted directory {root_dir}")


Processing files: 100%|██████████| 100/100 [10:34<00:00,  6.34s/it]


Batch 16: Ensemble completed
