In [1]:
# # Server
# %cd notebooks

In [2]:
import warnings
warnings.filterwarnings('ignore')

## Config

In [3]:
train_audio_dir = 'wav_train/'
train_json_file = 'vseross/A/train_opus/word_bounds.json'

test_audio_dir = 'wav_test/'

In [None]:
model_id = ""
# model_id = "./whisper228"

In [5]:
num_workers = 0
batch_size = 4

In [6]:
n_splits = 5
num_epochs = 1

In [7]:
encoder_lr = 1e-5
head_lr = 3e-4

## Data

In [8]:
import os
import json
import torch
import torchaudio
import numpy as np
import pandas as pd
from transformers import AutoProcessor
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold

### Read

In [9]:
with open(train_json_file) as file:
    train_data = json.load(file)

In [10]:
train_df = pd.DataFrame({'id': [file_name[:file_name.rfind('.')] for file_name in os.listdir(train_audio_dir)]})
train_df['label'] = train_df['id'].apply(lambda id: id in train_data).astype(int)
train_df['audio_path'] = train_audio_dir + train_df['id'] + '.wav'

In [11]:
train_df

Unnamed: 0,id,label,audio_path
0,0000020974966386734278978724803273402299,1,wav_train/000002097496638673427897872480327340...
1,0000080688513237533207488522718221342719,1,wav_train/000008068851323753320748852271822134...
2,0000168205906481357162070811351230430612,0,wav_train/000016820590648135716207081135123043...
3,0000283353856125472766020523736784290228,1,wav_train/000028335385612547276602052373678429...
4,0000794176926750993969039908245303376577,1,wav_train/000079417692675099396903990824530337...
...,...,...,...
89995,9999772820163578754413020293817284118809,1,wav_train/999977282016357875441302029381728411...
89996,9999824806978415903658512660966171219703,1,wav_train/999982480697841590365851266096617121...
89997,9999835390111041135328983182516306075861,0,wav_train/999983539011104113532898318251630607...
89998,9999934605001241089979469065796737895987,1,wav_train/999993460500124108997946906579673789...


### Split

In [12]:
splitter = StratifiedKFold(n_splits=n_splits)

### Dataset

In [13]:
class AudioClassificationDataset(Dataset):
    processor = AutoProcessor.from_pretrained(model_id)
    
    def __init__(self, audio_paths, labels=None, sampling_rate=16_000):
        self.audio_paths = audio_paths
        self.labels = labels
        self.sampling_rate = sampling_rate
        
    def __len__(self):
        return len(self.audio_paths)
    
    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        waveform, sr = torchaudio.load(audio_path)
        if sr != self.sampling_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate)

        inputs = self.processor(
            waveform.squeeze(0),
            sampling_rate=self.sampling_rate,
            return_tensors="pt"
        )

        inputs_dict = {"input_features": inputs.input_features.squeeze(0)}
        if self.labels is not None:
            inputs_dict["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)

        return inputs_dict

## Model

In [14]:
import torch
import torch.nn as nn
from transformers import AutoModelForSpeechSeq2Seq

In [15]:
class AudioClassificationModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = AutoModelForSpeechSeq2Seq.from_pretrained(model_id).model.encoder  # только энкодер
        self.head = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.encoder.layer_norm.normalized_shape[0], 2)
        )

    def forward(self, input_features, labels=None):
        outputs = self.encoder(input_features)
        hidden_states = outputs.last_hidden_state  # (batch, seq_len, hidden)
        
        # усредняем по времени (глобальный пуллинг)
        pooled = hidden_states.mean(dim=1)
        logits = self.head(pooled)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        
        return {"loss": loss, "logits": logits}

## Train

In [None]:
import os
from sklearn.metrics import f1_score
from tqdm.auto import tqdm
import torch

os.makedirs('models', exist_ok=True)

# Save
models = []
fit_results = []
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float32

# Train
for i, (train_index, valid_index) in enumerate(splitter.split(train_df, train_df['label'])):
    print(f"Split: {i + 1}", end="\n\n")

    train_audio_paths_split, train_labels_split = train_df['audio_path'].values[train_index], train_df['label'].values[train_index]
    valid_audio_paths_split, valid_labels_split = train_df['audio_path'].values[valid_index], train_df['label'].values[valid_index]

    train_dataset = AudioClassificationDataset(train_audio_paths_split, train_labels_split)
    valid_dataset = AudioClassificationDataset(valid_audio_paths_split, valid_labels_split)

    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        num_workers=num_workers, 
        shuffle=True, 
        drop_last=True,
        pin_memory=True
    )
    valid_loader = DataLoader(
        valid_dataset, 
        batch_size=batch_size, 
        num_workers=num_workers, 
        shuffle=False, 
        drop_last=False,
        pin_memory=True
    )

    model = AudioClassificationModel()
    model.to(device=device, dtype=dtype)
    optimizer = torch.optim.Adam([
        {'params': model.encoder.parameters(), 'lr': encoder_lr},
        {'params': model.head.parameters(), 'lr': head_lr}
    ])
    
    fit_result = {
        'train_losses': [],
        'train_f1': [],
        'valid_losses': [],
        'valid_f1': []
    }

    for epoch in range(1, num_epochs + 1):
        # === TRAIN ===
        model.train()
        train_losses = []
        all_train_preds, all_train_targets = [], []
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch} [Train]")
        
        for batch in progress_bar:
            input_features = batch['input_features'].to(device=device, dtype=dtype)
            labels = batch['labels'].to(device=device)

            optimizer.zero_grad()
            outputs = model(input_features, labels=labels)
            loss = outputs['loss']
            loss.backward()
            optimizer.step()

            preds = outputs['logits'].argmax(dim=1).detach().cpu().numpy()
            targets = labels.cpu().numpy()

            all_train_preds.extend(preds)
            all_train_targets.extend(targets)
            train_losses.append(loss.item())

            progress_bar.set_postfix(loss=loss.item())

        train_loss = sum(train_losses) / len(train_losses)
        train_f1 = f1_score(all_train_targets, all_train_preds, average='macro')
        fit_result['train_losses'].append(train_loss)
        fit_result['train_f1'].append(train_f1)
        print(f"Train loss: {train_loss:.4f} | F1 (macro): {train_f1:.4f}")

        # === VALID ===
        model.eval()
        valid_losses = []
        all_valid_preds, all_valid_targets = [], []

        progress_bar = tqdm(valid_loader, desc=f"Epoch {epoch} [Valid]")
        with torch.no_grad():
            for batch in progress_bar:
                input_features = batch['input_features'].to(device=device, dtype=dtype)
                labels = batch['labels'].to(device=device)

                outputs = model(input_features, labels=labels)
                loss = outputs['loss']

                preds = outputs['logits'].argmax(dim=1).detach().cpu().numpy()
                targets = labels.cpu().numpy()

                all_valid_preds.extend(preds)
                all_valid_targets.extend(targets)
                valid_losses.append(loss.item())

                progress_bar.set_postfix(loss=loss.item())

        valid_loss = sum(valid_losses) / len(valid_losses)
        valid_f1 = f1_score(all_valid_targets, all_valid_preds, average='macro')

        fit_result['valid_losses'].append(valid_loss)
        fit_result['valid_f1'].append(valid_f1)

        print(f"Valid loss: {valid_loss:.4f} | F1 (macro): {valid_f1:.4f}\n")

    fit_results.append(fit_result)
    models.append(model)
    
    path = f'models/split_{i + 1}.pt'
    torch.save(model.state_dict(), path)
    break

In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float32
model = AudioClassificationModel()
model.to(device=device, dtype=dtype)
model.load_state_dict(torch.load('split_1.pt', map_location=device))
model.to(device, dtype=dtype)
model.eval()

In [17]:
class AudioTestDataset(Dataset):
    processor = AutoProcessor.from_pretrained(model_id)

    def __init__(self, audio_paths, sampling_rate=16_000):
        self.audio_paths = audio_paths
        self.sampling_rate = sampling_rate

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        waveform, sr = torchaudio.load(audio_path)
        if sr != self.sampling_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate)

        inputs = self.processor(
            waveform.squeeze(0),
            sampling_rate=self.sampling_rate,
            return_tensors="pt"
        )
        return {"input_features": inputs.input_features.squeeze(0), "id": os.path.basename(audio_path).replace('.wav','')}

In [19]:
batch_size_test = 8

In [20]:
test_audio_paths = [os.path.join(test_audio_dir, f) for f in os.listdir(test_audio_dir) if f.endswith('.wav')]
test_loader = DataLoader(AudioTestDataset(test_audio_paths),
                         batch_size=batch_size_test, shuffle=False,
                         num_workers=num_workers, pin_memory=True)

In [None]:
import os
from sklearn.metrics import f1_score
from tqdm.auto import tqdm
import torch

all_preds = []
all_ids = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Test Inference"):
        input_features = batch['input_features'].to(device=device, dtype=dtype)
        outputs = model(input_features)
        # DataParallel: берем mean по GPU, если есть
        logits = outputs['logits']
        preds = logits.argmax(dim=1).detach().cpu().numpy()

        all_preds.extend(preds)
        all_ids.extend(batch['id'])

Test Inference:   0%|          | 0/3375 [00:00<?, ?it/s]

In [None]:
submission = pd.DataFrame({
    'id': all_ids,
    'label': all_preds
})

submission.to_csv('submission.csv', index=False)