In [1]:
!pip install av
!pip install torch
!pip install librosa
!pip install os
!pip install json
!pip install torchvision
!pip install torchaudio
!pip install h5py

Collecting av
  Downloading av-14.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.7 kB)
Downloading av-14.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (35.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m35.4/35.4 MB[0m [31m102.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: av
Successfully installed av-14.3.0
Collecting torch
  Downloading torch-2.7.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting filelock (from torch)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.3.2-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.

In [1]:
import numpy as np
import librosa
import os
import json
import av
import h5py
import matplotlib.pyplot as plt
import torch.optim as optim
from sklearn.model_selection import train_test_split
from collections import defaultdict
from collections import Counter

import torch
import torchvision
import torchaudio
import torch.nn as nn
import torch.nn.functional as F

from torchvision.io import read_video
import torchvision.models.video as video_models
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.models.video import r3d_18, R3D_18_Weights
from sklearn.model_selection import StratifiedShuffleSplit

In [2]:
# Check if CUDA (NVIDIA GPU) is available
gpu_available = torch.cuda.is_available()
print(f"CUDA Available: {gpu_available}")

if gpu_available:
    # Get the GPU name
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    
    # Get the number of GPUs
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    
    # Get the current GPU memory usage
    print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
    print(f"GPU Memory Reserved: {torch.cuda.memory_reserved(0) / 1024**2:.2f} MB")
    
    # Test a simple tensor operation on GPU
    x = torch.rand(3, 3).cuda()
    print("Tensor successfully created on GPU:", x)
else:
    print("CUDA is not available. Running on CPU.")

CUDA Available: True
GPU Name: NVIDIA H100 80GB HBM3
Number of GPUs: 1
GPU Memory Allocated: 0.00 MB
GPU Memory Reserved: 0.00 MB
Tensor successfully created on GPU: tensor([[0.0941, 0.4517, 0.0067],
        [0.8339, 0.7428, 0.6903],
        [0.2682, 0.0352, 0.9358]], device='cuda:0')


In [3]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True
    cudnn.benchmark = False

# data split

In [4]:
# Extract participant ID from the key
def extract_participant_id(key):
    parts = key.split('_')
    if len(parts) < 3:
        return None
    if 'repeat' in parts:
        return parts[1] + '_' + parts[2] + '_' + parts[3]
    else:
        return parts[0] + '_' + parts[1] + '_' + parts[2]

# Stratified split at participant level (since each participant has only one label)
def stratified_participant_split(participant_to_label, train_ratio=0.7, val_ratio=0.15, random_state=42):
    participants = list(participant_to_label.keys())
    labels = [participant_to_label[pid] for pid in participants]

    sss1 = StratifiedShuffleSplit(n_splits=1, train_size=train_ratio, random_state=random_state)
    train_idx, temp_idx = next(sss1.split(participants, labels))

    temp_participants = [participants[i] for i in temp_idx]
    temp_labels = [labels[i] for i in temp_idx]

    sss2 = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=random_state)
    val_idx, test_idx = next(sss2.split(temp_participants, temp_labels))

    train_pids = [participants[i] for i in train_idx]
    val_pids = [temp_participants[i] for i in val_idx]
    test_pids = [temp_participants[i] for i in test_idx]

    return train_pids, val_pids, test_pids

# Print label distribution
def print_label_distribution(data, split_name="Split"):
    labels = [entry[2] for entry in data]
    label_counter = Counter(labels)
    total = sum(label_counter.values())

    print(f"\nLabel distribution in {split_name}:")
    for label in sorted(label_counter):
        count = label_counter[label]
        percentage = (count / total) * 100
        print(f"  Class {label}: {count} samples ({percentage:.2f}%)")

# Load data and perform participant-disjoint stratified split
def load_data_and_split(file_path='data_copy.list'):
    data_list = []
    participant_id_list = []
    participant_to_label = {}
    pid_to_indices = defaultdict(list)
    skipped = 0
    loaded = 0

    with open(file_path, 'r') as f:
        for line_num, line in enumerate(f, 1):
            if not line.strip():
                continue
            try:
                entry = json.loads(line.strip())
                if 'key' not in entry or 'wav_path' not in entry or 'label' not in entry:
                    skipped += 1
                    continue

                key = entry['key']
                wav_path = os.path.join('audio', os.path.basename(entry['wav_path']))
                video_path = os.path.join('video', os.path.splitext(os.path.basename(entry['wav_path']))[0] + '.avi')
                label = int(entry['label'])

                if not os.path.exists(wav_path) or not os.path.exists(video_path):
                    skipped += 1
                    continue

                participant_id = extract_participant_id(key)
                if participant_id is None:
                    skipped += 1
                    continue

                # Make sure all entries of the same participant have the same label
                if participant_id in participant_to_label:
                    if participant_to_label[participant_id] != label:
                        print(f"Label mismatch for participant {participant_id} at line {line_num}")
                        skipped += 1
                        continue
                else:
                    participant_to_label[participant_id] = label

                index = len(data_list)
                data_list.append((wav_path, video_path, label, key))
                participant_id_list.append(participant_id)
                pid_to_indices[participant_id].append(index)
                loaded += 1

            except json.JSONDecodeError:
                skipped += 1

    if not data_list:
        raise ValueError("No valid data loaded.")

    print(f"Loaded {loaded} valid samples, skipped {skipped}.")
    print(f"Unique participants: {len(participant_to_label)}")

    # Stratified participant split
    train_pids, val_pids, test_pids = stratified_participant_split(participant_to_label)

    train_indices = [i for pid in train_pids for i in pid_to_indices[pid]]
    val_indices = [i for pid in val_pids for i in pid_to_indices[pid]]
    test_indices = [i for pid in test_pids for i in pid_to_indices[pid]]

    np.save('train_indices.npy', train_indices)
    np.save('val_indices.npy', val_indices)
    np.save('test_indices.npy', test_indices)

    train_data = [data_list[i] for i in train_indices]
    val_data = [data_list[i] for i in val_indices]
    test_data = [data_list[i] for i in test_indices]

    print_label_distribution(train_data, "Train")
    print_label_distribution(val_data, "Validation")
    print_label_distribution(test_data, "Test")

    return train_data, val_data, test_data, train_indices, val_indices, test_indices

# Entry point
def main(file_path='data_copy.list'):
    return load_data_and_split(file_path)

if __name__ == "__main__":
    train_data, val_data, test_data, train_indices, val_indices, test_indices = main()

Loaded 60476 valid samples, skipped 6712.
Unique participants: 85

Label distribution in Train:
  Class 0: 9044 samples (21.67%)
  Class 1: 19959 samples (47.82%)
  Class 2: 7405 samples (17.74%)
  Class 3: 5326 samples (12.76%)

Label distribution in Validation:
  Class 0: 808 samples (8.54%)
  Class 1: 4580 samples (48.39%)
  Class 2: 2597 samples (27.44%)
  Class 3: 1480 samples (15.64%)

Label distribution in Test:
  Class 0: 2058 samples (22.18%)
  Class 1: 4650 samples (50.12%)
  Class 2: 2423 samples (26.12%)
  Class 3: 146 samples (1.57%)


# function loading

In [5]:
# Normalization constants
AUDIO_MEAN = [0.485, 0.456, 0.406]
AUDIO_STD = [0.229, 0.224, 0.225]

VIDEO_MEAN = [0.43216, 0.394666, 0.37645]
VIDEO_STD = [0.22803, 0.22145, 0.216989]

In [6]:
# Dataset for fine-tuning
class AudioVideoDataset(Dataset):
    def __init__(self, data_list, indices, spec_dir='spectrograms', fixed_samples=16000, fps=25):
        self.data = [data_list[i] for i in indices]  # Subset by indices
        self.spec_dir = spec_dir
        self.fixed_samples = fixed_samples
        self.fps = fps

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        wav_path, video_path, label, key = self.data[idx]
        # Load precomputed spectrogram
        spec_path = os.path.join(self.spec_dir, f'spec_{key}.npy')
        try:
            spec = np.load(spec_path)
            spec = torch.from_numpy(spec).float()  # Keep on CPU
        except FileNotFoundError:
            raise FileNotFoundError(f"Spectrogram not found: {spec_path}")
        
        # Load video
        _, video = load_aligned_pair(wav_path, video_path, self.fixed_samples, self.fps)
        return spec, video, label, key

# Load aligned pair (video only)
def load_aligned_pair(wav_path, video_path, fixed_samples=16000, fps=25):
    try:
        audio_duration = fixed_samples / 16000  # sr=16000
        video_frames = int(audio_duration * fps)
        video, _, info = read_video(video_path, pts_unit='sec')
        video = video.float() / 255.0
        if video.shape[0] > video_frames:
            video = video[:video_frames]
        elif video.shape[0] < video_frames:
            video = torch.nn.functional.pad(video, (0, 0, 0, 0, 0, video_frames - video.shape[0]))
        
        video = video.permute(0, 3, 1, 2)
        video = torch.nn.functional.interpolate(video, size=(112, 112), mode='bilinear', align_corners=False)
        video = video.permute(1, 0, 2, 3)
        return None, video  # Keep on CPU
    except Exception as e:
        raise RuntimeError(f"Error loading {video_path}: {e}")

def collate_fn(batch):
    specs, videos, labels, keys = zip(*batch)
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    # Audio preprocessing
    specs = torch.stack([spec for spec in specs]).to(device)  # [B, 1, H, W]
    specs = specs.unsqueeze(1)
    specs = torch.nn.functional.interpolate(specs, size=(224, 224), mode='bilinear', align_corners=False)
    specs = specs.repeat(1, 3, 1, 1)
    mean = torch.tensor(AUDIO_MEAN).view(1, 3, 1, 1).to(device)
    std = torch.tensor(AUDIO_STD).view(1, 3, 1, 1).to(device)
    specs = (specs - mean) / std

    # Video preprocessing
    videos = torch.stack([video for video in videos]).to(device)
    mean = torch.tensor(VIDEO_MEAN).view(1, 3, 1, 1, 1).to(device)
    std = torch.tensor(VIDEO_STD).view(1, 3, 1, 1, 1).to(device)
    videos = (videos - mean) / std

    labels = torch.tensor(labels, dtype=torch.long).to(device)
    return specs, videos, labels, keys

class AudioResNet(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.resnet(x)

    def get_embedding(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)
        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)
        x = self.resnet.avgpool(x)
        x = torch.flatten(x, 1)
        return x

class VideoModel(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.resnet3d = r3d_18(weights=R3D_18_Weights.KINETICS400_V1)
        in_features = self.resnet3d.fc.in_features
        self.resnet3d.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.resnet3d(x)

    def get_embedding(self, x):
        x = self.resnet3d.stem(x)
        x = self.resnet3d.layer1(x)
        x = self.resnet3d.layer2(x)
        x = self.resnet3d.layer3(x)
        x = self.resnet3d.layer4(x)
        x = self.resnet3d.avgpool(x)
        x = torch.flatten(x, 1)
        return x

# train val - 2d reset &unfreeze -- pariticipant specific -- trianing samples

In [7]:
def load_data_list(file_path='data_copy.list'):
    """
    Load data_copy.list, skipping clips without audio, video, or spectrogram paths.
    """
    data_list = []
    skipped = 0
    loaded = 0
    
    with open(file_path, 'r') as f:
        for line_num, line in enumerate(f, 1):
            if not line.strip():
                continue
            try:
                entry = json.loads(line.strip())
                if 'key' not in entry or 'wav_path' not in entry or 'label' not in entry:
                    print(f"Key error at line {line_num}: Missing 'key', 'wav_path', or 'label', line content: {line.strip()}")
                    skipped += 1
                    continue
                key = entry['key']
                wav_path = os.path.join('audio', os.path.basename(entry['wav_path']))
                video_path = os.path.join('video', os.path.splitext(os.path.basename(entry['wav_path']))[0] + '.avi')
                label = int(entry['label'])
                if label not in [0, 1, 2, 3]:
                    print(f"Invalid label {label} at line {line_num}: {key}, skipping.")
                    skipped += 1
                    continue
                
                # Check if audio, video, and spectrogram paths exist
                spec_path = os.path.join('spectrograms', f'spec_{key}.npy')
                if not os.path.exists(wav_path):
                    #print(f"Missing audio file at line {line_num}: {wav_path}, skipping.")
                    skipped += 1
                    continue
                if not os.path.exists(video_path):
                    #print(f"Missing video file at line {line_num}: {video_path}, skipping.")
                    skipped += 1
                    continue
                if not os.path.exists(spec_path):
                    #print(f"Missing spectrogram at line {line_num}: {spec_path}, skipping.")
                    skipped += 1
                    continue
                
                data_list.append((wav_path, video_path, label, key))
                loaded += 1
            except json.JSONDecodeError as e:
                print(f"JSON error at line {line_num}: {e}, line content: {line.strip()}")
                skipped += 1
                continue
            except KeyError as e:
                print(f"Key error at line {line_num}: Missing {e}, line content: {line.strip()}")
                skipped += 1
                continue
    
    print(f"Loaded {loaded} valid clips, skipped {skipped} entries.")
    return data_list

def fine_tune_models(audio_model, video_model, train_loader, val_loader, num_epochs=10, device='cuda:0'):
    audio_model = audio_model.to(device)
    video_model = video_model.to(device)

    # Freeze all layers
    for param in audio_model.parameters():
        param.requires_grad = False
    for param in video_model.parameters():
        param.requires_grad = False

    # Unfreeze only layer4 and fc
    for name, param in audio_model.named_parameters():
        if 'layer4' in name or 'fc' in name:
            param.requires_grad = True
    for name, param in video_model.named_parameters():
        if 'layer4' in name or 'fc' in name:
            param.requires_grad = True

    # Optimizers with small LR
    audio_optimizer = optim.Adam(filter(lambda p: p.requires_grad, audio_model.parameters()), lr=1e-5, weight_decay=1e-4)
    video_optimizer = optim.Adam(filter(lambda p: p.requires_grad, video_model.parameters()), lr=1e-5, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss().to(device)

    for epoch in range(num_epochs):
        audio_model.train()
        video_model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0

        for specs, videos, labels, _ in train_loader:
            # Audio
            audio_optimizer.zero_grad()
            audio_outputs = audio_model(specs)
            audio_loss = criterion(audio_outputs, labels)
            audio_loss.backward()
            audio_optimizer.step()

            # Video
            video_optimizer.zero_grad()
            video_outputs = video_model(videos)
            video_loss = criterion(video_outputs, labels)
            video_loss.backward()
            video_optimizer.step()

            # Accumulate metrics
            batch_size = specs.size(0)
            train_total += batch_size
            train_loss += (audio_loss.item() + video_loss.item()) * batch_size
            _, audio_pred = torch.max(audio_outputs, 1)
            _, video_pred = torch.max(video_outputs, 1)
            train_correct += (audio_pred == labels).sum().item() + (video_pred == labels).sum().item()

        train_loss /= (2 * train_total)
        train_acc = train_correct / (2 * train_total) * 100

        # Validation
        audio_model.eval()
        video_model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0

        with torch.no_grad():
            for specs, videos, labels, _ in val_loader:
                audio_outputs = audio_model(specs)
                video_outputs = video_model(videos)
                audio_loss = criterion(audio_outputs, labels)
                video_loss = criterion(video_outputs, labels)
                batch_size = specs.size(0)
                val_total += batch_size
                val_loss += (audio_loss.item() + video_loss.item()) * batch_size
                _, audio_pred = torch.max(audio_outputs, 1)
                _, video_pred = torch.max(video_outputs, 1)
                val_correct += (audio_pred == labels).sum().item() + (video_pred == labels).sum().item()

        val_loss /= (2 * val_total)
        val_acc = val_correct / (2 * val_total) * 100

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        torch.cuda.empty_cache()

    torch.save(audio_model.state_dict(), 'audio_finetuned.pth')
    torch.save(video_model.state_dict(), 'video_finetuned.pth')
    print("Saved fine-tuned models.")
    return audio_model, video_model

def main(file_path='data_copy.list', num_epochs=10, batch_size=16):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}, GPU count: {torch.cuda.device_count()}")
    
    # Load indices
    train_indices = np.load('train_indices.npy')
    val_indices = np.load('val_indices.npy')
    
    # Load data list
    data_list = load_data_list(file_path=file_path)
    
    # Verify indices
    max_index = len(data_list) - 1
    if max(train_indices) > max_index or max(val_indices) > max_index:
        raise ValueError(f"Indices exceed data_list length ({max_index}). Re-run split_dataset.py with {file_path}.")
    
    # Prepare datasets
    train_dataset = AudioVideoDataset(data_list, train_indices, spec_dir='spectrograms')
    val_dataset = AudioVideoDataset(data_list, val_indices, spec_dir='spectrograms')
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False, collate_fn=collate_fn)
    
    # Initialize models
    audio_model = AudioResNet(num_classes=4)
    video_model = VideoModel(num_classes=4)
    
    # Fine-tune models
    print("Starting fine-tuning...")
    audio_model, video_model = fine_tune_models(audio_model, video_model, train_loader, val_loader, num_epochs=num_epochs, device=device)

if __name__ == "__main__":
    main()

Using device: cuda:0, GPU count: 1
JSON error at line 1919: Extra data: line 1 column 120 (char 119), line content: {"key": "S_M_00051_G2_task1_4_S00004", "wav_path": "data/S_M_00051_G2_task1_4_S00004.wav", "label": 1, "Frenchay": 114}4{"key": "S_M_00020_G4_task1_5_S00007", "wav_path": "data/S_M_00020_G4_task1_5_S00007.wav", "label": 1, "Frenchay": 92}
Loaded 60476 valid clips, skipped 6712 entries.
Starting fine-tuning...
Epoch 1/10, Train Loss: 0.4234, Train Acc: 83.58%, Val Loss: 1.9138, Val Acc: 44.50%
Epoch 2/10, Train Loss: 0.2122, Train Acc: 92.08%, Val Loss: 2.0609, Val Acc: 46.11%
Epoch 3/10, Train Loss: 0.1440, Train Acc: 94.73%, Val Loss: 2.2706, Val Acc: 45.96%
Epoch 4/10, Train Loss: 0.1043, Train Acc: 96.33%, Val Loss: 2.4306, Val Acc: 44.81%
Epoch 5/10, Train Loss: 0.0769, Train Acc: 97.33%, Val Loss: 2.3639, Val Acc: 46.79%
Epoch 6/10, Train Loss: 0.0576, Train Acc: 98.09%, Val Loss: 2.5523, Val Acc: 46.20%
Epoch 7/10, Train Loss: 0.0425, Train Acc: 98.64%, Val Loss: 2.

# embedding extraction 

In [6]:
# Save batch of embeddings and metadata
def save_batch(audio_embeds, video_embeds, metadata, batch_num, output_dir):
    batch_dir = os.path.join(output_dir, f'batch_{batch_num}')
    os.makedirs(batch_dir, exist_ok=True)
    
    for i, (audio_embed, video_embed, meta) in enumerate(zip(audio_embeds, video_embeds, metadata)):
        key = meta['key']
        audio_path = os.path.join(batch_dir, f'audio_embed_{key}.npy')
        video_path = os.path.join(batch_dir, f'video_embed_{key}.npy')
        np.save(audio_path, audio_embed)
        np.save(video_path, video_embed)
        meta_path = os.path.join(batch_dir, f'meta_{key}.json')
        with open(meta_path, 'w') as f:
            json.dump(meta, f)

def extract_finetuned_embeddings(audio_model, video_model, file_path='data_copy.list', output_dir='features_finetuned', batch_size=10, device='cuda:0'):
    audio_model.eval()
    video_model.eval()

    # Load data
    data_list = []
    skipped = 0
    loaded = 0
    with open(file_path, 'r') as f:
        for line_num, line in enumerate(f, 1):
            if not line.strip():
                continue
            try:
                entry = json.loads(line.strip())
                wav_path = os.path.join('audio', os.path.basename(entry['wav_path']))
                video_path = os.path.join('video', os.path.splitext(os.path.basename(entry['wav_path']))[0] + '.avi')
                if os.path.exists(wav_path) and os.path.exists(video_path):
                    label = int(entry['label'])  # Labels: 0, 1, 2, 3
                    if label not in [0, 1, 2, 3]:
                        print(f"Warning: Invalid label {label} in {wav_path}, skipping.")
                        skipped += 1
                        continue
                    key = entry['key']
                    spec_path = os.path.join('spectrograms', f'spec_{key}.npy')
                    if not os.path.exists(spec_path):
                        print(f"Warning: Missing spectrogram {spec_path}, skipping.")
                        skipped += 1
                        continue
                    data_list.append((wav_path, video_path, label, key))
                    loaded += 1
                else:
                    #print(f"Warning: Missing {wav_path} or {video_path}, skipping.")
                    skipped += 1
                    continue
            except json.JSONDecodeError as e:
                print(f"JSON error at line {line_num}: {e}, line content: {line.strip()}")
                skipped += 1
                continue
            except KeyError as e:
                print(f"Key error at line {line_num}: Missing {e}, line content: {line.strip()}")
                skipped += 1
                continue

    if not data_list:
        raise ValueError("No valid data loaded from data_copy.list. Check file format and spectrograms.")

    print(f"Loaded {loaded} valid entries, skipped {skipped} entries.")

    # Prepare dataset
    dataset = AudioVideoDataset(data_list, spec_dir='spectrograms')
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False, collate_fn=collate_fn)

    # Extract embeddings
    os.makedirs(output_dir, exist_ok=True)
    batch_num = 0
    batch_audio_specs = []
    batch_video_embeds = []
    batch_metadata = []
    processed = 0

    with torch.no_grad():
        for specs, videos, labels, keys in data_loader:
            # Extract embeddings
            audio_embeds = audio_model.get_embedding(specs)  # [batch, 512]
            video_embeds = video_model.get_embedding(videos)  # [batch, 512]
            # Collect batch
            for i in range(specs.size(0)):
                batch_audio_specs.append(audio_embeds[i].cpu().numpy())
                batch_video_embeds.append(video_embeds[i].cpu().numpy())
                batch_metadata.append({
                    'key': keys[i],
                    'label': labels[i].item(),
                    'wav_path': data_list[batch_num * batch_size + i][0],
                    'video_path': data_list[batch_num * batch_size + i][1]
                })
                processed += 1
                if processed % 100 == 0:
                    print(f"Processed {processed} items so far.")
                if len(batch_audio_specs) >= batch_size:
                    save_batch(batch_audio_specs, batch_video_embeds, batch_metadata, batch_num, output_dir)
                    batch_audio_specs, batch_video_embeds, batch_metadata = [], [], []
                    batch_num += 1
            torch.cuda.empty_cache()

    if batch_audio_specs:
        save_batch(batch_audio_specs, batch_video_embeds, batch_metadata, batch_num, output_dir)

    print(f"Finished. Processed: {processed}, Skipped: {skipped}")
    return skipped

def main(file_path='data_copy.list', audio_model_path='audio_finetuned.pth', video_model_path='video_finetuned.pth', num_classes=4):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}, GPU count: {torch.cuda.device_count()}")

    # Load models
    audio_model = AudioResNet(num_classes=num_classes).to(device)
    video_model = VideoModel(num_classes=num_classes).to(device)
    audio_model.load_state_dict(torch.load(audio_model_path, map_location=device))
    video_model.load_state_dict(torch.load(video_model_path, map_location=device))
    print(f"Loaded fine-tuned models: {audio_model_path}, {video_model_path}")

    # Extract embeddings
    print("Extracting fine-tuned embeddings...")
    extract_finetuned_embeddings(audio_model, video_model, file_path=file_path, output_dir='features_finetuned', batch_size=10, device=device)

if __name__ == "__main__":
    main()

Using device: cuda:0, GPU count: 1
Loaded fine-tuned models: audio_finetuned.pth, video_finetuned.pth
Extracting fine-tuned embeddings...
JSON error at line 1919: Extra data: line 1 column 120 (char 119), line content: {"key": "S_M_00051_G2_task1_4_S00004", "wav_path": "data/S_M_00051_G2_task1_4_S00004.wav", "label": 1, "Frenchay": 114}4{"key": "S_M_00020_G4_task1_5_S00007", "wav_path": "data/S_M_00020_G4_task1_5_S00007.wav", "label": 1, "Frenchay": 92}
Loaded 60476 valid entries, skipped 6712 entries.


TypeError: AudioVideoDataset.__init__() missing 1 required positional argument: 'indices'

In [6]:
def load_first_batch_embeddings(batch_dir='features_finetuned/batch_0', batch_size=10, device='cuda:0'):
    """
    Load the first batch of extracted 512-D embeddings and metadata.
    
    Args:
        batch_dir (str): Directory containing batch_0 embeddings and metadata.
        batch_size (int): Number of samples in the batch (default: 10).
        device (str): Device to load tensors onto (default: 'cuda:0').
    
    Returns:
        audio_embeds (torch.Tensor): [batch_size, 512] audio embeddings.
        video_embeds (torch.Tensor): [batch_size, 512] video embeddings.
        labels (torch.Tensor): [batch_size] labels.
        keys (list): List of keys for the batch.
    """
    if not os.path.exists(batch_dir):
        raise FileNotFoundError(f"Batch directory {batch_dir} does not exist. Run extract_embeddings_512.py first.")

    audio_embeds = []
    video_embeds = []
    labels = []
    keys = []
    
    # Get all .json files in batch_0
    meta_files = [f for f in os.listdir(batch_dir) if f.startswith('meta_') and f.endswith('.json')]
    if not meta_files:
        raise FileNotFoundError(f"No metadata files found in {batch_dir}.")
    
    # Load up to batch_size samples
    for meta_file in meta_files[:batch_size]:
        meta_path = os.path.join(batch_dir, meta_file)
        key = meta_file.replace('meta_', '').replace('.json', '')
        
        # Load metadata
        try:
            with open(meta_path, 'r') as f:
                meta = json.load(f)
            if 'key' not in meta or 'label' not in meta:
                print(f"Warning: Invalid metadata in {meta_path}, skipping.")
                continue
        except json.JSONDecodeError as e:
            print(f"Error reading {meta_path}: {e}, skipping.")
            continue
        
        # Load embeddings
        audio_path = os.path.join(batch_dir, f'audio_embed_{key}.npy')
        video_path = os.path.join(batch_dir, f'video_embed_{key}.npy')
        
        if not os.path.exists(audio_path) or not os.path.exists(video_path):
            print(f"Warning: Missing {audio_path} or {video_path}, skipping.")
            continue
        
        try:
            audio_embed = np.load(audio_path)  # [512]
            video_embed = np.load(video_path)  # [512]
            if audio_embed.shape != (512,) or video_embed.shape != (512,):
                print(f"Warning: Invalid shape in {audio_path} ({audio_embed.shape}) or {video_path} ({video_embed.shape}), skipping.")
                continue
        except Exception as e:
            print(f"Error loading {audio_path} or {video_path}: {e}, skipping.")
            continue
        
        audio_embeds.append(audio_embed)
        video_embeds.append(video_embed)
        labels.append(meta['label'])
        keys.append(key)
    
    if not audio_embeds:
        raise ValueError(f"No valid embeddings loaded from {batch_dir}.")
    
    # Convert to tensors
    audio_embeds = torch.tensor(np.stack(audio_embeds), dtype=torch.float32).to(device)  # [batch_size, 512]
    video_embeds = torch.tensor(np.stack(video_embeds), dtype=torch.float32).to(device)  # [batch_size, 512]
    labels = torch.tensor(labels, dtype=torch.long).to(device)  # [batch_size]
    
    # Print batch details
    print(f"\nFirst batch loaded successfully from {batch_dir}:")
    print(f"Audio embeddings shape: {audio_embeds.shape}")  # Expected: [10, 512]
    print(f"Video embeddings shape: {video_embeds.shape}")  # Expected: [10, 512]
    print(f"Labels shape: {labels.shape}, values: {labels.tolist()}")  # Expected: [10]
    print(f"Keys: {keys}")
    print(f"Sample audio embedding (first 5 values, key={keys[0]}): {audio_embeds[0, :5].tolist()}")
    print(f"Sample video embedding (first 5 values, key={keys[0]}): {video_embeds[0, :5].tolist()}")
    
    return audio_embeds, video_embeds, labels, keys

def main(batch_dir='features_finetuned/batch_0', batch_size=10):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}, GPU count: {torch.cuda.device_count()}")
    
    # Load first batch
    print("Loading first batch of embeddings...")
    audio_embeds, video_embeds, labels, keys = load_first_batch_embeddings(batch_dir=batch_dir, batch_size=batch_size, device=device)

if __name__ == "__main__":
    main()

Using device: cuda:0, GPU count: 1
Loading first batch of embeddings...

First batch loaded successfully from features_finetuned/batch_0:
Audio embeddings shape: torch.Size([10, 512])
Video embeddings shape: torch.Size([10, 512])
Labels shape: torch.Size([10]), values: [0, 3, 3, 1, 1, 1, 1, 0, 0, 1]
Keys: ['N_M_10003_G4_task1_4_S00004', 'S_F_00010_G4_task6_1_S00009', 'S_M_00013_G4_task2_1_S00002', 'S_M_00004_G5_task4_3_S00002', 'S_M_00044_G2_task7_2_S00008', 'S_M_00004_G2_task4_3_S00011', 'S_M_00051_G2_task2_1_S00004', 'N_M_10015_G2_task1_4_S00002', 'N_M_10010_G3_task8_1_S00001', 'S_M_00051_G2_task4_1_S00000']
Sample audio embedding (first 5 values, key=N_M_10003_G4_task1_4_S00004): [0.04829590395092964, 0.029215507209300995, 0.038779713213443756, 0.6907779574394226, 0.5916787385940552]
Sample video embedding (first 5 values, key=N_M_10003_G4_task1_4_S00004): [0.0010360876331105828, 0.9481890201568604, 0.25489696860313416, 0.04081109166145325, 0.014767770655453205]


# no traning - 2dResNet & unfreeze (original 2.0) -- all samples  

In [4]:
# Load and align pair to fixed length
def load_aligned_pair(wav_path, video_path, sr=16000, fixed_samples=16000, fps=25, device='cpu'):
    try:
        # Load audio
        audio, sample_rate = torchaudio.load(wav_path)
        if sample_rate != sr:
            audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=sr)(audio)
        if audio.shape[1] > fixed_samples:
            audio = audio[:, :fixed_samples]
        elif audio.shape[1] < fixed_samples:
            audio = torch.nn.functional.pad(audio, (0, fixed_samples - audio.shape[1]))

        # Calculate target number of video frames
        audio_duration = fixed_samples / sr  # 1s
        video_frames = int(audio_duration * fps)  # 25 frames

        # Load video
        video, _, info = read_video(video_path, pts_unit='sec')
        video = video.float() / 255.0  # Normalize to [0, 1]

        # Align video frames
        if video.shape[0] > video_frames:
            video = video[:video_frames]
        elif video.shape[0] < video_frames:
            video = torch.nn.functional.pad(video, (0, 0, 0, 0, 0, video_frames - video.shape[0]))

        # Resize video to 112x112 (standard for 3D ResNet)
        video = video.permute(0, 3, 1, 2)  # [T, H, W, C] -> [T, C, H, W]
        video = torch.nn.functional.interpolate(video, size=(112, 112), mode='bilinear', align_corners=False)

        # Reorder dimensions for 3D ResNet: [T, C, H, W] -> [C, T, H, W]
        video = video.permute(1, 0, 2, 3)  # [C, T, H, W], e.g., [3, 25, 112, 112]

        return audio.to(device), video.to(device)
    except Exception as e:
        raise RuntimeError(f"Error loading pair {wav_path} and {video_path}: {e}")

# Audio to spectrogram with dB conversion and normalization
def audio_to_spectrogram(audio, sr=16000, n_fft=512, hop_length=256, n_mels=128, fmax=8000, device='cuda' if torch.cuda.is_available() else 'cpu'):
    spec_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, f_max=fmax
    ).to(audio.device)

    # Compute mel-spectrogram
    mel = spec_transform(audio)  # [channels, n_mels, time_frames], e.g., [1, 128, 62]
    mel = mel.mean(dim=0)  # Average across channels, e.g., [128, 62]

    # Convert to dB scale
    mel_db = torchaudio.transforms.AmplitudeToDB(stype='power', top_db=None)(mel)  # [128, 62]

    # Normalize by subtracting the maximum value
    mel_db -= mel_db.max()  # [128, 62]
    
    # Ensure exactly 62 time frames
    target_frames = 62
    if mel_db.shape[1] > target_frames:
        mel_db = mel_db[:, :target_frames]  # Trim
    elif mel_db.shape[1] < target_frames:
        mel_db = F.pad(mel_db, (0, target_frames - mel_db.shape[1]))  # Pad

    # Resize to [224, 224]
    mel_db = mel_db.unsqueeze(0).unsqueeze(0)  # [1, 1, 128, 62]
    mel_db = F.interpolate(mel_db, size=(224, 224), mode='bilinear', align_corners=False)  # [1, 1, 224, 224]
    mel_db = mel_db.squeeze(0)  # [1, 224, 224]
    mel_db = mel_db.repeat(3, 1, 1)  # [3, 224, 224]

    return mel_db

# 2D ResNet for audio spectrograms
class AudioResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        # Keep original conv1 for [224, 224]
        # Output 128D
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 128)

    def forward(self, x):
        return self.resnet(x)

# VideoModel with half layers unfrozen
class VideoModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet3d = r3d_18(weights='KINETICS400_V1')
        in_features = self.resnet3d.fc.in_features  # 512
        self.resnet3d.fc = nn.Linear(in_features, 512)  # Output 512
        # Unfreeze layer3, layer4
        for param in self.resnet3d.parameters():
            param.requires_grad = False
        unfreeze_layers = ['layer3', 'layer4']
        for name, child in self.resnet3d.named_children():
            if name in unfreeze_layers:
                for param in child.parameters():
                    param.requires_grad = True

    def forward(self, x):
        return self.resnet3d(x)

# Original save_batch function
def save_batch(audio_specs, video_embeds, metadata, batch_num, output_dir):
    with h5py.File(os.path.join(output_dir, f'features_batch_{batch_num}.h5'), 'w') as f:
        f.create_dataset('audio_embeddings', data=np.array(audio_specs), compression='gzip')
        f.create_dataset('video_embeddings', data=np.array(video_embeds), compression='gzip')
        metadata_dtype = [('key', 'S50'), ('label', 'S10'), ('Frenchay', 'f4'), ('wav_path', 'S100'), ('video_path', 'S100')]
        f.create_dataset('metadata', data=np.array([(m['key'], m['label'], m['Frenchay'], m['wav_path'], m['video_path']) for m in metadata], dtype=metadata_dtype))

# Process and save to HDF5 with audio and video embeddings
def process_and_save_to_hdf5(file_path='data_copy.list', output_dir='features_updated_224', sr=16000, fixed_samples=16000, fps=25, batch_size=10, device='cuda' if torch.cuda.is_available() else 'cpu'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    skipped = 0
    processed = 0

    # Initialize models
    #audio_cnn = AudioCNN().to(device)
    audio_model = AudioResNet().to(device)
    video_model = VideoModel().to(device)
    #audio_cnn.eval()  # Use pre-computed BN stats
    audio_model.eval()
    video_model.eval()

    batch_audio_specs = []
    batch_video_embeds = []
    batch_metadata = []

    with open(file_path, 'r') as f:
        for line in f:
            if not line.strip():
                continue
            try:
                entry = json.loads(line.strip())
                wav_path = os.path.join('audio', os.path.basename(entry['wav_path']))
                video_path = os.path.join('video', os.path.splitext(os.path.basename(entry['wav_path']))[0] + '.avi')

                if not os.path.exists(wav_path) or not os.path.exists(video_path):
                    #print(f"Warning: Missing {wav_path} or {video_path}, skipping pair.")
                    skipped += 1
                    continue

                audio, video = load_aligned_pair(
                    wav_path, video_path, sr=sr, fixed_samples=fixed_samples, fps=fps, device=device
                )

                # Audio processing
                spec = audio_to_spectrogram(
                    audio, sr=sr, n_fft=512, hop_length=256, n_mels=128, fmax=8000, device=device
                )  # [128, 62]
                if spec.shape != torch.Size([3, 224, 224]):
                    raise ValueError(f"Expected spectrogram shape [3, 224, 224], got {spec.shape}")
                with torch.no_grad():
                    spec_embedding = audio_model(spec.unsqueeze(0)).detach().cpu().numpy()  # [1, 128]

                # Video processing
                with torch.no_grad():
                    video_embed = video_model(video.unsqueeze(0)).detach().cpu().numpy()  # [1, 512]

                batch_audio_specs.append(spec_embedding)
                batch_video_embeds.append(video_embed)
                batch_metadata.append({
                    'key': entry['key'],
                    'label': entry['label'],
                    'Frenchay': entry['Frenchay'],
                    'wav_path': wav_path,
                    'video_path': video_path
                })

                processed += 1
                if processed % 100 == 0:
                    print(f"Processed {processed} items so far.")

                if len(batch_audio_specs) >= batch_size:
                    batch_num = processed // batch_size
                    save_batch(batch_audio_specs, batch_video_embeds, batch_metadata, batch_num, output_dir)
                    batch_audio_specs, batch_video_embeds, batch_metadata = [], [], []

                del audio, video, spec, spec_embedding, video_embed
                torch.cuda.empty_cache()

            except Exception as e:
                print(f"Error processing {wav_path}: {e}")
                skipped += 1
                continue

        if batch_audio_specs:
            batch_num = (processed // batch_size) + 1
            save_batch(batch_audio_specs, batch_video_embeds, batch_metadata, batch_num, output_dir)

    print(f"Finished. Processed: {processed}, Skipped: {skipped}")
    return skipped

if __name__ == "__main__":
    process_and_save_to_hdf5(batch_size=10)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/ucloud/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 163MB/s] 
Downloading: "https://download.pytorch.org/models/r3d_18-b3b3357e.pth" to /home/ucloud/.cache/torch/hub/checkpoints/r3d_18-b3b3357e.pth
100%|██████████| 127M/127M [00:00<00:00, 196MB/s]  


Processed 100 items so far.
Processed 200 items so far.
Processed 300 items so far.
Processed 400 items so far.
Processed 500 items so far.
Processed 600 items so far.
Processed 700 items so far.
Processed 800 items so far.
Processed 900 items so far.
Processed 1000 items so far.
Processed 1100 items so far.
Processed 1200 items so far.
Processed 1300 items so far.
Processed 1400 items so far.
Processed 1500 items so far.
Processed 1600 items so far.
Processed 1700 items so far.
Error processing audio/S_F_00050_G1_task6_1_S00004.wav: Extra data: line 1 column 120 (char 119)
Processed 1800 items so far.
Processed 1900 items so far.
Processed 2000 items so far.
Processed 2100 items so far.
Processed 2200 items so far.
Processed 2300 items so far.
Processed 2400 items so far.
Processed 2500 items so far.
Processed 2600 items so far.
Processed 2700 items so far.
Processed 2800 items so far.
Processed 2900 items so far.
Processed 3000 items so far.
Processed 3100 items so far.
Processed 320

# tran val split version 3 -- precompute audio spectrograms - 10000 samples

In [4]:
def load_aligned_pair(wav_path, video_path, sr=16000, fixed_samples=16000, fps=25):
    try:
        audio_duration = fixed_samples / sr
        video_frames = int(audio_duration * fps)
        video, _, info = read_video(video_path, pts_unit='sec')
        video = video.float() / 255.0
        if video.shape[0] > video_frames:
            video = video[:video_frames]
        elif video.shape[0] < video_frames:
            video = torch.nn.functional.pad(video, (0, 0, 0, 0, 0, video_frames - video.shape[0]))
        video = video.permute(0, 3, 1, 2)
        video = torch.nn.functional.interpolate(video, size=(112, 112), mode='bilinear', align_corners=False)
        video = video.permute(1, 0, 2, 3)
        return video
    except Exception as e:
        raise RuntimeError(f"Error loading {video_path}: {e}")

class AudioCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(32 * 20 * 15, 128)  # For [80, 63]

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.bn1(self.conv1(x))))
        x = self.pool(nn.functional.relu(self.bn2(self.conv2(x))))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class VideoModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet3d = video_models.r3d_18(pretrained=True)
        in_features = self.resnet3d.fc.in_features
        self.resnet3d.fc = nn.Linear(in_features, 4)
        for param in self.resnet3d.parameters():
            param.requires_grad = False
        unfreeze_layers = ['layer3', 'layer4']
        for name, child in self.resnet3d.named_children():
            if name in unfreeze_layers:
                for param in child.parameters():
                    param.requires_grad = True

    def forward(self, x):
        return self.resnet3d(x)

class AVPairDataset(Dataset):
    def __init__(self, entries, spec_dir='spectrograms_10000', sr=16000, fixed_samples=16000, fps=25):
        self.entries = entries
        self.spec_dir = spec_dir
        self.sr = sr
        self.fixed_samples = fixed_samples
        self.fps = fps

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        entry = self.entries[idx]
        try:
            spec_path = os.path.join(self.spec_dir, f'spec_{entry["key"]}.npy')
            spec = torch.from_numpy(np.load(spec_path))
            video = load_aligned_pair(entry['wav_path'], entry['video_path'], self.sr, self.fixed_samples, self.fps)
            label = torch.tensor(entry['label'], dtype=torch.long)
            return spec, video, label
        except Exception as e:
            print(f"Error loading {entry['wav_path']} or spectrogram: {e}")
            return None, None, None

def plot_metrics(audio_metrics, video_metrics, num_epochs):
    epochs = range(1, num_epochs + 1)
    plt.figure(figsize=(12, 10))
    plt.subplot(2, 2, 1)
    plt.plot(epochs, audio_metrics['train_loss'], label='Train Loss', marker='o')
    plt.plot(epochs, audio_metrics['val_loss'], label='Val Loss', marker='o')
    plt.title('AudioCNN Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.subplot(2, 2, 2)
    plt.plot(epochs, audio_metrics['train_acc'], label='Train Acc', marker='o')
    plt.plot(epochs, audio_metrics['val_acc'], label='Val Acc', marker='o')
    plt.title('AudioCNN Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    plt.subplot(2, 2, 3)
    plt.plot(epochs, video_metrics['train_loss'], label='Train Loss', marker='o')
    plt.plot(epochs, video_metrics['val_loss'], label='Val Loss', marker='o')
    plt.title('VideoModel Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.subplot(2, 2, 4)
    plt.plot(epochs, video_metrics['train_acc'], label='Train Acc', marker='o')
    plt.plot(epochs, video_metrics['val_acc'], label='Val Acc', marker='o')
    plt.title('VideoModel Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.show()

def train_av_models(file_path='data_copy.list', sr=16000, fixed_samples=16000, fps=25, num_samples=10000, device='cuda' if torch.cuda.is_available() else 'cpu'):
    all_entries = []
    with open(file_path, 'r') as f:
        for line in f:
            if not line.strip():
                continue
            try:
                entry = json.loads(line.strip())
                wav_path = os.path.join('audio', os.path.basename(entry['wav_path']))
                video_path = os.path.join('video', os.path.splitext(os.path.basename(entry['wav_path']))[0] + '.avi')
                if os.path.exists(wav_path) and os.path.exists(video_path):
                    all_entries.append({
                        'key': entry['key'],
                        'wav_path': wav_path,
                        'video_path': video_path,
                        'label': int(entry['label']),
                        'Frenchay': entry['Frenchay']
                    })
            except json.JSONDecodeError:
                print(f"Skipping invalid JSON line: {line}")
                continue

    print(f"Found {len(all_entries)} valid entries in data.list")
    
    if len(all_entries) > num_samples:
        _, selected_entries = train_test_split(
            all_entries, 
            test_size=num_samples, 
            stratify=[e['label'] for e in all_entries], 
            random_state=42
        )
        all_entries = selected_entries
        print(f"Selected {len(all_entries)} samples with stratification")
    else:
        print(f"Warning: Only {len(all_entries)} entries available, using all")

    spec_dir = 'spectrograms_10000'
    valid_entries = []
    for entry in all_entries:
        spec_path = os.path.join(spec_dir, f'spec_{entry["key"]}.npy')
        if os.path.exists(spec_path):
            valid_entries.append(entry)
        else:
            print(f"Skipping {entry['wav_path']}: spectrogram not found")
    all_entries = valid_entries
    print(f"Found {len(all_entries)} entries with precomputed spectrograms")

    train_entries, val_entries = train_test_split(all_entries, test_size=0.15, stratify=[e['label'] for e in all_entries], random_state=42)
    print(f"Training entries: {len(train_entries)}, Validation entries: {len(val_entries)}")

    train_dataset = AVPairDataset(train_entries, spec_dir='spectrograms_10000', sr=sr, fixed_samples=fixed_samples, fps=fps)
    val_dataset = AVPairDataset(val_entries, spec_dir='spectrograms_10000', sr=sr, fixed_samples=fixed_samples, fps=fps)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

    audio_cnn = AudioCNN().to(device)
    video_model = VideoModel().to(device)

    audio_optimizer = torch.optim.Adam(audio_cnn.parameters(), lr=0.001)
    video_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, video_model.parameters()), lr=0.0001)
    criterion = nn.CrossEntropyLoss()

    audio_metrics = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    video_metrics = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    best_audio_acc = 0.0
    best_video_acc = 0.0

    num_epochs = 5
    print("Training AudioCNN...")
    for epoch in range(num_epochs):
        audio_cnn.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        for batch_idx, (spec, _, labels) in enumerate(train_loader):
            if spec is None or labels is None:
                continue
            if batch_idx == 0:
                print(f"Spectrogram shape: {spec.shape}")  # Debug
            spec, labels = spec.to(device), labels.to(device)
            audio_optimizer.zero_grad()
            outputs = audio_cnn(spec.unsqueeze(1))
            loss = criterion(outputs, labels)
            loss.backward()
            audio_optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss /= len(train_loader)
        train_acc = 100 * train_correct / train_total

        audio_cnn.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for spec, _, labels in val_loader:
                if spec is None or labels is None:
                    continue
                spec, labels = spec.to(device), labels.to(device)
                outputs = audio_cnn(spec.unsqueeze(1))
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = 100 * val_correct / val_total

        audio_metrics['train_loss'].append(train_loss)
        audio_metrics['val_loss'].append(val_loss)
        audio_metrics['train_acc'].append(train_acc)
        audio_metrics['val_acc'].append(val_acc)

        print(f"Audio Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        if val_acc > best_audio_acc:
            best_audio_acc = val_acc
            torch.save(audio_cnn.state_dict(), 'best_audio_cnn.pth')
            print("Saved best AudioCNN model")

    print("Training VideoModel...")
    for epoch in range(num_epochs):
        video_model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        for _, video, labels in train_loader:
            if video is None or labels is None:
                continue
            video, labels = video.to(device), labels.to(device)
            video_optimizer.zero_grad()
            outputs = video_model(video)
            loss = criterion(outputs, labels)
            loss.backward()
            video_optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss /= len(train_loader)
        train_acc = 100 * train_correct / train_total

        video_model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for _, video, labels in val_loader:
                if video is None or labels is None:
                    continue
                video, labels = video.to(device), labels.to(device)
                outputs = video_model(video)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = 100 * val_correct / val_total

        video_metrics['train_loss'].append(train_loss)
        video_metrics['val_loss'].append(val_loss)
        video_metrics['train_acc'].append(train_acc)
        video_metrics['val_acc'].append(val_acc)

        print(f"Video Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        if val_acc > best_video_acc:
            best_video_acc = val_acc
            torch.save(video_model.state_dict(), 'best_video_model.pth')
            print("Saved best VideoModel")

    plot_metrics(audio_metrics, video_metrics, num_epochs)

if __name__ == "__main__":
    train_av_models(num_samples=10000)

Skipping invalid JSON line: {"key": "S_M_00051_G2_task1_4_S00004", "wav_path": "data/S_M_00051_G2_task1_4_S00004.wav", "label": 1, "Frenchay": 114}4{"key": "S_M_00020_G4_task1_5_S00007", "wav_path": "data/S_M_00020_G4_task1_5_S00007.wav", "label": 1, "Frenchay": 92}



KeyboardInterrupt: 

# tran val split version 3 -- precompute audio spectrograms

In [None]:
# Load video
def load_aligned_pair(wav_path, video_path, sr=16000, fixed_samples=16000, fps=25, device='cpu'):
    try:
        audio_duration = fixed_samples / sr
        video_frames = int(audio_duration * fps)
        video, _, info = read_video(video_path, pts_unit='sec')
        video = video.float() / 255.0
        if video.shape[0] > video_frames:
            video = video[:video_frames]
        elif video.shape[0] < video_frames:
            video = torch.nn.functional.pad(video, (0, 0, 0, 0, 0, video_frames - video.shape[0]))
        video = video.permute(0, 3, 1, 2)
        video = torch.nn.functional.interpolate(video, size=(112, 112), mode='bilinear', align_corners=False)
        video = video.permute(1, 0, 2, 3)
        return video.to(device)
    except Exception as e:
        raise RuntimeError(f"Error loading {video_path}: {e}")

# AudioCNN
class AudioCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(32 * 32 * 15, 128)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.bn1(self.conv1(x))))
        x = self.pool(nn.functional.relu(self.bn2(self.conv2(x))))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# VideoModel
class VideoModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet3d = video_models.r3d_18(pretrained=True)
        in_features = self.resnet3d.fc.in_features
        self.resnet3d.fc = nn.Linear(in_features, 4)
        for param in self.resnet3d.parameters():
            param.requires_grad = False
        unfreeze_layers = ['layer3', 'layer4']
        for name, child in self.resnet3d.named_children():
            if name in unfreeze_layers:
                for param in child.parameters():
                    param.requires_grad = True

    def forward(self, x):
        return self.resnet3d(x)

# Custom Dataset
class AVPairDataset(Dataset):
    def __init__(self, entries, spec_dir='spectrograms', sr=16000, fixed_samples=16000, fps=25, device='cpu'):  # CPU
        self.entries = entries
        self.spec_dir = spec_dir
        self.sr = sr
        self.fixed_samples = fixed_samples
        self.fps = fps
        self.device = device

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        entry = self.entries[idx]
        try:
            spec_path = os.path.join(self.spec_dir, f'spec_{entry["key"]}.npy')
            spec = torch.from_numpy(np.load(spec_path))  # CPU
            video = load_aligned_pair(entry['wav_path'], entry['video_path'], self.sr, self.fixed_samples, self.fps, device='cpu')  # CPU
            label = torch.tensor(entry['label'], dtype=torch.long)
            return spec, video, label
        except Exception as e:
            print(f"Error loading {entry['wav_path']} or spectrogram: {e}")
            return None, None, None

# Plotting function
def plot_metrics(audio_metrics, video_metrics, num_epochs):
    epochs = range(1, num_epochs + 1)
    plt.figure(figsize=(12, 10))
    plt.subplot(2, 2, 1)
    plt.plot(epochs, audio_metrics['train_loss'], label='Train Loss', marker='o')
    plt.plot(epochs, audio_metrics['val_loss'], label='Val Loss', marker='o')
    plt.title('AudioCNN Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.subplot(2, 2, 2)
    plt.plot(epochs, audio_metrics['train_acc'], label='Train Acc', marker='o')
    plt.plot(epochs, audio_metrics['val_acc'], label='Val Acc', marker='o')
    plt.title('AudioCNN Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    plt.subplot(2, 2, 3)
    plt.plot(epochs, video_metrics['train_loss'], label='Train Loss', marker='o')
    plt.plot(epochs, video_metrics['val_loss'], label='Val Loss', marker='o')
    plt.title('VideoModel Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.subplot(2, 2, 4)
    plt.plot(epochs, video_metrics['train_acc'], label='Train Acc', marker='o')
    plt.plot(epochs, video_metrics['val_acc'], label='Val Acc', marker='o')
    plt.title('VideoModel Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.show()

def train_av_models(file_path='data_copy.list', sr=16000, fixed_samples=16000, fps=25, device='cuda' if torch.cuda.is_available() else 'cpu'):
    all_entries = []
    with open(file_path, 'r') as f:
        for line in f:
            if not line.strip():
                continue
            try:
                entry = json.loads(line.strip())
                wav_path = os.path.join('audio', os.path.basename(entry['wav_path']))
                video_path = os.path.join('video', os.path.splitext(os.path.basename(entry['wav_path']))[0] + '.avi')
                if os.path.exists(wav_path) and os.path.exists(video_path):
                    all_entries.append({
                        'key': entry['key'],
                        'wav_path': wav_path,
                        'video_path': video_path,
                        'label': int(entry['label']),
                        'Frenchay': entry['Frenchay']
                    })
            except json.JSONDecodeError:
                print(f"Skipping invalid JSON line: {line}")
                continue

    train_entries, val_entries = train_test_split(all_entries, test_size=0.15, stratify=[e['label'] for e in all_entries], random_state=42)
    print(f"Training entries: {len(train_entries)}, Validation entries: {len(val_entries)}")

    train_dataset = AVPairDataset(train_entries, spec_dir='spectrograms', sr=sr, fixed_samples=fixed_samples, fps=fps, device='cpu')
    val_dataset = AVPairDataset(val_entries, spec_dir='spectrograms', sr=sr, fixed_samples=fixed_samples, fps=fps, device='cpu')
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

    audio_cnn = AudioCNN().to(device)
    video_model = VideoModel().to(device)

    audio_optimizer = torch.optim.Adam(audio_cnn.parameters(), lr=0.001)
    video_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, video_model.parameters()), lr=0.0001)
    criterion = nn.CrossEntropyLoss()

    audio_metrics = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    video_metrics = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    best_audio_acc = 0.0
    best_video_acc = 0.0

    num_epochs = 5
    print("Training AudioCNN...")
    for epoch in range(num_epochs):
        audio_cnn.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        for spec, _, labels in train_loader:
            if spec is None or labels is None:
                continue
            spec, labels = spec.to(device), labels.to(device)
            audio_optimizer.zero_grad()
            outputs = audio_cnn(spec.unsqueeze(1))
            loss = criterion(outputs, labels)
            loss.backward()
            audio_optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss /= len(train_loader)
        train_acc = 100 * train_correct / train_total

        audio_cnn.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for spec, _, labels in val_loader:
                if spec is None or labels is None:
                    continue
                spec, labels = spec.to(device), labels.to(device)
                outputs = audio_cnn(spec.unsqueeze(1))
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = 100 * val_correct / val_total

        audio_metrics['train_loss'].append(train_loss)
        audio_metrics['val_loss'].append(val_loss)
        audio_metrics['train_acc'].append(train_acc)
        audio_metrics['val_acc'].append(val_acc)

        print(f"Audio Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        if val_acc > best_audio_acc:
            best_audio_acc = val_acc
            torch.save(audio_cnn.state_dict(), 'best_audio_cnn.pth')
            print("Saved best AudioCNN model")

    print("Training VideoModel...")
    for epoch in range(num_epochs):
        video_model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        for _, video, labels in train_loader:
            if video is None or labels is None:
                continue
            video, labels = video.to(device), labels.to(device)
            video_optimizer.zero_grad()
            outputs = video_model(video)
            loss = criterion(outputs, labels)
            loss.backward()
            video_optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss /= len(train_loader)
        train_acc = 100 * train_correct / train_total

        video_model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for _, video, labels in val_loader:
                if video is None or labels is None:
                    continue
                video, labels = video.to(device), labels.to(device)
                outputs = video_model(video)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = 100 * val_correct / val_total

        video_metrics['train_loss'].append(train_loss)
        video_metrics['val_loss'].append(val_loss)
        video_metrics['train_acc'].append(train_acc)
        video_metrics['val_acc'].append(val_acc)

        print(f"Video Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        if val_acc > best_video_acc:
            best_video_acc = val_acc
            torch.save(video_model.state_dict(), 'best_video_model.pth')
            print("Saved best VideoModel")

    plot_metrics(audio_metrics, video_metrics, num_epochs)

if __name__ == "__main__":
    train_av_models()

Skipping invalid JSON line: {"key": "S_M_00051_G2_task1_4_S00004", "wav_path": "data/S_M_00051_G2_task1_4_S00004.wav", "label": 1, "Frenchay": 114}4{"key": "S_M_00020_G4_task1_5_S00007", "wav_path": "data/S_M_00020_G4_task1_5_S00007.wav", "label": 1, "Frenchay": 92}

Training entries: 51404, Validation entries: 9072




Training AudioCNN...


# train val split version 2 -- too slow

In [None]:
# Load and align pair (unchanged)
def load_aligned_pair(wav_path, video_path, sr=16000, fixed_samples=16000, fps=25, device='cpu'):
    try:
        audio, sample_rate = torchaudio.load(wav_path)
        if sample_rate != sr:
            audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=sr)(audio)
        if audio.shape[1] > fixed_samples:
            audio = audio[:, :fixed_samples]
        elif audio.shape[1] < fixed_samples:
            audio = torch.nn.functional.pad(audio, (0, fixed_samples - audio.shape[1]))

        audio_duration = fixed_samples / sr
        video_frames = int(audio_duration * fps)

        video, _, info = read_video(video_path, pts_unit='sec')
        video = video.float() / 255.0

        if video.shape[0] > video_frames:
            video = video[:video_frames]
        elif video.shape[0] < video_frames:
            video = torch.nn.functional.pad(video, (0, 0, 0, 0, 0, video_frames - video.shape[0]))

        video = video.permute(0, 3, 1, 2)
        video = torch.nn.functional.interpolate(video, size=(112, 112), mode='bilinear', align_corners=False)
        video = video.permute(1, 0, 2, 3)

        return audio.to(device), video.to(device)
    except Exception as e:
        raise RuntimeError(f"Error loading pair {wav_path} and {video_path}: {e}")

# Audio to spectrogram (unchanged)
def audio_to_spectrogram(audio, sr=16000, n_fft=512, hop_length=256, n_mels=128, fmax=8000, device='cuda' if torch.cuda.is_available() else 'cpu'):
    spec_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, f_max=fmax
    ).to(audio.device)

    mel = spec_transform(audio)
    mel = mel.mean(dim=0)
    mel_db = torchaudio.transforms.AmplitudeToDB(stype='power', top_db=None)(mel)
    mel_db -= mel_db.max()
    return mel_db

# AudioCNN (unchanged)
class AudioCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(32 * 32 * 15, 128)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.bn1(self.conv1(x))))
        x = self.pool(nn.functional.relu(self.bn2(self.conv2(x))))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# VideoModel
class VideoModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet3d = video_models.r3d_18(pretrained=True)
        in_features = self.resnet3d.fc.in_features
        self.resnet3d.fc = nn.Linear(in_features, 4)  # 4 classes for training
        
        for param in self.resnet3d.parameters():
            param.requires_grad = False
        
        unfreeze_layers = ['layer3', 'layer4']
        for name, child in self.resnet3d.named_children():
            if name in unfreeze_layers:
                for param in child.parameters():
                    param.requires_grad = True

    def forward(self, x):
        return self.resnet3d(x)

# Custom Dataset
class AVPairDataset(Dataset):
    def __init__(self, entries, sr=16000, fixed_samples=16000, fps=25, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.entries = entries
        self.sr = sr
        self.fixed_samples = fixed_samples
        self.fps = fps
        self.device = device

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        entry = self.entries[idx]
        try:
            audio, video = load_aligned_pair(entry['wav_path'], entry['video_path'], self.sr, self.fixed_samples, self.fps, self.device)
            spec = audio_to_spectrogram(audio, sr=self.sr, n_fft=512, hop_length=256, n_mels=128, fmax=8000, device=self.device)
            label = torch.tensor(entry['label'], dtype=torch.long)
            return spec, video, label
        except Exception as e:
            print(f"Error loading {entry['wav_path']}: {e}")
            return None, None, None

# Plotting function
def plot_metrics(audio_metrics, video_metrics, num_epochs):
    epochs = range(1, num_epochs + 1)
    
    plt.figure(figsize=(12, 10))
    
    # Audio Loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, audio_metrics['train_loss'], label='Train Loss', marker='o')
    plt.plot(epochs, audio_metrics['val_loss'], label='Val Loss', marker='o')
    plt.title('AudioCNN Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Audio Accuracy
    plt.subplot(2, 2, 2)
    plt.plot(epochs, audio_metrics['train_acc'], label='Train Acc', marker='o')
    plt.plot(epochs, audio_metrics['val_acc'], label='Val Acc', marker='o')
    plt.title('AudioCNN Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    
    # Video Loss
    plt.subplot(2, 2, 3)
    plt.plot(epochs, video_metrics['train_loss'], label='Train Loss', marker='o')
    plt.plot(epochs, video_metrics['val_loss'], label='Val Loss', marker='o')
    plt.title('VideoModel Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Video Accuracy
    plt.subplot(2, 2, 4)
    plt.plot(epochs, video_metrics['train_acc'], label='Train Acc', marker='o')
    plt.plot(epochs, video_metrics['val_acc'], label='Val Acc', marker='o')
    plt.title('VideoModel Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.show()

def train_av_models(file_path='data_copy.list', sr=16000, fixed_samples=16000, fps=25, device='cuda' if torch.cuda.is_available() else 'cpu'):
    # Load entries
    all_entries = []
    with open(file_path, 'r') as f:
        for line in f:
            if not line.strip():
                continue
            try:
                entry = json.loads(line.strip())
                wav_path = os.path.join('audio', os.path.basename(entry['wav_path']))
                video_path = os.path.join('video', os.path.splitext(os.path.basename(entry['wav_path']))[0] + '.avi')
                if os.path.exists(wav_path) and os.path.exists(video_path):
                    all_entries.append({
                        'key': entry['key'],
                        'wav_path': wav_path,
                        'video_path': video_path,
                        'label': int(entry['label']),
                        'Frenchay': entry['Frenchay']
                    })
            except json.JSONDecodeError:
                print(f"Skipping invalid JSON line: {line}")
                continue

    # Train-val split
    train_entries, val_entries = train_test_split(all_entries, test_size=0.15, stratify=[e['label'] for e in all_entries], random_state=42)
    print(f"Training entries: {len(train_entries)}, Validation entries: {len(val_entries)}")

    # Create datasets
    train_dataset = AVPairDataset(train_entries, sr, fixed_samples, fps, device)
    val_dataset = AVPairDataset(val_entries, sr, fixed_samples, fps, device)
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)

    # Initialize models
    audio_cnn = AudioCNN().to(device)
    video_model = VideoModel().to(device)

    # Setup optimizers and loss
    audio_optimizer = torch.optim.Adam(audio_cnn.parameters(), lr=0.001)
    video_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, video_model.parameters()), lr=0.0001)
    criterion = nn.CrossEntropyLoss()

    # Metrics storage
    audio_metrics = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    video_metrics = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    best_audio_acc = 0.0
    best_video_acc = 0.0

    # Train audio model
    num_epochs = 5
    print("Training AudioCNN...")
    for epoch in range(num_epochs):
        audio_cnn.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        for spec, _, labels in train_loader:
            if spec is None or labels is None:
                continue
            spec, labels = spec.to(device), labels.to(device)
            audio_optimizer.zero_grad()
            outputs = audio_cnn(spec.unsqueeze(1))
            loss = criterion(outputs, labels)
            loss.backward()
            audio_optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss /= len(train_loader)
        train_acc = 100 * train_correct / train_total

        audio_cnn.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for spec, _, labels in val_loader:
                if spec is None or labels is None:
                    continue
                spec, labels = spec.to(device), labels.to(device)
                outputs = audio_cnn(spec.unsqueeze(1))
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = 100 * val_correct / val_total

        audio_metrics['train_loss'].append(train_loss)
        audio_metrics['val_loss'].append(val_loss)
        audio_metrics['train_acc'].append(train_acc)
        audio_metrics['val_acc'].append(val_acc)

        print(f"Audio Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        if val_acc > best_audio_acc:
            best_audio_acc = val_acc
            torch.save(audio_cnn.state_dict(), 'best_audio_cnn.pth')
            print("Saved best AudioCNN model")

    # Train video model
    print("Training VideoModel...")
    for epoch in range(num_epochs):
        video_model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        for _, video, labels in train_loader:
            if video is None or labels is None:
                continue
            video, labels = video.to(device), labels.to(device)
            video_optimizer.zero_grad()
            outputs = video_model(video)
            loss = criterion(outputs, labels)
            loss.backward()
            video_optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss /= len(train_loader)
        train_acc = 100 * train_correct / train_total

        video_model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for _, video, labels in val_loader:
                if video is None or labels is None:
                    continue
                video, labels = video.to(device), labels.to(device)
                outputs = video_model(video)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = 100 * val_correct / val_total

        video_metrics['train_loss'].append(train_loss)
        video_metrics['val_loss'].append(val_loss)
        video_metrics['train_acc'].append(train_acc)
        video_metrics['val_acc'].append(val_acc)

        print(f"Video Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        if val_acc > best_video_acc:
            best_video_acc = val_acc
            torch.save(video_model.state_dict(), 'best_video_model.pth')
            print("Saved best VideoModel")

    # Plot metrics
    plot_metrics(audio_metrics, video_metrics, num_epochs)

if __name__ == "__main__":
    train_av_models()

Skipping invalid JSON line: {"key": "S_M_00051_G2_task1_4_S00004", "wav_path": "data/S_M_00051_G2_task1_4_S00004.wav", "label": 1, "Frenchay": 114}4{"key": "S_M_00020_G4_task1_5_S00007", "wav_path": "data/S_M_00020_G4_task1_5_S00007.wav", "label": 1, "Frenchay": 92}

Training entries: 51404, Validation entries: 9072


Downloading: "https://download.pytorch.org/models/r3d_18-b3b3357e.pth" to /home/ucloud/.cache/torch/hub/checkpoints/r3d_18-b3b3357e.pth
100%|██████████| 127M/127M [00:00<00:00, 370MB/s] 


Training AudioCNN...


# original one

In [7]:
# Load and align pair to fixed length
def load_aligned_pair(wav_path, video_path, sr=16000, fixed_samples=16000, fps=25, device='cpu'):
    try:
        # Load audio
        audio, sample_rate = torchaudio.load(wav_path)
        if sample_rate != sr:
            audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=sr)(audio)
        if audio.shape[1] > fixed_samples:
            audio = audio[:, :fixed_samples]
        elif audio.shape[1] < fixed_samples:
            audio = torch.nn.functional.pad(audio, (0, fixed_samples - audio.shape[1]))

        # Calculate target number of video frames
        audio_duration = fixed_samples / sr  # 1s
        video_frames = int(audio_duration * fps)  # 25 frames

        # Load video
        video, _, info = read_video(video_path, pts_unit='sec')
        video = video.float() / 255.0  # Normalize to [0, 1]

        # Align video frames
        if video.shape[0] > video_frames:
            video = video[:video_frames]
        elif video.shape[0] < video_frames:
            video = torch.nn.functional.pad(video, (0, 0, 0, 0, 0, video_frames - video.shape[0]))

        # Resize video to 112x112 (standard for 3D ResNet)
        video = video.permute(0, 3, 1, 2)  # [T, H, W, C] -> [T, C, H, W]
        video = torch.nn.functional.interpolate(video, size=(112, 112), mode='bilinear', align_corners=False)

        # Reorder dimensions for 3D ResNet: [T, C, H, W] -> [C, T, H, W]
        video = video.permute(1, 0, 2, 3)  # [C, T, H, W], e.g., [3, 25, 112, 112]

        return audio.to(device), video.to(device)
    except Exception as e:
        raise RuntimeError(f"Error loading pair {wav_path} and {video_path}: {e}")

# Audio to spectrogram with dB conversion and normalization
def audio_to_spectrogram(audio, sr=16000, n_fft=512, hop_length=256, n_mels=128, fmax=8000, device='cuda' if torch.cuda.is_available() else 'cpu'):
    spec_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, f_max=fmax
    ).to(audio.device)

    # Compute mel-spectrogram
    mel = spec_transform(audio)  # [channels, n_mels, time_frames], e.g., [1, 128, 62]
    mel = mel.mean(dim=0)  # Average across channels, e.g., [128, 62]

    # Convert to dB scale
    mel_db = torchaudio.transforms.AmplitudeToDB(stype='power', top_db=None)(mel)  # [128, 62]

    # Normalize by subtracting the maximum value
    mel_db -= mel_db.max()  # [128, 62]

    return mel_db

#updated version with batchnorm
class AudioCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)  # BN for 16 channels
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)  # BN for 32 channels
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(32 * 32 * 15, 128)  # For [128, 62] input: [32, 32, 15] after pooling

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.bn1(self.conv1(x))))  # [1, 128, 62] -> [16, 64, 31]
        x = self.pool(nn.functional.relu(self.bn2(self.conv2(x))))  # [16, 64, 31] -> [32, 32, 15]
        x = x.view(x.size(0), -1)  # Flatten: [32 * 32 * 15 = 15360]
        x = self.fc(x)  # [15360] -> [128]
        return x
    
# Process and save to HDF5 with video embeddings
def process_and_save_to_hdf5(file_path='data.list', output_dir='features', sr=16000, fixed_samples=16000, fps=25, batch_size=10, device='cuda' if torch.cuda.is_available() else 'cpu'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    #n = 0
    skipped = 0
    processed = 0
    batch_audio_specs = []
    batch_video_embeds = []
    batch_metadata = []

    # Initialize models
    audio_cnn = AudioCNN().to(device)
    video_model = r3d_18(pretrained=True).to(device)  # 3D ResNet-18

    # Replace the FC layer to output [512] instead of [400]
    in_features = video_model.fc.in_features  # Should be 512
    video_model.fc = torch.nn.Linear(in_features, 512).to(device)  # Change output to 512
    video_model.eval()  # Set to evaluation mode

    with open(file_path, 'r') as f:
        for line in f:
            #n += 1
            #if n > 10:
                #break

            if not line.strip():
                continue
            try:
                entry = json.loads(line.strip())
            except json.JSONDecodeError:
                print(f"Skipping invalid JSON line: {line}")
                skipped += 1
                continue

            key = entry['key']
            wav_path = os.path.join('audio', os.path.basename(entry['wav_path']))
            video_path = os.path.join('video', os.path.splitext(os.path.basename(entry['wav_path']))[0] + '.avi')

            if not os.path.exists(wav_path) or not os.path.exists(video_path):
                print(f"Warning: Missing {wav_path} or {video_path}, skipping pair.")
                skipped += 1
                continue

            try:
                audio, video = load_aligned_pair(wav_path, video_path, sr, fixed_samples, fps, device)

                # Audio processing
                spec = audio_to_spectrogram(audio, sr=sr, n_fft=512, hop_length=256, n_mels=128, fmax=8000, device=device)  # [128, 62]
                spec_embedding = audio_cnn(spec.unsqueeze(0).unsqueeze(0)).detach().cpu().numpy()  # [1, 128]

                # Video processing with 3D ResNet
                with torch.no_grad():
                    video_embed = video_model(video.unsqueeze(0)).detach().cpu().numpy()  # [1, 512]

                batch_audio_specs.append(spec_embedding)
                batch_video_embeds.append(video_embed)
                batch_metadata.append({
                    'key': key,
                    'label': entry['label'],
                    'Frenchay': entry['Frenchay'],
                    'wav_path': wav_path,
                    'video_path': video_path
                })

                processed += 1
                if processed % 100 == 0:
                    print(f"Processed {processed} items so far.")

                if len(batch_audio_specs) >= batch_size:
                    batch_num = processed // batch_size
                    save_batch(batch_audio_specs, batch_video_embeds, batch_metadata, batch_num, output_dir)
                    batch_audio_specs, batch_video_embeds, batch_metadata = [], [], []

                del audio, video, spec, spec_embedding, video_embed
                torch.cuda.empty_cache()

            except Exception as e:
                print(f"Error processing {wav_path}: {e}")
                skipped += 1
                continue

        if batch_audio_specs:
            batch_num = (processed // batch_size) + 1
            save_batch(batch_audio_specs, batch_video_embeds, batch_metadata, batch_num, output_dir)

    print(f"Finished. Processed: {processed}, Skipped: {skipped}")
    return skipped

def save_batch(audio_specs, video_embeds, metadata, batch_num, output_dir):
    with h5py.File(os.path.join(output_dir, f'features_batch_{batch_num}.h5'), 'w') as f:
        f.create_dataset('audio_embeddings', data=np.array(audio_specs), compression='gzip')
        f.create_dataset('video_embeddings', data=np.array(video_embeds), compression='gzip')
        metadata_dtype = [('key', 'S50'), ('label', 'S10'), ('Frenchay', 'f4'), ('wav_path', 'S100'), ('video_path', 'S100')]
        f.create_dataset('metadata', data=np.array([(m['key'], m['label'], m['Frenchay'], m['wav_path'], m['video_path']) for m in metadata], dtype=metadata_dtype))



In [8]:
if __name__ == "__main__":
    process_and_save_to_hdf5(file_path='data_copy.list', output_dir='features')

Downloading: "https://download.pytorch.org/models/r3d_18-b3b3357e.pth" to /home/ucloud/.cache/torch/hub/checkpoints/r3d_18-b3b3357e.pth
100%|██████████| 127M/127M [00:00<00:00, 254MB/s]  


Processed 100 items so far.
Processed 200 items so far.
Processed 300 items so far.
Processed 400 items so far.
Processed 500 items so far.
Processed 600 items so far.
Processed 700 items so far.
Processed 800 items so far.
Processed 900 items so far.
Processed 1000 items so far.
Processed 1100 items so far.
Processed 1200 items so far.
Processed 1300 items so far.
Processed 1400 items so far.
Processed 1500 items so far.
Processed 1600 items so far.
Processed 1700 items so far.
Skipping invalid JSON line: {"key": "S_M_00051_G2_task1_4_S00004", "wav_path": "data/S_M_00051_G2_task1_4_S00004.wav", "label": 1, "Frenchay": 114}4{"key": "S_M_00020_G4_task1_5_S00007", "wav_path": "data/S_M_00020_G4_task1_5_S00007.wav", "label": 1, "Frenchay": 92}

Processed 1800 items so far.
Processed 1900 items so far.
Processed 2000 items so far.
Processed 2100 items so far.
Processed 2200 items so far.
Processed 2300 items so far.
Processed 2400 items so far.
Processed 2500 items so far.
Processed 2600 i

In [9]:
# Path to the first HDF5 file
hdf5_dir = 'features'  # Directory where HDF5 files are stored
hdf5_path = os.path.join(hdf5_dir, 'features_batch_1.h5')

# Load the first batch
with h5py.File(hdf5_path, 'r') as f:
    # Load audio embeddings
    audio_embeddings = f['audio_embeddings'][:]  # Shape: [10, 1, 128]
    # Load video embeddings
    video_embeddings = f['video_embeddings'][:]  # Shape: [10, 1, 512]
    # Load metadata
    metadata = f['metadata'][:]

# Remove the singleton dimension for easier inspection
audio_embeddings = audio_embeddings.squeeze(1)  # Shape: [10, 128]
video_embeddings = video_embeddings.squeeze(1)  # Shape: [10, 512]

# Print shapes and a few examples
print("Audio Embeddings Shape:", audio_embeddings.shape)
print("Video Embeddings Shape:", video_embeddings.shape)
print("\nFirst Audio Embedding (clip 0):", audio_embeddings[0][:5], "...")  # First 5 values
print("First Video Embedding (clip 0):", video_embeddings[0][:5], "...")  # First 5 values

# Print metadata for all clips in the batch
print("\nMetadata for the batch:")
for i, meta in enumerate(metadata):
    print(f"Clip {i}:")
    print(f"  Key: {meta['key'].decode('utf-8')}")
    print(f"  Label: {meta['label'].decode('utf-8')}")
    print(f"  Frenchay: {meta['Frenchay']}")
    print(f"  Wav Path: {meta['wav_path'].decode('utf-8')}")
    print(f"  Video Path: {meta['video_path'].decode('utf-8')}")

Audio Embeddings Shape: (10, 128)
Video Embeddings Shape: (10, 512)

First Audio Embedding (clip 0): [-1.8582482 -1.2904636  4.4079876  6.493519   3.718699 ] ...
First Video Embedding (clip 0): [-0.23355688 -0.62655675  0.37452427 -0.09834773 -0.40472502] ...

Metadata for the batch:
Clip 0:
  Key: S_M_00051_G2_task4_1_S00000
  Label: 1
  Frenchay: 114.0
  Wav Path: audio/S_M_00051_G2_task4_1_S00000.wav
  Video Path: video/S_M_00051_G2_task4_1_S00000.avi
Clip 1:
  Key: S_M_00004_G2_task4_3_S00011
  Label: 1
  Frenchay: 95.0
  Wav Path: audio/S_M_00004_G2_task4_3_S00011.wav
  Video Path: video/S_M_00004_G2_task4_3_S00011.avi
Clip 2:
  Key: N_M_10015_G2_task1_4_S00002
  Label: 0
  Frenchay: 116.0
  Wav Path: audio/N_M_10015_G2_task1_4_S00002.wav
  Video Path: video/N_M_10015_G2_task1_4_S00002.avi
Clip 3:
  Key: N_M_10010_G3_task8_1_S00001
  Label: 0
  Frenchay: 116.0
  Wav Path: audio/N_M_10010_G3_task8_1_S00001.wav
  Video Path: video/N_M_10010_G3_task8_1_S00001.avi
Clip 4:
  Key: S_M_0

# --- other early tries - failed 1

In [6]:
def load_audio(wav_path, sr=16000, device='cpu'):
    audio, sample_rate = torchaudio.load(wav_path)
    if sample_rate != sr:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=sr).to(device)
        audio = resampler(audio)
    return audio.to(device)


def load_video(video_path, device='cpu'):
    try:
        video, _, info = read_video(video_path, pts_unit='sec')
        video = video.permute(0, 3, 1, 2).float() / 255.0  # [T, H, W, C] → [T, C, H, W]
        return video.to(device)
    except Exception as e:
        print(f"Error loading video {video_path}: {e}")
        return None


def load_audio_video_from_list(file_path='data.list', sr=16000, device='cuda' if torch.cuda.is_available() else 'cpu'):
    audio_data = []
    video_data = []
    metadata = []
    n=0
    with open(file_path, 'r') as f:
        for line in f:

            if not line.strip():
                continue
            try:
                entry = json.loads(line.strip())
            except json.decoder.JSONDecodeError:
                print(f"Error decoding JSON: {line}")
                continue


            key = entry['key']
            original_wav_path = entry['wav_path']  # e.g., 'data/S_M_00013_G4_task2_1_S00002.wav'
            label = entry['label']
            frenchay = entry['Frenchay']

            # Build actual paths
            filename = os.path.basename(original_wav_path)  # Just the filename
            wav_path = os.path.join('audio', filename)
            video_filename = os.path.splitext(filename)[0] + '.avi'
            video_path = os.path.join('video', video_filename)
            
            if not os.path.exists(wav_path):
                print(f"Warning: {wav_path} not found, skipping.")
                continue
            k=0
            if not os.path.exists(video_path):
                #print(f"Warning: {video_path} not found, skipping.")
                k+=1
                continue

            try:
                audio = load_audio(wav_path, sr, device)
                video = load_video(video_path, device)

                if video is None:
                    continue

                audio_data.append(audio)
                video_data.append(video)
                metadata.append({
                    'key': key,
                    'label': label,
                    'Frenchay': frenchay,
                    'wav_path': wav_path,
                    'video_path': video_path
                })

                
                # Clear GPU cache after processing each video to manage memory
                torch.cuda.empty_cache()

                # Optional: print the progress for debugging purposes
                n += 1
                if n % 100 == 0:
                    print(f"Processed {n} items so far.")

            except Exception as e:
                print(f"Error processing {filename}: {e}")

    return audio_data, video_data, metadata,k

In [5]:
# Load the audio and video data
file_path = 'data_copy.list'
audio_data, video_data, metadata,k = load_audio_video_from_list(file_path, sr=16000, device='cuda' if torch.cuda.is_available() else 'cpu')

NameError: name 'load_audio_video_from_list' is not defined

In [15]:
#print the shape of the first audio_data and first video data
print(f"Audio data shape: {audio_data[0].shape}")
print(f"Video data shape: {video_data[0].shape}")


Audio data shape: torch.Size([1, 19568])
Video data shape: torch.Size([76, 3, 96, 96])


# --- other early tries - failed 2

In [9]:
# Load audio and video functions (unchanged)
def load_audio(wav_path, sr=16000, device='cpu'):
    audio, sample_rate = torchaudio.load(wav_path)
    if sample_rate != sr:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=sr).to(device)
        audio = resampler(audio)
    return audio.to(device)

def load_video(video_path, device='cpu'):
    try:
        video, _, info = read_video(video_path, pts_unit='sec')
        video = video.permute(0, 3, 1, 2).float() / 255.0  # [T, H, W, C] → [T, C, H, W]
        return video.to(device)
    except Exception as e:
        print(f"Error loading video {video_path}: {e}")
        return None

# Dummy feature extraction function (replace with your actual model/logic)
def extract_features(audio, video, device):
    # Example: return dummy features (replace with real feature extraction)
    audio_feature = torch.mean(audio, dim=-1).cpu().numpy()  # Mean of audio signal
    video_feature = torch.mean(video, dim=0).flatten().cpu().numpy()  # Mean over time, flattened
    return audio_feature, video_feature

# Process and save to NumPy files in batches
def process_and_save_to_numpy(file_path='data.list', output_dir='features', sr=16000, batch_size=10, device='cuda' if torch.cuda.is_available() else 'cpu'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)  # Create output directory if it doesn't exist
    
    skipped = 0
    processed = 0
    batch_audio_features = []
    batch_video_features = []
    batch_metadata = []

    with open(file_path, 'r') as f:
        for line in f:
            if not line.strip():
                continue
            try:
                entry = json.loads(line.strip())
            except json.decoder.JSONDecodeError:
                print(f"Error decoding JSON: {line}")
                continue

            key = entry['key']
            original_wav_path = entry['wav_path']
            label = entry['label']
            frenchay = entry['Frenchay']

            # Build actual paths
            filename = os.path.basename(original_wav_path)
            wav_path = os.path.join('audio', filename)
            video_filename = os.path.splitext(filename)[0] + '.avi'
            video_path = os.path.join('video', video_filename)

            if not os.path.exists(wav_path):
                print(f"Warning: {wav_path} not found, skipping.")
                continue
            if not os.path.exists(video_path):
                #print(f"Warning: {video_path} not found, skipping.")
                skipped += 1
                continue

            try:
                # Load audio and video
                audio = load_audio(wav_path, sr, device)
                video = load_video(video_path, device)
                if video is None:
                    continue

                # Extract features
                audio_feature, video_feature = extract_features(audio, video, device)

                # Append to batch
                batch_audio_features.append(audio_feature)
                batch_video_features.append(video_feature)
                batch_metadata.append({
                    'key': key,
                    'label': label,
                    'Frenchay': frenchay,
                    'wav_path': wav_path,
                    'video_path': video_path
                })

                processed += 1
                if processed % 100 == 0:
                    print(f"Processed {processed} items so far.")

                # Save batch when it reaches batch_size
                if len(batch_audio_features) >= batch_size:
                    batch_num = processed // batch_size
                    save_batch(batch_audio_features, batch_video_features, batch_metadata, batch_num, output_dir)
                    # Clear batch lists
                    batch_audio_features = []
                    batch_video_features = []
                    batch_metadata = []

                # Clear GPU memory
                del audio, video
                torch.cuda.empty_cache()

            except Exception as e:
                print(f"Error processing {filename}: {e}")

        # Save any remaining items in the last batch
        if batch_audio_features:
            batch_num = (processed // batch_size) + 1
            save_batch(batch_audio_features, batch_video_features, batch_metadata, batch_num, output_dir)

    print(f"Finished processing. Total processed: {processed}, Skipped: {skipped}")
    return skipped

# Helper function to save a batch to NumPy files
def save_batch(audio_features, video_features, metadata, batch_num, output_dir):
    # Convert lists to NumPy arrays
    audio_features_np = np.array(audio_features, dtype=object)  # Use object dtype if shapes vary
    video_features_np = np.array(video_features, dtype=object)
    
    # Save to .npy files
    np.save(os.path.join(output_dir, f'audio_features_batch_{batch_num}.npy'), audio_features_np)
    np.save(os.path.join(output_dir, f'video_features_batch_{batch_num}.npy'), video_features_np)
    
    # Save metadata as a separate .npy file (or JSON if preferred)
    np.save(os.path.join(output_dir, f'metadata_batch_{batch_num}.npy'), np.array(metadata, dtype=object))
    #print(f"Saved batch {batch_num} to {output_dir}")



In [None]:
# Run the function
if __name__ == "__main__":
    process_and_save_to_numpy(file_path='data_copy.list', output_dir='features')

In [11]:
batch_num = 1
audio_features = np.load(f'features/audio_features_batch_{batch_num}.npy', allow_pickle=True)
video_features = np.load(f'features/video_features_batch_{batch_num}.npy', allow_pickle=True)
metadata = np.load(f'features/metadata_batch_{batch_num}.npy', allow_pickle=True)

for i in range(len(metadata)):
    print(f"Key: {metadata[i]['key']}, Audio feature shape: {audio_features[i].shape}, Video feature shape: {video_features[i].shape}")

Key: S_M_00051_G2_task4_1_S00000, Audio feature shape: (1,), Video feature shape: (27648,)
Key: S_M_00004_G2_task4_3_S00011, Audio feature shape: (1,), Video feature shape: (27648,)
Key: N_M_10015_G2_task1_4_S00002, Audio feature shape: (1,), Video feature shape: (27648,)
Key: N_M_10010_G3_task8_1_S00001, Audio feature shape: (1,), Video feature shape: (27648,)
Key: S_M_00013_G4_task2_1_S00002, Audio feature shape: (1,), Video feature shape: (27648,)
Key: S_M_00004_G5_task4_3_S00002, Audio feature shape: (1,), Video feature shape: (27648,)
Key: S_M_00044_G2_task7_2_S00008, Audio feature shape: (1,), Video feature shape: (27648,)
Key: S_M_00051_G2_task2_1_S00004, Audio feature shape: (1,), Video feature shape: (27648,)
Key: S_F_00010_G4_task6_1_S00009, Audio feature shape: (1,), Video feature shape: (27648,)
Key: N_M_10003_G4_task1_4_S00004, Audio feature shape: (1,), Video feature shape: (27648,)


In [14]:
print(audio_features[0])

[9.6315641712863e-05]


In [17]:
print(video_features[0].shape)

(27648,)


# --- other early tries - failed 3

In [3]:
# Load audio and video functions (unchanged)
def load_audio(wav_path, sr=16000, device='cpu'):
    audio, sample_rate = torchaudio.load(wav_path)
    if sample_rate != sr:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=sr).to(device)
        audio = resampler(audio)
    return audio.to(device)

def load_video(video_path, device='cpu'):
    try:
        video, _, info = read_video(video_path, pts_unit='sec')
        video = video.permute(0, 3, 1, 2).float() / 255.0  # [T, H, W, C] → [T, C, H, W]
        return video.to(device)
    except Exception as e:
        print(f"Error loading video {video_path}: {e}")
        return None

# Process and save to HDF5 files in batches
def process_and_save_to_hdf5(file_path='data.list', output_dir='raw_features', sr=16000, batch_size=10, device='cuda' if torch.cuda.is_available() else 'cpu'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)  # Create output directory if it doesn’t exist
    
    skipped = 0
    processed = 0
    batch_audio = []
    batch_video = []
    batch_metadata = []

    with open(file_path, 'r') as f:
        for line in f:
            if not line.strip():
                continue
            try:
                entry = json.loads(line.strip())
            except json.decoder.JSONDecodeError:
                print(f"Error decoding JSON: {line}")
                continue

            key = entry['key']
            original_wav_path = entry['wav_path']
            label = entry['label']
            frenchay = entry['Frenchay']

            # Build actual paths
            filename = os.path.basename(original_wav_path)
            wav_path = os.path.join('audio', filename)
            video_filename = os.path.splitext(filename)[0] + '.avi'
            video_path = os.path.join('video', video_filename)

            # Check if both files exist; skip if either is missing
            if not os.path.exists(wav_path):
                #print(f"Warning: {wav_path} not found, skipping pair.")
                skipped += 1
                continue
            if not os.path.exists(video_path):
                #print(f"Warning: {video_path} not found, skipping pair.")
                skipped += 1
                continue

            try:
                # Load audio and video
                audio = load_audio(wav_path, sr, device)
                video = load_video(video_path, device)
                if video is None:
                    print(f"Warning: Failed to load {video_path}, skipping pair.")
                    skipped += 1
                    continue

                # Move tensors to CPU and convert to NumPy for saving
                audio_np = audio.cpu().numpy()  # Shape: [channels, time]
                video_np = video.cpu().numpy()  # Shape: [T, C, H, W]

                # Append to batch
                batch_audio.append(audio_np)
                batch_video.append(video_np)
                batch_metadata.append({
                    'key': key,
                    'label': label,
                    'Frenchay': frenchay,
                    'wav_path': wav_path,
                    'video_path': video_path
                })

                processed += 1
                if processed % 100 == 0:
                    print(f"Processed {processed} items so far.")

                # Save batch when it reaches batch_size
                if len(batch_audio) >= batch_size:
                    batch_num = processed // batch_size
                    save_batch(batch_audio, batch_video, batch_metadata, batch_num, output_dir)
                    # Clear batch lists
                    batch_audio = []
                    batch_video = []
                    batch_metadata = []

                # Clear GPU memory
                del audio, video, audio_np, video_np
                torch.cuda.empty_cache()

            except Exception as e:
                print(f"Error processing {filename}: {e}")
                skipped += 1

        # Save any remaining items in the last batch
        if batch_audio:
            batch_num = (processed // batch_size) + 1
            save_batch(batch_audio, batch_video, batch_metadata, batch_num, output_dir)

    print(f"Finished processing. Total processed: {processed}, Skipped: {skipped}")
    return skipped

# Helper function to save a batch to an HDF5 file
def save_batch(audio_data, video_data, metadata, batch_num, output_dir):
    hdf5_path = os.path.join(output_dir, f'features_batch_{batch_num}.h5')
    with h5py.File(hdf5_path, 'w') as f:
        # Create groups for audio and video
        audio_group = f.create_group('audio')
        video_group = f.create_group('video')

        # Save each audio and video sample in the batch
        for i, (audio, video) in enumerate(zip(audio_data, video_data)):
            audio_group.create_dataset(str(i), data=audio, compression='gzip')
            video_group.create_dataset(str(i), data=video, compression='gzip')

        # Save metadata as a dataset
        # Convert metadata dicts to a structured array for HDF5 compatibility
        metadata_dtype = [('key', 'S50'), ('label', 'S10'), ('Frenchay', 'f4'), ('wav_path', 'S100'), ('video_path', 'S100')]
        metadata_array = np.array(
            [(m['key'], m['label'], m['Frenchay'], m['wav_path'], m['video_path']) for m in metadata],
            dtype=metadata_dtype
        )
        f.create_dataset('metadata', data=metadata_array)

    #print(f"Saved batch {batch_num} to {hdf5_path}")



In [None]:
# Run the function
if __name__ == "__main__":
    process_and_save_to_hdf5(file_path='data_copy.list', output_dir='raw_features')

Processed 100 items so far.
Processed 200 items so far.
Processed 300 items so far.
Processed 400 items so far.
Processed 500 items so far.
Processed 600 items so far.
Processed 700 items so far.


In [6]:
# Load a specific batch
batch_num = 20
hdf5_path = f'raw_features/features_batch_{batch_num}.h5'
with h5py.File(hdf5_path, 'r') as f:
    # Access audio and video for a specific sample
    audio_0 = f['audio']['0'][:]  # NumPy array, e.g., [1, 16000]
    video_0 = f['video']['0'][:]  # NumPy array, e.g., [100, 3, 224, 224]
    
    # Convert back to PyTorch tensors if needed
    audio_tensor = torch.from_numpy(audio_0)
    video_tensor = torch.from_numpy(video_0)
    
    # Access metadata
    metadata = f['metadata'][:]
    print(f"Sample 0: Key: {metadata['key'][0].decode()}, Label: {metadata['label'][0].decode()}, Frenchay: {metadata['Frenchay'][0]}")
    print(f"Audio shape: {audio_tensor.shape}, Video shape: {video_tensor.shape}")

Sample 0: Key: S_M_00047_G4_task4_1_S00005, Label: 1, Frenchay: 91.0
Audio shape: torch.Size([1, 15744]), Video shape: torch.Size([69, 3, 96, 96])
