In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from imblearn.over_sampling import SMOTE  # SMOTE 임포트
from tqdm import tqdm
import os

# 사용자 정의 데이터셋 클래스
class AudioDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# VATT 기반 단일 음성 모델
class VATTEmotionClassifier(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=256, output_dim=6):
        super(VATTEmotionClassifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),  # 추가적인 은닉층
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, x):
        return self.model(x)

# Z-스코어 정규화 함수
def standardize_data(X):
    mean = np.mean(X, axis=0)
    std = np.std(X, axis=0)
    return (X - mean) / std

# CSV 데이터 로드 및 전처리
def load_data(csv_path):
    df = pd.read_csv(csv_path)
    X = df.iloc[:, 1:-1].values
    y = df['emotion'].values
    return X, y

# 모델 훈련 및 성능 평가
def train_and_evaluate_model(X, y, batch_size=16, epochs=100, learning_rate=0.001, patience=5):
    X = standardize_data(X)  # 데이터 정규화 수행
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # SMOTE 적용
    smote = SMOTE(random_state=42)
    X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train)
    
    train_dataset = AudioDataset(X_train_resampled, y_train_resampled)
    test_dataset = AudioDataset(X_test, y_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VATTEmotionClassifier().to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    best_accuracy = 0.0
    epochs_without_improvement = 0  # Early stopping을 위한 카운터
    
    # 모델 훈련
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        with tqdm(train_loader, unit="batch") as tepoch:
            tepoch.set_description(f"Epoch [{epoch+1}/{epochs}]")
            for inputs, labels in tepoch:
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)
                
                tepoch.set_postfix(loss=running_loss/len(train_loader), accuracy=100. * correct/total)

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100. * correct/total:.2f}%")

        # 검증 데이터에 대한 성능 평가
        model.eval()
        y_pred = []
        y_true = []
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                y_pred.extend(predicted.cpu().numpy())
                y_true.extend(labels.numpy())

        accuracy = accuracy_score(y_true, y_pred)
        print(f"Validation Accuracy: {accuracy:.4f}")

        # Early Stopping
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            epochs_without_improvement = 0
            # 모델 상태 저장
            torch.save(model.state_dict(), 'vatt_audio_final.pth')
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f"Early stopping activated after {epoch + 1} epochs without improvement.")
                break

    # 최종 평가
    model.eval()
    y_pred = []
    y_true = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            y_pred.extend(predicted.cpu().numpy())
            y_true.extend(labels.numpy())

    accuracy = accuracy_score(y_true, y_pred)
    print(f"Test Accuracy: {accuracy:.4f}")
    print("Classification Report:")
    print(classification_report(y_true, y_pred))

def main():
    csv_path = 'vatt_audio_features_encoded.csv'
    X, y = load_data(csv_path)
    train_and_evaluate_model(X, y)

if __name__ == "__main__":
    main()


Epoch [1/100]: 100%|█████████████████████████████████| 4209/4209 [00:17<00:00, 235.85batch/s, accuracy=36.4, loss=1.55]


Epoch [1/100], Loss: 1.5502, Accuracy: 36.40%
Validation Accuracy: 0.3125


Epoch [2/100]: 100%|█████████████████████████████████| 4209/4209 [00:17<00:00, 246.94batch/s, accuracy=54.1, loss=1.17]


Epoch [2/100], Loss: 1.1683, Accuracy: 54.14%
Validation Accuracy: 0.4261


Epoch [3/100]: 100%|█████████████████████████████████| 4209/4209 [00:17<00:00, 242.81batch/s, accuracy=63.3, loss=0.95]


Epoch [3/100], Loss: 0.9502, Accuracy: 63.33%
Validation Accuracy: 0.4613


Epoch [4/100]: 100%|████████████████████████████████| 4209/4209 [00:17<00:00, 245.79batch/s, accuracy=69.5, loss=0.793]


Epoch [4/100], Loss: 0.7934, Accuracy: 69.53%
Validation Accuracy: 0.4635


Epoch [5/100]: 100%|█████████████████████████████████| 4209/4209 [00:17<00:00, 235.73batch/s, accuracy=73.9, loss=0.68]


Epoch [5/100], Loss: 0.6800, Accuracy: 73.92%
Validation Accuracy: 0.4767


Epoch [6/100]: 100%|████████████████████████████████| 4209/4209 [00:17<00:00, 241.60batch/s, accuracy=77.4, loss=0.592]


Epoch [6/100], Loss: 0.5924, Accuracy: 77.42%
Validation Accuracy: 0.4809


Epoch [7/100]: 100%|██████████████████████████████████| 4209/4209 [00:18<00:00, 224.65batch/s, accuracy=80, loss=0.523]


Epoch [7/100], Loss: 0.5228, Accuracy: 80.04%
Validation Accuracy: 0.4889


Epoch [8/100]: 100%|████████████████████████████████| 4209/4209 [00:19<00:00, 218.59batch/s, accuracy=82.2, loss=0.464]


Epoch [8/100], Loss: 0.4636, Accuracy: 82.15%
Validation Accuracy: 0.4807


Epoch [9/100]: 100%|██████████████████████████████████| 4209/4209 [00:19<00:00, 218.16batch/s, accuracy=84, loss=0.417]


Epoch [9/100], Loss: 0.4169, Accuracy: 83.96%
Validation Accuracy: 0.4956


Epoch [10/100]: 100%|███████████████████████████████| 4209/4209 [00:18<00:00, 225.33batch/s, accuracy=85.4, loss=0.376]


Epoch [10/100], Loss: 0.3763, Accuracy: 85.37%
Validation Accuracy: 0.4961


Epoch [11/100]: 100%|████████████████████████████████| 4209/4209 [00:17<00:00, 241.61batch/s, accuracy=86.9, loss=0.34]


Epoch [11/100], Loss: 0.3401, Accuracy: 86.92%
Validation Accuracy: 0.4933


Epoch [12/100]: 100%|███████████████████████████████| 4209/4209 [00:18<00:00, 222.54batch/s, accuracy=88.1, loss=0.307]


Epoch [12/100], Loss: 0.3072, Accuracy: 88.11%
Validation Accuracy: 0.4825


Epoch [13/100]: 100%|███████████████████████████████| 4209/4209 [00:18<00:00, 232.50batch/s, accuracy=88.8, loss=0.286]


Epoch [13/100], Loss: 0.2860, Accuracy: 88.85%
Validation Accuracy: 0.5012


Epoch [14/100]: 100%|███████████████████████████████| 4209/4209 [00:17<00:00, 235.82batch/s, accuracy=90.1, loss=0.258]


Epoch [14/100], Loss: 0.2577, Accuracy: 90.12%
Validation Accuracy: 0.4884


Epoch [15/100]: 100%|████████████████████████████████| 4209/4209 [00:17<00:00, 236.97batch/s, accuracy=90.9, loss=0.24]


Epoch [15/100], Loss: 0.2395, Accuracy: 90.91%
Validation Accuracy: 0.5117


Epoch [16/100]: 100%|███████████████████████████████| 4209/4209 [00:17<00:00, 239.13batch/s, accuracy=91.6, loss=0.221]


Epoch [16/100], Loss: 0.2211, Accuracy: 91.57%
Validation Accuracy: 0.4936


Epoch [17/100]: 100%|███████████████████████████████| 4209/4209 [00:16<00:00, 250.40batch/s, accuracy=92.3, loss=0.204]


Epoch [17/100], Loss: 0.2039, Accuracy: 92.27%
Validation Accuracy: 0.5063


Epoch [18/100]: 100%|███████████████████████████████| 4209/4209 [00:17<00:00, 239.37batch/s, accuracy=92.8, loss=0.191]


Epoch [18/100], Loss: 0.1909, Accuracy: 92.84%
Validation Accuracy: 0.5014


Epoch [19/100]: 100%|███████████████████████████████| 4209/4209 [00:17<00:00, 234.11batch/s, accuracy=93.3, loss=0.178]


Epoch [19/100], Loss: 0.1784, Accuracy: 93.25%
Validation Accuracy: 0.4966


Epoch [20/100]: 100%|███████████████████████████████| 4209/4209 [00:17<00:00, 242.90batch/s, accuracy=93.7, loss=0.167]


Epoch [20/100], Loss: 0.1669, Accuracy: 93.74%
Validation Accuracy: 0.5014
Early stopping activated after 20 epochs without improvement.
Test Accuracy: 0.5014
Classification Report:
              precision    recall  f1-score   support

           0       0.48      0.52      0.50      2355
           1       0.35      0.29      0.31       951
           2       0.37      0.35      0.36       815
           3       0.56      0.47      0.51       901
           4       0.58      0.63      0.60      2764
           5       0.50      0.38      0.43       357

    accuracy                           0.50      8143
   macro avg       0.47      0.44      0.45      8143
weighted avg       0.50      0.50      0.50      8143

