In [1]:
# -*- coding: UTF-8 -*-
# Local modules
import os
import sys
import argparse
# 3rd-Party Modules
import numpy as np
import pickle as pk
import pandas as pd
from tqdm import tqdm
import glob
import librosa
import copy
import logging
import time 

# PyTorch Modules
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import ConcatDataset, DataLoader
import torch.optim as optim
from transformers import AutoModel
import importlib
# Self-Written Modules
sys.path.append('../')
from benchmark import net
from benchmark import utils
from torch.utils.data import WeightedRandomSampler
from transformers import RobertaTokenizer, RobertaModel


# parser = argparse.ArgumentParser()
# parser.add_argument("--seed", type=int, default=7)
# # parser.add_argument("--ssl_type", type=str, default="wavlm-large")
# # parser.add_argument("--batch_size", type=int, default=32)
# # parser.add_argument("--accumulation_steps", type=int, default=1)
# # parser.add_argument("--epochs", type=int, default=10)
# # parser.add_argument("--lr", type=float, default=0.001)
# # parser.add_argument("--model_path", type=str, default="./temp")
# parser.add_argument("--config_path", type=str, default="./configs/config_cat.json")
# # parser.add_argument("--head_dim", type=int, default=1024)

# # parser.add_argument("--pooling_type", type=str, default="AttentiveStatisticsPooling")
# args = parser.parse_args()

import json
from collections import defaultdict
# config_path = "configs/config_cat.json"
config_path = '../configs/config_cat_wavlmbase_robertabase.json'
with open(config_path, "r") as f:
    config = json.load(f)
audio_path = config["wav_dir"]
text_path = config["txt_dir"]
label_path = config["label_path"]

SSL_TYPE = utils.get_ssl_type(config['ssl_type'])
assert SSL_TYPE != None, print("Invalid SSL type!")
BATCH_SIZE = config['batch_size']
ACCUMULATION_STEP = config['accum_step']
assert (ACCUMULATION_STEP > 0) and (BATCH_SIZE % ACCUMULATION_STEP == 0)
EPOCHS= config['epochs']
LR=config['lr']
MODEL_PATH = config['model_path']
os.makedirs(MODEL_PATH, exist_ok=True)
HEAD_DIM = config['head_dim']
POOLING_TYPE = config['pooling_type']
WC = config["weight_decay"]
DROPOUT = config["dropout_head"]
USE_TIMBRE_PERTURB = config['use_timbre_perturb']
TP_PROB = config['tp_prob']
# utils.set_deterministic(args.seed)
# SSL_TYPE = utils.get_ssl_type(args.ssl_type)
# assert SSL_TYPE != None, print("Invalid SSL type!")
# BATCH_SIZE = args.batch_size
# ACCUMULATION_STEP = args.accumulation_steps
# assert (ACCUMULATION_STEP > 0) and (BATCH_SIZE % ACCUMULATION_STEP == 0)
# EPOCHS=args.epochs
# LR=args.lr
# MODEL_PATH = args.model_path
# os.makedirs(MODEL_PATH, exist_ok=True)


# Start logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(MODEL_PATH, '%s-%d.log' % ('loggingtxt', time.time()))),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger()



# print(config["use_balanced_batch"])
try:
    balanced_batch = config["use_balanced_batch"]
except:
    balanced_batch = False

try:
    normalize_wav = config["normalize_wav"]
except:
    normalize_wav = True

logger.info(f"Starting an experimento in model path = {MODEL_PATH}")
logger.info(f"Using ssl = {SSL_TYPE} LR = {LR} Epochs = {EPOCHS} Batch size = {BATCH_SIZE} Accum steps = {ACCUMULATION_STEP}")
logger.info(f"Using balanced batch = {balanced_batch}")
logger.info(f"Using normalize wav = {normalize_wav}")
logger.info(f"Using Timbre Perturbation = {USE_TIMBRE_PERTURB}")


import pandas as pd
import numpy as np

# Load the CSV file
label_df = pd.read_csv(label_path)
text_df = pd.read_csv(text_path)
df = label_df.merge(text_df, on = 'FileName', how = 'left')
# Filter out only 'Train' samples
train_df = df[df['Split_Set'] == 'Train']

# Classes (emotions)
classes = ['Angry', 'Sad', 'Happy', 'Surprise', 'Fear', 'Disgust', 'Contempt', 'Neutral']

# Calculate class frequencies
class_frequencies = train_df[classes].sum().to_dict()
# Total number of samples
total_samples = len(train_df)
# Calculate class weights
class_weights = {cls: total_samples / (len(classes) * freq) if freq != 0 else 0 for cls, freq in class_frequencies.items()}
print(class_weights)
# Convert to list in the order of classes
weights_list = [class_weights[cls] for cls in classes]
# Convert to PyTorch tensor
class_weights_tensor = torch.tensor(weights_list, device='cuda', dtype=torch.float)
# Print or return the tensor
print(class_weights_tensor)

logger.info(f"Class weights: {class_weights_tensor}")

tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
text_model = RobertaModel.from_pretrained("roberta-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
text_model.eval(); text_model.to(device)

total_dataset=dict()
total_dataloader=dict()
for dtype in ["train", "dev"]:
    cur_utts, cur_labs = utils.load_cat_emo_label(label_path, dtype)
    cur_wavs = utils.load_audio(audio_path, cur_utts)
    if dtype == "train":
        cur_wav_set = utils.WavSet(cur_wavs, normalize_wav=normalize_wav, use_tp = USE_TIMBRE_PERTURB, tp_prob= TP_PROB)
        cur_wav_set.save_norm_stat(MODEL_PATH+"/train_norm_stat.pkl")
        cur_txt_set = utils.TxtSet(df[df["Split_Set"] == 'Train'].transcription.tolist(), tokenizer)
    else:
        if dtype == "dev":
            wav_mean = total_dataset["train"].datasets[0].wav_mean
            wav_std = total_dataset["train"].datasets[0].wav_std
        elif dtype == "test":
            wav_mean, wav_std = utils.load_norm_stat(MODEL_PATH+"/train_norm_stat.pkl")
        cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std, normalize_wav=normalize_wav)
        cur_txt_set = utils.TxtSet(df[df["Split_Set"] == 'Development'].transcription.tolist(), tokenizer)
    ########################################################
    cur_bs = BATCH_SIZE // ACCUMULATION_STEP if dtype == "train" else 1
    is_shuffle=True if dtype == "train" else False
    ########################################################
    cur_emo_set = utils.CAT_EmoSet(cur_labs)
    total_dataset[dtype] = utils.CombinedSet([cur_wav_set, cur_emo_set, cur_utts, cur_txt_set])

    if((balanced_batch) & (dtype == "train")):
        logger.info('Using balanced batch')
        class_frequencies = train_df[classes].sum().to_dict()
        total_samples = len(train_df)
        class_weights_ = {cls: 1/np.sqrt(freq) if freq != 0 else 0 for cls, freq in class_frequencies.items()}
        weights_list_ = [class_weights_[cls] for cls in classes]
        # Convert to PyTorch tensor
        class_weights_tensor_ = torch.tensor(weights_list_, device='cuda', dtype=torch.float)
        logger.info(f'Using balanced batch. Weights = {class_weights_tensor_}')
        sampler = WeightedRandomSampler(
            weights=class_weights_tensor_,               
            num_samples=len(total_dataset[dtype]),       
            replacement=True                 
        )
        total_dataloader[dtype] = DataLoader(
            total_dataset[dtype], batch_size=cur_bs, sampler=sampler, 
            pin_memory=True, num_workers=4,
            collate_fn=utils.collate_fn_txt_wav_lab_mask
        )
    else:
        total_dataloader[dtype] = DataLoader(
        total_dataset[dtype], batch_size=cur_bs, shuffle=is_shuffle, 
        pin_memory=True, num_workers=4,
        collate_fn=utils.collate_fn_txt_wav_lab_mask
    )

2024-12-26 23:29:13,389 - INFO - Starting an experimento in model path = ./experiments/baseline_wavlmbase_robertabase
2024-12-26 23:29:13,391 - INFO - Using ssl = microsoft/wavlm-base LR = 1e-05 Epochs = 20 Batch size = 32 Accum steps = 2
2024-12-26 23:29:13,392 - INFO - Using balanced batch = False
2024-12-26 23:29:13,393 - INFO - Using normalize wav = False
2024-12-26 23:29:13,394 - INFO - Using Timbre Perturbation = False


{'Angry': 1.2440944881889764, 'Sad': 1.327941642879797, 'Happy': 0.5009271998564335, 'Surprise': 2.840569877883311, 'Fear': 7.476785714285715, 'Disgust': 5.847765363128492, 'Contempt': 3.356312625250501, 'Neutral': 0.28635912868036795}


2024-12-26 23:29:15,586 - INFO - Class weights: tensor([1.2441, 1.3279, 0.5009, 2.8406, 7.4768, 5.8478, 3.3563, 0.2864],
       device='cuda:0')


tensor([1.2441, 1.3279, 0.5009, 2.8406, 7.4768, 5.8478, 3.3563, 0.2864],
       device='cuda:0')


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


00%|████████████████████████████████████████████████████████████████████████████| 25258/25258 [01:48<00:00, 233.27it/s]

In [7]:
batch = next(iter(total_dataloader['train']))

In [10]:
batch[5]

tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])

In [3]:
batch = next(iter(total_dataset[dtype]))

In [4]:
batch[3]

(tensor([   0,    8,   38, 1266,    5, 1530,    6,  235,  116,  407,  734,    2,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,