**How to run?**


1.   Run the first 2 cells as they are
2.   Change the "file_path" in cell 3 to the desired file path



In [3]:
import torch
import torchvision
from torchvision import transforms
from torchvision.models import vision_transformer
from torchvision.models.video import r3d_18
import imageio
import cv2
import numpy as np
from PIL import Image

def load_pretrained_3d_cnn():
    # Load pre-trained 3D CNN model (e.g., ResNet3D)
    model = r3d_18(pretrained=True)
    model.eval()
    return model

# Load the Vision Transformer model
def load_vit_model():
    model = vision_transformer.vit_b_16(pretrained=True)
    model.eval()
    return model

In [8]:

# Function to extract features using the Vision Transformer
def extract_features_with_transformer(image_path, model):
    image = Image.open(image_path).convert('RGB')  # Ensure the image is treated as RGB
    image = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])(image).unsqueeze(0)

    with torch.no_grad():
        # Use the standard forward method
        features = model(image)

    return features.squeeze().numpy()



def extract_features_with_3d_cnn(video_path, model, num_frames_to_select=120):
    try:
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Ensure that num_frames_to_select does not exceed the total number of frames
        num_frames_to_select = min(num_frames_to_select, total_frames)

        # Select frames
        selected_frames = np.linspace(0, total_frames - 1, num_frames_to_select, dtype=int)

        features_list = []
        for frame_idx in selected_frames:
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()

            if not ret:
                break

            # Resize each frame to match the input size of the 3D CNN model
            frame_resized = cv2.resize(frame, (112, 112))

            # Convert frame to RGB format (assuming it's in BGR)
            frame_resized_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)

            # Convert frame to tensor
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
            frame_tensor = transform(frame_resized_rgb).unsqueeze(0)

            # Extend frame tensor along the temporal dimension
            frame_tensor_3d = frame_tensor.repeat(1, 3, 1, 1, 1)  # Repeat channels to match 3D CNN input

            # Extract features from the 3D CNN model for each frame
            with torch.no_grad():
                features_frame = model(frame_tensor_3d)

            features_list.append(features_frame)

        cap.release()

        if not features_list:
            print("No frames found in the video.")
            return None

        # Aggregate features from all frames
        features_tensor = torch.stack(features_list)
        aggregated_features = features_tensor.mean(dim=0)
        return aggregated_features.squeeze().numpy()

    except Exception as e:
        print(f"Error reading video frames: {e}")
        return None


# Load the 3D CNN model outside the loop
model_3d_cnn = load_pretrained_3d_cnn()

model_vit = load_vit_model()


def normalize_array(input_array):
    mean = np.mean(input_array)
    std = np.std(input_array)
    normalized_array = (input_array - mean) / std
    return normalized_array


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:05<00:00, 59.2MB/s]


In [10]:
import os

def extract_features_from_file(file_path):
    _, file_extension = os.path.splitext(file_path)

    if file_extension.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.gif']:
        # Image file
        return extract_features_with_transformer(file_path, model_vit)
    elif file_extension.lower() in ['.mp4', '.avi', '.mov', '.mkv', 'webm']:
        # Video file
        return extract_features_with_3d_cnn(file_path, model_3d_cnn, num_frames_to_select=120)
    elif file_extension.lower() in ['.gif']:
        # GIF file
        return extract_features_with_transformer(file_path, model_vit)
    else:
        print(f"Unsupported file type: {file_extension}")
        return None

# Example usage:
file_path = '/content/cat-video.mp4'

features = extract_features_from_file(file_path)

# Haha, totally unnecessary
features = normalize_array(features)

if features is not None:
    print("Features:")
    print(features, features.shape)
else:
    print("Error extracting features.")


Features:
[ 8.01601231e-01 -1.60894185e-01  8.31473351e-01 -4.52724360e-02
 -2.88302839e-01  8.96757543e-01  3.48300189e-01 -1.02607048e+00
 -4.92886364e-01 -2.58406550e-01  2.57356435e-01  6.91429913e-01
 -1.23900902e+00  1.21721423e+00 -8.08469236e-01 -1.45828581e+00
  1.52600813e+00  6.08039677e-01 -4.88237977e-01  5.91618657e-01
  1.36491835e+00  5.25291637e-02 -3.73319805e-01 -2.66194582e-01
  1.15501627e-01 -1.49223804e+00  1.75424084e-01 -1.24165580e-01
 -1.24127519e+00 -1.69156015e-01  4.61939633e-01  1.74247757e-01
  1.79955649e+00 -7.40244627e-01  1.67008817e+00  3.33798349e-01
  1.77910888e+00  1.00968170e+00 -3.56542438e-01 -2.98465520e-01
 -1.26642513e+00 -3.93640429e-01 -3.79487008e-01 -1.18483305e-01
 -2.97408193e-01  9.42761362e-01  9.50235605e-01 -3.08271796e-01
  1.56880841e-01 -6.14339173e-01 -9.58047032e-01 -1.47849226e+00
  3.69023204e-01  5.66570222e-01 -1.23563230e+00  8.68745685e-01
  2.31979012e+00  1.07317019e+00 -3.37347865e-01 -3.55733365e-01
 -8.69588256e-0