In [1]:
import os
import numpy as np
import pandas as pd
import librosa
import torch
import torch.nn as nn
import pytorch_lightning as pl

from tqdm.auto import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import HubertForSequenceClassification, AutoFeatureExtractor, AutoConfig

import bitsandbytes as bnb

from sklearn.metrics import roc_auc_score, mean_squared_error
from sklearn.calibration import calibration_curve

In [2]:
# Constants
DATA_DIR = ''  # Adjust this path as necessary
PREPROC_DIR = './preproc'
SUBMISSION_DIR = './submission'
MODEL_DIR = './model'
SAMPLING_RATE = 16000
SEED = 42
N_FOLD = 20
BATCH_SIZE = 2
NUM_LABELS = 2
AUDIO_MODEL_NAME = 'abhishtagatya/hubert-base-960h-asv19-deepfake'

In [3]:
# Utility functions
def accuracy(preds, labels):
    return (preds == labels).float().mean()


def getAudios(df):
    audios = []
    valid_indices = []
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        try:
            audio, _ = librosa.load(row['path'], sr=SAMPLING_RATE)
            audios.append(audio)
            valid_indices.append(idx)
        except FileNotFoundError:
            print(f"File not found: {row['path']}. Skipping.")
        except Exception as e:
            print(f"Error loading {row['path']}: {e}. Skipping.")
    return audios, valid_indices


# Dataset class
class MyDataset(Dataset):
    def __init__(self, audio, audio_feature_extractor, labels=None):
        if labels is None:
            labels = [[0] * NUM_LABELS for _ in range(len(audio))]
        self.labels = np.array(labels).astype(np.float32)
        self.audio = audio
        self.audio_feature_extractor = audio_feature_extractor

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels[idx]
        audio = self.audio[idx]
        audio_feature = self.audio_feature_extractor(raw_speech=audio, return_tensors='np', sampling_rate=SAMPLING_RATE)
        audio_values, audio_attn_mask = audio_feature['input_values'][0], audio_feature['attention_mask'][0]

        item = {
            'label': label,
            'audio_values': audio_values,
            'audio_attn_mask': audio_attn_mask,
        }

        return item


# Collate function
def collate_fn(samples):
    batch_labels = []
    batch_audio_values = []
    batch_audio_attn_masks = []

    for sample in samples:
        batch_labels.append(sample['label'])
        batch_audio_values.append(torch.tensor(sample['audio_values']))
        batch_audio_attn_masks.append(torch.tensor(sample['audio_attn_mask']))

    batch_labels = np.array(batch_labels)
    batch_labels = torch.tensor(batch_labels)
    batch_audio_values = pad_sequence(batch_audio_values, batch_first=True)
    batch_audio_attn_masks = pad_sequence(batch_audio_attn_masks, batch_first=True)

    batch = {
        'label': batch_labels,
        'audio_values': batch_audio_values,
        'audio_attn_mask': batch_audio_attn_masks,
    }

    return batch


def expected_calibration_error(y_true, y_prob, n_bins=10):
    prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy='uniform')
    bin_totals = np.histogram(y_prob, bins=np.linspace(0, 1, n_bins + 1), density=False)[0]
    non_empty_bins = bin_totals > 0
    bin_weights = bin_totals / len(y_prob)
    bin_weights = bin_weights[non_empty_bins]
    prob_true = prob_true[:len(bin_weights)]
    prob_pred = prob_pred[:len(bin_weights)]
    ece = np.sum(bin_weights * np.abs(prob_true - prob_pred))
    return ece


def auc_brier_ece(labels, preds):
    auc_scores = []
    brier_scores = []
    ece_scores = []

    for i in range(labels.shape[1]):
        y_true = labels[:, i]
        y_prob = preds[:, i]

        # AUC
        try:
            auc = roc_auc_score(y_true, y_prob)
        except ValueError:
            auc = 0.0
        auc_scores.append(auc)

        # Brier Score
        brier = mean_squared_error(y_true, y_prob)
        brier_scores.append(brier)

        # ECE
        ece = expected_calibration_error(y_true, y_prob)
        ece_scores.append(ece)

    mean_auc = np.mean(auc_scores)
    mean_brier = np.mean(brier_scores)
    mean_ece = np.mean(ece_scores)

    combined_score = 0.5 * (1 - mean_auc) + 0.25 * mean_brier + 0.25 * mean_ece
    return mean_auc, mean_brier, mean_ece, combined_score


# Lightning Model class
class MyLitModel(pl.LightningModule):
    def __init__(self, audio_model_name, num_labels, n_layers=1, projector=True, classifier=True, dropout=0.07,
                 lr_decay=1):
        super(MyLitModel, self).__init__()
        self.config = AutoConfig.from_pretrained(audio_model_name, num_labels=num_labels)
        self.config.activation_dropout = dropout
        self.config.attention_dropout = dropout
        self.config.final_dropout = dropout
        self.config.hidden_dropout = dropout
        self.config.hidden_dropout_prob = dropout
        self.audio_model = HubertForSequenceClassification.from_pretrained(audio_model_name, config=self.config)
        self.lr_decay = lr_decay
        self._do_reinit(n_layers, projector, classifier)

    def forward(self, audio_values, audio_attn_mask):
        logits = self.audio_model(input_values=audio_values, attention_mask=audio_attn_mask).logits
        return logits

    def training_step(self, batch, batch_idx):
        audio_values = batch['audio_values']
        audio_attn_mask = batch['audio_attn_mask']
        labels = batch['label']

        logits = self(audio_values, audio_attn_mask)
        loss = nn.MultiLabelSoftMarginLoss()(logits, labels)

        preds = torch.sigmoid(logits).detach().cpu().numpy()
        true_labels = labels.detach().cpu().numpy()

        try:
            # Calculate combined score
            _, _, _, train_combined_score = auc_brier_ece(true_labels, preds)
        except ValueError:
            train_combined_score = 0.0

        # Log combined score
        self.log('train_combined_score', train_combined_score, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        audio_values = batch['audio_values']
        audio_attn_mask = batch['audio_attn_mask']
        labels = batch['label']

        logits = self(audio_values, audio_attn_mask)
        loss = nn.MultiLabelSoftMarginLoss()(logits, labels)

        preds = torch.sigmoid(logits).detach().cpu().numpy()
        true_labels = labels.detach().cpu().numpy()

        try:
            # Calculate combined score
            _, _, _, val_combined_score = auc_brier_ece(true_labels, preds)
        except ValueError:
            val_combined_score = 0.0

        # Log combined score
        self.log('val_combined_score', val_combined_score, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return loss

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        audio_values = batch['audio_values']
        audio_attn_mask = batch['audio_attn_mask']

        logits = self(audio_values, audio_attn_mask)
        probs = torch.sigmoid(logits)

        return probs

    def configure_optimizers(self):
        lr = 1e-5
        layer_decay = self.lr_decay
        weight_decay = 0.01
        llrd_params = self._get_llrd_params(lr=lr, layer_decay=layer_decay, weight_decay=weight_decay)
        optimizer = bnb.optim.Adam8bit(llrd_params)
        return optimizer

    def _get_llrd_params(self, lr, layer_decay, weight_decay):
        n_layers = self.audio_model.config.num_hidden_layers
        llrd_params = []
        for name, value in list(self.named_parameters()):
            if ('bias' in name) or ('layer_norm' in name):
                llrd_params.append({"params": value, "lr": lr, "weight_decay": 0.0})
            elif ('emb' in name) or ('feature' in name):
                llrd_params.append(
                    {"params": value, "lr": lr * (layer_decay ** (n_layers + 1)), "weight_decay": weight_decay})
            elif 'encoder.layer' in name:
                for n_layer in range(n_layers):
                    if f'encoder.layer.{n_layer}' in name:
                        llrd_params.append(
                            {"params": value, "lr": lr * (layer_decay ** (n_layer + 1)), "weight_decay": weight_decay})
            else:
                llrd_params.append({"params": value, "lr": lr, "weight_decay": weight_decay})
        return llrd_params

    def _do_reinit(self, n_layers=0, projector=True, classifier=True):
        if projector:
            self.audio_model.projector.apply(self._init_weight_and_bias)
        if classifier:
            self.audio_model.classifier.apply(self._init_weight_and_bias)

        for n in range(n_layers):
            self.audio_model.hubert.encoder.layers[-(n + 1)].apply(self._init_weight_and_bias)

    def _init_weight_and_bias(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.audio_model.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

In [4]:
seed_everything(SEED)

# 사운드 특징 추출
audio_feature_extractor = AutoFeatureExtractor.from_pretrained(AUDIO_MODEL_NAME)
audio_feature_extractor.return_attention_mask = True

# 데이터 로드
train_df = pd.read_csv('./shuffled_audio_final.csv')
print(f"Train DataFrame shape: {train_df.shape}")

Seed set to 42


Train DataFrame shape: (99999, 2)


In [5]:
train_df['path'] = train_df['path'].apply(lambda x: os.path.join(DATA_DIR, x))

In [6]:
# 싱글 라벨을 멀티 라벨로 변환
train_df['label'] = train_df['label'].apply(
    lambda x: [0, 1] if x in [1, 4] else (
        [1, 0] if x in [2, 5] else (
            [1, 1] if x == 3 else [0, 0]
        )
    )
)

In [7]:
train_audios, valid_indices = getAudios(train_df)
print(f"Number of valid train audios: {len(train_audios)}")
train_df = train_df.iloc[valid_indices].reset_index(drop=True)
train_labels = np.array(train_df['label'].tolist())


  0%|          | 0/99999 [00:00<?, ?it/s]

  audio, _ = librosa.load(row['path'], sr=SAMPLING_RATE)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


File not found: ./audio_final/3_26_488.ogg. Skipping.
File not found: ./audio_final/3_26_494.ogg. Skipping.
File not found: ./train/ATF1P.ogg. Skipping.
File not found: ./audio_final/3_26_476.ogg. Skipping.
File not found: ./audio_final/3_26_477.ogg. Skipping.
File not found: ./audio_final/3_26_474.ogg. Skipping.
File not found: ./audio_final/3_26_497.ogg. Skipping.
File not found: ./audio_final/3_26_471.ogg. Skipping.
File not found: ./audio_final/3_26_496.ogg. Skipping.
File not found: ./audio_final/3_26_490.ogg. Skipping.
File not found: ./audio_final/3_26_485.ogg. Skipping.
File not found: ./audio_final/3_26_489.ogg. Skipping.
File not found: ./audio_final/3_26_473.ogg. Skipping.
File not found: ./audio_final/3_26_484.ogg. Skipping.
File not found: ./audio_final/3_26_492.ogg. Skipping.
File not found: ./audio_final/3_26_486.ogg. Skipping.
File not found: ./audio_final/3_26_487.ogg. Skipping.
File not found: ./audio_final/3_26_481.ogg. Skipping.
File not found: ./audio_final/3_26_49

In [None]:
skf = StratifiedKFold(n_splits=N_FOLD, shuffle=True, random_state=SEED)
for fold_idx, (train_indices, val_indices) in enumerate(
        skf.split(train_labels, train_labels.argmax(axis=1))):

    if fold_idx < 6 :
        continue
    
    print(
        f"Fold {fold_idx}: Train indices length: {len(train_indices)}, Validation indices length: {len(val_indices)}")
    train_fold_audios = [train_audios[train_index] for train_index in train_indices]
    val_fold_audios = [train_audios[val_index] for val_index in val_indices]

    train_fold_labels = train_labels[train_indices]
    val_fold_labels = train_labels[val_indices]
    train_fold_ds = MyDataset(train_fold_audios, audio_feature_extractor, train_fold_labels)
    val_fold_ds = MyDataset(val_fold_audios, audio_feature_extractor, val_fold_labels)
    train_fold_dl = DataLoader(train_fold_ds, batch_size=BATCH_SIZE, collate_fn=collate_fn)
    val_fold_dl = DataLoader(val_fold_ds, batch_size=BATCH_SIZE * 2, collate_fn=collate_fn)

    checkpoint_acc_callback = ModelCheckpoint(
        monitor='val_combined_score',
        dirpath=MODEL_DIR,
        filename=f'fold_{fold_idx}' + '_{epoch:02d}-{val_combined_score:.4f}',
        save_top_k=1,
        mode='min'
    )

    my_lit_model = MyLitModel(
        audio_model_name=AUDIO_MODEL_NAME,
        num_labels=NUM_LABELS,
        n_layers=1, projector=True, classifier=True, dropout=0.07, lr_decay=0.8
    )

    trainer = pl.Trainer(
        accelerator='cuda',
        max_epochs=3,
        precision='16-mixed',
        val_check_interval=0.1,
        callbacks=[checkpoint_acc_callback],
        accumulate_grad_batches=2
        # batch_size * accumulate_grad_batches = 가 실질적인 배치 사이즈임. (vram 은 batch_size 기준으로 소모함.)
    )

    print(f"Starting training for fold {fold_idx}...")
    trainer.fit(my_lit_model, train_fold_dl, val_fold_dl)
    print(f"Training completed for fold {fold_idx}.")

    del my_lit_model

Fold 6: Train indices length: 94971, Validation indices length: 4999


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
C:\Users\shsmc\Downloads\wavemotion\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch

Starting training for fold 6...



  | Name        | Type                            | Params | Mode
-----------------------------------------------------------------------
0 | audio_model | HubertForSequenceClassification | 94.6 M | eval
-----------------------------------------------------------------------
94.6 M    Trainable params
0         Non-trainable params
94.6 M    Total params
378.276   Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

C:\Users\shsmc\Downloads\wavemotion\venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
  attn_output = torch.nn.functional.scaled_dot_product_attention(
C:\Users\shsmc\Downloads\wavemotion\venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …