In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import sklearn
from sklearn import mixture
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import time
import timedelta
import sys
import random
import cv2
import importlib
import argparse
import glob

EOL = os.path.join('') # Change accordingly
sys.path.append(EOL)

import gnezdilnice_train
import gnezdilnice_plotters
import gnezdilnice_dataloader
import gnezdilnice_models
import gnezdilnice_utils

# change the width of the cells
from IPython.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

### Define paths

In [None]:
GNEZDILNICE = os.path.join(EOL)
DATA = gnezdilnice_utils.check_folder(os.path.join(GNEZDILNICE, 'data'))
RESULTS = gnezdilnice_utils.check_folder(os.path.join(GNEZDILNICE, 'results'))
OUTPUTS = gnezdilnice_utils.check_folder(os.path.join(GNEZDILNICE, 'outputs'))

### Setup parameters

In [None]:
parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('--dataset', default='gnezdilnice_spectrograms_50-1450_128', type=str)
parser.add_argument('--model', default='resnet9', type=str)
parser.add_argument('--method', default='base', type=str)
parser.add_argument('--split', default='cv', type=str, help='type of experiment (cross-validation, loso)')
parser.add_argument('--faint', default=False, type=str, help='are faint buzzes included')
parser.add_argument('--high', default=False, type=str, help='are high buzzes included')
parser.add_argument('--seed_data', default=3, type=int, help='dataset seed when selecting folds')
parser.add_argument('--num_epochs', default=100, type=int)
parser.add_argument('--num_steps', default=100, type=int)
parser.add_argument('--batch_size', default=128, type=int, help='train batchsize')
parser.add_argument('--lr_max', default=0.01, type=float, help='maximum allowed learning rate')
parser.add_argument('--use_sched', default=True, type=bool, help='whether to use learning rate scheduler')
parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay (L2 penalty)')
parser.add_argument('--grad_clip', default=0.1, type=float, help='gradient clipping to prevent exploding gradients')
parser.add_argument('--depth', default=0, type=int)
parser.add_argument('--DATA', default = DATA, type=str, help='path to data')
parser.add_argument('--RESULTS', default = RESULTS, type=str, help='path to results')
parser.add_argument('--seed_fix', default = 4, type=int, help='seed to fix the numpy, random, torch, and os random seeds')
parser.add_argument('-f') # dummy argument to prevent an error, since argparse is a module designed to parse the arguments passed from the command line
args = parser.parse_args()

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Device set to: cuda')
else: 
    device = torch.device('cpu')
    print('Device set to: cpu')

### Load dataset

In [None]:
%%time
args.dataset = 'gnezdilnice_spectrograms_50-1450_128'
DATASET = os.path.join(DATA, f'{args.dataset}.dat')
dataset = gnezdilnice_utils.file2dict(DATASET)
print(f'Dataset {DATASET} has been loaded')

### Helper functions

In [None]:
def load_model(args, device):
    RESULTS_ARGS = gnezdilnice_utils.results_dir(args)
    MODEL = os.path.join(RESULTS_ARGS, 'model.pth')
    model = gnezdilnice_models.ResNet9(num_classes = 2)
    model = nn.DataParallel(model) # sets to all available cuda devices
    model.to(device)
    model.load_state_dict(torch.load(MODEL))
    model.eval()
    return model

def predict_buzz_probabilities_rec(data, seed_datas, device):
    # find minimum of whole dataset
    data_min = -1.891314
    data_rec  = torch.tensor(data)
    # initialize empty data_rec
    data_rec_halves = torch.ones((data_rec.shape[0]+1, data_rec.shape[1], data_rec.shape[2]))*data_min
    for drh, d in zip(data_rec_halves, data_rec):
        drh[:, :64] = d[:, :64]
    data_rec_halves[-1][:, 64:] = data_rec[-1][:, 64:]
    data_rec_halves = data_rec_halves[:, None, :, :]
    #print(f'{data_rec_halves.shape=}')
    # Check models trained on the whole data (different data shuffle)
    pred_proba_persec_allseeds = []
    for seed_data in seed_datas:
        args.seed_data = seed_data
        # load the model
        model = load_model(args, device)
        # iterate over data in smaller chunkes not to overflow the memory
        batch = 64
        start_index = 0
        end_index = batch
        pred_proba_seed = []
        with torch.no_grad():
            for i in range(int(np.ceil(data_rec_halves.shape[0]/batch))):
                #print(start_index, end_index)
                data_rec_batch = data_rec_halves[start_index:end_index].to(device)
                output = model(data_rec_batch)
                pred_proba = F.softmax(output, dim=1).cpu().detach().numpy()[:, 1] # buzz prediction probability
#                 for d, p in zip(data_rec_batch, pred_proba):
#                     plt.imshow(d.detach().cpu().numpy()[0])
#                     plt.title(p)
#                     plt.show()
#                     plt.close()
                pred_proba_seed = pred_proba_seed + list(pred_proba)
                start_index = start_index + batch
                end_index = end_index + batch
                if end_index > data_rec.shape[0]:
                    end_index = data_rec.shape[0]
        pred_proba_persec_seed = [x for x in pred_proba_seed for _ in (0, 1)]
        pred_proba_persec_allseeds.append(pred_proba_persec_seed)
    mean_array = np.mean(pred_proba_persec_allseeds, axis=0)
    std_array = np.std(pred_proba_persec_allseeds, axis=0)
    return mean_array, std_array, pred_proba_persec_allseeds

def plot_buzz_probablity(pp_mean, pp_std, RECORDING, args, save=True):
    color = 'orange'
    plt.figure(figsize=(25, 2))
    upper = [min(1, x+sd) for x, sd in zip(pp_mean, pp_std)]
    lower = [max(0, x-sd) for x, sd in zip(pp_mean, pp_std)]
    times = np.arange(0.5, len(pp_mean)+0.5, 1)
    plt.plot(pp_mean, color=color )
    plt.fill_between(times, lower, upper, alpha=0.25, color=color)
    plt.xlim(left=0, right=len(pp_mean)+1)
    plt.title(RECORDING)
    plt.ylabel('Buzz probability')
    plt.xlabel('Time [s]')
    plt.xticks(np.arange(0, len(pp_mean), 60*5))
    FN = os.path.join(OUTPUTS, f'buzzproba_{args.dataset}_{args.method}_fnt={args.faint}_high={args.high}_epochs={args.num_epochs}_bs={args.batch_size}_lrmax={args.lr_max}_{RECORDING}.pdf')
    plt.grid()
    plt.tight_layout()
    if save:
        plt.savefig(FN, dpi=600)
    plt.show()
    plt.close()
    return

def predict_labels(data, seed_datas, device):
    # find minimum of whole dataset
    data_min = -1.891314
    data_rec = torch.tensor(data)
    data_rec = data_rec[:, None, :, :]
    pred_proba_persec_allseeds = []
    for seed_data in seed_datas:
        args.seed_data = seed_data
        # load the model
        model = load_model(args, device)
        # iterate over data in smaller chunkes not to overflow the memory
        batch = 64
        start_index = 0
        end_index = batch
        pred_proba_seed = []
        with torch.no_grad():
            for i in range(int(np.ceil(data_rec.shape[0]/batch))):
                data_rec_batch = data_rec[start_index:end_index].to(device)
                output = model(data_rec_batch)
                pred_proba = F.softmax(output, dim=1).cpu().detach().numpy()[:, 1] # buzz prediction probability
                pred_proba_seed = pred_proba_seed + list(pred_proba)
                start_index = start_index + batch
                end_index = end_index + batch
                if end_index > data_rec.shape[0]:
                    end_index = data_rec.shape[0]
        pred_proba_persec_allseeds.append(pred_proba_seed)
    mean_array = np.mean(pred_proba_persec_allseeds, axis=0)
    std_array = np.std(pred_proba_persec_allseeds, axis=0)
    pred_array = [0 if value < 0.5 else 1 for value in mean_array]
    return mean_array, std_array, pred_array

### Cross validation approach

In [None]:
args.split='cv'

args.batch_size = 64
args.num_epochs = 10
args.lr_max = 0.001
args.faint = False
args.high = False

args.method = 'base'
for seed in [1, 2, 3, 4, 5]:
    args.seed_data = seed
    if gnezdilnice_utils.experiment_already_done(args):
        print('Already done:', gnezdilnice_utils.results_dir(args))
        continue
    gnezdilnice_train.train_model(args, dataset, device)

### Leave-One-Location-Out Approach

In [None]:
args.split='loc'

args.batch_size = 64
args.num_epochs = 10
args.lr_max = 0.001
args.faint = False
args.high = False

for seed in [1, 2, 3]:
    args.seed_data = seed
    if gnezdilnice_utils.experiment_already_done(args):
        print('Already done:', gnezdilnice_utils.results_dir(args))
        continue
    gnezdilnice_train.train_model(args, dataset, device)

### Train on all data

In [None]:
# select hyperparameters based on the CV experiments
args.model = 'resnet9'
args.split='alltrain'

args.batch_size = 64
args.num_epochs = 10
args.lr_max = 0.001
args.faint = False
args.high = False

args.method = 'base'
for seed_data in [1, 2, 3, 4, 5]:
    args.seed_data = seed_data
    gnezdilnice_train.train_model(args, dataset, device)

### Validation: make predictions on the validation data using a model trained on all data

In [None]:
# trained models args
args.dataset = f'gnezdilnice_spectrograms_50-1450_128'
args.model = 'resnet9'
args.split='alltrain'
args.batch_size = 64
args.num_epochs = 10
args.lr_max = 0.001
args.faint = False
args.high = False
args.method = 'base'
seed_datas = [1, 2, 3, 4, 5]

VALID_DATAS = [
    'gnezdilnice_validation-20240314140023-001_crop_spectrograms_50-1450_128.dat',
    'gnezdilnice_validation-20240314145532-002_crop_spectrograms_50-1450_128.dat',
       ]

for VALID_DATA in VALID_DATAS:
    DICT =  os.path.join(OUTPUTS, f'output_{VALID_DATA.split(".dat")[0]}.pkl')
    print(f'Getting output for: {VALID_DATA.split(".dat")[0]}')
    DATASET = os.path.join(DATA, VALID_DATA)
    outputs = {}
    print(f'\t{DATASET}')
    dataset = gnezdilnice_utils.file2dict(DATASET)
    data = np.array(dataset['data'])
    location_rec = np.array(dataset['location'], dtype='str')[0]
    date_rec = np.array(dataset['datetime'], dtype='str')[0].split(' ')[0]
    time_rec = np.array(dataset['datetime'], dtype='str')[0].split(' ')[1]
    # get output
    pp_mean, pp_std, _ = predict_labels(data, seed_datas, device)
    # plot
    plot_buzz_probablity(pp_mean, pp_std, VALID_DATA.split(".dat")[0], args, save=False)
    # write to dictionary
    outputs_date = {
                   'location':location_rec, 
                   'time':time_rec,
                   'buzz_mean': pp_mean,
                   'buzz_std': pp_std,
                    }
    outputs[date_rec] = outputs_date
    gnezdilnice_utils.save_dict(outputs, DICT)
    print(f'Saved outputs to: {DICT}')