In [1]:
import os
import json
from dataclasses import dataclass
from typing import Dict, List, Any, Tuple

import numpy as np
import torch
from torch.utils.data import Dataset
from sklearn.utils.class_weight import compute_class_weight

import av  # pip install av
from transformers import (
    AutoImageProcessor,
    VideoMAEForVideoClassification,
    TrainingArguments,
    Trainer,
)

import evaluate  # pip install evaluate

# load environment variables with dotenv
from dotenv import load_dotenv
load_dotenv()


  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
from train import *
import os


In [3]:
# Point this at the Olympic Boxing dataset directory

# Pretrained VideoMAE base (self-supervised on K400)
model_name = "MCG-NJU/videomae-base"

image_processor = AutoImageProcessor.from_pretrained(model_name)
model = VideoMAEForVideoClassification.from_pretrained(
    model_name,
    num_labels=len(LABEL2ID),
    label2id=LABEL2ID,
    id2label=ID2LABEL,
)


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Some weights of VideoMAEForVideoClassification were not initialized from the model checkpoint at MCG-NJU/videomae-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# check for cuda
if torch.cuda.is_available():
    model.to("cuda")
    print("Using CUDA")
else:
    print("Using CPU")

Using CUDA


In [16]:
class BoxingDataset(Dataset):
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
    
    train_paths = []
    val_paths = []
    test_paths = []
    for label in os.listdir("preprocessed_clips_2"):
        paths = (lambda x: [f"preprocessed_clips_2/{x}/{p}" for p in os.listdir(f"preprocessed_clips_2/{x}")])(label)
        paths_count = len(paths)
        train_ind = math.floor(paths_count * 0.8)
        val_ind = train_ind + math.floor(paths_count * 0.1)
        test_ind = val_ind + math.floor(paths_count * 0.1)
        train_paths.extend(paths[:train_ind])
        val_paths.extend(paths[train_ind:val_ind])
        test_paths.extend(paths[val_ind:])
        
    def __init__(self, split: str):
        self.split = split
        
        
    def __len__(self):
        if self.split == "train":
            return len(self.train_paths)
        elif self.split == "val":
            return len(self.val_paths)
        elif self.split == "test":
            return len(self.test_paths)
        else:
            raise ValueError(f"Unknown split: {self.split}")

    def __getitem__(self, idx):
        if self.split == "train":
            path = self.train_paths[idx]
        elif self.split == "val":
            path = self.val_paths[idx]
        elif self.split == "test":
            path = self.test_paths[idx]
        else:
            raise ValueError(f"Unknown split: {self.split}")
        
        clip = np.load(path)
        
        # convert to float and scale to 0-1
        clip = clip.astype(np.float32) / 255.0
        
        # image net mean/std
        clip = (clip - self.mean) / self.std
        
        #reorder to (T,C,H,W)
        clip = clip.transpose(0,3,1,2)
        
        #convert to tensor
        clip = torch.from_numpy(clip)
        
        return {
            "pixel_values": clip,
            "labels": torch.tensor(LABEL2ID[path.split("/")[-2]], dtype=torch.long) 
        }



In [17]:
from collections import Counter 
def label_from_path(p): return p.split("/")[-2]

Counter(map(label_from_path, BoxingDataset.train_paths))


Counter({'LHHP': 2097,
         'RHHP': 1077,
         'LHMP': 993,
         'RHMP': 624,
         'LHBlP': 476,
         'RHBP': 234,
         'RHBlP': 228,
         'LHBP': 227})

In [18]:
Counter(map(label_from_path, BoxingDataset.val_paths))

Counter({'LHHP': 262,
         'RHHP': 134,
         'LHMP': 124,
         'RHMP': 78,
         'LHBlP': 59,
         'RHBP': 29,
         'LHBP': 28,
         'RHBlP': 28})

In [7]:
train_dataset = BoxingDataset(
    split="train",
)
val_dataset = BoxingDataset(
    split="val",
)
test_dataset = BoxingDataset(
    split="test",
)

In [None]:

# FACTS used batch_size=4, grad_accum=2, warmup_ratio=0.1, epochs=10
# Learning rate is not rendered in the HTML; start with 1e-4 and tune around it.
training_args = TrainingArguments(
    output_dir="./facts-boxing-videomae",
    eval_strategy="steps",
    eval_steps=500,
    save_steps=500,
    logging_steps=100,
    save_total_limit=2,
    num_train_epochs=10,
    per_device_train_batch_size=4, 
    per_device_eval_batch_size=8, 
    gradient_accumulation_steps=2,  # effective batch size 8
    warmup_ratio=0.1,
    learning_rate=1e-4,
    weight_decay=0.05,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    report_to="wandb",  # or "wandb"/"tensorboard"
    dataloader_num_workers=4,        # ADD THIS - use multiple workers
    dataloader_pin_memory=True,      # ADD THIS - faster CPU->GPU transfer
    dataloader_prefetch_factor=2, 
)

data_collator = VideoDataCollator()

train_labels = [LABEL2ID[path.split("/")[-2]] for path in BoxingDataset.train_paths]

# class_weights = compute_class_weight(
#     class_weight='balanced',
#     classes=np.arange(len(LABEL2ID)),
#     y=np.array(train_labels)  # Ensure it's a numpy array
# )
# class_weights = torch.tensor(class_weights, dtype=torch.float32)


In [9]:
# class_counts = np.bincount(train_labels)
# sample_weights = 1.0 / class_counts[train_labels]

# sampler = torch.utils.data.WeightedRandomSampler(
#     sample_weights,                                              
#     len(sample_weights), 
#     replacement=True
# )
# sampler


In [10]:
import random
import matplotlib.pyplot as plt
from collections import defaultdict

# sample_counts = defaultdict(int)
# for i in range(10000):
#     sampled_idx = random.choices(train_labels, weights=sample_weights, k=1)[0]
#     sample_counts[sampled_idx] += 1

# # use matplotlib to plot the distribution

# plt.bar(sample_counts.keys(), sample_counts.values())
# plt.show()

In [11]:
# class UniformSamplerTrainer(Trainer):
#     def __init__(self, *args, train_sampler=None, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.train_sampler = train_sampler
        
#     def _get_train_sampler(self, train_dataset: Dataset | None = None):
#         if train_dataset is None:
#             train_dataset = self.train_dataset
            
#         if train_dataset is None or not has_length(train_dataset):
#             return None
        
#         if self.train_sampler is not None:
#             return self.train_sampler
        
#         return super()._get_train_sampler(train_dataset)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)


In [None]:
os.environ["WANDB_NAME"] = "New Preprocessing - Sanity Check"

# Train
trainer.train()


[34m[1mwandb[0m: Currently logged in as: [33mnkosik11[0m ([33mnkosik11-hobby[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss


TrainOutput(global_step=150, training_loss=0.7302090231577555, metrics={'train_runtime': 28.0478, 'train_samples_per_second': 40.645, 'train_steps_per_second': 5.348, 'total_flos': 1.4205896420386406e+18, 'train_loss': 0.7302090231577555, 'epoch': 30.0})



In [13]:

# Evaluate on test split
test_metrics = trainer.evaluate(test_dataset)
print("Test metrics:", test_metrics)


Test metrics: {'eval_loss': 0.046778421849012375, 'eval_accuracy': 0.9736842105263158, 'eval_macro_f1': 0.9707792207792207, 'eval_f1_LHHP': 1.0, 'eval_precision_LHHP': 1.0, 'eval_recall_LHHP': 1.0, 'eval_f1_RHHP': 1.0, 'eval_precision_RHHP': 1.0, 'eval_recall_RHHP': 1.0, 'eval_f1_LHMP': 1.0, 'eval_precision_LHMP': 1.0, 'eval_recall_LHMP': 1.0, 'eval_f1_RHMP': 0.8571428571428571, 'eval_precision_RHMP': 1.0, 'eval_recall_RHMP': 0.75, 'eval_f1_LHBlP': 1.0, 'eval_precision_LHBlP': 1.0, 'eval_recall_LHBlP': 1.0, 'eval_f1_RHBlP': 1.0, 'eval_precision_RHBlP': 1.0, 'eval_recall_RHBlP': 1.0, 'eval_f1_LHBP': 1.0, 'eval_precision_LHBP': 1.0, 'eval_recall_LHBP': 1.0, 'eval_f1_RHBP': 0.9090909090909091, 'eval_precision_RHBP': 0.8333333333333334, 'eval_recall_RHBP': 1.0, 'eval_runtime': 0.8248, 'eval_samples_per_second': 46.07, 'eval_steps_per_second': 6.062, 'epoch': 30.0}
