In [1]:
!pip install torch
!pip install torchaudio

Collecting torch
  Downloading torch-2.7.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting filelock (from torch)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.3.2-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collec

In [3]:
import os
import json
import torch
import torchaudio
import numpy as np

In [4]:
# Load and align pair (audio only)
def load_aligned_pair(wav_path, sr=16000, fixed_samples=16000, device='cuda' if torch.cuda.is_available() else 'cpu'):
    try:
        audio, sample_rate = torchaudio.load(wav_path)
        if sample_rate != sr:
            audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=sr)(audio)
        if audio.shape[1] > fixed_samples:
            audio = audio[:, :fixed_samples]
        elif audio.shape[1] < fixed_samples:
            audio = torch.nn.functional.pad(audio, (0, fixed_samples - audio.shape[1]))
        return audio.to(device)
    except Exception as e:
        raise RuntimeError(f"Error loading {wav_path}: {e}")

# Audio to spectrogram
def audio_to_spectrogram(audio, sr=16000, n_fft=512, hop_length=256, n_mels=128, fmax=8000, device='cuda' if torch.cuda.is_available() else 'cpu'):
    spec_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, f_max=fmax
    ).to(audio.device)
    mel = spec_transform(audio)
    mel = mel.mean(dim=0)
    mel_db = torchaudio.transforms.AmplitudeToDB(stype='power', top_db=None)(mel)
    mel_db -= mel_db.max()
    return mel_db

# Precompute spectrograms
def precompute_spectrograms(file_path='data_copy.list', output_dir='spectrograms', sr=16000, fixed_samples=16000, device='cuda' if torch.cuda.is_available() else 'cpu'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    entries = []
    with open(file_path, 'r') as f:
        for line in f:
            if not line.strip():
                continue
            try:
                entry = json.loads(line.strip())
                wav_path = os.path.join('audio', os.path.basename(entry['wav_path']))
                video_path = os.path.join('video', os.path.splitext(os.path.basename(entry['wav_path']))[0] + '.avi')
                if os.path.exists(wav_path) and os.path.exists(video_path):
                    entries.append({
                        'key': entry['key'],
                        'wav_path': wav_path,
                        'video_path': video_path,
                        'label': int(entry['label']),
                        'Frenchay': entry['Frenchay']
                    })
            except json.JSONDecodeError:
                print(f"Skipping invalid JSON line: {line}")
                continue
    
    for idx, entry in enumerate(entries):
        try:
            audio = load_aligned_pair(entry['wav_path'], sr=sr, fixed_samples=fixed_samples, device=device)
            spec = audio_to_spectrogram(audio, sr=sr, n_fft=512, hop_length=256, n_mels=128, fmax=8000, device=device)
            np.save(os.path.join(output_dir, f'spec_{entry["key"]}.npy'), spec.cpu().numpy())
            if idx % 100 == 0:
                print(f"Precomputed {idx} spectrograms")
        except Exception as e:
            print(f"Error precomputing spectrogram for {entry['wav_path']}: {e}")
            continue
    print(f"Finished precomputing spectrograms for {len(entries)} clips")

if __name__ == "__main__":
    precompute_spectrograms()

Skipping invalid JSON line: {"key": "S_M_00051_G2_task1_4_S00004", "wav_path": "data/S_M_00051_G2_task1_4_S00004.wav", "label": 1, "Frenchay": 114}4{"key": "S_M_00020_G4_task1_5_S00007", "wav_path": "data/S_M_00020_G4_task1_5_S00007.wav", "label": 1, "Frenchay": 92}





Precomputed 0 spectrograms
Precomputed 100 spectrograms
Precomputed 200 spectrograms
Precomputed 300 spectrograms
Precomputed 400 spectrograms
Precomputed 500 spectrograms
Precomputed 600 spectrograms
Precomputed 700 spectrograms
Precomputed 800 spectrograms
Precomputed 900 spectrograms
Precomputed 1000 spectrograms
Precomputed 1100 spectrograms
Precomputed 1200 spectrograms
Precomputed 1300 spectrograms
Precomputed 1400 spectrograms
Precomputed 1500 spectrograms
Precomputed 1600 spectrograms
Precomputed 1700 spectrograms
Precomputed 1800 spectrograms
Precomputed 1900 spectrograms
Precomputed 2000 spectrograms
Precomputed 2100 spectrograms
Precomputed 2200 spectrograms
Precomputed 2300 spectrograms
Precomputed 2400 spectrograms
Precomputed 2500 spectrograms
Precomputed 2600 spectrograms
Precomputed 2700 spectrograms
Precomputed 2800 spectrograms
Precomputed 2900 spectrograms
Precomputed 3000 spectrograms
Precomputed 3100 spectrograms
Precomputed 3200 spectrograms
Precomputed 3300 spect