In [2]:
import numpy as np
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import pickle
import os

device = "cpu"

In [3]:
video_frames_path = '/Users/aahd/Library/CloudStorage/OneDrive-UniversityofSouthampton/year_4/Deep Learning/cw/DL_MindVideo/wen_2018/video_test_256_3hz.npy'
segment_ids_path = '/Users/aahd/Library/CloudStorage/OneDrive-UniversityofSouthampton/year_4/Deep Learning/cw/DL_MindVideo/wen_2018/test_seg_id_3hz.npy'
text_ids_path = '/Users/aahd/Library/CloudStorage/OneDrive-UniversityofSouthampton/year_4/Deep Learning/cw/DL_MindVideo/wen_2018/text_test_256_3hz.npy'
fmri_ids_path = '/Users/aahd/Library/CloudStorage/OneDrive-UniversityofSouthampton/year_4/Deep Learning/cw/DL_MindVideo/wen_2018/fmri_test_subject1.npy'

In [4]:
def load_clip_model():
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    return model, processor

def get_clip_embeddings(model, processor, image_tensor, text_caption):
    inputs = processor(text=[text_caption], images=image_tensor, return_tensors="pt", padding=True).to(device)
    outputs = model(**inputs)
    image_embedding = outputs.image_embeds.squeeze().detach()
    text_embedding = outputs.text_embeds.squeeze().detach()
    return image_embedding, text_embedding

def process_embeddings(video_frames, text_ids, model, processor):
    image_embeddings_list = []
    text_embeddings_list = []
    total_segments = video_frames.shape[0] * video_frames.shape[1]
    progress_bar = tqdm(total=total_segments, desc="Processing Video and Text Pairs")
    
    for i in range(video_frames.shape[0]):
        for j in range(video_frames.shape[1]):
            frame = video_frames[i, j, 0]  # Taking the first frame of each segment
            text = "Sample text"  # Placeholder for actual text extraction logic
            image_tensor = Image.fromarray(frame)
            image_embedding, text_embedding = get_clip_embeddings(model, processor, image_tensor, text)
            
            image_embeddings_list.append(image_embedding)
            text_embeddings_list.append(text_embedding)
            
            progress_bar.update(1)
    
    progress_bar.close()

    image_embeddings_tensor = torch.stack(image_embeddings_list).to(device)
    text_embeddings_tensor = torch.stack(text_embeddings_list).to(device)
    
    return text_embeddings_tensor, image_embeddings_tensor

def plot_similarity(similarity_matrix, title):
    plt.figure(figsize=(10, 8))
    sns.heatmap(similarity_matrix.cpu().detach().numpy(), cmap='viridis')
    plt.title(title)
    plt.show()


video_frames = np.load(video_frames_path)
text_ids = np.load(text_ids_path)
segment_ids = np.load(segment_ids_path)

# Load model and processor
model, processor = load_clip_model()




In [5]:
# Analysis 
# text_embed, video_embed = process_embeddings(video_frames, text_ids, model, processor, (image_projection_head, text_projection_head))
text_embed, video_embed = process_embeddings(video_frames, text_ids, model, processor)


Processing Video and Text Pairs: 100%|██████████| 1200/1200 [02:30<00:00,  7.96it/s]


In [6]:
#save the text_embedding as the a pickle file
with open('text_embedding.pkl', 'wb') as f:
    pickle.dump(text_embed, f)
    
#save the video_embedding as the a pickle file
with open('video_embedding.pkl', 'wb') as f:
    pickle.dump(video_embed, f)
    
