In [None]:
import re
import pandas as pd
import numpy as np
import yaml
import torch
import torchaudio
from pathlib import Path
from transformers import AutoTokenizer, AutoModel, Wav2Vec2Processor, Wav2Vec2Model
from tqdm import tqdm

## Carregamento e Transformação

In [None]:
def to_seconds(t):
    h, m, s_ms = t.split(":")
    s, ms = s_ms.split(",")
    return int(h)*3600 + int(m)*60 + int(s) + int(ms)/1000

def to_snake_case(name: str) -> str:
    """Convert a string (like a column name) to snake_case."""
    name = name.strip()
    name = re.sub(r"[^\w\s]", "", name)
    name = re.sub(r"\s+", "_", name)
    return name.lower()

def prepare_df(df: pd.DataFrame):
    df["start_s"] = df["StartTime"].apply(to_seconds)
    df["end_s"] = df["EndTime"].apply(to_seconds)
    df["duration_s"] = df["end_s"] - df["start_s"]

    df[df["duration_s"] <= 0].head()

    df = df.drop(columns=["Sr No.", "StartTime", "EndTime", "Season", "Episode", "Sentiment"])

    df.columns = [to_snake_case(c) for c in df.columns]

    return df

In [None]:
train_df = prepare_df(pd.read_csv("./data/train_sent_emo.csv"))

## Funções base

In [None]:
def load_config(path):
    with open(path, "r") as f:
        return yaml.safe_load(f)

In [None]:
class BaseExtractor:
    def __init__(self, save_dir):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)

    def save(self, dialogue_id, utterance_id, vector):
        path = self.save_dir / f"{dialogue_id}_{utterance_id}.npy"
        np.save(path, vector)
        return str(path)

## Features de Texto

In [None]:
class TextExtractor(BaseExtractor):
    def __init__(self, model_name, save_dir, device="cpu"):
        super().__init__(save_dir)
        self.device = device
        self.tok = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)

    def extract(self, text):
        inputs = self.tok(text, return_tensors="pt", truncation=True, padding=True).to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs).last_hidden_state[:, 0, :]  # [CLS]
        return outputs.squeeze().cpu().numpy()

In [None]:
def extract_text_with_progress(df, extractor, save_dir):
    """
    Extract features for all utterances, resuming if interrupted.
    
    - Skips already existing .npy files
    - Displays tqdm progress bar
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    total = len(df)
    processed = 0

    if len(list(save_dir.glob("*.npy"))) >= len(df):
        print(f"✅ Already processed {len(df)} utterances. Skipping.")
        return


    for _, row in tqdm(df.iterrows(), total=total, desc="Extracting", ncols=80):
        filename = f"{row['dialogue_id']}_{row['utterance_id']}.npy"
        out_path = save_dir / filename

        if out_path.exists():
            processed += 1
            continue

        vec = extractor.extract(row["utterance"])
        extractor.save(row["dialogue_id"], row["utterance_id"], vec)

        processed += 1

    print(f"\n✅ Completed: {processed}/{total} utterances processed.")
    print(f"Features saved in: {save_dir}")

### Roberta


In [None]:
cfg = load_config("configs/text/roberta.yaml")
extractor = TextExtractor(cfg["model_name"], cfg["save_dir"], cfg["device"])

extract_text_with_progress(train_df, extractor, cfg["save_dir"])

### DistilBERT

In [None]:
cfg = load_config("configs/text/distilbert.yaml")
extractor = TextExtractor(cfg["model_name"], cfg["save_dir"], cfg["device"])

extract_text_with_progress(train_df, extractor, cfg["save_dir"])

### MpNet

In [None]:
cfg = load_config("configs/text/mpnet.yaml")
extractor = TextExtractor(cfg["model_name"], cfg["save_dir"], cfg["device"])

extract_text_with_progress(train_df, extractor, cfg["save_dir"])

## Features de Áudio

In [None]:
class AudioExtractor(BaseExtractor):
    def __init__(self, model_name, save_dir, cfg):
        super().__init__(save_dir)
        self.model_name = model_name.lower()
        self.cfg = cfg
        self.sample_rate = cfg.get("sample_rate", 16000)

        # --- MFCC ---
        if self.model_name == "mfcc":
            self.extractor = torchaudio.transforms.MFCC(
                sample_rate=self.sample_rate,
                n_mfcc=cfg.get("n_mfcc", 13),
                melkwargs={
                    "n_fft": cfg.get("n_fft", 400),
                    "hop_length": cfg.get("hop_length", 160),
                    "n_mels": cfg.get("n_mels", 23),
                },
            )

        # --- MelSpectrogram ---
        elif self.model_name == "melspectrogram":
            self.extractor = torchaudio.transforms.MelSpectrogram(
                sample_rate=self.sample_rate,
                n_fft=cfg.get("n_fft", 400),
                hop_length=cfg.get("hop_length", 160),
                n_mels=cfg.get("n_mels", 64),
            )

        # --- Wav2Vec2 / Wav2Vec2-like ---
        elif "wav2vec" in self.model_name:
            self.processor = Wav2Vec2Processor.from_pretrained(
                self.model_name, use_safetensors=True
            )
            self.model = Wav2Vec2Model.from_pretrained(self.model_name)
            self.model.eval()
        else:
            raise ValueError(f"Unsupported model: {self.model_name}")

    # ----------------------------------------------------------------------

    def extract(self, audio_path: str):
        # Load the audio directly from .mp4 (requires ffmpeg in system or Conda)
        waveform, sr = torchaudio.load(audio_path)

        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample if necessary
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(
                waveform, sr, self.sample_rate
            )

        # Ensure consistent max duration (avoid huge clips)
        max_len_seconds = self.cfg.get("max_len_seconds", 30)
        max_len_samples = int(max_len_seconds * self.sample_rate)
        if waveform.shape[1] > max_len_samples:
            waveform = waveform[:, :max_len_samples]

        # -------- MFCC or MelSpectrogram --------
        if self.model_name in ["mfcc", "melspectrogram"]:
            feat = self.extractor(waveform)
            vec = feat.mean(dim=-1).squeeze().numpy()
            return vec

        elif "wav2vec" in self.model_name:
            with torch.no_grad():
                inputs = self.processor(
                    waveform.squeeze(),
                    sampling_rate=self.sample_rate,
                    return_tensors="pt",
                    padding=True,
                )
                outputs = self.model(**inputs).last_hidden_state
                vec = outputs.mean(dim=1).squeeze().cpu().numpy()
            return vec

In [None]:
def extract_audio_with_progress(df: pd.DataFrame, extractor: AudioExtractor, save_dir: str, audio_dir: str):
    """
    Extract audio features for all utterances, continuing from previous progress.
    """

    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    total = len(df)
    processed = 0

    for _, row in tqdm(df.iterrows(), total=total, desc="Audio extracting", ncols=80):
        filename = f"{row['dialogue_id']}_{row['utterance_id']}.npy"
        out_path = save_dir / filename
        
        if out_path.exists():
            processed += 1
            continue

        try:
            vec = extractor.extract(f"{audio_dir}/dia{row['dialogue_id']}_utt{row['utterance_id']}.mp4")
            extractor.save(row["dialogue_id"], row["utterance_id"], vec)
        except Exception as e:
            print(f"⚠️ Error on {audio_dir}/dia{row['dialogue_id']}_utt{row['utterance_id']}.mp4: {e}")
            continue
        

        processed += 1

    print(f"\n✅ Completed {processed}/{total} utterances.")

### MFCC

In [None]:
cfg = load_config("configs/audio/mfcc.yaml")
extractor = AudioExtractor(cfg["model_name"], cfg["save_dir"], cfg)

extract_audio_with_progress(train_df, extractor, cfg["save_dir"], "./data/train/train_splits")

### MelSpectrogram

In [None]:
cfg = load_config("configs/audio/melspec.yaml")
extractor = AudioExtractor(cfg["model_name"], cfg["save_dir"], cfg)

extract_audio_with_progress(train_df, extractor, cfg["save_dir"], "./data/train/train_splits")

### Wav2Vec2

In [None]:
cfg = load_config("configs/audio/wav2vec2.yaml")
extractor = AudioExtractor(cfg["model_name"], cfg["save_dir"], cfg)

extract_audio_with_progress(train_df, extractor, cfg["save_dir"], "./data/train/train_splits")