In [None]:
import torch
import torchvision
import skimage.io as io
import clip
from PIL import Image
import pickle
import os
from tqdm import tqdm
import pandas as pd
import numpy as np

# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

clip_model_type = "ViT-B_32"
out_path = "./data/Video_CLIP_ViT-B_32.pkl"
csv_path = "./Charades/Charades_v1_train.csv"
video_path = "./video" 
frames_no = 10 
num_video_batch = 5
model, preprocess = clip.load(clip_model_type, device=device, jit=False)

#extract frames
def extract_frames(video_file):
    video, audio, info = torchvision.io.read_video(video_file)

    num_frames = info["video_fps"] * info["video_duration"]
    frame_rate = info["video_fps"]
    interval = int(frame_rate * 60 / frames_no)
    frames = []
    to_pil = torchvision.transforms.ToPILImage()
    for i in range(0, num_frames, interval):
      frames.append(to_pil(video[i]))

    return frames

#encode using clip image encoder and stack the tensors alon a new dimension
def preprocess_frames(frames):

    tensors = []

    for i in range(0,len(frames)): #should this be len()+1???
        frame = preprocess(frames[i]).unsqueeze(0).to(device)
        tensors.append(frame)

    preprocess_batch = torch.stack(tensors)
    preprocess_batch = preprocess_batch.to(device)


    # with torch.no_grad():
    #     image_features = model.encode_image(batch)

    return preprocess_batch
    
# Define a function to concatenate tensors 
def CLIPtokenise(tensors):
    #the input tensors are 3d stacks (lets say x,y,z)
    #function returns a tensor of shape (x, y, sum(z)) and a list of z values

    # get the z values for each tensor
    z_values = [t.shape[2] for t in tensors]
    joined = torch.cat(tensors, dim=2)

    with torch.no_grad():
            image_features = model.encode_image(joined)

    #gets z values and splits them back to video stacks
    split = torch.split(image_features, z_values, dim=2)

    return split

# load captions and video names
df = pd.read_csv(csv_path)
ids = df["id"]

# loop over the ids with a batch size step
for i in tqdm(range(0, len(ids), num_video_batch)):
    # get the current batch of ids
    batch_ids = ids[i:i+num_video_batch]

    batch = []

    # loop over the batch ids
    for id in batch_ids:
        # get the video file name with the corresponding id
        video_file = f"{video_path}/{id}.mp4"
        batch.append(preprocess_frames(extract_frames(video_file)))
        
    tokenised_batch = CLIPtokenise(batch)

    for j in range(len(batch_ids)):
        torch.save(tokenised_batch[j], f"./video_tensors/{batch_ids[j]}.pt")

