In [1]:
import os
import logging
import pickle
import numpy as np
import soundfile as sf
import torchaudio
import torch
import datetime
import random

from typing import List, Optional, Tuple, Dict
from torch import Tensor, nn, optim
from torchvggish import vggish
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset
from torchvggish.vggish_input import waveform_to_examples
from tqdm.auto import tqdm
from abc import ABC, abstractmethod

SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


## Preproccessing Data

In [2]:
import argparse
import glob
import logging
import os
import pickle
import random

import pandas as pd
import soundfile as sf

from moviepy.editor import VideoFileClip
from sklearn.model_selection import train_test_split

logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")

data_root = "D:\LSTM\IEMOCAP_full_release"
ignore_length = 0

session_id = list(range(1, 6))

samples = []
labels = []
iemocap2label = {
    "neu": 0,
    "fru": 0,
    "exc": 1,
    "hap": 1,
    "sur": 1,
    "fea": 2,
    "ang": 2,
    "sad": 2,
    "dis": 2,
}
iemocap2label.update({"exc": 1})

for sess_id in tqdm(session_id):
    sess_path = os.path.join(data_root, "Session{}".format(sess_id))
    sess_autio_root = os.path.join(sess_path, "sentences/wav")
    sess_label_root = os.path.join(sess_path, "dialog/EmoEvaluation")
    label_paths = glob.glob(os.path.join(sess_label_root, "*.txt"))
    for l_path in label_paths:
        with open(l_path, "r") as f:
            label = f.read().split("\n")
            for l in label:
                if str(l).startswith("["):
                    data = l.split()
                    wav_folder = data[3][:-5]
                    wav_name = data[3] + ".wav"
                    emo = data[4]
                    wav_path = os.path.join(sess_autio_root, wav_folder, wav_name)
                    wav_data, _ = sf.read(wav_path, dtype="int16")
                    # Ignore samples with length < ignore_length
                    if len(wav_data) < ignore_length:
                        logging.warning(f"Ignoring sample {wav_path} with length {len(wav_data)}")
                        continue
                    emo = iemocap2label.get(emo, None)
                    if emo is not None:
                        samples.append((wav_path, emo))
                        labels.append(emo)

# Shuffle and split
temp = list(zip(samples, labels))
random.Random(0).shuffle(temp)
samples, labels = zip(*temp)
train_samples, test_samples, _, _ = train_test_split(samples, labels, test_size=0.2, random_state=0)

# Save data
os.makedirs("IEMOCAP_preprocessed", exist_ok=True)
with open(os.path.join("IEMOCAP_preprocessed", "train.pkl"), "wb") as f:
    pickle.dump(train_samples, f)
with open(os.path.join("IEMOCAP_preprocessed", "test.pkl"), "wb") as f:
    pickle.dump(test_samples, f)

logging.info(f"Train samples: {len(train_samples)}")
logging.info(f"Test samples: {len(test_samples)}")
logging.info(f"Saved to {'IEMOCAP_preprocessed'}")
logging.info("Preprocessing finished successfully")

100%|██████████| 5/5 [00:01<00:00,  2.52it/s]
2023-10-26 11:56:34,318 - root - INFO - Train samples: 6023
2023-10-26 11:56:34,319 - root - INFO - Test samples: 1506
2023-10-26 11:56:34,319 - root - INFO - Saved to IEMOCAP_preprocessed
2023-10-26 11:56:34,320 - root - INFO - Preprocessing finished successfully


## Config

In [3]:
class Base(ABC):
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

    @abstractmethod
    def show(self):
        pass

    @abstractmethod
    def save(self):
        pass

class BaseConfig(Base):
    def __init__(self, **kwargs):
        super(BaseConfig, self).__init__(**kwargs)

    def show(self):
        for key, value in self.__dict__.items():
            logging.info(f"{key}: {value}")

    def save(self, opt: str):
        message = "\n"
        for k, v in sorted(vars(opt).items()):
            message += f"{str(k):>30}: {str(v):<40}\n"

        os.makedirs(os.path.join(opt.checkpoint_dir), exist_ok=True)
        out_opt = os.path.join(opt.checkpoint_dir, "opt.log")
        with open(out_opt, "w") as opt_file:
            opt_file.write(message)
            opt_file.write("\n")

        logging.info(message)

    def load(self, opt_path: str):
        def decode_value(value: str):
            value = value.strip()
            if "." in value and value.replace(".", "").isdigit():
                value = float(value)
            elif value.isdigit():
                value = int(value)
            elif value == "True":
                value = True
            elif value == "False":
                value = False
            elif value == "None":
                value = None
            elif value.startswith("'") and value.endswith("'") or value.startswith('"') and value.endswith('"'):
                value = value[1:-1]
            return value

        with open(opt_path, "r") as f:
            data = f.read().split("\n")
            # remove all empty strings
            data = list(filter(None, data))
            # convert to dict
            data_dict = {}
            for i in range(len(data)):
                key, value = data[i].split(":")[0].strip(), data[i].split(":")[1].strip()
                if value.startswith("[") and value.endswith("]"):
                    value = value[1:-1].split(",")
                    value = [decode_value(x) for x in value]
                else:
                    value = decode_value(value)

                data_dict[key] = value
        for key, value in data_dict.items():
            setattr(self, key, value)


class Config(BaseConfig):
    # Base
    def __init__(self, **kwargs):
        super(Config, self).__init__(**kwargs)
        self.name = "default"
        self.set_args()
        for key, value in kwargs.items():
            setattr(self, key, value)

    def set_args(self, **kwargs):
        # Training settings
        self.num_epochs: int = 250
        self.checkpoint_dir: str = "checkpoints/AudioOnly_v2_notebook"
        self.save_all_states: bool = True
        self.save_best_val: bool = True
        self.max_to_keep: int = 1
        self.save_freq: int = 4000
        self.batch_size: int = 1

        self.loss_type: str = "CrossEntropyLoss"

        # Learning rate
        self.learning_rate: float = 0.01
        self.learning_rate_step_size: int = 20
        self.learning_rate_gamma: float = 0.1

        # Dataset
        self.data_name: str = "IEMOCAPAudio"  # [IEMOCAP, ESD, MELD, IEMOCAPAudio]
        self.data_root: str = "IEMOCAP_preprocessed"  # folder contains train.pkl and test.pkl
        
        # use for training with batch size > 1
        self.audio_max_length: int = 546220

        # Model
        self.num_classes: int = 3
        self.dropout: float = 0.5
        self.model_type: str = "AudioOnly_v2"  # [MMSERA, AudioOnly, TextOnly, SERVER]
        self.audio_encoder_type: str = "lstm"  # [vggish, panns, hubert_base, wav2vec2_base, wavlm_base, lstm]
        self.audio_encoder_dim: int = 512  # 2048 - panns, 128 - vggish, 768 - hubert_base,wav2vec2_base, 512 - wavlm_base
        self.audio_norm_type: str = "layer_norm"  # [layer_norm, min_max, None]
        self.audio_unfreeze: bool = True

        self.fusion_head_output_type: str = "cls"  # [cls, mean, max]

        self.linear_layer_last_dim: int = 64
        
        # For LSTM
        self.lstm_hidden_size=512 # should be the same as audio_encoder_dim
        self.lstm_num_layers=2
        
        self.name = f"{self.fusion_head_output_type}_{self.audio_encoder_type}"

        for key, value in kwargs.items():
            setattr(self, key, value)


## Dataset

In [4]:
class IEMOCAPAudioDataset(Dataset):
    def __init__(
        self,
        path: str = "path/to/data.pkl",
        audio_max_length: int = 546220,
        audio_encoder_type: str = "vggish",
    ):
        """Dataset for IEMOCAP

        Args:
            path (str, optional): Path to data.pkl. Defaults to "path/to/data.pkl".
            tokenizer (BertTokenizer, optional): Tokenizer for text. Defaults to BertTokenizer.from_pretrained("bert-base-uncased").
            audio_max_length (int, optional): The maximum length of audio. Defaults to 546220. None for no padding and truncation.
        """
        super(IEMOCAPAudioDataset, self).__init__()
        with open(path, "rb") as file:
            self.data_list = pickle.load(file)
        self.audio_max_length = audio_max_length
        self.audio_encoder_type = audio_encoder_type

    def __paudio__(self, file_path: int) -> torch.Tensor:
        wav_data, sr = sf.read(file_path, dtype="int16")
        samples = wav_data / 32768.0  # Convert to [-1.0, +1.0]
        if self.audio_max_length is not None and samples.shape[0] < self.audio_max_length:
            samples = np.pad(samples, (0, self.audio_max_length - samples.shape[0]), "constant")
        elif self.audio_max_length is not None:
            samples = samples[: self.audio_max_length]

        if self.audio_encoder_type == "vggish":
            samples = waveform_to_examples(samples, sr, return_tensor=False)  # num_samples, 96, 64
            samples = np.expand_dims(samples, axis=1)  # num_samples, 1, 96, 64
        elif self.audio_encoder_type != "panns":
            samples = torchaudio.functional.resample(samples, sr, 16000)

        return torch.from_numpy(samples.astype(np.float32))

    def __plabel__(self, label: int) -> torch.Tensor:
        return torch.tensor(label)

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, index: int) -> Dict[str, np.ndarray]:
        audio_path, label = self.data_list[index]
        samples = self.__paudio__(audio_path)
        label = self.__plabel__(label)
        return samples, label


def build_train_test_dataset(opt: Config):
    DATASET_MAP = {
        "IEMOCAPAudio": IEMOCAPAudioDataset,
    }

    dataset = DATASET_MAP.get(opt.data_name, None)
    if dataset is None:
        raise NotImplementedError(
            "Dataset {} is not implemented, list of available datasets: {}".format(opt.data_name, DATASET_MAP.keys())
        )

    audio_max_length = opt.audio_max_length
    if opt.batch_size == 1:
        audio_max_length = None

    training_data = dataset(
        os.path.join(opt.data_root, "train.pkl"),
        audio_max_length,
        opt.audio_encoder_type,
    )
    test_data = dataset(
        os.path.join(opt.data_root, "test.pkl"),
        None,
        opt.audio_encoder_type,
    )

    train_dataloader = DataLoader(training_data, batch_size=opt.batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)
    return (train_dataloader, test_dataloader)


## Model

In [5]:
class LayerNorm(nn.LayerNorm):
    """Layer norm with transpose"""

    def forward(self, input: Tensor) -> Tensor:
        x = input.transpose(-2, -1)
        x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        x = x.transpose(-2, -1)
        return x


class ConvLayerBlock(nn.Module):
    """Convolution unit of FeatureExtractor"""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        bias: bool,
        layer_norm: Optional[nn.Module],
    ):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.layer_norm = layer_norm
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            bias=bias,
        )

    def forward(
        self,
        x: Tensor,
        length: Optional[Tensor],
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """
        Args:
            x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
            length (Tensor or None, optional): Shape ``[batch, ]``.
        Returns:
            Tensor: Shape ``[batch, out_channels, out_frames]``.
            Optional[Tensor]: Shape ``[batch, ]``.
        """
        x = self.conv(x)
        if self.layer_norm is not None:
            x = self.layer_norm(x)
        x = nn.functional.gelu(x)

        if length is not None:
            length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1
            # When input length is 0, the resulting length can be negative. So fix it here.
            length = torch.max(torch.zeros_like(length), length)
        return x, length


class FeatureExtractor(nn.Module):
    """Extract features from audio

    Args:
        conv_layers (nn.ModuleList):
            convolution layers
    """

    def __init__(
        self,
        conv_layers: nn.ModuleList,
    ):
        super().__init__()
        self.conv_layers = conv_layers

    def forward(
        self,
        x: Tensor,
        length: Optional[Tensor],
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """
        Args:
            x (Tensor):
                Input Tensor representing a batch of audio,
                shape: ``[batch, time]``.
            length (Tensor or None, optional):
                Valid length of each input sample. shape: ``[batch, ]``.

        Returns:
            Tensor:
                The resulting feature, shape: ``[batch, frame, feature]``
            Optional[Tensor]:
                Valid length of each output sample. shape: ``[batch, ]``.
        """
        if x.ndim != 2:
            raise ValueError(f"Expected the input Tensor to be 2D (batch, time). Found: {list(x.shape)}")

        x = x.unsqueeze(1)  # (batch, channel==1, frame)
        for layer in self.conv_layers:
            x, length = layer(x, length)  # (batch, feature, frame)
        x = x.transpose(1, 2)  # (batch, frame, feature)
        return x, length


################################################################################
def get_feature_extractor(
    norm_mode: str,
    shapes: List[Tuple[int, int, int]],
    bias: bool,
) -> FeatureExtractor:
    """
    Args:
        norm_mode (str):
            Either "group_norm" or "layer_norm".
            If "group_norm", then a single normalization is applied
            in the first convolution block. Otherwise, all the convolution
            blocks will have layer normalization.
            This option corresponds to "extractor_mode" from fairseq.
            Expected values are "group_norm" for Base arch, and
            "layer_norm" for Large arch.
        shapes (list of tuple of int):
            Configuration of convolution layers. List of convolution configuration,
            i.e. ``[(output_channel, kernel_size, stride), ...]``
            This option corresponds to "conv_feature_layers" from fairseq.
            Expected values are
            ``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2``
            for all the architectures.
        bias (bool):
            Whether to include bias term to each convolution operation.
            This option corresponds to "conv_bias" from fairseq.
            Expected values are False for Base arch, and True for Large arch.

    For wav2vec2-base, use
        extractor_mode:  group_norm
        extractor_conv_layer_config:  [(512, 10, 5), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 2, 2), (512, 2, 2)]
        extractor_conv_bias:  False
    """
    if norm_mode not in ["group_norm", "layer_norm"]:
        raise ValueError("Invalid norm mode")
    blocks = []
    in_channels = 1
    for i, (out_channels, kernel_size, stride) in enumerate(shapes):
        normalization = None
        if norm_mode == "group_norm" and i == 0:
            normalization = nn.GroupNorm(
                num_groups=out_channels,
                num_channels=out_channels,
                affine=True,
            )
        elif norm_mode == "layer_norm":
            normalization = LayerNorm(
                normalized_shape=out_channels,
                elementwise_affine=True,
            )
        blocks.append(
            ConvLayerBlock(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                bias=bias,
                layer_norm=normalization,
            )
        )
        in_channels = out_channels
    return FeatureExtractor(nn.ModuleList(blocks))


In [6]:
# class VGGish(nn.Module):
#     def __init__(self):
#         super(VGGish, self).__init__()
#         self.vggish = vggish()

#     def forward(self, x):
#         out = []
#         for i in range(x.size(0)):
#             out.append(self.vggish(x[i]))
#         x = torch.stack(out, axis=0)
#         if len(x.size()) == 2:
#             x = x.unsqueeze(1)
#         return x


# def build_vggish_encoder(opt: Config) -> nn.Module:
#     """A function to build vggish encoder"""
#     return VGGish()

# class HuBertBase(nn.Module):
#     def __init__(self, **kwargs):
#         super(HuBertBase, self).__init__(**kwargs)
#         bundle = torchaudio.pipelines.HUBERT_BASE
#         self.model = bundle.get_model()

#     def forward(self, x):
#         features, _ = self.model(x)
#         return features

# def build_hubert_base_encoder(opt: Config) -> nn.Module:
#     """A function to build hubert encoder"""
#     return HuBertBase()


# class Wav2Vec2Base(nn.Module):
#     def __init__(self, **kwargs):
#         super(Wav2Vec2Base, self).__init__(**kwargs)
#         bundle = torchaudio.pipelines.WAV2VEC2_BASE
#         self.model = bundle.get_model()

#     def forward(self, x):
#         features, _ = self.model(x)
#         return features


# def build_wav2vec2_base_encoder(opt: Config) -> nn.Module:
#     return Wav2Vec2Base()


# class WavlmBase(nn.Module):
#     def __init__(self, **kwargs):
#         super(WavlmBase, self).__init__(**kwargs)
#         bundle = torchaudio.pipelines.WAVLM_BASE
#         self.model = bundle.get_model()

#     def forward(self, x):
#         features, _ = self.model(x)
#         return features


# def build_wavlm_base_encoder(opt: Config) -> nn.Module:
#     return WavlmBase()

class LSTM(nn.Module):
    def __init__(self, feature_module, input_size=512, hidden_size=512, num_layers=2, **kwargs):
        super(LSTM, self).__init__(**kwargs)
        self.feature_module = feature_module
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)

    def forward(self, x):
        x, lengths = self.feature_module(x, None)  # (samples, length)
        x, _ = self.lstm(x)
        # take only the last output
        x = x[:, -1, :]
        return x


def build_lstm_encoder(opt: Config) -> nn.Module:

    extractor_mode = "group_norm"
    extractor_conv_layer_config = [(512, 10, 5), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 2, 2), (512, 2, 2)]
    extractor_conv_bias = False
    feature_extractor = get_feature_extractor(extractor_mode, extractor_conv_layer_config, extractor_conv_bias)
    feature_extractor.to("cpu")
    state_dict = torch.load(os.path.join("D:/LSTM/feature_extractor_wav2vec_base.pth"))
    feature_extractor.load_state_dict(state_dict)

    model = LSTM(feature_extractor, input_size=512, hidden_size=opt.lstm_hidden_size, num_layers=opt.lstm_num_layers)

    return model


def build_audio_encoder(opt: Config) -> nn.Module:
    """A function to build audio encoder

    Args:
        type (str, optional): Type of audio encoder. Defaults to "vggish".

    Returns:
        nn.Module: Audio encoder
    """
    type = opt.audio_encoder_type
    encoders = {
        # "vggish": build_vggish_encoder,
        # "hubert_base": build_hubert_base_encoder,
        # "wav2vec2_base": build_wav2vec2_base_encoder,
        # "wavlm_base": build_wavlm_base_encoder,
        "lstm": build_lstm_encoder,
    }
    assert type in encoders.keys(), f"Invalid audio encoder type: {type}"
    return encoders[type](opt)


In [7]:
class AudioOnly_v2(nn.Module):
    def __init__(
        self,
        opt: Config,
        device: str = "cpu",
    ):
        """Speech Emotion Recognition with Audio Only

        Args:
            opt (Config): Config object
            device (str, optional): The device to use. Defaults to "cpu".
        """
        super(AudioOnly_v2, self).__init__()

        # Audio module
        self.audio_encoder = build_audio_encoder(opt)
        self.audio_encoder.to(device)
        # Freeze/Unfreeze the audio module
        for param in self.audio_encoder.parameters():
            param.requires_grad = opt.audio_unfreeze

        self.dropout = nn.Dropout(opt.dropout)
        self.linear1 = nn.Linear(opt.audio_encoder_dim, 256)
        self.linear2 = nn.Linear(256, 64)
        self.classifer = nn.Linear(64, opt.num_classes)
        self.fusion_head_output_type = opt.fusion_head_output_type

    def forward(self, audio: torch.Tensor):
        # Audio processing
        audio_embeddings = self.audio_encoder(audio)

        # Check if vggish outputs is (128) or (num_samples, 128)
        if len(audio_embeddings.size()) == 1:
            audio_embeddings = audio_embeddings.unsqueeze(0)

        # Expand the audio embeddings to match the text embeddings
        if len(audio_embeddings.size()) == 2:
            audio_embeddings = audio_embeddings.unsqueeze(0)

        # Get classification output
        if self.fusion_head_output_type == "cls":
            audio_embeddings = audio_embeddings[:, 0, :]
        elif self.fusion_head_output_type == "mean":
            audio_embeddings = audio_embeddings.mean(dim=1)
        elif self.fusion_head_output_type == "max":
            audio_embeddings = audio_embeddings.max(dim=1)
        else:
            raise ValueError("Invalid fusion head output type")

        # Classification head
        x = self.linear1(audio_embeddings)
        x = self.linear2(x)
        x = self.dropout(x)
        out = self.classifer(x)

        return out


## Train

In [8]:
opt = Config()

logging.info("Initializing model...")
model = AudioOnly_v2(opt)
model.to(device)

logging.info("Initializing checkpoint directory and dataset...")
# Preapre the checkpoint directory
opt.checkpoint_dir = checkpoint_dir = os.path.join(
    os.path.abspath(opt.checkpoint_dir),
    opt.name,
    datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
)
weight_dir = os.path.join(checkpoint_dir, "weights")
os.makedirs(weight_dir, exist_ok=True)
opt.save(opt)
weight_path = os.path.join(weight_dir, "latest_ckpt.pth")
weight_best_path = os.path.join(weight_dir, "best_ckpt.pth")

# Build dataset
train_ds, test_ds = build_train_test_dataset(opt)

# Build optimizer and criterion
optimizer = optim.SGD(params=model.parameters(), lr=opt.learning_rate)
lr_scheduler = None
if opt.learning_rate_step_size is not None:
    lr_scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=opt.learning_rate_step_size, gamma=opt.learning_rate_gamma
    )

# Define loss function
loss_fn = CrossEntropyLoss()

2023-10-26 11:57:17,635 - root - INFO - Initializing model...
2023-10-26 11:57:17,757 - root - INFO - Initializing checkpoint directory and dataset...
2023-10-26 11:57:17,761 - root - INFO - 
             audio_encoder_dim: 512                                     
            audio_encoder_type: lstm                                    
              audio_max_length: 546220                                  
               audio_norm_type: layer_norm                              
                audio_unfreeze: True                                    
                    batch_size: 1                                       
                checkpoint_dir: d:\LSTM\checkpoints\AudioOnly_v2_notebook\cls_lstm\20231026-115717
                     data_name: IEMOCAPAudio                            
                     data_root: IEMOCAP_preprocessed                    
                       dropout: 0.5                                     
       fusion_head_output_type: cls                 

In [9]:
logging.info("Start training...")
step = 0
best_acc = -999999
for epoch in range(1, opt.num_epochs+1):
    with tqdm(total=len(train_ds), ascii=True) as pbar:
        pbar.update(1)
        for batch in train_ds:
            audio, label = batch
            # Training step
            step += 1
            audio = audio.to(device)
            label = label.to(device)
            optimizer.zero_grad()
            
            out = model(audio)
            loss = loss_fn(out, label)
            loss.backward()
            optimizer.step()
            if lr_scheduler is not None:
                lr_scheduler.step()
            if step % opt.save_freq == 0:
                torch.save(model.state_dict(), weight_path)
                logging.info("Saved checkpoint to {} at iter {}, epoch {}".format(weight_path, step, epoch))
                
            acc = (out.argmax(dim=1) == label).float().mean()
            # Add logs, update progress bar
            postfix = "Loss: {}, Acc: {}".format(loss.item(), acc.item())
            pbar.set_description(postfix)
            pbar.update(1)
    
    val_loss = []
    val_acc = []
    for (audio, label) in tqdm(test_ds):
        audio = audio.to(device)
        label = label.to(device)
        out = model(audio)
        loss = loss_fn(out, label)
        acc = (out.argmax(dim=1) == label).float().mean()
        val_loss.append(loss.item())
        val_acc.append(acc.item())
    print("Val loss: {}, Val acc: {}".format(np.mean(val_loss), np.mean(val_acc)))
    if np.mean(val_acc) > best_acc:
        print("Model improved from {} to {}".format(best_acc, np.mean(val_acc)))
        best_acc = np.mean(val_acc)
        torch.save(model.state_dict(), weight_best_path)

2023-10-26 11:57:43,989 - root - INFO - Start training...
Loss: 1.2081342935562134, Acc: 0.0:   0%|          | 4/6023 [00:03<1:23:18,  1.20it/s]


KeyboardInterrupt: 