In [1]:
import numpy as np
import librosa
import os
import json
import av
import h5py
import matplotlib.pyplot as plt
import torch.optim as optim
from sklearn.model_selection import train_test_split
from collections import defaultdict

import torch
import torchvision
import torchaudio
import torch.nn as nn
import torch.nn.functional as F

from torchvision.io import read_video
import torchvision.models.video as video_models
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.models.video import r3d_18, R3D_18_Weights
from torchvision.transforms import Normalize

In [2]:
# Set seed for reproducibility
def set_seed(seed=42):
    import random
    import torch.backends.cudnn as cudnn
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True
    cudnn.benchmark = False

# Normalization constants
AUDIO_MEAN = [0.485, 0.456, 0.406]
AUDIO_STD = [0.229, 0.224, 0.225]
VIDEO_MEAN = [0.43216, 0.394666, 0.37645]
VIDEO_STD = [0.22803, 0.22145, 0.216989]

In [3]:
# Dataset
class AudioVideoDataset(Dataset):
    def __init__(self, data_list, indices, spec_dir='spectrograms', fixed_samples=16000, fps=25):
        self.data = [data_list[i] for i in indices]
        self.spec_dir = spec_dir
        self.fixed_samples = fixed_samples
        self.fps = fps

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        wav_path, video_path, label, key = self.data[idx]
        spec_path = os.path.join(self.spec_dir, f'spec_{key}.npy')
        spec = torch.from_numpy(np.load(spec_path)).float()
        _, video = load_aligned_pair(wav_path, video_path, self.fixed_samples, self.fps)
        return spec, video, label, key

# Video loading
def load_aligned_pair(wav_path, video_path, fixed_samples=16000, fps=25):
    audio_duration = fixed_samples / 16000
    video_frames = int(audio_duration * fps)
    video, _, _ = read_video(video_path, pts_unit='sec')
    video = video.float() / 255.0
    if video.shape[0] > video_frames:
        video = video[:video_frames]
    elif video.shape[0] < video_frames:
        video = F.pad(video, (0, 0, 0, 0, 0, video_frames - video.shape[0]))
    video = video.permute(0, 3, 1, 2)
    video = F.interpolate(video, size=(112, 112), mode='bilinear', align_corners=False)
    video = video.permute(1, 0, 2, 3)
    return None, video

# Collate
def collate_fn(batch):
    specs, videos, labels, keys = zip(*batch)
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    specs = torch.stack(specs).unsqueeze(1).to(device)
    specs = F.interpolate(specs, size=(224, 224), mode='bilinear', align_corners=False)
    specs = specs.repeat(1, 3, 1, 1)
    specs = (specs - torch.tensor(AUDIO_MEAN, device=device).view(1, 3, 1, 1)) / torch.tensor(AUDIO_STD, device=device).view(1, 3, 1, 1)
    videos = torch.stack(videos).to(device)
    videos = (videos - torch.tensor(VIDEO_MEAN, device=device).view(1, 3, 1, 1, 1)) / torch.tensor(VIDEO_STD, device=device).view(1, 3, 1, 1, 1)
    labels = torch.tensor(labels, dtype=torch.long).to(device)
    return specs, videos, labels, keys

# Audio model
class AudioResNet(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(nn.Dropout(0.5), nn.Linear(in_features, num_classes))

    def forward(self, x):
        return self.resnet(x)

    def get_embedding(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)
        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)
        x = self.resnet.avgpool(x)
        return torch.flatten(x, 1)

# Video model
class VideoModel(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.resnet3d = r3d_18(weights=R3D_18_Weights.KINETICS400_V1)
        in_features = self.resnet3d.fc.in_features
        self.resnet3d.fc = nn.Sequential(nn.Dropout(0.5), nn.Linear(in_features, num_classes))

    def forward(self, x):
        return self.resnet3d(x)

    def get_embedding(self, x):
        x = self.resnet3d.stem(x)
        x = self.resnet3d.layer1(x)
        x = self.resnet3d.layer2(x)
        x = self.resnet3d.layer3(x)
        x = self.resnet3d.layer4(x)
        x = self.resnet3d.avgpool(x)
        return torch.flatten(x, 1)

# Save batch
def save_batch(audio_embeds, video_embeds, metadata, batch_num, output_dir):
    batch_dir = os.path.join(output_dir, f'batch_{batch_num}')
    os.makedirs(batch_dir, exist_ok=True)
    for i, (ae, ve, meta) in enumerate(zip(audio_embeds, video_embeds, metadata)):
        key = meta['key']
        np.save(os.path.join(batch_dir, f'audio_embed_{key}.npy'), ae)
        np.save(os.path.join(batch_dir, f'video_embed_{key}.npy'), ve)
        with open(os.path.join(batch_dir, f'meta_{key}.json'), 'w') as f:
            json.dump(meta, f)

# Main extraction function
def extract_finetuned_embeddings(audio_model, video_model, file_path='data_copy.list', output_dir='features_finetuned', batch_size=16, device='cuda:0'):
    audio_model.eval()
    video_model.eval()

    data_list = []
    with open(file_path, 'r') as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                entry = json.loads(line)
            except json.JSONDecodeError as e:
                print(f"[Line {line_num}] Skipping malformed JSON: {e}")
                continue
            key = entry['key']
            wav_path = os.path.join('audio', os.path.basename(entry['wav_path']))
            video_path = os.path.join('video', os.path.splitext(os.path.basename(entry['wav_path']))[0] + '.avi')
            label = int(entry['label'])
            spec_path = os.path.join('spectrograms', f'spec_{key}.npy')
            if os.path.exists(wav_path) and os.path.exists(video_path) and os.path.exists(spec_path):
                data_list.append((wav_path, video_path, label, key))

    indices = list(range(len(data_list)))
    dataset = AudioVideoDataset(data_list, indices)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False, collate_fn=collate_fn)

    os.makedirs(output_dir, exist_ok=True)
    batch_num = 0
    batch_audio_embeds, batch_video_embeds, batch_metadata = [], [], []
    processed = 0

    with torch.no_grad():
        for specs, videos, labels, keys in loader:
            ae = audio_model.get_embedding(specs)
            ve = video_model.get_embedding(videos)
            for i in range(specs.size(0)):
                batch_audio_embeds.append(ae[i].cpu().numpy())
                batch_video_embeds.append(ve[i].cpu().numpy())
                batch_metadata.append({
                    'key': keys[i],
                    'label': labels[i].item(),
                    'wav_path': data_list[batch_num * batch_size + i][0],
                    'video_path': data_list[batch_num * batch_size + i][1]
                })
                processed += 1
                if processed % 200 == 0:
                    print(f"Processed {processed} items...")
                if len(batch_audio_embeds) >= batch_size:
                    save_batch(batch_audio_embeds, batch_video_embeds, batch_metadata, batch_num, output_dir)
                    batch_audio_embeds, batch_video_embeds, batch_metadata = [], [], []
                    batch_num += 1

    if batch_audio_embeds:
        save_batch(batch_audio_embeds, batch_video_embeds, batch_metadata, batch_num, output_dir)

    print(f"Finished. Total processed: {processed}")

# Entry point
def main():
    set_seed()
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    audio_model = AudioResNet().to(device)
    video_model = VideoModel().to(device)
    audio_model.load_state_dict(torch.load('audio_finetuned.pth', map_location=device))
    video_model.load_state_dict(torch.load('video_finetuned.pth', map_location=device))
    extract_finetuned_embeddings(audio_model, video_model, file_path='data_copy.list', output_dir='features_finetuned', batch_size=16, device=device)

if __name__ == '__main__':
    main()


[Line 1919] Skipping malformed JSON: Extra data: line 1 column 120 (char 119)
Processed 200 items...
Processed 400 items...
Processed 600 items...
Processed 800 items...
Processed 1000 items...
Processed 1200 items...
Processed 1400 items...
Processed 1600 items...
Processed 1800 items...
Processed 2000 items...
Processed 2200 items...
Processed 2400 items...
Processed 2600 items...
Processed 2800 items...
Processed 3000 items...
Processed 3200 items...
Processed 3400 items...
Processed 3600 items...
Processed 3800 items...
Processed 4000 items...
Processed 4200 items...
Processed 4400 items...
Processed 4600 items...
Processed 4800 items...
Processed 5000 items...
Processed 5200 items...
Processed 5400 items...
Processed 5600 items...
Processed 5800 items...
Processed 6000 items...
Processed 6200 items...
Processed 6400 items...
Processed 6600 items...
Processed 6800 items...
Processed 7000 items...
Processed 7200 items...
Processed 7400 items...
Processed 7600 items...
Processed 7800