In [1]:
import os
import shutil
import json
import time

import requests

import random

import av
import cv2
import numpy as np
import pandas as pd

from tqdm import tqdm

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import albumentations as A

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from transformers import AutoProcessor, AutoModel, pipeline

2024-06-11 10:24:08.143381: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-11 10:24:08.711011: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/vladimir/.virtualenvs/ml/lib/python3.10/site-packages/cv2/../../lib64:
2024-06-11 10:24:08.711060: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/vladimir/.virtualenvs/ml/lib/python3.10/site-pa

In [2]:
data = json.load(open('videollava.json'))

In [3]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [4]:
batch_size = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [5]:
random.shuffle(data)

In [6]:
X_train, X_val = data[:-100], data[-100:]

In [7]:
def apply_video_augmentations(video, transform):
    targets={'image': video[0]}
    for i in range(1, video.shape[0]):
        targets[f'image{i}'] = video[i]
    transformed = transform(**targets)
    transformed = np.concatenate(
        [np.expand_dims(transformed['image'], axis=0)] 
        + [np.expand_dims(transformed[f'image{i}'], axis=0) for i in range(1, video.shape[0])]
    )
    return transformed

def read_video_pyav(container, indices):
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    converted_len = int(clip_len * frame_sample_rate)
    while converted_len >= seg_len and clip_len > 1:
        clip_len -= 1
        converted_len = int(clip_len * frame_sample_rate)
    end_idx = converted_len
    start_idx = 0
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices

In [8]:
from FlagEmbedding import BGEM3FlagModel

bge = BGEM3FlagModel('BAAI/bge-m3', use_fp16=False, device='cpu') # Setting use_fp16 to True speeds up computation with a slight performance degradation

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [9]:
processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch16")
model = AutoModel.from_pretrained("microsoft/xclip-base-patch16")
model.to(device)

projector = nn.Linear(512, 1024, bias=False)
projector.to(device)

Linear(in_features=512, out_features=1024, bias=False)

In [10]:
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=0.5)
], additional_targets={
    f'image{i}': 'image'
    for i in range(1, 8)
})

In [11]:
def download_video(video_url):
    try:
        filename = '_'.join(video_url.split('/')[4:])
        if os.path.exists(f'videos/{filename}'):
            return filename
        
        response = requests.get(video_url, timeout=300)
        
        with open(f'videos/{filename}', 'wb') as file:
            file.write(response.content)
        return filename
    except Exception:
        return ''

In [12]:
class VideoDataset(Dataset):

    def __init__(self, meta, transform=None):
        self.meta = meta
        self.transform = transform

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

            
        while True:
            try:
                
                file_path = download_video(self.meta[idx]['link'])
                container = av.open(f'videos/{file_path}')
                frame_sample_rate = random.randint(5, 50)
                indices = sample_frame_indices(clip_len=8, frame_sample_rate=frame_sample_rate, seg_len=container.streams.video[0].frames)
                video = read_video_pyav(container, indices)
                while video.shape[0] < 8:
                    video = np.vstack([video, video[-1:]])
            except Exception as e:
                print(e)
                idx = random.randint(0, len(self.meta))
                continue
                
            break

        if self.transform:
            video = apply_video_augmentations(video, self.transform)

        inputs = processor(
            text='',
            videos=list(video),
            return_tensors="pt",
        )
        
        for i in inputs:
            inputs[i] = inputs[i][0]

        return inputs, self.meta[idx]['description']

In [13]:
train_dataset = VideoDataset(meta=X_train, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

val_dataset = VideoDataset(meta=X_val, transform=transform)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

In [14]:
epochs = 5
lr = 1e-5
projector_lr = 1e-3

optimizer = optim.AdamW(model.parameters(), lr)
projector_optimizer = optim.AdamW(projector.parameters(), projector_lr)


In [15]:
for param in model.parameters():
    param.requires_grad = False

In [17]:
for epoch in range(epochs):

    model.eval() 
    projector.train() 

    train_loss = []
    for i, (batch, texts) in enumerate(tqdm(train_dataloader, desc=f"Epoch: {epoch}")):

        text_embeddings = bge.encode(
            texts, 
            batch_size=1, 
            max_length=1024, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
        )['dense_vecs']
        text_embeddings = torch.tensor(text_embeddings).to(device)
        
        optimizer.zero_grad()
        projector_optimizer.zero_grad()

        batch = batch.to(device)

        outputs = projector(model(**batch).video_embeds)

        loss = (1 - nn.functional.cosine_similarity(outputs, text_embeddings)).mean()
        loss.backward()
        optimizer.step()
        projector_optimizer.step()

        train_loss.append(loss.item())
        

    print('Training loss:', np.mean(train_loss))
    
    model.eval()
    projector.eval() 

    val_loss = []
    for i, (batch, texts) in enumerate(tqdm(val_dataloader, desc=f"Epoch: {epoch}")):

        text_embeddings = bge.encode(
            texts, 
            batch_size=1, 
            max_length=1024, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
        )['dense_vecs']
        text_embeddings = torch.tensor(text_embeddings).to(device)

        batch = batch.to(device)

        with torch.no_grad():
            outputs = projector(model(**batch).video_embeds)

            loss = (1 - nn.functional.cosine_similarity(outputs, text_embeddings)).mean()
    
            val_loss.append(loss.item())

    print('Val loss:', np.mean(val_loss))
    torch.save(projector.state_dict(), 'projector.pth')

  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 0:   2%|████                                                                                                                                                                                       | 27/1237 [01:11<47:45,  2.37s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 0:   4%|███████▎                                                                                                                                                                                   | 48/123

Training loss: 0.18587193483292527


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
Epoch: 0:   8%|██████████████▌                                                                                                                                                                               | 1/13 [00:08<01:39,  8.30s/it]deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:29<00:00,  2.23s/it]


Val loss: 0.17087138157624465


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 1:   0%|▎                                                                                                                                                                                         | 2/1237 [00:11<1:44:38,  5.08s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 1:   0%|▊                                                                                                                                                                                         | 5/1237 [00:19<1:06:15,  3.23s/it]deprecated pixel format used, make sure you did set range correctly
Epoch: 1:   1%|█▏                                                  

Training loss: 0.16552247198580153


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 1:   8%|██████████████▌                                                                                                                                                                               | 1/13 [00:08<01:41,  8.46s/it]deprecated pixel format used, make sure you did set range correctly
Epoch: 1:  15%|█████████████████████████████▏                                                                                                                                                                | 2/13 [00:10<00:54,  4.93s/it]deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

Val loss: 0.1671923857468825


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 2:   3%|█████▋                                                                                                                                                                                     | 38/1237 [01:40<47:41,  2.39s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 6 more times)
Epoch: 2:   6%|███████████▏                                                                                                                                                                               | 74/1237 [03:09<47:25,  2.45s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly

Training loss: 0.16121034536784756


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 3:   8%|██████████████▌                                                                                                                                                                               | 1/13 [00:08<01:40,  8.40s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.34s/it]


Val loss: 0.16633060001409972


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 4:   0%|▏                                                                                                                                                                                         | 1/1237 [00:08<2:52:31,  8.38s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 4:   1%|█▋                                                                                                                                                                                         | 11/1237 [00:34<50:47,  2.49s/it]deprecated pixel format used, make sure you did set range correctly
Epoch: 4:   1%|██▍                                                 

Training loss: 0.16013166070490376


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:29<00:00,  2.24s/it]

Val loss: 0.16649101903805366





In [18]:
for param in model.parameters():
    param.requires_grad = True
for param in model.text_model.parameters():
    param.requires_grad = False

In [19]:
epochs = 15

In [None]:
for epoch in range(epochs):

    model.train() 
    projector.train() 

    train_loss = []
    for i, (batch, texts) in enumerate(tqdm(train_dataloader, desc=f"Epoch: {epoch}")):

        text_embeddings = bge.encode(
            texts, 
            batch_size=1, 
            max_length=1024, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
        )['dense_vecs']
        text_embeddings = torch.tensor(text_embeddings).to(device)
        
        optimizer.zero_grad()
        projector_optimizer.zero_grad()

        batch = batch.to(device)

        outputs = projector(model(**batch).video_embeds)

        loss = (1 - nn.functional.cosine_similarity(outputs, text_embeddings)).mean()
        loss.backward()
        optimizer.step()
        projector_optimizer.step()

        train_loss.append(loss.item())
        

    print('Training loss:', np.mean(train_loss))
    
    model.eval()
    projector.eval() 

    val_loss = []
    for i, (batch, texts) in enumerate(tqdm(val_dataloader, desc=f"Epoch: {epoch}")):

        text_embeddings = bge.encode(
            texts, 
            batch_size=1, 
            max_length=1024, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
        )['dense_vecs']
        text_embeddings = torch.tensor(text_embeddings).to(device)

        batch = batch.to(device)

        with torch.no_grad():
            outputs = projector(model(**batch).video_embeds)

            loss = (1 - nn.functional.cosine_similarity(outputs, text_embeddings)).mean()
    
            val_loss.append(loss.item())

    print('Val loss:', np.mean(val_loss))

  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 0:   2%|███▍                                                                                                                                                                                       | 23/1237 [01:09<55:58,  2.77s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 0:   3%|█████▏                                                                                                                                                                                     | 34/123

Training loss: 0.15413371657857802


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:32<00:00,  2.47s/it]


Val loss: 0.15706428656211266


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 1:   5%|█████████                                                                                                                                                                                  | 60/1237 [02:59<56:21,  2.87s/it]deprecated pixel format used, make sure you did set range correctly
Epoch: 1:   5%|█████████▉                                                                                                                                                                                 | 66/1237 [03:16<56:11,  2.88s/it]deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 1:   8%|███████████████▊                                                                                                        

Training loss: 0.14330358873742574


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 1:   8%|██████████████▌                                                                                                                                                                               | 1/13 [00:09<01:48,  9.03s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:31<00:00,  2.43s/it]


Val loss: 0.15243607530227074


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 2:   3%|████▊                                                                                                                                                                                      | 32/1237 [01:40<57:26,  2.86s/it]deprecated pixel format used, make sure you did set range correctly
Epoch: 2:   3%|████▉                                                                                                                                                                                      | 33/1237 [01:42<57:58,  2.89s/it]deprecated pixel format used, make sure you did set range correctly

Training loss: 0.13575205273727495


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:31<00:00,  2.41s/it]


Val loss: 0.1535979165480687


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 3:   1%|██▍                                                                                                                                                                                      | 16/1237 [00:55<1:00:02,  2.95s/it]deprecated pixel format used, make sure you did set range correctly
Epoch: 3:   2%|███▎                                                                                                                                                                                       | 22/1237 [01:13<58:55,  2.91s/it]deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly

Training loss: 0.1286671114671394


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 3:   8%|██████████████▌                                                                                                                                                                               | 1/13 [00:08<01:40,  8.33s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:31<00:00,  2.39s/it]


Val loss: 0.15498206592523134


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 4:   3%|████▊                                                                                                                                                                                      | 32/1237 [01:41<57:35,  2.87s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 4:   3%|█████▉                                                                                                                                                                                     | 39/1237 [02:01<54:59,  2.75s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly

Training loss: 0.12269820566850029


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 4:   8%|██████████████▌                                                                                                                                                                               | 1/13 [00:08<01:36,  8.04s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:31<00:00,  2.41s/it]


Val loss: 0.15670576462378868


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 5:   1%|█▋                                                                                                                                                                                       | 11/1237 [00:40<1:01:48,  3.02s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 5:   3%|█████▎                                                                                                                                                                                     | 35/1237 [01:49<54:44,  2.73s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly

Training loss: 0.11646499637162001


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:31<00:00,  2.46s/it]


Val loss: 0.15532627587135023


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 6:   2%|███▉                                                                                                                                                                                       | 26/1237 [01:23<58:23,  2.89s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 3 more times)
Epoch: 6:   4%|██████▊                                                                                                                                                                                    | 45/123

Training loss: 0.11045196177791759


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:29<00:00,  2.27s/it]


Val loss: 0.1590092984529642


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 7:   3%|█████▋                                                                                                                                                                                     | 38/1237 [01:56<54:28,  2.73s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 2 more times)
Epoch: 7:   5%|█████████▊                                                                                                                                                                                 | 65/123

Training loss: 0.10546020397324667


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 7:   8%|██████████████▌                                                                                                                                                                               | 1/13 [00:09<01:49,  9.15s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.38s/it]


Val loss: 0.15536455466197088


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 8:   0%|▌                                                                                                                                                                                         | 4/1237 [00:19<1:23:28,  4.06s/it]deprecated pixel format used, make sure you did set range correctly
Epoch: 8:   1%|█▎                                                                                                                                                                                        | 9/1237 [00:34<1:03:25,  3.10s/it]deprecated pixel format used, make sure you did set range correctly

Training loss: 0.09937054191082836


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 8:   8%|██████████████▌                                                                                                                                                                               | 1/13 [00:09<01:49,  9.10s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 8: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:30<00:00,  2.37s/it]


Val loss: 0.16068212000223306


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 9:   1%|█                                                                                                                                                                                         | 7/1237 [00:29<1:07:46,  3.31s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 9:   5%|█████████▋                                                                                                                                                                                 | 64/1237 [03:10<54:13,  2.77s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly

Training loss: 0.09490222948254745


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [01:03<00:00,  4.86s/it]


Val loss: 0.15545902000023767


  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
  return torch.tensor(value)
Epoch: 10:   1%|█▉                                                                                                                                                                                      | 13/1237 [01:16<1:29:52,  4.41s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly
 (repeated 7 more times)
Epoch: 10:   1%|██▏                                                                                                                                                                                     | 15/1237 [01:26<1:35:45,  4.70s/it]deprecated pixel format used, make sure you did set range correctly
deprecated pixel format used, make sure you did set range correctly

In [None]:
model.save_pretrained('finetuned-xclip-base-patch16')
processor.save_pretrained('finetuned-xclip-base-patch16')
torch.save(projector.state_dict(), 'projector.pth')