In [1]:
from IPython.display import clear_output

!pip install pytorch_lightning transformers

clear_output()

In [2]:
import torch
import pandas as pd
import pytorch_lightning as pl
import numpy as np
import torchmetrics

from google.colab import drive
drive.mount('/content/drive')

csv3_path = '/content/drive/MyDrive/의현/감정 분류를 위한 대화 음성 데이터셋/4차년도.csv'
wav_path = '/content/drive/MyDrive/의현/감정 분류를 위한 대화 음성 데이터셋/4차년도.zip'
!mkdir ./wav/
!cp -r "$csv3_path" ./
!cp -r "$wav_path" ./wav/

text_model_path = '/content/drive/MyDrive/의현/text_best.ckpt'
speech_model_path = '/content/drive/MyDrive/의현/speech_best.ckpt'
!mkdir ./trained/
!cp -r "$text_model_path" ./trained/
!cp -r "$speech_model_path" ./trained/

%cd /content/drive/MyDrive/의현

from speech_audio_classification import RequestsDataset, AudioModel
from sklearn.preprocessing import LabelEncoder
from speech_classification_bert import TranscriptData, Classifier
from sklearn.metrics import recall_score

%cd /content

drive.flush_and_unmount()

Mounted at /content/drive
/content/drive/MyDrive/의현
/content


In [3]:
class MultiModalData(torch.utils.data.Dataset):
    def __init__(self,df, le = None, target = '상황'):
        if le is None:
            self.le = LabelEncoder()
            self.le.fit(df[target])
        else:
            self.le = le
        self.audio_dataset = RequestsDataset(df, data_path='./file/',le=le,target=target)
        self.text_dataset = TranscriptData(df, le=le,target=target)

    def __len__(self):
        return len(self.audio_dataset)

    def __getitem__(self, idx):
        audio,labels = self.audio_dataset[idx]
        text, label_text = self.text_dataset[idx]
        assert (label_text == labels).all()
        return audio, text, labels

In [13]:
class LateFusion(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.audio_model = AudioModel.load_from_checkpoint('./trained/speech_best.ckpt',num_classes=5)
        self.text_model = Classifier.load_from_checkpoint('./trained/text_best.ckpt')
        self.proj = torch.nn.Linear(773,512)
        self.dropout = torch.nn.Dropout(0.2)
        self.classifier = torch.nn.Linear(512, 5)
        self.preds = np.array([])
        self.labels = np.array([])
        self.train_acc = torchmetrics.Accuracy(num_classes=5, average='macro', task="multiclass")
        self.valid_acc = torchmetrics.Accuracy(num_classes=5, average='macro', task="multiclass")

    def freeze(self):
        for param in self.audio_model.parameters():
            param.requires_grad = False
        for param in self.text_model.parameters():
            param.requires_grad = False

    def forward(self, audio_input, text_input):
        audio_emb = self.audio_model.forward(audio_input)
        text_emb = self.text_model.model(**text_input)
        text_emb = text_emb.last_hidden_state[:,0]
        x = torch.cat([audio_emb,text_emb], dim=1)
        x = self.dropout(x)
        x = self.proj(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.classifier(x)
        x = torch.softmax(x, dim=1)
        return x

    def training_step(self, batch, batch_idx):
        audio_input, text_input, y = batch
        y_hat = self(audio_input, text_input)
        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(y_hat, y)
        self.log('train_acc', self.train_acc(y_hat.argmax(dim=1), y), prog_bar = True)
        self.log('train_loss', loss, prog_bar = True)
        return loss

    def test_step(self, batch, batch_idx):
        audio_input, text_input, y = batch
        y_hat = self(audio_input, text_input)
        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(y_hat, y)
        self.log('test_acc', self.valid_acc(y_hat.argmax(dim=1), y), prog_bar = True)
        self.log('test_loss', loss, prog_bar = True)
        self.preds = np.append(self.preds, y_hat.argmax(dim=1).cpu().numpy())
        self.labels = np.append(self.labels, y.cpu().numpy())
        return loss

    def validation_step(self, batch, batch_idx):
        audio_input, text_input, y = batch
        y_hat = self(audio_input, text_input)
        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(y_hat, y)
        self.log('val_acc', self.valid_acc(y_hat.argmax(dim=1), y), prog_bar = True)
        self.log('val_loss', loss, prog_bar = True)
        self.preds = np.append(self.preds, y_hat.argmax(dim=1).cpu().numpy())
        self.labels = np.append(self.labels, y.cpu().numpy())
        return loss

    def on_validation_epoch_end(self):
        self.log('val_recall', recall_score(self.labels, self.preds, average='macro'))
        self.preds = np.array([])
        self.labels = np.array([])

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=5e-5)
        return optimizer

In [5]:
import os
import shutil
from sys import platform
from glob import glob
!unzip './wav/4차년도.zip' -d file/
clear_output()

In [6]:
csv3_data_path = './4차년도.csv'
csv3 = pd.read_csv(csv3_data_path, encoding = 'CP949')
data_path = './file/'

In [7]:
from sklearn.model_selection import train_test_split

train_size = 0.80

train, val = train_test_split(csv3, train_size = train_size, stratify=csv3['상황'], random_state=77)

In [14]:
if __name__ == '__main__':
    train_df = train
    dev_df = val

    train_dataset = MultiModalData(train_df)
    dev_dataset = MultiModalData(dev_df, le=train_dataset.le)

    train_loader = torch.utils.data.DataLoader(train_dataset, num_workers = 4, batch_size=8, shuffle=True)
    dev_loader = torch.utils.data.DataLoader(dev_dataset, num_workers = 4, batch_size=8, shuffle=False)

    model = LateFusion()
    checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath='late_ckpts',monitor='val_recall',save_top_k=1,mode='max')
    logger = pl.loggers.TensorBoardLogger(save_dir='logs/')
    trainer = pl.Trainer(
        devices = 'auto',
        accelerator='gpu',
        callbacks=[checkpoint_callback],
        max_epochs=5,
        precision=16,
        logger=logger
    )

    model.freeze()
    trainer.fit(model, train_loader, dev_loader)

Loading and processing audio
File ./file/5e3161c65807b852d9e032af.wav does not exist.
File ./file/5e2ad4145807b852d9e020d9.wav does not exist.
File ./file/5e32924e5807b852d9e03894.wav does not exist.
File ./file/5e3292825807b852d9e0389a.wav does not exist.
File ./file/5e33a9d35807b852d9e050f4.wav does not exist.
File ./file/5e298c085807b852d9e01a12.wav does not exist.
File ./file/5e2ad43e5807b852d9e020dc.wav does not exist.
File ./file/5e2998b85807b852d9e01b02.wav does not exist.
File ./file/5e33638b5807b852d9e04aeb.wav does not exist.
File ./file/5e298bc45807b852d9e01a10.wav does not exist.
File ./file/5e298b9f5807b852d9e01a0f.wav does not exist.
File ./file/5e298bdc5807b852d9e01a11.wav does not exist.
File ./file/5e2979c25807b852d9e018d5.wav does not exist.
File ./file/5e31622f5807b852d9e032ba.wav does not exist.
File ./file/5e3292655807b852d9e03896.wav does not exist.



  0%|          | 0/11684 [00:00<?, ?it/s][A
  2%|▏         | 231/11684 [00:00<00:04, 2302.81it/s][A
  4%|▍         | 478/11684 [00:00<00:04, 2395.75it/s][A
  6%|▌         | 718/11684 [00:00<00:04, 2394.50it/s][A
  8%|▊         | 962/11684 [00:00<00:04, 2409.58it/s][A
 10%|█         | 1206/11684 [00:00<00:04, 2419.18it/s][A
 12%|█▏        | 1457/11684 [00:00<00:04, 2448.08it/s][A
 15%|█▍        | 1708/11684 [00:00<00:04, 2466.18it/s][A
 17%|█▋        | 1955/11684 [00:00<00:03, 2453.66it/s][A
 19%|█▉        | 2201/11684 [00:00<00:03, 2437.56it/s][A
 21%|██        | 2454/11684 [00:01<00:03, 2463.14it/s][A
 23%|██▎       | 2709/11684 [00:01<00:03, 2489.41it/s][A
 25%|██▌       | 2958/11684 [00:01<00:03, 2484.92it/s][A
 27%|██▋       | 3207/11684 [00:01<00:03, 2389.54it/s][A
 30%|██▉       | 3450/11684 [00:01<00:03, 2398.94it/s][A
 32%|███▏      | 3691/11684 [00:01<00:03, 2370.75it/s][A
 34%|███▍      | 3944/11684 [00:01<00:03, 2416.06it/s][A
 36%|███▌      | 4195/11684 [0

Loading and processing audio
File ./file/5e315dca5807b852d9e03275.wav does not exist.



  0%|          | 0/2922 [00:00<?, ?it/s][A
  9%|▊         | 252/2922 [00:00<00:01, 2515.36it/s][A
 17%|█▋        | 504/2922 [00:00<00:00, 2502.42it/s][A
 26%|██▌       | 762/2922 [00:00<00:00, 2535.50it/s][A
 35%|███▍      | 1016/2922 [00:00<00:00, 2494.51it/s][A
 44%|████▎     | 1276/2922 [00:00<00:00, 2529.45it/s][A
 52%|█████▏    | 1530/2922 [00:00<00:00, 2528.83it/s][A
 61%|██████    | 1783/2922 [00:00<00:00, 2469.13it/s][A
 70%|██████▉   | 2036/2922 [00:00<00:00, 2485.38it/s][A
 79%|███████▊  | 2295/2922 [00:00<00:00, 2515.04it/s][A
 88%|████████▊ | 2558/2922 [00:01<00:00, 2547.84it/s][A
100%|██████████| 2922/2922 [00:01<00:00, 2493.49it/s]

100%|██████████| 2922/2922 [00:00<00:00, 55883.25it/s]
Some weights of the model checkpoint at kresnik/wav2vec2-large-xlsr-korean were not used when initializing Wav2Vec2Model: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with a

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


In [15]:
checkpoint_path = trainer.checkpoint_callback.best_model_path
checkpoint_path

'/content/late_ckpts/epoch=4-step=7295.ckpt'

In [16]:
trainer.test(ckpt_path=checkpoint_path, dataloaders=[dev_loader])

INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/late_ckpts/epoch=4-step=7295.ckpt
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at /content/late_ckpts/epoch=4-step=7295.ckpt


Testing: 0it [00:00, ?it/s]

[{'test_acc': 0.5817513465881348, 'test_loss': 1.2031283378601074}]

In [17]:
from google.colab import drive
drive.mount('/content/drive')

import shutil

checkpoint_path = trainer.checkpoint_callback.best_model_path
drive_path = "/content/drive/MyDrive/의현/multimodal_best.ckpt"

shutil.copy(checkpoint_path, drive_path)
drive.flush_and_unmount()

Mounted at /content/drive
