In [10]:
import torch
import torchvision
import clip
from PIL import Image
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import pickle

# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")



Using cuda device


In [11]:
clip_model_type = "ViT-B/32"
csv_path = "./Charades/Charades_v1_train.csv"
video_path = "./video" 
pt_path = "./video_tensors"
frames_no = 40 
num_video_batch = 100
model, preprocess = clip.load(clip_model_type, device=device, jit=False)

In [3]:
#extract frames
def extract_frames(video_file):
    video, audio, info = torchvision.io.read_video(video_file, output_format='TCHW')

    num_frames = len(video)
    frame_rate = info["video_fps"]
    interval = int(frame_rate * 60 / frames_no)
    frames = []
    to_pil = torchvision.transforms.ToPILImage()
    for i in range(0, num_frames, interval):
      frames.append(to_pil(video[i].squeeze(0)))    

    return frames

In [4]:
#encode using clip image encoder and stack the tensors alon a new dimension
def preprocess_frames(frames):

    tensors = []

    for i in range(0,len(frames)): 
        frame = preprocess(frames[i]).unsqueeze(0).to(device)
        tensors.append(frame)

    preprocess_batch = torch.cat(tensors, dim = 0)
    preprocess_batch = preprocess_batch.to(device)

    return preprocess_batch


In [5]:
# Define a function to concatenate tensors 
def CLIPtokenise(tensors):
    #the input tensors are 3d stacks (lets say x,y,z)
    #function returns a tensor of shape (x, y, sum(z)) and a list of z values

    # get the z values for each tensor
    z_values = [t.shape[0] for t in tensors]

    joined = torch.cat(tensors, dim=0)

    with torch.no_grad():
            image_features = model.encode_image(joined)

    image_features.to(device)
    print(image_features.shape)

    #gets z values and splits them back to video stacks
    split = torch.split(image_features, z_values, dim=0)

    return split




In [6]:
# load captions and video names
df = pd.read_csv(csv_path)
# df_new = pd.read_csv(csv_path)
# df = df_new.head(16)
ids = df["id"]
errors = []

# loop over the ids with a batch size step
for i in tqdm(range(0, len(ids), num_video_batch)):
    # get the current batch of ids
    batch_ids = ids[i:i+num_video_batch].copy()
    batch_ids = batch_ids.reset_index(drop=True)

    batch = []
    temp_errors = []

    # loop over the batch ids
    for j in range(len(batch_ids)):
        try:
            # get the video file name with the corresponding id
            # video_file = f"{video_path}/{batch_ids[j]}.mp4"
            video_file = os.path.join(video_path, batch_ids[j] + ".mp4")
            batch.append(preprocess_frames(extract_frames(video_file)))
            temp_errors.append(True)
        except Exception as e:
            temp_errors.append(False)
            errors.append(f"video {batch_ids[j]} error {e}")
        
        
    tokenised_batch = CLIPtokenise(batch)

    l = 0
    for k in range(len(batch_ids)):
        if temp_errors[k]:
            torch.save(tokenised_batch[l], f"{pt_path}/{batch_ids[k]}.pt")
            l=l+1

        




torch.Size([2050, 512])


  1%|▏         | 1/80 [02:02<2:40:48, 122.14s/it]

torch.Size([2053, 512])


  2%|▎         | 2/80 [04:02<2:37:35, 121.23s/it]

torch.Size([2071, 512])


  4%|▍         | 3/80 [06:06<2:36:49, 122.21s/it]

torch.Size([1978, 512])


  5%|▌         | 4/80 [08:01<2:31:36, 119.69s/it]

torch.Size([2072, 512])


  6%|▋         | 5/80 [10:06<2:31:53, 121.51s/it]

torch.Size([2039, 512])


  8%|▊         | 6/80 [19:44<5:41:25, 276.82s/it]

torch.Size([2173, 512])


  9%|▉         | 7/80 [21:46<4:35:10, 226.17s/it]

torch.Size([2126, 512])


 10%|█         | 8/80 [23:45<3:50:10, 191.81s/it]

torch.Size([1982, 512])


 11%|█▏        | 9/80 [25:32<3:15:33, 165.27s/it]

torch.Size([1961, 512])


 12%|█▎        | 10/80 [27:21<2:52:48, 148.11s/it]

torch.Size([2084, 512])


 14%|█▍        | 11/80 [29:21<2:40:22, 139.45s/it]

torch.Size([2048, 512])


 15%|█▌        | 12/80 [31:22<2:31:40, 133.83s/it]

torch.Size([2099, 512])


 16%|█▋        | 13/80 [33:30<2:27:30, 132.10s/it]

torch.Size([2142, 512])


 18%|█▊        | 14/80 [35:49<2:27:39, 134.23s/it]

torch.Size([2115, 512])


 19%|█▉        | 15/80 [37:55<2:22:43, 131.75s/it]

torch.Size([2039, 512])


 20%|██        | 16/80 [40:05<2:19:47, 131.06s/it]

torch.Size([2000, 512])


 21%|██▏       | 17/80 [42:55<2:29:51, 142.72s/it]

torch.Size([2091, 512])


 22%|██▎       | 18/80 [45:35<2:32:51, 147.92s/it]

torch.Size([2197, 512])


 24%|██▍       | 19/80 [47:51<2:26:48, 144.40s/it]

torch.Size([2066, 512])


 25%|██▌       | 20/80 [1:24:09<12:35:04, 755.08s/it]

torch.Size([2008, 512])


 26%|██▋       | 21/80 [1:26:16<9:17:08, 566.58s/it] 

torch.Size([2092, 512])


 28%|██▊       | 22/80 [1:28:30<7:01:58, 436.53s/it]

torch.Size([2031, 512])


 29%|██▉       | 23/80 [1:30:36<5:26:27, 343.64s/it]

torch.Size([2031, 512])


 30%|███       | 24/80 [1:32:49<4:21:37, 280.31s/it]

torch.Size([2037, 512])


 31%|███▏      | 25/80 [1:35:07<3:37:54, 237.71s/it]

torch.Size([2141, 512])


 32%|███▎      | 26/80 [1:37:18<3:05:03, 205.62s/it]

torch.Size([2207, 512])


 34%|███▍      | 27/80 [1:39:40<2:44:36, 186.34s/it]

torch.Size([2078, 512])


 35%|███▌      | 28/80 [1:41:56<2:28:24, 171.24s/it]

torch.Size([2188, 512])


 36%|███▋      | 29/80 [1:44:15<2:17:20, 161.57s/it]

torch.Size([2089, 512])


 38%|███▊      | 30/80 [1:46:29<2:07:45, 153.31s/it]

torch.Size([2181, 512])


 39%|███▉      | 31/80 [1:48:57<2:04:00, 151.85s/it]

torch.Size([2178, 512])


 40%|████      | 32/80 [1:51:24<2:00:21, 150.44s/it]

torch.Size([2027, 512])


 41%|████▏     | 33/80 [1:53:42<1:54:52, 146.66s/it]

torch.Size([2020, 512])


 42%|████▎     | 34/80 [1:55:56<1:49:31, 142.86s/it]

torch.Size([2039, 512])


 44%|████▍     | 35/80 [1:58:18<1:46:59, 142.65s/it]

torch.Size([2033, 512])


 45%|████▌     | 36/80 [2:00:29<1:42:01, 139.13s/it]

torch.Size([2133, 512])


 46%|████▋     | 37/80 [2:02:56<1:41:29, 141.61s/it]

torch.Size([2072, 512])


 48%|████▊     | 38/80 [2:05:45<1:44:47, 149.71s/it]

torch.Size([2036, 512])


 49%|████▉     | 39/80 [2:09:09<1:53:19, 165.84s/it]

torch.Size([2219, 512])


 50%|█████     | 40/80 [2:13:09<2:05:22, 188.07s/it]

torch.Size([2110, 512])


 51%|█████▏    | 41/80 [2:16:42<2:07:10, 195.66s/it]

torch.Size([2126, 512])


 52%|█████▎    | 42/80 [2:20:16<2:07:26, 201.22s/it]

torch.Size([2040, 512])


 54%|█████▍    | 43/80 [2:23:32<2:03:09, 199.70s/it]

torch.Size([2020, 512])


 55%|█████▌    | 44/80 [2:26:39<1:57:30, 195.85s/it]

torch.Size([2116, 512])


 56%|█████▋    | 45/80 [2:30:01<1:55:23, 197.81s/it]

torch.Size([2135, 512])


 57%|█████▊    | 46/80 [2:33:38<1:55:15, 203.41s/it]

torch.Size([2118, 512])


 59%|█████▉    | 47/80 [2:37:00<1:51:42, 203.11s/it]

torch.Size([1976, 512])


 60%|██████    | 48/80 [2:40:05<1:45:21, 197.55s/it]

torch.Size([2041, 512])


 61%|██████▏   | 49/80 [2:43:30<1:43:11, 199.72s/it]

torch.Size([2103, 512])


 62%|██████▎   | 50/80 [2:47:00<1:41:29, 203.00s/it]

torch.Size([2171, 512])


 64%|██████▍   | 51/80 [2:50:34<1:39:37, 206.12s/it]

torch.Size([2089, 512])


 65%|██████▌   | 52/80 [2:53:54<1:35:20, 204.29s/it]

torch.Size([2090, 512])


 66%|██████▋   | 53/80 [2:57:10<1:30:47, 201.75s/it]

torch.Size([2067, 512])


 68%|██████▊   | 54/80 [3:00:29<1:27:05, 201.00s/it]

torch.Size([2041, 512])


 69%|██████▉   | 55/80 [3:03:29<1:21:05, 194.60s/it]

torch.Size([2018, 512])


 70%|███████   | 56/80 [3:06:32<1:16:33, 191.41s/it]

torch.Size([2058, 512])


 71%|███████▏  | 57/80 [3:09:39<1:12:45, 189.82s/it]

torch.Size([2056, 512])


 72%|███████▎  | 58/80 [3:12:54<1:10:12, 191.47s/it]

torch.Size([2001, 512])


 74%|███████▍  | 59/80 [3:16:00<1:06:27, 189.89s/it]

torch.Size([2025, 512])


 75%|███████▌  | 60/80 [3:19:09<1:03:10, 189.54s/it]

torch.Size([2170, 512])


 76%|███████▋  | 61/80 [3:22:40<1:02:02, 195.92s/it]

torch.Size([2105, 512])


 78%|███████▊  | 62/80 [3:26:11<1:00:12, 200.67s/it]

torch.Size([1959, 512])


 79%|███████▉  | 63/80 [3:29:11<55:01, 194.23s/it]  

torch.Size([1974, 512])


 80%|████████  | 64/80 [3:32:08<50:24, 189.05s/it]

torch.Size([2014, 512])


 81%|████████▏ | 65/80 [3:35:18<47:21, 189.42s/it]

torch.Size([2038, 512])


 82%|████████▎ | 66/80 [3:38:30<44:24, 190.33s/it]

torch.Size([1992, 512])


 84%|████████▍ | 67/80 [3:41:38<41:02, 189.40s/it]

torch.Size([1990, 512])


 85%|████████▌ | 68/80 [3:44:57<38:30, 192.55s/it]

torch.Size([2061, 512])


 86%|████████▋ | 69/80 [3:48:26<36:12, 197.50s/it]

torch.Size([2044, 512])


 88%|████████▊ | 70/80 [3:51:46<33:01, 198.18s/it]

torch.Size([2098, 512])


 89%|████████▉ | 71/80 [3:55:18<30:19, 202.15s/it]

torch.Size([2001, 512])


 90%|█████████ | 72/80 [3:58:20<26:09, 196.23s/it]

torch.Size([2009, 512])


 91%|█████████▏| 73/80 [4:01:42<23:05, 197.91s/it]

torch.Size([2082, 512])


 92%|█████████▎| 74/80 [4:04:56<19:40, 196.73s/it]

torch.Size([2001, 512])


 94%|█████████▍| 75/80 [4:08:17<16:30, 198.07s/it]

torch.Size([2132, 512])


 95%|█████████▌| 76/80 [4:11:01<12:30, 187.70s/it]

torch.Size([2076, 512])


 96%|█████████▋| 77/80 [4:13:28<08:46, 175.55s/it]

torch.Size([2079, 512])


 98%|█████████▊| 78/80 [4:15:53<05:32, 166.31s/it]

torch.Size([2119, 512])


 99%|█████████▉| 79/80 [4:18:27<02:42, 162.67s/it]

torch.Size([1707, 512])


100%|██████████| 80/80 [4:20:29<00:00, 195.36s/it]


In [None]:
df = pd.read_csv(csv_path)
# Initialize empty lists to store the embeddings and captions
all_embeddings = []
all_captions = []


for i in tqdm(range(len(df["id"]))):
  path = os.path.join(pt_path, df.loc[i, "id"] + ".pt")
  #load the .pt files
  data = torch.load(path)
  # Extract the clip embedding and caption
  clip_embedding = data
  caption = df.loc[i, 'descriptions']
  # Append them to the lists
  all_embeddings.append(clip_embedding)
  all_captions.append(caption)

# Concatenate the embeddings into a single tensor
clip_embedding = torch.cat(all_embeddings, dim=0)

# Save the final dictionary as a pickle file
output_path = "./clip_caption/clip_caption.pkl"
with open(output_path, 'wb') as f:
  pickle.dump({"clip_embedding": clip_embedding, "captions": all_captions}, f)


In [7]:
print(errors)


[]


In [12]:
#Comparison between batch processing and processing frame by frame
video, audio, info = torchvision.io.read_video("./video/S6MPZ.mp4", output_format='TCHW')
to_pil = torchvision.transforms.ToPILImage()
image = preprocess(to_pil(video[0].squeeze(0))).unsqueeze(0).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)


tensor = torch.load("./video_tensors/S6MPZ.pt")

print(tensor)
cos_sim = torch.nn.functional.cosine_similarity(image_features, tensor[0].unsqueeze(0)) 
euc_dist = torch.nn.functional.pairwise_distance(image_features, tensor[0].unsqueeze(0))

print(cos_sim)
print(euc_dist)




tensor([[ 0.0497,  0.1740, -0.0112,  ...,  1.2744, -0.3362,  0.2883],
        [ 0.0991,  0.2061,  0.0660,  ...,  1.6748, -0.4888,  0.2273],
        [ 0.0687,  0.2087, -0.2421,  ...,  1.5742, -0.4629,  0.1131],
        ...,
        [ 0.2361,  0.0197,  0.1917,  ...,  1.0273, -0.1343,  0.1251],
        [ 0.2163,  0.0526,  0.2136,  ...,  0.9785, -0.1228,  0.2971],
        [-0.0357,  0.1965, -0.0172,  ...,  1.1807, -0.1632,  0.2786]],
       device='cuda:0', dtype=torch.float16)
tensor([1.], device='cuda:0', dtype=torch.float16)
tensor([0.0168], device='cuda:0', dtype=torch.float16)
