<a href="https://colab.research.google.com/github/MartijnRoozendaal1/vvr-prediction/blob/main/Feature_Extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Setup

In [None]:
# Install required libraries
!pip install torch torchvision

# Import necessary libraries
import numpy as np
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm


Load Processed Frames

In [None]:
# Path to preprocessed frames
processed_frames_path = '/content/drive/MyDrive/Filtered_Infrared_Frames_FAN.npy'

# Load the processed frames
processed_frames = np.load(processed_frames_path, allow_pickle=True).item()



ResNet-152 Feature Extraction

In [None]:
# Load ResNet-152 and remove the classification layer
resnet152 = models.resnet152(pretrained=True)
resnet152 = torch.nn.Sequential(*list(resnet152.children())[:-1])  # Exclude final layer
resnet152.eval()

# Transformation for ResNet-152
transform_resnet = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Extract features
resnet_features = {}
for video_id, frames in tqdm(processed_frames.items(), desc="Extracting ResNet Features"):
    frame_features = []
    for frame_data in frames:
        aligned_frame = frame_data['aligned_frame']
        image = Image.fromarray(aligned_frame).convert('RGB')
        input_tensor = transform_resnet(image).unsqueeze(0)

        with torch.no_grad():
            feature = resnet152(input_tensor).squeeze().numpy()
            frame_features.append(feature)
    resnet_features[video_id] = np.array(frame_features)

# Save features
resnet_features_path = '/content/drive/MyDrive/ResNet152_Features.npy'
np.save(resnet_features_path, resnet_features)
print(f"ResNet-152 features saved to: {resnet_features_path}")


VGG-16 Feature Extraction

In [None]:
# Load VGG-16 and add Global Average Pooling (GAP)
vgg16 = models.vgg16(pretrained=True)
vgg16.eval()
vgg16_gap = torch.nn.Sequential(
    *list(vgg16.features),
    torch.nn.AdaptiveAvgPool2d((1, 1))  # Reduce features to (512,)
)

# Transformation for VGG-16
transform_vgg = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Extract features
vgg_features = {}
for video_id, frames in tqdm(processed_frames.items(), desc="Extracting VGG Features"):
    frame_features = []
    for frame_data in frames:
        aligned_frame = frame_data['aligned_frame']
        image = Image.fromarray(aligned_frame).convert('RGB')
        input_tensor = transform_vgg(image).unsqueeze(0)

        with torch.no_grad():
            feature = vgg16_gap(input_tensor).squeeze().numpy()
            frame_features.append(feature)
    vgg_features[video_id] = np.array(frame_features)

# Save features
vgg_features_path = '/content/drive/MyDrive/VGG16_Features_512.npy'
np.save(vgg_features_path, vgg_features)
print(f"VGG-16 features saved to: {vgg_features_path}")


Verify Feature Extraction

In [None]:
# Load saved features for verification
extracted_resnet_features = np.load(resnet_features_path, allow_pickle=True).item()
extracted_vgg_features = np.load(vgg_features_path, allow_pickle=True).item()

# Check for missing features
missing_resnet = set(processed_frames.keys()) - set(extracted_resnet_features.keys())
missing_vgg = set(processed_frames.keys()) - set(extracted_vgg_features.keys())

print(f"Videos missing ResNet-152 features: {len(missing_resnet)}")
print(f"Videos missing VGG-16 features: {len(missing_vgg)}")
