# EMSN 2.0 - Vocalization Training v2
## Verbeterde modellen met SpecAugment, Residual CNN & HiDrive sync

### Verbeteringen t.o.v. v1 (ultimate):
- **SpecAugment** - Frequency & time masking op spectrogrammen
- **Meer augmentatie** - Volume scaling, achtergrondgeluid, pink noise
- **Residual CNN** - Skip connections + SE attention
- **Focal Loss** - Beter voor zeldzame klassen (alarm)
- **Eigen data pipeline** - BirdNET-Pi detecties als extra trainingsdata
- **Train/Val/Test split** - Apart test set voor eerlijke evaluatie
- **Confusion matrix** - Per-soort analyse van type-verwarring
- **Temperature scaling** - Gekalibreerde confidence scores
- **HiDrive auto-upload** - Modellen direct naar Strato cloud (993 GB vrij)

### Gebruik:
1. Runtime > Change runtime type > **GPU** (A100 aanbevolen)
2. High-RAM: **Aan**
3. Plak je HiDrive SSH key in cel 3
4. **Run All** - modellen worden automatisch ge-upload naar HiDrive
5. Na afloop: download naar Pi met het commando in cel 12

In [None]:
# Cel 1: GPU & Systeem Check
!nvidia-smi

import torch
import gc
import psutil

torch.cuda.empty_cache()
gc.collect()

print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

ram_gb = psutil.virtual_memory().total / 1e9
print(f"RAM: {ram_gb:.1f} GB")
if ram_gb < 20:
    print("\u26a0\ufe0f Low RAM! Enable High-RAM in Runtime settings")
else:
    print("\u2705 High RAM beschikbaar")

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"GPU Memory: {gpu_mem:.1f} GB")

    # Stability settings
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    if 'A100' in gpu_name:
        GPU_TYPE = 'A100'
        BATCH_SIZE = 64
        print(f"\n\U0001f680 A100 gedetecteerd - Maximum performance mode")
    elif 'V100' in gpu_name:
        GPU_TYPE = 'V100'
        BATCH_SIZE = 48
    elif 'L4' in gpu_name:
        GPU_TYPE = 'L4'
        BATCH_SIZE = 48
    else:
        GPU_TYPE = 'T4'
        BATCH_SIZE = 32
else:
    GPU_TYPE = 'CPU'
    BATCH_SIZE = 16
    print("\u26a0\ufe0f Geen GPU!")

In [None]:
# Cel 2: Dependencies + rclone installeren
!pip install librosa scikit-learn scikit-image matplotlib tqdm requests -q

# rclone voor HiDrive sync
!curl -s https://rclone.org/install.sh | bash -s beta 2>/dev/null || echo "rclone al geinstalleerd"
!rclone version | head -1

print("\u2705 Dependencies + rclone ge\u00efnstalleerd")

In [None]:
# Cel 3: HiDrive SFTP verbinding + Opslag configuratie
#
# Stap 1: Plak je HiDrive SSH private key hieronder (id_ed25519_hidrive)
# Stap 2: Draai deze cel - rclone wordt geconfigureerd en verbinding getest
#
import os
import time
import json
from pathlib import Path
from datetime import datetime

# === HIDRIVE SSH KEY ===
# Plak hier de inhoud van ~/.ssh/id_ed25519_hidrive
HIDRIVE_SSH_KEY = """-----BEGIN OPENSSH PRIVATE KEY-----
PLAK_HIER_JE_KEY
-----END OPENSSH PRIVATE KEY-----"""

# SSH key opslaan
ssh_dir = Path('/root/.ssh')
ssh_dir.mkdir(exist_ok=True)
key_path = ssh_dir / 'id_ed25519_hidrive'
key_path.write_text(HIDRIVE_SSH_KEY.strip() + '\n')
key_path.chmod(0o600)

# known_hosts toevoegen (voorkomt host key prompt)
!ssh-keyscan sftp.hidrive.strato.com >> /root/.ssh/known_hosts 2>/dev/null

# rclone configureren voor HiDrive SFTP
rclone_config = """[hidrive]
type = sftp
host = sftp.hidrive.strato.com
user = ronnyclouddisk
key_file = /root/.ssh/id_ed25519_hidrive
shell_type = unix
"""
rclone_dir = Path('/root/.config/rclone')
rclone_dir.mkdir(parents=True, exist_ok=True)
(rclone_dir / 'rclone.conf').write_text(rclone_config)

# HiDrive paden
HIDRIVE_BASE = '/users/ronnyclouddisk/emsn-backups/vocalization-models'
HIDRIVE_MODELS = f'{HIDRIVE_BASE}/v2'
HIDRIVE_RESULTS = f'{HIDRIVE_BASE}/results'

# Lokale paden
DRIVE_BASE = '/content/EMSN-Vocalization'
MODELS_DIR = f'{DRIVE_BASE}/models'
AUDIO_DIR = f'{DRIVE_BASE}/audio'
OWN_DATA_DIR = f'{DRIVE_BASE}/own_data'

os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(AUDIO_DIR, exist_ok=True)
os.makedirs(OWN_DATA_DIR, exist_ok=True)

# Test verbinding en maak mappen aan
!rclone mkdir hidrive:{HIDRIVE_MODELS}
!rclone mkdir hidrive:{HIDRIVE_RESULTS}

# Check ruimte
print("HiDrive verbinding testen...")
result = !rclone about hidrive: 2>&1
for line in result:
    print(f"  {line}")

# Check bestaande modellen op HiDrive
existing = !rclone ls hidrive:{HIDRIVE_MODELS} 2>/dev/null | wc -l
print(f"\nBestaande v2 modellen op HiDrive: {existing[0].strip()}")

# === V2 CONFIGURATIE ===
VERSION = '2025_v2'

# Training parameters
EPOCHS = 60
LEARNING_RATE = 0.001
MIN_LR = 0.00001
PATIENCE = 12
WEIGHT_DECAY = 0.01

# Data parameters
MAX_RECORDINGS_PER_TYPE = 50
MAX_SEGMENTS_PER_RECORDING = 5
NUM_WORKERS = 4
MAX_CONCURRENT_DOWNLOADS = 10

# Augmentation
USE_AUGMENTATION = True
AUGMENTATION_FACTOR = 3

# SpecAugment parameters
FREQ_MASK_PARAM = 15
TIME_MASK_PARAM = 20
NUM_FREQ_MASKS = 2
NUM_TIME_MASKS = 2

# Focal Loss
FOCAL_ALPHA = 0.25
FOCAL_GAMMA = 2.0

# Xeno-canto API key
XC_API_KEY = '14258afd1c8a8e055387d012f2620e20f59ef3a2'

# Upload functie voor gebruik na elke soort
def upload_model_to_hidrive(model_path):
    """Upload model naar HiDrive direct na training."""
    model_path = Path(model_path)
    if model_path.exists():
        os.system(f'rclone copy "{model_path}" hidrive:{HIDRIVE_MODELS}/ --progress')
        return True
    return False

def upload_results_to_hidrive():
    """Upload resultaten CSV en confusion JSON naar HiDrive."""
    for f in ['results_v2.csv', 'confusions_v2.json', 'results_v2.png', 'checkpoint_v2.csv']:
        src = Path(f'{DRIVE_BASE}/{f}')
        if src.exists():
            os.system(f'rclone copy "{src}" hidrive:{HIDRIVE_RESULTS}/')

print(f"\n\U0001f4ca EMSN VOCALIZATION TRAINING v2")
print(f"{'='*50}")
print(f"   GPU: {GPU_TYPE} | Batch: {BATCH_SIZE}")
print(f"   Epochs: {EPOCHS} | Patience: {PATIENCE}")
print(f"   Recordings/type: {MAX_RECORDINGS_PER_TYPE}")
print(f"   Augmentation: {AUGMENTATION_FACTOR}x + SpecAugment")
print(f"   Loss: Focal (gamma={FOCAL_GAMMA})")
print(f"   Architecture: Residual CNN")
print(f"   Opslag: HiDrive SFTP (auto-upload)")
print(f"   Version: {VERSION}")
print(f"{'='*50}")

In [None]:
# Cel 4: Eigen data uploaden (optioneel)
#
# Upload eigen gelabelde audio bestanden vanuit BirdNET-Pi.
# Structuur: own_data/{soortnaam}/{song|call|alarm}/*.mp3
#
# Voorbeeld:
#   own_data/vink/song/Vink-95-2026-01-15-birdnet-08:30:00.mp3
#   own_data/vink/call/Vink-88-2026-01-15-birdnet-09:15:00.mp3
#
# Tip: Exporteer geverifieerde detecties uit BirdNET-Pi en label ze.

from google.colab import files as colab_files
import zipfile

UPLOAD_OWN_DATA = False  # Zet op True om eigen data te uploaden

if UPLOAD_OWN_DATA:
    print("\U0001f4e4 Upload een ZIP met eigen gelabelde audio...")
    print("Verwachte structuur: {soortnaam}/{song|call|alarm}/*.mp3")
    print()
    uploaded = colab_files.upload()

    for filename, data in uploaded.items():
        if filename.endswith('.zip'):
            zip_path = f'/content/{filename}'
            with open(zip_path, 'wb') as f:
                f.write(data)
            with zipfile.ZipFile(zip_path, 'r') as z:
                z.extractall(OWN_DATA_DIR)
            print(f"\u2705 {filename} uitgepakt naar {OWN_DATA_DIR}")

    # Toon wat er ge-upload is
    own_data_path = Path(OWN_DATA_DIR)
    for species_dir in sorted(own_data_path.iterdir()):
        if species_dir.is_dir():
            for voc_dir in sorted(species_dir.iterdir()):
                if voc_dir.is_dir():
                    count = len(list(voc_dir.glob('*.mp3')))
                    if count > 0:
                        print(f"  {species_dir.name}/{voc_dir.name}: {count} bestanden")
else:
    print("\u2139\ufe0f Eigen data overgeslagen (UPLOAD_OWN_DATA = False)")
    print("Zet UPLOAD_OWN_DATA = True om BirdNET-Pi audio toe te voegen")

In [None]:
# Cel 5: Alle 217 soorten

ALL_SPECIES = [
    # A
    ("Aalscholver", "Phalacrocorax carbo", "aalscholver"),
    ("Appelvink", "Coccothraustes coccothraustes", "appelvink"),
    # B
    ("Baardman", "Panurus biarmicus", "baardman"),
    ("Barmsijs", "Acanthis flammea", "barmsijs"),
    ("Beflijster", "Turdus torquatus", "beflijster"),
    ("Bergeend", "Tadorna tadorna", "bergeend"),
    ("Bijeneter", "Merops apiaster", "bijeneter"),
    ("Blauwborst", "Luscinia svecica", "blauwborst"),
    ("Blauwe Kiekendief", "Circus cyaneus", "blauwe_kiekendief"),
    ("Blauwe Reiger", "Ardea cinerea", "blauwe_reiger"),
    ("Boerenzwaluw", "Hirundo rustica", "boerenzwaluw"),
    ("Bokje", "Lymnocryptes minimus", "bokje"),
    ("Bontbekplevier", "Charadrius hiaticula", "bontbekplevier"),
    ("Bonte Kraai", "Corvus cornix", "bonte_kraai"),
    ("Bonte Strandloper", "Calidris alpina", "bonte_strandloper"),
    ("Bonte Vliegenvanger", "Ficedula hypoleuca", "bonte_vliegenvanger"),
    ("Boomklever", "Sitta europaea", "boomklever"),
    ("Boomkruiper", "Certhia brachydactyla", "boomkruiper"),
    ("Boomleeuwerik", "Lullula arborea", "boomleeuwerik"),
    ("Boompieper", "Anthus trivialis", "boompieper"),
    ("Boomvalk", "Falco subbuteo", "boomvalk"),
    ("Bosrietzanger", "Acrocephalus palustris", "bosrietzanger"),
    ("Bosruiter", "Tringa glareola", "bosruiter"),
    ("Bosuil", "Strix aluco", "bosuil"),
    ("Braamsluiper", "Curruca curruca", "braamsluiper"),
    ("Brandgans", "Branta leucopsis", "brandgans"),
    ("Brilduiker", "Bucephala clangula", "brilduiker"),
    ("Bruine Kiekendief", "Circus aeruginosus", "bruine_kiekendief"),
    ("Buidelmees", "Remiz pendulinus", "buidelmees"),
    ("Buizerd", "Buteo buteo", "buizerd"),
    # C
    ("Canadese Gans", "Branta canadensis", "canadese_gans"),
    ("Cetti's Zanger", "Cettia cetti", "cettis_zanger"),
    ("Citroenkanarie", "Crithagra citrinelloides", "citroenkanarie"),
    # D
    ("Dodaars", "Tachybaptus ruficollis", "dodaars"),
    ("Draaihals", "Jynx torquilla", "draaihals"),
    ("Drieteenstrandloper", "Calidris alba", "drieteenstrandloper"),
    ("Dwergstern", "Sternula albifrons", "dwergstern"),
    # E
    ("Eider", "Somateria mollissima", "eider"),
    ("Ekster", "Pica pica", "ekster"),
    ("Europese Kanarie", "Serinus serinus", "europese_kanarie"),
    # F
    ("Fazant", "Phasianus colchicus", "fazant"),
    ("Fitis", "Phylloscopus trochilus", "fitis"),
    ("Flamingo", "Phoenicopterus roseus", "flamingo"),
    ("Fluiter", "Phylloscopus sibilatrix", "fluiter"),
    ("Fuut", "Podiceps cristatus", "fuut"),
    # G
    ("Gaai", "Garrulus glandarius", "gaai"),
    ("Geelgors", "Emberiza citrinella", "geelgors"),
    ("Gekraagde Roodstaart", "Phoenicurus phoenicurus", "gekraagde_roodstaart"),
    ("Gele Kwikstaart", "Motacilla flava", "gele_kwikstaart"),
    ("Gierzwaluw", "Apus apus", "gierzwaluw"),
    ("Glanskop", "Poecile palustris", "glanskop"),
    ("Goudhaan", "Regulus regulus", "goudhaan"),
    ("Goudplevier", "Pluvialis apricaria", "goudplevier"),
    ("Goudvink", "Pyrrhula pyrrhula", "goudvink"),
    ("Grasmus", "Curruca communis", "grasmus"),
    ("Graspieper", "Anthus pratensis", "graspieper"),
    ("Graszanger", "Cisticola juncidis", "graszanger"),
    ("Grauwe Gans", "Anser anser", "grauwe_gans"),
    ("Grauwe Kiekendief", "Circus pygargus", "grauwe_kiekendief"),
    ("Grauwe Klauwier", "Lanius collurio", "grauwe_klauwier"),
    ("Grauwe Vliegenvanger", "Muscicapa striata", "grauwe_vliegenvanger"),
    ("Groene Specht", "Picus viridis", "groene_specht"),
    ("Groenling", "Chloris chloris", "groenling"),
    ("Groenpootruiter", "Tringa nebularia", "groenpootruiter"),
    ("Grote Bonte Specht", "Dendrocopos major", "grote_bonte_specht"),
    ("Grote Canadese Gans", "Branta canadensis", "grote_canadese_gans"),
    ("Grote Gele Kwikstaart", "Motacilla cinerea", "grote_gele_kwikstaart"),
    ("Grote Karekiet", "Acrocephalus arundinaceus", "grote_karekiet"),
    ("Grote Lijster", "Turdus viscivorus", "grote_lijster"),
    ("Grote Mantelmeeuw", "Larus marinus", "grote_mantelmeeuw"),
    ("Grote Zaagbek", "Mergus merganser", "grote_zaagbek"),
    ("Grote Zilverreiger", "Ardea alba", "grote_zilverreiger"),
    ("Grutto", "Limosa limosa", "grutto"),
    # H
    ("Haakbek", "Pinicola enucleator", "haakbek"),
    ("Havik", "Accipiter gentilis", "havik"),
    ("Heggenmus", "Prunella modularis", "heggenmus"),
    ("Holenduif", "Columba oenas", "holenduif"),
    ("Hop", "Upupa epops", "hop"),
    ("Houtduif", "Columba palumbus", "houtduif"),
    ("Houtsnip", "Scolopax rusticola", "houtsnip"),
    ("Huismus", "Passer domesticus", "huismus"),
    ("Huiszwaluw", "Delichon urbicum", "huiszwaluw"),
    # I
    ("IJsvogel", "Alcedo atthis", "ijsvogel"),
    # K
    ("Kanoetstrandloper", "Calidris canutus", "kanoetstrandloper"),
    ("Kauw", "Coloeus monedula", "kauw"),
    ("Keep", "Fringilla montifringilla", "keep"),
    ("Kerkuil", "Tyto alba", "kerkuil"),
    ("Kievit", "Vanellus vanellus", "kievit"),
    ("Klapekster", "Lanius excubitor", "klapekster"),
    ("Kleine Bonte Specht", "Dryobates minor", "kleine_bonte_specht"),
    ("Kleine Karekiet", "Acrocephalus scirpaceus", "kleine_karekiet"),
    ("Kleine Mantelmeeuw", "Larus fuscus", "kleine_mantelmeeuw"),
    ("Kleine Rietgans", "Anser brachyrhynchus", "kleine_rietgans"),
    ("Kleine Strandloper", "Calidris minuta", "kleine_strandloper"),
    ("Kleine Zilverreiger", "Egretta garzetta", "kleine_zilverreiger"),
    ("Kleine Zwaan", "Cygnus columbianus", "kleine_zwaan"),
    ("Kluut", "Recurvirostra avosetta", "kluut"),
    ("Kneu", "Linaria cannabina", "kneu"),
    ("Knobbelzwaan", "Cygnus olor", "knobbelzwaan"),
    ("Koekoek", "Cuculus canorus", "koekoek"),
    ("Kokmeeuw", "Chroicocephalus ridibundus", "kokmeeuw"),
    ("Kolgans", "Anser albifrons", "kolgans"),
    ("Koolmees", "Parus major", "koolmees"),
    ("Koperwiek", "Turdus iliacus", "koperwiek"),
    ("Kraanvogel", "Grus grus", "kraanvogel"),
    ("Krakeend", "Mareca strepera", "krakeend"),
    ("Kramsvogel", "Turdus pilaris", "kramsvogel"),
    ("Kruisbek", "Loxia curvirostra", "kruisbek"),
    ("Kuifeend", "Aythya fuligula", "kuifeend"),
    ("Kuifmees", "Lophophanes cristatus", "kuifmees"),
    ("Kwak", "Nycticorax nycticorax", "kwak"),
    ("Kwartel", "Coturnix coturnix", "kwartel"),
    ("Kwartelkoning", "Crex crex", "kwartelkoning"),
    # M
    ("Mandarijneend", "Aix galericulata", "mandarijneend"),
    ("Matkop", "Poecile montanus", "matkop"),
    ("Meerkoet", "Fulica atra", "meerkoet"),
    ("Merel", "Turdus merula", "merel"),
    ("Middelste Zaagbek", "Mergus serrator", "middelste_zaagbek"),
    # N
    ("Nachtegaal", "Luscinia megarhynchos", "nachtegaal"),
    ("Nachtzwaluw", "Caprimulgus europaeus", "nachtzwaluw"),
    ("Nijlgans", "Alopochen aegyptiaca", "nijlgans"),
    ("Nonnetje", "Mergellus albellus", "nonnetje"),
    # O
    ("Oehoe", "Bubo bubo", "oehoe"),
    ("Oeverloper", "Actitis hypoleucos", "oeverloper"),
    ("Oeverzwaluw", "Riparia riparia", "oeverzwaluw"),
    ("Ooievaar", "Ciconia ciconia", "ooievaar"),
    # P
    ("Paapje", "Saxicola rubetra", "paapje"),
    ("Patrijs", "Perdix perdix", "patrijs"),
    ("Pestvogel", "Bombycilla garrulus", "pestvogel"),
    ("Pijlstaart", "Anas acuta", "pijlstaart"),
    ("Pimpelmees", "Cyanistes caeruleus", "pimpelmees"),
    ("Porseleinhoen", "Porzana porzana", "porseleinhoen"),
    ("Putter", "Carduelis carduelis", "putter"),
    # R
    ("Raaf", "Corvus corax", "raaf"),
    ("Ransuil", "Asio otus", "ransuil"),
    ("Regenwulp", "Numenius phaeopus", "regenwulp"),
    ("Rietgors", "Emberiza schoeniclus", "rietgors"),
    ("Rietzanger", "Acrocephalus schoenobaenus", "rietzanger"),
    ("Rode Wouw", "Milvus milvus", "rode_wouw"),
    ("Roek", "Corvus frugilegus", "roek"),
    ("Roerdomp", "Botaurus stellaris", "roerdomp"),
    ("Roodborst", "Erithacus rubecula", "roodborst"),
    ("Roodborsttapuit", "Saxicola rubicola", "roodborsttapuit"),
    ("Roodhalsfuut", "Podiceps grisegena", "roodhalsfuut"),
    ("Rosse Grutto", "Limosa lapponica", "rosse_grutto"),
    ("Rotsduif", "Columba livia", "rotsduif"),
    # S
    ("Scharrelaar", "Coracias garrulus", "scharrelaar"),
    ("Scholekster", "Haematopus ostralegus", "scholekster"),
    ("Sijs", "Spinus spinus", "sijs"),
    ("Slechtvalk", "Falco peregrinus", "slechtvalk"),
    ("Slobeend", "Spatula clypeata", "slobeend"),
    ("Smelleken", "Falco columbarius", "smelleken"),
    ("Smient", "Mareca penelope", "smient"),
    ("Snor", "Locustella luscinioides", "snor"),
    ("Sperwer", "Accipiter nisus", "sperwer"),
    ("Spotvogel", "Hippolais icterina", "spotvogel"),
    ("Spreeuw", "Sturnus vulgaris", "spreeuw"),
    ("Sprinkhaanzanger", "Locustella naevia", "sprinkhaanzanger"),
    ("Staartmees", "Aegithalos caudatus", "staartmees"),
    ("Stadsduif", "Columba livia domestica", "stadsduif"),
    ("Steenloper", "Arenaria interpres", "steenloper"),
    ("Steenuil", "Athene noctua", "steenuil"),
    ("Stormmeeuw", "Larus canus", "stormmeeuw"),
    # T
    ("Tafeleend", "Aythya ferina", "tafeleend"),
    ("Taigaboomkruiper", "Certhia familiaris", "taigaboomkruiper"),
    ("Tapuit", "Oenanthe oenanthe", "tapuit"),
    ("Tjiftjaf", "Phylloscopus collybita", "tjiftjaf"),
    ("Toendrarietgans", "Anser serrirostris", "toendrarietgans"),
    ("Torenvalk", "Falco tinnunculus", "torenvalk"),
    ("Tuinfluiter", "Sylvia borin", "tuinfluiter"),
    ("Tureluur", "Tringa totanus", "tureluur"),
    ("Turkse Tortel", "Streptopelia decaocto", "turkse_tortel"),
    # V
    ("Veldleeuwerik", "Alauda arvensis", "veldleeuwerik"),
    ("Velduil", "Asio flammeus", "velduil"),
    ("Vink", "Fringilla coelebs", "vink"),
    ("Visdief", "Sterna hirundo", "visdief"),
    ("Vuurgoudhaan", "Regulus ignicapilla", "vuurgoudhaan"),
    # W
    ("Waterhoen", "Gallinula chloropus", "waterhoen"),
    ("Waterral", "Rallus aquaticus", "waterral"),
    ("Watersnip", "Gallinago gallinago", "watersnip"),
    ("Wielewaal", "Oriolus oriolus", "wielewaal"),
    ("Wilde Eend", "Anas platyrhynchos", "wilde_eend"),
    ("Wilde Zwaan", "Cygnus cygnus", "wilde_zwaan"),
    ("Winterkoning", "Troglodytes troglodytes", "winterkoning"),
    ("Wintertaling", "Anas crecca", "wintertaling"),
    ("Witgat", "Tringa ochropus", "witgat"),
    ("Witte Kwikstaart", "Motacilla alba", "witte_kwikstaart"),
    ("Wulp", "Numenius arquata", "wulp"),
    # Z
    ("Zanglijster", "Turdus philomelos", "zanglijster"),
    ("Zilvermeeuw", "Larus argentatus", "zilvermeeuw"),
    ("Zomertortel", "Streptopelia turtur", "zomertortel"),
    ("Zwarte Kraai", "Corvus corone", "zwarte_kraai"),
    ("Zwarte Mees", "Periparus ater", "zwarte_mees"),
    ("Zwarte Roodstaart", "Phoenicurus ochruros", "zwarte_roodstaart"),
    ("Zwarte Ruiter", "Tringa erythropus", "zwarte_ruiter"),
    ("Zwarte Specht", "Dryocopus martius", "zwarte_specht"),
    ("Zwartkop", "Sylvia atricapilla", "zwartkop"),
]

print(f"Te trainen: {len(ALL_SPECIES)} soorten")

In [None]:
# Cel 6: Xeno-canto API + Download functies
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed


def search_xeno_canto(scientific_name, voc_type='song', max_results=100):
    """Zoek opnames op Xeno-canto API v3."""
    parts = scientific_name.split()
    if len(parts) < 2:
        return []

    genus, species = parts[0].lower(), parts[1].lower()

    if ' ' in voc_type:
        type_query = f'type:"{voc_type}"'
    else:
        type_query = f'type:{voc_type}'

    # Kwaliteit A = beste opnames
    query = f'gen:{genus} sp:{species} {type_query} q:A'
    url = f'https://xeno-canto.org/api/3/recordings?query={query}&key={XC_API_KEY}'

    try:
        response = requests.get(url, timeout=30)
        if response.status_code == 200:
            return response.json().get('recordings', [])[:max_results]
        return []
    except Exception:
        return []


def download_single(args):
    recording, output_dir = args
    xc_id = recording['id']
    file_url = recording.get('file', '')

    if not file_url:
        return None

    if file_url.startswith('//'):
        file_url = 'https:' + file_url
    elif not file_url.startswith('http'):
        file_url = 'https://xeno-canto.org' + file_url

    output_path = output_dir / f"XC{xc_id}.mp3"

    if output_path.exists():
        return output_path

    try:
        response = requests.get(file_url, timeout=60)
        if response.status_code == 200:
            with open(output_path, 'wb') as f:
                f.write(response.content)
            return output_path
    except Exception:
        pass
    return None


def download_recordings_parallel(recordings, output_dir, max_workers=10):
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    downloaded = []
    args_list = [(rec, output_dir) for rec in recordings]

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(download_single, args): args[0]['id']
            for args in args_list
        }
        for future in as_completed(futures):
            result = future.result()
            if result:
                downloaded.append(result)

    return downloaded


print("\u2705 Download functies geladen")

In [None]:
# Cel 7: Spectrogram generatie + SpecAugment + Audio Augmentation
import librosa
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from functools import partial

SAMPLE_RATE = 48000
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512
FMIN = 500
FMAX = 8000
SEGMENT_DURATION = 3.0


# --- SpecAugment: masking op spectrogrammen ---

def spec_augment(mel_spec, num_freq_masks=NUM_FREQ_MASKS, freq_mask_param=FREQ_MASK_PARAM,
                 num_time_masks=NUM_TIME_MASKS, time_mask_param=TIME_MASK_PARAM):
    """
    SpecAugment: frequency en time masking op mel spectrogram.
    Paper: Park et al. 2019 - "SpecAugment: A Simple Data Augmentation Method
    for Automatic Speech Recognition"
    """
    spec = mel_spec.copy()
    n_mels, n_frames = spec.shape

    # Frequency masking - maskeert horizontale banden
    for _ in range(num_freq_masks):
        f = np.random.randint(0, freq_mask_param + 1)
        f0 = np.random.randint(0, max(1, n_mels - f))
        spec[f0:f0 + f, :] = 0.0

    # Time masking - maskeert verticale banden
    for _ in range(num_time_masks):
        t = np.random.randint(0, time_mask_param + 1)
        t0 = np.random.randint(0, max(1, n_frames - t))
        spec[:, t0:t0 + t] = 0.0

    return spec


# --- Audio augmentation functies ---

def augment_audio(audio, sr):
    """
    Genereer geaugmenteerde versies van audio.
    Verbeterd t.o.v. v1: meer variatie, realistischer.
    """
    augmented = []
    target_len = len(audio)

    # 1. Pitch shift (+/- 1-3 semitones, random)
    try:
        n_steps = np.random.choice([-3, -2, -1, 1, 2, 3])
        shifted = librosa.effects.pitch_shift(audio, sr=sr, n_steps=n_steps)
        augmented.append(shifted)
    except Exception:
        pass

    # 2. Time stretch (random factor 0.85-1.15)
    try:
        rate = np.random.uniform(0.85, 1.15)
        stretched = librosa.effects.time_stretch(audio, rate=rate)
        if len(stretched) > target_len:
            stretched = stretched[:target_len]
        else:
            stretched = np.pad(stretched, (0, target_len - len(stretched)))
        augmented.append(stretched)
    except Exception:
        pass

    # 3. Volume scaling (random -6dB tot +6dB)
    gain_db = np.random.uniform(-6, 6)
    gain_linear = 10 ** (gain_db / 20)
    volume_scaled = np.clip(audio * gain_linear, -1.0, 1.0)
    augmented.append(volume_scaled)

    # 4. Pink noise (realistischer dan Gaussian voor buitenopnames)
    pink = _generate_pink_noise(target_len)
    snr_db = np.random.uniform(15, 25)  # Realistische SNR range
    signal_power = np.mean(audio ** 2) + 1e-10
    noise_power = signal_power / (10 ** (snr_db / 10))
    noisy = audio + pink * np.sqrt(noise_power / (np.mean(pink ** 2) + 1e-10))
    augmented.append(np.clip(noisy, -1.0, 1.0).astype(np.float32))

    # 5. Combinatie: pitch + noise
    if len(augmented) >= 2:
        combo = augmented[0] + pink * np.sqrt(noise_power / (np.mean(pink ** 2) + 1e-10)) * 0.5
        augmented.append(np.clip(combo, -1.0, 1.0).astype(np.float32))

    return augmented


def _generate_pink_noise(n_samples):
    """Genereer pink noise (1/f) - realistischer dan wit/Gaussian."""
    white = np.random.randn(n_samples)
    fft = np.fft.rfft(white)
    freqs = np.fft.rfftfreq(n_samples)
    freqs[0] = 1  # Voorkom deling door 0
    fft = fft / np.sqrt(freqs)
    pink = np.fft.irfft(fft, n=n_samples)
    pink = pink / (np.max(np.abs(pink)) + 1e-10)
    return pink.astype(np.float32)


# --- Spectrogram conversie ---

def audio_to_spectrogram(audio, sr=SAMPLE_RATE, apply_spec_augment=False):
    """Converteer audio naar mel spectrogram, optioneel met SpecAugment."""
    mel_spec = librosa.feature.melspectrogram(
        y=audio, sr=sr,
        n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH,
        fmin=FMIN, fmax=FMAX
    )
    mel_db = librosa.power_to_db(mel_spec, ref=np.max)
    mel_norm = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-8)

    # SpecAugment toepassen op training data
    if apply_spec_augment:
        mel_norm = spec_augment(mel_norm)

    if mel_norm.shape != (128, 128):
        from skimage.transform import resize
        mel_norm = resize(mel_norm, (128, 128), anti_aliasing=True)

    return mel_norm.astype(np.float32)


def process_single_audio(audio_path, max_segments=5, use_augmentation=True):
    """Verwerk audio naar spectrogrammen met augmentation + SpecAugment."""
    try:
        audio, sr = librosa.load(str(audio_path), sr=SAMPLE_RATE, mono=True)
    except Exception:
        return []

    segment_samples = int(SEGMENT_DURATION * SAMPLE_RATE)
    spectrograms = []
    segments_processed = 0

    for i in range(0, len(audio), segment_samples):
        if segments_processed >= max_segments:
            break

        segment = audio[i:i + segment_samples]
        if len(segment) < segment_samples // 2:
            continue

        if len(segment) < segment_samples:
            segment = np.pad(segment, (0, segment_samples - len(segment)))

        # Origineel (zonder SpecAugment)
        spec = audio_to_spectrogram(segment, apply_spec_augment=False)
        spectrograms.append(spec)

        # Origineel met SpecAugment
        spec_aug = audio_to_spectrogram(segment, apply_spec_augment=True)
        spectrograms.append(spec_aug)

        if use_augmentation:
            # Audio augmentaties
            aug_segments = augment_audio(segment, SAMPLE_RATE)
            for aug_seg in aug_segments[:AUGMENTATION_FACTOR]:
                # Geaugmenteerde audio + SpecAugment (50% kans)
                apply_sa = np.random.random() > 0.5
                aug_spec = audio_to_spectrogram(aug_seg, apply_spec_augment=apply_sa)
                spectrograms.append(aug_spec)

        segments_processed += 1

    return spectrograms


def process_audio_files_parallel(audio_paths, max_segments=5, max_workers=4,
                                  use_augmentation=True):
    all_specs = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        func = partial(process_single_audio, max_segments=max_segments,
                       use_augmentation=use_augmentation)
        results = list(executor.map(func, audio_paths))
    for specs in results:
        all_specs.extend(specs)
    return all_specs


print("\u2705 Spectrogram + SpecAugment + Audio Augmentation geladen")

In [None]:
# Cel 8: Residual CNN Model + Focal Loss
import torch
import torch.nn as nn
import torch.nn.functional as F


class ResidualBlock(nn.Module):
    """Residual block met skip connection."""

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout2d(0.2)

        # Skip connection met 1x1 conv als dimensies niet matchen
        self.skip = nn.Identity()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = self.skip(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        out = F.relu(out + identity)  # Skip connection
        return out


class VocalizationResNet(nn.Module):
    """
    Residual CNN voor vocalisatie classificatie.
    Verbeteringen t.o.v. v1 (ultimate):
    - Skip connections voor betere gradient flow
    - Channel attention (SE block) op bottleneck
    - Global average pooling ipv flatten
    """

    def __init__(self, num_classes=3):
        super().__init__()

        # Stem: initieel conv blok
        self.stem = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Residual blocks: 32 -> 64 -> 128 -> 256
        self.layer1 = ResidualBlock(32, 64, stride=1)
        self.layer2 = ResidualBlock(64, 128, stride=2)
        self.layer3 = ResidualBlock(128, 256, stride=2)

        # Squeeze-and-Excitation (channel attention)
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.Sigmoid()
        )

        # Global Average Pooling + Classifier
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        # Channel attention
        se_weight = self.se(x).unsqueeze(-1).unsqueeze(-1)
        x = x * se_weight

        x = self.gap(x)
        x = self.classifier(x)
        return x


class FocalLoss(nn.Module):
    """
    Focal Loss voor ongebalanceerde klassen.
    Geeft meer gewicht aan moeilijke/zeldzame voorbeelden (alarm).
    Paper: Lin et al. 2017 - "Focal Loss for Dense Object Detection"
    """

    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha  # Per-class weging
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, weight=self.alpha, reduction='none')
        p_t = torch.exp(-ce_loss)
        focal_loss = ((1 - p_t) ** self.gamma) * ce_loss
        return focal_loss.mean()


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\u2705 Residual CNN + Focal Loss klaar voor {device}")

# Test model
test_model = VocalizationResNet(num_classes=3).to(device)
test_input = torch.randn(4, 1, 128, 128).to(device)
test_output = test_model(test_input)
params = sum(p.numel() for p in test_model.parameters())
print(f"   Parameters: {params:,} ({params/1e6:.1f}M)")
print(f"   Output shape: {test_output.shape}")
del test_model, test_input, test_output

In [None]:
# Cel 9: Training Pipeline v2 met alle verbeteringen
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from sklearn.metrics import confusion_matrix


def compute_class_weights(y, num_classes, device):
    """Bereken inverse frequency weights voor Focal Loss."""
    counts = np.bincount(y, minlength=num_classes).astype(np.float32)
    counts = np.maximum(counts, 1)  # Voorkom deling door 0
    weights = 1.0 / counts
    weights = weights / weights.sum() * num_classes  # Normaliseer
    return torch.FloatTensor(weights).to(device)


def calibrate_temperature(model, val_loader, device):
    """
    Temperature scaling voor gekalibreerde confidence.
    Zoekt optimale temperature T zodat softmax(logits/T) goed gekalibreerd is.
    """
    model.eval()
    logits_list, labels_list = [], []

    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch = X_batch.to(device)
            logits = model(X_batch)
            logits_list.append(logits.cpu())
            labels_list.append(y_batch)

    logits_all = torch.cat(logits_list)
    labels_all = torch.cat(labels_list)

    # Grid search voor optimale temperature
    best_t = 1.0
    best_nll = float('inf')

    for t in np.arange(0.5, 5.0, 0.1):
        scaled = logits_all / t
        nll = F.cross_entropy(scaled, labels_all).item()
        if nll < best_nll:
            best_nll = nll
            best_t = t

    return best_t


def train_species_v2(dutch_name, scientific_name, dirname):
    """
    V2 training met SpecAugment, Residual CNN, Focal Loss,
    eigen data, train/val/test split en temperature scaling.
    """
    print(f"\n{'='*60}")
    print(f"\U0001f426 {dutch_name} ({scientific_name})")
    print(f"{'='*60}")

    start_time = time.time()
    audio_dir = Path(f'{DRIVE_BASE}/audio/{dirname}')

    X_all, y_all = [], []
    voc_types = [('song', 0), ('call', 1), ('alarm call', 2)]

    # --- Stap 1: Download Xeno-canto data ---
    for voc_type, label in voc_types:
        print(f"  \U0001f4e5 {voc_type}...", end=' ')
        recordings = search_xeno_canto(scientific_name, voc_type,
                                       max_results=MAX_RECORDINGS_PER_TYPE)

        if not recordings:
            print("0 gevonden")
            continue

        type_dir = audio_dir / voc_type.replace(' ', '_')
        audio_files = download_recordings_parallel(
            recordings[:MAX_RECORDINGS_PER_TYPE],
            type_dir,
            max_workers=MAX_CONCURRENT_DOWNLOADS
        )
        print(f"{len(audio_files)} files", end=' ')

        if audio_files:
            specs = process_audio_files_parallel(
                audio_files,
                max_segments=MAX_SEGMENTS_PER_RECORDING,
                max_workers=NUM_WORKERS,
                use_augmentation=USE_AUGMENTATION
            )
            for spec in specs:
                X_all.append(spec)
                y_all.append(label)
            print(f"\u2192 {len(specs)} specs")
        else:
            print()

    # --- Stap 2: Eigen data toevoegen ---
    own_species_dir = Path(OWN_DATA_DIR) / dirname
    if own_species_dir.exists():
        own_count = 0
        type_map = {'song': 0, 'call': 1, 'alarm': 2}
        for voc_name, label in type_map.items():
            own_type_dir = own_species_dir / voc_name
            if own_type_dir.exists():
                own_files = list(own_type_dir.glob('*.mp3'))
                if own_files:
                    # Eigen data: meer augmentatie (5x) want het is geverifieerd
                    specs = process_audio_files_parallel(
                        own_files,
                        max_segments=MAX_SEGMENTS_PER_RECORDING,
                        max_workers=NUM_WORKERS,
                        use_augmentation=True
                    )
                    for spec in specs:
                        X_all.append(spec)
                        y_all.append(label)
                    own_count += len(specs)
        if own_count > 0:
            print(f"  \U0001f3af Eigen data: +{own_count} specs")

    if len(X_all) < 30:
        print(f"  \u26a0\ufe0f Te weinig data ({len(X_all)})")
        return None, 'insufficient_data', {}

    X = np.array(X_all)
    y = np.array(y_all)

    # Label remapping
    unique_labels = np.unique(y)
    num_classes = len(unique_labels)

    if num_classes < 2:
        print(f"  \u26a0\ufe0f Slechts 1 klasse")
        return None, 'single_class', {}

    label_map = {old: new for new, old in enumerate(unique_labels)}
    y_remapped = np.array([label_map[l] for l in y])

    all_class_names = ['song', 'call', 'alarm']
    class_names = [all_class_names[l] for l in unique_labels]

    unique, counts = np.unique(y_remapped, return_counts=True)
    class_dist = {class_names[i]: int(counts[i]) for i in range(len(counts))}
    print(f"  \U0001f4ca {len(X)} specs: {class_dist}")

    # --- Stap 3: Train/Val/Test split (70/15/15) ---
    X_trainval, X_test, y_trainval, y_test = train_test_split(
        X, y_remapped, test_size=0.15, random_state=42, stratify=y_remapped
    )
    X_train, X_val, y_train, y_val = train_test_split(
        X_trainval, y_trainval, test_size=0.176,  # 0.176 * 0.85 ~= 0.15
        random_state=42, stratify=y_trainval
    )
    print(f"  Split: train={len(X_train)} val={len(X_val)} test={len(X_test)}")

    # DataLoaders
    train_loader = DataLoader(
        TensorDataset(torch.FloatTensor(X_train).unsqueeze(1),
                      torch.LongTensor(y_train)),
        batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS,
        pin_memory=True
    )
    val_loader = DataLoader(
        TensorDataset(torch.FloatTensor(X_val).unsqueeze(1),
                      torch.LongTensor(y_val)),
        batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True
    )
    test_loader = DataLoader(
        TensorDataset(torch.FloatTensor(X_test).unsqueeze(1),
                      torch.LongTensor(y_test)),
        batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True
    )

    # --- Stap 4: Model + Focal Loss + Optimizer ---
    model = VocalizationResNet(num_classes=num_classes).to(device)
    class_weights = compute_class_weights(y_train, num_classes, device)
    criterion = FocalLoss(alpha=class_weights, gamma=FOCAL_GAMMA)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE,
                                  weight_decay=WEIGHT_DECAY)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2,
                                            eta_min=MIN_LR)

    # --- Stap 5: Training loop ---
    best_acc = 0
    best_state = None
    patience_counter = 0

    try:
        for epoch in range(EPOCHS):
            model.train()
            train_loss = 0
            for X_batch, y_batch in train_loader:
                X_batch = X_batch.to(device, non_blocking=True)
                y_batch = y_batch.to(device, non_blocking=True)

                optimizer.zero_grad()
                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss += loss.item()

            scheduler.step(epoch)

            # Validation
            model.eval()
            val_correct = 0
            with torch.no_grad():
                for X_batch, y_batch in val_loader:
                    X_batch = X_batch.to(device, non_blocking=True)
                    y_batch = y_batch.to(device, non_blocking=True)
                    outputs = model(X_batch)
                    val_correct += (outputs.argmax(1) == y_batch).sum().item()

            val_acc = val_correct / len(y_val)

            if val_acc > best_acc:
                best_acc = val_acc
                best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= PATIENCE:
                print(f"  \u23f9\ufe0f Early stop @ epoch {epoch+1}")
                break

    except RuntimeError as e:
        if 'CUDA' in str(e):
            print(f"  \u26a0\ufe0f CUDA error")
            torch.cuda.empty_cache()
            gc.collect()
            if best_state is None:
                return None, 'cuda_error', {}
        else:
            raise

    if best_state is None:
        return None, 'training_failed', {}

    # --- Stap 6: Test evaluatie ---
    model.load_state_dict(best_state)
    model = model.to(device)
    model.eval()

    test_preds, test_labels = [], []
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            outputs = model(X_batch)
            test_preds.extend(outputs.argmax(1).cpu().numpy())
            test_labels.extend(y_batch.numpy())

    test_acc = np.mean(np.array(test_preds) == np.array(test_labels))

    # Confusion matrix
    cm = confusion_matrix(test_labels, test_preds, labels=list(range(num_classes)))
    cm_dict = {}
    for i, true_name in enumerate(class_names):
        for j, pred_name in enumerate(class_names):
            if cm[i, j] > 0:
                cm_dict[f"{true_name}>{pred_name}"] = int(cm[i, j])

    # --- Stap 7: Temperature scaling ---
    temperature = calibrate_temperature(model, val_loader, device)

    # --- Stap 8: Model opslaan ---
    model_path = Path(f'{DRIVE_BASE}/models/{dirname}_cnn_{VERSION}.pt')
    torch.save({
        'model_state_dict': best_state,
        'num_classes': num_classes,
        'class_names': class_names,
        'label_map': label_map,
        'accuracy': best_acc,
        'test_accuracy': float(test_acc),
        'temperature': temperature,
        'confusion_matrix': cm_dict,
        'class_distribution': class_dist,
        'species_name': dutch_name,
        'scientific_name': scientific_name,
        'version': VERSION,
        'architecture': 'ResNet_SE',
    }, model_path)

    del model, train_loader, val_loader, test_loader
    torch.cuda.empty_cache()
    gc.collect()

    elapsed = time.time() - start_time
    print(f"  \u2705 {model_path.name} | Val: {best_acc:.1%} | Test: {test_acc:.1%} | T={temperature:.2f} | {elapsed:.0f}s")

    if cm_dict:
        # Toon alleen verwarringen (true != pred)
        confusions = {k: v for k, v in cm_dict.items() if k.split('>')[0] != k.split('>')[1]}
        if confusions:
            print(f"  \U0001f500 Verwarringen: {confusions}")

    return test_acc, 'success', cm_dict


print("\u2705 V2 Training pipeline geladen")

In [None]:
# Cel 10: Start Training (met auto-upload naar HiDrive)
import pandas as pd

results = []
all_confusions = {}
start_all = time.time()

print(f"{'='*60}")
print(f"\U0001f680 EMSN VOCALIZATION TRAINING v2")
print(f"{'='*60}")
print(f"Start: {datetime.now().strftime('%H:%M:%S')}")
print(f"Soorten: {len(ALL_SPECIES)}")
print(f"GPU: {GPU_TYPE} | Architecture: ResNet+SE")
print(f"Augmentation: {AUGMENTATION_FACTOR}x + SpecAugment")
print(f"Loss: Focal (gamma={FOCAL_GAMMA})")
print(f"Opslag: HiDrive (auto-upload per model)")
print(f"{'='*60}")

successful = 0
failed = 0

for i, (dutch, scientific, dirname) in enumerate(ALL_SPECIES):
    try:
        acc, status, cm = train_species_v2(dutch, scientific, dirname)
        results.append({
            'species': dutch,
            'scientific': scientific,
            'test_accuracy': acc,
            'status': status
        })
        if cm:
            all_confusions[dutch] = cm

        if status == 'success':
            successful += 1
            # Direct uploaden naar HiDrive na succesvolle training
            model_file = f'{DRIVE_BASE}/models/{dirname}_cnn_{VERSION}.pt'
            upload_model_to_hidrive(model_file)
        else:
            failed += 1

    except Exception as e:
        print(f"  \u274c Error: {str(e)[:80]}")
        results.append({
            'species': dutch,
            'scientific': scientific,
            'test_accuracy': None,
            'status': 'error'
        })
        failed += 1

    # Checkpoint elke 20 soorten: resultaten opslaan + uploaden
    if (i + 1) % 20 == 0:
        pd.DataFrame(results).to_csv(f'{DRIVE_BASE}/checkpoint_v2.csv', index=False)
        with open(f'{DRIVE_BASE}/confusions_v2.json', 'w') as f:
            json.dump(all_confusions, f, indent=2)
        upload_results_to_hidrive()
        elapsed = time.time() - start_all
        eta = (elapsed / (i + 1)) * (len(ALL_SPECIES) - i - 1)
        print(f"\n  \U0001f4be [{i+1}/{len(ALL_SPECIES)}] \u2705{successful} \u274c{failed} | HiDrive synced | ETA: {eta/60:.0f}min\n")

# Finale resultaten uploaden
elapsed_all = time.time() - start_all
print(f"\n{'='*60}")
print(f"\U0001f3c1 TRAINING VOLTOOID!")
print(f"{'='*60}")
print(f"Tijd: {elapsed_all/60:.1f} minuten")
print(f"Succesvol: {successful}/{len(ALL_SPECIES)}")
print(f"Mislukt: {failed}/{len(ALL_SPECIES)}")
print(f"\nFinale upload naar HiDrive...")
upload_results_to_hidrive()
print(f"\u2705 Alle modellen staan op HiDrive!")

In [None]:
# Cel 11: Resultaten & Analyse
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(results)
df.to_csv(f'{DRIVE_BASE}/results_v2.csv', index=False)

# Confusion data opslaan
with open(f'{DRIVE_BASE}/confusions_v2.json', 'w') as f:
    json.dump(all_confusions, f, indent=2)

successful_df = df[df['status'] == 'success']

print(f"\n{'='*60}")
print(f"\U0001f4ca RESULTATEN v2")
print(f"{'='*60}")
print(f"Getraind: {len(successful_df)}/{len(df)}")

if len(successful_df) > 0:
    accs = successful_df['test_accuracy']
    print(f"\nTest Accuracy:")
    print(f"  Gemiddeld: {accs.mean():.1%}")
    print(f"  Mediaan:   {accs.median():.1%}")
    print(f"  Min:       {accs.min():.1%}")
    print(f"  Max:       {accs.max():.1%}")
    print(f"  >90%:      {(accs > 0.9).sum()} soorten")
    print(f"  >80%:      {(accs > 0.8).sum()} soorten")
    print(f"  <50%:      {(accs < 0.5).sum()} soorten")

    # Top 10 en Bottom 10
    print(f"\n\U0001f3c6 Top 10 (beste modellen):")
    for _, row in successful_df.nlargest(10, 'test_accuracy').iterrows():
        print(f"  {row['test_accuracy']:.1%} - {row['species']}")

    print(f"\n\u26a0\ufe0f Bottom 10 (meeste verbetering nodig):")
    for _, row in successful_df.nsmallest(10, 'test_accuracy').iterrows():
        print(f"  {row['test_accuracy']:.1%} - {row['species']}")

    # Meest verwarde soorten
    print(f"\n\U0001f500 Meest verwarde types:")
    confusion_scores = []
    for species, cm in all_confusions.items():
        total = sum(cm.values())
        errors = sum(v for k, v in cm.items() if k.split('>')[0] != k.split('>')[1])
        if total > 0:
            confusion_scores.append((species, errors / total, errors, total))
    confusion_scores.sort(key=lambda x: x[1], reverse=True)
    for species, rate, errors, total in confusion_scores[:10]:
        print(f"  {rate:.0%} verward ({errors}/{total}) - {species}")

    # Histogram
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    axes[0].hist(accs, bins=20, edgecolor='black', color='steelblue')
    axes[0].axvline(accs.mean(), color='red', linestyle='--', label=f'Gem: {accs.mean():.1%}')
    axes[0].axvline(accs.median(), color='orange', linestyle='--', label=f'Med: {accs.median():.1%}')
    axes[0].set_xlabel('Test Accuracy')
    axes[0].set_ylabel('Aantal soorten')
    axes[0].set_title('Accuracy Verdeling v2')
    axes[0].legend()

    # Confusion rate histogram
    if confusion_scores:
        rates = [s[1] for s in confusion_scores]
        axes[1].hist(rates, bins=20, edgecolor='black', color='coral')
        axes[1].set_xlabel('Verwarrings-percentage')
        axes[1].set_ylabel('Aantal soorten')
        axes[1].set_title('Type Verwarring per Soort')

    plt.tight_layout()
    plt.savefig(f'{DRIVE_BASE}/results_v2.png', dpi=100)
    plt.show()

# Mislukte soorten
failed_df = df[df['status'] != 'success']
if len(failed_df) > 0:
    print(f"\n\u274c Mislukt ({len(failed_df)}):")
    for status, group in failed_df.groupby('status'):
        print(f"  {status}: {', '.join(group['species'].tolist())}")

In [None]:
# Cel 12: HiDrive status & download naar Pi
#
# Modellen zijn al automatisch ge-upload tijdens training.
# Deze cel toont wat er op HiDrive staat en hoe je ze naar de Pi haalt.
#

print(f"{'='*60}")
print(f"\U0001f4c1 HIDRIVE MODEL STATUS")
print(f"{'='*60}")

# Toon modellen op HiDrive
print(f"\nModellen op HiDrive ({HIDRIVE_MODELS}):")
!rclone ls hidrive:{HIDRIVE_MODELS} 2>/dev/null | wc -l | xargs -I{} echo "  Totaal: {} bestanden"
!rclone size hidrive:{HIDRIVE_MODELS} 2>/dev/null

print(f"\nResultaten op HiDrive ({HIDRIVE_RESULTS}):")
!rclone ls hidrive:{HIDRIVE_RESULTS} 2>/dev/null

print(f"\n{'='*60}")
print(f"DOWNLOAD NAAR PI")
print(f"{'='*60}")
print(f"""
Draai dit commando op je Pi (emsn2-zolder) om de v2 modellen op te halen:

  # Maak map aan
  mkdir -p /mnt/nas-docker/emsn-vocalization/data/models

  # Download v2 modellen van HiDrive naar NAS
  rclone copy hidrive:{HIDRIVE_MODELS}/ \\
    /mnt/nas-docker/emsn-vocalization/data/models/ \\
    --progress --sftp-key-file ~/.ssh/id_ed25519_hidrive

  # Of configureer rclone eerst:
  # rclone config (type=sftp, host=sftp.hidrive.strato.com,
  #                user=ronnyclouddisk, key_file=~/.ssh/id_ed25519_hidrive)

  # Herstart vocalization enricher om nieuwe modellen te laden
  sudo systemctl restart vocalization-enricher

Klaar! De v2 modellen worden automatisch gebruikt (hogere prioriteit).
""")