In [1]:
!pip install -q sentence-transformers==2.7.0 transformers==4.45.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m171.5/171.5 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m28.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.3/564.3 kB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m67.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m81.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m66.5 MB/s[0m et

In [2]:
import os, math, json, random, warnings, time
from pathlib import Path
from typing import List, Tuple, Dict, Union

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

import torchaudio
from sklearn.model_selection import train_test_split
from transformers import AutoConfig, AutoModel, TrainingArguments, Trainer
from tqdm.auto import tqdm

warnings.filterwarnings("ignore")

2025-10-20 19:41:15.055014: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760989275.275823      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760989275.340010      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
TRAIN_DIR = "/kaggle/input/vseros-audio-task/train_data"
TEST_DIR  = "/kaggle/input/vseros-audio-task/test_data"
OUT_DIR   = "/kaggle/working/out_kws"

BACKBONE = "jonatasgrosman/wav2vec2-large-xlsr-53-russian"

In [4]:
EPOCHS       = 15
BATCH_SIZE   = 32
LR           = 1e-5
SEG_DUR      = 2.0
POS_MARGIN   = 0.2
NEG_RATIO    = 1.0
VAL_SPLIT    = 0.15
WINDOW_SEC   = 1.5
HOP_SEC      = 0.25
SEED         = 42

In [5]:
FORCE_CPU = False

In [6]:
if FORCE_CPU:
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
DEVICE = torch.device("cuda" if torch.cuda.is_available() and not FORCE_CPU else "cpu")
AMP_ENABLED = (DEVICE.type == "cuda")
PIN_MEMORY = AMP_ENABLED

In [7]:
def autocast_ctx():
    return torch.autocast(device_type="cuda", dtype=torch.float16, enabled=AMP_ENABLED)

print("Device:", DEVICE, "| AMP:", AMP_ENABLED, "| pin_memory:", PIN_MEMORY)

Device: cpu | AMP: False | pin_memory: False


In [8]:
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); 
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(SEED)


In [9]:
def _find_audio_root(p: Union[str, Path]) -> Path:
    p = Path(p)
    if (p/"audio").exists() and list((p/"audio").glob("*.opus")):
        return p
    for sub in p.rglob("audio"):
        if list(sub.glob("*.opus")):
            return sub.parent
    raise FileNotFoundError(f"Не нашёл audio/*.opus внутри {p}")

def list_valid_opus_fast(root_dir: Path, check_size: bool = False, min_bytes: int = 2048):
    audio_dir = Path(root_dir) / "audio"
    keep, skipped = [], 0
    for name in os.listdir(audio_dir):
        if not name.endswith(".opus"):
            continue
        if name.startswith("._"):
            skipped += 1
            continue
        if check_size:
            p = audio_dir / name
            try:
                if p.stat().st_size < min_bytes:
                    skipped += 1
                    continue
            except Exception:
                skipped += 1
                continue
        keep.append(str(audio_dir / name))
    keep.sort()
    print(f"{audio_dir}: kept {len(keep)} files, skipped {skipped} (._*{', size<'+str(min_bytes) if check_size else ''})")
    return keep

def safe_load_audio_16k(path: Union[str, Path], sr_target: int = 16000) -> torch.Tensor:
    path = str(path)
    try:
        wav, sr = torchaudio.load(path)
        if wav.dim()==2 and wav.size(0)>1:
            wav = wav.mean(dim=0, keepdim=True)
        wav = wav.squeeze(0)
        if sr != sr_target:
            wav = torchaudio.functional.resample(wav, sr, sr_target)
        wav = wav.float()
    except Exception:
        import soundfile as sf, librosa
        data, sr = sf.read(path, dtype="float32", always_2d=False)
        if isinstance(data, np.ndarray) and data.ndim > 1:
            data = data.mean(axis=1)
        if sr != sr_target:
            data = librosa.resample(data, orig_sr=sr, target_sr=sr_target)
        wav = torch.from_numpy(data.astype(np.float32))
    peak = float(wav.abs().max())
    if peak > 0:
        wav = wav / peak
    return wav

def build_pos_neg_lists(train_files: List[str], word_bounds: Dict[str, List[float]]):
    pos, neg = [], []
    for f in train_files:
        fid = Path(f).stem
        if fid in word_bounds:
            s, e = word_bounds[fid]
            pos.append((f, (float(s), float(e))))
        else:
            neg.append(f)
    return pos, neg

def get_conv_params_from_config(cfg):
    strides = list(getattr(cfg, "conv_stride", [5,2,2,2,2,2,2]))
    kernels = list(getattr(cfg, "conv_kernel", [10,3,3,3,3,2,2]))
    return strides, kernels

def feat_len_from_samples(n_samples: int, strides, kernels) -> int:
    L = n_samples
    for k, s in zip(kernels, strides):
        L = math.floor((L - k) / s + 1)
        if L <= 0:
            return 0
    return L

def ensure_length(wav: torch.Tensor, target_len: int) -> torch.Tensor:
    T = wav.numel()
    if T == target_len: return wav
    if T > target_len:  return wav[:target_len]
    return F.pad(wav, (0, target_len - T))

def pick_positive_window(T, sr, seg_size, bounds, context_frac=0.5):
    t0, t1 = bounds
    p0 = max(0, min(T, int(round(t0 * sr))))
    p1 = max(0, min(T, int(round(t1 * sr))))
    pos_len = max(1, p1 - p0)
    if pos_len >= seg_size:
        c = (p0 + p1) // 2
        left = max(0, c - seg_size // 2)
        right = min(T, left + seg_size)
        left = right - seg_size
        return left, right
    free = seg_size - pos_len
    alpha = float(np.clip(np.random.normal(loc=context_frac, scale=0.15), 0, 1))
    left_ctx = int(alpha * free); right_ctx = free - left_ctx
    left = p0 - left_ctx; right = p1 + right_ctx
    if left < 0:
        shift = -left; left = 0; right = min(T, right + shift)
    if right > T:
        shift = right - T; right = T; left = max(0, left - shift)
    if right - left != seg_size:
        right = min(T, left + seg_size); left = max(0, right - seg_size)
    return left, right

def pick_negative_window(T, seg_size):
    if T <= seg_size: return 0, min(T, seg_size)
    L = int(np.random.randint(0, T - seg_size + 1))
    return L, L + seg_size

In [10]:
class SegmentKWSDataset(Dataset):
    def __init__(self, pos_items, neg_items, seg_size_samples, sr, conv_strides, conv_kernels,
                 pos_margin_sec=0.2, neg_ratio=1.0, seed=SEED):
        super().__init__()
        self.pos_items = pos_items
        self.neg_items = neg_items
        self.sr = sr
        self.seg = int(seg_size_samples)
        self.margin = float(pos_margin_sec)
        self.strides = conv_strides
        self.kernels = conv_kernels

        rng = np.random.RandomState(seed)
        n_pos = len(pos_items)
        n_neg = int(math.ceil(n_pos * neg_ratio))
        if len(neg_items) == 0:
            raise RuntimeError("Нет отрицательных примеров.")
        neg_idx = rng.randint(0, len(neg_items), size=n_neg).tolist()
        self.index = [("pos", i) for i in range(n_pos)] + [("neg", j) for j in neg_idx]
        rng.shuffle(self.index)

        self.frames = feat_len_from_samples(self.seg, self.strides, self.kernels)
        if self.frames <= 0:
            raise RuntimeError("frames<=0 — проверь SEG_DUR или conv-параметры")

    def __len__(self): return len(self.index)

    def __getitem__(self, i):
        kind, idx = self.index[i]
        if kind == "pos":
            path, (t0, t1) = self.pos_items[idx]
            wav = safe_load_audio_16k(path); T = wav.numel()
            L, R = pick_positive_window(T, self.sr, self.seg, (t0, t1))
            seg = ensure_length(wav[L:R], self.seg)

            seg_start = L / float(self.sr)
            a = max(0.0, t0 - self.margin - seg_start)
            b = max(0.0, t1 + self.margin - seg_start)
            sa = max(0, min(self.seg - 1, int(np.floor(a * self.sr))))
            sb = max(0, min(self.seg,     int(np.ceil (b * self.sr))))
            def s2f(n): return feat_len_from_samples(n, self.strides, self.kernels)
            A = max(0, min(self.frames, s2f(sa)))
            B = max(0, min(self.frames, s2f(sb)))
            if B <= A: B = min(self.frames, A + 1)

            y = torch.zeros(self.frames, dtype=torch.float32)
            y[A:B] = 1.0
            m = torch.ones(self.frames, dtype=torch.bool)
        else:
            path = self.neg_items[idx]
            wav = safe_load_audio_16k(path); T = wav.numel()
            L, R = pick_negative_window(T, self.seg)
            seg = ensure_length(wav[L:R], self.seg)
            y = torch.zeros(self.frames, dtype=torch.float32)
            m = torch.ones(self.frames, dtype=torch.bool)

        return {"input_values": seg, "labels": y, "frame_mask": m}

def collate_segments(batch):
    return {
        "input_values": torch.stack([b["input_values"] for b in batch], 0),
        "labels":       torch.stack([b["labels"]       for b in batch], 0),
        "frame_mask":   torch.stack([b["frame_mask"]   for b in batch], 0),
    }

In [11]:
class Wav2Vec2KWS(nn.Module):
    def __init__(self, backbone_name=BACKBONE, dropout=0.1):
        super().__init__()
        self.config = AutoConfig.from_pretrained(backbone_name)
        self.backbone = AutoModel.from_pretrained(backbone_name)
        H = self.config.hidden_size
        self.head = nn.Sequential(
            nn.Conv1d(H, H, 3, padding=1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv1d(H, 1, 1),
        )
        if hasattr(self.backbone, "feature_extractor"):
            for p in self.backbone.feature_extractor.parameters():
                p.requires_grad = False

    def forward(self, input_values):
        out = self.backbone(input_values=input_values, output_hidden_states=False)
        x = out.last_hidden_state.transpose(1, 2)
        logits = self.head(x).squeeze(1)
        return logits

In [12]:
class KwsTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        mask   = inputs.pop("frame_mask")
        logits = model(**inputs)
        loss = F.binary_cross_entropy_with_logits(
            logits[mask], labels[mask],
            pos_weight=torch.tensor(2.0, device=logits.device)
        )
        return (loss, {"logits": logits}) if return_outputs else loss

In [13]:
train_root = _find_audio_root(TRAIN_DIR)
test_root  = _find_audio_root(TEST_DIR)

train_files = list_valid_opus_fast(train_root, check_size=False)
test_files  = list_valid_opus_fast(test_root,  check_size=False)

with open(train_root / "word_bounds.json", "r", encoding="utf-8") as f:
    word_bounds = json.load(f)

pos_items, neg_items = build_pos_neg_lists(train_files, word_bounds)
pos_tr, pos_dev = train_test_split(pos_items, test_size=VAL_SPLIT, random_state=SEED)
neg_tr, neg_dev = train_test_split(neg_items, test_size=VAL_SPLIT, random_state=SEED)

sr = 16000
seg_len = int(round(SEG_DUR * sr))

_tmp_cfg = AutoConfig.from_pretrained(BACKBONE)
strides, kernels = get_conv_params_from_config(_tmp_cfg)

train_ds = SegmentKWSDataset(pos_tr, neg_tr, seg_len, sr, strides, kernels,
                             pos_margin_sec=POS_MARGIN, neg_ratio=NEG_RATIO, seed=SEED)
dev_ds   = SegmentKWSDataset(pos_dev, neg_dev, seg_len, sr, strides, kernels,
                             pos_margin_sec=POS_MARGIN, neg_ratio=1.0, seed=SEED+1)

/kaggle/input/vseros-audio-task/train_data/train_opus/audio: kept 90000 files, skipped 90000 (._*)
/kaggle/input/vseros-audio-task/test_data/test_opus/audio: kept 27000 files, skipped 27000 (._*)


config.json: 0.00B [00:00, ?B/s]

In [None]:
model = Wav2Vec2KWS(BACKBONE).to(DEVICE)

args = TrainingArguments(
    output_dir=OUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LR,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    lr_scheduler_type="cosine",
    warmup_ratio=0.06,
    save_total_limit=1,
    fp16=AMP_ENABLED,
    dataloader_pin_memory=PIN_MEMORY,
    dataloader_num_workers=4,
    logging_steps=50,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    remove_unused_columns=False,
    label_names=["labels","frame_mask"],
    weight_decay=0.01,
    seed=SEED,
)

trainer = KwsTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=dev_ds,
    data_collator=collate_segments,
)

trainer.train()

pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss


In [None]:
OUT_PATH = Path(OUT_DIR); OUT_PATH.mkdir(parents=True, exist_ok=True)
WIN_BATCH = 128 if DEVICE.type == "cuda" else 64

@torch.no_grad()
def score_file_batched(model: nn.Module, wav: torch.Tensor,
                       win_sec=WINDOW_SEC, hop_sec=HOP_SEC, sr=16000,
                       batch_windows: int = WIN_BATCH) -> float:
    """Максимальная вероятность по всем окнам записи."""
    model.eval()
    T = wav.numel()
    win = int(round(win_sec * sr))
    hop = int(round(hop_sec * sr))
    starts = list(range(0, max(1, T - win + 1), hop))
    if T <= win:
        starts = [0]

    device = next(model.parameters()).device
    best = 0.0
    for i in range(0, len(starts), batch_windows):
        b_st = starts[i:i + batch_windows]
        xb = torch.stack([ensure_length(wav[s:s+win], win) for s in b_st], 0).to(device)
        with autocast_ctx():
            logits = trainer.model(xb)
            probs  = torch.sigmoid(logits).amax(dim=1)
        b_max = float(probs.max().item())
        if b_max > best:
            best = b_max
    return best

@torch.no_grad()
def collect_dev_scores_fast(model, pos_files, neg_files):
    pos_scores, neg_scores = [], []
    for path, _ in tqdm(pos_files, desc=f"DEV pos (W={WINDOW_SEC})"):
        pos_scores.append(score_file_batched(model, safe_load_audio_16k(path)))
    for path in tqdm(neg_files, desc=f"DEV neg (W={WINDOW_SEC})"):
        neg_scores.append(score_file_batched(model, safe_load_audio_16k(path)))
    return np.asarray(pos_scores, np.float32), np.asarray(neg_scores, np.float32)

def threshold_search_fast(pos_scores, neg_scores, far_cap=None,
                          lo=0.05, hi=0.995, steps=191):
    thr_grid = np.linspace(lo, hi, steps, dtype=np.float32)
    Ppos = (pos_scores[None, :] >= thr_grid[:, None])
    Pneg = (neg_scores[None, :] >= thr_grid[:, None])
    npos, nneg = len(pos_scores), len(neg_scores)

    TP = Ppos.sum(1); FN = npos - TP
    FP = Pneg.sum(1); TN = nneg - FP
    FAR = FP / max(1, nneg)
    FRR = FN / max(1, npos)
    TPR = 1.0 - FRR; TNR = 1.0 - FAR
    score = 2.0 / (1.0/np.clip(TPR,1e-9,None) + 1.0/np.clip(TNR,1e-9,None))

    if far_cap is not None:
        mask = FAR <= float(far_cap)
        if mask.any():
            i = int(np.argmax(score[mask])); idx = int(np.flatnonzero(mask)[i])
        else:
            idx = int(np.argmax(score))
    else:
        idx = int(np.argmax(score))

    thr = float(thr_grid[idx])
    best = dict(score=float(score[idx]), FAR=float(FAR[idx]), FRR=float(FRR[idx]))
    return thr, best

In [None]:
print("\n→ scoring DEV (batched windows)…")
pos_scores, neg_scores = collect_dev_scores_fast(trainer.model, pos_dev, neg_dev)
np.save(OUT_PATH/"dev_pos_scores.npy", pos_scores)
np.save(OUT_PATH/"dev_neg_scores.npy", neg_scores)

thr, best = threshold_search_fast(pos_scores, neg_scores, far_cap=None)
print(f"[DEV] thr={thr:.6f}  score={best['score']:.4f}  FAR={best['FAR']:.4f}  FRR={best['FRR']:.4f}")

In [None]:
with open(OUT_PATH/"dev_thresholds.json","w") as w:
    json.dump({"thr": thr, "stats": best,
               "window_sec": WINDOW_SEC, "hop_sec": HOP_SEC}, w, indent=2)


In [None]:
def infer_test_and_cache(model, files, cache_dir: Path = OUT_PATH, reuse_if_exists: bool = True):
    ids_path    = cache_dir / "test_ids.npy"
    scores_path = cache_dir / "test_scores.npy"

    if reuse_if_exists and ids_path.exists() and scores_path.exists():
        ids    = np.load(ids_path, allow_pickle=True)
        scores = np.load(scores_path).astype(np.float32)
        print("✓ loaded test cache:", ids_path.name, "|", scores_path.name)
        return ids.tolist(), scores

    ids, scores = [], []
    t0 = time.time()
    for f in tqdm(files, desc=f"TEST (batched, W={WINDOW_SEC})"):
        ids.append(Path(f).stem)
        scores.append(score_file_batched(trainer.model, safe_load_audio_16k(f)))
    tps = time.time() - t0
    print(f"✓ TEST throughput: {len(files)/max(1e-6,tps):.2f} files/sec")

    scores = np.asarray(scores, np.float32)
    np.save(ids_path, np.asarray(ids, dtype=object))
    np.save(scores_path, scores)
    print("✓ saved test cache:", ids_path.name, "|", scores_path.name)
    return ids, scores

In [None]:
ids, test_scores = infer_test_and_cache(trainer.model, test_files, OUT_PATH, reuse_if_exists=True)
labels = (test_scores >= thr).astype(np.int32)

sub_path = OUT_PATH / f"submission_thr_{thr:.6f}.csv"
pd.DataFrame({"id": ids, "label": labels}).to_csv(sub_path, index=False)
print("✓ saved submission:", sub_path)

In [None]:
def make_submission_from_cache(new_thr: float,
                               ids_path=OUT_PATH/"test_ids.npy",
                               scores_path=OUT_PATH/"test_scores.npy",
                               out_path=None):
    ids = np.load(ids_path, allow_pickle=True)
    scores = np.load(scores_path).astype(np.float32)
    out = Path(out_path or (OUT_PATH / f"submission_from_cache_thr_{float(new_thr):.6f}.csv"))
    pd.DataFrame({"id": ids, "label": (scores >= float(new_thr)).astype(np.int32)}).to_csv(out, index=False)
    print("✓ saved:", out)