In [1]:
import torch
import torchvision
import clip
import os
from tqdm import tqdm
import pandas as pd
import numpy as np


# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")



Using cuda device


In [2]:
clip_model_type = "ViT-B/32"
csv_path = "./Charades/Charades_v1_train.csv"
video_path = "./video" 
pt_path = "./video_tensors_120"
frames_no = 120 
num_video_batch = 25
model, preprocess = clip.load(clip_model_type, device=device, jit=False)

In [3]:
#extract frames
def preprocess_frames(video_file):
    video, audio, info = torchvision.io.read_video(video_file, output_format='TCHW')

    num_frames = len(video)
    frame_rate = info["video_fps"]
    interval = int(frame_rate * 60 / frames_no)
    tensors = []
    to_pil = torchvision.transforms.ToPILImage()
    for i in range(0, num_frames, interval):
      frame = preprocess(to_pil(video[i].squeeze(0))).unsqueeze(0).to(device)
      tensors.append(frame)
    
    preprocess_batch = torch.cat(tensors, dim = 0)
    preprocess_batch = preprocess_batch.to(device)     
 

    return preprocess_batch

In [4]:
# Define a function to concatenate tensors 
def CLIPtokenise(tensors):
    #the input tensors are 3d stacks (lets say x,y,z)
    #function returns a tensor of shape (x, y, sum(z)) and a list of z values

    # get the z values for each tensor
    z_values = [t.shape[0] for t in tensors]

    joined = torch.cat(tensors, dim=0)

    with torch.no_grad():
            image_features = model.encode_image(joined)

    image_features.to(device)
    print(image_features.shape)

    #gets z values and splits them back to video stacks
    split = torch.split(image_features, z_values, dim=0)

    return split




In [5]:
# load captions and video names
df = pd.read_csv(csv_path)
# df_new = pd.read_csv(csv_path)
# df = df_new.head(16)
ids = df["id"]
ids = ids[3750:].copy()
errors = []

# loop over the ids with a batch size step
for i in tqdm(range(0, len(ids), num_video_batch)):
    # get the current batch of ids
    batch_ids = ids[i:i+num_video_batch].copy()
    batch_ids = batch_ids.reset_index(drop=True)

    batch = []
    temp_errors = []

    # loop over the batch ids
    for j in range(len(batch_ids)):
        try:
            # get the video file name with the corresponding id
            # video_file = f"{video_path}/{batch_ids[j]}.mp4"
            video_file = os.path.join(video_path, batch_ids[j] + ".mp4")
            batch.append(preprocess_frames(video_file))
            temp_errors.append(False)
        except Exception as e:
            temp_errors.append(True)
            errors.append(f"video {batch_ids[j]} error {e}")
        
        
    tokenised_batch = CLIPtokenise(batch)

    l = 0
    for k in range(len(batch_ids)):
        if not temp_errors[k]:
            torch.save(tokenised_batch[l], f"{pt_path}/{batch_ids[k]}.pt")
            l=l+1

        


  0%|          | 0/170 [00:00<?, ?it/s]



torch.Size([1455, 512])


  1%|          | 1/170 [00:50<2:22:18, 50.52s/it]

torch.Size([1503, 512])


  1%|          | 2/170 [01:36<2:14:12, 47.93s/it]

torch.Size([1515, 512])


  2%|▏         | 3/170 [02:22<2:10:41, 46.95s/it]

torch.Size([1548, 512])


  2%|▏         | 4/170 [03:10<2:10:40, 47.23s/it]

torch.Size([1599, 512])


  3%|▎         | 5/170 [03:57<2:10:19, 47.39s/it]

torch.Size([1506, 512])


  4%|▎         | 6/170 [04:40<2:05:05, 45.76s/it]

torch.Size([1477, 512])


  4%|▍         | 7/170 [05:28<2:06:04, 46.41s/it]

torch.Size([1936, 512])


  5%|▍         | 8/170 [06:32<2:21:07, 52.27s/it]

torch.Size([1666, 512])


  5%|▌         | 9/170 [07:34<2:27:47, 55.08s/it]

torch.Size([1660, 512])


  6%|▌         | 10/170 [08:26<2:24:15, 54.10s/it]

torch.Size([1592, 512])


  6%|▋         | 11/170 [09:12<2:16:58, 51.69s/it]

torch.Size([1685, 512])


  7%|▋         | 12/170 [10:08<2:20:07, 53.21s/it]

torch.Size([1570, 512])


  8%|▊         | 13/170 [10:54<2:13:10, 50.89s/it]

torch.Size([1554, 512])


  8%|▊         | 14/170 [11:44<2:11:18, 50.50s/it]

torch.Size([1612, 512])


  9%|▉         | 15/170 [12:37<2:12:36, 51.33s/it]

torch.Size([1602, 512])


 10%|█         | 17/170 [14:07<2:01:24, 47.61s/it]

torch.Size([1624, 512])


 11%|█         | 18/170 [14:51<1:57:44, 46.48s/it]

torch.Size([1599, 512])
torch.Size([1451, 512])


 11%|█         | 19/170 [15:31<1:52:21, 44.65s/it]

torch.Size([1652, 512])


 12%|█▏        | 20/170 [16:21<1:55:21, 46.14s/it]

torch.Size([1586, 512])


 12%|█▏        | 21/170 [17:09<1:56:23, 46.87s/it]

torch.Size([1530, 512])


 13%|█▎        | 22/170 [17:48<1:49:51, 44.54s/it]

torch.Size([1552, 512])


 14%|█▎        | 23/170 [18:35<1:50:37, 45.15s/it]

torch.Size([1579, 512])


 14%|█▍        | 24/170 [19:22<1:51:14, 45.72s/it]

torch.Size([1496, 512])


 15%|█▍        | 25/170 [20:04<1:47:44, 44.58s/it]

torch.Size([1523, 512])


 15%|█▌        | 26/170 [20:51<1:49:00, 45.42s/it]

torch.Size([1469, 512])


 16%|█▌        | 27/170 [21:41<1:51:01, 46.58s/it]

torch.Size([1574, 512])


 16%|█▋        | 28/170 [22:32<1:53:16, 47.86s/it]

torch.Size([1603, 512])


 18%|█▊        | 30/170 [24:15<1:56:24, 49.89s/it]

torch.Size([1804, 512])
torch.Size([1900, 512])


 18%|█▊        | 31/170 [25:16<2:03:35, 53.35s/it]

torch.Size([1591, 512])


 19%|█▉        | 32/170 [26:06<2:00:17, 52.30s/it]

torch.Size([1529, 512])


 19%|█▉        | 33/170 [26:50<1:53:54, 49.88s/it]

torch.Size([1450, 512])


 20%|██        | 34/170 [27:34<1:48:34, 47.90s/it]

torch.Size([1559, 512])


 21%|██        | 35/170 [28:23<1:48:21, 48.16s/it]

torch.Size([1737, 512])


 21%|██        | 36/170 [29:15<1:50:07, 49.31s/it]

torch.Size([1609, 512])


 22%|██▏       | 37/170 [30:06<1:50:53, 50.03s/it]

torch.Size([1541, 512])


 22%|██▏       | 38/170 [30:55<1:48:54, 49.51s/it]

torch.Size([1587, 512])


 23%|██▎       | 39/170 [31:39<1:45:01, 48.10s/it]

torch.Size([1337, 512])


 24%|██▎       | 40/170 [32:21<1:39:52, 46.10s/it]

torch.Size([1560, 512])


 24%|██▍       | 41/170 [33:08<1:39:37, 46.34s/it]

torch.Size([1560, 512])


 25%|██▍       | 42/170 [33:59<1:42:05, 47.86s/it]

torch.Size([1597, 512])


 25%|██▌       | 43/170 [34:51<1:43:58, 49.12s/it]

torch.Size([1557, 512])


 26%|██▌       | 44/170 [35:41<1:43:52, 49.47s/it]

torch.Size([1439, 512])


 26%|██▋       | 45/170 [36:26<1:40:16, 48.13s/it]

torch.Size([1590, 512])


 27%|██▋       | 46/170 [37:14<1:39:24, 48.10s/it]

torch.Size([1609, 512])


 28%|██▊       | 48/170 [39:06<1:45:49, 52.04s/it]

torch.Size([1633, 512])
torch.Size([1612, 512])


 29%|██▉       | 49/170 [39:57<1:44:16, 51.71s/it]

torch.Size([1572, 512])


 30%|███       | 51/170 [41:42<1:44:03, 52.47s/it]

torch.Size([1834, 512])
torch.Size([1622, 512])


 31%|███       | 52/170 [42:35<1:43:26, 52.60s/it]

torch.Size([1553, 512])


 31%|███       | 53/170 [43:27<1:42:20, 52.48s/it]

torch.Size([1607, 512])


 32%|███▏      | 54/170 [44:21<1:42:16, 52.90s/it]

torch.Size([1618, 512])


 32%|███▏      | 55/170 [45:15<1:42:17, 53.37s/it]

torch.Size([1504, 512])


 33%|███▎      | 56/170 [46:05<1:39:20, 52.28s/it]

torch.Size([1592, 512])


 34%|███▎      | 57/170 [46:58<1:38:51, 52.49s/it]

torch.Size([1643, 512])


 34%|███▍      | 58/170 [47:45<1:34:46, 50.78s/it]

torch.Size([1537, 512])


 35%|███▍      | 59/170 [48:33<1:32:38, 50.08s/it]

torch.Size([1518, 512])


 35%|███▌      | 60/170 [49:18<1:28:54, 48.49s/it]

torch.Size([1726, 512])


 36%|███▌      | 61/170 [50:16<1:33:32, 51.49s/it]

torch.Size([1578, 512])


 36%|███▋      | 62/170 [51:09<1:33:24, 51.89s/it]

torch.Size([1670, 512])


 37%|███▋      | 63/170 [52:05<1:34:33, 53.02s/it]

torch.Size([1463, 512])


 38%|███▊      | 64/170 [52:55<1:31:52, 52.01s/it]

torch.Size([1636, 512])


 38%|███▊      | 65/170 [53:53<1:34:22, 53.93s/it]

torch.Size([1499, 512])


 39%|███▉      | 66/170 [54:45<1:32:30, 53.37s/it]

torch.Size([1537, 512])


 39%|███▉      | 67/170 [55:27<1:25:51, 50.02s/it]

torch.Size([1656, 512])


 40%|████      | 68/170 [56:24<1:28:15, 51.91s/it]

torch.Size([1470, 512])


 41%|████      | 69/170 [57:12<1:25:40, 50.90s/it]

torch.Size([1525, 512])


 41%|████      | 70/170 [58:00<1:23:07, 49.87s/it]

torch.Size([1665, 512])


 42%|████▏     | 71/170 [58:57<1:26:12, 52.25s/it]

torch.Size([1508, 512])


 42%|████▏     | 72/170 [59:36<1:18:31, 48.08s/it]

torch.Size([1496, 512])


 43%|████▎     | 73/170 [1:00:09<1:10:43, 43.75s/it]

torch.Size([1452, 512])


 44%|████▎     | 74/170 [1:00:40<1:03:50, 39.90s/it]

torch.Size([1647, 512])


 44%|████▍     | 75/170 [1:01:24<1:05:06, 41.12s/it]

torch.Size([1499, 512])


 45%|████▍     | 76/170 [1:02:12<1:07:31, 43.10s/it]

torch.Size([1559, 512])


 45%|████▌     | 77/170 [1:02:58<1:08:16, 44.05s/it]

torch.Size([1547, 512])


 46%|████▋     | 79/170 [1:04:33<1:08:52, 45.41s/it]

torch.Size([1432, 512])
torch.Size([1676, 512])


 47%|████▋     | 80/170 [1:05:25<1:10:50, 47.23s/it]

torch.Size([1490, 512])


 48%|████▊     | 81/170 [1:06:16<1:11:40, 48.32s/it]

torch.Size([1629, 512])


 49%|████▉     | 83/170 [1:07:59<1:12:03, 49.69s/it]

torch.Size([1512, 512])
torch.Size([1505, 512])


 50%|█████     | 85/170 [1:09:42<1:12:46, 51.37s/it]

torch.Size([1610, 512])
torch.Size([1452, 512])


 51%|█████     | 87/170 [1:11:23<1:10:47, 51.17s/it]

torch.Size([1508, 512])
torch.Size([1574, 512])


 52%|█████▏    | 88/170 [1:12:22<1:13:21, 53.68s/it]

torch.Size([1517, 512])


 52%|█████▏    | 89/170 [1:13:11<1:10:13, 52.01s/it]

torch.Size([1548, 512])


 53%|█████▎    | 90/170 [1:14:03<1:09:25, 52.07s/it]

torch.Size([1648, 512])


 54%|█████▎    | 91/170 [1:15:01<1:11:09, 54.04s/it]

torch.Size([1569, 512])


 54%|█████▍    | 92/170 [1:15:53<1:09:21, 53.35s/it]

torch.Size([1788, 512])


 55%|█████▍    | 93/170 [1:16:56<1:12:16, 56.31s/it]

torch.Size([1591, 512])


 55%|█████▌    | 94/170 [1:17:52<1:10:55, 55.99s/it]

torch.Size([1653, 512])


 56%|█████▌    | 95/170 [1:18:47<1:09:40, 55.74s/it]

torch.Size([1540, 512])


 56%|█████▋    | 96/170 [1:19:39<1:07:19, 54.58s/it]

torch.Size([1591, 512])


 57%|█████▋    | 97/170 [1:20:38<1:07:58, 55.87s/it]

torch.Size([1612, 512])


 58%|█████▊    | 98/170 [1:21:34<1:07:13, 56.02s/it]

torch.Size([1568, 512])


 58%|█████▊    | 99/170 [1:22:21<1:03:17, 53.48s/it]

torch.Size([1420, 512])


 59%|█████▉    | 100/170 [1:23:08<1:00:07, 51.54s/it]

torch.Size([1417, 512])


 59%|█████▉    | 101/170 [1:23:56<57:53, 50.34s/it]  

torch.Size([1553, 512])


 60%|██████    | 102/170 [1:24:48<57:36, 50.84s/it]

torch.Size([1521, 512])


 61%|██████    | 103/170 [1:25:37<56:11, 50.31s/it]

torch.Size([1505, 512])


 61%|██████    | 104/170 [1:26:28<55:42, 50.64s/it]

torch.Size([1411, 512])


 62%|██████▏   | 105/170 [1:27:11<52:11, 48.18s/it]

torch.Size([1553, 512])


 62%|██████▏   | 106/170 [1:28:02<52:16, 49.01s/it]

torch.Size([1672, 512])


 64%|██████▎   | 108/170 [1:30:00<55:36, 53.81s/it]

torch.Size([1436, 512])
torch.Size([1504, 512])


 64%|██████▍   | 109/170 [1:30:46<52:25, 51.56s/it]

torch.Size([1538, 512])


 65%|██████▍   | 110/170 [1:31:30<49:01, 49.03s/it]

torch.Size([1571, 512])


 65%|██████▌   | 111/170 [1:32:10<45:44, 46.51s/it]

torch.Size([1440, 512])


 66%|██████▌   | 112/170 [1:32:40<40:04, 41.46s/it]

torch.Size([1522, 512])


 66%|██████▋   | 113/170 [1:33:13<36:59, 38.95s/it]

torch.Size([1649, 512])


 67%|██████▋   | 114/170 [1:33:54<36:53, 39.53s/it]

torch.Size([1485, 512])


 68%|██████▊   | 115/170 [1:34:33<36:00, 39.28s/it]

torch.Size([1537, 512])


 68%|██████▊   | 116/170 [1:35:09<34:30, 38.35s/it]

torch.Size([1474, 512])


 69%|██████▉   | 117/170 [1:35:45<33:25, 37.84s/it]

torch.Size([1550, 512])


 69%|██████▉   | 118/170 [1:36:22<32:33, 37.57s/it]

torch.Size([1468, 512])


 70%|███████   | 119/170 [1:37:03<32:40, 38.44s/it]

torch.Size([1636, 512])


 71%|███████   | 120/170 [1:37:47<33:33, 40.27s/it]

torch.Size([1460, 512])


 71%|███████   | 121/170 [1:38:22<31:36, 38.70s/it]

torch.Size([1458, 512])


 72%|███████▏  | 122/170 [1:38:59<30:24, 38.00s/it]

torch.Size([1586, 512])


 72%|███████▏  | 123/170 [1:39:38<30:08, 38.47s/it]

torch.Size([1511, 512])


 73%|███████▎  | 124/170 [1:40:14<28:55, 37.73s/it]

torch.Size([1684, 512])


 74%|███████▎  | 125/170 [1:41:01<30:17, 40.38s/it]

torch.Size([1462, 512])


 74%|███████▍  | 126/170 [1:41:39<29:05, 39.66s/it]

torch.Size([1469, 512])


 75%|███████▍  | 127/170 [1:42:20<28:42, 40.06s/it]

torch.Size([1548, 512])


 75%|███████▌  | 128/170 [1:42:57<27:20, 39.07s/it]

torch.Size([1534, 512])


 76%|███████▌  | 129/170 [1:43:30<25:36, 37.48s/it]

torch.Size([1644, 512])


 76%|███████▋  | 130/170 [1:44:15<26:30, 39.76s/it]

torch.Size([1499, 512])


 77%|███████▋  | 131/170 [1:44:51<25:01, 38.49s/it]

torch.Size([1540, 512])


 78%|███████▊  | 132/170 [1:45:28<24:07, 38.10s/it]

torch.Size([1430, 512])


 78%|███████▊  | 133/170 [1:46:02<22:39, 36.75s/it]

torch.Size([1924, 512])


 79%|███████▉  | 134/170 [1:46:56<25:08, 41.90s/it]

torch.Size([1544, 512])


 79%|███████▉  | 135/170 [1:47:53<27:12, 46.63s/it]

torch.Size([1561, 512])


 80%|████████  | 136/170 [1:48:47<27:35, 48.69s/it]

torch.Size([1549, 512])


 81%|████████  | 137/170 [1:49:39<27:19, 49.68s/it]

torch.Size([1420, 512])


 81%|████████  | 138/170 [1:50:24<25:43, 48.25s/it]

torch.Size([1524, 512])


 82%|████████▏ | 139/170 [1:51:18<25:53, 50.12s/it]

torch.Size([1436, 512])


 83%|████████▎ | 141/170 [1:52:58<24:22, 50.41s/it]

torch.Size([1576, 512])
torch.Size([1580, 512])


 84%|████████▎ | 142/170 [1:53:38<21:58, 47.09s/it]

torch.Size([1529, 512])


 84%|████████▍ | 143/170 [1:54:13<19:38, 43.66s/it]

torch.Size([1551, 512])


 85%|████████▌ | 145/170 [1:55:32<17:21, 41.66s/it]

torch.Size([1666, 512])
torch.Size([1559, 512])


 86%|████████▌ | 146/170 [1:56:11<16:19, 40.81s/it]

torch.Size([1500, 512])


 86%|████████▋ | 147/170 [1:57:00<16:38, 43.39s/it]

torch.Size([1453, 512])


 87%|████████▋ | 148/170 [1:57:49<16:30, 45.04s/it]

torch.Size([1520, 512])


 88%|████████▊ | 149/170 [1:58:37<16:03, 45.87s/it]

torch.Size([1580, 512])


 88%|████████▊ | 150/170 [1:59:35<16:30, 49.55s/it]

torch.Size([1508, 512])


 89%|████████▉ | 151/170 [2:00:28<16:01, 50.60s/it]

torch.Size([1543, 512])


 89%|████████▉ | 152/170 [2:01:18<15:05, 50.31s/it]

torch.Size([1660, 512])


 90%|█████████ | 153/170 [2:02:18<15:08, 53.43s/it]

torch.Size([1769, 512])


 91%|█████████ | 154/170 [2:03:22<15:02, 56.39s/it]

torch.Size([1621, 512])


 91%|█████████ | 155/170 [2:04:12<13:37, 54.51s/it]

torch.Size([1583, 512])


 92%|█████████▏| 156/170 [2:04:55<11:55, 51.14s/it]

torch.Size([1549, 512])


 92%|█████████▏| 157/170 [2:05:34<10:18, 47.60s/it]

torch.Size([1512, 512])


 93%|█████████▎| 158/170 [2:06:12<08:55, 44.62s/it]

torch.Size([1645, 512])


 94%|█████████▎| 159/170 [2:06:54<08:02, 43.89s/it]

torch.Size([1560, 512])


 94%|█████████▍| 160/170 [2:07:38<07:18, 43.87s/it]

torch.Size([1770, 512])


 95%|█████████▍| 161/170 [2:08:25<06:42, 44.69s/it]

torch.Size([1344, 512])


 95%|█████████▌| 162/170 [2:08:56<05:25, 40.65s/it]

torch.Size([1637, 512])


 96%|█████████▌| 163/170 [2:09:42<04:56, 42.38s/it]

torch.Size([1649, 512])


 96%|█████████▋| 164/170 [2:10:31<04:25, 44.26s/it]

torch.Size([1642, 512])


 98%|█████████▊| 166/170 [2:11:52<02:47, 41.97s/it]

torch.Size([1540, 512])
torch.Size([1624, 512])


 98%|█████████▊| 167/170 [2:12:36<02:07, 42.62s/it]

torch.Size([1533, 512])


 99%|█████████▉| 168/170 [2:13:14<01:22, 41.21s/it]

torch.Size([1424, 512])


 99%|█████████▉| 169/170 [2:13:49<00:39, 39.20s/it]

torch.Size([610, 512])


100%|██████████| 170/170 [2:14:04<00:00, 47.32s/it]


In [None]:
df = pd.read_csv(csv_path)
# Initialize empty lists to store the embeddings and captions
all_embeddings = []
all_captions = []


for i in tqdm(range(len(df["id"]))):
  path = os.path.join(pt_path, df.loc[i, "id"] + ".pt")
  #load the .pt files
  data = torch.load(path)
  # Extract the clip embedding and caption
  clip_embedding = data
  caption = df.loc[i, 'descriptions']
  # Append them to the lists
  all_embeddings.append(clip_embedding)
  all_captions.append(caption)

# Concatenate the embeddings into a single tensor
clip_embedding = torch.cat(all_embeddings, dim=0)

# Save the final dictionary as a pickle file
output_path = "./clip_caption/clip_caption.pkl"
with open(output_path, 'wb') as f:
  pickle.dump({"clip_embedding": clip_embedding, "captions": all_captions}, f)


In [None]:
print(errors)


In [None]:
#Comparison between batch processing and processing frame by frame
video, audio, info = torchvision.io.read_video("./video/S6MPZ.mp4", output_format='TCHW')
to_pil = torchvision.transforms.ToPILImage()
image = preprocess(to_pil(video[0].squeeze(0))).unsqueeze(0).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)


tensor = torch.load("./video_tensors/S6MPZ.pt")

print(tensor)
cos_sim = torch.nn.functional.cosine_similarity(image_features, tensor[0].unsqueeze(0)) 
euc_dist = torch.nn.functional.pairwise_distance(image_features, tensor[0].unsqueeze(0))

print(cos_sim)
print(euc_dist)
