<a href="https://colab.research.google.com/github/Taewon-Park/Dacon/blob/main/%EA%B0%90%EC%A0%95%EC%9D%B8%EC%8B%9D(Audio).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Requirements
import os
import numpy as np
import pandas as pd
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from tqdm.auto import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from transformers.optimization import AdamW, get_constant_schedule_with_warmup
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, StochasticWeightAveraging
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoFeatureExtractor, HubertForSequenceClassification, AutoConfig

In [None]:
def accuracy(preds, labels):
  return (preds == labels).float().mean()


def getAudios(df):
  audios = []
  for idx, row in tqdm(df.iterrows(), total=len(df)):
    audio,_ = librosa.load(row['path'], sr=SAMPLING_RATE)
    audios.append(audio)
  return audios


class MyDataset(Dataset):
  def __init__(self, audio, audio_feature_extractor, label = None):
    if label is None:
        label = [0] * len(audio)
    self.label = np.array(label).astype(np.int64)
    self.audio = audio
    self.audio_feature_extractor = audio_feature_extractor

  def __len__(self):
    return len(self.label)

  def __getitem__(self, idx):
    label = self.label[idx]
    audio = self.audio[idx]
    audio_feature = audio_feature_extractor(raw_speech = audio, return_tensors = 'np', sampling_rate = SAMPLING_RATE)
    audio_values, audio_attn_mask = audio_feature['input_values'][0], audio_feature['attention_mask'][0]

    item = {
        'label' : label
        'audio_values' : audio_values,
        'audio_attn_mask' : audio_attn_mask,
    }

    return item


def collate_fn(samples):
  batch_labels = []
  batch_audio_values = []
  batch_audio_attn_masks = []

  for sample in samples:
    batch_labels.append(sample['label'])
    batch_audio_values.append(torch.tensro(sample['audio_values']))
    batch_audio_attn_masks.append(torch.tensor(sample['audio_attn_mask']))

  batch_labels = torch.tensor(batch_labels)
  batch_audio_values = pad_sequence(batch_audio_values, batch_first = True)
  batch_audio_attn_masks = pad_sequence(batch_audio_attn_masks, batch_first = True)

  batch = {
      'label' : batch_labels,
      'audio_values' : batch_audio_values,
      'audio_attn_mask' : batch_audio_attn_masks,
  }

  return batch


class MyLitModel(pl.LightningModule):
  def __init__(self, audio_model_name, num_labels, n_layers=1, projector=True, classifier=True, dropout=0.07, lr_decay=1):
    super(MyLitModel, self).__init__()
    self.config = AutoConfig.from_pretrained(audio_model_name)
    self.config.activation_dropout = dropout
    self.config.attention_dropout = dropout
    self.config.final_dropout = dropout
    self.config.hidden_dropout = dropout
    self.config.hidden_dropout_prob = dropout
    self.audio_model = HubertForSequenceClassification.from_pretrained(audio_model_name, config=self.config)
    self.lr_decay = lr_decay
    self._do_reinit(n_layers, projector, classifier)

  def forward(self, audio_values, audio_attn_mask):
    logits = self.audio_model(input_values = audio_values, attention_mask = audio_attn_mask).logits
    logits = torch.stack([
        logits[:, 0] + logits[:, 7],
        logits[:, 2] + logits[:, 9],
        logits[:, 5] + logits[:, 12],
        logits[:, 1] + logits[:, 8],
        logits[:, 4] + logits[:, 11],
        logits[:, 3] + logits[:, 10],]
        , dim = -1)
    return logits

  def training_step(self, batch, batch_idx):
    audio_values = batch['audio_values']
    audio_attn_mask = batch['audio_attn_mask']
    labels = batch['label']

    logits = self(audio_values, audio_attn_mask)
    loss = nn.CrossEntropyLoss()(logits, labels)

    preds = torch.argmax(logits, dim=1)
    acc = accuracy(preds, labels)

    self.log('train_loss', loss, on_step = True,  on_epoch = True, prog_bar = True, logger = True)
    self.log('train_acc', acc, on_step = True, on_epoch = True, prog_bar = True, logger = True)

    return loss

def validation_step(self, batch, batch_idx):
  audio_values = batch['audio_values']
  audio_attn_mask = batch['audio_attn_mask']
  labels = batch['label']

  logits = self(audio_values, audio_attn_mask)
  loss = nn.CrossEntropyLoss()(logits, labels)