In [1]:
import os

import av
import cv2
import numpy as np
import pandas as pd

from tqdm.notebook import tqdm

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import albumentations as A

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from transformers import AutoProcessor, AutoModel

In [2]:
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

In [3]:
def apply_video_augmentations(video, transform):
    targets={'image': video[0]}
    for i in range(1, video.shape[0]):
        targets[f'image{i}'] = video[i]
    transformed = transform(**targets)
    transformed = np.concatenate(
        [np.expand_dims(transformed['image'], axis=0)] 
        + [np.expand_dims(transformed[f'image{i}'], axis=0) for i in range(1, video.shape[0])]
    )
    return transformed

In [4]:
def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    converted_len = int(clip_len * frame_sample_rate)
    end_idx = converted_len
    start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices

In [5]:
batch_size = 16
root_dir = 'UCF-101/UCF-101/'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Dataset preparation

In [6]:
labels = [i for i in os.listdir(root_dir) if i[0] != '.']
labels2id = {label:i for i, label in enumerate(labels)}

In [7]:
train = []
for label in tqdm(labels):
    for video_name in os.listdir(f'{root_dir}/{label}'):
        container = av.open(f'{root_dir}/{label}/{video_name}')
        if container.streams.video[0].frames > 75:
            train.append({
                'label': label,
                'video_path': f'{root_dir}/{label}/{video_name}'
            })
train = pd.DataFrame(train)

  0%|          | 0/101 [00:00<?, ?it/s]

In [8]:
train.label.value_counts()

PlayingDhol          164
PlayingCello         164
HorseRiding          163
BoxingPunchingBag    162
Drumming             161
                    ... 
BodyWeightSquats      90
JavelinThrow          82
BlowingCandles        68
BasketballDunk        57
PushUps               54
Name: label, Length: 101, dtype: int64

In [9]:
train['label_id'] = train.label.map(labels2id)

In [10]:
X_train, X_val, _, _ = train_test_split(train, train['label'])

In [11]:
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=0.5)
], additional_targets={
    f'image{i}': 'image'
    for i in range(1, 8)
})

In [12]:
class ActionDataset(Dataset):

    def __init__(self, meta, transform=None):
        self.meta = meta
        self.transform = transform

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()


        file_path = self.meta['video_path'].iloc[idx]
        container = av.open(file_path)
        indices = sample_frame_indices(clip_len=8, frame_sample_rate=5, seg_len=container.streams.video[0].frames)
        try:
            indices = sample_frame_indices(clip_len=8, frame_sample_rate=5, seg_len=container.streams.video[0].frames)
        except Exception:
            indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
        if indices.shape[0] < 8:
            indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
            
        video = read_video_pyav(container, indices)
        while video.shape[0] < 8:
            video = np.vstack([video, video[-1:]])

        if self.transform:
            transformed = apply_video_augmentations(video, self.transform)
            

        inputs = processor(
            text=[''],
            videos=list(video),
            return_tensors="pt",
            padding=True,
        )
        for i in inputs:
            inputs[i] = inputs[i][0]

        return inputs, self.meta['label_id'].iloc[idx]

In [13]:
train_dataset = ActionDataset(meta=X_train, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=15)

val_dataset = ActionDataset(meta=X_val)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=15)

# Load model

In [14]:
processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")
model.to(device)
classifier = nn.Linear(512, len(labels))
classifier.to(device)

Linear(in_features=512, out_features=101, bias=True)

# Frozen XClip training 

In [15]:
for param in model.parameters():
    param.requires_grad = False

In [16]:
epochs = 5
freeze_epochs = 5
model_lr = 1e-5
classifier_lr = 1e-3

model_optimizer = optim.AdamW(model.parameters(), model_lr)
classifier_optimizer = torch.optim.AdamW(classifier.parameters(), lr=classifier_lr)

criterion = nn.CrossEntropyLoss()

In [17]:
for epoch in range(epochs):

    model.eval()
    classifier.train()     

    train_loss = []
    for i, (batch, targets) in enumerate(tqdm(train_dataloader, desc=f"Epoch: {epoch}")):
        classifier_optimizer.zero_grad()

        batch = batch.to(device)
        targets = targets.to(device)

        with torch.no_grad():
            outputs = model(**batch)
        logits = classifier(outputs.video_embeds)

        loss = criterion(logits, targets) 
        loss.backward()
        classifier_optimizer.step()

        train_loss.append(loss.item())

    print('Training loss:', np.mean(train_loss))

    model.eval() 
    classifier.eval()    

    val_loss = []
    val_targets = []
    val_preds = []
    for i, (batch, targets) in enumerate(tqdm(val_dataloader, desc=f"Epoch: {epoch}")):
        with torch.no_grad():

            batch = batch.to(device)
            targets = targets.to(device)

            outputs = model(**batch)
            logits = classifier(outputs.video_embeds)

            loss = criterion(logits, targets) 

            val_loss.append(loss.item())
            val_targets.extend(targets.cpu().numpy())
            val_preds.extend(logits.argmax(axis=1).cpu().numpy())

    print('Val loss:', np.mean(val_loss))
    print('F1:', f1_score(val_targets, val_preds, average='macro'))

Epoch: 0:   0%|          | 0/586 [00:00<?, ?it/s]

Training loss: 3.748993673422231


Epoch: 0:   0%|          | 0/196 [00:00<?, ?it/s]

Val loss: 2.9145336138958835
F1: 0.8244219464627169


Epoch: 1:   0%|          | 0/586 [00:00<?, ?it/s]

Training loss: 2.2512337396576134


Epoch: 1:   0%|          | 0/196 [00:00<?, ?it/s]

Val loss: 1.690573743411473
F1: 0.8721783842761593


Epoch: 2:   0%|          | 0/586 [00:00<?, ?it/s]

Training loss: 1.3135608704220312


Epoch: 2:   0%|          | 0/196 [00:00<?, ?it/s]

Val loss: 1.028557603456536
F1: 0.9075624404063354


Epoch: 3:   0%|          | 0/586 [00:00<?, ?it/s]

Training loss: 0.8378346111790719


Epoch: 3:   0%|          | 0/196 [00:00<?, ?it/s]

Val loss: 0.7018090010601647
F1: 0.9284045562234075


Epoch: 4:   0%|          | 0/586 [00:00<?, ?it/s]

Training loss: 0.5966130399459865


Epoch: 4:   0%|          | 0/196 [00:00<?, ?it/s]

Val loss: 0.5270408716584954
F1: 0.935476840782045


# Full XClip training

In [18]:
for param in model.parameters():
    param.requires_grad = True
for param in model.text_model.parameters():
    param.requires_grad = False

In [19]:
for epoch in range(epochs):

    model.train() 
    classifier.train()     

    train_loss = []
    for i, (batch, targets) in enumerate(tqdm(train_dataloader, desc=f"Epoch: {epoch}")):
        model_optimizer.zero_grad()
        classifier_optimizer.zero_grad()

        batch = batch.to(device)
        targets = targets.to(device)

        outputs = model(**batch)
        logits = classifier(outputs.video_embeds)

        loss = criterion(logits, targets) 
        loss.backward()
        model_optimizer.step()
        classifier_optimizer.step()

        train_loss.append(loss.item())

    print('Training loss:', np.mean(train_loss))

    model.eval()
    classifier.eval()   

    val_loss = []
    val_targets = []
    val_preds = []
    for i, (batch, targets) in enumerate(tqdm(val_dataloader, desc=f"Epoch: {epoch}")):
        with torch.no_grad():

            batch = batch.to(device)
            targets = targets.to(device)

            outputs = model(**batch)
            logits = classifier(outputs.video_embeds)

            loss = criterion(logits, targets) 

            val_loss.append(loss.item())
            val_targets.extend(targets.cpu().numpy())
            val_preds.extend(logits.argmax(axis=1).cpu().numpy())           

    print('Val loss:', np.mean(val_loss))
    print('F1:', f1_score(val_targets, val_preds, average='macro'))

Epoch: 0:   0%|          | 0/586 [00:00<?, ?it/s]

Training loss: 0.27012621462599407


Epoch: 0:   0%|          | 0/196 [00:00<?, ?it/s]

Val loss: 0.1985938476816732
F1: 0.9514175167728979


Epoch: 1:   0%|          | 0/586 [00:00<?, ?it/s]

Training loss: 0.06677514677326936


Epoch: 1:   0%|          | 0/196 [00:00<?, ?it/s]

Val loss: 0.13393537319094248
F1: 0.964959740182343


Epoch: 2:   0%|          | 0/586 [00:00<?, ?it/s]

Training loss: 0.03211416319050785


Epoch: 2:   0%|          | 0/196 [00:00<?, ?it/s]

Val loss: 0.08779511839740586
F1: 0.9757769098558563


Epoch: 3:   0%|          | 0/586 [00:00<?, ?it/s]

Training loss: 0.0136671471234678


Epoch: 3:   0%|          | 0/196 [00:00<?, ?it/s]

Val loss: 0.07810192471085002
F1: 0.9786765262877234


Epoch: 4:   0%|          | 0/586 [00:00<?, ?it/s]

Training loss: 0.10124871646508625


Epoch: 4:   0%|          | 0/196 [00:00<?, ?it/s]

Val loss: 0.1280776294642033
F1: 0.965067701623843
