In [1]:
from IPython.display import clear_output

!pip install pytorch_lightning transformers

import pytorch_lightning as pl
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
from scipy.stats import spearmanr
import torchmetrics
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn.metrics import recall_score
from sklearn.preprocessing import LabelEncoder
import librosa
import pickle
import os

tqdm.pandas()

clear_output()

In [2]:
import os

class Dataset_Generation(torch.utils.data.Dataset):
    @classmethod
    def get_le(cls,df,target='상황'):
        df[target] = df[target].apply(lambda x: x.lower().strip())
        le = LabelEncoder()
        le.fit(df[target])
        return le

    def get_labels(self):
        return self.labels

    def __init__(self, df, data_path, target='상황', max_sec=10, sr=16000,
                 le = None, truncate=True, test=False):

        self.processor = Wav2Vec2FeatureExtractor.from_pretrained('kresnik/wav2vec2-large-xlsr-korean')

        self.test = test
        self.truncate = truncate
        self.files = df['wav_id'].apply(lambda x: os.path.join(data_path, f'{x}.wav')).copy()

        df[target] = df[target].apply(lambda x: x.lower().strip())
        if le is None:
            self.le = LabelEncoder()
            self.labels = self.le.fit_transform(df[target].values)
        else:
            self.le = le
            self.labels = self.le.transform(df[target].values)
        self.maxlen = max_sec * sr

        self.audio_files = []
        for file_path in self.files:
            if os.path.exists(file_path):
                audio = librosa.load(file_path,sr=sr)[0]
                audio_processed = self.processor(audio, sampling_rate=sr, return_tensors="pt", padding=True).input_values.squeeze(0)
                self.audio_files.append(audio_processed)
            else:
                print(f"{file_path} 파일은 존재하지 않습니다.")
        self.files['audio'] = self.audio_files


    def __len__(self):
        return len(self.files['audio'])

    def __getitem__(self, idx):
        audio = self.files['audio'][idx]
        if not self.truncate:
            return audio, self.labels[idx]
        if (audio.shape[0] > self.maxlen):
            start = np.random.randint(audio.shape[0] - self.maxlen)
            audio = audio[start:start+self.maxlen]
        else:
            audio = torch.cat((audio, torch.zeros(self.maxlen - audio.shape[0])))
        if not self.test:
            return audio, self.labels[idx]
        else:
            return audio

In [3]:
class AudioModel(pl.LightningModule):
    def __init__(self,num_classes = 7, ckpt="kresnik/wav2vec2-large-xlsr-korean"):
        super().__init__()
        model = Wav2Vec2Model.from_pretrained(ckpt)
        model.encoder.layers = model.encoder.layers[:3]
        self.model = model
        self.model.feature_extractor._freeze_parameters()
        self.layer_weights = torch.nn.Parameter(torch.ones(4))
        self.linear = torch.nn.Linear(1024*2, num_classes)
        self.dropout = torch.nn.Dropout(0.2)
        self.preds = []
        self.labels = []

    def compute_features(self, x):
        x = self.model(input_values=x, output_hidden_states=True).hidden_states
        x = torch.stack(x,dim=1)
        weights = torch.nn.functional.softmax(self.layer_weights, dim=-1)
        mean_x = x.mean(dim = 2)
        std_x = x.std(dim = 2)
        x = torch.cat((mean_x, std_x), dim=-1)
        x = (x * weights.view(-1,4,1)).sum(dim=1)
        return x

    def forward(self, x):
        x = self.compute_features(x)
        x = self.dropout(x)
        x = self.linear(x)
        x = torch.softmax(x,dim=-1)
        return x

    def training_step(self, batch,batch_idx):
        x,y = batch
        logits = self.forward(x)
        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(logits,y)
        self.log('train_loss', loss,sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x,y = batch
        logits = self.forward(x)
        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(logits,y)
        self.log('val_loss', loss,sync_dist=True)
        logits = torch.sigmoid(logits)
        preds = logits.argmax(dim=-1).detach().cpu().numpy()
        self.preds.append(preds)
        self.labels.append(y.detach().cpu().numpy())
        return loss

    def on_validation_epoch_end(self):
        self.preds = np.concatenate(self.preds)
        self.labels = np.concatenate(self.labels)
        self.log('val_recall', recall_score(self.labels,self.preds,average='macro'), sync_dist=True)
        self.preds = []
        self.labels = []

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=5e-5)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': torch.optim.lr_scheduler.LinearLR(optimizer, 0.01, 1,total_iters=100),
                'interval': 'step',
            },
            'monitor': 'val_recall',
            'interval': 'epoch'
        }

In [9]:
from sklearn.model_selection import train_test_split

train_dataset, test_val = train_test_split(csv, train_size = 0.80, stratify=csv['상황'], random_state=77)
test_dataset, val_dataset = train_test_split(test_val, train_size = 0.50, stratify=test_val['상황'], random_state=77)

In [11]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=3,
    mode='min',
    verbose=True
)

checkpoint = ModelCheckpoint(
    monitor="val_loss", mode="min", save_weights_only=True
)

In [12]:
import os
from IPython.display import clear_output

model = AudioModel()
train_dataset = Dataset_Generation(train_dataset, data_path, max_sec = 10)
le = train_dataset.le
val_dataset = Dataset_Generation(val_dataset, data_path, max_sec = 10,le = le)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=8, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=8, shuffle=False)

trainer = Trainer(
    accelerator="gpu",
    devices="auto",
    precision=16,
    log_every_n_steps = 100,
    max_epochs=20,
    gradient_clip_val=0,
    accumulate_grad_batches=1,
    val_check_interval=1.0,
    callbacks=[checkpoint, LearningRateMonitor("step"), early_stopping],
)

clear_output()

In [13]:
trainer.fit(model, train_loader, val_loader)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/optimizer.py:375: Found unsupported keys in the optimizer configuration: {'interval'}
INFO:pytorch_lightning.callbacks.model_summary:
  | Name         | Type          | Params
-----------------------------------------------
0 | model        | Wav2Vec2Model | 50.9 M
1 | linear       | Linear        | 14.3 K
2 | dropout      | Dropout       | 0     
  | other params | n/a           | 4     
-----------------------------------------------
46.7 M    Trainable params
4.2 M     Non-trainable params
50.9 M    Total params
203.729   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved. New best score: 1.872


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.060 >= min_delta = 0.0. New best score: 1.812
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [18]:
from google.colab import drive
drive.mount('/content/drive')

import shutil

checkpoint_path = trainer.checkpoint_callback.best_model_path
drive_path = "/content/drive/MyDrive/의현/speech_best.ckpt"

shutil.copy(checkpoint_path, drive_path)
drive.flush_and_unmount()

Mounted at /content/drive
