1. работа с данными - вырезаем куски с key word, sliding иференс на test и тд (так то уже есть в бейзлайне), преобразование в логмелспектрограммы (через processor модели) + аугментации на train
2. сплит на train/val + подсчет кастомной метрики во время обучения
3. берем мощный трансформер wav2vec2 (енкодер + классификатор) с простой кросс ентропией должно обучиться норм, а дальше докрутить
4. блендинг/стекинг + подбор порогов на валидации/лб

In [1]:
import os
import gc
import json
import random
from pathlib import Path
from typing import List, Tuple, Dict, Optional
from collections import defaultdict
import pickle
import shutil

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from scipy.stats import hmean

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import librosa

from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor

!pip -q install torch-audiomentations
from torch_audiomentations import Compose, Gain, PolarityInversion, AddColoredNoise, Shift, HighPassFilter, LowPassFilter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

2025-10-18 16:09:17.283692: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760803757.473975      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760803757.530277      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.6/59.6 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.5/48.5 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m88.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m73.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m41.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m2.0 MB/s[

In [2]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

## Data preparing

In [3]:
train_audio_path = '/kaggle/input/vseros-2a/train_audio/train_audio'
test_audio_path = '/kaggle/input/vseros-2a/test_audio/test_audio'
word_bounds_path = '/kaggle/input/vseros-2a/word_bounds.json'

In [4]:
train_files = sorted(list(Path(train_audio_path).glob('*.opus')))
test_files = sorted(list(Path(test_audio_path).glob('*.opus')))

with open(word_bounds_path, 'r') as f:
    word_bounds = json.load(f)

print(f'Train files: {len(train_files)}')
print(f'Test files: {len(test_files)}')
print(f'Word bounds entries: {len(word_bounds)}')

pos_count = len([f for f in train_files if f.stem in word_bounds])
neg_count = len(train_files) - pos_count
print(f'Positive samples: {pos_count} ({pos_count/len(train_files)*100:.1f}%)')
print(f'Negative samples: {neg_count} ({neg_count/len(train_files)*100:.1f}%)')

Train files: 90000
Test files: 27000
Word bounds entries: 45000
Positive samples: 45000 (50.0%)
Negative samples: 45000 (50.0%)


In [5]:
pos_items = []
neg_items = []

for fpath in train_files:
    fid = fpath.stem
    if fid in word_bounds:
        start, end = word_bounds[fid]
        pos_items.append((str(fpath), float(start), float(end)))
    else:
        neg_items.append((str(fpath), None, None))

val_split = 0.1
pos_train, pos_val = train_test_split(pos_items, test_size=val_split, random_state=42)
neg_train, neg_val = train_test_split(neg_items, test_size=val_split, random_state=42)

train_items = pos_train + neg_train
val_items = pos_val + neg_val

random.shuffle(train_items)
random.shuffle(val_items)

print(f'Train: {len(train_items)} (pos: {len(pos_train)}, neg: {len(neg_train)})')
print(f'Val: {len(val_items)} (pos: {len(pos_val)}, neg: {len(neg_val)})')

Train: 81000 (pos: 40500, neg: 40500)
Val: 9000 (pos: 4500, neg: 4500)


## Audio utils & augmentations

In [6]:
def load_audio(path: str, sr: int = 16000) -> np.ndarray:
    try:
        wav, orig_sr = torchaudio.load(path)
        if wav.size(0) > 1:
            wav = wav.mean(dim=0, keepdim=True)
        wav = wav.squeeze(0)
        if orig_sr != sr:
            wav = torchaudio.functional.resample(wav, orig_sr, sr)
        wav = wav.numpy()
    except:
        try:
            wav, orig_sr = librosa.load(path, sr=sr, mono=True)
        except Exception as e:
            print(f'Error loading {path}: {e}')
            return np.zeros(sr, dtype=np.float32)
    
    return wav.astype(np.float32)

def normalize_audio(wav: np.ndarray, method='peak') -> np.ndarray:
    if method == 'peak':
        peak = np.abs(wav).max()
        if peak > 1e-8:
            wav = wav / peak
    elif method == 'rms':
        rms = np.sqrt(np.mean(wav ** 2))
        if rms > 1e-8:
            wav = wav / (rms * 10)
            wav = np.clip(wav, -1, 1)
    return wav

In [7]:
class AudioAugmentation:
    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate
        
        self.augment = Compose(
            transforms=[
                Gain(min_gain_in_db=-15.0, max_gain_in_db=5.0, p=0.5),
                PolarityInversion(p=0.5),
                AddColoredNoise(min_snr_in_db=5.0, max_snr_in_db=30.0, min_f_decay=-2.0, max_f_decay=2.0, p=0.5),
                Shift(min_shift=-0.5, max_shift=0.5, shift_unit="fraction", rollover=True, p=0.5),
                HighPassFilter(min_cutoff_freq=20.0, max_cutoff_freq=400.0, p=0.3),
                LowPassFilter(min_cutoff_freq=2000.0, max_cutoff_freq=7500.0, p=0.3),
            ]
        )
    
    def __call__(self, wav):
        if isinstance(wav, np.ndarray):
            wav = torch.from_numpy(wav).float()
        
        wav = wav.unsqueeze(0).unsqueeze(0)
        augmented = self.augment(wav, sample_rate=self.sample_rate)
        augmented = augmented.squeeze(0).squeeze(0)
        
        if isinstance(augmented, torch.Tensor):
            augmented = augmented.numpy()
        
        return augmented.astype(np.float32)

augmentation = AudioAugmentation(sample_rate=16000)

## Datasets & utils

In [8]:
class KWSDataset(Dataset):
    def __init__(self, items, segment_samples, sr=16000, augment=None):
        self.items = items
        self.segment_samples = segment_samples
        self.sr = sr
        self.augment = augment
    
    def __len__(self):
        return len(self.items)
    
    def extract_positive_segment(self, wav, start_sec, end_sec):
        start_idx = int(start_sec * self.sr)
        end_idx = int(end_sec * self.sr)
        
        phrase_len = end_idx - start_idx
        
        if phrase_len >= self.segment_samples:
            center = (start_idx + end_idx) // 2
            left = max(0, center - self.segment_samples // 2)
            right = min(len(wav), left + self.segment_samples)
            left = right - self.segment_samples
        else:
            context_total = self.segment_samples - phrase_len
            context_left = random.randint(0, context_total)
            context_right = context_total - context_left
            
            left = max(0, start_idx - context_left)
            right = min(len(wav), end_idx + context_right)
            
            if right - left < self.segment_samples:
                if left == 0:
                    right = min(len(wav), left + self.segment_samples)
                else:
                    left = max(0, right - self.segment_samples)
        
        segment = wav[left:right]
        
        if len(segment) < self.segment_samples:
            pad = self.segment_samples - len(segment)
            segment = np.pad(segment, (0, pad), mode='constant')
        elif len(segment) > self.segment_samples:
            segment = segment[:self.segment_samples]
        
        return segment
    
    def extract_negative_segment(self, wav):
        if len(wav) <= self.segment_samples:
            segment = wav
            if len(segment) < self.segment_samples:
                pad = self.segment_samples - len(segment)
                segment = np.pad(segment, (0, pad), mode='constant')
        else:
            start = random.randint(0, len(wav) - self.segment_samples)
            segment = wav[start:start + self.segment_samples]
        
        return segment
    
    def __getitem__(self, idx):
        path, start_sec, end_sec = self.items[idx]
        
        wav = load_audio(path, self.sr)
        wav = normalize_audio(wav)
        
        label = 1 if start_sec is not None else 0
        
        if label == 1:
            segment = self.extract_positive_segment(wav, start_sec, end_sec)
        else:
            segment = self.extract_negative_segment(wav)
        
        if self.augment is not None:
            segment = self.augment(segment)
        
        segment = normalize_audio(segment)
        
        return segment, label

In [9]:
class KWSTestDatasetWindowed(Dataset):
    def __init__(self, file_paths, window_samples, hop_samples, sr=16000):
        self.file_paths = file_paths
        self.window_samples = window_samples
        self.hop_samples = hop_samples
        self.sr = sr
        
        self.samples = []
        
        print("Preparing test windows...")
        for path in tqdm(file_paths):
            file_id = Path(path).stem
            wav = load_audio(path, sr)
            wav = normalize_audio(wav)
            
            if len(wav) <= window_samples:
                if len(wav) < window_samples:
                    pad = window_samples - len(wav)
                    wav = np.pad(wav, (0, pad), mode='constant')
                self.samples.append((wav, file_id, 0))
            else:
                num_windows = (len(wav) - window_samples) // hop_samples + 1
                
                for i in range(num_windows):
                    start = i * hop_samples
                    end = start + window_samples
                    
                    if end > len(wav):
                        start = len(wav) - window_samples
                        end = len(wav)
                    
                    segment = wav[start:end]
                    self.samples.append((segment, file_id, i))
                    
                    if end >= len(wav):
                        break
        
        print(f"Total windows created: {len(self.samples)}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        segment, file_id, window_idx = self.samples[idx]
        return segment, file_id, window_idx

In [10]:
MODEL_NAME = 'UrukHan/wav2vec2-russian'
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)

def collate_train_val(batch):
    wavs, labels = zip(*batch)
    
    inputs = feature_extractor(
        list(wavs),
        sampling_rate=16000,
        return_tensors="pt",
        padding=True,
        max_length=24000,
        truncation=True
    )
    
    labels = torch.tensor(labels, dtype=torch.long)
    return inputs.input_values, labels

def collate_test_windows(batch):
    segments, file_ids, window_indices = zip(*batch)
    
    inputs = feature_extractor(
        list(segments),
        sampling_rate=16000,
        return_tensors="pt",
        padding=True,
        max_length=24000,
        truncation=True
    )
    
    return inputs.input_values, list(file_ids), list(window_indices)

preprocessor_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

In [11]:
sample_rate = 16000
segment_duration = 1.5
segment_samples = int(sample_rate * segment_duration)

train_dataset = KWSDataset(
    train_items, 
    segment_samples=segment_samples,
    sr=sample_rate,
    augment=augmentation
)

val_dataset = KWSDataset(
    val_items,
    segment_samples=segment_samples,
    sr=sample_rate,
    augment=None
)

batch_size = 8
num_workers = 2

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=collate_train_val,
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_train_val,
    pin_memory=False
)

print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')

Train batches: 10125
Val batches: 1125


In [12]:
test_window_duration = 1.5
test_window_samples = int(sample_rate * test_window_duration)
test_hop_duration = 0.5
test_hop_samples = int(sample_rate * test_hop_duration)

test_dataset_windowed = KWSTestDatasetWindowed(
    [str(f) for f in test_files],
    window_samples=test_window_samples,
    hop_samples=test_hop_samples,
    sr=sample_rate
)

test_loader_windowed = DataLoader(
    test_dataset_windowed,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_test_windows,
    pin_memory=False
)

print(f'Test batches: {len(test_loader_windowed)}')

Preparing test windows...


  0%|          | 0/27000 [00:00<?, ?it/s]

Total windows created: 162000
Test batches: 20250


## Model & training utils

In [13]:
def calculate_metrics(preds, labels, num_pos, num_neg):
    correct = (preds == labels).sum()
    total = len(labels)
    accuracy = correct / total if total > 0 else 0
    
    tp = ((preds == 1) & (labels == 1)).sum()
    fp = ((preds == 1) & (labels == 0)).sum()
    fn = ((preds == 0) & (labels == 1)).sum()
    tn = ((preds == 0) & (labels == 0)).sum()
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    frr = fn / num_pos if num_pos > 0 else 0
    far = fp / num_neg if num_neg > 0 else 0
    
    score_1_frr = 1 - frr
    score_1_far = 1 - far
    
    if score_1_frr > 0 and score_1_far > 0:
        competition_score = hmean([score_1_frr, score_1_far])
    else:
        competition_score = 0.0
    
    return {
        'accuracy': accuracy,
        'f1': f1,
        'competition_score': competition_score,
        'tp': tp,
        'fp': fp,
        'fn': fn,
        'tn': tn
    }

In [14]:
# def train_epoch(model, loader, optimizer, criterion, device, num_pos, num_neg):
#     model.train()
#     total_loss = 0
#     all_preds = []
#     all_labels = []
    
#     pbar = tqdm(loader, desc='Training', leave=False)
#     for input_values, labels in pbar:
#         input_values = input_values.to(device)
#         labels = labels.to(device)
        
#         optimizer.zero_grad()
        
#         logits = model(input_values)
#         loss = criterion(logits, labels)
        
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#         optimizer.step()
        
#         total_loss += loss.item()
        
#         preds = torch.argmax(logits, dim=1)
#         all_preds.extend(preds.cpu().numpy())
#         all_labels.extend(labels.cpu().numpy())
        
#         pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
#     avg_loss = total_loss / len(loader)
#     all_preds = np.array(all_preds)
#     all_labels = np.array(all_labels)
#     metrics = calculate_metrics(all_preds, all_labels, num_pos, num_neg)
#     metrics['loss'] = avg_loss
    
#     return metrics

def train_epoch(model, loader, optimizer, criterion, device, num_pos, num_neg, accumulation_steps=1):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    optimizer.zero_grad()
    
    pbar = tqdm(loader, desc='Training', leave=False)
    for batch_idx, (input_values, labels) in enumerate(pbar):
        input_values = input_values.to(device)
        labels = labels.to(device)
        
        logits = model(input_values)
        loss = criterion(logits, labels)
        loss = loss / accumulation_steps  # нормализуем лосс
        
        loss.backward()
        
        if (batch_idx + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * accumulation_steps
        
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        pbar.set_postfix({'loss': f'{loss.item() * accumulation_steps:.4f}'})
    
    if (batch_idx + 1) % accumulation_steps != 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()
    
    avg_loss = total_loss / len(loader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    metrics = calculate_metrics(all_preds, all_labels, num_pos, num_neg)
    metrics['loss'] = avg_loss
    
    return metrics

In [15]:
def validate(model, loader, criterion, device, num_pos, num_neg):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    pbar = tqdm(loader, desc='Validation', leave=False)
    with torch.no_grad():
        for input_values, labels in pbar:
            input_values = input_values.to(device)
            labels = labels.to(device)
            
            logits = model(input_values)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / len(loader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    metrics = calculate_metrics(all_preds, all_labels, num_pos, num_neg)
    metrics['loss'] = avg_loss
    
    return metrics

In [16]:
# class AttentionPooling(nn.Module):
#     def __init__(self, hidden_size):
#         super().__init__()
#         self.attention = nn.Sequential(
#             nn.Linear(hidden_size, hidden_size // 2),
#             nn.Tanh(),
#             nn.Linear(hidden_size // 2, 1)
#         )
    
#     def forward(self, hidden_states):
#         attention_weights = self.attention(hidden_states)
#         attention_weights = F.softmax(attention_weights, dim=1)
#         pooled = torch.sum(hidden_states * attention_weights, dim=1)
#         return pooled

In [17]:
class Wav2Vec2ForKWS(nn.Module):
    def __init__(self, model_name, num_labels=2, freeze_feature_extractor=True, freeze_encoder_layers=0):
        super().__init__()
        
        self.wav2vec2 = Wav2Vec2Model.from_pretrained(model_name)
        
        if freeze_feature_extractor:
            for param in self.wav2vec2.feature_extractor.parameters():
                param.requires_grad = False
        
        if freeze_encoder_layers > 0:
            for layer in self.wav2vec2.encoder.layers[:freeze_encoder_layers]:
                for param in layer.parameters():
                    param.requires_grad = False
        
        hidden_size = self.wav2vec2.config.hidden_size
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(hidden_size, 256),
            nn.Tanh(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.Tanh(),
            nn.Dropout(0.1),
            nn.Linear(128, num_labels)
        )
        
        # for module in self.classifier.modules():
        #     if isinstance(module, nn.Linear):
        #         nn.init.normal_(module.weight, std=0.02)
        #         nn.init.zeros_(module.bias)
    
    def forward(self, input_values):
        outputs = self.wav2vec2(input_values)
        hidden_states = outputs.last_hidden_state
        pooled = torch.mean(hidden_states, dim=1)
        logits = self.classifier(pooled)
        return logits

## Training

In [18]:
freeze_feature_extractor = True
freeze_encoder_layers = 18

model = Wav2Vec2ForKWS(
    model_name=MODEL_NAME,
    num_labels=2,
    freeze_feature_extractor=freeze_feature_extractor,
    freeze_encoder_layers=freeze_encoder_layers
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
print(f'Trainable percentage: {trainable_params / total_params:.2f}')

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Total parameters: 315,734,274
Trainable parameters: 84,792,066
Trainable percentage: 0.27


In [19]:
num_epochs = 4
learning_rate = 3e-4
weight_decay = 0.01

batch_size_effective = 64
batch_size_actual = 8
accumulation_steps = batch_size_effective // batch_size_actual  # = 8

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=learning_rate,
    weight_decay=weight_decay
)

criterion = nn.CrossEntropyLoss()

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_epochs,
    eta_min=5e-7
)

In [20]:
gc.collect()
torch.cuda.empty_cache()

In [21]:
checkpoint_dir = MODEL_NAME[MODEL_NAME.find('/')+1:]
os.makedirs(checkpoint_dir, exist_ok=True)

history = {
    'train_loss': [], 'train_acc': [], 'train_f1': [], 'train_score': [],
    'val_loss': [], 'val_acc': [], 'val_f1': [], 'val_score': []
}

best_score = 0
best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')

train_num_pos = len(pos_train)
train_num_neg = len(neg_train)
val_num_pos = len(pos_val)
val_num_neg = len(neg_val)

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch+1}/{num_epochs}')
    
    train_metrics = train_epoch(
        model, train_loader, optimizer, criterion, device,
        train_num_pos, train_num_neg,
        accumulation_steps=accumulation_steps  # ДОБАВЛЕНО
    )
    val_metrics = validate(
        model, val_loader, criterion, device,
        val_num_pos, val_num_neg
    )
    
    history['train_loss'].append(train_metrics['loss'])
    history['train_acc'].append(train_metrics['accuracy'])
    history['train_f1'].append(train_metrics['f1'])
    history['train_score'].append(train_metrics['competition_score'])
    history['val_loss'].append(val_metrics['loss'])
    history['val_acc'].append(val_metrics['accuracy'])
    history['val_f1'].append(val_metrics['f1'])
    history['val_score'].append(val_metrics['competition_score'])
    
    print(f"\nTrain Metrics:")
    print(f"  Loss: {train_metrics['loss']:.4f} | Acc: {train_metrics['accuracy']:.4f} | F1: {train_metrics['f1']:.4f}")
    print(f"  Competition Score: {train_metrics['competition_score']:.4f}")
    
    print(f"\nVal Metrics:")
    print(f"  Loss: {val_metrics['loss']:.4f} | Acc: {val_metrics['accuracy']:.4f} | F1: {val_metrics['f1']:.4f}")
    print(f"  Competition Score: {val_metrics['competition_score']:.4f}")
    
    scheduler.step()
    
    torch.cuda.empty_cache()
    
    checkpoint_path = os.path.join(checkpoint_dir, f'epoch_{epoch+1}.pth')
    torch.save(model.state_dict(), checkpoint_path)
    print(f'\nCheckpoint saved: {checkpoint_path}')
    
    if val_metrics['competition_score'] > best_score:
        best_score = val_metrics['competition_score']
        torch.save(model.state_dict(), best_model_path)
        print(f'Best model saved with Competition Score: {best_score:.4f}')

print(f'\nTraining completed!')
print(f'Best validation Competition Score: {best_score:.4f}')


Epoch 1/4


Training:   0%|          | 0/10125 [00:00<?, ?it/s]

In [None]:
best_sd = torch.load(best_model_path)
model.load_state_dict(best_sd)
model.eval()
print()

## Inference utils

In [None]:
def generate_predictions(all_predictions, aggregation='max', threshold=0.5, submission_name=None):
    predictions = []
    
    for file_id, probs_list in all_predictions.items():
        if aggregation == 'max':
            final_prob = max(probs_list)
        elif aggregation == 'mean':
            final_prob = np.mean(probs_list)
        elif aggregation == 'quantile_75':
            final_prob = np.percentile(probs_list, 75)
        elif aggregation == 'quantile_90':
            final_prob = np.percentile(probs_list, 90)
        elif aggregation == 'quantile_95':
            final_prob = np.percentile(probs_list, 95)
        else:
            final_prob = max(probs_list)
        
        predictions.append({
            'id': file_id,
            'prob': final_prob
        })

    submission_df = pd.DataFrame(predictions)
    submission_df['label'] = (submission_df['prob'] >= threshold).astype(int)
    
    if submission_name is not None:    
        submission_final = submission_df[['id', 'label']]
        submission_final.to_csv(submission_name, index=False)
        print(f"✓ Submission saved: {submission_name}")
    
    return submission_df

In [None]:
def calculate_score_from_predictions(submission_df, ground_truth, num_pos, num_neg):
    """
    Вычисляет метрику соревнования по предсказаниям
    
    Args:
        submission_df: DataFrame с колонками ['id', 'label']
        ground_truth: dict {file_id: true_label}
        num_pos: количество позитивных примеров
        num_neg: количество негативных примеров
    
    Returns:
        dict с метриками
    """
    preds = []
    labels = []
    
    for _, row in submission_df.iterrows():
        file_id = row['id']
        if file_id in ground_truth:
            preds.append(row['label'])
            labels.append(ground_truth[file_id])
    
    preds = np.array(preds)
    labels = np.array(labels)
    
    metrics = calculate_metrics(preds, labels, num_pos, num_neg)
    
    return metrics

## Test aggregations & threshold on validation

In [None]:
val_dataset_windowed = KWSTestDatasetWindowed(
    [item[0] for item in val_items],
    window_samples=test_window_samples,
    hop_samples=test_hop_samples,
    sr=sample_rate
)

val_loader_windowed = DataLoader(
    val_dataset_windowed,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_test_windows,
    pin_memory=True
)

In [None]:
val_all_predictions = defaultdict(list)

with torch.no_grad(), torch.cuda.amp.autocast():
    for input_values, file_ids, window_indices in tqdm(val_loader_windowed, desc='Validation inference'):
        input_values = input_values.to(device)
        
        logits = model(input_values)
        probs = F.softmax(logits, dim=1)
        probs_pos = probs[:, 1].cpu().numpy()
        
        for file_id, prob in zip(file_ids, probs_pos):
            val_all_predictions[file_id].append(float(prob))

val_ground_truth = {}
for path, start_sec, end_sec in val_items:
    file_id = Path(path).stem
    label = 1 if start_sec is not None else 0
    val_ground_truth[file_id] = label

print(f"✓ Validation predictions ready for {len(val_all_predictions)} files")

In [None]:
aggregation_methods = ['max', 'mean', 'quantile_75', 'quantile_90', 'quantile_95']
thresholds = np.arange(0.05, 0.96, 0.05)

tuning_results = []

for agg_method in aggregation_methods:
    print(f"\nTesting aggregation: {agg_method}")
    
    best_threshold = 0.5
    best_score = 0
    
    for threshold in tqdm(thresholds, desc=f"  {agg_method}", leave=False):
        submission_df = generate_predictions(
            val_all_predictions,
            aggregation=agg_method,
            threshold=threshold,
            submission_name=None
        )
        
        metrics = calculate_score_from_predictions(
            submission_df,
            val_ground_truth,
            val_num_pos,
            val_num_neg
        )
        
        tuning_results.append({
            'aggregation': agg_method,
            'threshold': round(threshold, 2),
            'competition_score': metrics['competition_score'],
            'accuracy': metrics['accuracy'],
            'f1': metrics['f1']
        })
        
        if metrics['competition_score'] > best_score:
            best_score = metrics['competition_score']
            best_threshold = threshold
    
    print(f"  Best: threshold={best_threshold:.2f}, score={best_score:.4f}")

tuning_df = pd.DataFrame(tuning_results)
tuning_df.to_csv('tuning_results.csv', index=False)

print("\n✓ Tuning complete! Results saved to 'tuning_results.csv'")
tuning_df.head(45)

## Inference (sliding window)

In [None]:
test_all_predictions = defaultdict(list)

with torch.no_grad(), torch.cuda.amp.autocast():
    for input_values, file_ids, window_indices in tqdm(test_loader_windowed, desc='Predicting on test'):
        input_values = input_values.to(device)
        
        logits = model(input_values)
        probs = F.softmax(logits, dim=1)
        probs_pos = probs[:, 1].cpu().numpy()
        
        for file_id, prob in zip(file_ids, probs_pos):
            test_all_predictions[file_id].append(float(prob))

In [None]:
pickle_filename = 'test_all_predictions.pkl'

with open(pickle_filename, 'wb') as f:
    pickle.dump(dict(test_all_predictions), f)

pickle_size_mb = os.path.getsize(pickle_filename) / (1024 * 1024)
print(f"✓ Predictions saved: {pickle_filename} ({pickle_size_mb:.2f} MB)")
print(f"✓ Total files: {len(test_all_predictions)}")

# with open('test_all_predictions.pkl', 'rb') as f:
#     test_all_predictions = pickle.load(f)

# print(f"✓ Loaded predictions for {len(test_all_predictions)} files")

In [None]:
aggregation_methods = ['max', 'mean', 'quantile_75', 'quantile_90', 'quantile_95']
thresholds = np.arange(0.05, 0.96, 0.05)

submission_dir = 'submissions'
os.makedirs(submission_dir, exist_ok=True)

submission_log = []

print(f"Generating {len(aggregation_methods) * len(thresholds)} submissions...")

for agg_method in tqdm(aggregation_methods, desc="Aggregations"):
    for threshold in thresholds:
        threshold_rounded = round(threshold, 2)
        
        filename = f"{submission_dir}/submission_{agg_method}_th{threshold_rounded:.2f}.csv"
        
        submission_df = generate_predictions(
            test_all_predictions,
            aggregation=agg_method,
            threshold=threshold_rounded,
            submission_name=filename
        )
        
        pos_count = submission_df['label'].sum()
        neg_count = len(submission_df) - pos_count
        
        submission_log.append({
            'aggregation': agg_method,
            'threshold': threshold_rounded,
            'filename': Path(filename).name,
            'positive_count': pos_count,
            'negative_count': neg_count,
            'positive_ratio': pos_count / len(submission_df)
        })

submission_log_df = pd.DataFrame(submission_log)
submission_log_df.to_csv(f'{submission_dir}/submission_log.csv', index=False)

print(f"✓ Generated {len(submission_log)} submissions in '{submission_dir}/'")

In [None]:
pd.read_csv('submissions/submission_log.csv')

In [None]:
print("Generating balanced (50/50) submissions...")

balanced_submissions = []
total_samples = len(test_all_predictions)
target_positive = total_samples // 2

for agg_method in aggregation_methods:
    all_probs = []
    for file_id, probs_list in test_all_predictions.items():
        if agg_method == 'max':
            final_prob = max(probs_list)
        elif agg_method == 'mean':
            final_prob = np.mean(probs_list)
        elif agg_method == 'quantile_75':
            final_prob = np.percentile(probs_list, 75)
        elif agg_method == 'quantile_90':
            final_prob = np.percentile(probs_list, 90)
        elif agg_method == 'quantile_95':
            final_prob = np.percentile(probs_list, 95)
        else:
            final_prob = max(probs_list)
        
        all_probs.append(final_prob)
    
    all_probs_sorted = sorted(all_probs, reverse=True)
    
    balanced_threshold = all_probs_sorted[target_positive]
    balanced_threshold = round(balanced_threshold, 3)
    
    filename = f"{submission_dir}/submission_{agg_method}_balanced50.csv"
    
    submission_df = generate_predictions(
        test_all_predictions,
        aggregation=agg_method,
        threshold=balanced_threshold,
        submission_name=filename
    )
    
    pos_count = submission_df['label'].sum()
    neg_count = len(submission_df) - pos_count
    
    balanced_submissions.append({
        'aggregation': agg_method,
        'threshold': balanced_threshold,
        'filename': Path(filename).name,
        'positive_count': pos_count,
        'negative_count': neg_count,
        'positive_ratio': pos_count / len(submission_df)
    })
    
    print(f"  {agg_method}: th={balanced_threshold:.3f}, pos={pos_count}, neg={neg_count}")

balanced_df = pd.DataFrame(balanced_submissions)
balanced_df.to_csv(f'{submission_dir}/balanced_submissions.csv', index=False)

print(f"\n✓ Generated {len(balanced_submissions)} balanced submissions")

In [None]:
import shutil

shutil.make_archive('submissions', 'zip', submission_dir)

zip_size_mb = os.path.getsize('submissions.zip') / (1024 * 1024)
print(f"✓ Archive: submissions.zip ({zip_size_mb:.2f} MB)")