In [None]:
# Make sure that when no in docker, packages are installed
import os
RUN_PLATFORM = os.getenv("RUN_PLATFORM")
if RUN_PLATFORM == "LOCAL":
	!pip install -r requirements.txt
	data_path = "./data/"
elif RUN_PLATFORM == "KAGGLE":
	data_path = ""
metadata_path = os.path.join(self.auddata_pathio_dir, "train_metadata.csv")
audio_dir = os.path.join(self.auddata_pathio_dir, "train_audio")

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline
	# avoid rare bugs when image won't plot
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf

import pytorch_lightning as pl
from pytorch_lightning import Trainer

# loads .env file into runtime env vars
import os.path
if not os.path.isfile(".env"):
	raise Exception(".env must be created by copying .env.template and filling out the values")
from dotenv import load_dotenv
load_dotenv()

import wandb
wandb_api_key = os.getenv("WANDB_API_KEY")
wandb_username = os.getenv("WANDB_USERNAME")
#wandb.login(key=wandb_api_key) # TODO remove comment when needed

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader
from torchvision import transforms

# make it not fail in terminal execution
from IPython.display import display
from IPython.display import Audio

from types import SimpleNamespace
	# make dict to class
import pathlib
import platform

from sklearn.model_selection import train_test_split

wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: asztrikx (asztrikx-budapesti-m-szaki-s-gazdas-gtudom-nyi-egyetem). Use `wandb login --relogin` to force relogin
wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\asztr\_netrc


## Hyperparameter configuration, set random states

In [None]:
# hyperparameters
random_state = 42 # random state
test_val_size = 0.2 # size of test and validation together to the whole dataset
test_size = 0.5 # size of test to test and validation together
num_workers = 2 # number of workers for data loading
batch_size = 128

eval_only = True
is_score_on = True
visualize_and_analyze = True

# Faster exec. on newer GPUs
torch.set_float32_matmul_precision('medium')

# Fix random early on
np.random.seed(random_state)
torch.manual_seed(random_state)
torch.cuda.manual_seed_all(random_state);

## Custom Dataset

In [None]:
class BirdsDataset(Dataset):
	# override
	def __init__(self, metadata_csv, audio_dir, transform = None, target_transform = None):
		self.df = pd.read_csv(metadata_csv)
		self.audio_dir = audio_dir
		self.transform = transform
		self.target_transform = target_transform

	# override
	def __len__(self):
		return len(self.df)

	# override
	def __getitem__(self, index):
		audio_path = os.path.join(self.audio_dir, self.df.iloc[index]["filename"])
		audio = read_image(audio_path) # TODO find library to load 32kHZ ogg audio

		label = {
			"label": self.df.iloc[index]["primary_label"],
			"latitude": self.df.iloc[index]["latitude"],
			"longitude": self.df.iloc[index]["longitude"],
			"author": self.df.iloc[index]["author"],
		}
		
		if self.transform:
			audio = self.transform(audio)
		if self.target_transform:
			label = self.target_transform(label)
		return audio, label

## Data module

In [None]:
class BirdsDataModule(pl.LightningDataModule):
	def __init__(self, batch_size):
		super().__init__()
		self.batch_size = batch_size
		self.base_transform = transforms.Compose([
			# TODO
		])
		self.reverse_transform = transforms.Compose([
			# TODO
		])

		self.dataset_notransform = BirdsDataset(metadata_path, audio_dir, transform = None)
		self.dataset = BirdsDataset(metadata_path, audio_dir, transform = self.base_transform)

	# override
	def setup(self, stage=None):
		train, test_val = train_test_split(self.dataset, test_size=test_val_size, random_state=random_state)
		test, val = train_test_split(test_val, test_size=test_size, random_state=random_state)
		self.train_dataset = train
		self.val_dataset = val
		self.test_dataset = test

	# override
	def train_dataloader(self):
		return torch.utils.data.DataLoader(
			self.train_dataset,
			batch_size=self.batch_size,
			num_workers=num_workers,
			pin_memory=True,
			drop_last=True,
			shuffle=False,
			persistent_workers=True
				# avoid recreating after every fast epoch
		)

	# override
	def val_dataloader(self):
		return torch.utils.data.DataLoader(
			self.val_dataset,
			batch_size=self.batch_size,
			num_workers=num_workers,
			pin_memory=True,
			drop_last=True,
			shuffle=False,
			persistent_workers=True
		)

	# override
	def test_dataloader(self):
		return DataLoader(
			self.test_dataset,
			batch_size=self.batch_size,
			num_workers=num_workers,
			pin_memory=True,
			drop_last=True,
			shuffle=False,
			persistent_workers=True
		)

birds_dm = BirdsDataModule(batch_size)
birds_dm.prepare_data()
birds_dm.setup()

## Data visualization

## Data analysis

## Metrics

## Baseline

## Model