In [1]:
#
# A wrapper script that trains the SELDnet. The training stops when the early stopping metric - SELD error stops improving.
#

import os
import sys

sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plot
import cls_feature_class
import cls_data_generator
import seldnet_model
import parameters
import time
from time import gmtime, strftime
import torch
import torch.nn as nn
import torch.optim as optim
plot.switch_backend('agg')
from IPython import embed
from cls_compute_seld_results import ComputeSELDResults, reshape_3Dto2D
from SELD_evaluation_metrics import distance_between_cartesian_coordinates
import seldnet_model 



In [2]:
def get_accdoa_labels(accdoa_in, nb_classes):
    x, y, z = accdoa_in[:, :, :nb_classes], accdoa_in[:, :, nb_classes:2*nb_classes], accdoa_in[:, :, 2*nb_classes:]
    sed = np.sqrt(x**2 + y**2 + z**2) > 0.5
      
    return sed, accdoa_in


def get_multi_accdoa_labels(accdoa_in, nb_classes):
    """
    Args:
        accdoa_in:  [batch_size, frames, num_track*num_axis*num_class=3*3*12]
        nb_classes: scalar
    Return:
        sedX:       [batch_size, frames, num_class=12]
        doaX:       [batch_size, frames, num_axis*num_class=3*12]
    """
    x0, y0, z0 = accdoa_in[:, :, :1*nb_classes], accdoa_in[:, :, 1*nb_classes:2*nb_classes], accdoa_in[:, :, 2*nb_classes:3*nb_classes]
    sed0 = np.sqrt(x0**2 + y0**2 + z0**2) > 0.5
    doa0 = accdoa_in[:, :, :3*nb_classes]

    x1, y1, z1 = accdoa_in[:, :, 3*nb_classes:4*nb_classes], accdoa_in[:, :, 4*nb_classes:5*nb_classes], accdoa_in[:, :, 5*nb_classes:6*nb_classes]
    sed1 = np.sqrt(x1**2 + y1**2 + z1**2) > 0.5
    doa1 = accdoa_in[:, :, 3*nb_classes: 6*nb_classes]

    x2, y2, z2 = accdoa_in[:, :, 6*nb_classes:7*nb_classes], accdoa_in[:, :, 7*nb_classes:8*nb_classes], accdoa_in[:, :, 8*nb_classes:]
    sed2 = np.sqrt(x2**2 + y2**2 + z2**2) > 0.5
    doa2 = accdoa_in[:, :, 6*nb_classes:]

    return sed0, doa0, sed1, doa1, sed2, doa2


def determine_similar_location(sed_pred0, sed_pred1, doa_pred0, doa_pred1, class_cnt, thresh_unify, nb_classes):
    if (sed_pred0 == 1) and (sed_pred1 == 1):
        if distance_between_cartesian_coordinates(doa_pred0[class_cnt], doa_pred0[class_cnt+1*nb_classes], doa_pred0[class_cnt+2*nb_classes],
                                                  doa_pred1[class_cnt], doa_pred1[class_cnt+1*nb_classes], doa_pred1[class_cnt+2*nb_classes]) < thresh_unify:
            return 1
        else:
            return 0
    else:
        return 0


def test_epoch(data_generator, model, criterion, dcase_output_folder, params, device):
    # Number of frames for a 60 second audio with 100ms hop length = 600 frames
    # Number of frames in one batch (batch_size* sequence_length) consists of all the 600 frames above with zero padding in the remaining frames
    test_filelist = data_generator.get_filelist()

    nb_test_batches, test_loss = 0, 0.
    model.eval()
    file_cnt = 0
    with torch.no_grad():
        for data, target in data_generator.generate():
            # load one batch of data
            data, target = torch.tensor(data).to(device).float(), torch.tensor(target).to(device).float()

            # process the batch of data based on chosen mode
            output = model(data)
            loss = criterion(output, target)
            if params['multi_accdoa'] is True:
                sed_pred0, doa_pred0, sed_pred1, doa_pred1, sed_pred2, doa_pred2 = get_multi_accdoa_labels(output.detach().cpu().numpy(), params['unique_classes'])
                sed_pred0 = reshape_3Dto2D(sed_pred0)
                doa_pred0 = reshape_3Dto2D(doa_pred0)
                sed_pred1 = reshape_3Dto2D(sed_pred1)
                doa_pred1 = reshape_3Dto2D(doa_pred1)
                sed_pred2 = reshape_3Dto2D(sed_pred2)
                doa_pred2 = reshape_3Dto2D(doa_pred2)
            else:
                sed_pred, doa_pred = get_accdoa_labels(output.detach().cpu().numpy(), params['unique_classes'])
                sed_pred = reshape_3Dto2D(sed_pred)
                doa_pred = reshape_3Dto2D(doa_pred)

            # dump SELD results to the correspondin file
            output_file = os.path.join(dcase_output_folder, test_filelist[file_cnt].replace('.npy', '.csv'))
            file_cnt += 1
            output_dict = {}
            if params['multi_accdoa'] is True:
                for frame_cnt in range(sed_pred0.shape[0]):
                    for class_cnt in range(sed_pred0.shape[1]):
                        # determine whether track0 is similar to track1
                        flag_0sim1 = determine_similar_location(sed_pred0[frame_cnt][class_cnt], sed_pred1[frame_cnt][class_cnt], doa_pred0[frame_cnt], doa_pred1[frame_cnt], class_cnt, params['thresh_unify'], params['unique_classes'])
                        flag_1sim2 = determine_similar_location(sed_pred1[frame_cnt][class_cnt], sed_pred2[frame_cnt][class_cnt], doa_pred1[frame_cnt], doa_pred2[frame_cnt], class_cnt, params['thresh_unify'], params['unique_classes'])
                        flag_2sim0 = determine_similar_location(sed_pred2[frame_cnt][class_cnt], sed_pred0[frame_cnt][class_cnt], doa_pred2[frame_cnt], doa_pred0[frame_cnt], class_cnt, params['thresh_unify'], params['unique_classes'])
                        # unify or not unify according to flag
                        if flag_0sim1 + flag_1sim2 + flag_2sim0 == 0:
                            if sed_pred0[frame_cnt][class_cnt]>0.5:
                                if frame_cnt not in output_dict:
                                    output_dict[frame_cnt] = []
                                output_dict[frame_cnt].append([class_cnt, doa_pred0[frame_cnt][class_cnt], doa_pred0[frame_cnt][class_cnt+params['unique_classes']], doa_pred0[frame_cnt][class_cnt+2*params['unique_classes']]])
                            if sed_pred1[frame_cnt][class_cnt]>0.5:
                                if frame_cnt not in output_dict:
                                    output_dict[frame_cnt] = []
                                output_dict[frame_cnt].append([class_cnt, doa_pred1[frame_cnt][class_cnt], doa_pred1[frame_cnt][class_cnt+params['unique_classes']], doa_pred1[frame_cnt][class_cnt+2*params['unique_classes']]])
                            if sed_pred2[frame_cnt][class_cnt]>0.5:
                                if frame_cnt not in output_dict:
                                    output_dict[frame_cnt] = []
                                output_dict[frame_cnt].append([class_cnt, doa_pred2[frame_cnt][class_cnt], doa_pred2[frame_cnt][class_cnt+params['unique_classes']], doa_pred2[frame_cnt][class_cnt+2*params['unique_classes']]])
                        elif flag_0sim1 + flag_1sim2 + flag_2sim0 == 1:
                            if frame_cnt not in output_dict:
                                output_dict[frame_cnt] = []
                            if flag_0sim1:
                                if sed_pred2[frame_cnt][class_cnt]>0.5:
                                    output_dict[frame_cnt].append([class_cnt, doa_pred2[frame_cnt][class_cnt], doa_pred2[frame_cnt][class_cnt+params['unique_classes']], doa_pred2[frame_cnt][class_cnt+2*params['unique_classes']]])
                                doa_pred_fc = (doa_pred0[frame_cnt] + doa_pred1[frame_cnt]) / 2
                                output_dict[frame_cnt].append([class_cnt, doa_pred_fc[class_cnt], doa_pred_fc[class_cnt+params['unique_classes']], doa_pred_fc[class_cnt+2*params['unique_classes']]])
                            elif flag_1sim2:
                                if sed_pred0[frame_cnt][class_cnt]>0.5:
                                    output_dict[frame_cnt].append([class_cnt, doa_pred0[frame_cnt][class_cnt], doa_pred0[frame_cnt][class_cnt+params['unique_classes']], doa_pred0[frame_cnt][class_cnt+2*params['unique_classes']]])
                                doa_pred_fc = (doa_pred1[frame_cnt] + doa_pred2[frame_cnt]) / 2
                                output_dict[frame_cnt].append([class_cnt, doa_pred_fc[class_cnt], doa_pred_fc[class_cnt+params['unique_classes']], doa_pred_fc[class_cnt+2*params['unique_classes']]])
                            elif flag_2sim0:
                                if sed_pred1[frame_cnt][class_cnt]>0.5:
                                    output_dict[frame_cnt].append([class_cnt, doa_pred1[frame_cnt][class_cnt], doa_pred1[frame_cnt][class_cnt+params['unique_classes']], doa_pred1[frame_cnt][class_cnt+2*params['unique_classes']]])
                                doa_pred_fc = (doa_pred2[frame_cnt] + doa_pred0[frame_cnt]) / 2
                                output_dict[frame_cnt].append([class_cnt, doa_pred_fc[class_cnt], doa_pred_fc[class_cnt+params['unique_classes']], doa_pred_fc[class_cnt+2*params['unique_classes']]])
                        elif flag_0sim1 + flag_1sim2 + flag_2sim0 >= 2:
                            if frame_cnt not in output_dict:
                                output_dict[frame_cnt] = []
                            doa_pred_fc = (doa_pred0[frame_cnt] + doa_pred1[frame_cnt] + doa_pred2[frame_cnt]) / 3
                            output_dict[frame_cnt].append([class_cnt, doa_pred_fc[class_cnt], doa_pred_fc[class_cnt+params['unique_classes']], doa_pred_fc[class_cnt+2*params['unique_classes']]])
            else:
                for frame_cnt in range(sed_pred.shape[0]):
                    for class_cnt in range(sed_pred.shape[1]):
                        if sed_pred[frame_cnt][class_cnt]>0.5:
                            if frame_cnt not in output_dict:
                                output_dict[frame_cnt] = []
                            output_dict[frame_cnt].append([class_cnt, doa_pred[frame_cnt][class_cnt], doa_pred[frame_cnt][class_cnt+params['unique_classes']], doa_pred[frame_cnt][class_cnt+2*params['unique_classes']]]) 
            data_generator.write_output_format_file(output_file, output_dict)

            test_loss += loss.item()
            nb_test_batches += 1
            if params['quick_test'] and nb_test_batches == 4:
                break


        test_loss /= nb_test_batches

    return test_loss


def train_epoch(data_generator, optimizer, model, criterion, params, device):
    nb_train_batches, train_loss = 0, 0.
    model.train()
    for data, target in data_generator.generate():
        # load one batch of data
        data, target = torch.tensor(data).to(device).float(), torch.tensor(target).to(device).float()
        optimizer.zero_grad()

        # process the batch of data based on chosen mode
        output = model(data)
        
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        nb_train_batches += 1
        if params['quick_test'] and nb_train_batches == 4:
            break

    train_loss /= nb_train_batches

    return train_loss

In [3]:
argv = ['','1','test_data_load']
print(argv)
if len(argv) != 3:
    print('\n\n')
    print('-------------------------------------------------------------------------------------------------------')
    print('The code expected two optional inputs')
    print('\t>> python seld.py <task-id> <job-id>')
    print('\t\t<task-id> is used to choose the user-defined parameter set from parameter.py')
    print('Using default inputs for now')
    print('\t\t<job-id> is a unique identifier which is used for output filenames (models, training plots). '
          'You can use any number or string for this.')
    print('-------------------------------------------------------------------------------------------------------')
    print('\n\n')

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
torch.autograd.set_detect_anomaly(True)

if use_cuda:
    print("Using GPU")
else:
    print("Using CPU")

# use parameter set defined by user
task_id = '1' if len(argv) < 2 else argv[1]
params = parameters.get_params(task_id)
params['quick_test'] = True

job_id = 1 if len(argv) < 3 else argv[-1]

# Training setup
train_splits, val_splits, test_splits = None, None, None
if params['mode'] == 'dev':
    if '2020' in params['dataset_dir']:
        test_splits = [1]
        val_splits = [2]
        train_splits = [[3, 4, 5, 6]]

    elif '2021' in params['dataset_dir']:
        test_splits = [6]
        val_splits = [5]
        train_splits = [[1, 2, 3, 4]]

    elif '2022' in params['dataset_dir']:
        test_splits = [[4]]
        val_splits = [[4]]
        train_splits = [[1, 2, 3]] 

    else:
        print('Using Train 1, Val 2, Test 3,4')
        test_splits = [[3,4]]
        val_splits = [[2]]
        train_splits = [[1]] 

['', '1', 'test_data_load']
Using CPU
SET: 1
USING DEFAULT PARAMETERS

Loading data from /scratch/ci411/SELD_Datasets/TNR_0518
	quick_test: False
	finetune_mode: False
	pretrained_model_weights: None
	dataset_dir: /scratch/ci411/SELD_Datasets/TNR_0518
	feat_label_dir: /scratch/ci411/DCASE_GEN/seld_features/TNR_0518
	model_dir: models/
	dcase_output_dir: results/
	mode: dev
	dataset: mic
	fs: 24000
	hop_len_s: 0.02
	label_hop_len_s: 0.1
	max_audio_len_s: 60
	nb_mel_bins: 64
	use_salsalite: False
	fmin_doa_salsalite: 50
	fmax_doa_salsalite: 2000
	fmax_spectra_salsalite: 9000
	multi_accdoa: False
	thresh_unify: 15
	label_sequence_length: 50
	batch_size: 256
	dropout_rate: 0.05
	nb_cnn2d_filt: 64
	f_pool_size: [4, 4, 2]
	nb_rnn_layers: 2
	rnn_size: 128
	self_attn: False
	nb_heads: 4
	nb_fnn_layers: 1
	fnn_size: 128
	nb_epochs: 100
	lr: 0.001
	average: macro
	lad_doa_thresh: 20
	feature_sequence_length: 250
	t_pool_size: [5, 1, 1]
	patience: 100
	unique_classes: 14
Using Train 1, Val 2, Tes

In [4]:
#for split_cnt, split in enumerate(test_splits):
split_cnt = 0
split = test_splits[0]
# Unique name for the run
loc_feat = params['dataset']

if params['dataset'] == 'mic':
    if params['use_salsalite']:
        loc_feat = '{}_salsa'.format(params['dataset'])
    else:
        loc_feat = '{}_gcc'.format(params['dataset'])
loc_output = 'multiaccdoa' if params['multi_accdoa'] else 'accdoa'

cls_feature_class.create_folder(params['model_dir'])
unique_name = '{}_{}_{}_split{}_{}_{}'.format(
    task_id, job_id, params['mode'], split_cnt, loc_output, loc_feat
)
model_name = '{}_model.h5'.format(os.path.join(params['model_dir'], unique_name))
print("unique_name: {}\n".format(unique_name))

# Load train and validation data
print('Loading training dataset:')
data_gen_train = cls_data_generator.DataGenerator(
    params=params, split=train_splits[split_cnt]
)

print('Loading validation dataset:')
data_gen_val = cls_data_generator.DataGenerator(
    params=params, split=val_splits[split_cnt], shuffle=False, per_file=True
)

# Collect i/o data size and load model configuration
data_in, data_out = data_gen_train.get_data_sizes()
model = seldnet_model.CRNN(data_in, data_out, params).to(device)

if params['finetune_mode']:
    print('Running in finetuning mode. Initializing the model to the weights - {}'.format(params['pretrained_model_weights']))
    model.load_state_dict(torch.load(params['pretrained_model_weights'], map_location='cpu'))


# Dump results in DCASE output format for calculating final scores
dcase_output_val_folder = os.path.join(params['dcase_output_dir'], '{}_{}_val'.format(unique_name, strftime("%Y%m%d%H%M%S", gmtime())))
cls_feature_class.delete_and_create_folder(dcase_output_val_folder)
print('Dumping recording-wise val results in: {}'.format(dcase_output_val_folder))

# Initialize evaluation metric class
score_obj = ComputeSELDResults(params)

# start training
best_val_epoch = -1
best_ER, best_F, best_LE, best_LR, best_seld_scr = 1., 0., 180., 0., 9999 
patience_cnt = 0

nb_epoch = 2 if params['quick_test'] else params['nb_epochs']
optimizer = optim.Adam(model.parameters(), lr=params['lr'])
if params['multi_accdoa'] is True:
    criterion = seldnet_model.MSELoss_ADPIT()
else:
    criterion = nn.MSELoss()


unique_name: 1_test_data_load_dev_split0_accdoa_mic_gcc

Loading training dataset:
Computing some stats about the dataset in /scratch/ci411/DCASE_GEN/seld_features/TNR_0518/mic_dev_norm
	Datagen_mode: dev, nb_files: 900, nb_classes:14
	nb_frames_file: 3000, feat_len: 64, nb_ch: 10, label_len:56

	Dataset: mic, split: [1]
	batch_size: 256, feat_seq_len: 250, label_seq_len: 50, shuffle: True
	Total batches in dataset: 42
	label_dir: /scratch/ci411/DCASE_GEN/seld_features/TNR_0518/mic_dev_label
 	feat_dir: /scratch/ci411/DCASE_GEN/seld_features/TNR_0518/mic_dev_norm

Loading validation dataset:
Computing some stats about the dataset in /scratch/ci411/DCASE_GEN/seld_features/TNR_0518/mic_dev_norm
	Datagen_mode: dev, nb_files: 300, nb_classes:14
	nb_frames_file: 3000, feat_len: 64, nb_ch: 10, label_len:56

	Dataset: mic, split: [2]
	batch_size: 12, feat_seq_len: 250, label_seq_len: 50, shuffle: False
	Total batches in dataset: 300
	label_dir: /scratch/ci411/DCASE_GEN/seld_features/TNR_0518/

In [10]:
# Dump results in DCASE output format for calculating final scores
dcase_output_test_folder = os.path.join(params['dcase_output_dir'], '{}_{}_test'.format(unique_name, strftime("%Y%m%d%H%M%S", gmtime())))
cls_feature_class.delete_and_create_folder(dcase_output_test_folder)
print('Dumping recording-wise test results in: {}'.format(dcase_output_test_folder))

data_gen_test = cls_data_generator.DataGenerator(
            params=params, split=test_splits[split_cnt], shuffle=False, per_file=True
        )



Dumping recording-wise test results in: results/1_test_data_load_dev_split0_accdoa_mic_gcc_20230524165022_test
Computing some stats about the dataset in /scratch/ci411/DCASE_GEN/seld_features/TNR_0518/mic_dev_norm
	Datagen_mode: dev, nb_files: 121, nb_classes:14
	nb_frames_file: 17957, feat_len: 64, nb_ch: 10, label_len:56

	Dataset: mic, split: [3, 4]
	batch_size: 72, feat_seq_len: 250, label_seq_len: 50, shuffle: False
	Total batches in dataset: 121
	label_dir: /scratch/ci411/DCASE_GEN/seld_features/TNR_0518/mic_dev_label
 	feat_dir: /scratch/ci411/DCASE_GEN/seld_features/TNR_0518/mic_dev_norm



In [12]:
test_loss = test_epoch(data_gen_test, model, criterion, dcase_output_test_folder, params, device)

use_jackknife=False
test_ER, test_F, test_LE, test_LR, test_seld_scr, classwise_test_scr = score_obj.get_SELD_Results(dcase_output_test_folder, is_jackknife=use_jackknife )

In [None]:
for sample