In [1]:
from datasets import DatasetFactory, Category, Label
from datasets import SupportedSources as Sources
from datasets.SupportedSources import SupportedSourceTypes as SST
from utils.audio.features import get_feature_transformer, FeatureType
from utils import config
from models.panns import CNN10
from utils.pytorch import Trainer, Tester
from torch.utils.data import DataLoader
from torch import nn
import os
import utils.wlog as wlog

epochs = 50
SR = 44100
DURATION = 10
BATCH_SIZE = 32
ACCURACY_THRESHOLD = 0.5
feature_type = FeatureType.MEL_SPECTROGRAM


def get_extractor():
    extractor = get_feature_transformer(
        feature_type, **config.features[feature_type.name]
    )
    return extractor

In [2]:
import wandb
import datetime
wandb.login(key=os.environ['WANDB_API_KEY'], force=True)
time_point = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

wandb.init(
    project='ESIL-Noise.Model-S',
    name = time_point,
)

save_dir = f'training/{datetime.datetime.now().strftime("%Y-%m-%d")}'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mreviy[0m ([33mesil[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\Vi\.netrc


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

In [3]:
bird_label = Label(
    name='鸟叫',
    sources=Sources.get_data_source(SST.NATURE).get_childs(
        ['北红尾鸲叫声', '叉尾太阳鸟叫声', '大鹰鹃叫声', '强脚树莺叫声', '普通夜鹰叫声', '棕颈钩嘴鹛叫声', '淡脚柳莺叫声']
    )
)
labels = [
    Sources.get_data_source(SST.NATURE).get_child("雷声"),
    Sources.get_data_source(SST.NATURE).get_child("蛙声"),
    bird_label,
] + Sources.get_data_source(SST.TRAFFIC).childs
category = Category('Noise', labels)

In [4]:
category = Category("Noise", Sources.get_data_source(SST.NATURE).childs[:2])

In [5]:
dataset_factory = DatasetFactory(category)

train_data = dataset_factory.create_dataset(train=True, target_sr=SR, duration=DURATION, extractor=get_extractor())
test_data = dataset_factory.create_dataset(train=False, target_sr=SR, duration=DURATION, extractor=get_extractor())

labels_info = category.labels_info
wandb.log({"Labels": labels_info})
num_classes = len(labels_info)

In [6]:

def get_dataloader():
    train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
    return train_dataloader, test_dataloader

def get_model(input_size):
    model = CNN10(num_class=num_classes, input_size=input_size)
    return model
    
def get_file_path(index):
    return os.path.basename(test_data._get_audio_path(index))

def get_label_name(index):
    return category.get_label(index).name

In [7]:
import tqdm

train_dataloader, test_dataloader = get_dataloader()
input_size=test_data[0][0].shape[2]
print(f"{input_size=}")
model = get_model(input_size=input_size)
loss = nn.CrossEntropyLoss()
trainer = Trainer(
    model, 
    train_dataloader,
    loss,
    using_amp=False
)
# trainer.reload_trainer(model_path='test_model.pth', optimizer_path='test_params.pth')


input_size=862


In [8]:
# metrics = None
# tester = Tester.from_trainer(trainer, test_dataloader, num_classes, accuracy=0.0 if metrics is None else metrics.accuracy)
# metrics, bad_cases = tester.test_an_epoch(get_file_path=get_file_path, tqdm_instance=None)

In [9]:
metrics = None

tqdm_instance = tqdm.tqdm(range(epochs))
for i in tqdm_instance:
    train_loss = trainer.train_an_epoch(tqdm_instance=tqdm_instance)
    tester = Tester.from_trainer(trainer, test_dataloader, num_classes, accuracy=0.0 if metrics is None else metrics.accuracy)
    metrics, bad_cases = tester.test_an_epoch(get_file_path=get_file_path, get_label_name=get_label_name, tqdm_instance=tqdm_instance)
    if metrics.accuracy > ACCURACY_THRESHOLD:
        tester.save_model(os.path.join(save_dir, f"model_{i}_{metrics.accuracy:.2f}.pth"))
    log_info = {
        "Epoch": i+1,
        "train_loss": train_loss
    }
    for key, value in metrics.model_dump().items():
        if value is not None:
            log_info[key] = value
    wandb.log(log_info)
    wandb.log({"bad_cases": wlog.df2table(bad_cases)})

[valid] accuracy: 0.6000, loss: 0.6677: 100%|██████████| 1/1 [02:23<00:00, 143.32s/it]


In [10]:
wandb.finish()

VBox(children=(Label(value='0.031 MB of 0.031 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Epoch,▁
accuracy,▁
auc,▁
f1_score_micro,▁
loss,▁
precision,▁
recall,▁
train_loss,▁

0,1
Epoch,1.0
accuracy,0.6
auc,0.69716
f1_score_micro,0.6
loss,0.66766
precision,0.72857
recall,0.6
train_loss,0.69909
