In [1]:
import os
import random
import pickle
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import numpy as np
#from transformers import RobertaTokenizer, RobertaModel, Wav2Vec2FeatureExtractor
#from transformers.models.wavlm import WavLMModel

In [2]:
class PrecomputedEmbeddingsDataset(Dataset):
    def __init__(self, embeddings_dir):
        """
        dictionary in .pt:
          {
            'text_sequence': [num_sentences, text_dim]
            'audio_sequence': [num_sentences, audio_dim]
            'label': int
          }
        """
        self.embeddings_dir = embeddings_dir
        self.data = []
        self._load_data()

    def _load_data(self):
        # Get all .pt files in the directory
        all_files = [f for f in os.listdir(self.embeddings_dir) if f.endswith('.pt')]
        if not all_files:
            print(f"No .pt files found in {self.embeddings_dir}")

        for fname in all_files:
            file_path = os.path.join(self.embeddings_dir, fname)
            try:
                saved_data = torch.load(file_path, map_location='cpu')
                text_sequence = saved_data['text_sequence']
                audio_sequence = saved_data['audio_sequence']
                label = saved_data['label']
                self.data.append((text_sequence, audio_sequence, label))
            except Exception as e:
                print(f"Error loading {file_path}: {e}")

        print(f"Loaded {len(self.data)} sessions from {self.embeddings_dir}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Returns a tuple: (text_sequence, audio_sequence, label)
        return self.data[idx]

def collate_fn(batch):
    text_seqs = [item[0] for item in batch]  # text embeddings
    audio_seqs = [item[1] for item in batch] # audio embeddings
    labels = [item[2] for item in batch]

    # Compute padding lengths
    max_len_text = max(seq.size(0) for seq in text_seqs) if len(text_seqs) > 0 else 0
    max_len_audio = max(seq.size(0) for seq in audio_seqs) if len(audio_seqs) > 0 else 0

    # Create masks
    text_mask = torch.zeros(len(text_seqs), max_len_text, dtype=torch.long)
    audio_mask = torch.zeros(len(audio_seqs), max_len_audio, dtype=torch.long)

    padded_text = []
    for i, seq in enumerate(text_seqs):
        original_len = seq.size(0)
        if original_len < max_len_text:
            diff = max_len_text - original_len
            pad_tensor = torch.zeros(diff, seq.size(1))
            seq = torch.cat([seq, pad_tensor], dim=0)
        padded_text.append(seq.unsqueeze(0))
        text_mask[i, :original_len] = 1

    padded_text = torch.cat(padded_text, dim=0) if len(padded_text) > 0 else torch.empty(0)

    padded_audio = []
    for i, seq in enumerate(audio_seqs):
        original_len = seq.size(0)
        if original_len < max_len_audio:
            diff = max_len_audio - original_len
            pad_tensor = torch.zeros(diff, seq.size(1))
            seq = torch.cat([seq, pad_tensor], dim=0)
        padded_audio.append(seq.unsqueeze(0))
        audio_mask[i, :original_len] = 1

    padded_audio = torch.cat(padded_audio, dim=0) if len(padded_audio) > 0 else torch.empty(0)
    labels = torch.tensor(labels, dtype=torch.long)

    return padded_text, padded_audio, text_mask, audio_mask, labels

In [7]:
train_embedding_dir = '/home/popsatorn/Desktop/DD_FinalProject/Embeddings_Base/train'
val_embedding_dir = '/home/popsatorn/Desktop/DD_FinalProject/Embeddings_Base/validate'
test_embedding_dir = '/home/popsatorn/Desktop/DD_FinalProject/Embeddings_Base/test'

In [9]:
train_dataset = PrecomputedEmbeddingsDataset(train_embedding_dir)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

val_dataset = PrecomputedEmbeddingsDataset(val_embedding_dir)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

test_dataset = PrecomputedEmbeddingsDataset(test_embedding_dir)
test_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

  saved_data = torch.load(file_path, map_location='cpu')


Loaded 127 sessions from /home/popsatorn/Desktop/DD_FinalProject/Embeddings_Base/train
Loaded 30 sessions from /home/popsatorn/Desktop/DD_FinalProject/Embeddings_Base/validate
Loaded 31 sessions from /home/popsatorn/Desktop/DD_FinalProject/Embeddings_Base/test
