In [None]:
%load_ext autoreload 
%autoreload 2

import yaml
import json
import numpy as np
import torch
from pathlib import Path
from torch.utils.data import DataLoader, ConcatDataset
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from fusanet_utils.datasets.external import get_label_transforms
from fusanet_utils.datasets.simulated import SimulatedPoliphonic
from fusanet_utils.datasets.fusa import FUSA_dataset
from fusanet_utils.transforms import Collate_and_transform
    
preds_soft, labels, distances, places, names = [], [], [], [], []
experiment_path = Path('../experiments/Poliphonic-PANN-sed-no-pretrained/')
print(experiment_path)
categories = json.load(open(str(experiment_path / 'index_to_name.json')))
model = torch.load(str(experiment_path / 'model.pt'))
model.eval()
params = yaml.safe_load(open(str(experiment_path / 'params.yaml')))

import torchaudio
import pandas as pd
from torch.utils.data import Dataset
from typing import Tuple

In [None]:
singapura_classes = {
    0: {0: 'Other', 1: 'Screeching', 2: 'Plastic crinkling', 3: 'Cleaning', 4: 'Gear'},
    1: {0: 'Engine (other)', 1: 'Small engine', 2: 'Medium engine', 3: 'Large engine'},
    2: {0: 'Machinery impact (other)', 1: 'Rock drill', 2: 'Jackhammer', 3: 'Hoe ram', 4: 'Pile driver'},
    3: {0: 'Non-machinery impact (other)', 1: 'Glass breaking', 2: 'Car crash', 3: 'Explosion'},
    4: {0: 'Powered saw (other)', 1: 'Chainsaw', 2: 'Small/medium rotating saw', 3: 'Large rotating saw'},
    5: {0: 'Alert signal (other)', 1: 'Car horn', 2: 'Car alarm', 3: 'Siren', 4: 'Reverse beeper'},
    6: {0: 'Music (other)', 1: 'Stationary music', 2: 'Mobile music'},
    7: {0: 'Human voice (other)', 1: 'Talking', 2: 'Shouting', 3: 'Large crowd', 4: 'Amplified speech', 5: 'Singing'},
    8: {0: 'Human movement (other)', 1: 'Footsteps', 2: 'Clapping'},
    9: {0: 'Animal (other)', 1: 'Dog barking', 2: 'Bird chirping', 3: 'Insect chirping'},
    10: {0: 'Water (other)', 1: 'Hose pump'},
    11: {0: 'Weather (other)', 1: 'Rain', 2: 'Thunder', 3: 'Wind'},
    12: {0: 'Brake (other)', 1: 'Friction brake', 2: 'Exhaust brake'},
    13: {0: 'Train (other)', 1: 'Electric train'}
}

def translate_singapura_classes(singapura_class):
    category = int(singapura_class.split('-')[0])
    subcategory = int(singapura_class.split('-')[1])
    if category == 13 and subcategory == 2:
        subcategory = 0
    return singapura_classes[category][subcategory]

In [None]:
class SINGAPURA(Dataset):
    
    def __init__(self, categories):
        repo_path = '../'
        label_transforms = get_label_transforms(repo_path, "Singapura")
        self.file_list, self.label_list = [], []
        self.categories = categories
        folder = Path('../datasets/SINGAPURA/labels_public')
        for file_path in tqdm(list(folder.rglob('*.csv'))):
            df = pd.read_csv(file_path)
            for file_name, metadata in df.groupby('filename'):
                folder = (file_name.split('][')[1]).split('T')[0]
                file_path = Path('../datasets/SINGAPURA/labelled/') / folder / file_name
                if not file_path.exists():
                    print(file_path)
                    continue
                self.file_list.append(file_path)
            metadata = metadata[["event_label", "proximity", "onset", "offset"]]
            metadata = metadata.rename(columns={"event_label":"class", "onset": "start (s)", "offset": "end (s)"})
            metadata["class"] = metadata["class"].apply(lambda x: translate_singapura_classes(x))
            label_exists = metadata["class"].apply(lambda label: label in label_transforms)
            labels = metadata["class"].loc[label_exists].apply(lambda label: label_transforms[label])
            metadata["class"] = labels
            self.label_list.append(metadata)
            
    def __getitem__(self, idx: int) -> Tuple[Path, pd.DataFrame]:
        return (self.file_list[idx], self.label_list[idx])

    def __len__(self) -> int:
        return len(self.file_list)

In [None]:
dataset = SINGAPURA(list(categories.values()))

fusa_dataset = FUSA_dataset(ConcatDataset([dataset]), feature_params=params["features"])
fusa_loader = DataLoader(fusa_dataset, batch_size=10, shuffle=False, pin_memory=True, num_workers=2,
                         collate_fn=Collate_and_transform(params["features"]))

In [None]:
preds_model, labels_model, file_names = [], [], []
with torch.no_grad():
    for sample in tqdm(fusa_loader):
        preds_model.append(model(sample).numpy())
        labels_model.append(sample['label'].numpy())
        file_names.append(sample['filename'])
preds_soft.append(np.concatenate(preds_model))
labels.append(np.concatenate(labels_model))
names.append(np.concatenate(file_names))