In [None]:
import os
import re
import requests
import subprocess
from pathlib import Path
from tqdm import tqdm


def get_direct_file_link(mailru_file_url: str) -> str:
    """
    Преобразует публичную ссылку вида:
        https://cloud.mail.ru/public/<key>/<subkey>/<filename>
    в прямую ссылку на CDN, по которой можно скачать файл через wget или requests.

    Возвращает прямую ссылку для скачивания.
    """
    resp = requests.get(mailru_file_url)
    if resp.status_code != 200:
        raise RuntimeError(f"Ошибка {resp.status_code} при запросе {mailru_file_url}")

    match = re.search(r'dispatcher.*?weblink_get.*?url":"(.*?)"', resp.text)
    if not match:
        raise RuntimeError("Не удалось найти CDN ссылку в HTML Mail.ru")

    base_url = match.group(1)
    parts = mailru_file_url.strip("/").split("/")[-3:]
    return f"{base_url}/{parts[0]}/{parts[1]}/{parts[2]}"


def download_from_mailru(file_url: str, local_name: str, force: bool = False, show_progress: bool = True):
    """
    Скачивает файл с Mail.ru по публичной ссылке.

    Args:
        file_url: ссылка на файл в облаке Mail.ru.
        local_name: имя файла для сохранения.
        force: если True — перекачивает даже если файл уже есть.
        show_progress: показывать ли прогресс-бар.
    """
    local_path = Path(local_name)
    if local_path.exists() and not force:
        print(f"Файл {local_name} уже существует, пропускаем скачивание.")
        return

    direct = get_direct_file_link(file_url)
    print(f"Скачиваем {file_url} → {local_name}")

    with requests.get(direct, stream=True) as r:
        r.raise_for_status()
        total_size = int(r.headers.get("content-length", 0))
        block_size = 8192
        with open(local_name, "wb") as f, tqdm(
            total=total_size,
            unit="B",
            unit_scale=True,
            unit_divisor=1024,
            desc=f"Downloading {local_name}",
            disable=not show_progress,
        ) as bar:
            for chunk in r.iter_content(block_size):
                f.write(chunk)
                bar.update(len(chunk))

    print(f"Файл {local_name} успешно скачан ({os.path.getsize(local_name)/1e6:.1f} MB).")

In [None]:
# Ссылки на данные по задаче
train_link = "https://cloud.mail.ru/public/Gsyr/8VxmbhAaZ/train_data.tar"
test_link  = "https://cloud.mail.ru/public/Gsyr/8VxmbhAaZ/test_data.tar"

In [None]:
# Если скорость загрузки низкая — это может быть связано с CDN.
# Попробуйте перезапустить ячейку: при новом соединении может попасться другой узел CDN,
# и загрузка обычно проходит быстрее (2-3 минуты при нормальном узле).
download_from_mailru(train_link, "train_data.tar")
download_from_mailru(test_link, "test_data.tar")

Скачиваем https://cloud.mail.ru/public/Gsyr/8VxmbhAaZ/train_data.tar → train_data.tar


Downloading train_data.tar: 100%|██████████| 2.30G/2.30G [10:06<00:00, 4.07MB/s]


Файл train_data.tar успешно скачан (2472.0 MB).
Скачиваем https://cloud.mail.ru/public/Gsyr/8VxmbhAaZ/test_data.tar → test_data.tar


Downloading test_data.tar: 100%|██████████| 707M/707M [01:12<00:00, 10.2MB/s]

Файл test_data.tar успешно скачан (741.0 MB).





In [None]:
# Распаковка
subprocess.run(["tar", "xf", "train_data.tar"], check=True)
subprocess.run(["tar", "xf", "test_data.tar"], check=True)
print("Готово.")

Готово.


In [None]:
folder_path = '/content/test_opus/audio'
files_before = len([f for f in Path(folder_path).rglob('*') if f.is_file()])
print(f"Файлов до удаления: {files_before}")

!find {folder_path} -type f -name "._*" -delete

files_after = len([f for f in Path(folder_path).rglob('*') if f.is_file()])
deleted = files_before - files_after

print(f"Удалено файлов: {deleted}")
print(f"Файлов после удаления: {files_after}")

Файлов до удаления: 54000
Удалено файлов: 27000
Файлов после удаления: 27000


In [None]:
folder_path = '/content/train_opus/audio'
files_before = len([f for f in Path(folder_path).rglob('*') if f.is_file()])
print(f"Файлов до удаления: {files_before}")

!find {folder_path} -type f -name "._*" -delete

files_after = len([f for f in Path(folder_path).rglob('*') if f.is_file()])
deleted = files_before - files_after

print(f"Удалено файлов: {deleted}")
print(f"Файлов после удаления: {files_after}")

Файлов до удаления: 180000
Удалено файлов: 90000
Файлов после удаления: 90000


## Import & pipeline starting

In [None]:
import os
import gc
import json
import random
from pathlib import Path
from typing import List, Tuple, Dict, Optional
from collections import defaultdict
import pickle
import shutil

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from scipy.stats import hmean

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import librosa

from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor

!pip -q install torch-audiomentations
from torch_audiomentations import Compose, Gain, PolarityInversion, AddColoredNoise, Shift, HighPassFilter, LowPassFilter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/59.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.6/59.6 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.5/48.5 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for julius (setup.py) ... [?25l[?25hdone
Using device: cuda


In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

## Data preparing

In [None]:
train_audio_path = '/content/train_opus/audio'
test_audio_path = '/content/test_opus/audio'

word_bounds_path = '/content/train_opus/word_bounds.json'

In [None]:
train_files = sorted(list(Path(train_audio_path).glob('*.opus')))
test_files = sorted(list(Path(test_audio_path).glob('*.opus')))

with open(word_bounds_path, 'r') as f:
    word_bounds = json.load(f)

print(f'Train files: {len(train_files)}')
print(f'Test files: {len(test_files)}')
print(f'Word bounds entries: {len(word_bounds)}')

pos_count = len([f for f in train_files if f.stem in word_bounds])
neg_count = len(train_files) - pos_count
print(f'Positive samples: {pos_count} ({pos_count/len(train_files)*100:.1f}%)')
print(f'Negative samples: {neg_count} ({neg_count/len(train_files)*100:.1f}%)')

Train files: 90000
Test files: 27000
Word bounds entries: 45000
Positive samples: 45000 (50.0%)
Negative samples: 45000 (50.0%)


In [None]:
pos_items = []
neg_items = []

for fpath in train_files:
    fid = fpath.stem
    if fid in word_bounds:
        start, end = word_bounds[fid]
        pos_items.append((str(fpath), float(start), float(end)))
    else:
        neg_items.append((str(fpath), None, None))

val_split = 0.12
pos_train, pos_val = train_test_split(pos_items, test_size=val_split, random_state=42)
neg_train, neg_val = train_test_split(neg_items, test_size=val_split, random_state=42)

train_items = pos_train + neg_train
val_items = pos_val + neg_val

random.shuffle(train_items)
random.shuffle(val_items)

print(f'Train: {len(train_items)} (pos: {len(pos_train)}, neg: {len(neg_train)})')
print(f'Val: {len(val_items)} (pos: {len(pos_val)}, neg: {len(neg_val)})')

Train: 79200 (pos: 39600, neg: 39600)
Val: 10800 (pos: 5400, neg: 5400)


## Audio utils & augmentations

In [None]:
def load_audio(path: str, sr: int = 16000) -> np.ndarray:
    try:
        wav, orig_sr = torchaudio.load(path)
        if wav.size(0) > 1:
            wav = wav.mean(dim=0, keepdim=True)
        wav = wav.squeeze(0)
        if orig_sr != sr:
            wav = torchaudio.functional.resample(wav, orig_sr, sr)
        wav = wav.numpy()
    except:
        try:
            wav, orig_sr = librosa.load(path, sr=sr, mono=True)
        except Exception as e:
            print(f'Error loading {path}: {e}')
            return np.zeros(sr, dtype=np.float32)

    return wav.astype(np.float32)

def normalize_audio(wav: np.ndarray, method='peak') -> np.ndarray:
    if method == 'peak':
        peak = np.abs(wav).max()
        if peak > 1e-8:
            wav = wav / peak
    elif method == 'rms':
        rms = np.sqrt(np.mean(wav ** 2))
        if rms > 1e-8:
            wav = wav / (rms * 10)
            wav = np.clip(wav, -1, 1)
    return wav

In [None]:
import numpy as np
from tqdm import tqdm

def analyze_audio_lengths(file_list, sr=16000):
    """Анализ длин аудиофайлов"""
    durations = []

    for path in tqdm(file_list, desc="Analyzing"):
        wav = load_audio(path, sr)
        durations.append(len(wav) / sr)

    durations = np.array(durations)

    print(f"{'='*60}")
    print("СТАТИСТИКА ДЛИТЕЛЬНОСТИ АУДИО")
    print(f"{'='*60}")
    print(f"Количество файлов: {len(durations)}")
    print(f"Минимум: {durations.min():.2f} сек")
    print(f"Максимум: {durations.max():.2f} сек")
    print(f"Среднее: {durations.mean():.2f} сек")
    print(f"Медиана: {np.median(durations):.2f} сек")
    print(f"\nПерцентили:")
    print(f"  50%: {np.percentile(durations, 50):.2f} сек")
    print(f"  75%: {np.percentile(durations, 75):.2f} сек")
    print(f"  90%: {np.percentile(durations, 90):.2f} сек")
    print(f"  95%: {np.percentile(durations, 95):.2f} сек")
    print(f"  99%: {np.percentile(durations, 99):.2f} сек")

    # Сколько файлов короче 6 секунд
    short = (durations < 6.0).sum()
    print(f"\nФайлов короче 6 сек: {short} ({short/len(durations)*100:.1f}%)")

    return durations

# Анализ train
print("TRAIN:")
train_durations = analyze_audio_lengths(random.sample(train_files, 10000))

print("\n" + "="*60)

# Анализ test
print("\nTEST:")
test_durations = analyze_audio_lengths(random.sample(test_files, 10000))

TRAIN:


Analyzing: 100%|██████████| 10000/10000 [02:48<00:00, 59.18it/s]


СТАТИСТИКА ДЛИТЕЛЬНОСТИ АУДИО
Количество файлов: 10000
Минимум: 4.00 сек
Максимум: 4.00 сек
Среднее: 4.00 сек
Медиана: 4.00 сек

Перцентили:
  50%: 4.00 сек
  75%: 4.00 сек
  90%: 4.00 сек
  95%: 4.00 сек
  99%: 4.00 сек

Файлов короче 6 сек: 10000 (100.0%)


TEST:


Analyzing: 100%|██████████| 10000/10000 [02:49<00:00, 59.14it/s]

СТАТИСТИКА ДЛИТЕЛЬНОСТИ АУДИО
Количество файлов: 10000
Минимум: 4.00 сек
Максимум: 4.00 сек
Среднее: 4.00 сек
Медиана: 4.00 сек

Перцентили:
  50%: 4.00 сек
  75%: 4.00 сек
  90%: 4.00 сек
  95%: 4.00 сек
  99%: 4.00 сек

Файлов короче 6 сек: 10000 (100.0%)





In [None]:
# class AudioAugmentation:
#     def __init__(self, sample_rate=16000):
#         self.sample_rate = sample_rate

#         self.augment = Compose(
#             transforms=[
#                 Gain(min_gain_in_db=-15.0, max_gain_in_db=7.5, p=0.6),
#                 PolarityInversion(p=0.6),
#                 AddColoredNoise(min_snr_in_db=2.5, max_snr_in_db=30.0, min_f_decay=-2.0, max_f_decay=2.0, p=0.6),
#                 Shift(min_shift=-0.75, max_shift=0.75, shift_unit="fraction", rollover=True, p=0.6),
#                 HighPassFilter(min_cutoff_freq=20.0, max_cutoff_freq=400.0, p=0.4),
#                 LowPassFilter(min_cutoff_freq=2000.0, max_cutoff_freq=7500.0, p=0.4),
#             ]
#         )

#     def __call__(self, wav):
#         if isinstance(wav, np.ndarray):
#             wav = torch.from_numpy(wav).float()

#         wav = wav.unsqueeze(0).unsqueeze(0)
#         augmented = self.augment(wav, sample_rate=self.sample_rate)
#         augmented = augmented.squeeze(0).squeeze(0)

#         if isinstance(augmented, torch.Tensor):
#             augmented = augmented.numpy()

#         return augmented.astype(np.float32)

# augmentation = AudioAugmentation(sample_rate=16000)

# class AudioAugmentation:
#     def __init__(self, sample_rate=16000):
#         self.sample_rate = sample_rate

#         self.augment = Compose(
#             transforms=[
#                 # Gain: умеренный диапазон
#                 Gain(min_gain_in_db=-10.0, max_gain_in_db=6.0, p=0.5),

#                 # Polarity: безопасно
#                 PolarityInversion(p=0.5),

#                 # Noise: не слишком грязный
#                 AddColoredNoise(
#                     min_snr_in_db=8.0,  # не ниже 8 dB
#                     max_snr_in_db=35.0,
#                     min_f_decay=-2.0,
#                     max_f_decay=2.0,
#                     p=0.5
#                 ),

#                 # Shift: КРИТИЧНО - не больше 20%!
#                 Shift(
#                     min_shift=-0.2,  # ❗ НЕ -0.75!
#                     max_shift=0.2,   # ❗ НЕ 0.75!
#                     shift_unit="fraction",
#                     rollover=True,
#                     p=0.4
#                 ),

#                 # Фильтры: умеренно
#                 HighPassFilter(min_cutoff_freq=20.0, max_cutoff_freq=300.0, p=0.3),
#                 LowPassFilter(min_cutoff_freq=2500.0, max_cutoff_freq=7500.0, p=0.3),
#             ]
#         )

#     def __call__(self, wav):
#         if isinstance(wav, np.ndarray):
#             wav = torch.from_numpy(wav).float()

#         wav = wav.unsqueeze(0).unsqueeze(0)
#         augmented = self.augment(wav, sample_rate=self.sample_rate)
#         augmented = augmented.squeeze(0).squeeze(0)

#         if isinstance(augmented, torch.Tensor):
#             augmented = augmented.numpy()

#         return augmented.astype(np.float32)

# augmentation = AudioAugmentation(sample_rate=16000)

from torch_audiomentations import Compose, Gain, PolarityInversion, AddColoredNoise, Shift, HighPassFilter, LowPassFilter

class AudioAugmentation:
    """
    Мощные но адекватные аугментации
    Комбинация torch_audiomentations + librosa для максимальной эффективности
    """
    def __init__(self, sample_rate=16000, use_heavy_aug=True):
        self.sample_rate = sample_rate
        self.use_heavy_aug = use_heavy_aug

        # Базовые аугментации через torch_audiomentations (быстро на GPU)
        self.augment = Compose(
            transforms=[
                # Gain: чуть шире чем у советчика
                Gain(
                    min_gain_in_db=-12.0,  # 0.25x громкость
                    max_gain_in_db=8.0,    # 2.5x громкость
                    p=0.6
                ),

                # Polarity Inversion: безопасная аугментация
                PolarityInversion(p=0.5),

                # Colored Noise: чуть агрессивнее советчика
                AddColoredNoise(
                    min_snr_in_db=10.0,    # чуть ниже чем 15 у советчика
                    max_snr_in_db=35.0,    # выше чем 30
                    min_f_decay=-2.0,      # розовый шум
                    max_f_decay=2.0,       # коричневый шум
                    p=0.6
                ),

                # Time Shift: компромисс между 0.1 и 0.2
                Shift(
                    min_shift=-0.15,       # 15% влево
                    max_shift=0.15,        # 15% вправо
                    shift_unit="fraction",
                    rollover=True,
                    p=0.5
                ),

                # High Pass Filter: убираем низкие частоты
                HighPassFilter(
                    min_cutoff_freq=20.0,
                    max_cutoff_freq=400.0,  # чуть агрессивнее
                    p=0.35
                ),

                # Low Pass Filter: убираем высокие частоты
                LowPassFilter(
                    min_cutoff_freq=2000.0,
                    max_cutoff_freq=7500.0,
                    p=0.35
                ),
            ]
        )

    def apply_time_stretch(self, wav, rate=0.15):
        """Time stretching через librosa (медленно, но эффективно)"""
        if random.random() > 0.4:  # 40% вероятность
            try:
                stretch_factor = random.uniform(1.0 - rate, 1.0 + rate)

                if isinstance(wav, torch.Tensor):
                    wav_np = wav.cpu().numpy()
                else:
                    wav_np = wav

                stretched = librosa.effects.time_stretch(wav_np, rate=stretch_factor)

                # Приводим к исходной длине
                if len(stretched) > len(wav_np):
                    stretched = stretched[:len(wav_np)]
                else:
                    stretched = np.pad(stretched, (0, max(0, len(wav_np) - len(stretched))))

                return stretched
            except:
                return wav if isinstance(wav, np.ndarray) else wav.cpu().numpy()

        return wav if isinstance(wav, np.ndarray) else wav.cpu().numpy()

    def apply_pitch_shift(self, wav, n_steps_range=(-2, 2)):
        """Pitch shifting через librosa"""
        if random.random() > 0.5:  # 50% вероятность
            try:
                n_steps = random.uniform(n_steps_range[0], n_steps_range[1])

                if abs(n_steps) < 0.1:
                    return wav

                if isinstance(wav, torch.Tensor):
                    wav_np = wav.cpu().numpy()
                else:
                    wav_np = wav

                shifted = librosa.effects.pitch_shift(
                    wav_np,
                    sr=self.sample_rate,
                    n_steps=n_steps
                )

                return shifted
            except:
                return wav if isinstance(wav, np.ndarray) else wav.cpu().numpy()

        return wav if isinstance(wav, np.ndarray) else wav.cpu().numpy()

    def apply_background_noise(self, wav, noise_level_range=(0.005, 0.03)):
        """Добавление фонового шума"""
        if random.random() > 0.6:  # 40% вероятность
            if isinstance(wav, torch.Tensor):
                wav_np = wav.cpu().numpy()
            else:
                wav_np = wav

            std = np.std(wav_np)
            if std < 1e-6:
                return wav_np

            noise_level = random.uniform(*noise_level_range)
            noise = np.random.randn(len(wav_np)) * noise_level * std

            return wav_np + noise

        return wav if isinstance(wav, np.ndarray) else wav.cpu().numpy()

    def __call__(self, wav):
        """
        Применение всех аугментаций

        Args:
            wav: numpy array или torch tensor

        Returns:
            numpy array float32
        """
        # Конвертируем в torch для torch_audiomentations
        if isinstance(wav, np.ndarray):
            wav_torch = torch.from_numpy(wav).float()
        else:
            wav_torch = wav.float()

        # Применяем torch_audiomentations (быстро)
        wav_torch = wav_torch.unsqueeze(0).unsqueeze(0)  # [1, 1, T]
        augmented = self.augment(wav_torch, sample_rate=self.sample_rate)
        augmented = augmented.squeeze(0).squeeze(0)  # [T]

        # Конвертируем в numpy для librosa
        augmented_np = augmented.cpu().numpy()

        # Применяем дополнительные аугментации если нужно
        if self.use_heavy_aug:
            # Time stretch (медленная аугментация)
            augmented_np = self.apply_time_stretch(augmented_np, rate=0.15)

            # Pitch shift (медленная аугментация)
            augmented_np = self.apply_pitch_shift(augmented_np, n_steps_range=(-1.5, 1.5))

            # Background noise
            augmented_np = self.apply_background_noise(augmented_np, noise_level_range=(0.01, 0.04))

        return augmented_np.astype(np.float32)

# Создание объекта аугментации
augmentation = AudioAugmentation(sample_rate=16000, use_heavy_aug=True)

## Datasets & utils

In [None]:
class KWSDataset(Dataset):
    def __init__(self, items, segment_samples, sr=16000, augment=None):
        self.items = items
        self.segment_samples = segment_samples
        self.sr = sr
        self.augment = augment

    def __len__(self):
        return len(self.items)

    def extract_positive_segment(self, wav, start_sec, end_sec):
        start_idx = int(start_sec * self.sr)
        end_idx = int(end_sec * self.sr)

        phrase_len = end_idx - start_idx

        if phrase_len >= self.segment_samples:
            center = (start_idx + end_idx) // 2
            left = max(0, center - self.segment_samples // 2)
            right = min(len(wav), left + self.segment_samples)
            left = right - self.segment_samples
        else:
            context_total = self.segment_samples - phrase_len
            context_left = random.randint(0, context_total)
            context_right = context_total - context_left

            left = max(0, start_idx - context_left)
            right = min(len(wav), end_idx + context_right)

            if right - left < self.segment_samples:
                if left == 0:
                    right = min(len(wav), left + self.segment_samples)
                else:
                    left = max(0, right - self.segment_samples)

        segment = wav[left:right]

        if len(segment) < self.segment_samples:
            pad = self.segment_samples - len(segment)
            segment = np.pad(segment, (0, pad), mode='constant')
        elif len(segment) > self.segment_samples:
            segment = segment[:self.segment_samples]

        return segment

    def extract_negative_segment(self, wav):
        if len(wav) <= self.segment_samples:
            segment = wav
            if len(segment) < self.segment_samples:
                pad = self.segment_samples - len(segment)
                segment = np.pad(segment, (0, pad), mode='constant')
        else:
            start = random.randint(0, len(wav) - self.segment_samples)
            segment = wav[start:start + self.segment_samples]

        return segment

    def __getitem__(self, idx):
        path, start_sec, end_sec = self.items[idx]

        wav = load_audio(path, self.sr)
        wav = normalize_audio(wav)

        label = 1 if start_sec is not None else 0

        if label == 1:
            segment = self.extract_positive_segment(wav, start_sec, end_sec)
        else:
            segment = self.extract_negative_segment(wav)

        if self.augment is not None:
            segment = self.augment(segment)

        segment = normalize_audio(segment)

        return segment, label

In [None]:
class KWSTestDatasetWindowed(Dataset):
    def __init__(self, file_paths, window_samples, hop_samples, sr=16000):
        self.file_paths = file_paths
        self.window_samples = window_samples
        self.hop_samples = hop_samples
        self.sr = sr

        self.samples = []

        print("Preparing test windows...")
        for path in tqdm(file_paths):
            file_id = Path(path).stem
            wav = load_audio(path, sr)
            wav = normalize_audio(wav)

            if len(wav) <= window_samples:
                if len(wav) < window_samples:
                    pad = window_samples - len(wav)
                    wav = np.pad(wav, (0, pad), mode='constant')
                self.samples.append((wav, file_id, 0))
            else:
                num_windows = (len(wav) - window_samples) // hop_samples + 1

                for i in range(num_windows):
                    start = i * hop_samples
                    end = start + window_samples

                    if end > len(wav):
                        start = len(wav) - window_samples
                        end = len(wav)

                    segment = wav[start:end]
                    self.samples.append((segment, file_id, i))

                    if end >= len(wav):
                        break

        print(f"Total windows created: {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        segment, file_id, window_idx = self.samples[idx]
        return segment, file_id, window_idx

In [None]:
MODEL_NAME = 'internalhell/wav2vec2-large-ru-5ep'
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)

# def collate_train_val(batch):
#     wavs, labels = zip(*batch)

#     inputs = feature_extractor(
#         list(wavs),
#         sampling_rate=16000,
#         return_tensors="pt",
#         padding=True,
#         max_length=24000,
#         truncation=True
#     )

#     labels = torch.tensor(labels, dtype=torch.long)
#     return inputs.input_values, labels

# def collate_test_windows(batch):
#     segments, file_ids, window_indices = zip(*batch)

#     inputs = feature_extractor(
#         list(segments),
#         sampling_rate=16000,
#         return_tensors="pt",
#         padding=True,
#         max_length=24000,
#         truncation=True
#     )

#     return inputs.input_values, list(file_ids), list(window_indices)

def collate_train_val(batch):
    wavs, labels = zip(*batch)

    inputs = feature_extractor(
        list(wavs),
        sampling_rate=16000,
        return_tensors="pt",
        padding=True,
        max_length=64000,  # 6 СЕКУНД!
        truncation=True
    )

    labels = torch.tensor(labels, dtype=torch.long)
    return inputs.input_values, labels

def collate_test_windows(batch):
    segments, file_ids, window_indices = zip(*batch)

    inputs = feature_extractor(
        list(segments),
        sampling_rate=16000,
        return_tensors="pt",
        padding=True,
        max_length=64000,  # 6 СЕКУНД!
        truncation=True
    )

    return inputs.input_values, list(file_ids), list(window_indices)

preprocessor_config.json:   0%|          | 0.00/256 [00:00<?, ?B/s]

In [None]:
sample_rate = 16000
segment_duration = 4
segment_samples = int(sample_rate * segment_duration)

train_dataset = KWSDataset(
    train_items,
    segment_samples=segment_samples,
    sr=sample_rate,
    augment=augmentation
)

val_dataset = KWSDataset(
    val_items,
    segment_samples=segment_samples,
    sr=sample_rate,
    augment=None
)

batch_size = 64
num_workers = 2

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=collate_train_val,
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_train_val,
    pin_memory=False
)

print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')

Train batches: 1238
Val batches: 169


In [None]:
test_window_duration = 4
test_window_samples = int(sample_rate * test_window_duration)
test_hop_duration = 1
test_hop_samples = int(sample_rate * test_hop_duration)

test_dataset_windowed = KWSTestDatasetWindowed(
    [str(f) for f in test_files],
    window_samples=test_window_samples,
    hop_samples=test_hop_samples,
    sr=sample_rate
)

test_loader_windowed = DataLoader(
    test_dataset_windowed,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_test_windows,
    pin_memory=False
)

print(f'Test batches: {len(test_loader_windowed)}')

Preparing test windows...


  0%|          | 0/27000 [00:00<?, ?it/s]

Total windows created: 27000
Test batches: 422


## Model & training utils

In [None]:
def calculate_metrics(preds, labels, num_pos, num_neg):
    correct = (preds == labels).sum()
    total = len(labels)
    accuracy = correct / total if total > 0 else 0

    tp = ((preds == 1) & (labels == 1)).sum()
    fp = ((preds == 1) & (labels == 0)).sum()
    fn = ((preds == 0) & (labels == 1)).sum()
    tn = ((preds == 0) & (labels == 0)).sum()

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    frr = fn / num_pos if num_pos > 0 else 0
    far = fp / num_neg if num_neg > 0 else 0

    score_1_frr = 1 - frr
    score_1_far = 1 - far

    if score_1_frr > 0 and score_1_far > 0:
        competition_score = hmean([score_1_frr, score_1_far])
    else:
        competition_score = 0.0

    return {
        'accuracy': accuracy,
        'f1': f1,
        'competition_score': competition_score,
        'tp': tp,
        'fp': fp,
        'fn': fn,
        'tn': tn
    }

In [None]:
def train_epoch(model, loader, optimizer, criterion, device, num_pos, num_neg):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    pbar = tqdm(loader, desc='Training', leave=False)
    for input_values, labels in pbar:
        input_values = input_values.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        logits = model(input_values)
        loss = criterion(logits, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_loss = total_loss / len(loader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    metrics = calculate_metrics(all_preds, all_labels, num_pos, num_neg)
    metrics['loss'] = avg_loss

    return metrics

# def train_epoch(model, loader, optimizer, criterion, device, num_pos, num_neg, accumulation_steps=1):
#     model.train()
#     total_loss = 0
#     all_preds = []
#     all_labels = []

#     optimizer.zero_grad()

#     pbar = tqdm(loader, desc='Training', leave=False)
#     for batch_idx, (input_values, labels) in enumerate(pbar):
#         input_values = input_values.to(device)
#         labels = labels.to(device)

#         logits = model(input_values)
#         loss = criterion(logits, labels)
#         loss = loss / accumulation_steps  # нормализуем лосс

#         loss.backward()

#         if (batch_idx + 1) % accumulation_steps == 0:
#             torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#             optimizer.step()
#             optimizer.zero_grad()

#         total_loss += loss.item() * accumulation_steps

#         preds = torch.argmax(logits, dim=1)
#         all_preds.extend(preds.cpu().numpy())
#         all_labels.extend(labels.cpu().numpy())

#         pbar.set_postfix({'loss': f'{loss.item() * accumulation_steps:.4f}'})

#     if (batch_idx + 1) % accumulation_steps != 0:
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#         optimizer.step()
#         optimizer.zero_grad()

#     avg_loss = total_loss / len(loader)
#     all_preds = np.array(all_preds)
#     all_labels = np.array(all_labels)
#     metrics = calculate_metrics(all_preds, all_labels, num_pos, num_neg)
#     metrics['loss'] = avg_loss

#     return metrics

In [None]:
def validate(model, loader, criterion, device, num_pos, num_neg):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    pbar = tqdm(loader, desc='Validation', leave=False)
    with torch.no_grad():
        for input_values, labels in pbar:
            input_values = input_values.to(device)
            labels = labels.to(device)

            logits = model(input_values)
            loss = criterion(logits, labels)

            total_loss += loss.item()

            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_loss = total_loss / len(loader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    metrics = calculate_metrics(all_preds, all_labels, num_pos, num_neg)
    metrics['loss'] = avg_loss

    return metrics

In [None]:
class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.Tanh(),
            nn.Linear(hidden_size // 2, 1)
        )

    def forward(self, hidden_states):
        attention_weights = self.attention(hidden_states)
        attention_weights = F.softmax(attention_weights, dim=1)
        pooled = torch.sum(hidden_states * attention_weights, dim=1)
        return pooled

In [None]:
class Wav2Vec2ForKWS(nn.Module):
    def __init__(self, model_name, num_labels=2, freeze_feature_extractor=True, freeze_encoder_layers=0):
        super().__init__()

        self.wav2vec2 = Wav2Vec2Model.from_pretrained(model_name)

        if freeze_feature_extractor:
            for param in self.wav2vec2.feature_extractor.parameters():
                param.requires_grad = False

        if freeze_encoder_layers > 0:
            for layer in self.wav2vec2.encoder.layers[:freeze_encoder_layers]:
                for param in layer.parameters():
                    param.requires_grad = False

        hidden_size = self.wav2vec2.config.hidden_size

        self.attention_pooling = AttentionPooling(hidden_size)

        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, 256),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_labels)
        )

    def forward(self, input_values):
        outputs = self.wav2vec2(input_values)
        hidden_states = outputs.last_hidden_state

        pooled = self.attention_pooling(hidden_states)

        logits = self.classifier(pooled)
        return logits

## Training

In [None]:
freeze_feature_extractor = True
freeze_encoder_layers = 6

model = Wav2Vec2ForKWS(
    model_name=MODEL_NAME,
    num_labels=2,
    freeze_feature_extractor=freeze_feature_extractor,
    freeze_encoder_layers=freeze_encoder_layers
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
print(f'Trainable percentage: {trainable_params / total_params:.2f}')

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Total parameters: 316,261,635
Trainable parameters: 236,474,115
Trainable percentage: 0.75


In [None]:
# num_epochs = 5
# learning_rate = 3e-4
# weight_decay = 0.0125

# optimizer = torch.optim.AdamW(
#     filter(lambda p: p.requires_grad, model.parameters()),
#     lr=learning_rate,
#     weight_decay=weight_decay
# )

# criterion = nn.CrossEntropyLoss()

# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#     optimizer,
#     T_max=num_epochs,
#     eta_min=6e-7
# )

def get_optimizer_params(model, base_lr=3e-4):
    """
    Разные LR для разных частей модели
    """
    no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]

    optimizer_params = [
        # # Feature Extractor - ОЧЕНЬ маленький LR
        # {
        #     "params": [p for n, p in model.wav2vec2.feature_extractor.named_parameters()
        #               if p.requires_grad and not any(nd in n for nd in no_decay)],
        #     "lr": base_lr * 0.01,  # 3e-6
        #     "weight_decay": 0.01,
        # },
        # {
        #     "params": [p for n, p in model.wav2vec2.feature_extractor.named_parameters()
        #               if p.requires_grad and any(nd in n for nd in no_decay)],
        #     "lr": base_lr * 0.01,  # 3e-6
        #     "weight_decay": 0.0,
        # },

        # Encoder Layers - маленький LR
        {
            "params": [p for n, p in model.wav2vec2.encoder.named_parameters()
                      if p.requires_grad and not any(nd in n for nd in no_decay)],
            "lr": base_lr * 0.5,  # 3e-5
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in model.wav2vec2.encoder.named_parameters()
                      if p.requires_grad and any(nd in n for nd in no_decay)],
            "lr": base_lr * 0.5,  # 3e-5
            "weight_decay": 0.0,
        },

        # Classifier - полный LR
        {
            "params": [p for n, p in model.classifier.named_parameters()
                      if not any(nd in n for nd in no_decay)],
            "lr": base_lr,  # 3e-4
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in model.classifier.named_parameters()
                      if any(nd in n for nd in no_decay)],
            "lr": base_lr,  # 3e-4
            "weight_decay": 0.0,
        },
    ]

    return optimizer_params

# Использование
num_epochs = 5
base_lr = 3e-4

optimizer_params = get_optimizer_params(model, base_lr=base_lr)

optimizer = torch.optim.AdamW(optimizer_params)

criterion = nn.CrossEntropyLoss()

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_epochs,
    eta_min=1e-7
)

print(f"Feature Extractor LR: {base_lr * 0.01}")
print(f"Encoder LR: {base_lr * 0.1}")
print(f"Classifier LR: {base_lr}")

Feature Extractor LR: 2.9999999999999997e-06
Encoder LR: 2.9999999999999997e-05
Classifier LR: 0.0003


In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
checkpoint_dir = MODEL_NAME[MODEL_NAME.find('/')+1:]
os.makedirs(checkpoint_dir, exist_ok=True)

history = {
    'train_loss': [], 'train_acc': [], 'train_f1': [], 'train_score': [],
    'val_loss': [], 'val_acc': [], 'val_f1': [], 'val_score': []
}

best_score = 0
best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')

train_num_pos = len(pos_train)
train_num_neg = len(neg_train)
val_num_pos = len(pos_val)
val_num_neg = len(neg_val)

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch+1}/{num_epochs}')

    train_metrics = train_epoch(
        model, train_loader, optimizer, criterion, device,
        train_num_pos, train_num_neg
    )
    val_metrics = validate(
        model, val_loader, criterion, device,
        val_num_pos, val_num_neg
    )

    history['train_loss'].append(train_metrics['loss'])
    history['train_acc'].append(train_metrics['accuracy'])
    history['train_f1'].append(train_metrics['f1'])
    history['train_score'].append(train_metrics['competition_score'])
    history['val_loss'].append(val_metrics['loss'])
    history['val_acc'].append(val_metrics['accuracy'])
    history['val_f1'].append(val_metrics['f1'])
    history['val_score'].append(val_metrics['competition_score'])

    print(f"\nTrain Metrics:")
    print(f"  Loss: {train_metrics['loss']:.4f} | Acc: {train_metrics['accuracy']:.4f} | F1: {train_metrics['f1']:.4f}")
    print(f"  Competition Score: {train_metrics['competition_score']:.4f}")

    print(f"\nVal Metrics:")
    print(f"  Loss: {val_metrics['loss']:.4f} | Acc: {val_metrics['accuracy']:.4f} | F1: {val_metrics['f1']:.4f}")
    print(f"  Competition Score: {val_metrics['competition_score']:.4f}")

    scheduler.step()

    torch.cuda.empty_cache()

    checkpoint_path = os.path.join(checkpoint_dir, f'epoch_{epoch+1}.pth')
    torch.save(model.state_dict(), checkpoint_path)
    print(f'\nCheckpoint saved: {checkpoint_path}')

    if val_metrics['competition_score'] > best_score:
        best_score = val_metrics['competition_score']
        torch.save(model.state_dict(), best_model_path)
        print(f'Best model saved with Competition Score: {best_score:.4f}')

print(f'\nTraining completed!')
print(f'Best validation Competition Score: {best_score:.4f}')


Epoch 1/5


Training:   0%|          | 0/1238 [00:00<?, ?it/s]

Validation:   0%|          | 0/169 [00:00<?, ?it/s]


Train Metrics:
  Loss: 0.2899 | Acc: 0.8789 | F1: 0.8793
  Competition Score: 0.8788

Val Metrics:
  Loss: 0.0921 | Acc: 0.9678 | F1: 0.9682
  Competition Score: 0.9676

Checkpoint saved: wav2vec2-large-ru-5ep/epoch_1.pth
Best model saved with Competition Score: 0.9676

Epoch 2/5


Training:   0%|          | 0/1238 [00:00<?, ?it/s]

Validation:   0%|          | 0/169 [00:00<?, ?it/s]


Train Metrics:
  Loss: 0.2243 | Acc: 0.9102 | F1: 0.9104
  Competition Score: 0.9102

Val Metrics:
  Loss: 0.0862 | Acc: 0.9716 | F1: 0.9720
  Competition Score: 0.9713

Checkpoint saved: wav2vec2-large-ru-5ep/epoch_2.pth
Best model saved with Competition Score: 0.9713

Epoch 3/5


Training:   0%|          | 0/1238 [00:00<?, ?it/s]

Validation:   0%|          | 0/169 [00:00<?, ?it/s]


Train Metrics:
  Loss: 0.1951 | Acc: 0.9225 | F1: 0.9227
  Competition Score: 0.9225

Val Metrics:
  Loss: 0.0762 | Acc: 0.9728 | F1: 0.9732
  Competition Score: 0.9726

Checkpoint saved: wav2vec2-large-ru-5ep/epoch_3.pth
Best model saved with Competition Score: 0.9726

Epoch 4/5


Training:   0%|          | 0/1238 [00:00<?, ?it/s]

Validation:   0%|          | 0/169 [00:00<?, ?it/s]


Train Metrics:
  Loss: 0.1715 | Acc: 0.9330 | F1: 0.9333
  Competition Score: 0.9330

Val Metrics:
  Loss: 0.0701 | Acc: 0.9750 | F1: 0.9753
  Competition Score: 0.9749

Checkpoint saved: wav2vec2-large-ru-5ep/epoch_4.pth
Best model saved with Competition Score: 0.9749

Epoch 5/5


Training:   0%|          | 0/1238 [00:00<?, ?it/s]

Validation:   0%|          | 0/169 [00:00<?, ?it/s]


Train Metrics:
  Loss: 0.1580 | Acc: 0.9370 | F1: 0.9372
  Competition Score: 0.9370

Val Metrics:
  Loss: 0.0690 | Acc: 0.9770 | F1: 0.9771
  Competition Score: 0.9770

Checkpoint saved: wav2vec2-large-ru-5ep/epoch_5.pth
Best model saved with Competition Score: 0.9770

Training completed!
Best validation Competition Score: 0.9770


In [None]:
best_sd = torch.load(best_model_path)
model.load_state_dict(best_sd)
model.eval()
print()




## Inference utils

In [None]:
def generate_predictions(all_predictions, aggregation='max', threshold=0.5, submission_name=None):
    predictions = []

    for file_id, probs_list in all_predictions.items():
        if aggregation == 'max':
            final_prob = max(probs_list)
        elif aggregation == 'mean':
            final_prob = np.mean(probs_list)
        elif aggregation == 'quantile_75':
            final_prob = np.percentile(probs_list, 75)
        elif aggregation == 'quantile_90':
            final_prob = np.percentile(probs_list, 90)
        elif aggregation == 'quantile_95':
            final_prob = np.percentile(probs_list, 95)
        else:
            final_prob = max(probs_list)

        predictions.append({
            'id': file_id,
            'prob': final_prob
        })

    submission_df = pd.DataFrame(predictions)
    submission_df['label'] = (submission_df['prob'] >= threshold).astype(int)

    if submission_name is not None:
        submission_final = submission_df[['id', 'label']]
        submission_final.to_csv(submission_name, index=False)
        print(f"✓ Submission saved: {submission_name}")

    return submission_df

In [None]:
def calculate_score_from_predictions(submission_df, ground_truth, num_pos, num_neg):
    """
    Вычисляет метрику соревнования по предсказаниям

    Args:
        submission_df: DataFrame с колонками ['id', 'label']
        ground_truth: dict {file_id: true_label}
        num_pos: количество позитивных примеров
        num_neg: количество негативных примеров

    Returns:
        dict с метриками
    """
    preds = []
    labels = []

    for _, row in submission_df.iterrows():
        file_id = row['id']
        if file_id in ground_truth:
            preds.append(row['label'])
            labels.append(ground_truth[file_id])

    preds = np.array(preds)
    labels = np.array(labels)

    metrics = calculate_metrics(preds, labels, num_pos, num_neg)

    return metrics

## Test aggregations & threshold on validation

In [None]:
val_dataset_windowed = KWSTestDatasetWindowed(
    [item[0] for item in val_items],
    window_samples=test_window_samples,
    hop_samples=test_hop_samples,
    sr=sample_rate
)

val_loader_windowed = DataLoader(
    val_dataset_windowed,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_test_windows,
    pin_memory=True
)

Preparing test windows...


  0%|          | 0/10800 [00:00<?, ?it/s]

Total windows created: 10800


In [None]:
val_all_predictions = defaultdict(list)

with torch.no_grad(), torch.cuda.amp.autocast():
    for input_values, file_ids, window_indices in tqdm(val_loader_windowed, desc='Validation inference'):
        input_values = input_values.to(device)

        logits = model(input_values)
        probs = F.softmax(logits, dim=1)
        probs_pos = probs[:, 1].cpu().numpy()

        for file_id, prob in zip(file_ids, probs_pos):
            val_all_predictions[file_id].append(float(prob))

val_ground_truth = {}
for path, start_sec, end_sec in val_items:
    file_id = Path(path).stem
    label = 1 if start_sec is not None else 0
    val_ground_truth[file_id] = label

print(f"✓ Validation predictions ready for {len(val_all_predictions)} files")

Validation inference:   0%|          | 0/169 [00:00<?, ?it/s]

✓ Validation predictions ready for 10800 files


In [None]:
aggregation_methods = ['max', 'mean', 'quantile_75', 'quantile_90', 'quantile_95']
thresholds = np.arange(0.05, 0.96, 0.05)

tuning_results = []

for agg_method in aggregation_methods:
    print(f"\nTesting aggregation: {agg_method}")

    best_threshold = 0.5
    best_score = 0

    for threshold in tqdm(thresholds, desc=f"  {agg_method}", leave=False):
        submission_df = generate_predictions(
            val_all_predictions,
            aggregation=agg_method,
            threshold=threshold,
            submission_name=None
        )

        metrics = calculate_score_from_predictions(
            submission_df,
            val_ground_truth,
            val_num_pos,
            val_num_neg
        )

        tuning_results.append({
            'aggregation': agg_method,
            'threshold': round(threshold, 2),
            'competition_score': metrics['competition_score'],
            'accuracy': metrics['accuracy'],
            'f1': metrics['f1']
        })

        if metrics['competition_score'] > best_score:
            best_score = metrics['competition_score']
            best_threshold = threshold

    print(f"  Best: threshold={best_threshold:.2f}, score={best_score:.4f}")

tuning_df = pd.DataFrame(tuning_results)
tuning_df.to_csv('tuning_results.csv', index=False)

print("\n✓ Tuning complete! Results saved to 'tuning_results.csv'")
tuning_df.head(45)


Testing aggregation: max


  max:   0%|          | 0/19 [00:00<?, ?it/s]

  Best: threshold=0.55, score=0.9774

Testing aggregation: mean


  mean:   0%|          | 0/19 [00:00<?, ?it/s]

  Best: threshold=0.55, score=0.9774

Testing aggregation: quantile_75


  quantile_75:   0%|          | 0/19 [00:00<?, ?it/s]

  Best: threshold=0.55, score=0.9774

Testing aggregation: quantile_90


  quantile_90:   0%|          | 0/19 [00:00<?, ?it/s]

  Best: threshold=0.55, score=0.9774

Testing aggregation: quantile_95


  quantile_95:   0%|          | 0/19 [00:00<?, ?it/s]

  Best: threshold=0.55, score=0.9774

✓ Tuning complete! Results saved to 'tuning_results.csv'


Unnamed: 0,aggregation,threshold,competition_score,accuracy,f1
0,max,0.05,0.960671,0.961852,0.963096
1,max,0.1,0.96722,0.96787,0.968657
2,max,0.15,0.969927,0.97037,0.970972
3,max,0.2,0.97276,0.973056,0.973505
4,max,0.25,0.974056,0.974259,0.974617
5,max,0.3,0.975235,0.97537,0.97565
6,max,0.35,0.97603,0.976111,0.976322
7,max,0.4,0.977083,0.97713,0.977283
8,max,0.45,0.977109,0.97713,0.977233
9,max,0.5,0.977027,0.977037,0.977109


## Inference (sliding window)

In [None]:
test_all_predictions = defaultdict(list)

with torch.no_grad(), torch.cuda.amp.autocast():
    for input_values, file_ids, window_indices in tqdm(test_loader_windowed, desc='Predicting on test'):
        input_values = input_values.to(device)

        logits = model(input_values)
        probs = F.softmax(logits, dim=1)
        probs_pos = probs[:, 1].cpu().numpy()

        for file_id, prob in zip(file_ids, probs_pos):
            test_all_predictions[file_id].append(float(prob))

Predicting on test:   0%|          | 0/422 [00:00<?, ?it/s]

In [None]:
pickle_filename = 'test_all_predictions.pkl'

with open(pickle_filename, 'wb') as f:
    pickle.dump(dict(test_all_predictions), f)

pickle_size_mb = os.path.getsize(pickle_filename) / (1024 * 1024)
print(f"✓ Predictions saved: {pickle_filename} ({pickle_size_mb:.2f} MB)")
print(f"✓ Total files: {len(test_all_predictions)}")

# with open('test_all_predictions.pkl', 'rb') as f:
#     test_all_predictions = pickle.load(f)

# print(f"✓ Loaded predictions for {len(test_all_predictions)} files")

✓ Predictions saved: test_all_predictions.pkl (1.42 MB)
✓ Total files: 27000


In [None]:
aggregation_methods = ['max', 'mean', 'quantile_75', 'quantile_90', 'quantile_95']
thresholds = np.arange(0.05, 0.96, 0.05)

submission_dir = 'submissions'
os.makedirs(submission_dir, exist_ok=True)

submission_log = []

print(f"Generating {len(aggregation_methods) * len(thresholds)} submissions...")

for agg_method in tqdm(aggregation_methods, desc="Aggregations"):
    for threshold in thresholds:
        threshold_rounded = round(threshold, 2)

        filename = f"{submission_dir}/submission_{agg_method}_th{threshold_rounded:.2f}.csv"

        submission_df = generate_predictions(
            test_all_predictions,
            aggregation=agg_method,
            threshold=threshold_rounded,
            submission_name=filename
        )

        pos_count = submission_df['label'].sum()
        neg_count = len(submission_df) - pos_count

        submission_log.append({
            'aggregation': agg_method,
            'threshold': threshold_rounded,
            'filename': Path(filename).name,
            'positive_count': pos_count,
            'negative_count': neg_count,
            'positive_ratio': pos_count / len(submission_df)
        })

submission_log_df = pd.DataFrame(submission_log)
submission_log_df.to_csv(f'{submission_dir}/submission_log.csv', index=False)

print(f"✓ Generated {len(submission_log)} submissions in '{submission_dir}/'")

Generating 95 submissions...


Aggregations:   0%|          | 0/5 [00:00<?, ?it/s]

✓ Submission saved: submissions/submission_max_th0.05.csv
✓ Submission saved: submissions/submission_max_th0.10.csv
✓ Submission saved: submissions/submission_max_th0.15.csv
✓ Submission saved: submissions/submission_max_th0.20.csv
✓ Submission saved: submissions/submission_max_th0.25.csv
✓ Submission saved: submissions/submission_max_th0.30.csv
✓ Submission saved: submissions/submission_max_th0.35.csv
✓ Submission saved: submissions/submission_max_th0.40.csv
✓ Submission saved: submissions/submission_max_th0.45.csv
✓ Submission saved: submissions/submission_max_th0.50.csv
✓ Submission saved: submissions/submission_max_th0.55.csv
✓ Submission saved: submissions/submission_max_th0.60.csv
✓ Submission saved: submissions/submission_max_th0.65.csv
✓ Submission saved: submissions/submission_max_th0.70.csv
✓ Submission saved: submissions/submission_max_th0.75.csv
✓ Submission saved: submissions/submission_max_th0.80.csv
✓ Submission saved: submissions/submission_max_th0.85.csv
✓ Submission s

In [None]:
pd.read_csv('submissions/submission_log.csv')

Unnamed: 0,aggregation,threshold,filename,positive_count,negative_count,positive_ratio
0,max,0.05,submission_max_th0.05.csv,14405,12595,0.533519
1,max,0.10,submission_max_th0.10.csv,14199,12801,0.525889
2,max,0.15,submission_max_th0.15.csv,14064,12936,0.520889
3,max,0.20,submission_max_th0.20.csv,13972,13028,0.517481
4,max,0.25,submission_max_th0.25.csv,13893,13107,0.514556
...,...,...,...,...,...,...
90,quantile_95,0.75,submission_quantile_95_th0.75.csv,13306,13694,0.492815
91,quantile_95,0.80,submission_quantile_95_th0.80.csv,13212,13788,0.489333
92,quantile_95,0.85,submission_quantile_95_th0.85.csv,13106,13894,0.485407
93,quantile_95,0.90,submission_quantile_95_th0.90.csv,12909,14091,0.478111


In [None]:
print("Generating balanced (50/50) submissions...")

balanced_submissions = []
total_samples = len(test_all_predictions)
target_positive = total_samples // 2

for agg_method in aggregation_methods:
    all_probs = []
    for file_id, probs_list in test_all_predictions.items():
        if agg_method == 'max':
            final_prob = max(probs_list)
        elif agg_method == 'mean':
            final_prob = np.mean(probs_list)
        elif agg_method == 'quantile_75':
            final_prob = np.percentile(probs_list, 75)
        elif agg_method == 'quantile_90':
            final_prob = np.percentile(probs_list, 90)
        elif agg_method == 'quantile_95':
            final_prob = np.percentile(probs_list, 95)
        else:
            final_prob = max(probs_list)

        all_probs.append(final_prob)

    all_probs_sorted = sorted(all_probs, reverse=True)

    balanced_threshold = all_probs_sorted[target_positive]
    balanced_threshold = round(balanced_threshold, 3)

    filename = f"{submission_dir}/submission_{agg_method}_balanced50.csv"

    submission_df = generate_predictions(
        test_all_predictions,
        aggregation=agg_method,
        threshold=balanced_threshold,
        submission_name=filename
    )

    pos_count = submission_df['label'].sum()
    neg_count = len(submission_df) - pos_count

    balanced_submissions.append({
        'aggregation': agg_method,
        'threshold': balanced_threshold,
        'filename': Path(filename).name,
        'positive_count': pos_count,
        'negative_count': neg_count,
        'positive_ratio': pos_count / len(submission_df)
    })

    print(f"  {agg_method}: th={balanced_threshold:.3f}, pos={pos_count}, neg={neg_count}")

balanced_df = pd.DataFrame(balanced_submissions)
balanced_df.to_csv(f'{submission_dir}/balanced_submissions.csv', index=False)

print(f"\n✓ Generated {len(balanced_submissions)} balanced submissions")

Generating balanced (50/50) submissions...
✓ Submission saved: submissions/submission_max_balanced50.csv
  max: th=0.597, pos=13500, neg=13500
✓ Submission saved: submissions/submission_mean_balanced50.csv
  mean: th=0.597, pos=13500, neg=13500
✓ Submission saved: submissions/submission_quantile_75_balanced50.csv
  quantile_75: th=0.597, pos=13500, neg=13500
✓ Submission saved: submissions/submission_quantile_90_balanced50.csv
  quantile_90: th=0.597, pos=13500, neg=13500
✓ Submission saved: submissions/submission_quantile_95_balanced50.csv
  quantile_95: th=0.597, pos=13500, neg=13500

✓ Generated 5 balanced submissions


In [None]:
import shutil

shutil.make_archive('submissions', 'zip', submission_dir)

zip_size_mb = os.path.getsize('submissions.zip') / (1024 * 1024)
print(f"✓ Archive: submissions.zip ({zip_size_mb:.2f} MB)")

✓ Archive: submissions.zip (50.74 MB)
