# Imports

In [1]:
import numpy as np
import pandas as pd
import os
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import random

from transformers import BertConfig, BertModel, Trainer, TrainingArguments

gc.collect()

2025-05-20 10:41:16.726479: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747737676.749083     194 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747737676.756519     194 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


90

# Loading data

In [2]:
train_df = pd.read_parquet('/kaggle/input/small-music-genre-dataset/dataset/train.parquet')
val_df = pd.read_parquet('/kaggle/input/small-music-genre-dataset/dataset/val.parquet')
test_df = pd.read_parquet('/kaggle/input/small-music-genre-dataset/dataset/test.parquet')
train_df.head()

Unnamed: 0,main_genre,spec_path
0,classical,./dataset/spectrograms/332.npy
1,ska,./dataset/spectrograms/2882.npy
2,country,./dataset/spectrograms/421.npy
3,afrobeat,./dataset/spectrograms/79.npy
4,psychedelic,./dataset/spectrograms/41.npy


In [3]:
def fix_path(x):
    return x.replace('.', '/kaggle/input/small-music-genre-dataset', 1)

train_df['spec_path'] = train_df['spec_path'].apply(fix_path)
val_df['spec_path'] = val_df['spec_path'].apply(fix_path)
test_df['spec_path'] = test_df['spec_path'].apply(fix_path)

# Dataset and Model classes

### Dataset def

In [4]:
class SpectrogramDataset(Dataset):
    def __init__(self, 
                dataframe, 
                genre2label,
                transform=None,
                # Параметры аугментаций
                time_mask_param=15,
                freq_mask_param=8,
                noise_level=0.005,
                mixup_alpha=0.4):
        
        self.df = dataframe.reset_index(drop=True)
        self.genre2label = genre2label
        self.transform = transform
        
        # Параметры аугментаций
        self.time_mask_param = time_mask_param
        self.freq_mask_param = freq_mask_param
        self.noise_level = noise_level
        self.mixup_alpha = mixup_alpha

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Основная загрузка данных
        row = self.df.iloc[idx]
        spec = np.load(row['spec_path']).astype(np.float32)
        attention_mask = torch.ones(spec.shape[0], dtype=torch.long)
        label = self.genre2label[row['main_genre']]

        # Аугментации (применяются с вероятностью 70%)
        if self.transform and random.random() < 0.7:
            # 1. Временное маскирование
            spec = self.time_mask(spec)
            
            # 2. Частотное маскирование
            spec = self.frequency_mask(spec)
            
            # 3. Добавление шума
            spec = self.add_gaussian_noise(spec)
            
            # 4. Случайный временной сдвиг
            spec, attention_mask = self.random_time_shift(spec, attention_mask)
            
            # 5. Mixup аугментация
            if random.random() < 0.3:  # 30% вероятность
                spec, label = self.mixup(spec, label, idx)

        return {
            "spectrogram": torch.tensor(spec, dtype=torch.float32),
            "attention_mask": attention_mask,
            "labels": torch.tensor(label, dtype=torch.long)
        }

    def time_mask(self, spec):
        max_mask_length = int(spec.shape[0] * 0.2)  # Макс. 20% длины
        mask_length = random.randint(1, min(self.time_mask_param, max_mask_length))
        mask_start = random.randint(0, spec.shape[0] - mask_length)
        spec[mask_start:mask_start+mask_length, :] = 0
        return spec

    def frequency_mask(self, spec):
        mask_length = random.randint(1, self.freq_mask_param)
        mask_start = random.randint(0, spec.shape[1] - mask_length)
        spec[:, mask_start:mask_start+mask_length] = 0
        return spec

    def add_gaussian_noise(self, spec):
        noise = np.random.normal(0, self.noise_level, spec.shape)
        return np.clip(spec + noise, 0, 1)

    def random_time_shift(self, spec, mask):
        shift = random.randint(-int(spec.shape[0]*0.1), int(spec.shape[0]*0.1))
        if shift > 0:
            spec = np.pad(spec, ((shift,0), (0,0)), mode='constant')[:-shift]
            mask = torch.cat([mask[shift:], torch.zeros(shift, dtype=torch.long)])
        elif shift < 0:
            spec = np.pad(spec, ((0,-shift), (0,0)), mode='constant')[-shift:]
            mask = torch.cat([torch.zeros(-shift, dtype=torch.long), mask[:shift]])
        return spec, mask

    def mixup(self, spec1, label1, idx1):
        idx2 = random.randint(0, len(self)-1)
        row2 = self.df.iloc[idx2]
        spec2 = np.load(row2['spec_path']).astype(np.float32)
        label2 = self.genre2label[row2['main_genre']]
        
        lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
        mixed_spec = lam * spec1 + (1 - lam) * spec2
        mixed_label = lam * label1 + (1 - lam) * label2
        
        return mixed_spec, mixed_label

    def random_scaling(self, spec):
        scale_factor = random.choice([
            lambda x: x * random.uniform(0.8, 1.2),      # Амплитудное масштабирование
            lambda x: np.log1p(x * random.uniform(0.5, 2)),  # Нелинейное преобразование
            lambda x: x ** random.uniform(0.5, 1.5)      # Степенное преобразование
        ])
        return np.clip(scale_factor(spec), 0, 1)


### Model def

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel

class SpectrogramBertClassifier(nn.Module):
    def __init__(self, input_feature_dim, num_labels, mlm_prob, bert_config):
        super().__init__()
        self.mlm_prob = mlm_prob
        self.num_labels = num_labels
        self.bert_config = bert_config

        # Parameters for masking patches (tweak these if needed)
        self.min_patch_len = 5
        self.max_patch_len = 25
        
        # Projection of spectrograms
        self.projection = nn.Linear(input_feature_dim, bert_config.hidden_size)
        self.norm = nn.LayerNorm(bert_config.hidden_size)
        
        # BERT components
        self.bert = BertModel(bert_config)
        # Инициализируем mask token: можно добавить небольшое отклонение
        self.mask_token = nn.Parameter(torch.zeros(bert_config.hidden_size))
        
        # CLS token как обучаемый параметр с контролируемой инициализацией
        self.cls_token = nn.Parameter(torch.randn(1, 1, bert_config.hidden_size))
        nn.init.normal_(self.cls_token, std=bert_config.initializer_range)
        
        # Упрощённая классификационная голова
        self.classifier = nn.Linear(bert_config.hidden_size, num_labels)
        self.classifier.weight.data.normal_(mean=0.0, std=bert_config.initializer_range)
        self.classifier.bias.data.zero_()

        # MLM head (если требуется)
        self.mlm_norm = nn.LayerNorm(bert_config.hidden_size)
        self.mlm_head = nn.Linear(bert_config.hidden_size, input_feature_dim)
        self.loss_fct = nn.CrossEntropyLoss()

    def _create_patch_mask(self, batch_size, seq_len, attention_mask):
        """Улучшенное маскирование с защитой границ"""
        device = attention_mask.device
        mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device)
        
        for b in range(batch_size):
            valid_indices = torch.where(attention_mask[b])[0]
            if len(valid_indices) < self.min_patch_len:
                continue
                
            # Динамическое количество патчей
            max_possible = len(valid_indices) // self.min_patch_len
            num_patches = min(max(1, int(self.mlm_prob * seq_len / self.min_patch_len)), max_possible)
            
            for _ in range(num_patches):
                patch_len = torch.randint(self.min_patch_len, self.max_patch_len + 1, (1,)).item()
                # Если оставшихся валидных индексов меньше, чем длина патча – пропускаем итерацию
                if len(valid_indices) < patch_len:
                    continue
                    
                start = torch.randint(0, len(valid_indices) - patch_len + 1, (1,)).item()
                start_idx = valid_indices[start].item()
                end_idx = min(start_idx + patch_len, seq_len)
                
                # Защита первых/последних 10% кадров
                if start_idx < seq_len * 0.1 or end_idx > seq_len * 0.9:
                    continue
                    
                mask[b, start_idx:end_idx] = True
                
        return mask

    def forward(self, spectrogram, attention_mask, labels=None):
        B, T, _ = spectrogram.size()
        
        # 1. Проекция и нормализация
        projected = self.norm(self.projection(spectrogram))
        
        # Если требуются позиционные энкодинги, их можно добавить здесь
        
        # 2. Добавление CLS токена
        cls_tokens = self.cls_token.expand(B, -1, -1)
        # Собираем эмбеддинги: CLS токен + проекции
        inputs_embeds = torch.cat([cls_tokens, projected], dim=1)
        
        # 3. Создание маски для патчей и применение mask token
        patch_mask = self._create_patch_mask(B, T, attention_mask)
        # Получаем эмбеддинги без CLS токена, делаем клон для безопасного in-place изменения
        embeds_without_cls = inputs_embeds[:, 1:].clone()
        # Применяем маску: там, где patch_mask==True, заменяем на mask_token
        embeds_without_cls[patch_mask] = self.mask_token
        # Собираем итоговый тензор с обновленными эмбеддингами
        inputs_embeds = torch.cat([cls_tokens, embeds_without_cls], dim=1)
        
        # 4. Расширяем маску внимания для CLS токена
        extended_mask = torch.cat([
            torch.ones(B, 1, device=attention_mask.device),
            attention_mask
        ], dim=1)
        
        # 5. Пропуск через BERT
        outputs = self.bert(
            inputs_embeds=inputs_embeds,
            attention_mask=extended_mask
        )
        
        
        hidden_states = outputs.last_hidden_state[:, 1:]
        mlm_predictions = self.mlm_head(self.mlm_norm(hidden_states))

        # for r^2 metrics
        mlm_targets = spectrogram[patch_mask]
        mlm_preds = mlm_predictions[patch_mask]

        cls_output = outputs.last_hidden_state[:, 0]
        classification_logits = self.classifier(cls_output)
        
        # 7. Вычисление потерь
        loss_dict = {}
        if labels is not None:

            # Расчет accuracy
            predicted_labels = torch.argmax(classification_logits, dim=-1)
            accuracy = (predicted_labels == labels).float().mean().cpu().detach().numpy()
            
            classification_loss = self.loss_fct(classification_logits, labels)
            mlm_loss = F.mse_loss(mlm_predictions[patch_mask], spectrogram[patch_mask])

            # R² Calculation
            ss_res = torch.sum((mlm_targets - mlm_preds)**2)
            ss_tot = torch.sum((mlm_targets - torch.mean(mlm_targets))**2)
            r_squared = 1.0 - (ss_res / (ss_tot + 1e-8)).cpu().detach().numpy()  # Добавлен epsilon для стабильности
            
            
            loss = 0.5 * classification_loss + 0.5 * mlm_loss 
            
            loss_dict = {
                "loss": loss,
                "classification_loss": classification_loss,
                "mlm_loss": mlm_loss,
                "accuracy": accuracy,
                "r_squared": r_squared
            }
    
        return {
            **loss_dict,
            "cls_embedding": cls_output.detach(),
            "classification_logits": classification_logits,
            "mlm_predictions": mlm_predictions,
            "patch_mask": patch_mask
        }


# Training

### Data prep

In [6]:
# mapping from genre to integer label for training
unique_genres = pd.concat([train_df['main_genre'], val_df['main_genre'], test_df['main_genre']]).unique()
genre2label = {genre: idx for idx, genre in enumerate(unique_genres)}
num_labels = len(genre2label)


train_dataset = SpectrogramDataset(
    dataframe=train_df,
    genre2label=genre2label,
    transform=True,
    time_mask_param=2,    # Макс. длина временной маски
    freq_mask_param=2,    # Макс. ширина частотной маски
    noise_level=0.01,      
    mixup_alpha=0.1       
)
val_dataset = SpectrogramDataset(val_df, genre2label)
test_dataset = SpectrogramDataset(test_df, genre2label)

In [7]:
# input feature dimension (F).
example_spec = np.load(train_df.loc[4, 'spec_path']).astype(np.float32)
input_feature_dim = example_spec.shape[1]

print(input_feature_dim)

1292


### Device choice

In [8]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    

### Model config and init

In [9]:
from transformers import BertConfig

custom_config = BertConfig(
    vocab_size=1,  
    hidden_size=768,
    num_hidden_layers=8,
    num_attention_heads=8,
    intermediate_size=2048,
    max_position_embeddings=2048,
)


model = SpectrogramBertClassifier(
    input_feature_dim=input_feature_dim,
    num_labels=num_labels,
    mlm_prob=0.15,
    bert_config=custom_config
)


model.to(device)
print("Using device:", device)


def data_collator(features):
    batch = {}
    # padding
    batch["spectrogram"] = torch.nn.utils.rnn.pad_sequence(
        [f["spectrogram"] for f in features], batch_first=True, padding_value=0.0
    )
    batch["attention_mask"] = torch.nn.utils.rnn.pad_sequence(
        [f["attention_mask"] for f in features], batch_first=True, padding_value=0
    )
    batch["labels"] = torch.stack([f["labels"] for f in features])
    return batch

Using device: cuda


### WANDB init

In [10]:
import wandb
wandb.login(key='f9e64e1618bc54c2c1fae6cfae56e2905541c9cb')

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mostgot[0m ([33mkomandakomanda[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [11]:
import os
os.environ["WANDB_PROJECT"]="audio-bert-final"

### HF Trainer setup and training

In [12]:
import numpy as np
import math
import torch
from transformers import Trainer, TrainingArguments

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        loss = outputs.get("loss")
        return (loss, outputs) if return_outputs else loss

    def training_step(self, model, inputs, p):
        model.train()
        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
        loss.backward()

        # Логирование метрик для классификации и MLM
        logs = {
            "train_loss": loss.item(),
            "train_mlm_loss": outputs.get("mlm_loss", 0).item(),
            "train_classification_loss": outputs.get("classification_loss", 0).item(),
            "train_accuracy": outputs.get("accuracy", 0),
            "train_r_squared": outputs.get("r_squared", 0)
        }
        self.log(logs)
        return loss.detach()

    def evaluation_step(self, model, inputs):
        with torch.no_grad():
            outputs = model(**inputs)
        
        return {
            "loss": outputs.get("loss").item() if outputs.get("loss") is not None else None,
            "mlm_loss": outputs.get("mlm_loss", 0).item(),
            "classification_loss": outputs.get("classification_loss", 0).item(),
            "accuracy": outputs.get("accuracy", 0),
            "r_squared": outputs.get("r_squared", 0)
        }

    def evaluate(self, eval_dataset=None, metric_key_prefix="eval", **kwargs):
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        
        metric_keys = ["loss", "mlm_loss", "classification_loss", "accuracy", "r_squared"]
        all_metrics = {k: [] for k in metric_keys}

        for inputs in eval_dataloader:
            batch_metrics = self.evaluation_step(self.model, inputs)
            for k in metric_keys:
                if batch_metrics[k] is not None:
                    all_metrics[k].append(batch_metrics[k])

        # Агрегация метрик, исключая NaN
        eval_metrics = {}
        for k in metric_keys:
            values = [v for v in all_metrics[k] if not math.isnan(v)]
            if values:
                eval_metrics[f"{metric_key_prefix}_{k}"] = np.mean(values)
        
        self.log(eval_metrics)
        return eval_metrics


from transformers import TrainerCallback

class DistributionMonitor(TrainerCallback):
    def on_evaluate(self, args, state, control, **kwargs):
        model.eval()
        with torch.no_grad():
            sample = next(iter(val_dataloader))
            outputs = model(**sample)
            
            # CLS vs среднее проекций
            plt.figure(figsize=(12,5))
            plt.subplot(121)
            plt.hist(outputs["cls_embedding"].cpu().numpy().flatten(), bins=50, alpha=0.5, label='CLS')
            plt.hist(outputs["mlm_predictions"].cpu().numpy().flatten(), bins=50, alpha=0.5, label='MLM')
            plt.legend()
            
            # Градиентные распределения
            plt.subplot(122)
            grads = [p.grad.cpu().numpy().flatten() for p in model.parameters() if p.grad is not None]
            plt.hist(np.concatenate(grads), bins=100, log=True)
            plt.title("Gradient Distribution")
            plt.show()


training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=15,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    logging_steps=3,
    eval_steps=3,
    learning_rate=5e-4,
    weight_decay=0.01,
    eval_strategy='steps',
    report_to="wandb",
    run_name="audio-bert-classification-acc-r2-2",
)

# Создаем кастомный тренер
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
)

trainer.add_callback(DistributionMonitor())

# Запускаем обучение
trainer.train()


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Step,Training Loss,Validation Loss


TypeError: Object of type ndarray is not JSON serializable

# Evaluating work

In [None]:
train_hist_df = pd.DataFrame(trainer.state.log_history)

train_hist_df

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.lineplot(train_hist_df[['step', 'loss', 'eval_loss']].set_index('step'))

In [None]:
sns.lineplot(train_hist_df[['step', 'eval_cls_loss', 'eval_mlm_loss']].set_index('step'))

# CLS embedding analysis

In [None]:
! ls results/checkpoint-770

In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder

### 3 genres

In [None]:
genres = 3
loaders = []

for i in range(genres):
    tmp_df = test_df[test_df['main_genre'] == test_df['main_genre'].unique()[i]]
    dataset = SpectrogramDataset(tmp_df, genre2label)
    loader = DataLoader(dataset, batch_size=8, collate_fn=data_collator)
    loaders.append(loader)


all_cls, all_labels = [], []

for loader in loaders:
    # Сбор эмбеддингов и меток
    
    for batch in tqdm(loader):
        
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch['labels']
        batch.pop('labels')
    
        outputs = model(**batch)
        all_cls.append(outputs["cls_embedding"].cpu())
        all_labels.append(labels.cpu())

all_cls = torch.cat(all_cls)
all_labels = torch.cat(all_labels)
    
# Визуализация через PCA
    
pca = PCA(n_components=2)
cls_2d = pca.fit_transform(all_cls.detach().numpy())
    
plt.scatter(cls_2d[:, 0], cls_2d[:, 1], c=all_labels.numpy(), cmap='viridis')
plt.title("PCA [CLS] Embeddings")
plt.show()

In [None]:
! pip install umap-learn -q


In [None]:
import umap
import matplotlib.pyplot as plt
import numpy as np

# Преобразование тензора в numpy array (предполагаем, что all_cls и all_labels уже определены)
cls_embeddings = all_cls.detach().numpy()
labels = all_labels.numpy()

# Инициализация и применение UMAP с оптимальными параметрами
umap_reducer = umap.UMAP(
    n_components=2,
    n_neighbors=15,          # Оптимально для баланса локальных/глобальных структур
    min_dist=0.1,            # Плотность кластеров
    metric='cosine',         # Лучше всего для эмбеддингов
    random_state=42          # Репроизводимость
)
cls_2d_umap = umap_reducer.fit_transform(cls_embeddings)

# Настройка визуализации
plt.figure(figsize=(10, 8))
scatter = plt.scatter(
    cls_2d_umap[:, 0], 
    cls_2d_umap[:, 1],
    c=labels,
    cmap='viridis',
    s=25,                    # Размер точек
    alpha=0.7,               # Прозрачность
    edgecolor='none'         # Убираем границы точек
)

# Дополнительные элементы графика
plt.title('UMAP Projection of CLS Embeddings', pad=20, fontsize=14)
plt.xlabel('UMAP Dimension 1', labelpad=10)
plt.ylabel('UMAP Dimension 2', labelpad=10)
plt.grid(alpha=0.3)          # Сетка с прозрачностью

# Цветовая легенда
cbar = plt.colorbar(scatter, pad=0.01)
cbar.set_label('Class Labels', rotation=270, labelpad=20)

# Оптимизация расположения
plt.tight_layout()
plt.show()

In [None]:
X_train.shape