In [1]:
# -*- coding: UTF-8 -*-
# Local modules
import os
import sys
import argparse
# 3rd-Party Modules
import numpy as np
import pickle as pk
import pandas as pd
from tqdm import tqdm
import glob
import librosa
import copy
import logging
import time 

# PyTorch Modules
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import ConcatDataset, DataLoader
import torch.optim as optim
from transformers import AutoModel
import importlib
# Self-Written Modules
# sys.path.append(os.getcwd())
sys.path.append(os.getcwd())
sys.path.append('../')
from benchmark import net
from benchmark import utils
from benchmark.utils.dataset import dataset as func
from torch.utils.data import WeightedRandomSampler

In [12]:

import json
from collections import defaultdict
config_path = '../configs/config_cat_nonorm_timbreperturb.json'
with open(config_path, "r") as f:
    config = json.load(f)
audio_path = config["wav_dir"]
label_path = config["label_path"]

SSL_TYPE = utils.get_ssl_type(config['ssl_type'])
assert SSL_TYPE != None, print("Invalid SSL type!")
# BATCH_SIZE = config['batch_size']
# ACCUMULATION_STEP = config['accum_step']
# assert (ACCUMULATION_STEP > 0) and (BATCH_SIZE % ACCUMULATION_STEP == 0)
# EPOCHS= config['epochs']
# LR=config['lr']
MODEL_PATH = config['model_path']
os.makedirs(MODEL_PATH, exist_ok=True)
HEAD_DIM = config['head_dim']
POOLING_TYPE = config['pooling_type']

USE_TIMBRE_PERTURB = config['use_timbre_perturb']

# print(config["use_balanced_batch"])
try:
    balanced_batch = config["use_balanced_batch"]
except:
    balanced_batch = False

try:
    normalize_wav = config["normalize_wav"]
except:
    normalize_wav = True

import pandas as pd
import numpy as np

# Load the CSV file
df = pd.read_csv(label_path)

# Filter out only 'Train' samples
train_df = df[df['Split_Set'] == 'Train']

# Classes (emotions)
classes = ['Angry', 'Sad', 'Happy', 'Surprise', 'Fear', 'Disgust', 'Contempt', 'Neutral']

classes_ = ['A', 'S', 'H', 'U', 'F', 'D', 'C', 'N']

map_argmax = dict()
for i, c in enumerate(classes_):
    map_argmax[i] = c

# Calculate class frequencies
class_frequencies = train_df[classes].sum().to_dict()

# Total number of samples
total_samples = len(train_df)

# Calculate class weights
class_weights = {cls: total_samples / (len(classes) * freq) if freq != 0 else 0 for cls, freq in class_frequencies.items()}

print(class_weights)

# Convert to list in the order of classes
weights_list = [class_weights[cls] for cls in classes]

# Convert to PyTorch tensor
class_weights_tensor = torch.tensor(weights_list, device='cuda', dtype=torch.float)


# Print or return the tensor
print(class_weights_tensor)

# import json
# from collections import defaultdict
# config_path = "configs/config_cat.json"
# with open(config_path, "r") as f:
#     config = json.load(f)
# audio_path = config["wav_dir"]
# label_path = config["label_path"]

dtype = "dev"
cur_utts, cur_labs = utils.load_cat_emo_label(label_path, dtype)

total_dataset=dict()
total_dataloader=dict()

cur_utts, cur_labs = utils.load_cat_emo_label(label_path, dtype)
cur_wavs = utils.load_audio(audio_path, cur_utts)
# cur_wavs = utils.load_audio(audio_path, cur_utts)
wav_mean, wav_std = utils.load_norm_stat('../experiments/baseline_wavlmbase_nonormwav_timbreperturb/train_norm_stat.pkl')
cur_emo_set = utils.CAT_EmoSet(cur_labs)
cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std, normalize_wav=normalize_wav, use_tp = True) # In evaluation we never use tp
total_dataset[dtype] = utils.CombinedSet([cur_wav_set,cur_emo_set, cur_utts])
total_dataloader[dtype] = DataLoader(
    total_dataset[dtype], batch_size=1, shuffle=False, 
    pin_memory=True, num_workers=4,
    collate_fn=utils.collate_fn_wav_lab_mask
)

{'Angry': 1.2440944881889764, 'Sad': 1.327941642879797, 'Happy': 0.5009271998564335, 'Surprise': 2.840569877883311, 'Fear': 7.476785714285715, 'Disgust': 5.847765363128492, 'Contempt': 3.356312625250501, 'Neutral': 0.28635912868036795}
tensor([1.2441, 1.3279, 0.5009, 2.8406, 7.4768, 5.8478, 3.3563, 0.2864],
       device='cuda:0')



00%|████████████████████████████████████████████████████████████████████████████| 25258/25258 [01:07<00:00, 375.89it/s]

In [13]:
total_dataset[dtype].datasets[0].use_tp

True

In [14]:
data = total_dataset[dtype].__getitem__(200)
audio, i = data[0]

In [15]:
import IPython
IPython.display.Audio(audio, rate = 16000)


In [16]:
import IPython
IPython.display.Audio(audio, rate = 16000)


In [17]:
audio

array([1.34755422e-06, 1.29163480e-05, 3.21448201e-05, ...,
       2.86184322e-04, 1.48646295e-04, 5.10936246e-05])

In [18]:
audio2 = func.fixed_timbre_perturb(audio, formant_rate = 5)

In [19]:
import IPython
IPython.display.Audio(audio2, rate = 16000)


In [32]:
# class_frequencies = train_df[classes].sum().to_dict()
# total_samples = len(train_df)
# class_weights_ = {cls: 1/np.sqrt(freq) if freq != 0 else 0 for cls, freq in class_frequencies.items()}
# weights_list_ = [class_weights_[cls] for cls in classes]
# # Convert to PyTorch tensor
# class_weights_tensor_ = torch.tensor(weights_list_, device='cuda', dtype=torch.float)
# logger.info(f'Using balanced batch. Weights = {class_weights_tensor_}')

class_frequencies = train_df[classes].sum().to_dict()
# Calculate inverse frequency weights
class_weights = {cls: 1/freq if freq != 0 else 0 for cls, freq in class_frequencies.items()}

# Normalize weights
factor = len(class_weights) / sum(class_weights.values())
class_weights = {cls: w * factor for cls, w in class_weights.items()}

val_df =  df[df['Split_Set'] == 'Development']

# Create per-sample weights based on their class
sample_weights = [class_weights[val_df[classes].iloc[i].idxmax()] for i in range(len(val_df))]



sampler = WeightedRandomSampler(
weights=sample_weights,               
num_samples=len(total_dataset[dtype]),       
replacement=True                 
)
total_dataloader[dtype] = DataLoader(
total_dataset[dtype], batch_size=32, sampler=sampler, 
pin_memory=True, num_workers=4,
collate_fn=utils.collate_fn_wav_lab_mask
)

In [34]:
classes

['Angry', 'Sad', 'Happy', 'Surprise', 'Fear', 'Disgust', 'Contempt', 'Neutral']

In [33]:
j = 0
for batch in total_dataloader[dtype]:
    print(f'#### BATCH {j} ####')
    
    y = batch[1]; y=y.max(dim=1)[1]
    for i in range(len(classes)):
        cnt = 100*(y[y==i].shape[0]/32)
        print(f"Class {i} cnt = {cnt}")
    
    if(j>=10):
        break

    j+=1

#### BATCH 0 ####
Class 0 cnt = 12.5
Class 1 cnt = 9.375
Class 2 cnt = 15.625
Class 3 cnt = 9.375
Class 4 cnt = 9.375
Class 5 cnt = 9.375
Class 6 cnt = 18.75
Class 7 cnt = 15.625
#### BATCH 1 ####
Class 0 cnt = 21.875
Class 1 cnt = 6.25
Class 2 cnt = 9.375
Class 3 cnt = 6.25
Class 4 cnt = 15.625
Class 5 cnt = 6.25
Class 6 cnt = 21.875
Class 7 cnt = 12.5
#### BATCH 2 ####
Class 0 cnt = 12.5
Class 1 cnt = 18.75
Class 2 cnt = 9.375
Class 3 cnt = 12.5
Class 4 cnt = 0.0
Class 5 cnt = 9.375
Class 6 cnt = 25.0
Class 7 cnt = 12.5
#### BATCH 3 ####
Class 0 cnt = 25.0
Class 1 cnt = 6.25
Class 2 cnt = 9.375
Class 3 cnt = 9.375
Class 4 cnt = 9.375
Class 5 cnt = 9.375
Class 6 cnt = 18.75
Class 7 cnt = 12.5
#### BATCH 4 ####
Class 0 cnt = 31.25
Class 1 cnt = 6.25
Class 2 cnt = 3.125
Class 3 cnt = 12.5
Class 4 cnt = 0.0
Class 5 cnt = 21.875
Class 6 cnt = 25.0
Class 7 cnt = 0.0
#### BATCH 5 ####
Class 0 cnt = 21.875
Class 1 cnt = 15.625
Class 2 cnt = 6.25
Class 3 cnt = 15.625
Class 4 cnt = 9.375
Class