# Train head over embeddings

In [1]:
import torch
import torchvision
import pytorchvideo

In [2]:
torch.cuda.empty_cache()
#device = "cpu"
device = "cuda:1" if torch.cuda.is_available() else "cpu"
print("Device: " + device)
print(f"Devices count: {torch.cuda.device_count()}")

Device: cuda:1
Devices count: 2


In [3]:
from pathlib import Path

In [4]:
import json
import pandas
import numpy

In [5]:
from tqdm import tqdm

In [6]:
from misc.utils_mvit import *
from misc.embeddings import *



In [7]:
data_path = Path("./data")
videos_path = data_path / "train_dataset"
metadata_path = data_path / "train.csv"

In [8]:
filenames = walk_directory(videos_path)
filenames = sorted(filenames)

## Model

In [9]:
#model_name = "mvit_v1_b"
model_name = "mvit_v2_s"
backbone = getattr(torchvision.models.video, model_name)(weights=torchvision.models.video.MViT_V2_S_Weights.KINETICS400_V1)

In [10]:
backbone = backbone.to(device)
backbone = backbone.eval()

In [11]:
head = torch.nn.Sequential(
    torch.nn.Linear(400, 64),
    torch.nn.BatchNorm1d(64, affine=False),
).to(device)

In [12]:
encoder = VideoEncoder(
    backbone,
    head,
    torchvision.models.video.MViT_V1_B_Weights.KINETICS400_V1.transforms()
).to(device)

## Load embeddings

In [13]:
embeddings_path = data_path / model_name

In [14]:
embeddings_path_torch = embeddings_path / "new_embeddings.pt"
embeddings_uuid_path  = embeddings_path / "new_embeddings_uuid.csv"

In [15]:
metadata_train  = pandas.read_csv(metadata_path)#, index_col="uuid")
embeddings_uuid = pandas.read_csv(embeddings_uuid_path)

In [16]:
id_to_uuid = embeddings_uuid["uuid"].to_numpy()
uuid_to_id = defaultdict(lambda: -1, {value: index for index, value in enumerate(id_to_uuid)})

In [17]:
metadata_train

Unnamed: 0,created,uuid,link,is_duplicate,duplicate_for,is_hard
0,2024-06-01 00:05:43,23fac2f2-7f00-48cb-b3ac-aac8caa3b6b4,https://s3.ritm.media/yappy-db-duplicates/23fa...,False,,False
1,2024-06-01 00:11:01,2fa37210-3c25-4a87-88f2-1242c2c8a699,https://s3.ritm.media/yappy-db-duplicates/2fa3...,False,,False
2,2024-06-01 00:13:20,31cc33d5-95de-4799-ad01-87c8498d1bde,https://s3.ritm.media/yappy-db-duplicates/31cc...,False,,False
3,2024-06-01 00:27:23,03abd0ec-609e-4eea-9f2a-b6b7442bc881,https://s3.ritm.media/yappy-db-duplicates/03ab...,False,,False
4,2024-06-01 00:30:23,22ee0045-004b-4c7e-98f2-77e5e02e2f15,https://s3.ritm.media/yappy-db-duplicates/22ee...,False,,False
...,...,...,...,...,...,...
22758,2024-09-12 13:46:57,0efe756a-e965-40c1-94db-de7f3e6649a9,https://s3.ritm.media/yappy-db-duplicates/0efe...,True,131846f3-6f5c-497a-a2fa-95cfb3929301,False
22759,2024-09-12 14:46:13,caec3b94-e356-4576-b00a-515e0df1dfc3,https://s3.ritm.media/yappy-db-duplicates/caec...,True,3b5eb15a-c6d7-4214-8dd6-c029564ff11d,False
22760,2024-09-13 09:08:42,c5b69151-f240-4e27-a5c9-c41f79a167e9,https://s3.ritm.media/yappy-db-duplicates/c5b6...,True,17ecc94a-f28a-40d5-b438-86b6e82a2fef,False
22761,2024-09-13 14:52:21,6d3233b6-f8de-49ba-8697-bb30dbf825f7,https://s3.ritm.media/yappy-db-duplicates/6d32...,True,1838f7a7-ef2a-4141-a125-90fb5bf0c5a2,False


In [18]:
_not_in_index = set(id_to_uuid) - set(metadata_train.index)
_not_in_index = list(_not_in_index)

_dummy_data = pandas.DataFrame([['2020-06-01 00:05:43', uuid, '', False, numpy.nan, False] for uuid in range(len(_not_in_index))], index=list(_not_in_index), columns=metadata_train.columns)
metadata_train = pandas.concat([metadata_train, _dummy_data], axis=0)

In [19]:
embeddings = torch.load(embeddings_path_torch)

In [20]:
positive_pairs = metadata_train[metadata_train["is_duplicate"] == True][["uuid", "duplicate_for"]]

In [21]:
positive_pairs["uuid"] = positive_pairs["uuid"].map(uuid_to_id)
positive_pairs["duplicate_for"] = positive_pairs["duplicate_for"].map(uuid_to_id)

In [22]:
positive_pairs = positive_pairs.loc[(positive_pairs >= 0).all(axis=1)]

In [23]:
positive_pairs.shape

(280, 2)

## Train

In [24]:
#train_dataset = ContrastiveDuplicatesDataset(embeddings, positive_pairs.to_numpy()[numpy.random.choice(100, size=100),:], n_negatives=32)
train_dataset = ContrastiveDuplicatesDataset(embeddings, positive_pairs.to_numpy(), n_negatives=32)

In [25]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader  = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=False)

In [26]:
history = train_head(encoder, train_dataloader, test_dataloader, device=device, n_epochs=100)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.85it/s]


In [27]:
with torch.no_grad():
    new_embeddings = encoder.head(embeddings.to(device)).detach().cpu()

In [28]:
new_embeddings_path_torch = embeddings_path / "headed_embeddings.pt"
torch.save(new_embeddings, new_embeddings_path_torch)