In [1]:
import os
import shutil
from glob import glob
import numpy as np
import pandas as pd
from tqdm import tqdm
import wandb

import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
# from sklearn.metrics import confusion_matrix, classification_report, accuracy_score


# Dataset folder organize

### X_Dataset Split bird_call, background

total : 20000  
bird_call : 10017  
background : 9983

In [None]:
dataRoot = "C:/Users/CHAM/Desktop/2021_class/에너지환경통계/EES_project/dataset"
metaDF = pd.read_csv(dataRoot + '/BirdVoxDCASE20k_csvpublic.csv')

for idx, row in metaDF.iterrows():
    itemid = row['itemid']
    src_path = f"{dataRoot}/wav/{itemid}.wav"
    if row['hasbird'] == 1:
        dist_path = f"{dataRoot}/bird_call/{itemid}.wav"
    else:
        dist_path = f"{dataRoot}/background/{itemid}.wav"
        
    shutil.move(src_path, dist_path)

### X_Dataset Split Train, Val

train  
bird_call : 8014   
background : 7986

val  
bird_call : 2003  
background : 1997

In [None]:
bird_call_list = glob(dataRoot + '/bird_call/*.wav')
background_list = glob(dataRoot + '/background/*.wav')

np.random.seed(len(bird_call_list))
val_bird_call = np.random.choice(bird_call_list, round(len(bird_call_list)*0.2), replace=False)

np.random.seed(len(background_list))
val_background = np.random.choice(background_list, round(len(background_list)*0.2), replace=False)

for file in val_bird_call:
    fileName = os.path.basename(file)
    shutil.move(file, dataRoot + '/val/bird_call/' + fileName)

for file in val_background:
    fileName = os.path.basename(file)
    shutil.move(file, dataRoot + '/val/background/' + fileName)


# Dataset

In [12]:
class BirdVoxDataset(Dataset):
    def __init__(self, config, split):
        self.seg = config['seg']
        self.categories = config['categories']
        self.num_class = len(self.categories)
        self.transforms = nn.Sequential(
            T.Resample(config['origin_Fs'], config['Fs']),
            T.MelSpectrogram(sample_rate=config['Fs'], n_fft=config['n_fft'], n_mels=config['n_mels'])
        )
        
        
        self.file_list = []
        for category in self.categories:
            self.file_list.extend(glob(f"{config['data_root']}/{split}/{category}/*.wav"))
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, index):
        audio, sample_rate = torchaudio.load(self.file_list[index])
        if len(audio) > 1:
            audio = audio.mean(dim=0)                                       
        audio = self.transforms(audio)

        for i in range(self.num_class):
            if self.categories[i] in self.file_list[index]:
                label = i

        return audio, label

# Model

# Train

input size: torch.Size([16, 1, 64, 313]), labels: torch.Size([16])

In [8]:
def train():
    # ============== Config Init ==============
    config = {
        "dataset" : "BirdVox-DCASE-20k",
        "data_root" : "C:/Users/CHAM/Desktop/2021_class/EES/EES_project/dataset",
        "categories" : ["background", "bird_call"],
        "seg" : 1,
        "origin_Fs" : 44100,
        "Fs" : 16000,
        "n_fft" : 1024,
        "n_mels" : 64,
        'model_name' : 'ResNet_attention',
        'drop_rate' : 0.1,
        'epochs': 50,
        'h_test' : 5,
        'batch_size': 16,
        'learning_rate': 0.01,
        'h_stepsize' : 2,
        'h_decay' : 0.95,
        'optimizer': 'sgd'
        'device' : torch.device("cuda" if torch.cuda.is_available() else "cpu")
    }

    #     wandb.init(config=config, project='EES_proejct', entity='jaekyeong')
    #     wandb.run.name = config_defaults['model_name']


    # ============== Data Load ==============
    train_set = BirdVoxDataset(config, split='train')
    train_loader = DataLoader(train_set, batch_size=config['batch_size'], num_workers=0, shuffle=True)

    val_set = BirdVoxDataset(config, split='val')
    test_loader = DataLoader(val_set, batch_size=config['batch_size'], num_workers=0, shuffle=True)

In [13]:
train()

input size: torch.Size([16, 1, 64, 313]), labels: torch.Size([16])


In [None]:
# ============== Model Init ==============
    
    model_config = {
        # 'Simple_CNN' : models.CNN(config.num_classes).to(device)
        # 'CNN_v2' : models.CNN_v2(config.num_classes).to(device),
        # 'ResNet18' : models.resnet18().to(device),
        # 'ResNet50' : models.resnet50().to(device),
        # 'DenseNet201' : models.densenet201(drop_rate=float(config.drop_rate), num_classes=int(config.num_classes)).to(device),
        'ResNet_attention' : models.ResidualNet(depth=101, num_classes=trainset.num_class, att_type='CBAM').to(device)      # att_type : CBAM or BAM
    }
    model = model_config[config.model_name]
    model = nn.DataParallel(model)
    loss_fn = nn.CrossEntropyLoss()

   # Define the optimizer
    if config.optimizer=='sgd':
      optim = torch.optim.SGD(model.parameters(),lr=config.learning_rate, momentum=0.9)
    elif config.optimizer=='adam':
      optim = torch.optim.Adam(model.parameters(),lr=config.learning_rate)
    
    scheduler = torch.optim.lr_scheduler.StepLR(optim, config.h_stepsize, config.h_decay)

    wandb.watch(model)

    # ============== Train & Test ==============
    best_labels, best_y_hat, best_acc = [], [], 0.
    for epoch in range(config.epochs):
        # ========> Train
        model.train()

        total_loss, running_loss = 0.0,  0.0
        for batch_idx, data in enumerate(tqdm(train_loader, desc=f'{epoch + 1:2d}')):
            optim.zero_grad()
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)

            loss = loss_fn(outputs, labels)
            loss.backward()
            optim.step()
            total_loss += loss.item()
        
        # logging
        wandb.log({"Train Loss" : total_loss, "global_step" : epoch+1})
        print(f'[{(epoch+1):2d}/{config.epochs:4d}] loss: {total_loss:.6f}')
        
        scheduler.step()

        # ========> Test
        total_loss = 0.0
        if(epoch % config.h_test == config.h_test-1):
            model.eval()
            save_model(model_save_path, config.model_name, model, epoch)
            
            y_hat_total, labels_total, total_loss = [], [], 0.0
            with torch.no_grad():
                for i, data in enumerate(test_loader):
                    inputs, labels = data[0].to(device), data[1].to(device)
                    labels_total += labels.tolist()
                    outputs = model(inputs)
                    loss = loss_fn(outputs, labels)
                    total_loss += loss.item()

                    y_hat = torch.argmax(outputs, dim=1)
                    y_hat_total += y_hat.tolist()
                
                acc = print_report(labels_total, y_hat_total, total_loss, testset.categories)
       
                # Best Accuracy Check
                if acc > best_acc:
                    best_acc = acc
                    best_labels = labels_total
                    best_y_hat = y_hat_total
        
                wandb.log({"Test Acc": 100.*acc, 
                    "Test Loss": total_loss,
                    "learning_rate": optim.param_groups[0]['lr'], 
                    "global_step" : epoch+1})
            
    wandb.log({"conf_mat" : wandb.sklearn.plot_confusion_matrix(best_labels, best_y_hat, testset.categories)})